You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm currently using the Orthogonal module and define it as part of my model weights. When I tried to resume training from a checkpoint, an unexpected error occurred when I executed "load_state_dict":
Unexpected key(s) in state_dict: "rotation_matrices._B"
rotation_matrices is the name of the Orthogonal object. I think the error ocurred because when the model is initialized, rotation_matrices._B=None, so that the _B weights in the state_dict cannot be loaded.
I tried two methods to solve this problme, but both failed.
Retract _B before load_state_dict:
mod = rotation_matrices
not_B = mod._B is None
if not_B or (not mod._B.grad_fn and torch.is_grad_enabled()):
B = mod.retraction(mod.A, mod.base)
mod._B = mod.retraction(mod.A, mod.base).detach()
# Just to be safe
mod._B.requires_grad_()
# Now self._B it's not a leaf tensor, so we convert it into a leaf
mod._B.retain_grad()
... ...
# Then in the main.py, I run
model.load_state_dict()
At this point, it did not raise error. The error occurred when running backprogation loss.backward()
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: output with shape [64] doesn't match the broadcast shape [128, 768, 1, 64]
load_state_dict(state_dict, strict=False)
Instead of re-defining _B, I change the strict argument fed into load_state_dict. The error occurred when executing loss.backward():
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: The size of tensor a (64) must match the size of tensor b (32) at non-singleton dimension 0
I feel like it has something to do the optimizer. Could you give me some suggestions?
The text was updated successfully, but these errors were encountered:
Even more, in master the torch.nn.utils.parametrizations.orthogonal (to be released in PyTorch 1.11 soon) will bring an improved version of this as well.
Hi, authors. Thanks for providing this repo.
I'm currently using the Orthogonal module and define it as part of my model weights. When I tried to resume training from a checkpoint, an unexpected error occurred when I executed "load_state_dict":
rotation_matrices
is the name of the Orthogonal object. I think the error ocurred because when the model is initialized,rotation_matrices._B=None
, so that the_B
weights in thestate_dict
cannot be loaded.I tried two methods to solve this problme, but both failed.
_B
beforeload_state_dict
:At this point, it did not raise error. The error occurred when running backprogation
loss.backward()
load_state_dict(state_dict, strict=False)
Instead of re-defining
_B
, I change thestrict
argument fed intoload_state_dict
. The error occurred when executingloss.backward()
:I feel like it has something to do the optimizer. Could you give me some suggestions?
The text was updated successfully, but these errors were encountered: