Skip to content

Commit

Permalink
fix device check
Browse files Browse the repository at this point in the history
  • Loading branch information
MengqingCao committed Oct 17, 2024
1 parent b92a266 commit 0dee0b4
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/open_clip_train/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def init_distributed_device(args):
else:
device = 'cuda:0'
torch.cuda.set_device(device)
elif torch.npu.is_available():
elif args.device == "npu" and torch.npu.is_available():
if args.distributed and not args.no_set_device_rank:
device = 'npu:%d' % args.local_rank
else:
Expand Down
2 changes: 1 addition & 1 deletion src/open_clip_train/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def main(args):
hvd.broadcast_optimizer_state(optimizer, root_rank=0)

if args.precision == "amp":
if torch.npu.is_available():
if args.device == "npu" and torch.npu.is_available():
from torch.npu.amp import GradScaler
else:
from torch.cuda.amp import GradScaler
Expand Down
3 changes: 3 additions & 0 deletions src/open_clip_train/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ def parse_args(args):
parser.add_argument(
"--accum-freq", type=int, default=1, help="Update the model every --acum-freq steps."
)
parser.add_argument(
"--device", default="cuda", type=str, choices=["cpu", "cuda", "npu"], help="Accelerator to use."
)
# arguments for distributed training
parser.add_argument(
"--dist-url",
Expand Down
6 changes: 3 additions & 3 deletions src/open_clip_train/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ def profile_torch(model, text_input_size, image_input_size, batch_size=1, force_
def count_params(model):
return sum(m.numel() for m in model.parameters())

def profile_model(model_name, batch_size=1, profiler='torch'):
def profile_model(model_name, batch_size=1, profiler='torch', device="cuda"):
assert profiler in ['torch', 'fvcore'], 'Only torch and fvcore profilers are supported'
if profiler == 'fvcore':
assert fvcore is not None, 'Please install fvcore.'
model = open_clip.create_model(model_name, force_custom_text=True, pretrained_hf=False)
model.eval()
if torch.cuda.is_available():
model = model.cuda()
elif torch.npu.is_available():
elif device == "npu" and torch.npu.is_available():
model = model.npu()

if isinstance(model.visual.image_size, (tuple, list)):
Expand Down Expand Up @@ -219,7 +219,7 @@ def main():
print('='*100)
print(f'Profiling {m}')
try:
row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler)
row = profile_model(m, batch_size=args.batch_size, profiler=args.profiler, device=args.device)
results.append(row)
except Exception as e:
print(f'Error profiling {m}: {e}')
Expand Down

0 comments on commit 0dee0b4

Please sign in to comment.