Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: for the same k, node_cost is double counted #89

Merged
merged 55 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
5d82d97
change python tests/test_a_star/prepare_for_test.py to python pygmtoo…
heatingma Aug 1, 2023
4c05275
add bdist_wheel
heatingma Aug 1, 2023
4155230
delete test/test_a_star
heatingma Aug 1, 2023
b083a7b
add astar ori_files
heatingma Aug 1, 2023
678a49e
Create publish_req.txt
heatingma Aug 1, 2023
6cf8381
Delete a_star.tar.gz
heatingma Aug 1, 2023
3678388
version is a23
heatingma Aug 1, 2023
ef3f41c
Update setup.py
heatingma Aug 1, 2023
439c757
delete some notes
heatingma Aug 1, 2023
6e0ea0e
add alternate url to fix download problem
heatingma Aug 2, 2023
18d3873
add return
heatingma Aug 3, 2023
be38bf0
fix the md5 problem with genn_astar pretrained models
heatingma Aug 3, 2023
d07498e
add the missed ','
heatingma Aug 3, 2023
b5f3133
fix the "diff=200.0" problem
heatingma Aug 3, 2023
57230e8
swap the url and the url_alter
heatingma Aug 3, 2023
55379dc
change the astar_pretrain_path
heatingma Aug 3, 2023
bde8106
Revert "swap the url and the url_alter"
heatingma Aug 4, 2023
220f782
Revert "change the astar_pretrain_path"
heatingma Aug 4, 2023
3143592
change the pretrained path
heatingma Aug 4, 2023
61181f8
only small files has url_alter
heatingma Aug 4, 2023
0642a14
add new url for pretrained modles
heatingma Aug 4, 2023
0eec2ee
add new download pretrained models' paths for jittor backend
heatingma Aug 4, 2023
52143ae
add new download pretrained models' paths for jittor backend
heatingma Aug 4, 2023
329a115
add new download pretrained models' paths for jittor backend
heatingma Aug 4, 2023
f3a78c0
add new download path for pytorch backend pretrained models
heatingma Aug 4, 2023
d8f821b
add new download path for paddle backend
heatingma Aug 4, 2023
c368697
delete some unused url
heatingma Aug 5, 2023
0ae85c7
add new alternate download path for cie and pca
heatingma Aug 5, 2023
35a4e02
don't test neural_solvers now
heatingma Aug 5, 2023
6f42e7c
only test neural_solvers
heatingma Aug 5, 2023
62f2bef
only test neural_solvers
heatingma Aug 5, 2023
9c76481
only test neural_solvers
heatingma Aug 5, 2023
a591e7e
only test neural
heatingma Aug 5, 2023
ceeaf60
add the forget ","
heatingma Aug 5, 2023
18c98a9
add all tests
heatingma Aug 5, 2023
f408a51
add new url_path and change the download func
heatingma Aug 6, 2023
7e4a9de
delete dropout, trust_fact and no_pred_size for astar
heatingma Aug 7, 2023
6e6f773
delete dropout for genn_astar
heatingma Aug 7, 2023
a3ef42e
delete some parameters for astar and genn_astar
heatingma Aug 7, 2023
86ab428
Merge branch 'Thinklab-SJTU:main' into main
heatingma Aug 7, 2023
ba4d408
python
heatingma Aug 8, 2023
6df2f3f
Merge branch 'Thinklab-SJTU:main' into main
heatingma Aug 8, 2023
66a4b05
Merge branch 'Thinklab-SJTU:main' into main
heatingma Aug 22, 2023
37f88af
Merge branch 'Thinklab-SJTU:main' into main
heatingma Sep 4, 2023
8616fb4
change a_star to astar
heatingma Sep 12, 2023
de06854
change the function name "astar" from cython to "c_astar"
heatingma Oct 5, 2023
b0c8c0d
fix: 'module' object is not callable
heatingma Oct 5, 2023
57b0c26
change astar to c_astar
heatingma Oct 5, 2023
3a1373a
Merge branch 'Thinklab-SJTU:main' into main
heatingma Oct 11, 2023
086e182
Merge branch 'Thinklab-SJTU:main' into main
heatingma Nov 7, 2023
ed731b3
Merge branch 'Thinklab-SJTU:main' into main
heatingma Dec 1, 2023
e6237d9
Fix: for the same k, node_cost is double counted
heatingma Dec 1, 2023
cfe5b5d
Merge branch 'main' of https://github.com/heatingma/pygmtools
heatingma Dec 1, 2023
8bb5841
Add alter urls for datasets
heatingma Dec 2, 2023
1b44be0
fix: No such file or directory: 'data/SPair-71k/Layout/small/trn.txt'
heatingma Dec 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions pygmtools/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,8 @@ def __init__(self, sets, obj_resize, **ds_dict):
SPLIT_OFFSET = dataset_cfg.WillowObject.SPLIT_OFFSET
TRAIN_SAME_AS_TEST = dataset_cfg.WillowObject.TRAIN_SAME_AS_TEST
RAND_OUTLIER = dataset_cfg.WillowObject.RAND_OUTLIER
URL = 'http://www.di.ens.fr/willow/research/graphlearning/WILLOW-ObjectClass_dataset.zip'
URL = ['http://www.di.ens.fr/willow/research/graphlearning/WILLOW-ObjectClass_dataset.zip',
'https://huggingface.co/heatingma/pygmtools/resolve/main/WILLOW-ObjectClass_dataset.zip']
if len(ds_dict.keys()) > 0:
if 'CLASSES' in ds_dict.keys():
CLASSES = ds_dict['CLASSES']
Expand Down Expand Up @@ -750,6 +751,8 @@ def __init__(self, sets, obj_resize, problem='2GM', **ds_dict):
COMB_CLS = dataset_cfg.SPair.COMB_CLS
SIZE = dataset_cfg.SPair.SIZE
ROOT_DIR = dataset_cfg.SPair.ROOT_DIR
URL = ['https://huggingface.co/heatingma/pygmtools/resolve/main/SPair-71k.tar.gz',
'http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz']
if len(ds_dict.keys()) > 0:
if 'TRAIN_DIFF_PARAMS' in ds_dict.keys():
TRAIN_DIFF_PARAMS = ds_dict['TRAIN_DIFF_PARAMS']
Expand All @@ -775,14 +778,12 @@ def __init__(self, sets, obj_resize, problem='2GM', **ds_dict):
)

