diff --git a/src/model_clay.py b/src/model_clay.py index e89fef36..f8ad8e62 100644 --- a/src/model_clay.py +++ b/src/model_clay.py @@ -56,7 +56,6 @@ def __init__( # noqa: PLR0913 mask_ratio, image_size, patch_size, - shuffle, dim, depth, heads, @@ -74,7 +73,6 @@ def __init__( # noqa: PLR0913 self.mask_ratio = mask_ratio self.image_size = image_size self.patch_size = patch_size - self.shuffle = shuffle self.dim = dim self.bands = bands self.band_groups = band_groups @@ -246,7 +244,7 @@ def mask_out(self, patches): GL == self.num_patches ), f"Expected {self.num_patches} patches, got {GL} patches." - if self.shuffle: # Shuffle the patches + if self.training: # Shuffle the patches noise = torch.randn((B, GL), device=patches.device) # [B GL] else: # Don't shuffle useful for interpolation & inspection of embeddings noise = rearrange( @@ -589,14 +587,12 @@ def __init__( # noqa: PLR0913 "sar": (10, 11), "dem": (12,), }, - shuffle=True, **kwargs, ): super().__init__() self.mask_ratio = mask_ratio self.image_size = image_size self.patch_size = patch_size - self.shuffle = shuffle self.bands = bands self.band_groups = band_groups @@ -604,7 +600,6 @@ def __init__( # noqa: PLR0913 mask_ratio=mask_ratio, image_size=image_size, patch_size=patch_size, - shuffle=shuffle, dim=dim, depth=depth, heads=heads,