diff --git a/README.md b/README.md
index b4d8ad6..28078de 100644
--- a/README.md
+++ b/README.md
@@ -83,10 +83,4 @@ implementation.
`python bilateral.py input.png output.png 20 0.25`
-
+![Input Image](images/wimr_small.png) ![Filtered](images/filtered.png)
diff --git a/crfrnn/crf.py b/crfrnn/crf.py
index 441f7a2..fbd9f1d 100644
--- a/crfrnn/crf.py
+++ b/crfrnn/crf.py
@@ -1,35 +1,81 @@
import torch as th
import torch.nn as nn
-import numpy as np
+import torch.nn.functional as thf
-from permutohedral.hash import get_hash_cap, make_hashtable
-from permutohedral.gfilt import gfilt, make_gfilt_buffers
+from permutohedral.gfilt import gfilt
+
+
+def gaussian_filter(ref, val, kstd):
+ return gfilt(ref / kstd[:, :, None, None], val)
-def gaussian_filter(ref, val, kstd, hb=None, gb=None):
- return gfilt(ref / kstd[:, None, None], val, hb, gb)
def mgrid(h, w, dev):
y = th.arange(0, h, device=dev).repeat(w, 1).t()
x = th.arange(0, w, device=dev).repeat(h, 1)
return th.stack([y, x], 0)
+
def gkern(std, chans, dev):
sig_sq = std ** 2
r = sig_sq if (sig_sq % 2) else sig_sq - 1
s = 2 * r + 1
- k = th.exp(-((mgrid(s, s, dev)-r) ** 2).sum(0) / (2 * sig_sq))
+ k = th.exp(-((mgrid(s, s, dev) - r) ** 2).sum(0) / (2 * sig_sq))
W = th.zeros(chans, chans, s, s, device=dev)
for i in range(chans):
W[i, i] = k / k.sum()
return W
+
class CRF(nn.Module):
- def __init__(self, sxy_bf=70, sc_bf=12, compat_bf=4, sxy_spatial=6,
- compat_spatial=2, num_iter=5, normalize_final_iter=True,
- trainable_kstd=False):
+ def __init__(
+ self,
+ n_ref: int,
+ n_out: int,
+ sxy_bf: float = 70,
+ sc_bf: float = 12,
+ compat_bf: float = 4,
+ sxy_spatial: float = 6,
+ compat_spatial: float = 2,
+ num_iter: int = 5,
+ normalize_final_iter: bool = True,
+ trainable_kstd: bool = False,
+ ):
+ """Implements fast approximate mean-field inference for a
+ fully-connected CRF with Gaussian edge potentials within a neural
+ network layer using fast bilateral filtering.
+
+ Args:
+ n_ref: Number of channels in the reference images.
+
+ n_out: Number of labels.
+
+ sxy_bf: Spatial standard deviation of the bilateral filter.
+
+ sc_bf: Color standard deviation of the bilateral filter.
+
+ compat_bf: Label compatibility weight for the bilateral filter.
+ Assumes a Potts model w/one parameter.
+
+ sxy_spatial: Spatial standard deviation of the 2D Gaussian
+ convolution kernel.
+
+ compat_spatial: Label compatibility weight of the 2D Gaussian
+ convolution kernel.
+
+ num_iter: Number of steps to run in the inference loop.
+
+ normalize_final_iter: If pre-softmax outputs are desired rather
+ than label probabilities, set this to False.
+
+ trainable_kstd: Allow the parameters of the bilateral filter to be
+ learned as well. This option may make training less stable.
+ """
+ assert n_ref in {1, 3}, "Reference image must be either RGB or greyscale (3 or 1 channels)."
super().__init__()
+ self.n_ref = n_ref
+ self.n_out = n_out
self.sxy_bf = sxy_bf
self.sc_bf = sc_bf
self.compat_bf = compat_bf
@@ -39,57 +85,40 @@ def __init__(self, sxy_bf=70, sc_bf=12, compat_bf=4, sxy_spatial=6,
self.normalize_final_iter = normalize_final_iter
self.trainable_kstd = trainable_kstd
- if isinstance(sc_bf, (int, float)):
- sc_bf = 3 * [sc_bf]
+ kstd = th.FloatTensor([sxy_bf, sxy_bf, sc_bf, sc_bf, sc_bf])
+ if n_ref == 1:
+ kstd = kstd[:3]
- kstd = th.FloatTensor([sxy_bf, sxy_bf, sc_bf[0], sc_bf[1], sc_bf[2]])
if trainable_kstd:
self.kstd = nn.Parameter(kstd)
else:
self.register_buffer("kstd", kstd)
- def forward(self, unary, ref):
- N, ref_dim, H, W = ref.shape
- Nu, val_dim, Hu, Wu = unary.shape
- assert(Nu == N and Hu == H and Wu == W)
+ self.register_buffer("gk", gkern(sxy_spatial, n_out))
- if ref_dim not in [3, 1]:
- raise ValueError("Reference image must be either color or greyscale (3 or 1 channels).")
- ref_dim += 2
-
- kstd = self.kstd[:3] if ref_dim == 3 else self.kstd
- gk = gkern(self.sxy_spatial, val_dim, unary.device)
-
- yx = mgrid(H, W, unary.device)
- grid = yx[None].repeat(N, 1, 1, 1)
-
- cap = get_hash_cap(H * W, ref_dim)
- stacked = th.cat([grid, ref], dim=1)
- gb = make_gfilt_buffers(val_dim, H, W, cap, unary.device)
- gb1 = make_gfilt_buffers(1, H, W, cap, unary.device)
-
- def _bilateral(V, R, hb):
- o = th.ones(1, H, W, device=unary.device)
- norm = th.sqrt(gaussian_filter(R, o, kstd, hb, gb1)) + 1e-8
- return gaussian_filter(R, V / norm, kstd, hb, gb) / norm
-
- def _step(prev_q, U, ref, hb, normalize=True):
- qbf = _bilateral(prev_q, ref, hb)
- qsf = th.nn.functional.conv2d(prev_q[None], gk, padding=gk.shape[-1]//2)[0]
+ def forward(self, unary, ref):
+ def _bilateral(V, R):
+ return gaussian_filter(R, V, self.kstd[None])
+ def _step(prev_q, U, ref, normalize=True):
+ qbf = _bilateral(prev_q, ref)
+ qsf = thf.conv2d(prev_q, self.gk, padding=self.gk.shape[-1] // 2)
q_hat = -self.compat_bf * qbf - self.compat_spatial * qsf
q_hat = U - q_hat
+ return th.softmax(q_hat, dim=1) if normalize else q_hat
- return th.softmax(q_hat, dim=0) if normalize else q_hat
-
- def _inference(unary_i, ref_i):
- U = th.log(th.clamp(unary_i, 1e-5, 1))
- prev_q = th.softmax(U, dim=0)
- hb = make_hashtable(ref_i.data / kstd[:, None, None].data)
+ def _inference(unary, ref):
+ U = th.log(th.clamp(unary, 1e-5, 1))
+ prev_q = th.softmax(U, dim=1)
for i in range(self.num_iter):
normalize = self.normalize_final_iter or i < self.num_iter - 1
- prev_q = _step(prev_q, U, ref_i, hb, normalize=normalize)
+ prev_q = _step(prev_q, U, ref, normalize=normalize)
return prev_q
- return th.stack([_inference(unary[i], stacked[i]) for i in range(N)])
+ N, _, H, W = unary.shape
+ yx = mgrid(H, W, unary.device)
+ grid = yx[None].repeat(N, 1, 1, 1)
+ stacked = th.cat([grid, ref], dim=1)
+
+ return _inference(unary, stacked)
diff --git a/permutohedral/gfilt.py b/permutohedral/gfilt.py
index 190ddcf..56f32f8 100644
--- a/permutohedral/gfilt.py
+++ b/permutohedral/gfilt.py
@@ -1,15 +1,19 @@
import torch as th
-import numpy as np
from .hash import make_hashtable
import permutohedral_ext
+
th.ops.load_library(permutohedral_ext.__file__)
gfilt_cuda = th.ops.permutohedral_ext.gfilt_cuda
+
def make_gfilt_buffers(b, val_dim, h, w, cap, dev):
- return [th.zeros(b, val_dim, h, w, device=dev),# output
- th.empty(cap, val_dim+1, device=dev), # tmp_vals_1
- th.empty(cap, val_dim+1, device=dev)] # tmp_vals_2
+ return [
+ th.zeros(b, val_dim, h, w, device=dev), # output
+ th.empty(cap, val_dim + 1, device=dev), # tmp_vals_1
+ th.empty(cap, val_dim + 1, device=dev), # tmp_vals_2
+ ]
+
class GaussianFilter(th.autograd.Function):
@staticmethod
@@ -37,7 +41,7 @@ def forward(ctx, ref, val, _hash_buffers=None, _gfilt_buffers=None):
gfilt_cuda(val, *gfilt_buffers, *hash_buffers, ref_dim, False)
else:
raise NotImplementedError("Gfilt currently requires CUDA support.")
- #gfilt_cpu(val, *gfilt_buffers, *hash_buffers, cap, ref_dim, False)
+ # gfilt_cpu(val, *gfilt_buffers, *hash_buffers, cap, ref_dim, False)
out = gfilt_buffers[0]
@@ -76,11 +80,11 @@ def filt(v):
grads[0] = th.stack(
[
- (grad_output * (filt(val * r_i) - r_i * out)) +
- (val * (filt(grad_output * r_i) - r_i * filt_og))
+ (grad_output * (filt(val * r_i) - r_i * out))
+ + (val * (filt(grad_output * r_i) - r_i * filt_og))
for r_i in ref.split(1, dim=1)
],
- dim=1
+ dim=1,
).sum(dim=2)
if ctx.needs_input_grad[1]:
@@ -88,4 +92,5 @@ def filt(v):
return grads[0], grads[1], grads[2], grads[3]
+
gfilt = GaussianFilter.apply
diff --git a/permutohedral/hash.py b/permutohedral/hash.py
index a25c61b..ca2fb97 100644
--- a/permutohedral/hash.py
+++ b/permutohedral/hash.py
@@ -1,19 +1,25 @@
import torch as th
import permutohedral_ext
+
th.ops.load_library(permutohedral_ext.__file__)
build_hash_cuda = th.ops.permutohedral_ext.build_hash_cuda
+
def make_hash_buffers(b, dim, h, w, cap, dev):
- return [-th.ones(b, cap, dtype=th.int32, device=dev), # hash_entries
- th.zeros(b, cap, dim, dtype=th.int16, device=dev), # hash_keys
- th.zeros(b, dim+1, h, w, dtype=th.int32, device=dev), # neib_ents
- th.zeros(b, dim+1, h, w, device=dev), # barycentric
- th.zeros(b, cap, dtype=th.int32, device=dev), # valid_entries
- th.zeros(b, 1).int().to(device=dev)] # n_valid_entries
+ return [
+ -th.ones(b, cap, dtype=th.int32, device=dev), # hash_entries
+ th.zeros(b, cap, dim, dtype=th.int16, device=dev), # hash_keys
+ th.zeros(b, dim + 1, h, w, dtype=th.int32, device=dev), # neib_ents
+ th.zeros(b, dim + 1, h, w, device=dev), # barycentric
+ th.zeros(b, cap, dtype=th.int32, device=dev), # valid_entries
+ th.zeros(b, 1).int().to(device=dev), # n_valid_entries
+ ]
+
def get_hash_cap(N, dim):
- return N*(dim+1)
+ return N * (dim + 1)
+
def make_hashtable(points):
b, dim, h, w = points.shape
@@ -25,6 +31,6 @@ def make_hashtable(points):
build_hash_cuda(points.contiguous(), *buffers)
else:
raise NotImplementedError("Hash table currently requires CUDA support.")
- #build_hash_cpu(points.contiguous(), *buffers, cap)
+ # build_hash_cpu(points.contiguous(), *buffers, cap)
return buffers