Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MPS (Mac M1) device support #13102

Closed
carmocca opened this issue May 18, 2022 · 15 comments · Fixed by #13642
Closed

MPS (Mac M1) device support #13102

carmocca opened this issue May 18, 2022 · 15 comments · Fixed by #13642
Assignees
Labels
accelerator feature Is an improvement or enhancement
Milestone

Comments

@carmocca
Copy link
Contributor

carmocca commented May 18, 2022

🚀 Feature

https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/
Docs: https://pytorch.org/docs/master/notes/mps.html


If you enjoy Lightning, check out our other projects! ⚡

  • Metrics: Machine learning metrics for distributed, scalable PyTorch applications.

  • Lite: enables pure PyTorch users to scale their existing code on any kind of device while retaining full control over their own loops and optimization logic.

  • Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, fine-tuning, and solving problems with deep learning.

  • Bolts: Pretrained SOTA Deep Learning models, callbacks, and more for research and production with PyTorch Lightning and PyTorch.

  • Lightning Transformers: Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

cc @Borda @akihironitta @rohitgr7 @justusschock

@carmocca carmocca added feature Is an improvement or enhancement accelerator labels May 18, 2022
@carmocca carmocca added this to the 1.7 milestone May 18, 2022
@carmocca
Copy link
Contributor Author

carmocca commented May 18, 2022

I heard @awaelchli just got an M1 😈

@awaelchli
Copy link
Contributor

@justusschock also has one if I remember correctly 😃

@justusschock
Copy link
Member

I'll investigate on the weekend

@carmocca
Copy link
Contributor Author

After a bit of offline discussion, we thought about this API:

  • The MPSAccelerator is introduced.
  • The GPUAccelerator is deprecated and renamed in favor of CUDAAccelerator
  • 2 new accelerator options: Trainer(accelerator="cuda"|"mps") to explicitly choose one of the above.
  • Trainer(accelerator="gpu") now chooses from the 2 accelerators above based on the available hardware. (side note: this is how we would also support ROCm).
  • Trainer(accelerator="auto") is different in that it also considers IPUs, TPUs, HPUs...

This shouldn't be a problem because a machine cannot have both accelerators available.

@eubinecto
Copy link

Just when I thought I wanted to raise an issue! I'm so rooting for this.

@ananthsub
Copy link
Contributor

I'm glad we did the accelerator refactor to make supporting new features like this much easier :)
GPUs is ambiguous, whereas the torch.device type is not: #10410 (comment)

@scalastic
Copy link

scalastic commented May 29, 2022

Note that, on M1 max, with a basic example like LitAutoEncoder of MNIST (like the one described on PytorchLightning homepage) I get better results without mpssettings:

results
  • With accelerator="mps" :
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
...
Epoch 0: | 7936/55000 [01:22<08:06, **96.77it/s**, loss=0.00155, v_num=0]
  • Without "mps" settings:
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
...
Epoch 0:   | 26081/55000 [00:52<00:57, **498.76it/s**, loss=0.0416, v_num=1]
env

My configuration is (numpy is installed with blas=*=*accelerate* in both cases):