assert not problem == 'MGM', 'No match found for problem {} in SPair-71k'.format(problem)
self.dataset_dir = 'data/SPair-71k'

if not os.path.exists(SPair71k_image_path):
assert ROOT_DIR == dataset_cfg.SPair.ROOT_DIR, 'you should not change ROOT_DIR unless the data have been manually downloaded'
self.download(url='http://cvlab.postech.ac.kr/research/SPair-71k/data/SPair-71k.tar.gz')

if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)

self.download(url=URL)

self.dataset_dir = 'data/SPair-71k'
self.obj_resize = obj_resize
self.sets = sets_translation_dict[sets]
self.ann_files = open(os.path.join(self.SPair71k_layout_path, self.SPair71k_dataset_size, self.sets + ".txt"), "r").read().split("\n")
Expand Down Expand Up @@ -815,20 +816,24 @@ def download(self, url=None, retries=5):
if not os.path.exists(dirs):
os.makedirs(dirs)
print('Downloading dataset SPair-71k...')
filename = "data/SPair-71k.tgz"
filename = "data/SPair-71k.tar.gz"
download(filename=filename, url=url, to_cache=False)
try:
tar = tarfile.open(filename, "r")
except tarfile.ReadError as err:
print('Warning: Content error. Retrying...\n', err)
os.remove(filename)
return self.download(url, retries - 1)


self.dataset_dir = 'data/SPair-71k'
if not os.path.exists(self.dataset_dir):
os.makedirs(self.dataset_dir)

