Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialize shonan using minimum spanning tree #777

Open
wants to merge 28 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
f936169
Initialize shonan using minimum spanning tree
ayushbaid Feb 1, 2024
1dea8a6
Move out MST initialization into utility function
ayushbaid Feb 3, 2024
1a1de1f
Add typinh and docstring
ayushbaid Feb 20, 2024
d530283
Use number of inliers as edge weights for MST
ayushbaid Feb 27, 2024
91c4cc4
Remove unused import
ayushbaid Feb 29, 2024
63eeccd
Fix bugs and add docstring
ayushbaid Feb 29, 2024
f5f61d5
cleanup docstrings
Mar 11, 2024
9770142
add more tests
Mar 11, 2024
c3a6589
add different impl
Mar 11, 2024
8e0bafe
clean up notation
Mar 11, 2024
ebfff90
clean up test
Mar 11, 2024
e6d200a
python black fixes
Mar 11, 2024
83c354b
fix flake8
Mar 11, 2024
c655512
use i1 < i2 convention for pair indices
Mar 12, 2024
fbdf99a
use v_corr_idxs instead of two_view_estimation_reports to get inlier …
Mar 13, 2024
3a1ec06
python black reformat
Mar 13, 2024
429177a
fix import error
Mar 13, 2024
beef33b
Control MST init with flag plus fixes
ayushbaid May 14, 2024
c01f254
Log initialization technique
ayushbaid May 14, 2024
4d6c3a5
Add unit test comparing initializations
ayushbaid May 14, 2024
42b3e70
Log shonan optimality
ayushbaid May 14, 2024
8848d76
Remove optimality logging
ayushbaid May 14, 2024
980b94d
Remove unused import
ayushbaid May 14, 2024
f55f2b3
Merge master
ayushbaid May 21, 2024
337c6db
Add unit test for initialization on larger scene
ayushbaid May 26, 2024
dde24e8
env v2
ayushbaid May 13, 2024
725aedf
Merge branch 'master' into feature/shonan_mst_init
ayushbaid May 26, 2024
35b3a7b
Remove duplicate args plus update process meta
ayushbaid May 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions environment_v2_linux_cpuonly.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
name: gtsfm-v2
channels:
# for priority order, we prefer pytorch as the highest priority as it supplies
# latest stable packages for numerous deep learning based methods. conda-forge
# supplies higher versions of packages like opencv compared to the defaults
# channel.
- pytorch
- conda-forge
dependencies:
# python essentials
- python
- pip
# formatting and dev environment
- black
- coverage
- mypy
- pylint
- pytest
- flake8
- isort
# dask and related
- dask # same as dask[complete] pip distribution
- asyncssh
- python-graphviz
# core functionality and APIs
- matplotlib
- networkx
- numpy
- nodejs
- pandas
- pillow
- scikit-learn
- seaborn
- scipy
- hydra-core
- gtsam
# 3rd party algorithms for different modules
- cpuonly # replacement of cudatoolkit for cpu only machines
- pytorch
- torchvision
- kornia
- pycolmap
- opencv
# io
- h5py
- plotly
- tabulate
- simplejson
- open3d
- colour
- pydot
- trimesh
# testing
- parameterized
# - pip:
# - pydegensac
3 changes: 2 additions & 1 deletion gtsfm/averaging/rotation/rotation_averaging_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ class RotationAveragingBase(GTSFMProcess):
rotations.
"""

@staticmethod
def get_ui_metadata() -> UiMetadata:
"""Returns data needed to display node and edge info for this process in the process graph."""

return UiMetadata(
display_name="Rotation Averaging",
input_products=("View-Graph Relative Rotations", "Relative Pose Priors"),
input_products=("View-Graph Relative Rotations", "Relative Pose Priors", "Verified Correspondences"),
output_products=("Global Rotations",),
parent_plate="Sparse Reconstruction",
)
Expand Down
33 changes: 28 additions & 5 deletions gtsfm/averaging/rotation/shonan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
Rot3,
ShonanAveraging3,
ShonanAveragingParameters3,
Values,
)

import gtsfm.utils.logger as logger_utils
import gtsfm.utils.rotation as rotation_util
from gtsfm.averaging.rotation.rotation_averaging_base import RotationAveragingBase
from gtsfm.common.pose_prior import PosePrior

Expand All @@ -38,7 +40,10 @@ class ShonanRotationAveraging(RotationAveragingBase):
"""Performs Shonan rotation averaging."""

def __init__(
self, two_view_rotation_sigma: float = _DEFAULT_TWO_VIEW_ROTATION_SIGMA, weight_by_inliers: bool = True
self,
two_view_rotation_sigma: float = _DEFAULT_TWO_VIEW_ROTATION_SIGMA,
weight_by_inliers: bool = True,
use_mst_init: bool = False,
) -> None:
"""Initializes module.

