Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix some bugs and adjust the code about number of frames. #106

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ The folder structure of the dataset is shown below:
|────...
```
#### Script
Config the the checkpoint and dataset paths in [video_llama_stage1_pretrain.yaml](./train_configs/video_llama_stage1_pretrain.yaml).
Config the the checkpoint and dataset paths in [visionbranch_stage1_pretrain.yaml](./train_configs/visionbranch_stage1_pretrain.yaml).
Run the script:
```
conda activate videollama
torchrun --nproc_per_node=8 train.py --cfg-path ./train_configs/video_llama_stage1_pretrain.yaml
torchrun --nproc_per_node=8 train.py --cfg-path ./train_configs/visionbranch_stage1_pretrain.yaml
```

### 2. Instruction Fine-tuning
Expand All @@ -205,10 +205,10 @@ For now, the fine-tuning dataset consists of:
* 11K video-based instructions from VideoChat [[link](https://github.com/OpenGVLab/InternVideo/tree/main/Data/instruction_data)]

#### Script
Config the checkpoint and dataset paths in [video_llama_stage2_finetune.yaml](./train_configs/video_llama_stage2_finetune.yaml).
Config the checkpoint and dataset paths in [visionbranch_stage2_finetune.yaml](./train_configs/visionbranch_stage2_finetune.yaml).
```
conda activate videollama
torchrun --nproc_per_node=8 train.py --cfg-path ./train_configs/video_llama_stage2_finetune.yaml
torchrun --nproc_per_node=8 train.py --cfg-path ./train_configs/visionbranch_stage2_finetune.yaml
```

## Recommended GPUs
Expand Down
2 changes: 1 addition & 1 deletion demo_audiovideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def setup_seeds(config):
model.eval()
vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), vis_processor_cfg)
print('Initialization Finished')

# ========================================
Expand Down
2 changes: 1 addition & 1 deletion demo_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def setup_seeds(config):
model.eval()
vis_processor_cfg = cfg.datasets_cfg.webvid.vis_processor.train
vis_processor = registry.get_processor_class(vis_processor_cfg.name).from_config(vis_processor_cfg)
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id))
chat = Chat(model, vis_processor, device='cuda:{}'.format(args.gpu_id), vis_processor_cfg)
print('Initialization Finished')

# ========================================
Expand Down
4 changes: 2 additions & 2 deletions eval_configs/video_llama_eval_only_vl.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
model:
arch: video_llama
model_type: pretrain_vicuna
model_type: pretrain_vicuna or pretrain_llama_v2
freeze_vit: True
freeze_qformer: True
max_txt_len: 512
end_sym: "###"
end_sym: "###" or "</s>"
low_resource: False

frozen_llama_proj: False
Expand Down
4 changes: 2 additions & 2 deletions eval_configs/video_llama_eval_withaudio.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
model:
arch: video_llama
model_type: pretrain_vicuna
model_type: pretrain_vicuna or pretrain_llama_v2
freeze_vit: True
freeze_qformer: True
max_txt_len: 512
end_sym: "###"
end_sym: "###" or "</s>"
low_resource: False

frozen_llama_proj: False
Expand Down
2 changes: 1 addition & 1 deletion train_configs/audiobranch_stage1_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
model:
arch: video_llama
model_type: pretrain_vicuna
model_type: pretrain_vicuna or pretrain_llama_v2
freeze_vit: True
freeze_qformer: True
low_resource: False
Expand Down
2 changes: 1 addition & 1 deletion train_configs/audiobranch_stage2_finetune.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
model:
arch: video_llama
model_type: pretrain_vicuna
model_type: pretrain_vicuna or pretrain_llama_v2
freeze_vit: True
freeze_qformer: True

Expand Down
2 changes: 1 addition & 1 deletion train_configs/visionbranch_stage1_pretrain.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
model:
arch: video_llama
model_type: pretrain_vicuna
model_type: pretrain_vicuna or pretrain_llama_v2
freeze_vit: True
freeze_qformer: True

Expand Down
2 changes: 1 addition & 1 deletion train_configs/visionbranch_stage2_finetune.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
model:
arch: video_llama
model_type: pretrain_vicuna
model_type: pretrain_vicuna or pretrain_llama_v2
freeze_vit: True
freeze_qformer: True

Expand Down
20 changes: 11 additions & 9 deletions video_llama/conversation/conversation_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,13 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
sep2="</s>",
)
class Chat:
def __init__(self, model, vis_processor, device='cuda:0'):
def __init__(self, model, vis_processor, device='cuda:0', vis_processor_cfg):
self.device = device
self.model = model
self.vis_processor = vis_processor
self.image_vis_processor = Blip2ImageEvalProcessor()
self.n_frms = vis_processor_cfg.n_frms
self.image_size = vis_processor_cfg.image_size
# stop_words_ids = [torch.tensor([835]).to(self.device),
# torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
# self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
Expand Down Expand Up @@ -230,7 +232,7 @@ def answer(self, conv, img_list, max_new_tokens=300, num_beams=1, min_length=1,
conv.messages[-1][1] = output_text
return output_text, output_token.cpu().numpy()

def upload_video(self, video_path, conv, img_list):
def upload_video(self, video_path, conv, img_list, n_frms=8, image_size=224):

msg = ""
if isinstance(video_path, str): # is a video path
Expand All @@ -239,9 +241,9 @@ def upload_video(self, video_path, conv, img_list):
# image = self.vis_processor(image).unsqueeze(0).to(self.device)
video, msg = load_video(
video_path=video_path,
n_frms=8,
height=224,
width=224,
n_frms=self.n_frms,
height=self.image_size,
width=self.image_size,
sampling ="uniform", return_msg = True
)
video = self.vis_processor.transform(video)
Expand Down Expand Up @@ -277,17 +279,17 @@ def upload_video(self, video_path, conv, img_list):
conv.append_message(conv.roles[0], "<Video><ImageHere></Video> "+ msg)
return "Received."

def upload_video_without_audio(self, video_path, conv, img_list):
def upload_video_without_audio(self, video_path, conv, img_list, n_frms=8, image_size=224):
msg = ""
if isinstance(video_path, str): # is a video path
ext = os.path.splitext(video_path)[-1].lower()
print(video_path)
# image = self.vis_processor(image).unsqueeze(0).to(self.device)
video, msg = load_video(
video_path=video_path,
n_frms=8,
height=224,
width=224,
n_frms=self.n_frms,
height=self.image_size,
width=self.image_size,
sampling ="uniform", return_msg = True
)
video = self.vis_processor.transform(video)
Expand Down
2 changes: 2 additions & 0 deletions video_llama/datasets/builders/instruct_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def build(self):
vis_root=build_info.videos_dir,
ann_root=build_info.anno_dir,
num_video_query_token = num_video_query_token,
resize_size=self.config.vis_processor.train.image_size,
num_frm=self.config.vis_processor.train.n_frms,
tokenizer_name = tokenizer_name,
data_type = self.config.data_type
)
Expand Down
6 changes: 3 additions & 3 deletions video_llama/datasets/datasets/llava_instruct_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
IGNORE_INDEX = -100

class Instruct_Dataset(BaseDataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'image', model_type='vicuna'):
def __init__(self, vis_processor, text_processor, vis_root, ann_root, num_video_query_token=32, resize_size=224, num_frm=8,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'image', model_type='vicuna'):
"""
vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/)
ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
Expand All @@ -52,8 +52,8 @@ def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_q
self.annotation = json.load(f)

self.vis_root = vis_root
self.resize_size = 224
self.num_frm = 8
self.resize_size = resize_size
self.num_frm = num_frm
self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False)
self.tokenizer.pad_token = self.tokenizer.unk_token
self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
Expand Down
6 changes: 3 additions & 3 deletions video_llama/datasets/datasets/video_instruct_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
IGNORE_INDEX = -100

class Video_Instruct_Dataset(BaseDataset):
def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_query_token=32,tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'video', model_type='vicuna'):
def __init__(self, vis_processor, text_processor, vis_root, ann_root, num_video_query_token=32, resize_size=224, num_frm=8, tokenizer_name = '/mnt/workspace/ckpt/vicuna-13b/',data_type = 'video', model_type='vicuna'):
"""
vis_root (string): Root directory of Llava images (e.g. webvid_eval/video/)
ann_root (string): Root directory of video (e.g. webvid_eval/annotations/)
Expand All @@ -54,8 +54,8 @@ def __init__(self, vis_processor, text_processor, vis_root, ann_root,num_video_q

self.num_video_query_token = num_video_query_token
self.vis_root = vis_root
self.resize_size = 224
self.num_frm = 8
self.resize_size = resize_size
self.num_frm = num_frm
self.tokenizer = LlamaTokenizer.from_pretrained(tokenizer_name, use_fast=False)
self.tokenizer.pad_token = self.tokenizer.unk_token
self.tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
Expand Down
6 changes: 3 additions & 3 deletions video_llama/runners/runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,14 +627,14 @@ def _load_checkpoint(self, url_or_filename):
cached_file = download_cached_file(
url_or_filename, check_hash=False, progress=True
)
checkpoint = torch.load(cached_file, map_location=self.device, strict=False)
checkpoint = torch.load(cached_file, map_location=self.device)
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location=self.device, strict=False)
checkpoint = torch.load(url_or_filename, map_location=self.device)
else:
raise RuntimeError("checkpoint url or path is invalid")

state_dict = checkpoint["model"]
self.unwrap_dist_model(self.model).load_state_dict(state_dict)
self.unwrap_dist_model(self.model).load_state_dict(state_dict, strict=False)

self.optimizer.load_state_dict(checkpoint["optimizer"])
if self.scaler and "scaler" in checkpoint:
Expand Down