-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/hiera #418
Feature/hiera #418
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a lot!! Loved reading it...As always, minor comments
configs/model/im2im/hiera.yaml
Outdated
backbone: | ||
_target_: cyto_dl.nn.vits.mae.HieraMAE | ||
spatial_dims: ${spatial_dims} | ||
patch_size: 2 # patch_size* num_patches should be your patch shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be image shape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be a list for ZYX patch size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes - the terminology is confusing here haha. "patch" = the small crop extracted from your original image, but "patch" is also the tokenized component of the image fed into the network. The patch size can be either an int (repeated for each spatial dim) or a list of size spatial_dims
configs/model/im2im/hiera.yaml
Outdated
spatial_dims: ${spatial_dims} | ||
patch_size: 2 # patch_size* num_patches should be your patch shape | ||
num_patches: 8 # patch_size * num_patches = img_shape | ||
num_mask_units: 4 #img_shape / num_mask_units = size of each mask unit in pixels, num_patches/num_mask_units = number of patches permask unit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clarify what a mask unit is here?
configs/model/im2im/hiera.yaml
Outdated
architecture: | ||
# mask_unit_attention blocks - attention is only done within a mask unit and not across mask units | ||
# the total amount of q_stride across the architecture must be less than the number of patches per mask unit | ||
- repeat: 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is repeat?
# self attention transformer - attention is done across all patches, irrespective of which mask unit they're in | ||
- repeat: 2 | ||
num_heads: 4 | ||
self_attention: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so last layer is global attention and first 2 layers are local attention? Is 3 layers the recommended hierarchy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct. 3 layers is small enough to test quickly. All of the models with unit tests are tiny by default in the configs and I have somewhere in the docs that you should increase the model size if you want good performance.
if self.spatial_dims == 3: | ||
q = reduce( | ||
q, | ||
"b n h (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c ->b n h (n_patches_z n_patches_y n_patches_x) c", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you use the same nomenclature here? e.g. n = num_mask_units = mask_units, num_heads = h = n_heads
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
c = head_dim
self.spatial_dims = spatial_dims | ||
self.num_heads = num_heads | ||
self.head_dim = dim_out // num_heads | ||
self.scale = qk_scale or self.head_dim**-0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this isn't used anywhere
# change dimension and subsample within mask unit for skip connection | ||
x = self.proj(x_norm) | ||
|
||
x = x + self.drop_path(self.attn(x_norm)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does dim_out = dim for skip connection with attention?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question - each block specified in the architecture argument doubles the embedding dimension and halves the size of the mask unit. This doubling/pooling happens on the last repeat of the block, so dim_out=dim for all repeats except the last. I updated the docstring with an example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool!
dim_out: int, | ||
heads: int, | ||
spatial_dims: int = 3, | ||
mlp_ratio: float = 4.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is mlp_ratio? add to docstring?
|
||
|
||
class PatchifyHiera(PatchifyBase): | ||
"""Class for converting images to a masked sequence of patches with positional embeddings.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to "mask units" instead of masked sequence? since that's what a regular patchify does?
cyto_dl/nn/vits/utils.py
Outdated
@@ -40,3 +47,8 @@ def get_positional_embedding( | |||
cls_token = torch.zeros(1, 1, emb_dim) | |||
pe = torch.cat([cls_token, pe], dim=0) | |||
return torch.nn.Parameter(pe, requires_grad=False) | |||
|
|||
|
|||
def validate_spatial_dims(spatial_dims, tuples): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like the code might be clearer by not having this be a separate function and just calling these 2 lines in every class? I thought this function was doing a lot more based on the name (like some math to check that the spatial dimensions of each patch and mask is correct). What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd prefer renaming it to something clearer rather than repeating the code, maybe match_tuple_to_spatial_dims
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good
windows tests always seem to be failing. any ideas why? |
the windows tests are just way slower for some reason... all the tests pass but I set a 70 minute time out so we don't rack up crazy costs. |
What does this PR do?
Before submitting
pytest
command?pre-commit run -a
command?Did you have fun?
Make sure you had fun coding 🙃