From 5795e2b21309f6beb690df4edd5d4c8599d60bac Mon Sep 17 00:00:00 2001 From: heatingma <115260102+heatingma@users.noreply.github.com> Date: Mon, 4 Dec 2023 14:17:26 +0800 Subject: [PATCH] rm redundant A-star code and backup links for datasets (#89) --- pygmtools/dataset.py | 32 +++++++++++++++++++----------- pygmtools/pytorch_astar_modules.py | 3 +-- pygmtools/pytorch_backend.py | 16 ++++++++------- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/pygmtools/dataset.py b/pygmtools/dataset.py index 41cc426..a0edab7 100644 --- a/pygmtools/dataset.py +++ b/pygmtools/dataset.py @@ -476,7 +476,8 @@ def __init__(self, sets, obj_resize, **ds_dict): SPLIT_OFFSET = dataset_cfg.WillowObject.SPLIT_OFFSET TRAIN_SAME_AS_TEST = dataset_cfg.WillowObject.TRAIN_SAME_AS_TEST RAND_OUTLIER = dataset_cfg.WillowObject.RAND_OUTLIER - URL = 'http://www.di.ens.fr/willow/research/graphlearning/WILLOW-ObjectClass_dataset.zip' + URL = ['http://www.di.ens.fr/willow/research/graphlearning/WILLOW-ObjectClass_dataset.zip', + 'https://huggingface.co/heatingma/pygmtools/resolve/main/WILLOW-ObjectClass_dataset.zip'] if len(ds_dict.keys()) > 0: if 'CLASSES' in ds_dict.keys(): CLASSES = ds_dict['CLASSES'] @@ -750,6 +751,8 @@ def __init__(self, sets, obj_resize, problem='2GM', **ds_dict): COMB_CLS = dataset_cfg.SPair.COMB_CLS SIZE = dataset_cfg.SPair.SIZE ROOT_DIR = dataset_cfg.SPair.ROOT_DIR + URL = ['https://huggingface.co/heatingma/pygmtools/resolve/main/SPair-71k.tar.gz', + 'http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz'] if len(ds_dict.keys()) > 0: if 'TRAIN_DIFF_PARAMS' in ds_dict.keys(): TRAIN_DIFF_PARAMS = ds_dict['TRAIN_DIFF_PARAMS'] @@ -775,14 +778,12 @@ def __init__(self, sets, obj_resize, problem='2GM', **ds_dict): ) assert not problem == 'MGM', 'No match found for problem {} in SPair-71k'.format(problem) - self.dataset_dir = 'data/SPair-71k' + if not os.path.exists(SPair71k_image_path): assert ROOT_DIR == dataset_cfg.SPair.ROOT_DIR, 'you should not change ROOT_DIR unless the data have been manually downloaded' - self.download(url='http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz') - - if not os.path.exists(self.dataset_dir): - os.makedirs(self.dataset_dir) - + self.download(url=URL) + + self.dataset_dir = 'data/SPair-71k' self.obj_resize = obj_resize self.sets = sets_translation_dict[sets] self.ann_files = open(os.path.join(self.SPair71k_layout_path, self.SPair71k_dataset_size, self.sets + ".txt"), "r").read().split("\n") @@ -815,7 +816,7 @@ def download(self, url=None, retries=5): if not os.path.exists(dirs): os.makedirs(dirs) print('Downloading dataset SPair-71k...') - filename = "data/SPair-71k.tgz" + filename = "data/SPair-71k.tar.gz" download(filename=filename, url=url, to_cache=False) try: tar = tarfile.open(filename, "r") @@ -823,12 +824,16 @@ def download(self, url=None, retries=5): print('Warning: Content error. Retrying...\n', err) os.remove(filename) return self.download(url, retries - 1) - + + self.dataset_dir = 'data/SPair-71k' + if not os.path.exists(self.dataset_dir): + os.makedirs(self.dataset_dir) + file_names = tar.getnames() print('Unzipping files...') sleep(0.5) for file_name in tqdm(file_names): - tar.extract(file_name, "data/") + tar.extract(file_name, self.dataset_dir) tar.close() try: os.remove(filename) @@ -1018,7 +1023,9 @@ def __init__(self, sets, obj_resize, **ds_dict): CLASSES = dataset_cfg.IMC_PT_SparseGM.CLASSES ROOT_DIR_NPZ = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_NPZ ROOT_DIR_IMG = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_IMG - URL = 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1Po9pRMWXTqKK2ABPpVmkcsOq-6K_2v-B' + URL = ['https://drive.google.com/u/0/uc?id=1bisri2Ip1Of3RsUA8OBrdH5oa6HlH3k-&export=download', + 'https://huggingface.co/heatingma/pygmtools/resolve/main/IMC-PT-SparseGM.tar.gz'] + if len(ds_dict.keys()) > 0: if 'MAX_KPT_NUM' in ds_dict.keys(): MAX_KPT_NUM = ds_dict['MAX_KPT_NUM'] @@ -1190,7 +1197,8 @@ class CUB2011: def __init__(self, sets, obj_resize, **ds_dict): CLS_SPLIT = dataset_cfg.CUB2011.CLASS_SPLIT ROOT_DIR = dataset_cfg.CUB2011.ROOT_DIR - URL = 'https://drive.google.com/u/0/uc?export=download&confirm=B8eu&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45' + URL = ['https://drive.google.com/u/0/uc?export=download&confirm=B8eu&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45', + 'https://huggingface.co/heatingma/pygmtools/resolve/main/CUB_200_2011.tar.gz'] if len(ds_dict.keys()) > 0: if 'ROOT_DIR' in ds_dict.keys(): ROOT_DIR = ds_dict['ROOT_DIR'] diff --git a/pygmtools/pytorch_astar_modules.py b/pygmtools/pytorch_astar_modules.py index c527910..52868c2 100644 --- a/pygmtools/pytorch_astar_modules.py +++ b/pygmtools/pytorch_astar_modules.py @@ -47,8 +47,7 @@ def check_layer_parameter(params): return True -def node_metric(node1, node2): - +def node_metric(node1, node2): encoding = torch.sum(torch.abs(node1.unsqueeze(2) - node2.unsqueeze(1)), dim=-1) non_zero = torch.nonzero(encoding) for i in range(non_zero.shape[0]): diff --git a/pygmtools/pytorch_backend.py b/pygmtools/pytorch_backend.py index 6bac03d..274ab90 100644 --- a/pygmtools/pytorch_backend.py +++ b/pygmtools/pytorch_backend.py @@ -953,7 +953,6 @@ def _astar(self, data: GraphPair): ns_2 = torch.bincount(data.g2.batch) adj_1 = to_dense_adj(edge_index_1, batch=batch_1, edge_attr=edge_attr_1) - dummy_adj_1 = torch.zeros(adj_1.shape[0], adj_1.shape[1] + 1, adj_1.shape[2] + 1, device=device) dummy_adj_1[:, :-1, :-1] = adj_1 adj_2 = to_dense_adj(edge_index_2, batch=batch_2, edge_attr=edge_attr_2) @@ -1055,12 +1054,15 @@ def net_prediction_cache(self, data: GraphPair, partial_pmat=None, return_ged_no return ged def heuristic_prediction_hun(self, k: torch.Tensor, n1, n2, partial_pmat): - k_prime = k.reshape(-1, n1 + 1, n2 + 1) - node_costs = torch.empty(k_prime.shape[0]) - for i in range(k_prime.shape[0]): - _, node_costs[i] = hungarian_ged(k_prime[i], n1, n2) - node_cost_mat = node_costs.reshape(n1 + 1, n2 + 1) - self.heuristic_cache['node_cost'] = node_cost_mat + if 'node_cost' in self.heuristic_cache: + node_cost_mat = self.heuristic_cache['node_cost'] + else: + k_prime = k.reshape(-1, n1 + 1, n2 + 1) + node_costs = torch.empty(k_prime.shape[0]) + for i in range(k_prime.shape[0]): + _, node_costs[i] = hungarian_ged(k_prime[i], n1, n2) + node_cost_mat = node_costs.reshape(n1 + 1, n2 + 1) + self.heuristic_cache['node_cost'] = node_cost_mat graph_1_mask = ~partial_pmat.sum(dim=-1).to(dtype=torch.bool) graph_2_mask = ~partial_pmat.sum(dim=-2).to(dtype=torch.bool)