Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to use Scheduled Huber Loss in all training pipelines to improve resilience to data corruption #1228

Merged
merged 14 commits into from
Apr 7, 2024

Conversation

kabachuha
Copy link
Contributor

As heavily discussed in #294, the presence even of a small number of outliers can heavily distort the image output.

While there have been proposed alternative losses, such as Huber loss, the problem was the loss of fine details when training the pictures. After researching the mathematical formulations of diffusion models, we came to the conclusion that it's possible to schedule the Huber loss, making it smoothly transition from Huber loss with L1 asymptotics on the first reverse-diffusion timesteps (when the image only begins to form and is most vulnerable to concept outliers) to the standard L2 MSE loss, for which diffusion models are originally formulated, on the last reverse-diffusion timesteps, when the fine-details of the image are forming.

Our method shows greater stability and resilience (similarity to clean pictures on corrupted runs - similarity to clean pictures on clean runs), than both pure Huber and L2 losses.

image

The experiments confirm that indeed this schedule improves the resilience greatly

image

Our paper: https://arxiv.org/abs/2403.16728

Diffusers discussion huggingface/diffusers#7488

Most importantly, this approach has virtually no computational costs over the standard L2 computation. (and minimal code changes to the training scripts)

cc @cheald @kohya-ss

@kohya-ss
Copy link
Owner

Thank you for submitting this pull request! Although I don't fully understand the contents of the discussion in #294, it seems that a very interesting discussion is taking place.

This PR appears to have the objective of learning while balancing the large features and fine details of the image.

If I understand correctly, this PR has a similar objective to the method proposed by cheald in #294, but takes a different approach. Is this correct?

@kabachuha
Copy link
Contributor Author

These are quite different approaches. @cheald makes adjustments to latents, then feeds them into the standard loss.

And as cheald is using mse_loss in the end, I believe these approaches can exist in synergy!

@kohya-ss
Copy link
Owner

Thank you clarification! I will merge this sooner.

@cheald
Copy link

cheald commented Mar 31, 2024

This is very exciting! I've been experimenting with manipulating the mean/stddev of slices of the noise prior to forward noising, and have found that I can directly manipulate the level of detail trained with certain permutations of noise. However, it still suffers from outliers early in training having too large an impact, which I've had to manage through very careful tuning of noise. I'm very excited to try adding this into my experiments - if it performs how I imagine, then it could solve a number of problems which could lead to faster and more controlled training.

@gesen2egee
Copy link
Contributor

Tried a training, the results were impressive.

@kabachuha
Copy link
Contributor Author

@gesen2egee Thank you for giving a test run! Would you mind sharing some of them?

@cheald
Copy link

cheald commented Mar 31, 2024

Here's my quick training experiment. I just ran each for 9 epochs. Overall, results look very promising.

Ground truth:

ComfyUI_00049_

General settings: adamw8bit, CosineAnnealingWarmRestarts w/ restart every 2 epochs, unet_lr 1e-4, and I am using masks and masked loss here.

Each image pair is l2 on the left, huber_scheduled on the right.

out1

Marginal improvement, I think!


Here's where it gets fun:

I'm experimenting with recentering and rescaling each noise channel individually, and then also experimenting with shifting and rescaling all noise channels together, as well as independently. This is conceptually similar to an expansion of the idea encoded in the noise_offset routine. It's yielding some very interesting results.

Dependent scale, dependent shift:

out2

Independent channel scaling, dependent channel shift:

out3

Dependent scaling, independent shift:

out4

and finally, independent scaling, independent shift:

out5


My general observation here is that the huber_scheduled loss definitely does improve detail retention (look at the brick in the background!), but isn't quite learning as well. However, I suspect that is likely due to nothing more than the lower loss values damping the rate of change in the learned weights. It might be that using huber_loss would permit a larger learning rate relative to l2 loss, which would be great if it can retain the same improvement in details.

@cheald
Copy link

cheald commented Apr 1, 2024

Here's another example. I'm working on extending the dynamic range of my trained samples, which I'm doing by pushing the mean of the first channel in the noise. I also cranked my LR up to 4e-4 to see if this would give the huber noise an edge, but it appears to actually not be working that way.

Same seed and training params for all 4 images, the left images are a "night time photo" and the right images are a "bright daytime photo". Upper is l2, lower is huber.

The huber examples have distorted less, but are certainly less like the ground truth overall. Is there guidance on how to set the delta parameter (huber_c?) to achieve a middle ground?

Night L2 Day L2
Night Huber Day Huber

