diff --git a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc index eab50796212f1c..8db7ac8d0806f7 100644 --- a/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/reshard/reshard_utils.cc @@ -189,7 +189,8 @@ phi::DeviceContext* GetDistTensorDeviceContext( phi::distributed::DistTensor* input) { // TODO(GhostScreaming): pipeline parallel may create an undefined middle grad // tensor. In such case, we need to get default place. - auto place = input && input->defined() ? input->place() : GetDefaultPlace(); + auto place = + input && input->initialized() ? input->place() : GetDefaultPlace(); return phi::DeviceContextPool::Instance().Get(place); } diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 375be055f92aa8..d0ee6476ddd24f 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1043,41 +1043,21 @@ def __init__( # convert dygraph model to static model if isinstance(loader, ShardDataloader): ( - inputs_spec, - labels_spec, + self._engine._inputs_spec, + self._engine._labels_spec, ) = self._engine._prepare_data_spec_from_dataloader(loader) else: batch_size = loader.batch_sampler.batch_size - inputs_spec, labels_spec = self._engine._prepare_data_spec( + ( + self._engine._inputs_spec, + self._engine._labels_spec, + ) = self._engine._prepare_data_spec( loader.dataset, None, batch_size ) - if optimizer is not None and loss is not None: - # get the static graph in train mode - self._engine.prepare( - copy.deepcopy(inputs_spec), - copy.deepcopy(labels_spec), - mode="train", - init_parameters=False, - ) - if loss is not None: - # get the static graph in eval mode - self._engine.prepare( - copy.deepcopy(inputs_spec), - copy.deepcopy(labels_spec), - mode="eval", - init_parameters=False, - ) - # get the static graph in predict mode - self._engine.prepare( - copy.deepcopy(inputs_spec), - None, - mode="predict", - init_parameters=False, - ) # paddle.enable_static() will be called implicitly in self._engine.prepare. # call paddle.disable_static to keep the outside of DistModel in dynamic graph mode - paddle.disable_static() + # set the default mode if optimizer is not None and loss is not None: self.train() @@ -1093,11 +1073,11 @@ def train(self): parameters of the model and return the loss. """ if not self._engine._has_prepared["train"]: - raise RuntimeError( - "The model for training has not been prepared, please set 'loss' and 'optimizer' when constructing DistModel." - ) + self._engine._prepare_program(mode="train", init_parameters=False) + self._mode = "train" self._engine.to_mode("train") + paddle.disable_static() def eval(self): """ @@ -1105,11 +1085,11 @@ def eval(self): executing ``__call__`` will return the loss. """ if not self._engine._has_prepared["eval"]: - raise RuntimeError( - "The model for evaluation has not been prepared, please set 'loss' when constructing DistModel." - ) + self._engine._prepare_program(mode="eval", init_parameters=False) + self._mode = "eval" self._engine.to_mode("eval") + paddle.disable_static() def predict(self): """ @@ -1118,11 +1098,16 @@ def predict(self): outputs of the model. """ if not self._engine._has_prepared["predict"]: - raise RuntimeError( - "The model for prediction has not been prepared." + self._engine.prepare( + copy.deepcopy(self._engine._inputs_spec), + None, + mode="predict", + init_parameters=False, ) + self._mode = "predict" self._engine.to_mode("predict") + paddle.disable_static() def __validate_mode(self, mode): if mode is None and self._mode is None: diff --git a/python/paddle/distributed/auto_parallel/static/engine.py b/python/paddle/distributed/auto_parallel/static/engine.py index 3afca065238f47..422c1a76c1d54d 100644 --- a/python/paddle/distributed/auto_parallel/static/engine.py +++ b/python/paddle/distributed/auto_parallel/static/engine.py @@ -825,14 +825,19 @@ def _plan(self, mode): if var.name in block.vars: feed_list.append(block.vars[var.name]) - self._dp_world_sizes = [] - self._dp_ranks = [] - for feed_var in feed_list: - dp_world_size, dp_rank = auto_utils.get_input_split_info( - self._cur_rank, feed_var, self._dist_contexts[mode] - ) - self._dp_world_sizes.append(dp_world_size) - self._dp_ranks.append(dp_rank) + self._dp_world_sizes = getattr(self, "_dp_world_sizes", []) + self._dp_ranks = getattr(self, "_dp_ranks", []) + if mode in ['eval', 'predice'] or ( + not self._dp_world_sizes and not self._dp_ranks + ): + self._dp_world_sizes = [] + self._dp_ranks = [] + for feed_var in feed_list: + dp_world_size, dp_rank = auto_utils.get_input_split_info( + self._cur_rank, feed_var, self._dist_contexts[mode] + ) + self._dp_world_sizes.append(dp_world_size) + self._dp_ranks.append(dp_rank) def _parallel(self, mode, all_ranks=False): # Parallelize program based on the planner's results diff --git a/python/paddle/distributed/auto_parallel/static/helper.py b/python/paddle/distributed/auto_parallel/static/helper.py index b0ea25a492e6c1..436e53a42e5bad 100644 --- a/python/paddle/distributed/auto_parallel/static/helper.py +++ b/python/paddle/distributed/auto_parallel/static/helper.py @@ -16,6 +16,7 @@ import logging from collections import defaultdict +import paddle from paddle.jit import not_to_static, to_static from paddle.jit.dy2static.program_translator import ( ProgramTranslator, @@ -30,6 +31,7 @@ ) from .converter import Converter +from .process_group import get_world_process_group from .utils import get_logger, to_list @@ -251,7 +253,9 @@ def build_program(self, mode): # NOTE(dev): Because @to_static is a Lazy mechanism, so we explicitly call this to trigger # generating Program IR immediately. - concrete_program = getattr(self.proxy_layer, func_name).concrete_program + concrete_program = getattr( + self.proxy_layer, func_name + ).concrete_program # noqa: B018 prepare_op_amp_options( concrete_program.main_program, ProgramTranslator.get_instance()._amp_records, @@ -335,8 +339,26 @@ def static_func(self): def init(self, main_program, place, dist_context): if self.lazy_init: return + + is_comm = False for param in self.concrete_program.parameters: + if param.is_dist(): + serial_main_program = self.concrete_program.main_program + var = serial_main_program.global_block().vars[param.name] + var_dist_attr = dist_context.get_tensor_dist_attr_for_program( + var + ) + is_comm = True + tmp = paddle.base.core.reshard(param, var_dist_attr) + if tmp._is_initialized(): + param.get_tensor()._share_data_with(tmp.get_tensor()) + else: + param = None + paddle.device.synchronize() + # create var in scope and share parameters to scope + if param is None: + continue if param.name not in main_program.global_block().vars: continue if param.is_dense(): @@ -361,6 +383,19 @@ def init(self, main_program, place, dist_context): dense_tensor = global_scope().var(param.name).get_tensor() dense_tensor._share_data_with(param.get_tensor().get_tensor()) + world_group = get_world_process_group() + if ( + is_comm + and world_group.nranks > 1 + and paddle.distributed.get_world_size() > 1 + ): + paddle.disable_static() + barrier_tensor = paddle.full([1], 1, dtype="int32") + paddle._legacy_C_ops.barrier( + barrier_tensor, barrier_tensor, 'ring_id', 0 + ) + paddle.enable_static() + @property def concrete_program(self): return self.static_func().concrete_program diff --git a/test/auto_parallel/hybrid_strategy/CMakeLists.txt b/test/auto_parallel/hybrid_strategy/CMakeLists.txt index 71716897874b24..4af5c3e2cc37a4 100644 --- a/test/auto_parallel/hybrid_strategy/CMakeLists.txt +++ b/test/auto_parallel/hybrid_strategy/CMakeLists.txt @@ -42,3 +42,11 @@ if((WITH_GPU) AND (LINUX)) set_tests_properties(test_cross_mesh_reshard PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=HYBRID") endif() +if((WITH_GPU) AND (LINUX)) + py_test_modules( + test_semi_auto_parallel_llama_model_vpp MODULES + test_semi_auto_parallel_llama_model_vpp ENVS + "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_semi_auto_parallel_llama_model_vpp + PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=HYBRID") +endif() diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_llama_pp_gradmerge.py b/test/auto_parallel/hybrid_strategy/semi_auto_llama_pp_gradmerge.py new file mode 100644 index 00000000000000..86943d2e47c6d3 --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/semi_auto_llama_pp_gradmerge.py @@ -0,0 +1,289 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import reduce + +import numpy as np +from semi_auto_parallel_llama_model import ( + LlamaForCausalLMAuto, + LlamaPretrainingCriterionAuto, + set_global_mesh, +) + +import paddle +import paddle.distributed as dist +from paddle import LazyGuard +from paddle.io import BatchSampler, DataLoader, Dataset + + +class Config: + vocab_size = 32000 + hidden_size = 4096 + intermediate_size = 11008 + max_position_embeddings = 2048 + seq_length = 2048 + num_hidden_layers = 4 + num_attention_heads = 32 + num_key_value_heads = 32 + initializer_range = 0.02 + rms_norm_eps = 1e-6 + use_cache = True + use_flash_attention = False + sequence_parallel = False + rope = True + recompute = False + recompute_granularity = None + use_lazy_init = False + virtual_pp_degree = 1 + + +class RandomDataset(Dataset): + def __init__(self, seq_len, num_samples=100): + super().__init__() + self.seq_len = seq_len + self.num_samples = num_samples + + def __getitem__(self, index): + input = np.random.uniform(size=[self.seq_len]).astype("int64") + label = (np.random.uniform(size=[self.seq_len]) * 10).astype("int64") + return input, label + + def __len__(self): + return self.num_samples + + +def create_optimizer(model, lr_scheduler): + decay_parameters = [ + p.name + for n, p in model.named_parameters() + if not any(nd in n for nd in ["bias", "norm"]) + ] + + def apply_decay_param_fun(x): + return x in decay_parameters + + # test global_clip in auto_parallel + if os.getenv("use_param_group") == "true": + param_group = {} + param_group["params"] = list(model.parameters()) + param_group["weight_decay"] = 0.01 + param_group["grad_clip"] = paddle.nn.ClipGradByGlobalNorm(1.0) + optimizer = paddle.optimizer.adamw.AdamW( + learning_rate=lr_scheduler, + apply_decay_param_fun=apply_decay_param_fun, + parameters=[param_group], + ) + else: + optimizer = paddle.optimizer.adamw.AdamW( + learning_rate=lr_scheduler, + apply_decay_param_fun=apply_decay_param_fun, + parameters=model.parameters(), + weight_decay=0.01, + grad_clip=paddle.nn.ClipGradByGlobalNorm(1.0), + ) + return optimizer + + +class TestLlamaAuto: + def __init__(self): + self.config = Config() + self.dp = int(os.getenv("dp")) + self.mp = int(os.getenv("mp")) + self.pp = int(os.getenv("pp")) + if os.getenv("virtual_pp_degree"): + self.config.virtual_pp_degree = int(os.getenv("virtual_pp_degree")) + if os.getenv("use_sp") == "true": + self.config.sequence_parallel = True + if os.getenv("recompute") == "true": + self.config.recompute = True + self.config.recompute_granularity = os.getenv("recompute_granularity") + self.gradient_accumulation_steps = int(os.getenv("acc_step")) + self.only_static = os.getenv("only_static") + + if self.config.virtual_pp_degree == 1: + self.schedule_mode = "1F1B" + elif self.config.virtual_pp_degree > 1: + self.schedule_mode = "VPP" + + self.init_dist_env() + + def init_dist_env(self): + order = ["dp", "pp", "mp"] + dp_degree = self.dp + mp_degree = self.mp + pp_degree = self.pp + degree = [dp_degree, pp_degree, mp_degree] + mesh_dims = list(filter(lambda x: x[1] > 1, list(zip(order, degree)))) + if not mesh_dims: + mesh_dims = [("dp", 1)] + dim_names = [mesh_dim[0] for mesh_dim in mesh_dims] + mesh_shape = [mesh_dim[1] for mesh_dim in mesh_dims] + mesh_arr = np.arange( + 0, reduce(lambda x, y: x * y, mesh_shape, 1) + ).reshape(mesh_shape) + global_mesh = dist.ProcessMesh(mesh_arr, dim_names) + set_global_mesh(global_mesh) + + def run_llama(self, to_static=0): + if self.only_static and to_static == 0: + return + + if self.config.use_lazy_init: + with LazyGuard(): + model = LlamaForCausalLMAuto(self.config) + for param in model.parameters(): + assert not param._is_initialized() + param.initialize() + else: + model = LlamaForCausalLMAuto(self.config) + model = LlamaForCausalLMAuto(self.config) + criterion = LlamaPretrainingCriterionAuto(self.config) + + lr_scheduler = paddle.optimizer.lr.LinearWarmup( + learning_rate=0.0001, warmup_steps=2, start_lr=0, end_lr=0.0001 + ) + optimizer = create_optimizer(model, lr_scheduler) + optimizer = dist.shard_optimizer(optimizer) + + micro_bsz = 2 + global_bsz = micro_bsz * self.dp * self.gradient_accumulation_steps + + global_step = 1 + tr_loss = float(0) + + if not to_static: + train_dataset = RandomDataset(self.config.seq_length) + train_sampler = BatchSampler( + train_dataset, + batch_size=micro_bsz, + shuffle=True, + drop_last=True, + ) + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + num_workers=0, + ) + + model.train() + for epoch_idx in range(1): + for step, inputs in enumerate(train_dataloader): + input_ids, labels = inputs + logits = model(input_ids) + tr_loss_step = criterion(logits, labels) + + if self.gradient_accumulation_steps > 1: + tr_loss_step /= self.gradient_accumulation_steps + + tr_loss_step.backward() + tr_loss += tr_loss_step + + if global_step % self.gradient_accumulation_steps == 0: + print( + f"step: {global_step // self.gradient_accumulation_steps} loss: {tr_loss.numpy()}" + ) + optimizer.step() + optimizer.clear_grad() + lr_scheduler.step() + tr_loss = 0 + + global_step += 1 + if global_step // self.gradient_accumulation_steps >= 10: + break + else: + strategy = dist.Strategy() + if self.pp > 1 and self.gradient_accumulation_steps > 1: + strategy.pipeline.enable = True + strategy.pipeline.accumulate_steps = ( + self.gradient_accumulation_steps + ) + strategy.pipeline.micro_batch_size = micro_bsz + strategy.pipeline.schedule_mode = self.schedule_mode + strategy.pipeline.vpp_degree = self.config.virtual_pp_degree + strategy.pipeline.vpp_seg_method = "LlamaDecoderLayerAuto" + elif self.gradient_accumulation_steps > 1: + strategy.gradient_merge.enable = True + strategy.gradient_merge.k_steps = ( + self.gradient_accumulation_steps + ) + strategy.gradient_merge.avg = True + + train_dataset = RandomDataset(self.config.seq_length) + train_sampler = BatchSampler( + train_dataset, + batch_size=global_bsz, + shuffle=True, + drop_last=True, + ) + train_dataloader = DataLoader( + train_dataset, + batch_sampler=train_sampler, + num_workers=0, + ) + + dist_model, dist_loader = dist.to_static( + model, train_dataloader, criterion, optimizer, strategy=strategy + ) + + def validate_batch(batch): + if self.gradient_accumulation_steps == 1 or self.pp > 1: + batches = [batch] + else: + split_batches = [ + np.split( + np.array(b), self.gradient_accumulation_steps, 0 + ) + for b in batch + ] + batches = [] + for i in range(len(split_batches[0])): + micro_batch = [ + split_batch[i] for split_batch in split_batches + ] + batches.append(micro_batch) + return batches + + dist_model.train() + for epoch_idx in range(1): + for step, inputs in enumerate(dist_loader()): + batches = validate_batch(inputs) + for micro_batch in batches: + input_ids, labels = micro_batch + tr_loss_step = dist_model(input_ids, labels) + + if ( + tr_loss_step is not None + and self.gradient_accumulation_steps > 1 + ): + tr_loss_step = np.sum(tr_loss_step) + tr_loss_step /= self.gradient_accumulation_steps + + if tr_loss_step: + tr_loss += tr_loss_step + + print(f"step: {step} loss: {np.array(tr_loss)}") + lr_scheduler.step() + tr_loss = float(0) + + if step >= 10: + break + + def run_test_cases(self): + self.run_llama(to_static=0) + self.run_llama(to_static=1) + + +if __name__ == '__main__': + TestLlamaAuto().run_test_cases() diff --git a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_llama_model.py b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_llama_model.py index cd9c29b4a7d4df..c729fc86e67228 100644 --- a/test/auto_parallel/hybrid_strategy/semi_auto_parallel_llama_model.py +++ b/test/auto_parallel/hybrid_strategy/semi_auto_parallel_llama_model.py @@ -35,10 +35,12 @@ def set_global_mesh(mesh): _global_mesh = mesh -def get_mesh(pp_idx=0): +def get_mesh(pp_idx=None): global _global_mesh mesh = _global_mesh assert _global_mesh is not None, "_global_mesh is not initialized!" + if pp_idx is None: + return mesh if "pp" in _global_mesh.dim_names: mesh = _global_mesh.get_mesh_with_dim("pp")[pp_idx] return mesh @@ -483,7 +485,7 @@ def __init__(self, config): ) self.embed_tokens.weight = dist.shard_tensor( self.embed_tokens.weight, - get_mesh(), + get_mesh(0), [dist.Replicate(), dist.Shard(1)], ) @@ -517,7 +519,7 @@ def get_layer_ipp(layer_index): @staticmethod def _prepare_decoder_attention_mask( - attention_mask, input_shape, past_key_values_length, dtype + attention_mask, input_shape, past_key_values_length, dtype, mesh ): if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -533,7 +535,7 @@ def _prepare_decoder_attention_mask( ) combined_attention_mask = dist.shard_tensor( combined_attention_mask, - get_mesh(), + mesh, [dist.Replicate(), dist.Replicate()], ) expanded_attn_mask = ( @@ -577,6 +579,14 @@ def forward( use_cache if use_cache is not None else self.config.use_cache ) + if ( + not paddle.in_dynamic_mode() + and getattr(self.config, "virtual_pp_degree", 1) > 1 + ): + # NOTE: temprorary method to guarantee the later ops are placed on all ranks until meeting new annotaion. + full = dist.shard_op(paddle.full, get_mesh()) + full(shape=[1], fill_value=0) + # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( @@ -600,9 +610,32 @@ def forward( cache_length = paddle.shape(past_key_values[0][0])[1] seq_length_with_past += cache_length + if ( + not paddle.in_dynamic_mode() + and getattr(self.config, "virtual_pp_degree", 1) > 1 + ): + # NOTE: temprorary method to guarantee the later ops are placed on pp stage 0 until meeting new annotaion. + full = dist.shard_op(paddle.full, get_mesh(0)) + full(shape=[1], fill_value=0) + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + if self.config.sequence_parallel: + # [B, S, H] -> [S, B, H] + inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) + + if ( + not paddle.in_dynamic_mode() + and getattr(self.config, "virtual_pp_degree", 1) > 1 + ): + # NOTE: temprorary method to guarantee the later ops are placed on all ranks until meeting new annotaion. + full = dist.shard_op(paddle.full, get_mesh()) + full(shape=[1], fill_value=0) + mesh = get_mesh() + else: + mesh = get_mesh(0) + # embed positions if attention_mask is None: # [bs, seq_len] @@ -615,25 +648,24 @@ def forward( (batch_size, seq_length) ) position_ids = dist.shard_tensor( - position_ids, get_mesh(), [dist.Replicate(), dist.Replicate()] + position_ids, mesh, [dist.Replicate(), dist.Replicate()] ) - if self.config.sequence_parallel: - # [B, S, H] -> [S, B, H] - inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2]) - attention_mask = self._prepare_decoder_attention_mask( attention_mask, (batch_size, seq_length), cache_length, inputs_embeds.dtype, + mesh, ) # [bs, 1, seq_len, seq_len] if self.config.use_flash_attention: is_casual = is_casual_mask(attention_mask) if is_casual: attention_mask = None hidden_states = inputs_embeds - hidden_states = dist.reshard(hidden_states, get_mesh(), self.placements) + hidden_states = dist.reshard( + hidden_states, get_mesh(0), self.placements + ) # decoder layers all_hidden_states = () if output_hidden_states else None @@ -649,23 +681,35 @@ def forward( ) has_gradient = not hidden_states.stop_gradient - - if decoder_layer.ipp is not None and pre_ipp != decoder_layer.ipp: - hidden_states = dist.reshard( - hidden_states, - get_mesh(decoder_layer.ipp), - self.placements, - ) - position_ids = dist.reshard( - position_ids, - get_mesh(decoder_layer.ipp), - [dist.Shard(0), dist.Replicate()], - ) - attention_mask = dist.reshard( - attention_mask, - get_mesh(decoder_layer.ipp), - [dist.Shard(0), dist.Replicate()], - ) + ipp = decoder_layer.ipp + + if ipp is not None and pre_ipp != ipp: + if ( + not paddle.in_dynamic_mode() + and getattr(self.config, "virtual_pp_degree", 1) > 1 + ): + hidden_states = dist.reshard( + hidden_states, + get_mesh(ipp), + self.placements, + ) + decoder_layer = dist.shard_op(decoder_layer, get_mesh(ipp)) + else: + hidden_states = dist.reshard( + hidden_states, + get_mesh(ipp), + self.placements, + ) + position_ids = dist.reshard( + position_ids, + get_mesh(ipp), + [dist.Shard(0), dist.Replicate()], + ) + attention_mask = dist.reshard( + attention_mask, + get_mesh(ipp), + [dist.Shard(0), dist.Replicate()], + ) if ( self.config.recompute @@ -690,7 +734,7 @@ def forward( past_key_value, use_cache, ) - pre_ipp = decoder_layer.ipp + pre_ipp = ipp if type(layer_outputs) is tuple: hidden_states = layer_outputs[0] diff --git a/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model_vpp.py b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model_vpp.py new file mode 100644 index 00000000000000..447b4c9705497c --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/test_semi_auto_parallel_llama_model_vpp.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import unittest + +sys.path.append("../../") +import collective.test_communication_api_base as test_base + + +class TestSemiAutoParallelLlama3DVPP(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=8, timeout=200, nnode=1) + self._default_envs = { + "seed": "2023", + "dp": "2", + "mp": "2", + "pp": "2", + "FLAGS_embedding_deterministic": "1", + "FLAGS_cudnn_deterministic": "1", + "acc_step": "2", + "only_static": "true", + } + self._changeable_envs = { + "backend": ["gpu"], + "use_sp": ["true"], + "use_param_group": ["true"], + "recompute": ["true"], + "recompute_granularity": ["full"], + "virtual_pp_degree": ["2"], + } + + def test_simple_net_hybrid_strategy(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + self._log_dir.name = "./vpp_log" + for envs in envs_list: + self.run_test_case( + "semi_auto_llama_pp_gradmerge.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/semi_auto_parallel_dist_to_static_api.py b/test/auto_parallel/semi_auto_parallel_dist_to_static_api.py index 886e68f27b676d..fd6ec758086d97 100644 --- a/test/auto_parallel/semi_auto_parallel_dist_to_static_api.py +++ b/test/auto_parallel/semi_auto_parallel_dist_to_static_api.py @@ -120,6 +120,7 @@ def get_program_test(self, dist_model): self.assertNotEqual(main_program, None) self.assertNotEqual(startup_program, None) + dist_model.eval() main_program = dist_model.dist_main_program("eval") startup_program = dist_model.dist_startup_program("eval") self.assertNotEqual(main_program, None) @@ -212,12 +213,12 @@ def run_test(self): dist_model._engine._has_prepared["train"] = False dist_model._engine._has_prepared["eval"] = False dist_model._engine._has_prepared["predict"] = False - with self.assertRaises(RuntimeError): + with self.assertRaises(TypeError): dist_model.train() - with self.assertRaises(RuntimeError): + with self.assertRaises(TypeError): dist_model.eval() - with self.assertRaises(RuntimeError): - dist_model.predict() + # with self.assertRaises(TypeError): + dist_model.predict() dist_model._engine._has_prepared["train"] = True dist_model._engine._has_prepared["eval"] = True dist_model._engine._has_prepared["predict"] = True