Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update SD3 init parameters (replacing height, width with image_shape) #1951

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ class StableDiffusion3Backbone(Backbone):
model. Defaults to `1000`.
shift: float. The shift value for the timestep schedule. Defaults to
`3.0`.
height: optional int. The output height of the image.
width: optional int. The output width of the image.
image_shape: tuple. The input shape without the batch size. Defaults to
`(1024, 1024, 3)`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previously these were None right? shoudl we keep it that way?

Copy link
Collaborator Author

@james77777778 james77777778 Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, if we want a compilable function in MMDiT, we need to specify the height and width at instantiation.
The main barrier is at Unpatch:
https://github.com/keras-team/keras-hub/blob/master/keras_hub/src/models/stable_diffusion_3/mmdit.py#L660-L672

We can't have 2 Nones in ops.reshape

EDITED:
Previously, we set to 1024 if height and/or width is not provided.

data_format: `None` or str. If specified, either `"channels_last"` or
`"channels_first"`. The ordering of the dimensions in the
inputs. `"channels_last"` corresponds to inputs with shape
Expand Down Expand Up @@ -270,23 +270,21 @@ def __init__(
output_channels=3,
num_train_timesteps=1000,
shift=3.0,
height=None,
width=None,
image_shape=(1024, 1024, 3),
data_format=None,
dtype=None,
**kwargs,
):
height = int(height or 1024)
width = int(width or 1024)
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
"`height` and `width` must be divisible by 8. "
f"Received: height={height}, width={width}"
)
data_format = standardize_data_format(data_format)
if data_format != "channels_last":
raise NotImplementedError
image_shape = (height, width, int(vae.input_channels))
height = image_shape[0]
width = image_shape[1]
if height % 8 != 0 or width % 8 != 0:
raise ValueError(
"height and width in `image_shape` must be divisible by 8. "
f"Received: image_shape={image_shape}"
)
latent_shape = (height // 8, width // 8, int(latent_channels))
context_shape = (None, 4096 if t5 is None else t5.hidden_dim)
pooled_projection_shape = (clip_l.hidden_dim + clip_g.hidden_dim,)
Expand Down Expand Up @@ -452,8 +450,7 @@ def __init__(
self.output_channels = output_channels
self.num_train_timesteps = num_train_timesteps
self.shift = shift
self.height = height
self.width = width
self.image_shape = image_shape

@property
def latent_shape(self):
Expand Down Expand Up @@ -585,8 +582,7 @@ def get_config(self):
"output_channels": self.output_channels,
"num_train_timesteps": self.num_train_timesteps,
"shift": self.shift,
"height": self.height,
"width": self.width,
"image_shape": self.image_shape,
}
)
return config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

class StableDiffusion3BackboneTest(TestCase):
def setUp(self):
height, width = 64, 64
image_shape = (64, 64, 3)
height, width = image_shape[0], image_shape[1]
vae = VAEBackbone(
[32, 32, 32, 32],
[1, 1, 1, 1],
Expand All @@ -36,8 +37,7 @@ def setUp(self):
"vae": vae,
"clip_l": clip_l,
"clip_g": clip_g,
"height": height,
"width": width,
"image_shape": image_shape,
}
self.input_data = {
"images": ops.ones((2, height, width, 3)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class StableDiffusion3ImageToImage(ImageToImage):
Use `generate()` to do image generation.
```python
image_to_image = keras_hub.models.StableDiffusion3ImageToImage.from_preset(
"stable_diffusion_3_medium", height=512, width=512
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
)
image_to_image.generate(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def setUp(self):
clip_g=CLIPTextEncoder(
20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g"
),
height=64,
width=64,
image_shape=(64, 64, 3),
)
self.init_kwargs = {
"preprocessor": self.preprocessor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class StableDiffusion3Inpaint(Inpaint):
reference_image = np.ones((1024, 1024, 3), dtype="float32")
reference_mask = np.ones((1024, 1024), dtype="float32")
inpaint = keras_hub.models.StableDiffusion3Inpaint.from_preset(
"stable_diffusion_3_medium", height=512, width=512
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
)
inpaint.generate(
reference_image,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def setUp(self):
clip_g=CLIPTextEncoder(
20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g"
),
height=64,
width=64,
image_shape=(64, 64, 3),
)
self.init_kwargs = {
"preprocessor": self.preprocessor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
"path": "stable_diffusion_3",
"model_card": "https://arxiv.org/abs/2110.00476",
},
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/2",
"kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/3",
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class StableDiffusion3TextToImage(TextToImage):
Use `generate()` to do image generation.
```python
text_to_image = keras_hub.models.StableDiffusion3TextToImage.from_preset(
"stable_diffusion_3_medium", height=512, width=512
"stable_diffusion_3_medium", image_shape=(512, 512, 3)
)
text_to_image.generate(
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def setUp(self):
clip_g=CLIPTextEncoder(
20, 128, 128, 2, 2, 256, "gelu", -2, name="clip_g"
),
height=64,
width=64,
image_shape=(64, 64, 3),
)
self.init_kwargs = {
"preprocessor": self.preprocessor,
Expand Down
6 changes: 2 additions & 4 deletions keras_hub/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,8 @@ def get_backbone_kwargs(self, **kwargs):
backbone_kwargs["dtype"] = kwargs.pop("dtype", None)

# Forward `height` and `width` to backbone when using `TextToImage`.
if "height" in kwargs:
backbone_kwargs["height"] = kwargs.pop("height", None)
if "width" in kwargs:
backbone_kwargs["width"] = kwargs.pop("width", None)
if "image_shape" in kwargs:
backbone_kwargs["image_shape"] = kwargs.pop("image_shape", None)

return backbone_kwargs, kwargs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ def convert_model(preset, height, width):
vae,
clip_l,
clip_g,
height=height,
width=width,
image_shape=(height, width, 3),
name="stable_diffusion_3_backbone",
)
return backbone
Expand Down Expand Up @@ -532,8 +531,7 @@ def main(_):

keras_preprocessor.save_to_preset(preset)
# Set the image size to 1024, the same as in huggingface/diffusers.
keras_model.height = 1024
keras_model.width = 1024
keras_model.image_shape = (1024, 1024, 3)
keras_model.save_to_preset(preset)
print(f"🏁 Preset saved to ./{preset}.")

Expand Down
Loading