out

Loss curves for each:

2024-03-31_17-26

@feffy380
Copy link
Contributor

feffy380 commented Apr 1, 2024

Excited to try this out. Any idea how this interacts with Min-SNR-gamma, which weights the loss based on timestep?

@drhead
Copy link

drhead commented Apr 1, 2024

I'm looking forward to trying this out. I do have a concern about part of the paper though, namely the function defining the schedule:
image
I don't think it makes much sense to define the schedule in terms of the indices of timesteps since they will vary between models, and even within SD1.5 there's zero terminal SNR which does modify the noise schedule. If the reasoning is that loss should behave more like MAE at early timesteps/high noise levels and more like MSE at late timesteps/low noise levels, wouldn't it make more sense to define the schedule in terms of something like SNR directly? From playing around in a notebook quickly, it seems that 1 - sqrt(1.0 - alphas_cumprod) is fairly close (though I can't say with any certainty that it is the value that makes the most sense to use),

edit: 1 / (1 + sigmas)**2 is closer and is directly derived from noise levels (where sigmas = (((1.0 - alphas_cumprod) / alphas_cumprod) ** 0.5)), and also behaves fine when snr is zero (sigma will be infinity, resulting delta will be zero)

@kabachuha
Copy link
Contributor Author

kabachuha commented Apr 1, 2024

@cheald @drhead we chose the exponential because of its simplicity, to test the claim that it should decrease with the (forward) diffusion timestep. It indeed quite likely may be suboptimal, and the idea about the snr-scheduling sounds very reasonable!

Adding some selectable huber_c schedules, I think, would be a great next addition


Also, interesting case with delta===0. With delta arbitrary close to zero, the pseudo-Huber loss function will converge to MAE, while with delta exactly zero (zero snr), the entire function will collapse to the flat line. So we may need to handle this specific case and limit delta somehow? It seems like it's ok (see https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html#torch.nn.SmoothL1Loss, as we need L2 convergence more than L1)

@kabachuha
Copy link
Contributor Author

kabachuha commented Apr 1, 2024

There is another alternative for huber function: pseudo-Huber loss divided by delta:

$$\sqrt{\delta^2 + a^2} - \delta$$

whereas the "math" version is:

$$\delta^2*(\sqrt{1 + \frac{a^2}{\delta^2}} - 1)$$

The difference is math version converging to zero when the delta goes to zero, while the modified version transitions between L1 and L2

