diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index 2d387d156..0c527c352 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -346,14 +346,14 @@ def get_weighted_text_embeddings( # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2 -def pyramid_noise_like(noise, device, iterations=6, discount=0.3): - b, c, w, h = noise.shape +def pyramid_noise_like(noise, device, iterations=6, discount=0.4): + b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant! u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device) for i in range(iterations): r = random.random() * 2 + 2 # Rather than always going 2x, - w, h = max(1, int(w / (r**i))), max(1, int(h / (r**i))) - noise += u(torch.randn(b, c, w, h).to(device)) * discount**i - if w == 1 or h == 1: + wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i))) + noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i + if wn == 1 or hn == 1: break # Lowest resolution is 1x1 return noise / noise.std() # Scaled back to roughly unit variance