diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3b4bc09..d3fc164 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,8 @@ repos: rev: v0.991 hooks: - id: mypy - args: [ + args: + [ "--ignore-missing-imports", "--scripts-are-modules", "--disallow-incomplete-defs", @@ -35,5 +36,5 @@ repos: "--disallow-untyped-calls", "--install-types", "--non-interactive", - "--follow-imports=skip", # This is temporary until the mbi directory is not excluded + "--follow-imports=skip", ] diff --git a/README.md b/README.md index e9b8228..c2998a4 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ conda activate fdiff pip install -e . ``` 4. If you intend to train models, make sure that wandb is correctly configured on your machine by following [this guide](https://docs.wandb.ai/quickstart). +5. Some of the datasets are automatically downloaded by our scripts via kaggle API. Make sure to create a kaggle token as explained [here](https://towardsdatascience.com/downloading-datasets-from-kaggle-for-your-ml-project-b9120d405ea4). When the packages are installed, you are ready to train diffusion models! diff --git a/cmd/conf/datamodule/ecg.yaml b/cmd/conf/datamodule/ecg.yaml index 29bbc2d..fced0cc 100644 --- a/cmd/conf/datamodule/ecg.yaml +++ b/cmd/conf/datamodule/ecg.yaml @@ -1,4 +1,5 @@ _target_: fdiff.dataloaders.datamodules.ECGDatamodule data_dir: ${hydra:runtime.cwd}/data random_seed: ${random_seed} +fourier_transform: ${fourier_transform} batch_size: 64 diff --git a/cmd/conf/train.yaml b/cmd/conf/train.yaml index 3316e23..4c1ac8d 100644 --- a/cmd/conf/train.yaml +++ b/cmd/conf/train.yaml @@ -1,4 +1,5 @@ random_seed: 42 +fourier_transform: false defaults: - _self_ - score_model: default diff --git a/cmd/sample.py b/cmd/sample.py index 232fb9b..7c0264f 100644 --- a/cmd/sample.py +++ b/cmd/sample.py @@ -12,6 +12,7 @@ from fdiff.sampling.metrics import MetricCollection from fdiff.sampling.sampler import DiffusionSampler from fdiff.utils.extraction import dict_to_str, get_best_checkpoint +from fdiff.utils.fourier import idft class SamplingRunner: @@ -38,6 +39,7 @@ def __init__(self, cfg: DictConfig) -> None: # Read training config from model directory and instantiate the right datamodule train_cfg = OmegaConf.load(self.save_dir / "train_config.yaml") self.datamodule: Datamodule = instantiate(train_cfg.datamodule) + self.fourier_transform: bool = self.datamodule.fourier_transform self.datamodule.prepare_data() self.datamodule.setup() @@ -69,6 +71,10 @@ def sample(self) -> None: num_samples=self.num_samples, num_diffusion_steps=self.num_diffusion_steps ) + # If sampling in frequency domain, bring back the sample to time domain + if self.fourier_transform: + X = idft(X) + # Compute metrics results = self.metrics(X) logging.info(f"Metrics:\n{dict_to_str(results)}") diff --git a/src/fdiff/dataloaders/datamodules.py b/src/fdiff/dataloaders/datamodules.py index 4834faa..f36c926 100644 --- a/src/fdiff/dataloaders/datamodules.py +++ b/src/fdiff/dataloaders/datamodules.py @@ -10,18 +10,26 @@ from torch.utils.data import DataLoader, Dataset from fdiff.utils.dataclasses import collate_batch +from fdiff.utils.fourier import dft class DiffusionDataset(Dataset): - def __init__(self, X: torch.Tensor, y: Optional[torch.Tensor] = None): + def __init__( + self, + X: torch.Tensor, + y: Optional[torch.Tensor] = None, + fourier_transform: bool = False, + ) -> None: super().__init__() + if fourier_transform: + X = dft(X).detach() self.X = X self.y = y def __len__(self) -> int: return len(self.X) - def __getitem__(self, index) -> dict[str, torch.Tensor]: + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: data = {} data["X"] = self.X[index] if self.y is not None: @@ -35,6 +43,7 @@ def __init__( data_dir: Path | str = Path.cwd() / "data", random_seed: int = 42, batch_size: int = 32, + fourier_transform: bool = False, ) -> None: super().__init__() # Cast data_dir to Path type @@ -43,6 +52,7 @@ def __init__( self.data_dir = data_dir / self.dataset_name self.random_seed = random_seed self.batch_size = batch_size + self.fourier_transform = fourier_transform self.X_train = torch.Tensor() self.y_train: Optional[torch.Tensor] = None self.X_test = torch.Tensor() @@ -61,7 +71,9 @@ def download_data(self) -> None: ... def train_dataloader(self) -> DataLoader: - train_set = DiffusionDataset(X=self.X_train, y=self.y_train) + train_set = DiffusionDataset( + X=self.X_train, y=self.y_train, fourier_transform=self.fourier_transform + ) return DataLoader( train_set, batch_size=self.batch_size, @@ -70,7 +82,9 @@ def train_dataloader(self) -> DataLoader: ) def test_dataloader(self) -> DataLoader: - test_set = DiffusionDataset(X=self.X_test, y=self.y_test) + test_set = DiffusionDataset( + X=self.X_test, y=self.y_test, fourier_transform=self.fourier_transform + ) return DataLoader( test_set, batch_size=self.batch_size, @@ -79,7 +93,9 @@ def test_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: - test_set = DiffusionDataset(X=self.X_test, y=self.y_test) + test_set = DiffusionDataset( + X=self.X_test, y=self.y_test, fourier_transform=self.fourier_transform + ) return DataLoader( test_set, batch_size=self.batch_size, @@ -106,9 +122,13 @@ def __init__( data_dir: Path | str = Path.cwd() / "data", random_seed: int = 42, batch_size: int = 32, + fourier_transform: bool = False, ) -> None: super().__init__( - data_dir=data_dir, random_seed=random_seed, batch_size=batch_size + data_dir=data_dir, + random_seed=random_seed, + batch_size=batch_size, + fourier_transform=fourier_transform, ) def setup(self, stage: str = "fit") -> None: diff --git a/src/fdiff/sampling/metrics.py b/src/fdiff/sampling/metrics.py index 2c92fee..3c27a95 100644 --- a/src/fdiff/sampling/metrics.py +++ b/src/fdiff/sampling/metrics.py @@ -5,6 +5,7 @@ import numpy as np import torch +from fdiff.utils.fourier import dft from fdiff.utils.tensors import check_flat_array from fdiff.utils.wasserstein import WassersteinDistances @@ -33,20 +34,35 @@ def __init__( original_samples: Optional[np.ndarray | torch.Tensor] = None, include_baselines: bool = True, ) -> None: - for i, metric in enumerate(metrics): + metrics_time: list[Metric] = [] + metrics_freq: list[Metric] = [] + + original_samples_freq = ( + dft(original_samples) if original_samples is not None else None + ) + + for metric in metrics: # If metric is partially instantiated, instantiate it with original samples if isinstance(metric, partial): assert ( original_samples is not None ), f"Original samples must be provided for metric {metric.name} to be instantiated." - metrics[i] = metric(original_samples=original_samples) # type: ignore - self.metrics = metrics + metrics_time.append(metric(original_samples=original_samples)) # type: ignore + metrics_freq.append(metric(original_samples=original_samples_freq)) # type: ignore + self.metrics_time = metrics_time + self.metrics_freq = metrics_freq self.include_baselines = include_baselines def __call__(self, other_samples: np.ndarray | torch.Tensor) -> dict[str, float]: metric_dict = {} - for metric in self.metrics: - metric_dict.update(metric(other_samples)) + other_samples_freq = dft(other_samples) + for metric_time, metric_freq in zip(self.metrics_time, self.metrics_freq): + metric_dict.update( + {f"time_{k}": v for k, v in metric_time(other_samples).items()} + ) + metric_dict.update( + {f"freq_{k}": v for k, v in metric_freq(other_samples_freq).items()} + ) if self.include_baselines: metric_dict.update(self.baseline_metrics) return dict(sorted(metric_dict.items(), key=lambda item: item[0])) @@ -54,8 +70,13 @@ def __call__(self, other_samples: np.ndarray | torch.Tensor) -> dict[str, float] @property def baseline_metrics(self) -> dict[str, float]: metric_dict = {} - for metric in self.metrics: - metric_dict.update(metric.baseline_metrics) + for metric_time, metric_freq in zip(self.metrics_time, self.metrics_freq): + metric_dict.update( + {f"time_{k}": v for k, v in metric_time.baseline_metrics.items()} + ) + metric_dict.update( + {f"freq_{k}": v for k, v in metric_freq.baseline_metrics.items()} + ) return metric_dict diff --git a/src/fdiff/utils/fourier.py b/src/fdiff/utils/fourier.py new file mode 100644 index 0000000..f7c53e7 --- /dev/null +++ b/src/fdiff/utils/fourier.py @@ -0,0 +1,86 @@ +import math + +import torch +from torch.fft import irfft, rfft + + +def dft(x: torch.Tensor) -> torch.Tensor: + """Compute the DFT of the input time series by keeping only the non-redundant components. + + Args: + x (torch.Tensor): Time series of shape (batch_size, max_len, n_channels). + + Returns: + torch.Tensor: DFT of x with the same size (batch_size, max_len, n_channels). + """ + + max_len = x.size(1) + + # Compute the FFT until the Nyquist frequency + dft_full = rfft(x, dim=1, norm="ortho") + dft_re = torch.real(dft_full) + dft_im = torch.imag(dft_full) + + # The first harmonic corresponds to the mean, which is always real + zero_padding = torch.zeros_like(dft_im[:, 0, :], device=x.device) + assert torch.allclose( + dft_im[:, 0, :], zero_padding + ), f"The first harmonic of a real time series should be real, yet got imaginary part {dft_im[:, 0, :]}." + dft_im = dft_im[:, 1:] + + # If max_len is even, the last component is always zero + if max_len % 2 == 0: + assert torch.allclose( + dft_im[:, -1, :], zero_padding + ), f"Got an even {max_len=}, which should be real at the Nyquist frequency, yet got imaginary part {dft_im[:, -1, :]}." + dft_im = dft_im[:, :-1] + + # Concatenate real and imaginary parts + x_tilde = torch.cat((dft_re, dft_im), dim=1) + assert ( + x_tilde.size() == x.size() + ), f"The DFT and the input should have the same size. Got {x_tilde.size()} and {x.size()} instead." + + return x_tilde.detach() + + +def idft(x: torch.Tensor) -> torch.Tensor: + """Compute the inverse DFT of the input DFT that only contains non-redundant components. + + Args: + x (torch.Tensor): DFT of shape (batch_size, max_len, n_channels). + + Returns: + torch.Tensor: Inverse DFT of x with the same size (batch_size, max_len, n_channels). + """ + + max_len = x.size(1) + n_real = math.ceil((max_len + 1) / 2) + + # Extract real and imaginary parts + x_re = x[:, :n_real, :] + x_im = x[:, n_real:, :] + + # Create imaginary tensor + zero_padding = torch.zeros(size=(x.size(0), 1, x.size(2))) + x_im = torch.cat((zero_padding, x_im), dim=1) + + # If number of time steps is even, put the null imaginary part + if max_len % 2 == 0: + x_im = torch.cat((x_im, zero_padding), dim=1) + + assert ( + x_im.size() == x_re.size() + ), f"The real and imaginary parts should have the same shape, got {x_re.size()} and {x_im.size()} instead." + + x_freq = torch.complex(x_re, x_im) + + # Apply IFFT + x_time = irfft(x_freq, n=max_len, dim=1, norm="ortho") + + assert isinstance(x_time, torch.Tensor) + assert ( + x_time.size() == x.size() + ), f"The inverse DFT and the input should have the same size. Got {x_time.size()} and {x.size()} instead." + + return x_time.detach() diff --git a/src/fdiff/utils/wandb.py b/src/fdiff/utils/wandb.py index 428ccc9..0510528 100644 --- a/src/fdiff/utils/wandb.py +++ b/src/fdiff/utils/wandb.py @@ -8,10 +8,7 @@ def maybe_initialize_wandb(cfg: DictConfig) -> str | None: """Initialize wandb if necessary.""" cfg_flat = flatten_config(cfg) if "pytorch_lightning.loggers.WandbLogger" in cfg_flat.values(): - wandb.init( - project="FourierDiffusion", - config=cfg_flat, - ) + wandb.init(project="FourierDiffusion", config=cfg_flat, entity="fdiff") assert wandb.run is not None run_id = wandb.run.id assert isinstance(run_id, str) diff --git a/tests/test_datamodules.py b/tests/test_datamodules.py index 18237a8..fdb612c 100644 --- a/tests/test_datamodules.py +++ b/tests/test_datamodules.py @@ -4,6 +4,7 @@ from fdiff.dataloaders.datamodules import Datamodule from fdiff.utils.dataclasses import DiffusableBatch +from fdiff.utils.fourier import idft max_len = 30 n_channels = 3 @@ -20,15 +21,20 @@ def __init__( batch_size: int = batch_size, max_len: int = max_len, n_channels: int = n_channels, + fourier_transform: bool = False, ) -> None: super().__init__( - data_dir=data_dir, random_seed=random_seed, batch_size=batch_size + data_dir=data_dir, + random_seed=random_seed, + batch_size=batch_size, + fourier_transform=fourier_transform, ) self.max_len = max_len self.n_channels = n_channels self.batch_size = batch_size def setup(self, stage: str = "fit") -> None: + torch.manual_seed(self.random_seed) self.X_train = torch.randn( (10 * self.batch_size, self.max_len, self.n_channels), dtype=torch.float32 ) @@ -46,7 +52,7 @@ def dataset_name(self) -> str: return "dummy" -def test_dataloader(): +def test_dataloader() -> None: datamodule = DummyDatamodule() datamodule.prepare_data() datamodule.setup() @@ -55,3 +61,20 @@ def test_dataloader(): assert isinstance(batch, DiffusableBatch) assert batch.X.shape == (batch_size, max_len, n_channels) assert batch.y.shape == (batch_size,) + + +def test_fourier_transform() -> None: + # Default datamodule + datamodule = DummyDatamodule() + datamodule.prepare_data() + datamodule.setup() + + # Fourier datamodule + datamodule_fourier = DummyDatamodule(fourier_transform=True) + datamodule_fourier.prepare_data() + datamodule_fourier.setup() + + X = datamodule.train_dataloader().dataset.X + X_tilde = datamodule_fourier.train_dataloader().dataset.X + + assert torch.allclose(X, idft(X_tilde), atol=1e-5) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index ed48204..e55c116 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize("shift", test_data_wasserstein) -def test_sliced_waserstein(shift: float): +def test_sliced_waserstein(shift: float) -> None: # Set random seed np.random.seed(random_seed) @@ -28,6 +28,7 @@ def test_sliced_waserstein(shift: float): X_t=check_flat_array(dataset1), X_s=check_flat_array(dataset2), n_projections=num_directions, + seed=random_seed, ) # Compute sliced wasserstein distance @@ -43,7 +44,7 @@ def test_sliced_waserstein(shift: float): @pytest.mark.parametrize("shift", test_data_wasserstein) -def test_marginal_waserstein(shift: float): +def test_marginal_waserstein(shift: float) -> None: # Set random seed np.random.seed(random_seed) diff --git a/tests/test_utils.py b/tests/test_utils.py index eca55d5..65522c4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,9 +1,15 @@ +import torch from omegaconf import DictConfig from fdiff.utils.extraction import flatten_config +from fdiff.utils.fourier import dft, idft +max_len = 100 +n_channels = 3 +batch_size = 100 -def test_flatten_config(): + +def test_flatten_config() -> None: cfg_dict = { "Option1": "Value1", "Option2": { @@ -25,3 +31,21 @@ def test_flatten_config(): "Option5": ["Value5_0", "Value5_1"], "Option6": "Value6", } + + +def test_dft() -> None: + # Create a random real time series + x_even = torch.randn(batch_size, max_len, n_channels) + x_odd = torch.randn(batch_size, max_len + 1, n_channels) + + # Compute the DFT + x_even_tilde = dft(x_even) + x_odd_tilde = dft(x_odd) + + # Compute the inverse DFT + x_even_hat = idft(x_even_tilde) + x_odd_hat = idft(x_odd_tilde) + + # Check that the inverse DFT is the original time series + assert torch.allclose(x_even, x_even_hat, atol=1e-5) + assert torch.allclose(x_odd, x_odd_hat, atol=1e-5)