Skip to content

Commit

Permalink
feat: add diffusion ar, remove multiblock patching, move pretrained m…
Browse files Browse the repository at this point in the history
…odels
  • Loading branch information
flavioschneider committed Nov 25, 2022
1 parent c0020d5 commit c937e49
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 103 deletions.
23 changes: 1 addition & 22 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<img src="./LOGO.png"></img>

Unconditional audio generation using diffusion models, in PyTorch. The goal of this repository is to explore different architectures and diffusion models to generate audio (speech and music) directly from/to the waveform.
Progress will be documented in the [experiments](#experiments) section. You can use the [`audio-diffusion-pytorch-trainer`](https://github.com/archinetai/audio-diffusion-pytorch-trainer) to run your own experiments – please share your findings in the [discussions](https://github.com/archinetai/audio-diffusion-pytorch/discussions) page!
Progress will be documented in the [experiments](#experiments) section. You can use the [`audio-diffusion-pytorch-trainer`](https://github.com/archinetai/audio-diffusion-pytorch-trainer) to run your own experiments – please share your findings in the [discussions](https://github.com/archinetai/audio-diffusion-pytorch/discussions) page! Pretrained models can be found at [`archisound`](https://github.com/archinetai/archisound).

## Install

Expand Down Expand Up @@ -241,27 +241,6 @@ composer = SpanBySpanComposer(
y_long = composer(y, keep_start=True) # [1, 1, 98304]
```

## Pretrained Models

### Diffusion (Magnitude) AutoEncoder ([`dmae1d-ATC64-v1`](https://huggingface.co/archinetai/dmae1d-ATC64-v1/tree/main))
```py
from audio_diffusion_pytorch import AudioModel

autoencoder = AudioModel.from_pretrained("dmae1d-ATC64-v1")

x = torch.randn(1, 2, 2**18)
z = autoencoder.encode(x) # [1, 32, 256]
y = autoencoder.decode(z, num_steps=20) # [1, 2, 262144]
```

| Info | |
| ------------- | ------------- |
| Input type | Audio (stereo @ 48kHz) |
| Number of parameters | 234.2M |
| Compression Factor | 64x |
| Downsampling Factor | 1024x |
| Bottleneck Type | Tanh |


## Experiments

Expand Down
2 changes: 1 addition & 1 deletion audio_diffusion_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
AudioDiffusionUpphaser,
AudioDiffusionUpsampler,
AudioDiffusionVocoder,
AudioModel,
DiffusionAR1d,
DiffusionAutoencoder1d,
DiffusionMAE1d,
DiffusionUpphaser1d,
Expand Down
120 changes: 103 additions & 17 deletions audio_diffusion_pytorch/model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from math import pi
from random import randint
from typing import Any, Optional, Sequence, Tuple, Union

import torch
from audio_encoders_pytorch import Bottleneck, Encoder1d
from einops import rearrange
from torch import Tensor, nn
from tqdm import tqdm

from .diffusion import LinearSchedule, UniformDistribution, VSampler, XDiffusion
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d
from .modules import STFT, SinusoidalEmbedding, UNet1d, UNetConditional1d, rand_bool
from .utils import (
closest_power_2,
default,
Expand Down Expand Up @@ -355,6 +357,105 @@ def forward(self, x: Tensor, **kwargs) -> Tensor:
return self.diffusion(x, channels_list=[resampled], features=features, **kwargs)


class DiffusionAR1d(Model1d):
def __init__(
self,
in_channels: int,
chunk_length: int,
upsample: int = 0,
dropout: float = 0.05,
verbose: int = 0,
**kwargs,
):
self.in_channels = in_channels
self.chunk_length = chunk_length
self.dropout = dropout
self.upsample = upsample
self.verbose = verbose
super().__init__(
in_channels=in_channels,
context_channels=[in_channels * (2 if upsample > 0 else 1)],
**kwargs,
)

def reupsample(self, x: Tensor) -> Tensor:
x = x.clone()
x = downsample(x, factor=self.upsample)
x = upsample(x, factor=self.upsample)
return x

def forward(self, x: Tensor, **kwargs) -> Tensor:
b, _, t, device = *x.shape, x.device
cl, num_chunks = self.chunk_length, t // self.chunk_length
assert num_chunks >= 2, "Input tensor length must be >= chunk_length * 2"

# Get prev and current target chunks
chunk_index = randint(0, num_chunks - 2)
chunk_pos = cl * (chunk_index + 1)
chunk_prev = x[:, :, cl * chunk_index : chunk_pos]
chunk_curr = x[:, :, chunk_pos : cl * (chunk_index + 2)]

# Randomly dropout source chunks to allow for zero AR start
if self.dropout > 0:
batch_mask = rand_bool(shape=(b, 1, 1), proba=self.dropout, device=device)
chunk_zeros = torch.zeros_like(chunk_prev)
chunk_prev = torch.where(batch_mask, chunk_zeros, chunk_prev)

# Condition on previous chunk and reupsampled current if required
if self.upsample > 0:
chunk_reupsampled = self.reupsample(chunk_curr)
channels_list = [torch.cat([chunk_prev, chunk_reupsampled], dim=1)]
else:
channels_list = [chunk_prev]

# Diffuse current current chunk
return self.diffusion(chunk_curr, channels_list=channels_list, **kwargs)

def sample(self, x: Tensor, start: Optional[Tensor] = None, **kwargs) -> Tensor: # type: ignore # noqa
noise = x

if self.upsample > 0:
# In this case we assume that x is the downsampled audio instead of noise
upsampled = upsample(x, factor=self.upsample)
noise = torch.randn_like(upsampled)

b, c, t, device = *noise.shape, noise.device
cl, num_chunks = self.chunk_length, t // self.chunk_length
assert c == self.in_channels
assert t % cl == 0, "noise must be divisible by chunk_length"

# Initialize previous chunk
if exists(start):
chunk_prev = start[:, :, -cl:]
else:
chunk_prev = torch.zeros(b, c, cl).to(device)

# Computed chunks
chunks = []

for i in tqdm(range(num_chunks), disable=(self.verbose == 0)):
# Chunk noise
chunk_start, chunk_end = cl * i, cl * (i + 1)
noise_curr = noise[:, :, chunk_start:chunk_end]

# Condition on previous chunk and artifically upsampled current if required
if self.upsample > 0:
chunk_upsampled = upsampled[:, :, chunk_start:chunk_end]
channels_list = [torch.cat([chunk_prev, chunk_upsampled], dim=1)]
else:
channels_list = [chunk_prev]
default_kwargs = dict(channels_list=channels_list)

# Sample current chunk
chunk_curr = super().sample(noise_curr, **{**default_kwargs, **kwargs})

# Save chunk and use current as prev
chunks += [chunk_curr]
chunk_prev = chunk_curr

return rearrange(chunks, "l b c t -> b c (l t)")


"""
Audio Diffusion Classes (specific for 1d audio data)
"""
Expand All @@ -363,7 +464,7 @@ def forward(self, x: Tensor, **kwargs) -> Tensor:
def get_default_model_kwargs():
return dict(
channels=128,
patch_factor=16,
patch_size=16,
multipliers=[1, 2, 4, 4, 4, 4, 4],
factors=[4, 4, 4, 2, 2, 2],
num_blocks=[2, 2, 2, 2, 2, 2],
Expand Down Expand Up @@ -500,18 +601,3 @@ def __init__(self, in_channels: int, **kwargs):

def sample(self, *args, **kwargs):
return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})


""" Pretrained Models Helper """

REVISION = {"dmae1d-ATC64-v1": "07885065867977af43b460bb9c1422bdc90c29a0"}


class AudioModel:
@staticmethod
def from_pretrained(name: str) -> nn.Module:
from transformers import AutoModel

return AutoModel.from_pretrained(
f"archinetai/{name}", trust_remote_code=True, revision=REVISION[name]
)
73 changes: 11 additions & 62 deletions audio_diffusion_pytorch/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,12 @@ def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor:
return h + self.to_out(x)


class PatchBlock(nn.Module):
class Patcher(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int = 2,
patch_size: int,
context_mapping_features: Optional[int] = None,
):
super().__init__()
Expand All @@ -223,7 +223,7 @@ def __init__(
self.block = ResnetBlock1d(
in_channels=in_channels,
out_channels=out_channels // patch_size,
num_groups=min(patch_size, in_channels),
num_groups=1,
context_mapping_features=context_mapping_features,
)

Expand All @@ -233,12 +233,12 @@ def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor:
return x


class UnpatchBlock(nn.Module):
class Unpatcher(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
patch_size: int = 2,
patch_size: int,
context_mapping_features: Optional[int] = None,
):
super().__init__()
Expand All @@ -249,7 +249,7 @@ def __init__(
self.block = ResnetBlock1d(
in_channels=in_channels // patch_size,
out_channels=out_channels,
num_groups=min(patch_size, out_channels),
num_groups=1,
context_mapping_features=context_mapping_features,
)

Expand All @@ -259,56 +259,6 @@ def forward(self, x: Tensor, mapping: Optional[Tensor] = None) -> Tensor:
return x


class Patcher(ConditionedSequential):
def __init__(
self,
in_channels: int,
out_channels: int,
blocks: int,
factor: int,
context_mapping_features: Optional[int] = None,
):
channels_pre = [in_channels * (factor ** i) for i in range(blocks)]
channels_post = [in_channels * (factor ** (i + 1)) for i in range(blocks - 1)]
channels_post += [out_channels]

super().__init__(
PatchBlock(
in_channels=channels_pre[i],
out_channels=channels_post[i],
patch_size=factor,
context_mapping_features=context_mapping_features,
)
for i in range(blocks)
)


class Unpatcher(ConditionedSequential):
def __init__(
self,
in_channels: int,
out_channels: int,
blocks: int,
factor: int,
context_mapping_features: Optional[int] = None,
):
channels_pre = [in_channels]
channels_pre += [
out_channels * (factor ** (i + 1)) for i in reversed(range(blocks - 1))
]
channels_post = [out_channels * (factor ** i) for i in reversed(range(blocks))]

super().__init__(
UnpatchBlock(
in_channels=channels_pre[i],
out_channels=channels_post[i],
patch_size=factor,
context_mapping_features=context_mapping_features,
)
for i in range(blocks)
)


"""
Attention Components
"""
Expand Down Expand Up @@ -927,8 +877,7 @@ def __init__(
factors: Sequence[int],
num_blocks: Sequence[int],
attentions: Sequence[int],
patch_blocks: int = 1,
patch_factor: int = 1,
patch_size: int = 1,
resnet_groups: int = 8,
use_context_time: bool = True,
kernel_multiplier_downsample: int = 2,
Expand Down Expand Up @@ -1013,11 +962,12 @@ def __init__(
assert exists(in_channels) and exists(out_channels)
self.stft = STFT(**stft_kwargs)

assert not kwargs, f"Unknown arguments: {', '.join(list(kwargs.keys()))}"

self.to_in = Patcher(
in_channels=in_channels + context_channels[0],
out_channels=channels * multipliers[0],
blocks=patch_blocks,
factor=patch_factor,
patch_size=patch_size,
context_mapping_features=context_mapping_features,
)

Expand Down Expand Up @@ -1076,8 +1026,7 @@ def __init__(
self.to_out = Unpatcher(
in_channels=channels * multipliers[0],
out_channels=out_channels,
blocks=patch_blocks,
factor=patch_factor,
patch_size=patch_size,
context_mapping_features=context_mapping_features,
)

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-diffusion-pytorch",
packages=find_packages(exclude=[]),
version="0.0.92",
version="0.0.93",
license="MIT",
description="Audio Diffusion - PyTorch",
long_description_content_type="text/markdown",
Expand All @@ -12,6 +12,7 @@
url="https://github.com/archinetai/audio-diffusion-pytorch",
keywords=["artificial intelligence", "deep learning", "audio generation"],
install_requires=[
"tqdm",
"torch>=1.6",
"data-science-types>=0.2",
"einops>=0.4",
Expand Down

0 comments on commit c937e49

Please sign in to comment.