Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: error when training using mps #3506

Open
joprice opened this issue Jul 23, 2024 · 1 comment
Open

[Bug]: error when training using mps #3506

joprice opened this issue Jul 23, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@joprice
Copy link

joprice commented Jul 23, 2024

Describe the bug

When setting flair.device to mps, the following error is thrown during training:

RuntimeError: User specified an unsupported autocast device_type 'mps'

To Reproduce

flair.device = torch.device("mps")
... build and train model

Expected behavior

Torch's mps support should be usable via flair.

Logs and Stack traces

No response

Screenshots

No response

Additional Context

No response

Environment

Versions:

Flair

0.13.1

Pytorch

2.3.1

Transformers

4.42.4

GPU

False

@joprice joprice added the bug Something isn't working label Jul 23, 2024
@BoilerToad
Copy link

I got past that by using a newer version of torch (2.5.0) and transformers (4.43.3). Using Flair version 0.13.1 or 0.14.0 gives me the following issue when training a model ...

Traceback (most recent call last):
File "/Users/xxxxxx/Dev/flairNLP/train-models/train_model.py", line 165, in
main()
File "/Users/xxxxxx/Dev/flairNLP/train-models/train_model.py", line 138, in main
trainer.train(
File "/Users/xxxxxx/VirtualEnvs/venv-flair-311/lib/python3.11/site-packages/flair/trainers/trainer.py", line 200, in train
return self.train_custom(**local_variables, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxxx/VirtualEnvs/venv-flair-311/lib/python3.11/site-packages/flair/trainers/trainer.py", line 600, in train_custom
with torch.autocast(device_type=flair.device.type, enabled=use_amp):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxxx/VirtualEnvs/venv-flair-311/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 230, in init
dtype = torch.get_autocast_dtype(device_type)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: unsupported scalarType

I validate my instance of Torch is good and accepts "mps" using sample code from here - https://github.com/mrdbourke/pytorch-apple-silicon

Is MPS (Metal Performance Shader) built? True
Is MPS available? True
Using device: mps
... and the model works as expected there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants