Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about Group DRO implementation #33

Open
NicholasCorrado opened this issue Oct 16, 2024 · 0 comments
Open

Question about Group DRO implementation #33

NicholasCorrado opened this issue Oct 16, 2024 · 0 comments

Comments

@NicholasCorrado
Copy link

First, thank you for the solid work and for making this code public -- the paper makes some great insights and the code is very clean! I'm using this codebase a reference for implementing a variant of Group DRO, and I had a clarification question on the loss computation.

The Group DRO loss stated in the DoReMi paper is (Eq. 1):

$$\min_{\theta} \max_{\alpha \in \Delta^k} \mathcal L(\theta, \alpha) := \sum_{i=1}^k \alpha_i \left[\frac{1}{\sum_{x\in D_i}|x|} \sum_{x \in D_i}\ell_\theta(x) - \ell_\text{ref}(x)\right]$$

However, it looks like the code is actually optimizing

$$\min_{\theta} \max_{\alpha \in \Delta^k} \mathcal L(\theta, \alpha) := \frac{1}{\sum_{x\in D}|x|}\sum_{i=1}^k \alpha_i \left[\sum_{x \in D_i}\ell_\theta(x) - \ell_\text{ref}(x)\right]$$

In particular, it looks like the loss is a reweighted average over all samples across all domains rather than a reweighted sum of averages over domain-specific losses.

The domain weight update computes the average domain-specific losses here: https://github.com/sangmichaelxie/doremi/blob/7cde52d1848737aa967ecbdb9e643cf334de160d/doremi/trainer.py#L252C22-L252C110

I would expect to see a similar computation for the model parameter updates, but it looks like the code computes the total loss across all domains, reweights it by the domain weights, and then normalizes by a constant normalizer (a reweighted average loss over all samples in all domains).
https://github.com/sangmichaelxie/doremi/blob/7cde52d1848737aa967ecbdb9e643cf334de160d/doremi/trainer.py#L363C17-L363C89

Could you please clarify if the Group DRO loss stated in Eq. 1 is indeed implemented in the code, or if it is summing domain-specific average? Thank you!

@NicholasCorrado NicholasCorrado changed the title Question about DRO Question about DRO implementation Oct 16, 2024
@NicholasCorrado NicholasCorrado changed the title Question about DRO implementation Question about Group DRO implementation Oct 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant