-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MPS Inf/Nan Loss #13285
Comments
Hi @gloryVine, I created a separate issue for this :) Thanks for the report. To be honest, I likely suspect this to be an issue with core PyTorch and not our code. We are only doing the device mapping and as long as everything is running on MPS, we are doing our job correctly. As this is still an experimental feature (from both, PyTorch and our side) it might very well be the case, that some operations do not yet work as expected. For this it would be very helpful to pin down which operation specifically does not behave as expected. Therefore it would be helpful to run the network with the trainerflag |
@gloryVine Here is where the batch gets fetched (and |
@justusschock I ran the trainer with the flag, nothing changed, including the output. @carmocca I did the following at the position in your first link: Immediately after |
@carmocca The issue is caused by PyTorch lightning, not PyTorch. Here is a minimal MNIST example that works with native PyTorch. However, I refactor the exact same code into the lightning format it does not work due to overflow/underflows in the targets. Native PyTorch version: import torch
import torchvision.transforms as transforms
import tqdm
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
batch_size = 128
device = torch.device("mps")
network = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(28 * 28, 64),
torch.nn.ReLU(),
torch.nn.Dropout(0.1),
torch.nn.Linear(64, 64),
torch.nn.ReLU(),
torch.nn.Dropout(0.1),
torch.nn.Linear(64, 10),
)
optimizer = torch.optim.Adam(network.parameters(), lr=0.001)
loss_func = torch.nn.CrossEntropyLoss()
transform = transforms.ToTensor()
dataset_train = MNIST("/tmp/data", train=True, download=True, transform=transform)
dataset_test = MNIST("/tmp/data", train=False, download=True, transform=transform)
train_dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)
network.to(device)
for epoch_idx in range(10):
train_losses = []
train_accs = []
network.train()
for x, y in tqdm.tqdm(train_dataloader):
assert (
torch.sum(y >= 10) == 0
), f"y has more than 10 unique values but got {y}"
x = x.to(device)
y = y.to(device)
logits = network(x)
y_pred = torch.argmax(logits, dim=1)
loss = loss_func(logits, y)
loss.backward()
accuracy = torch.mean((y_pred == y).float())
optimizer.step()
optimizer.zero_grad()
train_losses.append(loss.detach())
train_accs.append(accuracy)
print(
f"Epoch {epoch_idx} train loss: {torch.mean(torch.stack(train_losses))} train acc: {torch.mean(torch.stack(train_accs))}"
)
val_losses = []
val_accs = []
network.eval()
for x, y in tqdm.tqdm(test_dataloader):
x = x.to(device)
y = y.to(device)
logits = network(x)
y_pred = torch.argmax(logits, dim=1)
loss = loss_func(logits, y)
accuracy = torch.mean((y_pred == y).float())
val_losses.append(loss.detach())
val_accs.append(accuracy)
print(
f"Epoch {epoch_idx} val loss: {torch.mean(torch.stack(val_losses))} val acc: {torch.mean(torch.stack(val_accs))}"
) PyTorch lighting version: import pytorch_lightning as pl
import torch
import torchvision.transforms as transforms
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
class Model(pl.LightningModule):
def __init__(self):
super().__init__()
self.network = torch.nn.Sequential(
torch.nn.Conv2d(1, 32, 3, 1),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 64, 3, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2),
torch.nn.Dropout2d(0.1),
torch.nn.Flatten(),
torch.nn.Linear(9216, 128),
torch.nn.ReLU(),
torch.nn.Dropout1d(0.1),
torch.nn.Linear(128, 10),
)
self.loss_func = torch.nn.CrossEntropyLoss()
def training_step(self, batch, batch_nb):
x, y = batch
# I did not use torch.unique() such that you do not need PYTORCH_ENABLE_MPS_FALLBACK=1
assert torch.sum(y >= 10) == 0, f"y has more than 10 unique values but got {y}"
logits = self.network(x)
y_pred = torch.argmax(logits, dim=1)
loss = self.loss_func(logits, y)
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc", torch.mean((y_pred == y).float()), prog_bar=True)
return loss
def validation_step(self, batch, batch_nb):
x, y = batch
print(x.device)
logits = self.network(x)
y_pred = torch.argmax(logits, dim=1)
loss = self.loss_func(logits, y)
self.log("val_loss", loss, prog_bar=True)
self.log("vall_acc", torch.mean((y_pred == y).float()), prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.network.parameters(), lr=0.001)
batch_size = 128
network = Model()
dataset_train = MNIST(
"/tmp/data", train=True, download=True, transform=transforms.ToTensor()
)
dataset_test = MNIST(
"/tmp/data", train=False, download=True, transform=transforms.ToTensor()
)
train_dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(dataset_test, batch_size=batch_size, shuffle=False)
trainer = Trainer(
accelerator="mps",
logger=False,
devices=1,
enable_checkpointing=False,
)
trainer.fit(
network, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader
) Output
MacOs: Monterey 12.5 (21G72)
|
Hey @j0rd1smit thanks for the MWE and sorry for the inconvenience. I could indeed reproduce the issue on my side, but so far I have no clue what we could do causing overflows. I will investigate though and will get back to you once I found something! cc @akihironitta @awaelchli who might have ideas where this could come from |
I have been able to solve this issue for myself it by making the following change here: _MPS_DEVICES = ("mps", torch.device("mps:0"))
...
if isinstance(data, Tensor) and device not in _CPU_DEVICES and device not in _MPS_DEVICES:
kwargs["non_blocking"] = True
data_output = data.to(device, **kwargs) I happy to make a PR for it. However, I wanted to discuss the solution first because I'm not if this solution solves the real problem or just one of the symptoms. What do you think @justusschock |
@j0rd1smit Great finding! Quick search led me to pytorch/pytorch#83015, and I've just confirmed that, with |
@j0rd1smit That's indeed a great finding! the solution does look reasonable. Please go ahead with a PR! |
I am encountering a bug, namely some of my neural network's targets are corrupted using the M1 GPU. This does not happen on CPU. Specifically, some targets are set to large values (~ -2+e25), resulting in inf/nan loss.
I have isolated the behavior by stepping through the code and verifying at which steps the targets are still in the range (-1,1) as intended. The corrupted values first occur in the batch argument of validation_step(). I implemented a custom collate_fn to verify that nothing is wrong with my datasets and dataloaders. The correct targets leave the collate_fn, but then some of them appear corrupted within validation_step().
As I wrote above, the corrupted targets appear only when using the M1 GPU, whereas on the CPU everything works correctly. I could investigate this issue further if someone could tell me what happens between collate_fn and validation_step, i.e. where else I should step through the code to identify the source of corruption.
Originally posted by @gloryVine in #13102 (comment)
cc @akihironitta @justusschock
The text was updated successfully, but these errors were encountered: