From 0dee0b45c736bdd1e4e333ef87e59ca8d1b26799 Mon Sep 17 00:00:00 2001 From: MengqingCao Date: Thu, 17 Oct 2024 12:53:02 +0000 Subject: [PATCH] fix device check --- src/open_clip_train/distributed.py | 2 +- src/open_clip_train/main.py | 2 +- src/open_clip_train/params.py | 3 +++ src/open_clip_train/profiler.py | 6 +++--- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/open_clip_train/distributed.py b/src/open_clip_train/distributed.py index dd18cb643..1438c9f5f 100644 --- a/src/open_clip_train/distributed.py +++ b/src/open_clip_train/distributed.py @@ -107,7 +107,7 @@ def init_distributed_device(args): else: device = 'cuda:0' torch.cuda.set_device(device) - elif torch.npu.is_available(): + elif args.device == "npu" and torch.npu.is_available(): if args.distributed and not args.no_set_device_rank: device = 'npu:%d' % args.local_rank else: diff --git a/src/open_clip_train/main.py b/src/open_clip_train/main.py index ff1cf2003..1089ade53 100644 --- a/src/open_clip_train/main.py +++ b/src/open_clip_train/main.py @@ -329,7 +329,7 @@ def main(args): hvd.broadcast_optimizer_state(optimizer, root_rank=0) if args.precision == "amp": - if torch.npu.is_available(): + if args.device == "npu" and torch.npu.is_available(): from torch.npu.amp import GradScaler else: from torch.cuda.amp import GradScaler diff --git a/src/open_clip_train/params.py b/src/open_clip_train/params.py index 829b63817..c5b7804f2 100644 --- a/src/open_clip_train/params.py +++ b/src/open_clip_train/params.py @@ -306,6 +306,9 @@ def parse_args(args): parser.add_argument( "--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps." ) + parser.add_argument( + "--device", default="cuda", type=str, choices=["cpu", "cuda", "npu"], help="Accelerator to use." + ) # arguments for distributed training parser.add_argument( "--dist-url", diff --git a/src/open_clip_train/profiler.py b/src/open_clip_train/profiler.py index cd2a588d5..17c302201 100644 --- a/src/open_clip_train/profiler.py +++ b/src/open_clip_train/profiler.py @@ -125,7 +125,7 @@ def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_ def count_params(model): return sum(m.numel() for m in model.parameters()) -def profile_model(model_name, batch_size=1, profiler='torch'): +def profile_model(model_name, batch_size=1, profiler='torch', device="cuda"): assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported' if profiler == 'fvcore': assert fvcore is not None, 'Please install fvcore.' @@ -133,7 +133,7 @@ def profile_model(model_name, batch_size=1, profiler='torch'): model.eval() if torch.cuda.is_available(): model = model.cuda() - elif torch.npu.is_available(): + elif device == "npu" and torch.npu.is_available(): model = model.npu() if isinstance(model.visual.image_size, (tuple, list)): @@ -219,7 +219,7 @@ def main(): print('='*100) print(f'Profiling {m}') try: - row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler) + row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler, device=args.device) results.append(row) except Exception as e: print(f'Error profiling {m}: {e}')