diff --git a/FeatureMatching/README.md b/FeatureMatching/README.md index 498f8e3..cdf5198 100644 --- a/FeatureMatching/README.md +++ b/FeatureMatching/README.md @@ -11,3 +11,18 @@ The parameter top K can be reduced for speeding up. The performance won't drop t ``` sh scripts/reproduce_test/indoor_ds_quadtree.sh ``` + +## Sample to Test an Image Pair + +- Download outdoor weights from this [Github Release](https://github.com/Tangshitao/QuadTreeAttention/releases/tag/QuadTreeAttention_feature_match) + +- Run the following command: + +```bash +python3 test_one_image_pair_sample.py --weight_path ./outdoor.ckpt \ + --config_path ./configs/loftr/outdoor/loftr_ds_quadtree.py \ + --query_path ./assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg \ + --ref_path ./assets/phototourism_sample_images/london_bridge_49190386_5209386933.jpg +``` + +![Feature Matching Sample](./docs/images/feature_matching_result.jpg) diff --git a/FeatureMatching/docs/images/feature_matching_result.jpg b/FeatureMatching/docs/images/feature_matching_result.jpg new file mode 100644 index 0000000..f228073 Binary files /dev/null and b/FeatureMatching/docs/images/feature_matching_result.jpg differ diff --git a/FeatureMatching/test_one_image_pair_sample.py b/FeatureMatching/test_one_image_pair_sample.py new file mode 100644 index 0000000..e17ddc2 --- /dev/null +++ b/FeatureMatching/test_one_image_pair_sample.py @@ -0,0 +1,117 @@ +from typing import Tuple + +import cv2 +import numpy as np +import torch + +from src.config.default import get_cfg_defaults +from src.loftr import LoFTR +from src.utils.misc import lower_config + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser("test quadtree attention-based feature matching") + parser.add_argument("--weight_path", type=str, required=True) + parser.add_argument("--config_path", type=str, required=True) + parser.add_argument("--query_path", type=str, required=True) + parser.add_argument("--ref_path", type=str, required=True) + parser.add_argument("--confidence_thresh", type=float, default=0.5) + + return parser.parse_args() + + +def main(): + args = get_args() + + config = get_cfg_defaults() + config.merge_from_file(args.config_path) + config = lower_config(config) + + matcher = LoFTR(config=config["loftr"]) + state_dict = torch.load(args.weight_path, map_location="cpu")["state_dict"] + matcher.load_state_dict(state_dict, strict=True) + + query_image = cv2.imread(args.query_path, 1) + ref_image = cv2.imread(args.ref_path, 1) + + new_shape = (480, 640) + + query_gray = cv2.cvtColor(query_image, cv2.COLOR_BGR2GRAY) + ref_gray = cv2.cvtColor(ref_image, cv2.COLOR_BGR2GRAY) + + query_gray = cv2.resize(query_gray, new_shape[::-1]) + ref_gray = cv2.resize(ref_gray, new_shape[::-1]) + + with torch.no_grad(): + batch = { + "image0": load_torch_image(query_gray), + "image1": load_torch_image(ref_gray), + } + + matcher.eval() + matcher.to("cuda") + matcher(batch) + + query_kpts = batch["mkpts0_f"].cpu().numpy() + ref_kpts = batch["mkpts1_f"].cpu().numpy() + confidences = batch["mconf"].cpu().numpy() + del batch + + conf_mask = np.where(confidences > args.confidence_thresh) + query_kpts = query_kpts[conf_mask] + ref_kpts = ref_kpts[conf_mask] + + def _np_to_cv2_kpts(np_kpts): + cv2_kpts = [] + for np_kpt in np_kpts: + cur_cv2_kpt = cv2.KeyPoint() + cur_cv2_kpt.pt = tuple(np_kpt) + cv2_kpts.append(cur_cv2_kpt) + return cv2_kpts + + query_shape = query_image.shape[:2] + ref_shape = ref_image.shape[:2] + query_kpts = resample_kpts( + query_kpts, + query_shape[0] / new_shape[0], + query_shape[1] / new_shape[1], + ) + + ref_kpts = resample_kpts( + ref_kpts, + ref_shape[0] / new_shape[0], + ref_shape[1] / new_shape[1], + ) + query_kpts, ref_kpts = _np_to_cv2_kpts(query_kpts), _np_to_cv2_kpts(ref_kpts) + + matched_image = cv2.drawMatches( + query_image, + query_kpts, + ref_image, + ref_kpts, + [ + cv2.DMatch(_queryIdx=idx, _trainIdx=idx, _distance=0) + for idx in range(len(query_kpts)) + ], + None, + flags=2, + ) + cv2.imwrite("result.jpg", matched_image) + + +def resample_kpts(kpts: np.ndarray, height_ratio, width_ratio): + kpts[:, 0] *= width_ratio + kpts[:, 1] *= height_ratio + + return kpts + + +def load_torch_image(image): + image = torch.from_numpy(image)[None][None].cuda() / 255 + return image + + +if __name__ == "__main__": + main()