Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
wget jx_vit_base_p16_224-80ecf9dd.pth
  • Loading branch information
QinghongLin authored Aug 22, 2023
1 parent 063bd10 commit f3e8895
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ def __init__(self,
arch_config = video_params.get('arch_config', 'base_patch16_224')
vit_init = video_params.get('vit_init', 'imagenet-21k')
if arch_config == 'base_patch16_224':
vit_model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=pretrained)
# vit_model = torch.load("pretrained/jx_vit_base_p16_224-80ecf9dd.pth", map_location="cpu")
# you can download the checkpoint via wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth
# vit_model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=pretrained)
vit_model = torch.load("pretrained/jx_vit_base_p16_224-80ecf9dd.pth", map_location="cpu")
model = SpaceTimeTransformer(num_frames=num_frames,
time_init=time_init,
attention_style=attention_style)
Expand Down

0 comments on commit f3e8895

Please sign in to comment.