diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index ff95e89d1d2cf..6b8e44b610352 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -114,8 +114,8 @@ def _wrap_init(f: Callable) -> Callable: @functools.wraps(f) def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None: params = dict(inspect.signature(module._old_init).parameters) - params.pop("args") - params.pop("kwargs") + params.pop("args", None) + params.pop("kwargs", None) for init_name, init_arg in chain(zip(params, args), kwargs.items()): setattr(module, init_name, init_arg) f(module, *args, **kwargs)