file_names = tar.getnames()
print('Unzipping files...')
sleep(0.5)
for file_name in tqdm(file_names):
tar.extract(file_name, "data/")
tar.extract(file_name, self.dataset_dir)
tar.close()
try:
os.remove(filename)
Expand Down Expand Up @@ -1018,7 +1023,9 @@ def __init__(self, sets, obj_resize, **ds_dict):
CLASSES = dataset_cfg.IMC_PT_SparseGM.CLASSES
ROOT_DIR_NPZ = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_NPZ
ROOT_DIR_IMG = dataset_cfg.IMC_PT_SparseGM.ROOT_DIR_IMG
URL = 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=1Po9pRMWXTqKK2ABPpVmkcsOq-6K_2v-B'
URL = ['https://drive.google.com/u/0/uc?id=1bisri2Ip1Of3RsUA8OBrdH5oa6HlH3k-&export=download',
'https://huggingface.co/heatingma/pygmtools/resolve/main/IMC-PT-SparseGM.tar.gz']

if len(ds_dict.keys()) > 0:
if 'MAX_KPT_NUM' in ds_dict.keys():
MAX_KPT_NUM = ds_dict['MAX_KPT_NUM']
Expand Down Expand Up @@ -1190,7 +1197,8 @@ class CUB2011:
def __init__(self, sets, obj_resize, **ds_dict):
CLS_SPLIT = dataset_cfg.CUB2011.CLASS_SPLIT
ROOT_DIR = dataset_cfg.CUB2011.ROOT_DIR
URL = 'https://drive.google.com/u/0/uc?export=download&confirm=B8eu&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45'
URL = ['https://drive.google.com/u/0/uc?export=download&confirm=B8eu&id=1hbzc_P1FuxMkcabkgn9ZKinBwW683j45',
'https://huggingface.co/heatingma/pygmtools/resolve/main/CUB_200_2011.tar.gz']
if len(ds_dict.keys()) > 0:
if 'ROOT_DIR' in ds_dict.keys():
ROOT_DIR = ds_dict['ROOT_DIR']
Expand Down
3 changes: 1 addition & 2 deletions pygmtools/pytorch_astar_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def check_layer_parameter(params):
return True


def node_metric(node1, node2):

def node_metric(node1, node2):
encoding = torch.sum(torch.abs(node1.unsqueeze(2) - node2.unsqueeze(1)), dim=-1)
non_zero = torch.nonzero(encoding)
for i in range(non_zero.shape[0]):
Expand Down
16 changes: 9 additions & 7 deletions pygmtools/pytorch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,6 @@ def _astar(self, data: GraphPair):
ns_2 = torch.bincount(data.g2.batch)

adj_1 = to_dense_adj(edge_index_1, batch=batch_1, edge_attr=edge_attr_1)

dummy_adj_1 = torch.zeros(adj_1.shape[0], adj_1.shape[1] + 1, adj_1.shape[2] + 1, device=device)
dummy_adj_1[:, :-1, :-1] = adj_1
adj_2 = to_dense_adj(edge_index_2, batch=batch_2, edge_attr=edge_attr_2)
Expand Down Expand Up @@ -1055,12 +1054,15 @@ def net_prediction_cache(self, data: GraphPair, partial_pmat=None, return_ged_no
return ged

def heuristic_prediction_hun(self, k: torch.Tensor, n1, n2, partial_pmat):
k_prime = k.reshape(-1, n1 + 1, n2 + 1)
node_costs = torch.empty(k_prime.shape[0])
for i in range(k_prime.shape[0]):
_, node_costs[i] = hungarian_ged(k_prime[i], n1, n2)
node_cost_mat = node_costs.reshape(n1 + 1, n2 + 1)
self.heuristic_cache['node_cost'] = node_cost_mat
if 'node_cost' in self.heuristic_cache:
node_cost_mat = self.heuristic_cache['node_cost']
else:
k_prime = k.reshape(-1, n1 + 1, n2 + 1)
node_costs = torch.empty(k_prime.shape[0])
for i in range(k_prime.shape[0]):
_, node_costs[i] = hungarian_ged(k_prime[i], n1, n2)
node_cost_mat = node_costs.reshape(n1 + 1, n2 + 1)
self.heuristic_cache['node_cost'] = node_cost_mat

graph_1_mask = ~partial_pmat.sum(dim=-1).to(dtype=torch.bool)
graph_2_mask = ~partial_pmat.sum(dim=-2).to(dtype=torch.bool)
Expand Down
Loading