Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

keras.mixed_precision not working with TorchModuleWrapper #20726

Open
yonigottesman opened this issue Jan 5, 2025 · 1 comment
Open

keras.mixed_precision not working with TorchModuleWrapper #20726

yonigottesman opened this issue Jan 5, 2025 · 1 comment

Comments

@yonigottesman
Copy link

When using a torch model with TorchModuleWrapper, the mixed_precision doesnt work.
I guess somehow in the call of TorchModuleWrapper we are supposed to wrap the call to the torch model with
with torch.cuda.amp.autocast():
Here is some code that doesnt work:

import os
os.environ["KERAS_BACKEND"] = "torch"
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import keras

keras.mixed_precision.set_global_policy("mixed_float16")
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)
train_dataloader = DataLoader(training_data, batch_size=64)

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to("cuda")

inputs = keras.layers.Input(shape=(1, 28,28))
outputs = keras.layers.TorchModuleWrapper(model)(inputs)
keras_model = keras.models.Model(inputs,outputs)

keras_model.compile( optimizer=keras.optimizers.SGD(learning_rate=1e-3),loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True))
keras_model.fit(train_dataloader)
@sonali-kumari1
Copy link
Contributor

Hi @yonigottesman -

Thanks for reporting this issue.
The error you are getting is Error encountered: mat1 and mat2 must have the same dtype, but got Half and Float because data type of input is half(float16) and weights or operations inside the model is Float(float32). Since you are using keras.mixed_precision.set_global_policy("mixed_float16"), you can explicitly call model=model.half() which will convert all the model parameters to float16.
Attaching gist for your reference.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants