Skip to content

Commit

Permalink
Better messaging around PyTorch workaround (#509)
Browse files Browse the repository at this point in the history
Priority is compatibility w/ PyTorch; left a link for folks to learn more if they want to customize.
  • Loading branch information
sdatkinson authored Nov 24, 2024
1 parent 9a1c72e commit f2c3ff9
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f environments/requirements.txt ]; then pip install -r environments/requirements.txt; fi
python -m pip install .
- name: Lint with flake8
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# Created Date: Saturday February 13th 2021
# Author: Steven Atkinson ([email protected])

# Environment for CPU and macOS (Intel and Apple Silicon)

name: nam
channels:
- conda-forge # pytest-mock
Expand All @@ -19,6 +21,8 @@ dependencies:
- pydantic
- pytest
- pytest-mock
# Performance note:
# https://github.com/sdatkinson/neural-amp-modeler/issues/505
- pytorch
- scipy
- semver
Expand Down
File renamed without changes.
5 changes: 3 additions & 2 deletions requirements.txt → environments/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ pytest-mock
pytorch_lightning
scipy
sounddevice
# Performance note: https://github.com/sdatkinson/neural-amp-modeler/issues/505
torch
# Not required, but if you have it, it needs to be recent enough so I'm adding
# it.
# `transformers` is not required, but if you have it, it needs to be recent
# enough so I'm adding it.
transformers>=4
tqdm
wavio
Expand Down
11 changes: 11 additions & 0 deletions nam/models/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,17 @@ def _forward_mps_safe(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
return self._forward(x, **kwargs)
except NotImplementedError as e:
if "Output channels > 65536 not supported at the MPS device." in str(e):
print(
"===WARNING===\n"
"NAM encountered a bug in PyTorch's MPS backend and will "
"switch to a fallback.\n"
f"Your version of PyTorch is {torch.__version__}.\n"
"Please report this in an Issue at:\n"
"https://github.com/sdatkinson/neural-amp-modeler/issues/new/choose"
"\n"
"so that NAM's dependencies can avoid buggy versions of "
"PyTorch and the associated performance hit."
)
self._mps_65536_fallback = True
return self._forward_mps_safe(x, **kwargs)
else:
Expand Down
2 changes: 1 addition & 1 deletion nam/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Author: Steven Atkinson ([email protected])

"""
Implements the base PyTorch Lightning model.
Implements the base PyTorch Lightning module.
This is meant to combine an actual model (subclassed from `._base.BaseNet`)
along with loss function boilerplate.
Expand Down

0 comments on commit f2c3ff9

Please sign in to comment.