Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Neater use off nn.Sequential in controlnet #7754

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 33 additions & 46 deletions monai/networks/nets/controlnet.py
KumoLiu marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -34,84 +34,71 @@
from collections.abc import Sequence

import torch
import torch.nn.functional as F
from torch import nn

from monai.networks.blocks import Convolution
from monai.networks.nets.diffusion_model_unet import get_down_block, get_mid_block, get_timestep_embedding
from monai.utils import ensure_tuple_rep


class ControlNetConditioningEmbedding(nn.Module):
class ControlNetConditioningEmbedding(nn.Sequential):
"""
Network to encode the conditioning into a latent space.
"""

def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, channels: Sequence[int]):
super().__init__()

self.conv_in = Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=channels[0],
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
)

self.blocks = nn.ModuleList([])

convs = [
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
Convolution(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=channels[0],
strides=1,
kernel_size=3,
padding=1,
adn_ordering="A",
act="SWISH",
)
]
for i in range(len(channels) - 1):
channel_in = channels[i]
channel_out = channels[i + 1]
self.blocks.append(
convs += [
Convolution(
spatial_dims=spatial_dims,
in_channels=channel_in,
out_channels=channel_in,
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
)
)

self.blocks.append(
adn_ordering="A",
act="SWISH",
),
Convolution(
spatial_dims=spatial_dims,
in_channels=channel_in,
out_channels=channel_out,
strides=2,
kernel_size=3,
padding=1,
conv_only=True,
adn_ordering="A",
act="SWISH",
),
]
convs.append(
zero_module(
Convolution(
spatial_dims=spatial_dims,
in_channels=channels[-1],
out_channels=out_channels,
strides=1,
kernel_size=3,
padding=1,
adn_ordering="A",
act="SWISH",
marksgraham marked this conversation as resolved.
Show resolved Hide resolved
)
)

self.conv_out = zero_module(
Convolution(
spatial_dims=spatial_dims,
in_channels=channels[-1],
out_channels=out_channels,
strides=1,
kernel_size=3,
padding=1,
conv_only=True,
)
)

def forward(self, conditioning):
embedding = self.conv_in(conditioning)
embedding = F.silu(embedding)

for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)

embedding = self.conv_out(embedding)

return embedding
super().__init__(*convs)


def zero_module(module):
Expand Down
Loading