From c01bc22eacec48f3dd0d4c7d4384c258102341a1 Mon Sep 17 00:00:00 2001 From: Yezhen Cong <52420115+THU17cyz@users.noreply.github.com> Date: Fri, 15 Oct 2021 20:42:14 -0700 Subject: [PATCH] Fix bug --- sunrgbd/sunrgbd_ssl_dataset.py | 5 +++-- train.py | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/sunrgbd/sunrgbd_ssl_dataset.py b/sunrgbd/sunrgbd_ssl_dataset.py index a5b755b..2319d37 100644 --- a/sunrgbd/sunrgbd_ssl_dataset.py +++ b/sunrgbd/sunrgbd_ssl_dataset.py @@ -183,7 +183,7 @@ def __getitem__(self, idx): class SunrgbdSSLUnlabeledDataset(Dataset): def __init__(self, labeled_sample_list=None, num_points=20000, use_color=False, use_height=False, use_v1=False, - aug_num=1, scan_idx_list=None, load_labels=None): + aug_num=1, scan_idx_list=None, load_labels=None, augment=True): print('----------------Sunrgbd Unlabeled Dataset Initialization----------------') if use_v1: self.data_path = os.path.join(ROOT_DIR, 'sunrgbd/sunrgbd_pc_bbox_votes_50k_v1_train') @@ -211,6 +211,7 @@ def __init__(self, labeled_sample_list=None, num_points=20000, use_color=False, self.use_height = use_height self.aug_num = aug_num self.load_labels = load_labels + self.augment = augment if load_labels: print('Warning! Loading labels for analysis') @@ -308,4 +309,4 @@ def __getitem__(self, idx): ret_dict['scale'] = np.array(scale_ratio).astype(np.float32) ret_dict['scan_idx'] = np.array(idx).astype(np.int64) ret_dict['supervised_mask'] = np.array(0).astype(np.int64) - return ret_dict \ No newline at end of file + return ret_dict diff --git a/train.py b/train.py index 44b8d7d..4cfceed 100644 --- a/train.py +++ b/train.py @@ -122,7 +122,8 @@ def my_worker_init_fn(worker_id): use_color=FLAGS.use_color, use_height=(not FLAGS.no_height), use_v1=(not FLAGS.use_sunrgbd_v2), - load_labels=FLAGS.view_stats) + load_labels=FLAGS.view_stats, + augment=True) TEST_DATASET = SunrgbdDetectionVotesDataset('val', num_points=NUM_POINT, augment=False, use_color=FLAGS.use_color, use_height=(not FLAGS.no_height), @@ -142,7 +143,8 @@ def my_worker_init_fn(worker_id): num_points=NUM_POINT, use_color=FLAGS.use_color, use_height=(not FLAGS.no_height), - load_labels=FLAGS.view_stats) + load_labels=FLAGS.view_stats, + augment=True) TEST_DATASET = ScannetDetectionDataset('val', num_points=NUM_POINT, augment=False, use_color=FLAGS.use_color, use_height=(not FLAGS.no_height))