diff --git a/README.md b/README.md index b4d8ad6..28078de 100644 --- a/README.md +++ b/README.md @@ -83,10 +83,4 @@ implementation. `python bilateral.py input.png output.png 20 0.25` - +![Input Image](images/wimr_small.png) ![Filtered](images/filtered.png) diff --git a/crfrnn/crf.py b/crfrnn/crf.py index 441f7a2..fbd9f1d 100644 --- a/crfrnn/crf.py +++ b/crfrnn/crf.py @@ -1,35 +1,81 @@ import torch as th import torch.nn as nn -import numpy as np +import torch.nn.functional as thf -from permutohedral.hash import get_hash_cap, make_hashtable -from permutohedral.gfilt import gfilt, make_gfilt_buffers +from permutohedral.gfilt import gfilt + + +def gaussian_filter(ref, val, kstd): + return gfilt(ref / kstd[:, :, None, None], val) -def gaussian_filter(ref, val, kstd, hb=None, gb=None): - return gfilt(ref / kstd[:, None, None], val, hb, gb) def mgrid(h, w, dev): y = th.arange(0, h, device=dev).repeat(w, 1).t() x = th.arange(0, w, device=dev).repeat(h, 1) return th.stack([y, x], 0) + def gkern(std, chans, dev): sig_sq = std ** 2 r = sig_sq if (sig_sq % 2) else sig_sq - 1 s = 2 * r + 1 - k = th.exp(-((mgrid(s, s, dev)-r) ** 2).sum(0) / (2 * sig_sq)) + k = th.exp(-((mgrid(s, s, dev) - r) ** 2).sum(0) / (2 * sig_sq)) W = th.zeros(chans, chans, s, s, device=dev) for i in range(chans): W[i, i] = k / k.sum() return W + class CRF(nn.Module): - def __init__(self, sxy_bf=70, sc_bf=12, compat_bf=4, sxy_spatial=6, - compat_spatial=2, num_iter=5, normalize_final_iter=True, - trainable_kstd=False): + def __init__( + self, + n_ref: int, + n_out: int, + sxy_bf: float = 70, + sc_bf: float = 12, + compat_bf: float = 4, + sxy_spatial: float = 6, + compat_spatial: float = 2, + num_iter: int = 5, + normalize_final_iter: bool = True, + trainable_kstd: bool = False, + ): + """Implements fast approximate mean-field inference for a + fully-connected CRF with Gaussian edge potentials within a neural + network layer using fast bilateral filtering. + + Args: + n_ref: Number of channels in the reference images. + + n_out: Number of labels. + + sxy_bf: Spatial standard deviation of the bilateral filter. + + sc_bf: Color standard deviation of the bilateral filter. + + compat_bf: Label compatibility weight for the bilateral filter. + Assumes a Potts model w/one parameter. + + sxy_spatial: Spatial standard deviation of the 2D Gaussian + convolution kernel. + + compat_spatial: Label compatibility weight of the 2D Gaussian + convolution kernel. + + num_iter: Number of steps to run in the inference loop. + + normalize_final_iter: If pre-softmax outputs are desired rather + than label probabilities, set this to False. + + trainable_kstd: Allow the parameters of the bilateral filter to be + learned as well. This option may make training less stable. + """ + assert n_ref in {1, 3}, "Reference image must be either RGB or greyscale (3 or 1 channels)." super().__init__() + self.n_ref = n_ref + self.n_out = n_out self.sxy_bf = sxy_bf self.sc_bf = sc_bf self.compat_bf = compat_bf @@ -39,57 +85,40 @@ def __init__(self, sxy_bf=70, sc_bf=12, compat_bf=4, sxy_spatial=6, self.normalize_final_iter = normalize_final_iter self.trainable_kstd = trainable_kstd - if isinstance(sc_bf, (int, float)): - sc_bf = 3 * [sc_bf] + kstd = th.FloatTensor([sxy_bf, sxy_bf, sc_bf, sc_bf, sc_bf]) + if n_ref == 1: + kstd = kstd[:3] - kstd = th.FloatTensor([sxy_bf, sxy_bf, sc_bf[0], sc_bf[1], sc_bf[2]]) if trainable_kstd: self.kstd = nn.Parameter(kstd) else: self.register_buffer("kstd", kstd) - def forward(self, unary, ref): - N, ref_dim, H, W = ref.shape - Nu, val_dim, Hu, Wu = unary.shape - assert(Nu == N and Hu == H and Wu == W) + self.register_buffer("gk", gkern(sxy_spatial, n_out)) - if ref_dim not in [3, 1]: - raise ValueError("Reference image must be either color or greyscale (3 or 1 channels).") - ref_dim += 2 - - kstd = self.kstd[:3] if ref_dim == 3 else self.kstd - gk = gkern(self.sxy_spatial, val_dim, unary.device) - - yx = mgrid(H, W, unary.device) - grid = yx[None].repeat(N, 1, 1, 1) - - cap = get_hash_cap(H * W, ref_dim) - stacked = th.cat([grid, ref], dim=1) - gb = make_gfilt_buffers(val_dim, H, W, cap, unary.device) - gb1 = make_gfilt_buffers(1, H, W, cap, unary.device) - - def _bilateral(V, R, hb): - o = th.ones(1, H, W, device=unary.device) - norm = th.sqrt(gaussian_filter(R, o, kstd, hb, gb1)) + 1e-8 - return gaussian_filter(R, V / norm, kstd, hb, gb) / norm - - def _step(prev_q, U, ref, hb, normalize=True): - qbf = _bilateral(prev_q, ref, hb) - qsf = th.nn.functional.conv2d(prev_q[None], gk, padding=gk.shape[-1]//2)[0] + def forward(self, unary, ref): + def _bilateral(V, R): + return gaussian_filter(R, V, self.kstd[None]) + def _step(prev_q, U, ref, normalize=True): + qbf = _bilateral(prev_q, ref) + qsf = thf.conv2d(prev_q, self.gk, padding=self.gk.shape[-1] // 2) q_hat = -self.compat_bf * qbf - self.compat_spatial * qsf q_hat = U - q_hat + return th.softmax(q_hat, dim=1) if normalize else q_hat - return th.softmax(q_hat, dim=0) if normalize else q_hat - - def _inference(unary_i, ref_i): - U = th.log(th.clamp(unary_i, 1e-5, 1)) - prev_q = th.softmax(U, dim=0) - hb = make_hashtable(ref_i.data / kstd[:, None, None].data) + def _inference(unary, ref): + U = th.log(th.clamp(unary, 1e-5, 1)) + prev_q = th.softmax(U, dim=1) for i in range(self.num_iter): normalize = self.normalize_final_iter or i < self.num_iter - 1 - prev_q = _step(prev_q, U, ref_i, hb, normalize=normalize) + prev_q = _step(prev_q, U, ref, normalize=normalize) return prev_q - return th.stack([_inference(unary[i], stacked[i]) for i in range(N)]) + N, _, H, W = unary.shape + yx = mgrid(H, W, unary.device) + grid = yx[None].repeat(N, 1, 1, 1) + stacked = th.cat([grid, ref], dim=1) + + return _inference(unary, stacked) diff --git a/permutohedral/gfilt.py b/permutohedral/gfilt.py index 190ddcf..56f32f8 100644 --- a/permutohedral/gfilt.py +++ b/permutohedral/gfilt.py @@ -1,15 +1,19 @@ import torch as th -import numpy as np from .hash import make_hashtable import permutohedral_ext + th.ops.load_library(permutohedral_ext.__file__) gfilt_cuda = th.ops.permutohedral_ext.gfilt_cuda + def make_gfilt_buffers(b, val_dim, h, w, cap, dev): - return [th.zeros(b, val_dim, h, w, device=dev),# output - th.empty(cap, val_dim+1, device=dev), # tmp_vals_1 - th.empty(cap, val_dim+1, device=dev)] # tmp_vals_2 + return [ + th.zeros(b, val_dim, h, w, device=dev), # output + th.empty(cap, val_dim + 1, device=dev), # tmp_vals_1 + th.empty(cap, val_dim + 1, device=dev), # tmp_vals_2 + ] + class GaussianFilter(th.autograd.Function): @staticmethod @@ -37,7 +41,7 @@ def forward(ctx, ref, val, _hash_buffers=None, _gfilt_buffers=None): gfilt_cuda(val, *gfilt_buffers, *hash_buffers, ref_dim, False) else: raise NotImplementedError("Gfilt currently requires CUDA support.") - #gfilt_cpu(val, *gfilt_buffers, *hash_buffers, cap, ref_dim, False) + # gfilt_cpu(val, *gfilt_buffers, *hash_buffers, cap, ref_dim, False) out = gfilt_buffers[0] @@ -76,11 +80,11 @@ def filt(v): grads[0] = th.stack( [ - (grad_output * (filt(val * r_i) - r_i * out)) + - (val * (filt(grad_output * r_i) - r_i * filt_og)) + (grad_output * (filt(val * r_i) - r_i * out)) + + (val * (filt(grad_output * r_i) - r_i * filt_og)) for r_i in ref.split(1, dim=1) ], - dim=1 + dim=1, ).sum(dim=2) if ctx.needs_input_grad[1]: @@ -88,4 +92,5 @@ def filt(v): return grads[0], grads[1], grads[2], grads[3] + gfilt = GaussianFilter.apply diff --git a/permutohedral/hash.py b/permutohedral/hash.py index a25c61b..ca2fb97 100644 --- a/permutohedral/hash.py +++ b/permutohedral/hash.py @@ -1,19 +1,25 @@ import torch as th import permutohedral_ext + th.ops.load_library(permutohedral_ext.__file__) build_hash_cuda = th.ops.permutohedral_ext.build_hash_cuda + def make_hash_buffers(b, dim, h, w, cap, dev): - return [-th.ones(b, cap, dtype=th.int32, device=dev), # hash_entries - th.zeros(b, cap, dim, dtype=th.int16, device=dev), # hash_keys - th.zeros(b, dim+1, h, w, dtype=th.int32, device=dev), # neib_ents - th.zeros(b, dim+1, h, w, device=dev), # barycentric - th.zeros(b, cap, dtype=th.int32, device=dev), # valid_entries - th.zeros(b, 1).int().to(device=dev)] # n_valid_entries + return [ + -th.ones(b, cap, dtype=th.int32, device=dev), # hash_entries + th.zeros(b, cap, dim, dtype=th.int16, device=dev), # hash_keys + th.zeros(b, dim + 1, h, w, dtype=th.int32, device=dev), # neib_ents + th.zeros(b, dim + 1, h, w, device=dev), # barycentric + th.zeros(b, cap, dtype=th.int32, device=dev), # valid_entries + th.zeros(b, 1).int().to(device=dev), # n_valid_entries + ] + def get_hash_cap(N, dim): - return N*(dim+1) + return N * (dim + 1) + def make_hashtable(points): b, dim, h, w = points.shape @@ -25,6 +31,6 @@ def make_hashtable(points): build_hash_cuda(points.contiguous(), *buffers) else: raise NotImplementedError("Hash table currently requires CUDA support.") - #build_hash_cpu(points.contiguous(), *buffers, cap) + # build_hash_cpu(points.contiguous(), *buffers, cap) return buffers