Skip to content

Commit

Permalink
Merge pull request #1125 from bghira/feature/sd3-skip-layer
Browse files Browse the repository at this point in the history
sd3: add skip layer guidance
  • Loading branch information
bghira authored Nov 8, 2024
2 parents 33ad3ca + e33d588 commit bf6c23a
Show file tree
Hide file tree
Showing 10 changed files with 1,128 additions and 115 deletions.
35 changes: 32 additions & 3 deletions documentation/quickstart/SD3.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,33 @@ For more information, see the [dataloader](/documentation/DATALOADER.md) and [tu

## Notes & troubleshooting tips

### Skip-layer guidance (SD3.5 Medium)

StabilityAI recommends enabling SLG (Skip-layer guidance) on SD 3.5 Medium inference. This doesn't impact training results, only the validation sample quality.

The following values are recommended for `config.json`:

```json
{
"--validation_guidance_skip_layers": [7, 8, 9],
"--validation_guidance_skip_layers_start": 0.01,
"--validation_guidance_skip_layers_stop": 0.2,
"--validation_guidance_skip_scale": 2.8,
"--validation_guidance": 4.0
}
```

- `..skip_scale` determines how much to scale the positive prompt prediction during skip-layer guidance. The default value of 2.8 is safe for the base model's skip value of `7, 8, 9` but will need to be increased if more layers are skipped, doubling it for each additional layer.
- `..skip_layers` tells which layers to skip during the negative prompt prediction.
- `..skip_layers_start` determine the fraction of the inference pipeline during which skip-layer guidance should begin to be applied.
- `..skip_layers_stop` will set the fraction of the total number of inference steps after which SLG will no longer be applied.

SLG can be applied for fewer steps for a weaker effect or less reduction of inference speed.

It seems that extensive training of a LoRA or LyCORIS model will require modification to these values, though it's not clear how exactly it changes.

**Lower CFG must be used during inference.**

### Model instability

The SD 3.5 Large 8B model has potential instabilities during training:
Expand All @@ -288,12 +315,14 @@ Some changes were made to SimpleTuner's SD3.5 support:
#### Stable configuration values

These options have been known to keep SD3.5 in-tact for as long as possible:
- optimizer=optimi-stableadamw
- learning_rate=1e-5
- optimizer=adamw_bf16
- flux_schedule_shift=1
- learning_rate=1e-4
- batch_size=4 * 3 GPUs
- max_grad_norm=0.01
- max_grad_norm=0.1
- base_model_precision=int8-quanto
- No loss masking or dataset regularisation, as their contribution to this instability is unknown
- `validation_guidance_skip_layers=[7,8,9]`

### Lowest VRAM config

Expand Down
43 changes: 43 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Dict, List, Optional, Tuple
import random
import time
import json
import logging
import sys
import torch
Expand Down Expand Up @@ -1350,6 +1351,37 @@ def get_argument_parser():
" the default mode, provides the most benefit."
),
)
parser.add_argument(
"--validation_guidance_skip_layers",
type=str,
default=None,
help=(
"StabilityAI recommends a value of [7, 8, 9] for Stable Diffusion 3.5 Medium."
),
)
parser.add_argument(
"--validation_guidance_skip_layers_start",
type=float,
default=0.01,
help=("StabilityAI recommends a value of 0.01 for SLG start."),
)
parser.add_argument(
"--validation_guidance_skip_layers_stop",
type=float,
default=0.01,
help=("StabilityAI recommends a value of 0.2 for SLG start."),
)
parser.add_argument(
"--validation_guidance_skip_scale",
type=float,
default=2.8,
help=(
"StabilityAI recommends a value of 2.8 for SLG guidance skip scaling."
" When adding more layers, you must increase the scale, eg. adding one more layer requires doubling"
" the value given."
),
)

parser.add_argument(
"--allow_tf32",
action="store_true",
Expand Down Expand Up @@ -2391,4 +2423,15 @@ def parse_cmdline_args(input_args=None):
f"Invalid gradient_accumulation_steps parameter: {args.gradient_accumulation_steps}, should be >= 1"
)

if args.validation_guidance_skip_layers is not None:
try:
import json

args.validation_guidance_skip_layers = json.loads(
args.validation_guidance_skip_layers
)
except Exception as e:
logger.error(f"Could not load skip layers: {e}")
raise

return args
Loading

0 comments on commit bf6c23a

Please sign in to comment.