diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/0_run_train_dit_trainer.sh b/ppdiffusers/examples/class_conditional_image_generation/DiT/0_run_train_dit_trainer.sh index 6c9c1337a..3ff6466c2 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/0_run_train_dit_trainer.sh +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/0_run_train_dit_trainer.sh @@ -17,7 +17,7 @@ TRAINER_INSTANCES='127.0.0.1' MASTER='127.0.0.1:8080' TRAINERS_NUM=1 # nnodes, machine num TRAINING_GPUS_PER_NODE=8 # nproc_per_node -DP_DEGREE=1 # dp_parallel_degree +DP_DEGREE=8 # dp_parallel_degree MP_DEGREE=1 # tensor_parallel_degree SHARDING_DEGREE=1 # sharding_parallel_degree @@ -30,7 +30,7 @@ num_workers=8 max_steps=7000000 logging_steps=50 save_steps=5000 -image_logging_steps=5000 +image_logging_steps=-1 seed=0 USE_AMP=True @@ -68,4 +68,9 @@ ${TRAINING_PYTHON} train_image_generation_trainer.py \ --seed ${seed} \ --recompute ${recompute} \ --enable_xformers_memory_efficient_attention ${enable_xformers} \ - --bf16 ${USE_AMP} + --bf16 ${USE_AMP} \ + --dp_degree ${DP_DEGREE} \ + --tensor_parallel_degree ${MP_DEGREE} \ + --sharding_parallel_degree ${SHARDING_DEGREE} \ + --pipeline_parallel_degree 1 \ + --sep_parallel_degree 1 \ diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/2_run_train_sit_trainer.sh b/ppdiffusers/examples/class_conditional_image_generation/DiT/2_run_train_sit_trainer.sh index 57a2d665e..22f6d318a 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/2_run_train_sit_trainer.sh +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/2_run_train_sit_trainer.sh @@ -17,7 +17,7 @@ TRAINER_INSTANCES='127.0.0.1' MASTER='127.0.0.1:8080' TRAINERS_NUM=1 # nnodes, machine num TRAINING_GPUS_PER_NODE=8 # nproc_per_node -DP_DEGREE=1 # dp_parallel_degree +DP_DEGREE=8 # dp_parallel_degree MP_DEGREE=1 # tensor_parallel_degree SHARDING_DEGREE=1 # sharding_parallel_degree @@ -68,4 +68,9 @@ ${TRAINING_PYTHON} train_image_generation_trainer.py \ --seed ${seed} \ --recompute ${recompute} \ --enable_xformers_memory_efficient_attention ${enable_xformers} \ - --bf16 ${USE_AMP} + --bf16 ${USE_AMP} \ + --dp_degree ${DP_DEGREE} \ + --tensor_parallel_degree ${MP_DEGREE} \ + --sharding_parallel_degree ${SHARDING_DEGREE} \ + --pipeline_parallel_degree 1 \ + --sep_parallel_degree 1 \ diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/4_run_train_largedit_3b_trainer.sh b/ppdiffusers/examples/class_conditional_image_generation/DiT/4_run_train_largedit_3b_trainer.sh index 8b5e0c73c..c8bb253a7 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/4_run_train_largedit_3b_trainer.sh +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/4_run_train_largedit_3b_trainer.sh @@ -17,20 +17,22 @@ TRAINER_INSTANCES='127.0.0.1' MASTER='127.0.0.1:8080' TRAINERS_NUM=1 # nnodes, machine num TRAINING_GPUS_PER_NODE=8 # nproc_per_node -DP_DEGREE=1 # dp_parallel_degree -MP_DEGREE=1 # tensor_parallel_degree +DP_DEGREE=8 # dp_parallel_degree +MP_DEGREE=2 # tensor_parallel_degree SHARDING_DEGREE=1 # sharding_parallel_degree +accumulation_steps=2 # gradient_accumulation_steps + config_file=config/LargeDiT_3B_patch2.json OUTPUT_DIR=./output_trainer/LargeDiT_3B_patch2_trainer feature_path=./data/fastdit_imagenet256 -batch_size=32 # per gpu +batch_size=16 # per gpu num_workers=8 max_steps=7000000 logging_steps=50 save_steps=5000 -image_logging_steps=5000 +image_logging_steps=-1 seed=0 USE_AMP=True @@ -45,7 +47,7 @@ ${TRAINING_PYTHON} train_image_generation_trainer.py \ --feature_path ${feature_path} \ --output_dir ${OUTPUT_DIR} \ --per_device_train_batch_size ${batch_size} \ - --gradient_accumulation_steps 1 \ + --gradient_accumulation_steps ${accumulation_steps} \ --learning_rate 1e-4 \ --weight_decay 0.0 \ --max_steps ${max_steps} \ @@ -68,4 +70,8 @@ ${TRAINING_PYTHON} train_image_generation_trainer.py \ --seed ${seed} \ --recompute ${recompute} \ --enable_xformers_memory_efficient_attention ${enable_xformers} \ - --bf16 ${USE_AMP} + --bf16 ${USE_AMP} \ + --dp_degree ${DP_DEGREE} \ + --tensor_parallel_degree ${MP_DEGREE} \ + --sharding_parallel_degree ${SHARDING_DEGREE} \ + --pipeline_parallel_degree 1 \ diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md index 334e226e7..759c09ac2 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/README.md @@ -36,17 +36,25 @@ Tips: #### 1.3.2 单机多卡训练 ```bash -config_file=config/DiT_XL_patch2.json -OUTPUT_DIR=./output/DiT_XL_patch2_trainer +TRAINING_MODEL_RESUME="None" +TRAINER_INSTANCES='127.0.0.1' +MASTER='127.0.0.1:8080' +TRAINERS_NUM=1 # nnodes, machine num +TRAINING_GPUS_PER_NODE=8 # nproc_per_node +DP_DEGREE=8 # dp_parallel_degree +MP_DEGREE=1 # tensor_parallel_degree +SHARDING_DEGREE=1 # sharding_parallel_degree -# config_file=config/SiT_XL_patch2.json -# OUTPUT_DIR=./output/SiT_XL_patch2_trainer +config_file=config/DiT_XL_patch2.json +OUTPUT_DIR=./output_trainer/DiT_XL_patch2_trainer feature_path=./data/fastdit_imagenet256 batch_size=32 # per gpu num_workers=8 max_steps=7000000 logging_steps=50 +save_steps=5000 +image_logging_steps=-1 seed=0 USE_AMP=True @@ -55,7 +63,8 @@ enable_tensorboard=True recompute=True enable_xformers=True -python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" train_image_generation_trainer.py \ +TRAINING_PYTHON="python -m paddle.distributed.launch --master ${MASTER} --nnodes ${TRAINERS_NUM} --nproc_per_node ${TRAINING_GPUS_PER_NODE} --ips ${TRAINER_INSTANCES}" +${TRAINING_PYTHON} train_image_generation_trainer.py \ --do_train \ --feature_path ${feature_path} \ --output_dir ${OUTPUT_DIR} \ @@ -66,10 +75,10 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" train_image_gene --max_steps ${max_steps} \ --lr_scheduler_type "constant" \ --warmup_steps 0 \ - --image_logging_steps 1000 \ + --image_logging_steps ${image_logging_steps} \ --logging_dir ${OUTPUT_DIR}/tb_log \ --logging_steps ${logging_steps} \ - --save_steps 10000 \ + --save_steps ${save_steps} \ --save_total_limit 50 \ --dataloader_num_workers ${num_workers} \ --vae_name_or_path stabilityai/sd-vae-ft-mse \ @@ -83,24 +92,42 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" train_image_gene --seed ${seed} \ --recompute ${recompute} \ --enable_xformers_memory_efficient_attention ${enable_xformers} \ - --bf16 ${USE_AMP} + --bf16 ${USE_AMP} \ + --dp_degree ${DP_DEGREE} \ + --tensor_parallel_degree ${MP_DEGREE} \ + --sharding_parallel_degree ${SHARDING_DEGREE} \ + --pipeline_parallel_degree 1 \ + --sep_parallel_degree 1 \ ``` ### 1.4 自定义训练逻辑开启训练 #### 1.4.1 单机多卡训练 ```bash +config_file=config/DiT_XL_patch2.json +results_dir=./output_notrainer/DiT_XL_patch2_notrainer + +feature_path=./data/fastdit_imagenet256 +global_batch_size=256 +num_workers=8 +max_steps=7000000 +logging_steps=50 +save_steps=5000 + python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" \ train_image_generation_notrainer.py \ - --config_file config/DiT_XL_patch2.json \ - --feature_path ./data/fastdit_imagenet256 \ - --global_batch_size 256 + --config_file ${config_file} \ + --feature_path ${feature_path} \ + --global_batch_size ${global_batch_size} \ + --num_workers ${num_workers} \ + --log_every ${logging_steps} \ + --ckpt_every ${save_steps} \ ``` ## 2 模型推理 -待模型训练完毕,会在`output_dir`保存训练好的模型权重。注意DiT模型推理可以使用ppdiffusers中的DiTPipeline,但是SiT模型推理暂时不支持生成`Pipeline`。 +待模型训练完毕,会在`output_dir`保存训练好的模型权重。注意DiT模型推理可以使用ppdiffusers中的DiTPipeline,**但是SiT模型推理暂时不支持生成`Pipeline`**。 可以参照运行`python infer_demo_dit.py`或者`python infer_demo_dit.py`。 DiT可以使用`tools/convert_dit_to_ppdiffusers.py`生成推理所使用的`Pipeline`。 @@ -128,22 +155,22 @@ python tools/convert_dit_to_ppdiffusers.py 在生成`Pipeline`的权重后,我们可以使用如下的代码进行推理。 ```python -from ppdiffusers import DiTPipeline, DPMSolverMultistepScheduler, DDIMScheduler import paddle from paddlenlp.trainer import set_seed -dtype=paddle.float32 -pipe=DiTPipeline.from_pretrained("./DiT_XL_2_256", paddle_dtype=dtype) -#pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + +from ppdiffusers import DDIMScheduler, DiTPipeline + +dtype = paddle.float32 +pipe = DiTPipeline.from_pretrained("./DiT_XL_2_256", paddle_dtype=dtype) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) -words = ["white shark"] +words = ["golden retriever"] # class_ids [207] class_ids = pipe.get_label_ids(words) set_seed(42) generator = paddle.Generator().manual_seed(0) image = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator).images[0] -image.save("white_shark.png") -print(f'\nGPU memory usage: {paddle.device.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB') +image.save("result_DiT_golden_retriever.png") ``` diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/__init__.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/__init__.py index 0aba18263..e9283b2d6 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/__init__.py +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/__init__.py @@ -13,11 +13,17 @@ # limitations under the License. from . import gaussian_diffusion as gd +from .dist_env import setdistenv from .dit import DiT from .dit_llama import DiT_Llama from .respace import SpacedDiffusion, space_timesteps from .trainer import LatentDiffusionTrainer -from .trainer_args import DataArguments, ModelArguments, NoTrainerTrainingArguments +from .trainer_args import ( + DataArguments, + ModelArguments, + NoTrainerTrainingArguments, + TrainerArguments, +) from .trainer_model import DiTDiffusionModel # Modified from OpenAI's diffusion repos diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/dist_env.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/dist_env.py new file mode 100644 index 000000000..63377dbda --- /dev/null +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/dist_env.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker + + +def set_hyrbid_parallel_seed(basic_seed, data_world_rank, mp_rank, pp_rank=0): + device_id = paddle.device.get_device() + assert "gpu" in device_id + + random.seed(basic_seed + data_world_rank) + np.random.seed(basic_seed + data_world_rank) + paddle.seed(basic_seed + data_world_rank) + + # local_seed/ global_seed is used to control dropout in ModelParallel + local_seed = 1024 + basic_seed + mp_rank * 100 + data_world_rank + global_seed = 2048 + basic_seed + data_world_rank + tracker = get_rng_state_tracker() + tracker.add("global_seed", global_seed) + tracker.add("local_seed", local_seed) + + +def setdistenv(args): + world_size = dist.get_world_size() + if world_size > 1: + args.dp_degree = max(args.dp_degree, 1) + args.sharding_parallel_degree = max(args.sharding_parallel_degree, 1) + args.tensor_parallel_degree = max(args.tensor_parallel_degree, 1) + args.sep_parallel_degree = max(args.sep_parallel_degree, 1) + args.pipeline_parallel_degree = max(args.pipeline_parallel_degree, 1) + + assert ( + world_size % (args.tensor_parallel_degree * args.pipeline_parallel_degree) == 0 + ), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {args.tensor_parallel_degree} and pipeline_parallel_degree: {args.pipeline_parallel_degree}." + + args.dp_degree = world_size // ( + args.tensor_parallel_degree * args.sharding_parallel_degree * args.pipeline_parallel_degree + ) + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": args.dp_degree, + "mp_degree": args.tensor_parallel_degree, + "sharding_degree": args.sharding_parallel_degree, + "pp_degree": args.pipeline_parallel_degree, + } + # strategy.find_unused_parameters = True + + # set control in tensor parallel + strategy.tensor_parallel_configs = {"tensor_init_seed": args.seed} + + fleet.init(is_collective=True, strategy=strategy) + + args.rank = dist.get_rank() + # obtain rank message of hybrid parallel + hcg = fleet.get_hybrid_communicate_group() + args.mp_rank = hcg.get_model_parallel_rank() + args.dp_rank = hcg.get_data_parallel_rank() + args.sharding_rank = hcg.get_sharding_parallel_rank() + + args.data_world_rank = args.dp_rank * args.sharding_parallel_degree + args.sharding_rank + args.data_world_size = world_size // abs(args.tensor_parallel_degree * args.pipeline_parallel_degree) + else: + args.data_world_rank = 0 + args.data_world_size = 1 + args.mp_rank = 0 + args.rank = 0 + + # seed control in hybrid parallel + set_hyrbid_parallel_seed(args.seed, args.data_world_rank, args.mp_rank) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/dit.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/dit.py index 2eec7ea53..5783b091f 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/dit.py +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/dit.py @@ -14,7 +14,6 @@ import collections.abc import math -from functools import partial from itertools import repeat from typing import Optional @@ -22,13 +21,24 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -from paddle.nn.initializer import Constant, Normal, TruncatedNormal - -trunc_normal_ = TruncatedNormal(std=0.02) -normal_ = Normal -zeros_ = Constant(value=0.0) -ones_ = Constant(value=1.0) import paddle.nn.initializer as initializer +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.nn.functional.flash_attention import flash_attention + +from ppdiffusers.configuration_utils import ConfigMixin +from ppdiffusers.models.modeling_utils import ModelMixin + + +def is_model_parrallel(): + if paddle.distributed.get_world_size() > 1: + hcg = paddle.distributed.fleet.get_hybrid_communicate_group() + if hcg.get_model_parallel_world_size() > 1: + return True + else: + return False + else: + return False def _ntuple(n): @@ -82,29 +92,44 @@ def __init__( norm_layer=None, bias=True, drop=0.0, - use_conv=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - bias = to_2tuple(bias) - drop_probs = to_2tuple(drop) - linear_layer = partial(nn.Conv2D, kernel_size=1) if use_conv else nn.Linear - self.fc1 = linear_layer(in_features, hidden_features, bias_attr=bias[0]) + if is_model_parrallel(): + self.fc1 = fleet.meta_parallel.ColumnParallelLinear( + in_features, + hidden_features, + weight_attr=None, + has_bias=bias, + gather_output=True, + ) + self.fc2 = fleet.meta_parallel.ColumnParallelLinear( + hidden_features, + out_features, + weight_attr=None, + has_bias=bias, + gather_output=True, + ) + else: + self.fc1 = nn.Linear(in_features, hidden_features, bias_attr=bias) + self.fc2 = nn.Linear(hidden_features, out_features, bias_attr=bias) + self.act = act_layer() - self.drop1 = nn.Dropout(drop_probs[0]) + self.drop1 = nn.Dropout(drop) self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() - self.fc2 = linear_layer(hidden_features, out_features, bias_attr=bias[1]) - self.drop2 = nn.Dropout(drop_probs[1]) + self.drop2 = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) - x = self.drop1(x) + with get_rng_state_tracker().rng_state("global_seed"): + x = self.drop1(x) x = self.norm(x) x = self.fc2(x) - x = self.drop2(x) + with get_rng_state_tracker().rng_state("global_seed"): + x = self.drop2(x) return x @@ -127,51 +152,92 @@ def __init__( self.scale = self.head_dim**-0.5 self.fused_attn = fused_attn - self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + if is_model_parrallel(): + self.qkv = fleet.meta_parallel.ColumnParallelLinear( + dim, dim * 3, weight_attr=None, has_bias=qkv_bias, gather_output=True + ) + else: + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + if is_model_parrallel(): + self.proj = fleet.meta_parallel.ColumnParallelLinear( + dim, dim, weight_attr=None, has_bias=True, gather_output=True + ) + else: + self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: paddle.Tensor) -> paddle.Tensor: B, N, C = x.shape + dtype = x.dtype qkv = self.qkv(x).reshape([B, N, 3, self.num_heads, self.head_dim]).transpose([2, 0, 3, 1, 4]) q, k, v = qkv[0], qkv[1], qkv[2] q, k = self.q_norm(q), self.k_norm(k) - if self.fused_attn: - x = F.scaled_dot_product_attention_( + if dtype in [paddle.float16, paddle.bfloat16]: + x, _ = flash_attention( q, k, v, - dropout_p=self.attn_drop.p if self.training else 0.0, + dropout=self.attn_drop.p, + return_softmax=False, ) else: - q = q * self.scale - attn = q @ k.transpose([0, 1, 3, 2]) - attn = F.softmax(attn, axis=-1) - attn = self.attn_drop(attn) - x = attn @ v + if self.fused_attn: + x = F.scaled_dot_product_attention_( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = q @ k.transpose([0, 1, 3, 2]) + attn = F.softmax(attn, axis=-1) + with get_rng_state_tracker().rng_state("global_seed"): + attn = self.attn_drop(attn) + x = attn @ v x = x.transpose([0, 2, 1, 3]).reshape([B, N, C]) x = self.proj(x) - x = self.proj_drop(x) + with get_rng_state_tracker().rng_state("global_seed"): + x = self.proj_drop(x) return x -class TimestepEmbedder(nn.Layer): +class ParallelTimestepEmbedder(nn.Layer): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias_attr=True), - nn.Silu(), - nn.Linear(hidden_size, hidden_size, bias_attr=True), - ) + if is_model_parrallel(): + self.mlp = nn.Sequential( + fleet.meta_parallel.ColumnParallelLinear( + frequency_embedding_size, + hidden_size, + weight_attr=None, + has_bias=True, + gather_output=False, # True + ), + nn.Silu(), + fleet.meta_parallel.RowParallelLinear( + hidden_size, + hidden_size, + weight_attr=None, + has_bias=True, + input_is_parallel=True, # + ), + ) + else: + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size), + nn.Silu(), + nn.Linear(hidden_size, hidden_size), + ) self.frequency_embedding_size = frequency_embedding_size @staticmethod @@ -184,22 +250,21 @@ def timestep_embedding(t, dim, max_period=10000): :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ - # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 - freqs = paddle.exp(-math.log(max_period) * paddle.arange(start=0, end=half, dtype=paddle.float32) / half) - args = t[:, None].cast("float32") * freqs[None] - embedding = paddle.concat([paddle.cos(args), paddle.sin(args)], axis=-1) + freqs = paddle.exp(x=-math.log(max_period) * paddle.arange(start=0, end=half, dtype="float32") / half) + args = t[:, (None)].astype(dtype="float32") * freqs[None] + embedding = paddle.concat(x=[paddle.cos(x=args), paddle.sin(x=args)], axis=-1) if dim % 2: - embedding = paddle.concat([embedding, paddle.zeros_like(embedding[:, :1])], axis=-1) + embedding = paddle.concat(x=[embedding, paddle.zeros_like(x=embedding[:, :1])], axis=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) + t_emb = self.mlp(t_freq.cast(self.mlp[0].weight.dtype)) return t_emb -class LabelEmbedder(nn.Layer): +class ParallelLabelEmbedder(nn.Layer): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. """ @@ -207,7 +272,11 @@ class LabelEmbedder(nn.Layer): def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + embedding_dim = num_classes + use_cfg_embedding + if is_model_parrallel: + self.embedding_table = fleet.meta_parallel.VocabParallelEmbedding(embedding_dim, hidden_size) + else: + self.embedding_table = nn.Embedding(embedding_dim, hidden_size) self.num_classes = num_classes self.dropout_prob = dropout_prob @@ -230,7 +299,8 @@ def token_drop(self, labels, force_drop_ids=None): def forward(self, labels, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (self.training and use_dropout) or (force_drop_ids is not None): - labels = self.token_drop(labels, force_drop_ids) + with get_rng_state_tracker().rng_state("global_seed"): + labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings @@ -247,8 +317,25 @@ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, fused_attn=False, **bl self.norm2 = nn.LayerNorm(hidden_size, epsilon=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate=True) # 'tanh' - self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) - self.adaLN_modulation = nn.Sequential(nn.Silu(), nn.Linear(hidden_size, 6 * hidden_size, bias_attr=True)) + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=approx_gelu, + drop=0, + ) + + if is_model_parrallel(): + self.adaLN_modulation = nn.Sequential( + nn.Silu(), + fleet.meta_parallel.ColumnParallelLinear( + hidden_size, 6 * hidden_size, weight_attr=None, has_bias=True, gather_output=True + ), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.Silu(), + nn.Linear(hidden_size, 6 * hidden_size, bias_attr=True), + ) def forward(self, x, c): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, axis=1) @@ -257,16 +344,27 @@ def forward(self, x, c): return x -class FinalLayer(nn.Layer): - """ - The final layer of DiT. - """ - +class ParallelFinalLayer(nn.Layer): def __init__(self, hidden_size, patch_size, out_channels): super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, epsilon=1e-6) - self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias_attr=True) - self.adaLN_modulation = nn.Sequential(nn.Silu(), nn.Linear(hidden_size, 2 * hidden_size, bias_attr=True)) + self.norm_final = nn.LayerNorm(hidden_size, weight_attr=False, bias_attr=False, epsilon=1e-06) + if is_model_parrallel(): + self.linear = fleet.meta_parallel.ColumnParallelLinear( + hidden_size, + patch_size * patch_size * out_channels, + weight_attr=None, + has_bias=True, + gather_output=True, + ) + self.adaLN_modulation = nn.Sequential( + nn.Silu(), + fleet.meta_parallel.ColumnParallelLinear( + hidden_size, 2 * hidden_size, weight_attr=None, has_bias=True, gather_output=True + ), + ) + else: + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels) + self.adaLN_modulation = nn.Sequential(nn.Silu(), nn.Linear(hidden_size, 2 * hidden_size)) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, axis=1) @@ -275,7 +373,7 @@ def forward(self, x, c): return x -class DiT(nn.Layer): +class DiT(ModelMixin, ConfigMixin): """ Diffusion model with a Transformer backbone. """ @@ -285,17 +383,17 @@ class DiT(nn.Layer): def __init__( self, - sample_size=32, # image_size // 8 - patch_size=2, - in_channels=4, - out_channels=8, - num_layers=28, - num_attention_heads=16, - attention_head_dim=72, - mlp_ratio=4.0, - class_dropout_prob=0.1, - num_classes=1000, - learn_sigma=True, + sample_size: int = 32, # image_size // 8 + patch_size: int = 2, + in_channels: int = 4, + out_channels: int = 8, + num_layers: int = 28, + num_attention_heads: int = 16, + attention_head_dim: int = 72, + mlp_ratio: float = 4.0, + class_dropout_prob: float = 0.0, # for tensor parallel + num_classes: int = 1000, + learn_sigma: bool = True, ): super().__init__() self.sample_size = sample_size @@ -311,16 +409,19 @@ def __init__( self.num_classes = num_classes self.learn_sigma = learn_sigma - self.gradient_checkpointing = False - self.fused_attn = False + self.gradient_checkpointing = True + self.fused_attn = True # 1. Define input layers self.x_embedder = PatchEmbed(sample_size, patch_size, in_channels, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + self.t_embedder = ParallelTimestepEmbedder(hidden_size) + self.y_embedder = ParallelLabelEmbedder(num_classes, hidden_size, class_dropout_prob) num_patches = self.x_embedder.num_patches # Will use fixed sin-cos embedding: - self.pos_embed = self.create_parameter(shape=(1, num_patches, hidden_size), default_initializer=zeros_) + self.pos_embed = self.create_parameter( + shape=(1, num_patches, hidden_size), + default_initializer=initializer.Constant(0.0), + ) self.add_parameter("pos_embed", self.pos_embed) # 2. Define transformers blocks @@ -332,14 +433,16 @@ def __init__( ) # 3. Define output layers - self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.final_layer = ParallelFinalLayer(hidden_size, patch_size, self.out_channels) self.initialize_weights() def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): - if isinstance(module, nn.Linear): + if isinstance( + module, (nn.Linear, fleet.meta_parallel.ColumnParallelLinear, fleet.meta_parallel.RowParallelLinear) + ): initializer.XavierUniform()(module.weight) if module.bias is not None: initializer.Constant(value=0)(module.bias) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/dit_llama.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/dit_llama.py index 23b892117..d29e2c5a7 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/dit_llama.py +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/dit_llama.py @@ -12,20 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math +from typing import Optional import paddle import paddle.nn as nn import paddle.nn.functional as F -from paddle.nn.initializer import Constant, Normal, TruncatedNormal +import paddle.nn.initializer as initializer +from paddle.distributed import fleet +from paddle.nn.functional.flash_attention import flash_attention -from .dit import LabelEmbedder, modulate +from ppdiffusers.configuration_utils import ConfigMixin +from ppdiffusers.models.modeling_utils import ModelMixin -trunc_normal_ = TruncatedNormal(std=0.02) -normal_ = Normal -zeros_ = Constant(value=0.0) -ones_ = Constant(value=1.0) -import paddle.nn.initializer as initializer +from .dit import ( + ParallelLabelEmbedder, + ParallelTimestepEmbedder, + is_model_parrallel, + modulate, +) def TypePromote(x, y): @@ -46,49 +50,11 @@ def TypePromote(x, y): promote_type = TYPE_PROMOTE_DICT[y.dtype.name + x.dtype.name] else: return x, y - return x.astype(promote_type), y.astype(promote_type) + return x.cast(promote_type), y.cast(promote_type) -class TimestepEmbedder(paddle.nn.Layer): - """ - Embeds scalar timesteps into vector representations. - """ - - def __init__(self, hidden_size, frequency_embedding_size=256): - super().__init__() - self.mlp = paddle.nn.Sequential( - paddle.nn.Linear(in_features=frequency_embedding_size, out_features=hidden_size, bias_attr=True), - paddle.nn.Silu(), - paddle.nn.Linear(in_features=hidden_size, out_features=hidden_size, bias_attr=True), - ) - self.frequency_embedding_size = frequency_embedding_size - - @staticmethod - def timestep_embedding(t, dim, max_period=10000): - """ - Create sinusoidal timestep embeddings. - :param t: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an (N, D) Tensor of positional embeddings. - """ - half = dim // 2 - freqs = paddle.exp(x=-math.log(max_period) * paddle.arange(start=0, end=half, dtype="float32") / half) - args = t[:, (None)].astype(dtype="float32") * freqs[None] - embedding = paddle.concat(x=[paddle.cos(x=args), paddle.sin(x=args)], axis=-1) - if dim % 2: - embedding = paddle.concat(x=[embedding, paddle.zeros_like(x=embedding[:, :1])], axis=-1) - return embedding - - def forward(self, t): - t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) - return t_emb - - -class Attention(paddle.nn.Layer): - def __init__(self, dim: int, n_heads: int, n_kv_heads, qk_norm: bool): +class Attention(nn.Layer): + def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True): """ Initialize the Attention module. @@ -117,23 +83,36 @@ def __init__(self, dim: int, n_heads: int, n_kv_heads, qk_norm: bool): self.n_local_kv_heads = self.n_kv_heads self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = dim // n_heads - - self.wq = paddle.nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) - self.wk = paddle.nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) - self.wv = paddle.nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) - self.wo = paddle.nn.Linear(n_heads * self.head_dim, dim, bias_attr=False) + if is_model_parrallel(): + self.wq = fleet.meta_parallel.ColumnParallelLinear( + dim, n_heads * self.head_dim, weight_attr=None, has_bias=False, gather_output=True + ) + self.wk = fleet.meta_parallel.ColumnParallelLinear( + dim, self.n_kv_heads * self.head_dim, weight_attr=None, has_bias=False, gather_output=True + ) + self.wv = fleet.meta_parallel.ColumnParallelLinear( + dim, self.n_kv_heads * self.head_dim, weight_attr=None, has_bias=False, gather_output=True + ) + self.wo = fleet.meta_parallel.ColumnParallelLinear( + n_heads * self.head_dim, dim, weight_attr=None, has_bias=False, gather_output=True + ) + else: + self.wq = nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) + self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) + self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) + self.wo = nn.Linear(n_heads * self.head_dim, dim, bias_attr=False) if qk_norm: - self.q_norm = paddle.nn.LayerNorm(self.n_local_heads * self.head_dim) - self.k_norm = paddle.nn.LayerNorm(self.n_local_kv_heads * self.head_dim) + self.q_norm = nn.LayerNorm(self.n_local_heads * self.head_dim) + self.k_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim) else: - self.q_norm = self.k_norm = paddle.nn.Identity() + self.q_norm = self.k_norm = nn.Identity() - self.fused_attn = False + self.fused_attn = fused_attn self.scale = self.head_dim**-0.5 @staticmethod - def reshape_for_broadcast(freqs_cis: paddle.Tensor, x: paddle.Tensor): + def reshape_for_broadcast(freqs_cis, x): """ Reshape frequency tensor for broadcasting it with another tensor. @@ -161,7 +140,7 @@ def reshape_for_broadcast(freqs_cis: paddle.Tensor, x: paddle.Tensor): return freqs_cis.reshape([*shape]) @staticmethod - def apply_rotary_emb(xq: paddle.Tensor, xk: paddle.Tensor, freqs_cis): + def apply_rotary_emb(xq, xk, freqs_cis): """ Apply rotary embeddings to input tensors using the given frequency tensor. @@ -183,14 +162,14 @@ def apply_rotary_emb(xq: paddle.Tensor, xk: paddle.Tensor, freqs_cis): and key tensor with rotary embeddings. """ with paddle.amp.auto_cast(enable=False): - xq_ = paddle.as_complex(x=xq.astype(dtype="float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) - xk_ = paddle.as_complex(x=xk.astype(dtype="float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) + xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) + xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_) - xq_out = paddle.as_real(x=xq_ * freqs_cis).flatten(start_axis=3) - xk_out = paddle.as_real(x=xk_ * freqs_cis).flatten(start_axis=3) - return xq_out.astype(dtype=xq.dtype), xk_out.astype(dtype=xk.dtype) + xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3) + xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3) + return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype) - def forward(self, x: paddle.Tensor, freqs_cis: paddle.Tensor) -> paddle.Tensor: + def forward(self, x, freqs_cis): """ Forward pass of the attention module. @@ -214,34 +193,44 @@ def forward(self, x: paddle.Tensor, freqs_cis: paddle.Tensor) -> paddle.Tensor: xv = xv.reshape([bsz, seqlen, self.n_local_kv_heads, self.head_dim]) xq, xk = Attention.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - xq, xk = xq.to(dtype), xk.to(dtype) + xq, xk = xq.cast(dtype), xk.cast(dtype) n_rep = self.n_local_heads // self.n_local_kv_heads if n_rep >= 1: xk = xk.unsqueeze(axis=3).tile([1, 1, 1, n_rep, 1]).flatten(start_axis=2, stop_axis=3) xv = xv.unsqueeze(axis=3).tile([1, 1, 1, n_rep, 1]).flatten(start_axis=2, stop_axis=3) - if self.fused_attn: - output = F.scaled_dot_product_attention( - xq.transpose([0, 2, 1, 3]), - xk.transpose([0, 2, 1, 3]), - xv.transpose([0, 2, 1, 3]), - dropout_p=0.0, - is_causal=False, - ).transpose([0, 2, 1, 3]) + if dtype in [paddle.float16, paddle.bfloat16]: + output, _ = flash_attention( + xq, + xk, + xv, + dropout=0.0, + causal=False, + return_softmax=False, + ) else: - q = xq.transpose([0, 2, 1, 3]) * self.scale - attn = q @ xk.transpose([0, 2, 1, 3]).transpose([0, 1, 3, 2]) - attn = F.softmax(attn, axis=-1) - output = attn @ xv.transpose([0, 2, 1, 3]) - output = output.transpose([0, 2, 1, 3]) + if self.fused_attn: + output = F.scaled_dot_product_attention_( + xq, + xk, + xv, + dropout_p=0.0, + is_causal=False, + ) + else: + q = xq.transpose([0, 2, 1, 3]) * self.scale + attn = q @ xk.transpose([0, 2, 1, 3]).transpose([0, 1, 3, 2]) + attn = F.softmax(attn, axis=-1) + output = attn @ xv.transpose([0, 2, 1, 3]) + output = output.transpose([0, 2, 1, 3]) output = output.flatten(start_axis=-2) return self.wo(output) -class FeedForward(paddle.nn.Layer): - def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier): +class FeedForward(nn.Layer): + def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): """ Initialize the FeedForward module. @@ -265,19 +254,30 @@ def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multipli hidden_dim = int(2 * hidden_dim / 3) if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.w1 = paddle.nn.Linear(in_features=dim, out_features=hidden_dim, bias_attr=False) - self.w2 = paddle.nn.Linear(in_features=hidden_dim, out_features=dim, bias_attr=False) - self.w3 = paddle.nn.Linear(in_features=dim, out_features=hidden_dim, bias_attr=False) - - def _forward_silu_gating(self, x1, x3): - return F.silu(x1) * x3 + hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) + if is_model_parrallel(): + self.w1 = fleet.meta_parallel.ColumnParallelLinear( + dim, hidden_dim, weight_attr=None, has_bias=False, gather_output=True + ) + self.w2 = fleet.meta_parallel.ColumnParallelLinear( + hidden_dim, dim, weight_attr=None, has_bias=False, gather_output=True + ) + self.w3 = fleet.meta_parallel.ColumnParallelLinear( + dim, hidden_dim, weight_attr=None, has_bias=False, gather_output=True + ) + else: + self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False) + self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False) + self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False) def forward(self, x): - return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + xw1 = F.silu(self.w1(x)) + xw3 = self.w3(x) + output = self.w2(xw1 * xw3) + return output -class TransformerBlock(paddle.nn.Layer): +class TransformerBlock(nn.Layer): def __init__( self, layer_id: int, @@ -285,9 +285,11 @@ def __init__( n_heads: int, n_kv_heads: int, multiple_of: int, + mlp_ratio: float, ffn_dim_multiplier: float, norm_eps: float, qk_norm: bool, + fused_attn: bool, ) -> None: """ Initialize a TransformerBlock. @@ -322,18 +324,29 @@ def __init__( super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm) + self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn) + mlp_hidden_dim = int(dim * mlp_ratio) self.feed_forward = FeedForward( - dim=dim, hidden_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier + dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier ) self.layer_id = layer_id - self.attention_norm = paddle.nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) - self.ffn_norm = paddle.nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) - self.adaLN_modulation = paddle.nn.Sequential( - paddle.nn.Silu(), paddle.nn.Linear(in_features=min(dim, 1024), out_features=6 * dim, bias_attr=True) - ) + self.attention_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) + self.ffn_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) + + if is_model_parrallel(): + self.adaLN_modulation = nn.Sequential( + nn.Silu(), + fleet.meta_parallel.ColumnParallelLinear( + min(dim, 1024), 6 * dim, weight_attr=None, has_bias=True, gather_output=True + ), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.Silu(), + nn.Linear(min(dim, 1024), 6 * dim), + ) - def forward(self, x: paddle.Tensor, freqs_cis: paddle.Tensor, adaln_input=None): + def forward(self, x, freqs_cis, adaln_input=None): """ Perform a forward pass through the TransformerBlock. @@ -349,36 +362,49 @@ def forward(self, x: paddle.Tensor, freqs_cis: paddle.Tensor, adaln_input=None): """ if adaln_input is not None: - (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation( - adaln_input - ).chunk(chunks=6, axis=1) - h = x + gate_msa.unsqueeze(axis=1) * self.attention( + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk( + 6, axis=1 + ) + h = x + gate_msa.unsqueeze(1) * self.attention( modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis ) - out = h + gate_mlp.unsqueeze(axis=1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp)) + out = h + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp)) else: h = x + self.attention(self.attention_norm(x), freqs_cis) out = h + self.feed_forward(self.ffn_norm(h)) return out -class FinalLayer(paddle.nn.Layer): +class ParallelFinalLayer(nn.Layer): def __init__(self, hidden_size, patch_size, out_channels): super().__init__() - self.norm_final = paddle.nn.LayerNorm(hidden_size, weight_attr=False, bias_attr=False, epsilon=1e-06) - self.linear = paddle.nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias_attr=True) - self.adaLN_modulation = paddle.nn.Sequential( - paddle.nn.Silu(), paddle.nn.Linear(min(hidden_size, 1024), 2 * hidden_size, bias_attr=True) - ) + self.norm_final = nn.LayerNorm(hidden_size, weight_attr=False, bias_attr=False, epsilon=1e-06) + if is_model_parrallel(): + self.linear = fleet.meta_parallel.ColumnParallelLinear( + hidden_size, + patch_size * patch_size * out_channels, + weight_attr=None, + has_bias=True, + gather_output=True, + ) + self.adaLN_modulation = nn.Sequential( + nn.Silu(), + fleet.meta_parallel.ColumnParallelLinear( + min(hidden_size, 1024), 2 * hidden_size, weight_attr=None, has_bias=True, gather_output=True + ), + ) + else: + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels) + self.adaLN_modulation = nn.Sequential(nn.Silu(), nn.Linear(min(hidden_size, 1024), 2 * hidden_size)) def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(chunks=2, axis=1) + shift, scale = self.adaLN_modulation(c).chunk(2, axis=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x -class DiT_Llama(nn.Layer): +class DiT_Llama(ModelMixin, ConfigMixin): """ Diffusion model with a Transformer backbone. """ @@ -388,55 +414,89 @@ class DiT_Llama(nn.Layer): def __init__( self, - sample_size: int = 32, + sample_size: int = 32, # image_size // 8 patch_size: int = 2, in_channels: int = 4, out_channels: int = 8, + num_layers: int = 32, num_attention_heads: int = 16, attention_head_dim: int = 96, - num_layers: int = 32, + mlp_ratio: float = 4.0, n_kv_heads=None, multiple_of: int = 256, ffn_dim_multiplier=None, norm_eps: float = 1e-05, - class_dropout_prob: float = 0.1, + class_dropout_prob: float = 0.0, # for tensor parallel num_classes: int = 1000, learn_sigma: bool = True, qk_norm: bool = True, ): super().__init__() - self.learn_sigma = learn_sigma - self.in_channels = in_channels - self.out_channels = in_channels * 2 if learn_sigma else in_channels - self.sample_size = sample_size self.patch_size = patch_size - + self.in_channels = in_channels + self.out_channels = in_channels * 2 if learn_sigma else in_channels dim = attention_head_dim * num_attention_heads - self.x_embedder = paddle.nn.Linear(in_channels * patch_size**2, dim, bias_attr=True) - self.t_embedder = TimestepEmbedder(min(dim, 1024)) - self.y_embedder = LabelEmbedder(num_classes, min(dim, 1024), class_dropout_prob) + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.multiple_of = multiple_of + self.ffn_dim_multiplier = ffn_dim_multiplier + self.norm_eps = norm_eps + self.class_dropout_prob = class_dropout_prob + self.num_classes = num_classes + self.learn_sigma = learn_sigma + self.qk_norm = qk_norm - self.layers = paddle.nn.LayerList( + self.gradient_checkpointing = True + self.fused_attn = True + + # 1. Define input layers + if is_model_parrallel(): + self.x_embedder = fleet.meta_parallel.ColumnParallelLinear( + in_channels * patch_size**2, + dim, + weight_attr=None, + has_bias=True, + gather_output=True, + ) + else: + self.x_embedder = nn.Linear(in_channels * patch_size**2, dim) + self.t_embedder = ParallelTimestepEmbedder(min(dim, 1024)) + self.y_embedder = ParallelLabelEmbedder(num_classes, min(dim, 1024), class_dropout_prob) + + # 2. Define transformers blocks + self.layers = nn.LayerList( [ TransformerBlock( - layer_id, dim, num_attention_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm + layer_id=idx, + dim=dim, + n_heads=num_attention_heads, + n_kv_heads=n_kv_heads, + multiple_of=multiple_of, + mlp_ratio=mlp_ratio, + ffn_dim_multiplier=ffn_dim_multiplier, + norm_eps=norm_eps, + qk_norm=qk_norm, + fused_attn=self.fused_attn, ) - for layer_id in range(num_layers) + for idx in range(num_layers) ] ) - self.final_layer = FinalLayer(dim, patch_size, self.out_channels) + # 3. Define output layers + self.final_layer = ParallelFinalLayer(dim, patch_size, self.out_channels) self.freqs_cis = self.precompute_freqs_cis(dim // num_attention_heads, 4096) - self.gradient_checkpointing = True self.initialize_weights() def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): - if isinstance(module, nn.Linear): + if isinstance( + module, (nn.Linear, fleet.meta_parallel.ColumnParallelLinear, fleet.meta_parallel.RowParallelLinear) + ): initializer.XavierUniform()(module.weight) if module.bias is not None: initializer.Constant(value=0)(module.bias) @@ -461,7 +521,7 @@ def _basic_init(module): initializer.Constant(value=0)(block.adaLN_modulation[-1].weight) initializer.Constant(value=0)(block.adaLN_modulation[-1].bias) - # Zero-out output layers: + # Zero-out final_layer: initializer.Constant(value=0)(self.final_layer.adaLN_modulation[-1].weight) initializer.Constant(value=0)(self.final_layer.adaLN_modulation[-1].bias) initializer.Constant(value=0)(self.final_layer.linear.weight) @@ -470,11 +530,11 @@ def _basic_init(module): def enable_gradient_checkpointing(self, value=True): self.gradient_checkpointing = value - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[str] = None): + self._use_memory_efficient_attention_xformers = True + self.fused_attn = True - def unpatchify(self, x: paddle.Tensor) -> paddle.Tensor: + def unpatchify(self, x): """ Args: x: (N, T, patch_size**2 * C) @@ -484,12 +544,13 @@ def unpatchify(self, x: paddle.Tensor) -> paddle.Tensor: p = self.patch_size h = w = int(tuple(x.shape)[1] ** 0.5) assert h * w == tuple(x.shape)[1] + x = x.reshape(shape=([tuple(x.shape)[0], h, w, p, p, c])) x = paddle.einsum("nhwpqc->nchpwq", x) imgs = x.reshape(shape=([tuple(x.shape)[0], c, h * p, h * p])) return imgs - def patchify(self, x: paddle.Tensor) -> paddle.Tensor: + def patchify(self, x): B, C, H, W = tuple(x.shape) assert (H, W) == (self.sample_size, self.sample_size) pH = pW = self.patch_size @@ -518,16 +579,15 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): paddle.Tensor: Precomputed frequency tensor with complex exponentials. """ - freqs = 1.0 / theta ** (paddle.arange(start=0, end=dim, step=2)[: dim // 2].astype(dtype="float32") / dim) + freqs = 1.0 / theta ** (paddle.arange(start=0, end=dim, step=2)[: dim // 2].cast("float32") / dim) t = paddle.arange(end=end) input_0, vec2_0 = TypePromote(t, freqs) - freqs = paddle.outer(x=input_0, y=vec2_0).astype(dtype="float32") + freqs = paddle.outer(input_0, vec2_0).cast("float32") freqs_cis = paddle.complex( - paddle.ones_like(x=freqs) * paddle.cos(freqs), paddle.ones_like(x=freqs) * paddle.sin(freqs) + paddle.ones_like(freqs) * paddle.cos(freqs), paddle.ones_like(freqs) * paddle.sin(freqs) ) return freqs_cis - # def forward(self, hidden_states, timestep, class_labels): def forward(self, x, t, y): """ Args: @@ -537,28 +597,29 @@ def forward(self, x, t, y): class_labels: (N,) tensor of class labels """ hidden_states, timestep, class_labels = x, t, y - hidden_states = hidden_states.cast(x.dtype) + dtype = hidden_states.dtype # 1. Input hidden_states = self.patchify(hidden_states) x = self.x_embedder(hidden_states) t = self.t_embedder(timestep) - y = self.y_embedder(class_labels, self.training) + y = self.y_embedder(class_labels) adaln_input = t + y # 2. Blocks - for layer in self.layers: + for i, layer in enumerate(self.layers): if self.gradient_checkpointing: - x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) + x = paddle.distributed.fleet.utils.recompute( + layer, x, self.freqs_cis[: x.shape[1]].cast(dtype), adaln_input, use_reentrant=False + ) else: x = layer( x, - self.freqs_cis[: x.shape[1]], + self.freqs_cis[: x.shape[1]].cast(dtype), adaln_input, ) # 3. Output hidden_states = self.final_layer(x, adaln_input) output = self.unpatchify(hidden_states) - return output diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/trainer.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/trainer.py index 15f43f76a..8d020e582 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/trainer.py +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/trainer.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib +import os import sys import time @@ -27,9 +28,15 @@ VisualDLCallback, rewrite_logs, ) +from paddlenlp.transformers.model_utils import _add_variant from paddlenlp.utils import profiler from paddlenlp.utils.log import logger +from ppdiffusers.training_utils import unwrap_model + +PADDLE_WEIGHTS_NAME = "model_state.pdparams" +TRAINING_ARGS_NAME = "training_args.bin" + def worker_init_fn(_): worker_info = paddle.io.get_worker_info() @@ -253,3 +260,39 @@ def get_train_dataloader(self): worker_init_fn=worker_init_fn, ) return train_dataloader + + def _save_todo(self, output_dir=None, state_dict=None, merge_tensor_parallel=False): + # TODO: merge_tensor_parallel + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + + if self.args.only_save_updated_model: + unwraped_model = unwrap_model(self.model) + logger.info(f"Saving transformer DiT checkpoint to {output_dir}/transformer") + unwraped_model.transformer.save_pretrained( + os.path.join(output_dir, "transformer"), + # merge_tensor_parallel=merge_tensor_parallel, + ) + + if unwraped_model.use_ema: + logger.info(f"Saving ema transformer DiT checkpoint to {output_dir}/transformer") + with unwraped_model.ema_scope(): + unwraped_model.transformer.save_pretrained( + os.path.join(output_dir, "transformer"), + # merge_tensor_parallel=merge_tensor_parallel, + variant="ema", + ) + + else: + logger.info(f"Saving model checkpoint to {output_dir}") + if state_dict is None: + state_dict = self.model.state_dict() + paddle.save( + state_dict, + os.path.join( + output_dir, + _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix), + ), + ) + if self.args.should_save: + paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/trainer_args.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/trainer_args.py index f70c302bd..275b1271f 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/trainer_args.py +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/diffusion/trainer_args.py @@ -17,6 +17,7 @@ from typing import Optional import paddle +from paddlenlp.trainer import TrainingArguments from paddlenlp.utils.log import logger @@ -42,6 +43,9 @@ class ModelArguments: enable_xformers_memory_efficient_attention: bool = field( default=False, metadata={"help": "enable_xformers_memory_efficient_attention."} ) + only_save_updated_model: bool = field( + default=True, metadata={"help": "Whether or not save only_save_updated_model"} + ) prediction_type: Optional[str] = field( default="epsilon", metadata={ @@ -86,6 +90,149 @@ class DataArguments: ) +@dataclass +class TrainerArguments(TrainingArguments): + """ + Arguments pertaining to what training options we are going to use during pretraining. + """ + + pretrained_model_path: Optional[str] = field( + default=None, + metadata={"help": "Whether to use pretrained checkpoint weights."}, + ) + resume_from_checkpoint: Optional[str] = field( + default=None, + metadata={"help": "The path to a folder with a valid checkpoint for your model."}, + ) + + optim: str = field(default="adamw", metadata={"help": "optimizer setting, [lamb/adamw]"}) + learning_rate: float = field(default=1e-4, metadata={"help": "The initial learning rate for AdamW."}) + weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) + adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) + adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) + adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) + max_grad_norm: float = field(default=0.0, metadata={"help": "Max gradient norm."}) # clip_grad + + # new added + warmup_lr: float = field(default=0.0, metadata={"help": "The initial learning rate for AdamW."}) + min_lr: float = field(default=0.0, metadata={"help": "The initial learning rate for AdamW."}) + warmup_steps: int = field(default=-1, metadata={"help": "Linear warmup over warmup_steps."}) + warmup_epochs: int = field(default=1, metadata={"help": "Linear warmup over warmup_epochs."}) + + output_dir: str = field( + default="output_dir", + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + ) + logging_dir: str = field( + default="output_dir/tb_ft_log", + metadata={"help": "The output directory where logs saved."}, + ) + logging_steps: int = field(default=10, metadata={"help": "logging_steps print frequency (default: 10)"}) + + do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) + do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) + do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) + do_export: bool = field(default=False, metadata={"help": "Whether to export infernece model."}) + per_device_train_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU core/CPU for training."}) + per_device_eval_batch_size: int = field( + default=8, metadata={"help": "Batch size per GPU core/CPU for evaluation."} + ) + gradient_accumulation_steps: int = field( + default=1, + metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, + ) + accum_freq: int = field( + default=1, + metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, + ) + + num_train_epochs: float = field(default=-1, metadata={"help": "Total number of training epochs to perform."}) + max_steps: int = field( + default=-1, + metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, + ) + lr_scheduler_type: str = field( + default="cosine", + metadata={"help": "The scheduler type to use. suppor linear, cosine, constant, constant_with_warmup"}, + ) + warmup_ratio: float = field( + default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."} + ) + warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) + num_cycles: float = field(default=0.5, metadata={"help": "The number of waves in the cosine scheduler."}) + lr_end: float = field(default=1e-7, metadata={"help": "The end LR in the polynomial scheduler."}) + power: float = field(default=1.0, metadata={"help": "The power factor in the polynomial scheduler."}) + + save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) + save_epochs: int = field(default=1, metadata={"help": "Save checkpoint every X updates epochs."}) + + seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) + + bf16: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA" + " architecture or using CPU (no_cuda). This is an experimental API and it may change." + ) + }, + ) + fp16: bool = field( + default=False, + metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"}, + ) + fp16_opt_level: str = field( + default="O1", + metadata={ + "help": ( + "For fp16: AMP optimization level selected in ['O0', 'O1', and 'O2']. " + "See details at https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/amp/auto_cast_cn.html" + ) + }, + ) + + dp_degree: int = field( + default=1, + metadata={"help": " data parallel degrees."}, + ) + sharding_parallel_degree: int = field( + default=1, + metadata={"help": " sharding parallel degrees."}, + ) + tensor_parallel_degree: int = field( + default=1, + metadata={"help": " tensor parallel degrees."}, + ) + pipeline_parallel_degree: int = field( + default=1, + metadata={"help": " pipeline parallel degrees."}, + ) + sep_parallel_degree: int = field( + default=1, + metadata={"help": ("sequence parallel strategy.")}, + ) + + last_epoch: int = field(default=-1, metadata={"help": "the last epoch to resume"}) + + dataloader_drop_last: bool = field( + default=True, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} + ) + dataloader_num_workers: int = field( + default=1, + metadata={ + "help": "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + }, + ) + + disable_tqdm: Optional[bool] = field( + default=True, metadata={"help": "Whether or not to disable the tqdm progress bars."} + ) + tensorboard: bool = field( + default=False, + metadata={"help": "Whether to use tensorboard to record loss."}, + ) + + @dataclass class NoTrainerTrainingArguments: output_dir: str = field( diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/train_image_generation_notrainer.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/train_image_generation_notrainer.py index 0ecbbd0de..99755b1b1 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/train_image_generation_notrainer.py +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/train_image_generation_notrainer.py @@ -24,8 +24,10 @@ import paddle import paddle.distributed as dist from diffusion import create_diffusion +from diffusion.dist_env import set_hyrbid_parallel_seed from diffusion.dit import DiT from diffusion.dit_llama import DiT_Llama +from paddle.distributed import fleet from transport import create_transport from transport.sit import SiT from transport.utils import parse_transport_args @@ -246,4 +248,17 @@ def main(args): parse_transport_args(parser) args = parser.parse_args() print(args) + + strategy = fleet.DistributedStrategy() + fleet.init(is_collective=True, strategy=strategy) + + sharding_parallel_degree = 1 + hcg = fleet.get_hybrid_communicate_group() + mp_rank = hcg.get_model_parallel_rank() + dp_rank = hcg.get_data_parallel_rank() + sharding_rank = hcg.get_sharding_parallel_rank() + data_world_rank = dp_rank * sharding_parallel_degree + sharding_rank + + # seed control in hybrid parallel + set_hyrbid_parallel_seed(args.global_seed, data_world_rank, mp_rank) main(args) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/train_image_generation_trainer.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/train_image_generation_trainer.py index a04927789..6938f06a8 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/train_image_generation_trainer.py +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/train_image_generation_trainer.py @@ -14,6 +14,8 @@ import itertools import math import os +import pprint +import socket import numpy as np import paddle @@ -22,13 +24,10 @@ DiTDiffusionModel, LatentDiffusionTrainer, ModelArguments, + TrainerArguments, + setdistenv, ) -from paddlenlp.trainer import ( - PdArgumentParser, - TrainingArguments, - get_last_checkpoint, - set_seed, -) +from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed from paddlenlp.utils.log import logger from transport import SiTDiffusionModel @@ -56,12 +55,24 @@ def __getitem__(self, idx): def main(): - parser = PdArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + parser = PdArgumentParser((ModelArguments, DataArguments, TrainerArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() + training_args.hostname = socket.gethostname() + pprint.pprint(data_args) + pprint.pprint(model_args) + pprint.pprint(training_args) + setdistenv(training_args) + model_args.data_world_rank = training_args.data_world_rank + model_args.data_world_size = training_args.data_world_size + # report to custom_visualdl training_args.report_to = ["custom_visualdl"] training_args.resolution = data_args.resolution training_args.benchmark = model_args.benchmark + training_args.use_ema = model_args.use_ema + training_args.enable_xformers_memory_efficient_attention = model_args.enable_xformers_memory_efficient_attention + training_args.only_save_updated_model = model_args.only_save_updated_model + training_args.profiler_options = model_args.profiler_options training_args.image_logging_steps = model_args.image_logging_steps = ( (math.ceil(model_args.image_logging_steps / training_args.logging_steps) * training_args.logging_steps) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/transport/sit.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/transport/sit.py index 44cf02ebd..102c25436 100644 --- a/ppdiffusers/examples/class_conditional_image_generation/DiT/transport/sit.py +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/transport/sit.py @@ -14,7 +14,6 @@ import collections.abc import math -from functools import partial from itertools import repeat from typing import Optional @@ -22,13 +21,24 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F -from paddle.nn.initializer import Constant, Normal, TruncatedNormal - -trunc_normal_ = TruncatedNormal(std=0.02) -normal_ = Normal -zeros_ = Constant(value=0.0) -ones_ = Constant(value=1.0) import paddle.nn.initializer as initializer +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from paddle.nn.functional.flash_attention import flash_attention + +from ppdiffusers.configuration_utils import ConfigMixin +from ppdiffusers.models.modeling_utils import ModelMixin + + +def is_model_parrallel(): + if paddle.distributed.get_world_size() > 1: + hcg = paddle.distributed.fleet.get_hybrid_communicate_group() + if hcg.get_model_parallel_world_size() > 1: + return True + else: + return False + else: + return False def _ntuple(n): @@ -82,29 +92,44 @@ def __init__( norm_layer=None, bias=True, drop=0.0, - use_conv=False, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - bias = to_2tuple(bias) - drop_probs = to_2tuple(drop) - linear_layer = partial(nn.Conv2D, kernel_size=1) if use_conv else nn.Linear - self.fc1 = linear_layer(in_features, hidden_features, bias_attr=bias[0]) + if is_model_parrallel(): + self.fc1 = fleet.meta_parallel.ColumnParallelLinear( + in_features, + hidden_features, + weight_attr=None, + has_bias=bias, + gather_output=True, + ) + self.fc2 = fleet.meta_parallel.ColumnParallelLinear( + hidden_features, + out_features, + weight_attr=None, + has_bias=bias, + gather_output=True, + ) + else: + self.fc1 = nn.Linear(in_features, hidden_features, bias_attr=bias) + self.fc2 = nn.Linear(hidden_features, out_features, bias_attr=bias) + self.act = act_layer() - self.drop1 = nn.Dropout(drop_probs[0]) + self.drop1 = nn.Dropout(drop) self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() - self.fc2 = linear_layer(hidden_features, out_features, bias_attr=bias[1]) - self.drop2 = nn.Dropout(drop_probs[1]) + self.drop2 = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) - x = self.drop1(x) + with get_rng_state_tracker().rng_state("global_seed"): + x = self.drop1(x) x = self.norm(x) x = self.fc2(x) - x = self.drop2(x) + with get_rng_state_tracker().rng_state("global_seed"): + x = self.drop2(x) return x @@ -127,51 +152,92 @@ def __init__( self.scale = self.head_dim**-0.5 self.fused_attn = fused_attn - self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + if is_model_parrallel(): + self.qkv = fleet.meta_parallel.ColumnParallelLinear( + dim, dim * 3, weight_attr=None, has_bias=qkv_bias, gather_output=True + ) + else: + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) + if is_model_parrallel(): + self.proj = fleet.meta_parallel.ColumnParallelLinear( + dim, dim, weight_attr=None, has_bias=True, gather_output=True + ) + else: + self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: paddle.Tensor) -> paddle.Tensor: B, N, C = x.shape + dtype = x.dtype qkv = self.qkv(x).reshape([B, N, 3, self.num_heads, self.head_dim]).transpose([2, 0, 3, 1, 4]) q, k, v = qkv[0], qkv[1], qkv[2] q, k = self.q_norm(q), self.k_norm(k) - if self.fused_attn: - x = F.scaled_dot_product_attention_( + if dtype in [paddle.float16, paddle.bfloat16]: + x, _ = flash_attention( q, k, v, - dropout_p=self.attn_drop.p if self.training else 0.0, + dropout=self.attn_drop.p, + return_softmax=False, ) else: - q = q * self.scale - attn = q @ k.transpose([0, 1, 3, 2]) - attn = F.softmax(attn, axis=-1) - attn = self.attn_drop(attn) - x = attn @ v + if self.fused_attn: + x = F.scaled_dot_product_attention_( + q, + k, + v, + dropout_p=self.attn_drop.p if self.training else 0.0, + ) + else: + q = q * self.scale + attn = q @ k.transpose([0, 1, 3, 2]) + attn = F.softmax(attn, axis=-1) + with get_rng_state_tracker().rng_state("global_seed"): + attn = self.attn_drop(attn) + x = attn @ v x = x.transpose([0, 2, 1, 3]).reshape([B, N, C]) x = self.proj(x) - x = self.proj_drop(x) + with get_rng_state_tracker().rng_state("global_seed"): + x = self.proj_drop(x) return x -class TimestepEmbedder(nn.Layer): +class ParallelTimestepEmbedder(nn.Layer): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() - self.mlp = nn.Sequential( - nn.Linear(frequency_embedding_size, hidden_size, bias_attr=True), - nn.Silu(), - nn.Linear(hidden_size, hidden_size, bias_attr=True), - ) + if is_model_parrallel(): + self.mlp = nn.Sequential( + fleet.meta_parallel.ColumnParallelLinear( + frequency_embedding_size, + hidden_size, + weight_attr=None, + has_bias=True, + gather_output=False, # True + ), + nn.Silu(), + fleet.meta_parallel.RowParallelLinear( + hidden_size, + hidden_size, + weight_attr=None, + has_bias=True, + input_is_parallel=True, # + ), + ) + else: + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size), + nn.Silu(), + nn.Linear(hidden_size, hidden_size), + ) self.frequency_embedding_size = frequency_embedding_size @staticmethod @@ -184,22 +250,21 @@ def timestep_embedding(t, dim, max_period=10000): :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ - # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 - freqs = paddle.exp(-math.log(max_period) * paddle.arange(start=0, end=half, dtype=paddle.float32) / half) - args = t[:, None].cast("float32") * freqs[None] - embedding = paddle.concat([paddle.cos(args), paddle.sin(args)], axis=-1) + freqs = paddle.exp(x=-math.log(max_period) * paddle.arange(start=0, end=half, dtype="float32") / half) + args = t[:, (None)].astype(dtype="float32") * freqs[None] + embedding = paddle.concat(x=[paddle.cos(x=args), paddle.sin(x=args)], axis=-1) if dim % 2: - embedding = paddle.concat([embedding, paddle.zeros_like(embedding[:, :1])], axis=-1) + embedding = paddle.concat(x=[embedding, paddle.zeros_like(x=embedding[:, :1])], axis=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq) + t_emb = self.mlp(t_freq.cast(self.mlp[0].weight.dtype)) return t_emb -class LabelEmbedder(nn.Layer): +class ParallelLabelEmbedder(nn.Layer): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. """ @@ -207,7 +272,11 @@ class LabelEmbedder(nn.Layer): def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = dropout_prob > 0 - self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) + embedding_dim = num_classes + use_cfg_embedding + if is_model_parrallel: + self.embedding_table = fleet.meta_parallel.VocabParallelEmbedding(embedding_dim, hidden_size) + else: + self.embedding_table = nn.Embedding(embedding_dim, hidden_size) self.num_classes = num_classes self.dropout_prob = dropout_prob @@ -230,7 +299,8 @@ def token_drop(self, labels, force_drop_ids=None): def forward(self, labels, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (self.training and use_dropout) or (force_drop_ids is not None): - labels = self.token_drop(labels, force_drop_ids) + with get_rng_state_tracker().rng_state("global_seed"): + labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings @@ -247,8 +317,25 @@ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, fused_attn=False, **bl self.norm2 = nn.LayerNorm(hidden_size, epsilon=1e-6) mlp_hidden_dim = int(hidden_size * mlp_ratio) approx_gelu = lambda: nn.GELU(approximate=True) # 'tanh' - self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) - self.adaLN_modulation = nn.Sequential(nn.Silu(), nn.Linear(hidden_size, 6 * hidden_size, bias_attr=True)) + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=approx_gelu, + drop=0, + ) + + if is_model_parrallel(): + self.adaLN_modulation = nn.Sequential( + nn.Silu(), + fleet.meta_parallel.ColumnParallelLinear( + hidden_size, 6 * hidden_size, weight_attr=None, has_bias=True, gather_output=True + ), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.Silu(), + nn.Linear(hidden_size, 6 * hidden_size, bias_attr=True), + ) def forward(self, x, c): shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, axis=1) @@ -257,16 +344,27 @@ def forward(self, x, c): return x -class FinalLayer(nn.Layer): - """ - The final layer of DiT. - """ - +class ParallelFinalLayer(nn.Layer): def __init__(self, hidden_size, patch_size, out_channels): super().__init__() - self.norm_final = nn.LayerNorm(hidden_size, epsilon=1e-6) - self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias_attr=True) - self.adaLN_modulation = nn.Sequential(nn.Silu(), nn.Linear(hidden_size, 2 * hidden_size, bias_attr=True)) + self.norm_final = nn.LayerNorm(hidden_size, weight_attr=False, bias_attr=False, epsilon=1e-06) + if is_model_parrallel(): + self.linear = fleet.meta_parallel.ColumnParallelLinear( + hidden_size, + patch_size * patch_size * out_channels, + weight_attr=None, + has_bias=True, + gather_output=True, + ) + self.adaLN_modulation = nn.Sequential( + nn.Silu(), + fleet.meta_parallel.ColumnParallelLinear( + hidden_size, 2 * hidden_size, weight_attr=None, has_bias=True, gather_output=True + ), + ) + else: + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels) + self.adaLN_modulation = nn.Sequential(nn.Silu(), nn.Linear(hidden_size, 2 * hidden_size)) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, axis=1) @@ -275,7 +373,7 @@ def forward(self, x, c): return x -class SiT(nn.Layer): +class SiT(ModelMixin, ConfigMixin): """ Diffusion model with a Transformer backbone. """ @@ -285,17 +383,17 @@ class SiT(nn.Layer): def __init__( self, - sample_size=32, # image_size // 8 - patch_size=2, - in_channels=4, - out_channels=8, - num_layers=28, - num_attention_heads=16, - attention_head_dim=72, - mlp_ratio=4.0, - class_dropout_prob=0.1, - num_classes=1000, - learn_sigma=True, + sample_size: int = 32, # image_size // 8 + patch_size: int = 2, + in_channels: int = 4, + out_channels: int = 8, + num_layers: int = 28, + num_attention_heads: int = 16, + attention_head_dim: int = 72, + mlp_ratio: float = 4.0, + class_dropout_prob: float = 0.0, # for tensor parallel + num_classes: int = 1000, + learn_sigma: bool = True, ): super().__init__() self.sample_size = sample_size @@ -311,16 +409,19 @@ def __init__( self.num_classes = num_classes self.learn_sigma = learn_sigma - self.gradient_checkpointing = False - self.fused_attn = False + self.gradient_checkpointing = True + self.fused_attn = True # 1. Define input layers self.x_embedder = PatchEmbed(sample_size, patch_size, in_channels, hidden_size, bias=True) - self.t_embedder = TimestepEmbedder(hidden_size) - self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob) + self.t_embedder = ParallelTimestepEmbedder(hidden_size) + self.y_embedder = ParallelLabelEmbedder(num_classes, hidden_size, class_dropout_prob) num_patches = self.x_embedder.num_patches # Will use fixed sin-cos embedding: - self.pos_embed = self.create_parameter(shape=(1, num_patches, hidden_size), default_initializer=zeros_) + self.pos_embed = self.create_parameter( + shape=(1, num_patches, hidden_size), + default_initializer=initializer.Constant(0.0), + ) self.add_parameter("pos_embed", self.pos_embed) # 2. Define transformers blocks @@ -332,14 +433,16 @@ def __init__( ) # 3. Define output layers - self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.final_layer = ParallelFinalLayer(hidden_size, patch_size, self.out_channels) self.initialize_weights() def initialize_weights(self): # Initialize transformer layers: def _basic_init(module): - if isinstance(module, nn.Linear): + if isinstance( + module, (nn.Linear, fleet.meta_parallel.ColumnParallelLinear, fleet.meta_parallel.RowParallelLinear) + ): initializer.XavierUniform()(module.weight) if module.bias is not None: initializer.Constant(value=0)(module.bias) diff --git a/ppdiffusers/examples/text_to_image_mscoco_uvit/README.md b/ppdiffusers/examples/text_to_image_mscoco_uvit/README.md index 14d0edf76..e7627e11c 100644 --- a/ppdiffusers/examples/text_to_image_mscoco_uvit/README.md +++ b/ppdiffusers/examples/text_to_image_mscoco_uvit/README.md @@ -52,12 +52,12 @@ TRAINER_INSTANCES='127.0.0.1' MASTER='127.0.0.1:8080' TRAINERS_NUM=1 # nnodes, machine num TRAINING_GPUS_PER_NODE=8 # nproc_per_node -DP_DEGREE=1 # dp_parallel_degree +DP_DEGREE=8 # dp_parallel_degree MP_DEGREE=1 # tensor_parallel_degree SHARDING_DEGREE=1 # sharding_parallel_degree uvit_config_file=config/uvit_t2i_small.json -output_dir=output_dir/uvit_t2i_small +output_dir=output_trainer/uvit_t2i_small_trainer feature_path=./datasets/coco256_features per_device_train_batch_size=32 @@ -65,8 +65,8 @@ dataloader_num_workers=8 max_steps=1000000 save_steps=5000 warmup_steps=5000 -logging_steps=20 -image_logging_steps=10000 +logging_steps=50 +image_logging_steps=-1 seed=1234 USE_AMP=True @@ -92,7 +92,7 @@ ${TRAINING_PYTHON} train_txt2img_mscoco_uvit_trainer.py \ --image_logging_steps ${image_logging_steps} \ --logging_steps ${logging_steps} \ --save_steps ${save_steps} \ - --seed ${seed}\ + --seed ${seed} \ --dataloader_num_workers ${dataloader_num_workers} \ --max_grad_norm -1 \ --uvit_config_file ${uvit_config_file} \ @@ -105,6 +105,10 @@ ${TRAINING_PYTHON} train_txt2img_mscoco_uvit_trainer.py \ --fp16 ${USE_AMP} \ --fp16_opt_level=${fp16_opt_level} \ --enable_xformers_memory_efficient_attention ${enable_xformers} \ + --dp_degree ${DP_DEGREE} \ + --tensor_parallel_degree ${MP_DEGREE} \ + --sharding_parallel_degree ${SHARDING_DEGREE} \ + --pipeline_parallel_degree 1 \ ``` diff --git a/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/__init__.py b/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/__init__.py index 7fc75f9a2..6b8c8debc 100644 --- a/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/__init__.py +++ b/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/__init__.py @@ -13,7 +13,13 @@ # limitations under the License. # flake8: noqa -from .ldm_args import DataArguments, ModelArguments, NoTrainerTrainingArguments +from .dist_env import setdistenv +from .ldm_args import ( + DataArguments, + ModelArguments, + NoTrainerTrainingArguments, + TrainerArguments, +) from .ldm_trainer import LatentDiffusionTrainer from .model import LatentDiffusionModel from .text_to_image_dataset import MSCOCO256Features, worker_init_fn diff --git a/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/dist_env.py b/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/dist_env.py new file mode 100644 index 000000000..63377dbda --- /dev/null +++ b/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/dist_env.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker + + +def set_hyrbid_parallel_seed(basic_seed, data_world_rank, mp_rank, pp_rank=0): + device_id = paddle.device.get_device() + assert "gpu" in device_id + + random.seed(basic_seed + data_world_rank) + np.random.seed(basic_seed + data_world_rank) + paddle.seed(basic_seed + data_world_rank) + + # local_seed/ global_seed is used to control dropout in ModelParallel + local_seed = 1024 + basic_seed + mp_rank * 100 + data_world_rank + global_seed = 2048 + basic_seed + data_world_rank + tracker = get_rng_state_tracker() + tracker.add("global_seed", global_seed) + tracker.add("local_seed", local_seed) + + +def setdistenv(args): + world_size = dist.get_world_size() + if world_size > 1: + args.dp_degree = max(args.dp_degree, 1) + args.sharding_parallel_degree = max(args.sharding_parallel_degree, 1) + args.tensor_parallel_degree = max(args.tensor_parallel_degree, 1) + args.sep_parallel_degree = max(args.sep_parallel_degree, 1) + args.pipeline_parallel_degree = max(args.pipeline_parallel_degree, 1) + + assert ( + world_size % (args.tensor_parallel_degree * args.pipeline_parallel_degree) == 0 + ), f"Total world_size:{world_size} shoule be devided by tensor_parallel_degree: {args.tensor_parallel_degree} and pipeline_parallel_degree: {args.pipeline_parallel_degree}." + + args.dp_degree = world_size // ( + args.tensor_parallel_degree * args.sharding_parallel_degree * args.pipeline_parallel_degree + ) + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": args.dp_degree, + "mp_degree": args.tensor_parallel_degree, + "sharding_degree": args.sharding_parallel_degree, + "pp_degree": args.pipeline_parallel_degree, + } + # strategy.find_unused_parameters = True + + # set control in tensor parallel + strategy.tensor_parallel_configs = {"tensor_init_seed": args.seed} + + fleet.init(is_collective=True, strategy=strategy) + + args.rank = dist.get_rank() + # obtain rank message of hybrid parallel + hcg = fleet.get_hybrid_communicate_group() + args.mp_rank = hcg.get_model_parallel_rank() + args.dp_rank = hcg.get_data_parallel_rank() + args.sharding_rank = hcg.get_sharding_parallel_rank() + + args.data_world_rank = args.dp_rank * args.sharding_parallel_degree + args.sharding_rank + args.data_world_size = world_size // abs(args.tensor_parallel_degree * args.pipeline_parallel_degree) + else: + args.data_world_rank = 0 + args.data_world_size = 1 + args.mp_rank = 0 + args.rank = 0 + + # seed control in hybrid parallel + set_hyrbid_parallel_seed(args.seed, args.data_world_rank, args.mp_rank) diff --git a/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/ldm_args.py b/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/ldm_args.py index 3d5ce9e24..f27d63acf 100644 --- a/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/ldm_args.py +++ b/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/ldm_args.py @@ -17,6 +17,7 @@ from typing import Optional import paddle +from paddlenlp.trainer import TrainingArguments from paddlenlp.utils.log import logger @@ -109,6 +110,149 @@ class DataArguments: ) +@dataclass +class TrainerArguments(TrainingArguments): + """ + Arguments pertaining to what training options we are going to use during pretraining. + """ + + pretrained_model_path: Optional[str] = field( + default=None, + metadata={"help": "Whether to use pretrained checkpoint weights."}, + ) + resume_from_checkpoint: Optional[str] = field( + default=None, + metadata={"help": "The path to a folder with a valid checkpoint for your model."}, + ) + + optim: str = field(default="adamw", metadata={"help": "optimizer setting, [lamb/adamw]"}) + learning_rate: float = field(default=1e-4, metadata={"help": "The initial learning rate for AdamW."}) + weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) + adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) + adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) + adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) + max_grad_norm: float = field(default=0.0, metadata={"help": "Max gradient norm."}) # clip_grad + + # new added + warmup_lr: float = field(default=0.0, metadata={"help": "The initial learning rate for AdamW."}) + min_lr: float = field(default=0.0, metadata={"help": "The initial learning rate for AdamW."}) + warmup_steps: int = field(default=-1, metadata={"help": "Linear warmup over warmup_steps."}) + warmup_epochs: int = field(default=1, metadata={"help": "Linear warmup over warmup_epochs."}) + + output_dir: str = field( + default="output_dir", + metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, + ) + logging_dir: str = field( + default="output_dir/tb_ft_log", + metadata={"help": "The output directory where logs saved."}, + ) + logging_steps: int = field(default=10, metadata={"help": "logging_steps print frequency (default: 10)"}) + + do_train: bool = field(default=False, metadata={"help": "Whether to run training."}) + do_eval: bool = field(default=False, metadata={"help": "Whether to run eval on the dev set."}) + do_predict: bool = field(default=False, metadata={"help": "Whether to run predictions on the test set."}) + do_export: bool = field(default=False, metadata={"help": "Whether to export infernece model."}) + per_device_train_batch_size: int = field(default=8, metadata={"help": "Batch size per GPU core/CPU for training."}) + per_device_eval_batch_size: int = field( + default=8, metadata={"help": "Batch size per GPU core/CPU for evaluation."} + ) + gradient_accumulation_steps: int = field( + default=1, + metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, + ) + accum_freq: int = field( + default=1, + metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}, + ) + + num_train_epochs: float = field(default=-1, metadata={"help": "Total number of training epochs to perform."}) + max_steps: int = field( + default=-1, + metadata={"help": "If > 0: set total number of training steps to perform. Override num_train_epochs."}, + ) + lr_scheduler_type: str = field( + default="cosine", + metadata={"help": "The scheduler type to use. suppor linear, cosine, constant, constant_with_warmup"}, + ) + warmup_ratio: float = field( + default=0.0, metadata={"help": "Linear warmup over warmup_ratio fraction of total steps."} + ) + warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) + num_cycles: float = field(default=0.5, metadata={"help": "The number of waves in the cosine scheduler."}) + lr_end: float = field(default=1e-7, metadata={"help": "The end LR in the polynomial scheduler."}) + power: float = field(default=1.0, metadata={"help": "The power factor in the polynomial scheduler."}) + + save_steps: int = field(default=500, metadata={"help": "Save checkpoint every X updates steps."}) + save_epochs: int = field(default=1, metadata={"help": "Save checkpoint every X updates epochs."}) + + seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) + + bf16: bool = field( + default=False, + metadata={ + "help": ( + "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA" + " architecture or using CPU (no_cuda). This is an experimental API and it may change." + ) + }, + ) + fp16: bool = field( + default=False, + metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"}, + ) + fp16_opt_level: str = field( + default="O1", + metadata={ + "help": ( + "For fp16: AMP optimization level selected in ['O0', 'O1', and 'O2']. " + "See details at https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/amp/auto_cast_cn.html" + ) + }, + ) + + dp_degree: int = field( + default=1, + metadata={"help": " data parallel degrees."}, + ) + sharding_parallel_degree: int = field( + default=1, + metadata={"help": " sharding parallel degrees."}, + ) + tensor_parallel_degree: int = field( + default=1, + metadata={"help": " tensor parallel degrees."}, + ) + pipeline_parallel_degree: int = field( + default=1, + metadata={"help": " pipeline parallel degrees."}, + ) + sep_parallel_degree: int = field( + default=1, + metadata={"help": ("sequence parallel strategy.")}, + ) + + last_epoch: int = field(default=-1, metadata={"help": "the last epoch to resume"}) + + dataloader_drop_last: bool = field( + default=True, metadata={"help": "Drop the last incomplete batch if it is not divisible by the batch size."} + ) + dataloader_num_workers: int = field( + default=1, + metadata={ + "help": "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + }, + ) + + disable_tqdm: Optional[bool] = field( + default=True, metadata={"help": "Whether or not to disable the tqdm progress bars."} + ) + tensorboard: bool = field( + default=False, + metadata={"help": "Whether to use tensorboard to record loss."}, + ) + + @dataclass class NoTrainerTrainingArguments: output_dir: str = field( diff --git a/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/ldm_trainer.py b/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/ldm_trainer.py index 83a29b93a..9d67e05f5 100644 --- a/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/ldm_trainer.py +++ b/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/ldm_trainer.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib +import os import sys import time @@ -25,6 +26,7 @@ VisualDLCallback, rewrite_logs, ) +from paddlenlp.transformers.model_utils import _add_variant from paddlenlp.utils import profiler from paddlenlp.utils.log import logger @@ -352,6 +354,49 @@ def get_train_dataloader(self): ) return train_dataloader + def _save_todo(self, output_dir=None, state_dict=None, merge_tensor_parallel=False): + # TODO: merge_tensor_parallel + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + if self.args.only_save_updated_model: + unwraped_model = unwrap_model(self.model) + logger.info(f"Saving unet checkpoint to {output_dir}/unet") + unwraped_model.unet.save_pretrained( + os.path.join(output_dir, "unet"), + # merge_tensor_parallel=merge_tensor_parallel, + ) + + if unwraped_model.use_ema: + logger.info(f"Saving ema unet checkpoint to {output_dir}/unet") + with unwraped_model.ema_scope(): + unwraped_model.unet.save_pretrained( + os.path.join(output_dir, "unet"), + # merge_tensor_parallel=merge_tensor_parallel, + variant="ema", + ) + + if unwraped_model.train_text_encoder: + logger.info(f"Saving text encoder checkpoint to {output_dir}/text_encoder") + unwraped_model.text_encoder.save_pretrained( + os.path.join(output_dir, "text_encoder"), + # merge_tensor_parallel=merge_tensor_parallel, + ) + else: + logger.info(f"Saving model checkpoint to {output_dir}") + if state_dict is None: + state_dict = self.model.state_dict() + paddle.save( + state_dict, + os.path.join( + output_dir, + _add_variant(PADDLE_WEIGHTS_NAME, self.args.weight_name_suffix), + ), + ) + if self.args.should_save: + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir, merge_tensor_parallel=merge_tensor_parallel) + paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + def get_grad_norm_and_clip(model, max_norm, norm_type=2.0, error_if_nonfinite=False): r"""Clips gradient norm of an iterable of parameters. diff --git a/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/model.py b/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/model.py index 68d680a4f..374a44510 100644 --- a/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/model.py +++ b/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/model.py @@ -22,25 +22,17 @@ import paddle import paddle.nn as nn import paddle.nn.functional as F +from paddle.distributed import fleet from paddle.nn.initializer import TruncatedNormal from paddlenlp.transformers import AutoTokenizer, CLIPTextModel from paddlenlp.utils.log import logger from ppdiffusers import AutoencoderKL, DDIMScheduler, is_ppxformers_available +from ppdiffusers.initializer import reset_initialized_parameter, zeros_ from ppdiffusers.models.attention_processor import Attention from ppdiffusers.models.ema import LitEma -from ppdiffusers.training_utils import freeze_params - -try: - from ppdiffusers.models.attention import SpatialTransformer -except ImportError: - from ppdiffusers.models.transformer_2d import ( - Transformer2DModel as SpatialTransformer, - ) - -from ppdiffusers.initializer import reset_initialized_parameter, zeros_ -from ppdiffusers.models.resnet import ResnetBlock2D from ppdiffusers.models.vae import DiagonalGaussianDistribution +from ppdiffusers.training_utils import freeze_params from .uvit_t2i import UViTT2IModel @@ -211,13 +203,9 @@ def init_uvit_weights(self): if isinstance(m, Attention) and getattr(m, "group_norm", None) is not None: zeros_(m.to_out[0].weight) zeros_(m.to_out[0].bias) - if isinstance(m, ResnetBlock2D): - zeros_(m.conv2.weight) - zeros_(m.conv2.bias) - if isinstance(m, SpatialTransformer): - zeros_(m.proj_out.weight) - zeros_(m.proj_out.bias) - if isinstance(m, nn.Linear): + if isinstance( + m, (nn.Linear, fleet.meta_parallel.ColumnParallelLinear, fleet.meta_parallel.RowParallelLinear) + ): trunc_normal_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: zeros_(m.bias) diff --git a/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/uvit_t2i.py b/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/uvit_t2i.py index c2844d055..9d6747a0d 100644 --- a/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/uvit_t2i.py +++ b/ppdiffusers/examples/text_to_image_mscoco_uvit/ldm/uvit_t2i.py @@ -13,54 +13,30 @@ # limitations under the License. import math +from typing import Optional import einops import numpy as np import paddle import paddle.nn as nn import paddle.nn.functional as F -from paddle.nn.initializer import Constant +from paddle.distributed import fleet +from paddle.distributed.fleet.meta_parallel import get_rng_state_tracker +from ppdiffusers.configuration_utils import ConfigMixin +from ppdiffusers.models.modeling_utils import ModelMixin from ppdiffusers.utils import is_ppxformers_available -ones_ = Constant(value=1.0) -zeros_ = Constant(value=0.0) - -from typing import Optional - - -def drop_path(input, drop_prob: float = 0.0, training: bool = False): - """ - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, - however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... - See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the - layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the - argument. - """ - if drop_prob == 0.0 or not training: - return input - keep_prob = 1 - drop_prob - shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets - random_tensor = keep_prob + paddle.rand(shape, dtype=input.dtype) - random_tensor = paddle.floor(random_tensor) # binarize - output = (input / keep_prob) * random_tensor - return output - -class DropPath(nn.Layer): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob: Optional[float] = None) -> None: - super().__init__() - self.drop_prob = drop_prob - - def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: - return drop_path(hidden_states, self.drop_prob, self.training) - - def extra_repr(self) -> str: - return "p={}".format(self.drop_prob) +def is_model_parrallel(): + if paddle.distributed.get_world_size() > 1: + hcg = paddle.distributed.fleet.get_hybrid_communicate_group() + if hcg.get_model_parallel_world_size() > 1: + return True + else: + return False + else: + return False class Mlp(nn.Layer): @@ -68,17 +44,36 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) + if is_model_parrallel(): + self.fc1 = fleet.meta_parallel.ColumnParallelLinear( + in_features, + hidden_features, + weight_attr=None, + has_bias=True, + gather_output=True, + ) + self.fc2 = fleet.meta_parallel.ColumnParallelLinear( + hidden_features, + out_features, + weight_attr=None, + has_bias=True, + gather_output=True, + ) + else: + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) - x = self.drop(x) + with get_rng_state_tracker().rng_state("global_seed"): + x = self.drop(x) x = self.fc2(x) - x = self.drop(x) + with get_rng_state_tracker().rng_state("global_seed"): + x = self.drop(x) return x @@ -117,9 +112,19 @@ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0. self.head_size = head_dim self.scale = qk_scale or head_dim**-0.5 - self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) - self.attn_drop = attn_drop - self.proj = nn.Linear(dim, dim) + if is_model_parrallel(): + self.qkv = fleet.meta_parallel.ColumnParallelLinear( + dim, dim * 3, weight_attr=None, has_bias=qkv_bias, gather_output=True + ) + self.proj = fleet.meta_parallel.ColumnParallelLinear( + dim, dim, weight_attr=None, has_bias=True, gather_output=True + ) + else: + self.qkv = nn.Linear(dim, dim * 3, bias_attr=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) self.proj_drop = nn.Dropout(proj_drop) self._use_memory_efficient_attention_xformers = True @@ -183,7 +188,7 @@ def forward(self, x): value_proj, attn_mask=None, scale=self.scale, - dropout_p=self.attn_drop, + dropout_p=self.attn_drop_p, training=self.training, attention_op=self._attention_op, ) @@ -191,14 +196,15 @@ def forward(self, x): with paddle.amp.auto_cast(enable=False): attention_scores = paddle.matmul(query_proj * self.scale, key_proj, transpose_y=True) attention_probs = F.softmax(attention_scores, axis=-1) + with get_rng_state_tracker().rng_state("global_seed"): + attention_probs = self.attn_drop(attention_probs) hidden_states = paddle.matmul(attention_probs, value_proj).cast(x.dtype) - hidden_states = self.reshape_batch_dim_to_heads( - hidden_states, transpose=not self._use_memory_efficient_attention_xformers - ) - - hidden_states = self.proj_drop(self.proj(hidden_states)) - return hidden_states + x = self.reshape_batch_dim_to_heads(hidden_states, transpose=not self._use_memory_efficient_attention_xformers) + x = self.proj(x) + with get_rng_state_tracker().rng_state("global_seed"): + x = self.proj_drop(x) + return x class Block(nn.Layer): @@ -211,7 +217,6 @@ def __init__( qk_scale=None, drop=0.0, attn_drop=0.0, - drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, @@ -222,11 +227,17 @@ def __init__( self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop ) - self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - self.skip_linear = nn.Linear(2 * dim, dim) if skip else None + if skip: + if is_model_parrallel(): + self.skip_linear = fleet.meta_parallel.ColumnParallelLinear( + 2 * dim, dim, weight_attr=None, has_bias=True, gather_output=True + ) + else: + self.skip_linear = nn.Linear(2 * dim, dim) + else: + self.skip_linear = None def forward(self, x, skip=None): if self.skip_linear is not None: @@ -251,7 +262,7 @@ def forward(self, x): return x -class UViTT2IModel(nn.Layer): +class UViTT2IModel(ModelMixin, ConfigMixin): def __init__( self, sample_size=32, @@ -283,7 +294,12 @@ def __init__( self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim) num_patches = (sample_size // patch_size) ** 2 - self.context_embed = nn.Linear(clip_dim, embed_dim) + if is_model_parrallel(): + self.context_embed = fleet.meta_parallel.ColumnParallelLinear( + clip_dim, embed_dim, weight_attr=None, has_bias=True, gather_output=True + ) + else: + self.context_embed = nn.Linear(clip_dim, embed_dim) self.extras = 1 + num_text_tokens self.pos_embed = self.create_parameter( shape=(1, self.extras + num_patches, embed_dim), @@ -337,8 +353,15 @@ def __init__( ) self.norm = norm_layer(embed_dim, weight_attr=False, bias_attr=False) self.patch_dim = patch_size**2 * in_channels - self.decoder_pred = nn.Linear(embed_dim, self.patch_dim, bias_attr=True) + + if is_model_parrallel(): + self.decoder_pred = fleet.meta_parallel.ColumnParallelLinear( + embed_dim, self.patch_dim, weight_attr=None, has_bias=True, gather_output=True + ) + else: + self.decoder_pred = nn.Linear(embed_dim, self.patch_dim) self.final_layer = nn.Conv2D(self.in_channels, self.in_channels, 3, padding=1) if conv else nn.Identity() + self.gradient_checkpointing = False self.fused_attn = False @@ -370,7 +393,8 @@ def forward(self, x, timesteps, encoder_hidden_states): x = paddle.concat((time_token, context_token, x), 1) x = x + self.pos_embed - x = self.pos_drop(x) + with get_rng_state_tracker().rng_state("global_seed"): + x = self.pos_drop(x) skips = [] for i, blk in enumerate(self.in_blocks): @@ -398,5 +422,4 @@ def forward(self, x, timesteps, encoder_hidden_states): x = x[:, self.extras :, :] x = unpatchify(x, self.in_channels) x = self.final_layer(x) # conv - return x diff --git a/ppdiffusers/examples/text_to_image_mscoco_uvit/run_train_uvit_t2i_small.sh b/ppdiffusers/examples/text_to_image_mscoco_uvit/run_train_uvit_t2i_small.sh index d6ba78c85..0f13e24ed 100644 --- a/ppdiffusers/examples/text_to_image_mscoco_uvit/run_train_uvit_t2i_small.sh +++ b/ppdiffusers/examples/text_to_image_mscoco_uvit/run_train_uvit_t2i_small.sh @@ -17,12 +17,12 @@ TRAINER_INSTANCES='127.0.0.1' MASTER='127.0.0.1:8080' TRAINERS_NUM=1 # nnodes, machine num TRAINING_GPUS_PER_NODE=8 # nproc_per_node -DP_DEGREE=1 # dp_parallel_degree +DP_DEGREE=8 # dp_parallel_degree MP_DEGREE=1 # tensor_parallel_degree SHARDING_DEGREE=1 # sharding_parallel_degree uvit_config_file=config/uvit_t2i_small.json -output_dir=output_dir/uvit_t2i_small +output_dir=output_trainer/uvit_t2i_small_trainer feature_path=./datasets/coco256_features per_device_train_batch_size=32 @@ -30,12 +30,12 @@ dataloader_num_workers=8 max_steps=1000000 save_steps=5000 warmup_steps=5000 -logging_steps=20 -image_logging_steps=10000 +logging_steps=50 +image_logging_steps=-1 seed=1234 USE_AMP=True -fp16_opt_level="O1" # "O2" bf16 bug now +fp16_opt_level="O1" enable_tensorboard=True recompute=True enable_xformers=True @@ -57,7 +57,7 @@ ${TRAINING_PYTHON} train_txt2img_mscoco_uvit_trainer.py \ --image_logging_steps ${image_logging_steps} \ --logging_steps ${logging_steps} \ --save_steps ${save_steps} \ - --seed ${seed}\ + --seed ${seed} \ --dataloader_num_workers ${dataloader_num_workers} \ --max_grad_norm -1 \ --uvit_config_file ${uvit_config_file} \ @@ -70,5 +70,7 @@ ${TRAINING_PYTHON} train_txt2img_mscoco_uvit_trainer.py \ --fp16 ${USE_AMP} \ --fp16_opt_level=${fp16_opt_level} \ --enable_xformers_memory_efficient_attention ${enable_xformers} \ - - + --dp_degree ${DP_DEGREE} \ + --tensor_parallel_degree ${MP_DEGREE} \ + --sharding_parallel_degree ${SHARDING_DEGREE} \ + --pipeline_parallel_degree 1 \ diff --git a/ppdiffusers/examples/text_to_image_mscoco_uvit/train_txt2img_mscoco_uvit_trainer.py b/ppdiffusers/examples/text_to_image_mscoco_uvit/train_txt2img_mscoco_uvit_trainer.py index 2b0df0c11..da312895a 100644 --- a/ppdiffusers/examples/text_to_image_mscoco_uvit/train_txt2img_mscoco_uvit_trainer.py +++ b/ppdiffusers/examples/text_to_image_mscoco_uvit/train_txt2img_mscoco_uvit_trainer.py @@ -14,6 +14,8 @@ import itertools import math import os +import pprint +import socket import paddle from ldm import ( @@ -22,19 +24,24 @@ LatentDiffusionTrainer, ModelArguments, MSCOCO256Features, + TrainerArguments, + setdistenv, ) -from paddlenlp.trainer import ( - PdArgumentParser, - TrainingArguments, - get_last_checkpoint, - set_seed, -) +from paddlenlp.trainer import PdArgumentParser, get_last_checkpoint, set_seed from paddlenlp.utils.log import logger def main(): - parser = PdArgumentParser((ModelArguments, DataArguments, TrainingArguments)) + parser = PdArgumentParser((ModelArguments, DataArguments, TrainerArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() + training_args.hostname = socket.gethostname() + pprint.pprint(data_args) + pprint.pprint(model_args) + pprint.pprint(training_args) + setdistenv(training_args) + model_args.data_world_rank = training_args.data_world_rank + model_args.data_world_size = training_args.data_world_size + training_args.report_to = ["visualdl"] training_args.resolution = data_args.resolution training_args.feature_path = data_args.feature_path diff --git a/ppdiffusers/ppdiffusers/models/dit_llama.py b/ppdiffusers/ppdiffusers/models/dit_llama.py index 7d0c5ca5d..004ac5034 100644 --- a/ppdiffusers/ppdiffusers/models/dit_llama.py +++ b/ppdiffusers/ppdiffusers/models/dit_llama.py @@ -13,15 +13,14 @@ # limitations under the License. import math -from dataclasses import dataclass -from typing import Any, Dict, Optional +from typing import Optional import paddle +import paddle.nn as nn import paddle.nn.functional as F -from paddle.distributed.fleet.utils import recompute +from paddle.nn.functional.flash_attention import flash_attention from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import recompute_use_reentrant, use_old_recompute from .embeddings import LabelEmbedding from .modeling_utils import ModelMixin from .transformer_2d import Transformer2DModelOutput @@ -45,24 +44,24 @@ def TypePromote(x, y): promote_type = TYPE_PROMOTE_DICT[y.dtype.name + x.dtype.name] else: return x, y - return x.astype(promote_type), y.astype(promote_type) + return x.cast(promote_type), y.cast(promote_type) def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(axis=1)) + shift.unsqueeze(axis=1) -class TimestepEmbedder(paddle.nn.Layer): +class TimestepEmbedder(nn.Layer): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() - self.mlp = paddle.nn.Sequential( - paddle.nn.Linear(in_features=frequency_embedding_size, out_features=hidden_size, bias_attr=True), - paddle.nn.Silu(), - paddle.nn.Linear(in_features=hidden_size, out_features=hidden_size, bias_attr=True), + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size), + nn.Silu(), + nn.Linear(hidden_size, hidden_size), ) self.frequency_embedding_size = frequency_embedding_size @@ -86,12 +85,12 @@ def timestep_embedding(t, dim, max_period=10000): def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) - t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) + t_emb = self.mlp(t_freq.cast(self.mlp[0].weight.dtype)) return t_emb -class Attention(paddle.nn.Layer): - def __init__(self, dim: int, n_heads: int, n_kv_heads, qk_norm: bool): +class Attention(nn.Layer): + def __init__(self, dim, n_heads, n_kv_heads, qk_norm=True, fused_attn=True): """ Initialize the Attention module. @@ -121,22 +120,22 @@ def __init__(self, dim: int, n_heads: int, n_kv_heads, qk_norm: bool): self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = dim // n_heads - self.wq = paddle.nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) - self.wk = paddle.nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) - self.wv = paddle.nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) - self.wo = paddle.nn.Linear(n_heads * self.head_dim, dim, bias_attr=False) + self.wq = nn.Linear(dim, n_heads * self.head_dim, bias_attr=False) + self.wk = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) + self.wv = nn.Linear(dim, self.n_kv_heads * self.head_dim, bias_attr=False) + self.wo = nn.Linear(n_heads * self.head_dim, dim, bias_attr=False) if qk_norm: - self.q_norm = paddle.nn.LayerNorm(self.n_local_heads * self.head_dim) - self.k_norm = paddle.nn.LayerNorm(self.n_local_kv_heads * self.head_dim) + self.q_norm = nn.LayerNorm(self.n_local_heads * self.head_dim) + self.k_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim) else: - self.q_norm = self.k_norm = paddle.nn.Identity() + self.q_norm = self.k_norm = nn.Identity() - self.fused_attn = False + self.fused_attn = fused_attn self.scale = self.head_dim**-0.5 @staticmethod - def reshape_for_broadcast(freqs_cis: paddle.Tensor, x: paddle.Tensor): + def reshape_for_broadcast(freqs_cis, x): """ Reshape frequency tensor for broadcasting it with another tensor. @@ -164,7 +163,7 @@ def reshape_for_broadcast(freqs_cis: paddle.Tensor, x: paddle.Tensor): return freqs_cis.reshape([*shape]) @staticmethod - def apply_rotary_emb(xq: paddle.Tensor, xk: paddle.Tensor, freqs_cis): + def apply_rotary_emb(xq, xk, freqs_cis): """ Apply rotary embeddings to input tensors using the given frequency tensor. @@ -186,14 +185,14 @@ def apply_rotary_emb(xq: paddle.Tensor, xk: paddle.Tensor, freqs_cis): and key tensor with rotary embeddings. """ with paddle.amp.auto_cast(enable=False): - xq_ = paddle.as_complex(x=xq.astype(dtype="float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) - xk_ = paddle.as_complex(x=xk.astype(dtype="float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) + xq_ = paddle.as_complex(xq.cast("float32").reshape([*tuple(xq.shape)[:-1], -1, 2])) + xk_ = paddle.as_complex(xk.cast("float32").reshape([*tuple(xk.shape)[:-1], -1, 2])) freqs_cis = Attention.reshape_for_broadcast(freqs_cis, xq_) - xq_out = paddle.as_real(x=xq_ * freqs_cis).flatten(start_axis=3) - xk_out = paddle.as_real(x=xk_ * freqs_cis).flatten(start_axis=3) - return xq_out.astype(dtype=xq.dtype), xk_out.astype(dtype=xk.dtype) + xq_out = paddle.as_real(xq_ * freqs_cis).flatten(start_axis=3) + xk_out = paddle.as_real(xk_ * freqs_cis).flatten(start_axis=3) + return xq_out.cast(xq.dtype), xk_out.cast(xk.dtype) - def forward(self, x: paddle.Tensor, freqs_cis: paddle.Tensor) -> paddle.Tensor: + def forward(self, x, freqs_cis): """ Forward pass of the attention module. @@ -217,34 +216,44 @@ def forward(self, x: paddle.Tensor, freqs_cis: paddle.Tensor) -> paddle.Tensor: xv = xv.reshape([bsz, seqlen, self.n_local_kv_heads, self.head_dim]) xq, xk = Attention.apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) - xq, xk = xq.to(dtype), xk.to(dtype) + xq, xk = xq.cast(dtype), xk.cast(dtype) n_rep = self.n_local_heads // self.n_local_kv_heads if n_rep >= 1: xk = xk.unsqueeze(axis=3).tile([1, 1, 1, n_rep, 1]).flatten(start_axis=2, stop_axis=3) xv = xv.unsqueeze(axis=3).tile([1, 1, 1, n_rep, 1]).flatten(start_axis=2, stop_axis=3) - if self.fused_attn: - output = F.scaled_dot_product_attention( - xq.transpose([0, 2, 1, 3]), - xk.transpose([0, 2, 1, 3]), - xv.transpose([0, 2, 1, 3]), - dropout_p=0.0, - is_causal=False, - ).transpose([0, 2, 1, 3]) + if dtype in [paddle.float16, paddle.bfloat16]: + output, _ = flash_attention( + xq, + xk, + xv, + dropout=0.0, + causal=False, + return_softmax=False, + ) else: - q = xq.transpose([0, 2, 1, 3]) * self.scale - attn = q @ xk.transpose([0, 2, 1, 3]).transpose([0, 1, 3, 2]) - attn = F.softmax(attn, axis=-1) - output = attn @ xv.transpose([0, 2, 1, 3]) - output = output.transpose([0, 2, 1, 3]) + if self.fused_attn: + output = F.scaled_dot_product_attention_( + xq, + xk, + xv, + dropout_p=0.0, + is_causal=False, + ) + else: + q = xq.transpose([0, 2, 1, 3]) * self.scale + attn = q @ xk.transpose([0, 2, 1, 3]).transpose([0, 1, 3, 2]) + attn = F.softmax(attn, axis=-1) + output = attn @ xv.transpose([0, 2, 1, 3]) + output = output.transpose([0, 2, 1, 3]) output = output.flatten(start_axis=-2) return self.wo(output) -class FeedForward(paddle.nn.Layer): - def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multiplier): +class FeedForward(nn.Layer): + def __init__(self, dim, hidden_dim, multiple_of=256, ffn_dim_multiplier=None): """ Initialize the FeedForward module. @@ -268,19 +277,20 @@ def __init__(self, dim: int, hidden_dim: int, multiple_of: int, ffn_dim_multipli hidden_dim = int(2 * hidden_dim / 3) if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) - hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) - self.w1 = paddle.nn.Linear(in_features=dim, out_features=hidden_dim, bias_attr=False) - self.w2 = paddle.nn.Linear(in_features=hidden_dim, out_features=dim, bias_attr=False) - self.w3 = paddle.nn.Linear(in_features=dim, out_features=hidden_dim, bias_attr=False) + hidden_dim = int(multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)) - def _forward_silu_gating(self, x1, x3): - return F.silu(x1) * x3 + self.w1 = nn.Linear(dim, hidden_dim, bias_attr=False) + self.w2 = nn.Linear(hidden_dim, dim, bias_attr=False) + self.w3 = nn.Linear(dim, hidden_dim, bias_attr=False) def forward(self, x): - return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) + xw1 = F.silu(self.w1(x)) + xw3 = self.w3(x) + output = self.w2(xw1 * xw3) + return output -class TransformerBlock(paddle.nn.Layer): +class TransformerBlock(nn.Layer): def __init__( self, layer_id: int, @@ -288,9 +298,11 @@ def __init__( n_heads: int, n_kv_heads: int, multiple_of: int, + mlp_ratio: float, ffn_dim_multiplier: float, norm_eps: float, qk_norm: bool, + fused_attn: bool, ) -> None: """ Initialize a TransformerBlock. @@ -325,18 +337,21 @@ def __init__( super().__init__() self.dim = dim self.head_dim = dim // n_heads - self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm) + self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, fused_attn) + mlp_hidden_dim = int(dim * mlp_ratio) self.feed_forward = FeedForward( - dim=dim, hidden_dim=4 * dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier + dim=dim, hidden_dim=mlp_hidden_dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier ) self.layer_id = layer_id - self.attention_norm = paddle.nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) - self.ffn_norm = paddle.nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) - self.adaLN_modulation = paddle.nn.Sequential( - paddle.nn.Silu(), paddle.nn.Linear(in_features=min(dim, 1024), out_features=6 * dim, bias_attr=True) + self.attention_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) + self.ffn_norm = nn.LayerNorm(dim, epsilon=norm_eps, bias_attr=False) + + self.adaLN_modulation = nn.Sequential( + nn.Silu(), + nn.Linear(min(dim, 1024), 6 * dim), ) - def forward(self, x: paddle.Tensor, freqs_cis: paddle.Tensor, adaln_input=None): + def forward(self, x, freqs_cis, adaln_input=None): """ Perform a forward pass through the TransformerBlock. @@ -352,13 +367,13 @@ def forward(self, x: paddle.Tensor, freqs_cis: paddle.Tensor, adaln_input=None): """ if adaln_input is not None: - (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp) = self.adaLN_modulation( - adaln_input - ).chunk(chunks=6, axis=1) - h = x + gate_msa.unsqueeze(axis=1) * self.attention( + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk( + 6, axis=1 + ) + h = x + gate_msa.unsqueeze(1) * self.attention( modulate(self.attention_norm(x), shift_msa, scale_msa), freqs_cis ) - out = h + gate_mlp.unsqueeze(axis=1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp)) + out = h + gate_mlp.unsqueeze(1) * self.feed_forward(modulate(self.ffn_norm(h), shift_mlp, scale_mlp)) else: h = x + self.attention(self.attention_norm(x), freqs_cis) out = h + self.feed_forward(self.ffn_norm(h)) @@ -369,13 +384,11 @@ class FinalLayer(paddle.nn.Layer): def __init__(self, hidden_size, patch_size, out_channels): super().__init__() self.norm_final = paddle.nn.LayerNorm(hidden_size, weight_attr=False, bias_attr=False, epsilon=1e-06) - self.linear = paddle.nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias_attr=True) - self.adaLN_modulation = paddle.nn.Sequential( - paddle.nn.Silu(), paddle.nn.Linear(min(hidden_size, 1024), 2 * hidden_size, bias_attr=True) - ) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels) + self.adaLN_modulation = nn.Sequential(nn.Silu(), nn.Linear(min(hidden_size, 1024), 2 * hidden_size)) def forward(self, x, c): - shift, scale = self.adaLN_modulation(c).chunk(chunks=2, axis=1) + shift, scale = self.adaLN_modulation(c).chunk(2, axis=1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x @@ -383,17 +396,19 @@ def forward(self, x, c): class DiTLLaMA2DModel(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True + _use_memory_efficient_attention_xformers = True @register_to_config def __init__( self, - sample_size: int = 32, + sample_size: int = 32, # image_size // 8 patch_size: int = 2, in_channels: int = 4, - out_channels: int = 4, - num_attention_heads: int = 16, - attention_head_dim: int = 88, + out_channels: int = 8, num_layers: int = 32, + num_attention_heads: int = 16, + attention_head_dim: int = 96, + mlp_ratio: float = 4.0, n_kv_heads=None, multiple_of: int = 256, ffn_dim_multiplier=None, @@ -404,38 +419,65 @@ def __init__( qk_norm: bool = True, ): super().__init__() - self.learn_sigma = learn_sigma + self.sample_size = sample_size + self.patch_size = patch_size self.in_channels = in_channels self.out_channels = in_channels * 2 if learn_sigma else in_channels + dim = attention_head_dim * num_attention_heads - self.sample_size = sample_size - self.patch_size = patch_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.mlp_ratio = mlp_ratio + self.multiple_of = multiple_of + self.ffn_dim_multiplier = ffn_dim_multiplier + self.norm_eps = norm_eps + self.class_dropout_prob = class_dropout_prob + self.num_classes = num_classes + self.learn_sigma = learn_sigma + self.qk_norm = qk_norm - dim = attention_head_dim * num_attention_heads + self.gradient_checkpointing = True + self.fused_attn = True - self.x_embedder = paddle.nn.Linear(in_channels * patch_size**2, dim, bias_attr=True) + self.x_embedder = nn.Linear(in_channels * patch_size**2, dim) self.t_embedder = TimestepEmbedder(min(dim, 1024)) self.y_embedder = LabelEmbedding(num_classes, min(dim, 1024), class_dropout_prob) - self.layers = paddle.nn.LayerList( + # 2. Define transformers blocks + self.layers = nn.LayerList( [ TransformerBlock( - layer_id, dim, num_attention_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm + layer_id=idx, + dim=dim, + n_heads=num_attention_heads, + n_kv_heads=n_kv_heads, + multiple_of=multiple_of, + mlp_ratio=mlp_ratio, + ffn_dim_multiplier=ffn_dim_multiplier, + norm_eps=norm_eps, + qk_norm=qk_norm, + fused_attn=self.fused_attn, ) - for layer_id in range(num_layers) + for idx in range(num_layers) ] ) + # 3. Define output layers self.final_layer = FinalLayer(dim, patch_size, self.out_channels) self.freqs_cis = self.precompute_freqs_cis(dim // num_attention_heads, 4096) - self.gradient_checkpointing = True - def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value - def unpatchify(self, x: paddle.Tensor) -> paddle.Tensor: + def enable_gradient_checkpointing(self, value=True): + self.gradient_checkpointing = value + + def enable_xformers_memory_efficient_attention(self, attention_op: Optional[str] = None): + self._use_memory_efficient_attention_xformers = True + self.fused_attn = True + + def unpatchify(self, x): """ Args: x: (N, T, patch_size**2 * C) @@ -445,12 +487,13 @@ def unpatchify(self, x: paddle.Tensor) -> paddle.Tensor: p = self.patch_size h = w = int(tuple(x.shape)[1] ** 0.5) assert h * w == tuple(x.shape)[1] + x = x.reshape(shape=([tuple(x.shape)[0], h, w, p, p, c])) x = paddle.einsum("nhwpqc->nchpwq", x) imgs = x.reshape(shape=([tuple(x.shape)[0], c, h * p, h * p])) return imgs - def patchify(self, x: paddle.Tensor) -> paddle.Tensor: + def patchify(self, x): B, C, H, W = tuple(x.shape) assert (H, W) == (self.sample_size, self.sample_size) pH = pW = self.patch_size @@ -479,12 +522,12 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): paddle.Tensor: Precomputed frequency tensor with complex exponentials. """ - freqs = 1.0 / theta ** (paddle.arange(start=0, end=dim, step=2)[: dim // 2].astype(dtype="float32") / dim) + freqs = 1.0 / theta ** (paddle.arange(start=0, end=dim, step=2)[: dim // 2].cast("float32") / dim) t = paddle.arange(end=end) input_0, vec2_0 = TypePromote(t, freqs) - freqs = paddle.outer(x=input_0, y=vec2_0).astype(dtype="float32") + freqs = paddle.outer(input_0, vec2_0).cast("float32") freqs_cis = paddle.complex( - paddle.ones_like(x=freqs) * paddle.cos(freqs), paddle.ones_like(x=freqs) * paddle.sin(freqs) + paddle.ones_like(freqs) * paddle.cos(freqs), paddle.ones_like(freqs) * paddle.sin(freqs) ) return freqs_cis @@ -503,36 +546,19 @@ def forward( class_labels: (N,) tensor of class labels """ hidden_states = hidden_states.cast(self.dtype) + timestep = timestep.cast(self.dtype) # 1. Input hidden_states = self.patchify(hidden_states) x = self.x_embedder(hidden_states) t = self.t_embedder(timestep) - y = self.y_embedder(class_labels, self.training) + y = self.y_embedder(class_labels) adaln_input = t + y # 2. Blocks - for layer in self.layers: - if self.gradient_checkpointing and not hidden_states.stop_gradient and not use_old_recompute(): - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - ckpt_kwargs = {} if recompute_use_reentrant() else {"use_reentrant": False} - - x = recompute( - create_custom_forward(layer), - x, - self.freqs_cis[: x.shape[1]], - adaln_input, - **ckpt_kwargs, - ) + for i, layer in enumerate(self.layers): + if self.gradient_checkpointing: + x = paddle.distributed.fleet.utils.recompute(layer, x, self.freqs_cis[: x.shape[1]], adaln_input) else: x = layer( x,