Skip to content

Commit

Permalink
fix cpu infer device_map (#2103)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jintao-Huang committed Sep 23, 2024
1 parent 4bd62ea commit 3b82b5d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 28 deletions.
15 changes: 7 additions & 8 deletions swift/llm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,15 @@ def prepare_model_template(args: InferArguments,
device_map: Optional[str] = None,
verbose: bool = True,
automodel_class=None) -> Tuple[PreTrainedModel, Template]:

model_kwargs = {}
from .sft import get_default_device_map
if is_torch_npu_available():
logger.info(f'device_count: {torch.npu.device_count()}')
if device_map is None:
device_map = 'npu:0'
print(f'device_count: {torch.npu.device_count()}')
else:
logger.info(f'device_count: {torch.cuda.device_count()}')
if device_map is None:
device_map = 'auto' if torch.cuda.device_count() > 1 else 'cuda:0'
print(f'device_count: {torch.cuda.device_count()}')
model_kwargs = {}
if device_map is not None:
device_map = get_default_device_map()
model_kwargs['device_map'] = device_map
if device_map == 'auto':
model_kwargs['low_cpu_mem_usage'] = True
model_kwargs['device_map'] = device_map
Expand Down
48 changes: 30 additions & 18 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from swift.torchacc_utils import patch_acc_model
from swift.trainers import TrainerFactory
from swift.trainers.utils import can_return_loss, find_labels
from swift.utils import (append_to_jsonl, check_json_format, compute_acc_metrics, compute_nlg_metrics, get_logger,
get_main, get_model_info, is_ddp_plus_mp, is_dist, is_master, plot_images,
from swift.utils import (append_to_jsonl, check_json_format, compute_acc_metrics, compute_nlg_metrics, get_dist_setting,
get_logger, get_main, get_model_info, is_ddp_plus_mp, is_dist, is_master, plot_images,
preprocess_logits_for_metrics, seed_everything, show_layers, use_torchacc)
from .accelerator import ta_accelerate
from .tuner import prepare_model
Expand Down Expand Up @@ -114,6 +114,25 @@ def llm_sft_megatron(args: SftArguments) -> Dict[str, Any]:
return {}


def get_default_device_map():
if is_deepspeed_zero3_enabled() or os.environ.get('ACCELERATE_USE_FSDP', 'False') == 'true':
return None
local_rank = get_dist_setting()[1]
if is_torch_npu_available():
if local_rank >= 0:
return f'npu:{local_rank}'
else:
return 'npu:0'
if torch.cuda.device_count() == 0:
return 'cpu'
elif torch.cuda.device_count() == 1:
return 'cuda:0'
elif is_dist() and not is_ddp_plus_mp():
return f'cuda:{local_rank}'
else:
return 'auto'


def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None):

if args.gpu_memory_fraction is not None:
Expand All @@ -128,21 +147,15 @@ def prepare_model_template_train(args, msg: Optional[Dict[str, Any]] = None):
f'world_size: {args.world_size}, local_world_size: {args.local_world_size}')

# Loading Model and Tokenizer
if is_deepspeed_zero3_enabled() or os.environ.get('ACCELERATE_USE_FSDP', 'False') == 'true':
model_kwargs = {'device_map': None}
elif is_torch_npu_available():
model_kwargs = {'device_map': args.local_rank if args.local_rank >= 0 else 0}
elif args.device_map_config is not None:
model_kwargs = {'device_map': args.device_map_config}
else:
model_kwargs = {'low_cpu_mem_usage': True}
if is_dist() and not is_ddp_plus_mp():
model_kwargs['device_map'] = {'': args.local_rank}
elif torch.cuda.device_count() == 1:
model_kwargs['device_map'] = 'cuda:0'
elif not use_torchacc():
model_kwargs['device_map'] = 'auto'

model_kwargs = {}
if not use_torchacc():
if args.device_map_config is not None:
device_map = args.device_map_config
else:
device_map = get_default_device_map()
model_kwargs['device_map'] = device_map
if device_map == 'auto':
model_kwargs['low_cpu_mem_usage'] = True
if args.device_max_memory:
n_gpu = torch.cuda.device_count()
assert len(args.device_max_memory) == n_gpu // args.local_world_size
Expand Down Expand Up @@ -354,7 +367,6 @@ def prepare_dataset(args, template: Template, msg: Optional[Dict[str, Any]] = No
f'Setting args.preprocess_num_proc to: {args.preprocess_num_proc}')
else:
template.model = None
logger.info(f'Using num_proc: {args.preprocess_num_proc}')
td0, tkwargs0 = template.encode(train_dataset[0])
print_example(td0, tokenizer, tkwargs0)
train_dataset = dataset_map(train_dataset, template.encode, args.preprocess_num_proc, streaming=args.streaming)
Expand Down
4 changes: 2 additions & 2 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def _map_mp(dataset: HfDataset, map_func: MapFunc, num_proc: int) -> List[Dict[s
# Solving the unordered problem
data = [None] * len(dataset)
num_proc = min(num_proc, len(dataset))
for d in tqdm(_map_mp_i(dataset, map_func, num_proc), total=len(dataset)):
for d in tqdm(_map_mp_i(dataset, map_func, num_proc), total=len(dataset), desc=f'Map (num_proc={num_proc})'):
data[d[0]] = d[1]
return data

Expand All @@ -314,7 +314,7 @@ def dataset_map(dataset: DATASET_TYPE,
single_map = partial(_single_map, map_func=map_func)
if num_proc == 1:
data = []
for d in tqdm(dataset):
for d in tqdm(dataset, desc='Map'):
d = single_map(d)
data.append(d)
else:
Expand Down

0 comments on commit 3b82b5d

Please sign in to comment.