Skip to content

Commit

Permalink
[upd] update /pretrain/README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
keyu-tian committed Apr 4, 2023
1 parent 6ffe453 commit ec877d3
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pretrain/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See [/pretrain/models/custom.py](/pretrain/models/custom.py). The things needed
- implementing member function `get_feature_map_channels` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
- implementing member function `forward` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
- define `your_convnet(...)` with `@register_model` in [/pretrain/models/custom.py](/pretrain/models/custom.py).
- add default kwargs of `your_convnet(...)` in [/pretrain/models/__init__.py](/pretrain/models/__init__.py).
- add default kwargs of `your_convnet(...)` in [/pretrain/models/\_\_init\_\_.py](/pretrain/models/__init__.py).

Then you can use `--model=your_convnet` in the pre-training script.

Expand Down Expand Up @@ -97,4 +97,4 @@ Here is the reason: when we do mask, we:
3. then progressively upsample it (i.e., expand its 2nd and 3rd dimensions by calling `repeat_interleave(..., 2)` and `repeat_interleave(..., 3)` in [/pretrain/encoder.py line16](/pretrain/encoder.py)), to mask those feature maps ([`x` in line21](/pretrain/encoder.py)) with larger resolutions .

So if you want a patch size of 16 or 8, you should actually define a new CNN model with a downsample ratio of 16 or 8.
See `Tutorial for customizing your own CNN model` above.
See [Tutorial for customizing your own CNN model (above)](https://github.com/keyu-tian/SparK/tree/main/pretrain#some-details-how-we-mask-images-and-how-to-set-the-patch-size).
1 change: 1 addition & 0 deletions pretrain/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from models.convnext import ConvNeXt
from models.resnet import ResNet
from models.custom import YourConvNet
_import_resnets_for_timm_registration = (ResNet,)


Expand Down

0 comments on commit ec877d3

Please sign in to comment.