Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify EMA to use Pytorch's update_parameters #5469

Merged
merged 4 commits into from
Feb 27, 2022
Merged

Simplify EMA to use Pytorch's update_parameters #5469

merged 4 commits into from
Feb 27, 2022

Conversation

xiaohu2015
Copy link
Contributor

@xiaohu2015 xiaohu2015 commented Feb 24, 2022

The PR is about #5284.

with torch nightly version1.12.0.dev20220225+cu113, it runs OK:

import torch                                                                                                         


class ExponentialMovingAverageV1(torch.optim.swa_utils.AveragedModel):
    """Maintains moving averages of model parameters using an exponential decay.
    ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
    `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_    is used to compute the EMA.
    """
    def __init__(self, model, decay, device='cpu'):
        ema_avg = (lambda avg_model_param, model_param, num_averaged:
                   decay * avg_model_param + (1 - decay) * model_param)
        super().__init__(model, device, ema_avg) 

    def update_parameters(self, model):
        for p_swa, p_model in zip(self.module.state_dict().values(), model.state_dict().values()):
            device = p_swa.device 
            p_model_ = p_model.detach().to(device)
            if self.n_averaged == 0:
                p_swa.detach().copy_(p_model_)
            else:
                p_swa.detach().copy_(self.avg_fn(p_swa.detach(), p_model_,
                                     self.n_averaged.to(device)))
        self.n_averaged += 1

class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
    """Maintains moving averages of model parameters using an exponential decay.
    ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
    `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
    is used to compute the EMA.    """
    def __init__(self, model, decay, device='cpu'):
        ema_avg = (lambda avg_model_param, model_param, num_averaged:
                   decay * avg_model_param + (1 - decay) * model_param)
        super().__init__(model, device, ema_avg, use_buffers=True)



class ToyModel(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        self.x = torch.nn.Parameter(torch.zeros(5))
        self.register_buffer('y', torch.zeros(5))


    def forward(self, input):
        self.x += input
        self.y += input
        return self.x, self.y

decay = 0.9
model1 = ToyModel()
ema1 = ExponentialMovingAverageV1(model1, decay)
model2 = ToyModel()
ema2 = ExponentialMovingAverage(model2, decay)

x = torch.ones(5)

for _ in range(10):
    with torch.no_grad():
        model1(x)
        model2(x)
        ema1.update_parameters(model1)
        ema2.update_parameters(model2)

        assert torch.equal(ema1.module.x, ema2.module.x)
        assert torch.equal(ema1.module.y, ema2.module.y) 

@facebook-github-bot
Copy link

facebook-github-bot commented Feb 24, 2022

💊 CI failures summary and remediations

As of commit 31b0e04 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@xiaohu2015 xiaohu2015 marked this pull request as draft February 24, 2022 11:29
@xiaohu2015 xiaohu2015 marked this pull request as ready for review February 26, 2022 04:45
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xiaohu2015 Thanks a lot for the contribution and for providing a snippet that verifies that the implementation works fine. I think we can merge this. :)

@datumbox datumbox merged commit f40c8df into pytorch:main Feb 27, 2022
@datumbox datumbox linked an issue Feb 27, 2022 that may be closed by this pull request
@xiaohu2015 xiaohu2015 deleted the patch-1 branch February 27, 2022 14:26
facebook-github-bot pushed a commit that referenced this pull request Mar 3, 2022
Summary: Co-authored-by: Vasilis Vryniotis <[email protected]>

Reviewed By: datumbox

Differential Revision: D34579515

fbshipit-source-id: 6f563a48305dc1c9d99274d40c15416075c9b20f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Simplify EMA to use Pytorch's update_parameters
3 participants