Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Noietch committed Nov 16, 2024
1 parent ebb2e01 commit dad3a3f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
13 changes: 8 additions & 5 deletions exps/muti_label/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ def parse_data_info(self, raw_data_info):
# print(raw_img_info)
# print([cur_cates[ann['category_id']-1] for ann in raw_ann_info])
# to one-hot
category_ids = torch.unique(torch.tensor([ann['category_id'] for ann in raw_ann_info])) - 1
cate_one_hot = torch.eye(len(cur_cates))[category_ids].sum(dim=0)
if len(raw_ann_info) == 0:
cate_one_hot = torch.zeros(len(cur_cates))
else:
category_ids = torch.unique(torch.tensor([ann['category_id'] for ann in raw_ann_info])) - 1
cate_one_hot = torch.eye(len(cur_cates))[category_ids].sum(dim=0)

return {
"img_path": os.path.join(self.data_root, raw_img_info['file_name']),
Expand All @@ -45,9 +48,9 @@ def load_data_list(self) -> list[dict]:
ann_ids = self.lvis.get_ann_ids(img_ids=[img_id])
raw_ann_info = self.lvis.load_anns(ann_ids)

if len(raw_ann_info) == 0:
# print(f"Image {img_id} has no annotations, skipped.")
continue
# if len(raw_ann_info) == 0:
# # print(f"Image {img_id} has no annotations, skipped.")
# continue

parsed_data_info = self.parse_data_info({
'raw_ann_info':
Expand Down
8 changes: 4 additions & 4 deletions oadp/dp/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def forward(
img_path = img_meta['img_path']
if img_path not in self.ram_pred_result:
print(f"Image path {img_path} not found in RAM prediction results")
ram_cls_score.append(np.ones((1, cls_score.shape[1])))
raise ValueError(f"Image path {img_path} not found in RAM prediction results")
ram_cls_score.append(np.ones((cls_score.shape[1] - 1, )))
# raise ValueError(f"Image path {img_path} not found in RAM prediction results")
else:
ram_cls_score.append(self.ram_pred_result[img_path])
ram_cls_score = np.concatenate(ram_cls_score, axis=0)
Expand All @@ -203,8 +203,8 @@ def forward(
ram_cls_score = repeat(ram_cls_score, 'c -> b c', b=num_box).sigmoid()
ram_cls_score_with_bg = torch.cat([ram_cls_score, background_score], dim=1).to(cls_score.device)

return cls_score * ram_cls_score_with_bg
return (cls_score.softmax(-1) * ram_cls_score_with_bg).log()


@MODELS.register_module()
class RAMEnsembleOADPRoIHead(OADPRoIHead):
Expand Down

0 comments on commit dad3a3f

Please sign in to comment.