-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit b80f082
Showing
73 changed files
with
3,827 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
"""Copyright (c) Dreamfold.""" | ||
from .version import __version__ |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Copyright (c) Dreamfold.""" | ||
import torch | ||
from torch import Tensor | ||
from functorch import vmap | ||
from .so3_helpers import so3_exp_map | ||
|
||
def f_igso3_small(omega, sigma): | ||
""" Borrowed from: https://github.com/tomato1mule/edf/blob/1dd342e849fcb34d3eb4b6ad2245819abbd6c812/edf/dist.py#L99 | ||
This function implements the approximation of the density function of omega of the isotropic Gaussian distribution. | ||
""" | ||
# TODO: check for stability and maybe replace by limit in 0 for small values | ||
|
||
#TODO: figure out scaling constant: eps = eps/2 | ||
eps = (sigma / torch.sqrt(torch.tensor([2])).to(device=omega.device))**2 | ||
|
||
pi = torch.Tensor([torch.pi]).to(device=omega.device) | ||
|
||
small_number = 1e-9 | ||
small_num = small_number / 2 | ||
small_dnm = (1-torch.exp(-1. * pi**2 / eps)*(2 - 4 * (pi**2) / eps)) * small_number | ||
|
||
return (0.5 * torch.sqrt(pi) * (eps ** -1.5) * | ||
torch.exp((eps - (omega**2 / eps))/4) / (torch.sin(omega/2) + small_num) * | ||
(small_dnm + omega - ((omega - 2*pi)*torch.exp(pi * (omega - pi) / eps) | ||
+ (omega + 2*pi)*torch.exp(-pi * (omega+pi) / eps)))) | ||
|
||
|
||
# Marginal density of rotation angle for uniform density on SO(3) | ||
def angle_density_unif(omega): | ||
return (1-torch.cos(omega))/torch.pi | ||
|
||
def interp(x: Tensor, xp: Tensor, fp: Tensor) -> Tensor: | ||
"""One-dimensional linear interpolation for monotonically increasing sample | ||
points. | ||
Returns the one-dimensional piecewise linear interpolant to a function with | ||
given discrete data points :math:`(xp, fp)`, evaluated at :math:`x`. | ||
Args: | ||
x: the :math:`x`-coordinates at which to evaluate the interpolated | ||
values. | ||
xp: the :math:`x`-coordinates of the data points, must be increasing. | ||
fp: the :math:`y`-coordinates of the data points, same length as `xp`. | ||
Returns: | ||
the interpolated values, same size as `x`. | ||
""" | ||
m = (fp[1:] - fp[:-1]) / (xp[1:] - xp[:-1]) # slope | ||
b = fp[:-1] - (m * xp[:-1]) # y-intercept | ||
|
||
indicies = torch.sum(torch.ge(x[:, None], xp[None, :]), dim=1) - 1 | ||
indicies = torch.clamp(indicies, 0, len(m) - 1) | ||
|
||
return m[indicies] * x + b[indicies] | ||
|
||
|
||
def _f(omega, eps): | ||
return f_igso3_small(omega, eps) | ||
|
||
def _pdf(omega, eps): | ||
f_unif = angle_density_unif(omega) | ||
return _f(omega, eps) * f_unif | ||
|
||
def _sample(eps, n): | ||
# sample n points from IGSO3(I, eps) | ||
num_omegas = 1024 | ||
omega_grid = torch.linspace(0, torch.pi, num_omegas+1).to(eps.device)[1:] # skip omega=0 | ||
# numerical integration of (1-cos(omega))/pi*f_igso3(omega, eps) over omega | ||
pdf = _pdf(omega_grid, eps) | ||
dx = omega_grid[1] - omega_grid[0] | ||
cdf = torch.cumsum(pdf, dim=-1) * dx # cumalative density function | ||
|
||
# sample n points from the distribution | ||
rand_angle = torch.rand(n).to(eps.device) | ||
omegas = interp(rand_angle, cdf, omega_grid) | ||
axes = torch.randn(n, 3).to(eps.device) #sample axis uniformly | ||
axis_angle = omegas[..., None] * axes / torch.linalg.norm(axes, dim=-1, keepdim=True) | ||
return axis_angle | ||
|
||
def _batch_sample(mu, eps, n): | ||
aa_samples = vmap(_sample, in_dims=(0, None), randomness="different")(eps, n).squeeze() | ||
return mu @ so3_exp_map(aa_samples) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
"""Copyright (c) Dreamfold.""" | ||
import matplotlib | ||
|
||
matplotlib.use("Agg") | ||
import torch | ||
import argparse | ||
|
||
from torch import nn | ||
from torch.utils.data import DataLoader | ||
|
||
#import plotting | ||
import os | ||
import json | ||
import ipdb | ||
|
||
MY_EPS = {torch.float32: 1e-4, torch.float64: 1e-8} | ||
|
||
# noinspection PyShadowingNames | ||
def proj2manifold(x): | ||
u, _, vT = torch.linalg.svd(x) | ||
return u @ vT | ||
|
||
|
||
# noinspection PyShadowingNames | ||
def proj2tangent(x, v): | ||
shape = v.shape | ||
m = x.size(1) | ||
if v.ndim == 2: | ||
v = v.view(-1, m, m) | ||
return 0.5 * (v - x @ v.permute(0, 2, 1) @ x).view(shape) | ||
|
||
EPS = 1e-9 | ||
MIN_NORM = 1e-15 | ||
|
||
|
||
# noinspection PyShadowingNames,PyAbstractClass | ||
class Manifold: | ||
def __init__(self, ambient_dim, manifold_dim): | ||
""" | ||
ambient_dim: dimension of ambient space | ||
manifold_dim: dimension of manifold | ||
""" | ||
self.ambient_dim = ambient_dim | ||
self.manifold_dim = manifold_dim | ||
|
||
@staticmethod | ||
def phi(x): | ||
""" | ||
x: point on ambient space | ||
return: point on euclidean patch | ||
""" | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def invphi(x_tilde): | ||
""" | ||
x_tilde: point on euclidean patch | ||
return: point on ambient space | ||
""" | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def project(x): | ||
""" | ||
x: manifold point on ambient space | ||
return: projection of x onto the manifold in the ambient space | ||
""" | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def g(x): | ||
""" | ||
x: manifold point on ambient space | ||
return: differentiable determinant of the metric tensor at point x | ||
""" | ||
raise NotImplementedError | ||
|
||
def norm(self, x, u, squared=False, keepdim=False): | ||
norm_sq = self.inner(x, u, u, keepdim) | ||
norm_sq.data.clamp_(MY_EPS[u.dtype]) | ||
return norm_sq if squared else norm_sq.sqrt() | ||
|
||
# noinspection PyShadowingNames,PyAbstractClass | ||
class OrthogonalGroup(Manifold): | ||
def __init__(self): | ||
super(OrthogonalGroup, self).__init__(ambient_dim=9, manifold_dim=3) | ||
|
||
@staticmethod | ||
def proj2manifold(x): | ||
return proj2manifold(x) | ||
|
||
@staticmethod | ||
def proj2tangent(x, v): | ||
return proj2tangent(x, v) |
Oops, something went wrong.