From f3e8895c7a1a691bc7fb0c07618c3be0015887eb Mon Sep 17 00:00:00 2001 From: Kevin Date: Tue, 22 Aug 2023 13:43:36 +0800 Subject: [PATCH] Update model.py wget jx_vit_base_p16_224-80ecf9dd.pth --- model/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/model/model.py b/model/model.py index c1a5d99..dc0b4d2 100644 --- a/model/model.py +++ b/model/model.py @@ -43,8 +43,9 @@ def __init__(self, arch_config = video_params.get('arch_config', 'base_patch16_224') vit_init = video_params.get('vit_init', 'imagenet-21k') if arch_config == 'base_patch16_224': - vit_model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=pretrained) - # vit_model = torch.load("pretrained/jx_vit_base_p16_224-80ecf9dd.pth", map_location="cpu") + # you can download the checkpoint via wget https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth + # vit_model = timm.models.vision_transformer.vit_base_patch16_224(pretrained=pretrained) + vit_model = torch.load("pretrained/jx_vit_base_p16_224-80ecf9dd.pth", map_location="cpu") model = SpaceTimeTransformer(num_frames=num_frames, time_init=time_init, attention_style=attention_style)