Skip to content

Commit

Permalink
feat: stereo transform, convert always to 2 channels
Browse files Browse the repository at this point in the history
  • Loading branch information
flavioschneider committed Sep 4, 2022
1 parent fbe0932 commit cdb7e1a
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 16 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@ resample = Resample(source=48000, target=22050), # Resamples from 48kHz to 22kHz
from audio_data_pytorch import OverlapChannels
overlap = OverlapChannels() # Overap channels by sum (C, N) -> (1, N)

from audio_data_pytorch import DuplicateChannels
duplicate = DuplicateChannels() # Duplicate channels (1, N) -> (2, N) or (2, N) -> (2, N)
from audio_data_pytorch import Stereo
stereo = Stereo() # Duplicate channels (1, N) -> (2, N) or (2, N) -> (2, N)

from audio_data_pytorch import Scale
scale = Scale(scale=0.8) # Scale waveform amplitude by 0.8
Expand Down
2 changes: 1 addition & 1 deletion audio_data_pytorch/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from .all import AllTransform
from .crop import Crop
from .duplicate_channels import DuplicateChannels
from .loudness import Loudness
from .overlap_channels import OverlapChannels
from .randomcrop import RandomCrop
from .resample import Resample
from .scale import Scale
from .stereo import Stereo
6 changes: 3 additions & 3 deletions audio_data_pytorch/transforms/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

from ..utils import exists
from .crop import Crop
from .duplicate_channels import DuplicateChannels
from .loudness import Loudness
from .overlap_channels import OverlapChannels
from .randomcrop import RandomCrop
from .resample import Resample
from .scale import Scale
from .stereo import Stereo


class AllTransform(nn.Module):
Expand All @@ -21,7 +21,7 @@ def __init__(
random_crop_size: Optional[int] = None,
loudness: Optional[int] = None,
scale: Optional[float] = None,
duplicate_channels: bool = False,
use_stereo: bool = False,
overlap_channels: bool = False,
):
super().__init__()
Expand All @@ -39,7 +39,7 @@ def __init__(
RandomCrop(random_crop_size) if exists(random_crop_size) else nn.Identity(),
Crop(crop_size) if exists(crop_size) else nn.Identity(),
OverlapChannels() if overlap_channels else nn.Identity(),
DuplicateChannels() if duplicate_channels else nn.Identity(),
Stereo() if use_stereo else nn.Identity(),
Loudness(sampling_rate=target_rate, target=loudness) # type: ignore
if exists(loudness)
else nn.Identity(),
Expand Down
9 changes: 0 additions & 9 deletions audio_data_pytorch/transforms/duplicate_channels.py

This file was deleted.

15 changes: 15 additions & 0 deletions audio_data_pytorch/transforms/stereo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from torch import Tensor, nn


class Stereo(nn.Module):
def forward(self, x: Tensor) -> Tensor:
shape = x.shape
channels = shape[0]
if len(shape) == 1: # s -> 2, s
x = x.unsqueeze(0).repeat(2, 1)
elif len(shape) == 2:
if channels == 1: # 1, s -> 2, s
x = x.repeat(2, 1)
elif channels > 2: # ?, s -> 2,s
x = x[:2, :]
return x
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name="audio-data-pytorch",
packages=find_packages(exclude=[]),
version="0.0.8",
version="0.0.9",
license="MIT",
description="Audio Data - PyTorch",
long_description_content_type="text/markdown",
Expand Down

0 comments on commit cdb7e1a

Please sign in to comment.