diff --git a/cyto_dl/image/transforms/generate_jepa_masks.py b/cyto_dl/image/transforms/generate_jepa_masks.py index 445587a2..da8201fb 100644 --- a/cyto_dl/image/transforms/generate_jepa_masks.py +++ b/cyto_dl/image/transforms/generate_jepa_masks.py @@ -5,7 +5,7 @@ from monai.transforms import RandomizableTransform from skimage.segmentation import find_boundaries -from cyto_dl.nn.vits.utils import validate_spatial_dims +from cyto_dl.nn.vits.utils import match_tuple_dimensions class JEPAMaskGenerator(RandomizableTransform): @@ -39,7 +39,7 @@ def __init__( """ assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1" - num_patches = validate_spatial_dims(spatial_dims, [num_patches])[0] + num_patches = match_tuple_dimensions(spatial_dims, [num_patches])[0] assert mask_size * max(block_aspect_ratio) < min( num_patches[-2:] ), "mask_size * max mask aspect ratio must be less than the smallest dimension of num_patches" diff --git a/cyto_dl/nn/vits/decoder.py b/cyto_dl/nn/vits/decoder.py index fe296121..0ef96cda 100644 --- a/cyto_dl/nn/vits/decoder.py +++ b/cyto_dl/nn/vits/decoder.py @@ -12,8 +12,8 @@ from cyto_dl.nn.vits.blocks import CrossAttentionBlock from cyto_dl.nn.vits.utils import ( get_positional_embedding, + match_tuple_dimensions, take_indexes, - validate_spatial_dims, ) @@ -51,7 +51,7 @@ def __init__( If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__() - num_patches, patch_size = validate_spatial_dims(spatial_dims, [num_patches, patch_size]) + num_patches, patch_size = match_tuple_dimensions(spatial_dims, [num_patches, patch_size]) self.has_cls_token = has_cls_token diff --git a/cyto_dl/nn/vits/mae.py b/cyto_dl/nn/vits/mae.py index ab792e2f..17d3e472 100644 --- a/cyto_dl/nn/vits/mae.py +++ b/cyto_dl/nn/vits/mae.py @@ -8,7 +8,7 @@ from cyto_dl.nn.vits.decoder import CrossMAE_Decoder, MAE_Decoder from cyto_dl.nn.vits.encoder import HieraEncoder, MAE_Encoder -from cyto_dl.nn.vits.utils import validate_spatial_dims +from cyto_dl.nn.vits.utils import match_tuple_dimensions class MAE_Base(torch.nn.Module, ABC): @@ -16,7 +16,7 @@ def __init__( self, spatial_dims, num_patches, patch_size, mask_ratio, features_only, context_pixels ): super().__init__() - num_patches, patch_size, context_pixels = validate_spatial_dims( + num_patches, patch_size, context_pixels = match_tuple_dimensions( spatial_dims, [num_patches, patch_size, context_pixels] ) @@ -213,7 +213,7 @@ def __init__( features_only=features_only, context_pixels=context_pixels, ) - num_mask_units = validate_spatial_dims(self.spatial_dims, [num_mask_units])[0] + num_mask_units = match_tuple_dimensions(self.spatial_dims, [num_mask_units])[0] self._encoder = HieraEncoder( num_patches=self.num_patches, diff --git a/cyto_dl/nn/vits/predictor.py b/cyto_dl/nn/vits/predictor.py index 476901b7..8eb5978f 100644 --- a/cyto_dl/nn/vits/predictor.py +++ b/cyto_dl/nn/vits/predictor.py @@ -8,8 +8,8 @@ from cyto_dl.nn.vits.blocks import CrossAttentionBlock from cyto_dl.nn.vits.utils import ( get_positional_embedding, + match_tuple_dimensions, take_indexes, - validate_spatial_dims, ) @@ -56,7 +56,7 @@ def __init__( ] ) - num_patches = validate_spatial_dims(spatial_dims, [num_patches])[0] + num_patches = match_tuple_dimensions(spatial_dims, [num_patches])[0] self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) self.pos_embedding = get_positional_embedding(