From 11c5c4e1ca84b6becb08ee89ffddebf7ced5a77a Mon Sep 17 00:00:00 2001 From: rbler1234 Date: Fri, 29 Nov 2024 15:52:05 +0800 Subject: [PATCH] mmscan-devkit --- mmscan/utils/euler_utils.py | 11 ++++++++-- models/LEO/model/leo_agent.py | 15 ++------------ models/LL3DA/eval_utils/evaluate_mmscan.py | 24 ++++++++++++++++++++-- 3 files changed, 33 insertions(+), 17 deletions(-) diff --git a/mmscan/utils/euler_utils.py b/mmscan/utils/euler_utils.py index 40da72f..4f078b4 100644 --- a/mmscan/utils/euler_utils.py +++ b/mmscan/utils/euler_utils.py @@ -4,8 +4,15 @@ import numpy as np import torch -from pytorch3d.ops import box3d_overlap -from pytorch3d.transforms import euler_angles_to_matrix, matrix_to_euler_angles + +try: + from pytorch3d.ops import box3d_overlap + from pytorch3d.transforms import (euler_angles_to_matrix, + matrix_to_euler_angles) +except ImportError: + box3d_overlap = None + euler_angles_to_matrix = None + matrix_to_euler_angles = None from torch import Tensor diff --git a/models/LEO/model/leo_agent.py b/models/LEO/model/leo_agent.py index 09b0e94..d5c7ecd 100644 --- a/models/LEO/model/leo_agent.py +++ b/models/LEO/model/leo_agent.py @@ -281,19 +281,8 @@ def forward(self, data_dict): bs = len(data_dict['prompt_after_obj']) if 'obj_tokens' not in data_dict: - try: - data_dict = self.pcd_encoder(data_dict) - except: - torch.save( - self.state_dict(), - '/mnt/petrelfs/linjingli/tmp/data/big_tmp/model_dict_1028.pth' - ) - torch.save( - data_dict, - '/mnt/petrelfs/linjingli/tmp/data/big_tmp/debug_leo_1028.pt' - ) - import IPython - IPython.embed() + data_dict = self.pcd_encoder(data_dict) + data_dict['obj_tokens'] = self.pcd_proj( data_dict['obj_tokens'].to(device)) # data_dict['obj_tokens'] = data_dict['obj_tokens'] + self.pcd_type_embed diff --git a/models/LL3DA/eval_utils/evaluate_mmscan.py b/models/LL3DA/eval_utils/evaluate_mmscan.py index 274cdba..09ff578 100644 --- a/models/LL3DA/eval_utils/evaluate_mmscan.py +++ b/models/LL3DA/eval_utils/evaluate_mmscan.py @@ -18,6 +18,11 @@ from mmscan import QA_Evaluator +model_config = { + 'simcse': '/mnt/petrelfs/linjingli/mmscan_modelzoo-main/evaluation/pc', + 'sbert': '/mnt/petrelfs/linjingli/mmscan_modelzoo-main/evaluation/st' +} + def to_mmscan_form(raw_input): _input = {} @@ -39,8 +44,9 @@ def evaluate( logout=print, curr_train_iter=-1, ): - """define the mmscan_eval here.""" - model_config = {'simcse': '', 'sbert': ''} + + # prepare ground truth caption labels + print('preparing corpus...') evaluator = QA_Evaluator(model_config) @@ -130,6 +136,20 @@ def evaluate( logout(f'Evaluate {epoch_str}; Batch [{curr_iter}/{num_batches}]; ' f'Evaluating on iter: {curr_train_iter}; ' f'Iter time {time_delta.avg:0.2f}; Mem {mem_mb:0.2f}MB') + if curr_iter % 200 == 0: + with open( + os.path.join(args.checkpoint_dir, + f'qa_pred_gt_val_{curr_iter}.json'), + 'w') as f: + pred_gt_val = {} + for index_, scene_object_id_key in enumerate(candidates): + pred_gt_val[scene_object_id_key] = { + 'instruction': scene_object_id_key.split('@')[1], + 'pred': candidates[scene_object_id_key], + 'gt': corpus[scene_object_id_key], + } + json.dump(pred_gt_val, f, indent=4) + print(f'save pred_gt_val {curr_iter}') barrier() if is_primary():