% conda list
# packages in environment at /opt/homebrew/Caskroom/miniforge/base/envs/pytorch-lightning-m1:
#
# Name                    Version                   Build  Channel
absl-py                   1.0.0                    pypi_0    pypi
aiohttp                   3.8.1                    pypi_0    pypi
aiosignal                 1.2.0                    pypi_0    pypi
async-timeout             4.0.2                    pypi_0    pypi
attrs                     21.4.0                   pypi_0    pypi
blas                      2.114                accelerate    conda-forge
blas-devel                3.9.0           14_osxarm64_accelerate    conda-forge
bzip2                     1.0.8                h3422bc3_4    conda-forge
ca-certificates           2022.5.18.1          h4653dfc_0    conda-forge
cachetools                5.1.0                    pypi_0    pypi
certifi                   2022.5.18.1              pypi_0    pypi
charset-normalizer        2.0.12                   pypi_0    pypi
frozenlist                1.3.0                    pypi_0    pypi
fsspec                    2022.2.0                 pypi_0    pypi
google-auth               2.6.6                    pypi_0    pypi
google-auth-oauthlib      0.4.6                    pypi_0    pypi
grpcio                    1.46.3                   pypi_0    pypi
idna                      3.3                      pypi_0    pypi
importlib-metadata        4.11.4                   pypi_0    pypi
libblas                   3.9.0           14_osxarm64_accelerate    conda-forge
libcblas                  3.9.0           14_osxarm64_accelerate    conda-forge
libcxx                    14.0.3               h6a5c8ee_0    conda-forge
libffi                    3.4.2                h3422bc3_5    conda-forge
libgfortran               5.0.0.dev0      11_0_1_hf114ba7_23    conda-forge
libgfortran5              11.0.1.dev0         hf114ba7_23    conda-forge
liblapack                 3.9.0           14_osxarm64_accelerate    conda-forge
liblapacke                3.9.0           14_osxarm64_accelerate    conda-forge
libprotobuf               3.20.1               h332123e_0    conda-forge
libzlib                   1.2.11            h90dfc92_1014    conda-forge
llvm-openmp               14.0.4               hd125106_0    conda-forge
markdown                  3.3.7                    pypi_0    pypi
multidict                 6.0.2                    pypi_0    pypi
ncurses                   6.3                  h07bb92c_1    conda-forge
numpy                     1.22.3                   pypi_0    pypi
oauthlib                  3.2.0                    pypi_0    pypi
openssl                   3.0.3                ha287fd2_0    conda-forge
packaging                 21.3                     pypi_0    pypi
pandas                    1.4.2                    pypi_0    pypi
pillow                    9.1.1                    pypi_0    pypi
pip                       22.1.1             pyhd8ed1ab_0    conda-forge
protobuf                  4.21.0                   pypi_0    pypi
pyasn1                    0.4.8                    pypi_0    pypi
pyasn1-modules            0.2.8                    pypi_0    pypi
pydeprecate               0.3.2                    pypi_0    pypi
pyparsing                 3.0.9                    pypi_0    pypi
python                    3.8.13          hd3575e6_0_cpython    conda-forge
python-dateutil           2.8.2                    pypi_0    pypi
python_abi                3.8                      2_cp38    conda-forge
pytorch-lightning         1.7.0.dev0                dev_0    <develop>
pytz                      2022.1                   pypi_0    pypi
pyyaml                    6.0                      pypi_0    pypi
readline                  8.1                  hedafd6a_0    conda-forge
requests                  2.27.1                   pypi_0    pypi
requests-oauthlib         1.3.1                    pypi_0    pypi
rsa                       4.8                      pypi_0    pypi
setuptools                62.3.2           py38h10201cd_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
sqlite                    3.38.5               h40dfcc0_0    conda-forge
tensorboard               2.8.0                    pypi_0    pypi
tensorboard-data-server   0.6.1                    pypi_0    pypi
tensorboard-plugin-wit    1.8.1                    pypi_0    pypi
tk                        8.6.12               he1e0b03_0    conda-forge
torch                     1.13.0.dev20220525          pypi_0    pypi
torch-tb-profiler         0.4.0                    pypi_0    pypi
torchmetrics              0.7.2                    pypi_0    pypi
torchvision               0.12.0                   pypi_0    pypi
tqdm                      4.63.0                   pypi_0    pypi
typing-extensions         4.1.1                    pypi_0    pypi
urllib3                   1.26.9                   pypi_0    pypi
werkzeug                  2.1.2                    pypi_0    pypi
wheel                     0.37.1             pyhd8ed1ab_0    conda-forge
xz                        5.2.5                h642e427_1    conda-forge
yarl                      1.7.2                    pypi_0    pypi
zipp                      3.8.0                    pypi_0    pypi
zlib                      1.2.11            h90dfc92_1014    conda-forge

@awaelchli
Copy link
Contributor

We could think about adding M1 in GitHub actions / self-host it.

@paantya
Copy link

paantya commented May 29, 2022

Is training With accelerator="mps" slower than without it?

@scalastic
Copy link

@paantya In my case, yes it is.

  • With mps accelerator: 96.77it/s, GPU working at 100%, CPU at 30% approximately
  • Without mps: 498.76it/s, GPU 25%, CPU 30% also approximately

@babaniyi
Copy link

babaniyi commented May 31, 2022

@carmocca How do I resolve pytorch_lightning not working on Mac M1 but torch and torchvision are working?
After installing via import pytorch_lightning as pl I get the below error:

TypeError: Descriptors cannot not be created directly.
If this call came from a _pb2.py file, your generated code is out of date and must be regenerated with protoc >= 3.19.0.
If you cannot immediately regenerate your protos, some other possible workarounds are:
 1. Downgrade the protobuf package to 3.20.x or lower.
 2. Set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python (but this will use pure-Python parsing and will be much slower).

More information: https://developers.google.com/protocol-buffers/docs/news/2022-05-06#python-updates

@awaelchli
Copy link
Contributor

@babaniyi I think you need to downgrade protobuf to < 4.21.0, e.g., 3.20

@scalastic
Copy link

@babaniyi Here are the commands I use to configure a PyTorchLightning env for Apple Silicon.

Run at the root level of the PyTorchLightning repository on mps_accelerator branch:

% conda create -n pytorch-lightning-m1
% conda activate pytorch-lightning-m1
% conda install python=3.8
% pip install --upgrade --force-reinstall --pre torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu
% pip install -r requirements.txt
% pip install -e .
% pip install --pre torchvision --extra-index-url https://download.pytorch.org/whl/nightly/cpu --force-reinstall
% conda install -c conda-forge protobuf=3.20

Hope this helps

@babaniyi
Copy link

@awaelchli @scalastic
Thanks for your suggestions, they both work.

@gloryVine

This comment was marked as duplicate.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
accelerator feature Is an improvement or enhancement
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants