Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Record Deserializer not implemented for enum #1431

Closed
laggui opened this issue Mar 7, 2024 · 3 comments · Fixed by #1436
Closed

Record Deserializer not implemented for enum #1431

laggui opened this issue Mar 7, 2024 · 3 comments · Fixed by #1436
Assignees
Labels
enhancement Enhance existing features

Comments

@laggui
Copy link
Member

laggui commented Mar 7, 2024

Describe the bug
I stumbled upon this issue when importing a PyTorch model with PyTorchFileRecorder to load it into the Burn equivalent definition.

The PyTorch model defines a module that depends on a condition (returns a conv block in one case, and depthwise conv block in the other). For that, I implemented the block as an enum since #1337 landed. But it fails with this error:

thread 'main' panicked at burn/crates/burn-core/src/record/serde/de.rs:325:9:
not implemented: deserialize_enum is not implemented
note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

Straight from unimplemented!("deserialize_enum is not implemented") in de.rs

Would be nice to add support to deserialize with enums now that they are supported for modules.

MWE

import torch
from torch import nn, Tensor

class DwsConv(nn.Module):
    """Depthwise separable convolution."""

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int) -> None:
        super().__init__()
        # Depthwise conv
        self.dconv = nn.Conv2d(in_channels, in_channels, kernel_size, groups=in_channels)
        # Pointwise conv
        self.pconv = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=1)

    def forward(self, x: Tensor) -> Tensor:
        x = self.dconv(x)
        return self.pconv(x)


class Net(nn.Module):
    def __init__(self, depthwise: bool = False) -> None:
        super().__init__()
        self.conv = DwsConv(3, 64, 3) if depthwise else nn.Conv2d(3, 64, 3)


if __name__ == "__main__":
    torch.manual_seed(42)  # To make it reproducible
    model = Net(depthwise=True)
    model_weights = model.state_dict()
    torch.save(model_weights, "conv-dws.pt")
use burn::{
    backend::{ndarray::NdArrayDevice, NdArray},
    module::Module,
    nn::conv::{Conv2d, Conv2dConfig},
    record::{FullPrecisionSettings, Recorder},
    tensor::backend::Backend,
};
use burn_import::pytorch::PyTorchFileRecorder;

#[derive(Module, Debug)]
pub enum Conv<B: Backend> {
    DwsConv(DwsConv<B>),
    Conv(Conv2d<B>),
}

#[derive(Module, Debug)]
pub struct DwsConv<B: Backend> {
    dconv: Conv2d<B>,
    pconv: Conv2d<B>,
}

#[derive(Module, Debug)]
pub struct Net<B: Backend> {
    conv: Conv<B>,
}

fn main() {
    let device = NdArrayDevice::default();
    let model: Net<NdArray> = Net {
        conv: Conv::Conv(Conv2dConfig::new([3, 64], [3, 3]).init(&device)),
    };

    let record = PyTorchFileRecorder::<FullPrecisionSettings>::new()
        .load("conv-dws.pt".into(), &device)
        .unwrap();

    model.load_record(record);
}
@laggui
Copy link
Member Author

laggui commented Mar 7, 2024

@antimora tagging since you worked on this in the past. We can discuss this offline, would like to get your opinion.

@antimora
Copy link
Collaborator

antimora commented Mar 7, 2024

@laggui , yes, I agree we should support it.

@antimora antimora added the enhancement Enhance existing features label Mar 7, 2024
@antimora antimora assigned antimora and unassigned laggui Mar 7, 2024
@antimora
Copy link
Collaborator

antimora commented Mar 7, 2024

antimora added a commit to antimora/burn that referenced this issue Mar 8, 2024
antimora added a commit that referenced this issue Mar 11, 2024
* Add Enum module support in PyTorchFileRecorder

Fixes #1431

* Fix wording/typos per PR feedback
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Enhance existing features
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants