Skip to content

Commit

Permalink
mmscan-devkit
Browse files Browse the repository at this point in the history
  • Loading branch information
rbler1234 committed Nov 29, 2024
1 parent a84d590 commit 11c5c4e
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 17 deletions.
11 changes: 9 additions & 2 deletions mmscan/utils/euler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,15 @@

import numpy as np
import torch
from pytorch3d.ops import box3d_overlap
from pytorch3d.transforms import euler_angles_to_matrix, matrix_to_euler_angles

try:
from pytorch3d.ops import box3d_overlap
from pytorch3d.transforms import (euler_angles_to_matrix,
matrix_to_euler_angles)
except ImportError:
box3d_overlap = None
euler_angles_to_matrix = None
matrix_to_euler_angles = None
from torch import Tensor


Expand Down
15 changes: 2 additions & 13 deletions models/LEO/model/leo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,8 @@ def forward(self, data_dict):
bs = len(data_dict['prompt_after_obj'])
if 'obj_tokens' not in data_dict:

try:
data_dict = self.pcd_encoder(data_dict)
except:
torch.save(
self.state_dict(),
'/mnt/petrelfs/linjingli/tmp/data/big_tmp/model_dict_1028.pth'
)
torch.save(
data_dict,
'/mnt/petrelfs/linjingli/tmp/data/big_tmp/debug_leo_1028.pt'
)
import IPython
IPython.embed()
data_dict = self.pcd_encoder(data_dict)

data_dict['obj_tokens'] = self.pcd_proj(
data_dict['obj_tokens'].to(device))
# data_dict['obj_tokens'] = data_dict['obj_tokens'] + self.pcd_type_embed
Expand Down
24 changes: 22 additions & 2 deletions models/LL3DA/eval_utils/evaluate_mmscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@

from mmscan import QA_Evaluator

model_config = {
'simcse': '/mnt/petrelfs/linjingli/mmscan_modelzoo-main/evaluation/pc',
'sbert': '/mnt/petrelfs/linjingli/mmscan_modelzoo-main/evaluation/st'
}


def to_mmscan_form(raw_input):
_input = {}
Expand All @@ -39,8 +44,9 @@ def evaluate(
logout=print,
curr_train_iter=-1,
):
"""define the mmscan_eval here."""
model_config = {'simcse': '', 'sbert': ''}

# prepare ground truth caption labels
print('preparing corpus...')

evaluator = QA_Evaluator(model_config)

Expand Down Expand Up @@ -130,6 +136,20 @@ def evaluate(
logout(f'Evaluate {epoch_str}; Batch [{curr_iter}/{num_batches}]; '
f'Evaluating on iter: {curr_train_iter}; '
f'Iter time {time_delta.avg:0.2f}; Mem {mem_mb:0.2f}MB')
if curr_iter % 200 == 0:
with open(
os.path.join(args.checkpoint_dir,
f'qa_pred_gt_val_{curr_iter}.json'),
'w') as f:
pred_gt_val = {}
for index_, scene_object_id_key in enumerate(candidates):
pred_gt_val[scene_object_id_key] = {
'instruction': scene_object_id_key.split('@')[1],
'pred': candidates[scene_object_id_key],
'gt': corpus[scene_object_id_key],
}
json.dump(pred_gt_val, f, indent=4)
print(f'save pred_gt_val {curr_iter}')
barrier()

if is_primary():
Expand Down

0 comments on commit 11c5c4e

Please sign in to comment.