While the divided version gave us worse results for resilience, it may be better perceptually. (didn't analyze the perceptual part much, unfortunately) The former version, suggested by OpenAI, is used in Diffusers for LCM training.

Here's the gif comparison:

outfile.mp4

I think it may be worth adding it here too and making some more experiments

Edit: fixed parabola's coefficient to suit the losses (1/2 a^2)


The problem may lie in the beginning of the delta values: when delta is zero (snr = 0), as the OpenAI's Huber function is above L2, it will make the model actively learn, and thus vulnerable to outliers, while our math variant will not really take into account what is happening at pure noise.

It may explain, why the OpenAI's loss function fails in our resilience experiments greatly (huber_scheduled_old)

A good compromise may be adding a minimal delta value to pad it at zero snr, and that's quite what we did in our experiments


Another Edit: pseudo huber loss computation needs to get *2 multiplier to to better correspond to MSE's coeffs, as from Taylor expansion, the pseudo-huber loss's asymptotic is ~1/2 * a^2, that leads to discrepancy when the MSE's a^2 parabola is far away from the formed curve at a=1.

@kabachuha
Copy link
Contributor Author

kabachuha commented Apr 1, 2024

@drhead the second main guy in my team agreed that 1 / (1 + sigmas)**2 makes sense, so I think delta = (1-delta_0) / (1 + sigmas)**2 + delta_0, accounting for the zero-snr padding, may do a good gob. Will make some tests once I'll get home

@kabachuha kabachuha marked this pull request as draft April 1, 2024 13:18
The Taylor expansion of sqrt near zero gives 1/2 a^2, which differs from a^2 of the standard MSE loss. This change scales them better against one another
@drhead
Copy link

drhead commented Apr 3, 2024

Interesting... so, playing around with those functions, the main difference seems to be that the math version is implicitly weighting timesteps by delta? It may in fact be the primary mechanism at play here, and is still a useful finding on its own -- in fact, I've been experimenting with learned timestep weighting lately (through a reapplication of the methods suggested in the EDM2 paper, which makes learned timestep weightings viable on lower batch sizes), and the timestep weight schedule that my model has been learning is in fact very similar to what this formula is implying (though my model uses v-prediction, so the timestep weights might not be what it converges to on epsilon prediction, and it may also be due to the dataset I am using being mostly digital art):
loss_weights_20299_ec3418efd5fccaa41eb7
I will be testing this out soon (under data that I expect is "clean" but may have some bad labeling, and under the OpenAI formula for simplicity, since my adaptive timesteps will end up making them the same anyways), but I would suggest doing an ablation test on simply weighting L2 loss timesteps by your delta schedule so you can see what part of it is the timestep weighting and what part of it is the huber loss.

@kabachuha
Copy link
Contributor Author

Good suggestion! Btw I'm making the experiments with different losses/schedules right now. Will post it here soon

@kabachuha
Copy link
Contributor Author

kabachuha commented Apr 3, 2024

@drhead (on David Revoy dataset https://drive.google.com/drive/folders/1Z4gVNugFK2RXQEO2yiohFrIbhP00tIOo?usp=drive_link)

I think, subjectively, I divide between SNR Smooth L1 and SNR Huber (they both have strong and weak sides)

Constant L2's parameter is 1, because it's the final delta value

image

Robustness is another thing, and it needs its own tests


All the generation samples from the schedules/loss types experiments above (16 per each):

https://drive.google.com/drive/folders/1DnU-o_TT9JH8l1JS_WuQfeZ-k6uafCJo?usp=drive_link

@kabachuha kabachuha marked this pull request as ready for review April 3, 2024 20:49
@drhead
Copy link

drhead commented Apr 6, 2024

Well, I've completed a small-ish ablation test of my own on a larger finetune of data that is in theory clean, and the results look pretty promising!

Here's regular MSE loss with my adaptive loss weights:
huber_ablation_control

And here's with scheduled pseudo-huber loss:
huber_ablation_test

Backgrounds are much sharper, and most significantly the character's form is more consistent. Very happy with how this turned out and I think I will continue to use this.

Interestingly, my loss weight model chose a different weighting (ignore the fact that it looks weird on the tails, I don't think this is a true reflection of the ideal timestep weighting and is instead an artifact of the model used to train timestep weights which I am still trying to nail down the correct hyperparameters for).
loss_weights_20099_a2cade8c70c640af9afd
This effectively reflects the relative difficulty of each timestep (in form of the weighting required to equalize the model's capability at each timestep). So, these changes indicate that scheduled pseudo huber loss made the high-noise timesteps relatively easier, and later timesteps relatively harder. Again, my model is v-prediction, so the high noise timesteps are closer to x0 prediction and lower noise timesteps are closer to epsilon prediction -- you would probably see different results on epsilon prediction. But I do find this to be an interesting result.

@kohya-ss
Copy link
Owner

kohya-ss commented Apr 7, 2024

Thanks again for the PR and great discussion! I have created a brief description to add to the release notes to explain the new features that this PR offers. Any comments would be appreciated.


Scheduled Huber Loss has been introduced to each training scripts. This is a method to improve robustness against outliers or anomalies (data corruption) in the training data.

With the traditional MSE (L2) loss function, the impact of outliers could be significant, potentially leading to a degradation in the quality of generated images. On the other hand, while the Huber loss function can suppress the influence of outliers, it tends to compromise the reproduction of fine details in images.

To address this, the proposed method employs a clever application of the Huber loss function. By scheduling the use of Huber loss in the early stages of training (when noise is high) and MSE in the later stages, it strikes a balance between outlier robustness and fine detail reproduction.

Experimental results have confirmed that this method achieves higher accuracy on data containing outliers compared to pure Huber loss or MSE. The increase in computational cost is minimal.

The newly added arguments loss_type, huber_schedule, and huber_c allow for the selection of the loss function type (Huber, smooth L1, MSE), scheduling method (exponential, constant, SNR), and Huber's parameter. This enables optimization based on the characteristics of the dataset.

See #1228 for details.

@kohya-ss kohya-ss changed the base branch from main to dev April 7, 2024 04:49
@kohya-ss kohya-ss merged commit 90b1879 into kohya-ss:dev Apr 7, 2024
1 check passed
@kabachuha
Copy link
Contributor Author

The default is exponential

I think it will be better to use 'snr' by default as it have been shown to have better quality

@dill-shower
Copy link

Is there any information about when different scheduling method and loss function type should be applied? For example, for a large dataset one, for a small dataset another. For training on photos one, for training on anime artwork another

@kohya-ss
Copy link
Owner

kohya-ss commented Apr 7, 2024

The default is exponential

I think it will be better to use 'snr' by default as it have been shown to have better quality

Thank you! I have updated it.

@kabachuha
Copy link
Contributor Author

kabachuha commented Apr 7, 2024

It's still "exponential" in the code itself

default="exponential",

Edit: Oops, I refered to the dev branch, and it's corrected on the main

@kohya-ss
Copy link
Owner

kohya-ss commented Apr 7, 2024

Edit: Oops, I refered to the dev branch, and it's corrected on the main

Sorry, I updated directry the main🙇‍♂️ Thank you for your confirmation!

@drhead
Copy link

drhead commented Apr 14, 2024

I did want to follow up on this since some further experiments led me to notice some tradeoffs with the outcomes of training on this loss function.

While the benefits to image coherency still seem very clear, it seems that on my training setup, one of the side effects of this loss function is a loss of control over image contrast -- manifesting as things like "dark" in a prompt being less powerful on zero terminal SNR models. Incidentally, I have a very good idea of what the root cause might be, and it makes perfect sense given what this loss function does.

TL;DR of the above post is that the VAE used by SD 1.5 (and SD 2.1 and Midjourney and the Pixart series of models and DALL-E 3 and EDM2 and probably many others), tends to produce a high magnitude artifact, especially on large images and desaturated images (noting that desaturated in latent space is closer to sepia). My theory is that the purpose of this artifact is to abuse the normalization layers of the model by placing an artifact that serves to desaturate the image by having a high magnitude that will cause other values to decrease significantly as they pass through the model normalization layers. @madebyollin made some very helpful graphics demonstrating some of these effects: https://gist.github.com/madebyollin/ff6aeadf27b2edbc51d05d5f97a595d9

With this in mind, it does make perfect sense why this would happen, especially on my training setup. I'm training a high resolution model (actually with several different resolution groups) with v-prediction. With v-prediction, higher noise timesteps are closer to x_0 prediction, and the terminal timestep is outright x_0 prediction. This means that the prediction target has something that is both high magnitude and extremely important for image reconstruction. Errors are generally likely to be proportional to the magnitude of the target, which is not good when our loss function is more relaxed on large errors at this point. The loss function can't take into account the fact that this will cause a massive error in pixel space -- for MSE loss, this would be less of a problem, since larger errors would be pulled in much harder. While the loss objective does effectively ignore large outliers from the data in this way, the VAE artifact is sadly not separable from this. It is possible that it could be mitigated by finding some metric for a latent that would indicate consistency in its saturation level, but the solution that makes more sense is to just move to a model that doesn't use the SD1.5 VAE.

As a final note, I have found the effects of this to be very similar to the effects of FreeU, both in its improvements of image coherence and in having a similar issue with weakening of saturation (and they do combine together well, I was using it with my samples above). FreeU's underlying theory is that high-frequency information in the skip connections converges too fast in the decoder and therefore the authors choose to modulate backbone features and skip connections to place more emphasis on low frequency feature maps. That might be related to why this improves coherency -- mitigating the impact of outliers allowing for low frequency features to be better represented. I'm sure there's a lot more research to be done on optimizing hierarchical denoising with things like this.

All of that being said, I do think that these issues are not likely to be a problem for epsilon prediction since the VAE artifact is not directly part of the prediction target, and it is not likely to be much of an issue on SDXL or for several other models using a different VAE or for pixel-space models. I'll be trying it out on SD3 when I can, hopefully their VAE doesn't have the same issues.

@Deathawaits4
Copy link

Can i ask which function im supposed to use to actually use scheduled huber loss in kohya? there is no explanation.. Do i need to set smooth_l1? or huber? What are the settings im supposed to use?

@rockerBOO
Copy link
Contributor

Can i ask which function im supposed to use to actually use scheduled huber loss in kohya? there is no explanation.. Do i need to set smooth_l1? or huber? What are the settings im supposed to use?

There is some details explained on the main README about-scheduled-huber-loss

@araleza
Copy link

araleza commented May 12, 2024

Hey, I just tried out this new Huber loss parameter... and it's amazing.

Thanks for all the hard work by the people here!

@sangoi-exe
Copy link

What’s the practical effect of increasing or decreasing the value of huber_c?

@DarkAlchy
Copy link

What’s the practical effect of increasing or decreasing the value of huber_c?

My experience is total destruction of the training almost instantly. Even a 0.01 in either direction resulted in total devastation/noise.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.