diff --git a/src/tensor_parallel/wrapper.py b/src/tensor_parallel/wrapper.py index 3d907c9..ab1fd39 100644 --- a/src/tensor_parallel/wrapper.py +++ b/src/tensor_parallel/wrapper.py @@ -73,3 +73,8 @@ def forward(self, *args, **kwargs): def __getattr__(self, attr): return getattr(self.tp_wrapped_module, attr) + + def __setattr__(self, attr, value): + super().__setattr__(attr, value) + if attr == "tp_wrapped_module": + self.__dict__["tp_wrapped_module"] = value # to access without getattr, nn.Module removed it from __dict__