-
Notifications
You must be signed in to change notification settings - Fork 4.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use device_name instead of device index to support other device #3933
Conversation
@hipudding, can you please share more details on the failure this is supposed to fix? Each accelerator can implement |
0272639
to
ce41575
Compare
Create tensor with device=Integer will always choose cuda as its deivce for current pytorch version(2.1), other device should use device={device}:{index}. Change get_accelerator().current_device() to get_accelerator().current_device_name() to support other devices.
Yes. For example, I'm using NPU as the backend and want to create a empty tensor.
For every accelerator, current_device() will return the current index of the backend. Suppose we are using npu:1, current_device() will return Integer 1. Then, the code above equals to:
But pytorch will use Cuda as it's backend if device is a Integer. If these code want to work with every backend, it should specify the device name, so I changed current_device() to current_device_name(), which will return a device name and its index.
==
|
@hipudding, I see your point. I agree that this quite an incovenience of torch, but I was suggesting that rather than changing deepspeed code, you could follow xpu_accelerator implementation. That is working without needing this PR. |
On second thought, perhaps I should confirm what xpu_accelerator is actually doing. @delock, how do you avoid the problem solved by this PR? Thanks. |
I see these codes are introduced from zero++ (#3784), it should be misuse of |
@delock, thanks for the explanation. That makes sense. |
Create tensor with device=Integer will always choose cuda as its deivce for current pytorch version(2.1), other device should use device={device}:{index}.
Change get_accelerator().current_device() to
get_accelerator().current_device_name() to support other devices.