Expand All @@ -50,10 +55,11 @@ def __init__(
of inlier correspondences per edge.
"""
super().__init__()
self._two_view_rotation_sigma = two_view_rotation_sigma
self._p_min = 3
self._p_max = 64
self._two_view_rotation_sigma = two_view_rotation_sigma
self._weight_by_inliers = weight_by_inliers
self._use_mst_init = use_mst_init

def __get_shonan_params(self) -> ShonanAveragingParameters3:
lm_params = LevenbergMarquardtParams.CeresDefaults()
Expand Down Expand Up @@ -108,7 +114,7 @@ def get_isotropic_noise_model_sigma(covariance: np.ndarray) -> float:
return measurements

def _run_with_consecutive_ordering(
self, num_connected_nodes: int, measurements: gtsam.BinaryMeasurementsRot3
self, num_connected_nodes: int, measurements: gtsam.BinaryMeasurementsRot3, initial: Optional[Values]
) -> List[Optional[Rot3]]:
"""Run the rotation averaging on a connected graph w/ N keys ordered consecutively [0,...,N-1].

Expand All @@ -134,7 +140,9 @@ def _run_with_consecutive_ordering(
)
shonan = ShonanAveraging3(measurements, self.__get_shonan_params())

initial = shonan.initializeRandomly()
if initial is None:
logger.info("Using random initialization for Shonan")
initial = shonan.initializeRandomly()
logger.info("Initial cost: %.5f", shonan.cost(initial))
result, _ = shonan.run(initial, self._p_min, self._p_max)
logger.info("Final cost: %.5f", shonan.cost(result))
Expand Down Expand Up @@ -203,13 +211,28 @@ def run_rotation_averaging(
if (i1, i2) in i2Ri1_dict
}

# Use negative of the number of correspondences as the edge weight.
initial_values: Optional[Values] = None
if self._use_mst_init:
logger.info("Using MST initialization for Shonan")
wRi_initial_ = rotation_util.initialize_global_rotations_using_mst(
len(nodes_with_edges),
i2Ri1_dict_remapped,
edge_weights={
(i1, i2): -num_correspondences_dict.get((i1, i2), 0) for i1, i2 in i2Ri1_dict_remapped.keys()
},
)
initial_values = Values()
for i, wRi in enumerate(wRi_initial_):
initial_values.insert(i, wRi)

def _create_factors_and_run() -> List[Rot3]:
measurements: gtsam.BinaryMeasurementsRot3 = self.__measurements_from_2view_relative_rotations(
i2Ri1_dict=i2Ri1_dict_remapped, num_correspondences_dict=num_correspondences_dict
)
measurements.extend(self._measurements_from_pose_priors(i1Ti2_priors, old_to_new_idxs))
wRi_list_subset = self._run_with_consecutive_ordering(
num_connected_nodes=len(nodes_with_edges), measurements=measurements
num_connected_nodes=len(nodes_with_edges), measurements=measurements, initial=initial_values
)
return wRi_list_subset

Expand Down
4 changes: 3 additions & 1 deletion gtsfm/utils/geometry_comparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ def compare_rotations(
Args:
aTi_list: 1st list of rotations.
bTi_list: 2nd list of rotations.
angular_error_threshold_degrees: the threshold for angular error between two rotations.
angular_error_threshold_degrees: Threshold for angular error between two rotations.

Returns:
Result of the comparison.
"""
Expand All @@ -55,6 +56,7 @@ def compare_rotations(
relative_rotations_angles = np.array(
[compute_relative_rotation_angle(aRi, aRi_) for (aRi, aRi_) in zip(aRi_list, aRi_list_)], dtype=np.float32
)
print(relative_rotations_angles)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be removed

return np.all(relative_rotations_angles < angular_error_threshold_degrees)


Expand Down
76 changes: 76 additions & 0 deletions gtsfm/utils/rotation.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: would spanning_tree be a better name for this file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I indented to keep all rotations related util functions here. I feel this is not as generic to be named a spanning tree right now because the args are rotations and not a generic type.

Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Utility functions for rotations.

Authors: Ayush Baid
"""

from typing import Dict, List, Tuple

import networkx as nx
import numpy as np
from gtsam import Rot3


def random_rotation(angle_scale_factor: float = 0.1) -> Rot3:
"""Sample a random rotation by generating a sample from the 4d unit sphere."""
q = np.random.rand(4)
# make unit-length quaternion
q /= np.linalg.norm(q)
qw, qx, qy, qz = q
R = Rot3(qw, qx, qy, qz)
axis, angle = R.axisAngle()
angle = angle * angle_scale_factor
return Rot3.AxisAngle(axis.point3(), angle)


def initialize_global_rotations_using_mst(
num_images: int, i2Ri1_dict: Dict[Tuple[int, int], Rot3], edge_weights: Dict[Tuple[int, int], int]
) -> List[Rot3]:
"""Initializes rotations using minimum spanning tree (weighted by number of correspondences).

Args:
num_images: Number of images in the scene.
i2Ri1_dict: Dictionary of relative rotations (i1, i2): i2Ri1.
edge_weights: Weight of the edges (i1, i2). All edges in i2Ri1 must have an edge weight.

Returns:
Global rotations wRi initialized using an MST. Randomly initialized if we have a forest.
"""
# Create a graph from the relative rotations dictionary.
graph = nx.Graph()
for i1, i2 in i2Ri1_dict.keys():
graph.add_edge(i1, i2, weight=edge_weights[(i1, i2)])

if not nx.is_connected(graph):
raise ValueError("Relative rotation graph is not connected")

# Compute the Minimum Spanning Tree (MST)
mst = nx.minimum_spanning_tree(graph)

# MST graph.
G = nx.Graph()
G.add_edges_from(mst.edges)

wRi_list: List[Rot3] = [Rot3()] * num_images
# Choose origin node.
origin_node = list(G.nodes)[0]
wRi_list[origin_node] = Rot3()

# Ignore 0th node, as we already set its global pose as the origin
for dst_node in list(G.nodes)[1:]:
# Determine the path to this node from the origin. ordered from [origin_node,...,dst_node]
path = nx.shortest_path(G, source=origin_node, target=dst_node)

# Chain relative rotations w.r.t. origin node. Initialize as identity Rot3 w.r.t origin node `i1`.
wRi1 = Rot3()
for i1, i2 in zip(path[:-1], path[1:]):
# NOTE: i1, i2 may not be in sorted order here. May need to reverse ordering.
if i1 < i2:
i1Ri2 = i2Ri1_dict[(i1, i2)].inverse()
else:
i1Ri2 = i2Ri1_dict[(i2, i1)]
Comment on lines +67 to +70
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not guaranteed, right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm why not?

# Path order is (origin -> ... -> i1 -> i2 -> ... -> dst_node). Set `i2` to be new `i1`.
wRi1 = wRi1 * i1Ri2

wRi_list[dst_node] = wRi1

return wRi_list
Loading
Loading