-
Notifications
You must be signed in to change notification settings - Fork 45
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add option to output raw patch embeddings #133
Conversation
src/model_clay.py
Outdated
) | ||
if self.hparams.output_patch_embeddings: | ||
embeddings_raw = rearrange( | ||
embeddings_raw[:, :-2, :], "b (w h g) s -> b (w h s) g", w=16, h=16, g=6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The order of the tensor matters, in this case the ouput from the encoder i.e embeddings_raw is of shape batch x (group x num_spatial_patches) x embedding_dims
.
So, the einops operation should be, b (g l) d -> b g l d
- you can check this notebook for reference: https://github.com/Clay-foundation/model/blob/docs/model/docs/clay-v0-visualization.ipynb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, but please review again to make sure I got it right this time. Related to this, do you think this is a good way to unravel the patch embeddings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mentioned this below at #133 (comment), but can we just keep the unravelled shape (B, 256, 768), or (B, 1536, 768)? Less work for the downstream user since they won't need to figure out how to unravel the tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes absolutely. Don't know why I thought it had to be a 1d array...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That said, do we want to store a (256, 768) embedding tensor in a single row, or split it into 256 rows of 768-length embeddings? I'm reading the original thread at #127, and it sounds like we need to enable searching on the 32x32 patch level instead of 256x256 chips, which might mean storing a row for each patch?
We can store 2D arrays in GeoParquet, but I'm not sure if vector databases allow indexing 2D arrays, or if we need to make it 1D. Note that this would increase the embedding file size significantly (x256), and the vector database indexing will be much slower. But if these raw embeddings are only meant to be generated on an ad-hoc basis for small locations, and not for similarity search applications, it should be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about that too, but that would also imply that we have to track the MGRS tile and the source path 256x more. That seems ineficient. My thinking was that the splitting into rows can happen at the moment on ingestion to a vector search DB or when analysing if required. For "local manual inspection" situations, having the patches in a single array is useful too. So I vote to keep it at one row in the gpd and let the separation be a downstream issue.
docs/model_embeddings.md
Outdated
The `output_patch_embeddings` flag determines how the embeddings are calculated. | ||
If `False`, one average embedding per MGRS tile of size 768 will be created. If | ||
`True`, the embeddings will be kept at the patch level. The embedding array will | ||
be of size 16 * 16 * 768, representing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dangling sentence, representing what? Also, should be size 16*16, 768
no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated wording to adopt to the embedding levels input.
output embedding of shape (B, 768). | ||
2. By default, the mean or average is taken across the 1536 patch dimension, | ||
yielding an output embedding of shape (B, 768). If patch embeddings are | ||
requested, the shape is (B, 16 * 16 * 768), one embedding per patch. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we decide to use the unsqueezed shape.
requested, the shape is (B, 16 * 16 * 768), one embedding per patch. | |
requested, the shape is (B, 6*16*16, 768), one embedding per patch. |
src/model_clay.py
Outdated
if self.hparams.output_patch_embeddings: | ||
# Take the mean of the embeddings along the group dimension | ||
# excluding the last two latlon_ and time_ embeddings. This | ||
# results in one embedding per patch. | ||
embeddings_raw = rearrange( | ||
embeddings_raw[:, :-2, :], "b (g l) s -> b g (l s)", l=256, g=6 | ||
) | ||
embeddings_mean = reduce(embeddings_raw, "b g s -> b s", "mean") | ||
assert embeddings_mean.shape == torch.Size( | ||
[ | ||
self.model.encoder.B, | ||
256 * 768, | ||
] # (batch_size, nr of patches * hidden_size) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we squeezing the embeddings to a size like (B, 256*768)
?? Can't we just keep the raw embedding shape of (B, 1538, 768)
, where 1538 is 6 (band groups) * 16 * 16 (patch size)
? It will be a lot harder to unsqueeze a 196608 length tensor back to 6*16*16
later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unsqueezed would undoubtedly be better, but I assumed that we can only store one-dimensional arrays in a field. I guess from your questions that is not the case. Can we store multidimensional arrays here? If yes, lets store the unsqueezed one for sure!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, GeoParquet/Arrow technically supports FixedShapeTensorArray which can be multi-dimensional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just found out that this will require a rewrite without pandas, because pandas can not handle this and implicitly flattens the FixedShapeTensorArray
, see example below. I ran across this when running tests and not getting the expected multidimensional arrays. This defintively complicates things in terms of keeping the structure.
import numpy as np
import pandas as pd
import pyarrow as pa
In [15]: array = np.arange(8).reshape((2,2,2))
...:
...: arrow = pa.FixedShapeTensorArray.from_numpy_ndarray(array)
...:
...: df = pd.DataFrame(data={"arrow": arrow})
...:
...: df
...:
Out[15]:
arrow
0 [0, 1, 2, 3]
1 [4, 5, 6, 7]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok so based on the above, the arrays are outputted in a flat structure for now. I updated the documentation with a sentence about where the structure comes from and why the arrays are flat.
This is not ideal but keeping the multidimensional arrays now would require refactoring of how we construct the geoparquet files.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, have you tried inserting a FixedShapeTensorArray into a row cell directly? This seems to work:
import numpy as np
import pandas as pd
import pyarrow as pa
array0 = np.arange(0, 8).reshape((2, 2, 2))
array1 = np.arange(9, 17).reshape((2, 2, 2))
arrow0 = pa.FixedShapeTensorArray.from_numpy_ndarray(array0)
arrow1 = pa.FixedShapeTensorArray.from_numpy_ndarray(array1)
df = pd.DataFrame(data={"embeddings": [arrow0, arrow1]})
print(df)
# embeddings
# 0 ([0, 1, 2, 3], [4, 5, 6, 7])
# 1 ([9, 10, 11, 12], [13, 14, 15, 16])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The arrow part still flattens the arrays that are input. For this to work, we would have to do the manual deconstruction up to 3 dimensions deep. Not sure that this will make things more user friendly. With this approach, the user would have to reconstruct a list of a list of a list of arrays in the "group" case.
The patch embeddings are averages over the band groups.
Allow for 3 levels: mean, patch, group. Arrays are flattened when passed to pandas. This could be improved in the future.
cd6bb7b
to
f6a64d8
Compare
No longer necessary after #135
f6a64d8
to
bda3b24
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like the idea of storing embeddings at multiple levels.
@@ -87,19 +87,27 @@ def test_model_vit_fit(datapipe): | |||
@pytest.mark.parametrize( | |||
"litmodule,precision", | |||
[ | |||
(CLAYModule, "bf16-mixed" if torch.cuda.is_available() else "32-true"), | |||
(ViTLitModule, "bf16-mixed"), | |||
(CLAYModule, "16-mixed" if torch.cuda.is_available() else "32-true"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is bf16-mixed
precision giving any issue while inferencing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is for the Nvidia Geforce RTX on my local machine, which does not support this type apparently.
This PR adds an option to the Lightning CLI to output patch level embeddings. I.e. one embedding per patch in each chip. The band group dimension is reduced and, so the patch embeddings are averages over the band groups.
This adds two args to the CLI
output_patch_embeddings
, andshuffle
. Because for the patch embeddings we need to ensure that shuffle is off, while it is on by default. See also #123.Updated the documentation to explain these changes.
This PR Closes #130