Skip to content

Commit

Permalink
Autoformat and clean up CRF code.
Browse files Browse the repository at this point in the history
  • Loading branch information
HapeMask committed Nov 18, 2020
1 parent cf526b6 commit dbddf3f
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 70 deletions.
8 changes: 1 addition & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,4 @@ implementation.

`python bilateral.py input.png output.png 20 0.25`

<a
href="https://github.com/HapeMask/crfrnn_layer/raw/master/images/wimr_small.png"><img
src="https://github.com/HapeMask/crfrnn_layer/raw/master/images/wimr_small.png"
width=400 /></a> <a
href="https://github.com/HapeMask/crfrnn_layer/raw/master/images/filtered.png"><img
src="https://github.com/HapeMask/crfrnn_layer/raw/master/images/filtered.png"
width=400 /></a>
![Input Image](images/wimr_small.png) ![Filtered](images/filtered.png)
123 changes: 76 additions & 47 deletions crfrnn/crf.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,81 @@
import torch as th
import torch.nn as nn
import numpy as np
import torch.nn.functional as thf

from permutohedral.hash import get_hash_cap, make_hashtable
from permutohedral.gfilt import gfilt, make_gfilt_buffers
from permutohedral.gfilt import gfilt


def gaussian_filter(ref, val, kstd):
return gfilt(ref / kstd[:, :, None, None], val)

def gaussian_filter(ref, val, kstd, hb=None, gb=None):
return gfilt(ref / kstd[:, None, None], val, hb, gb)

def mgrid(h, w, dev):
y = th.arange(0, h, device=dev).repeat(w, 1).t()
x = th.arange(0, w, device=dev).repeat(h, 1)
return th.stack([y, x], 0)


def gkern(std, chans, dev):
sig_sq = std ** 2
r = sig_sq if (sig_sq % 2) else sig_sq - 1
s = 2 * r + 1
k = th.exp(-((mgrid(s, s, dev)-r) ** 2).sum(0) / (2 * sig_sq))
k = th.exp(-((mgrid(s, s, dev) - r) ** 2).sum(0) / (2 * sig_sq))
W = th.zeros(chans, chans, s, s, device=dev)
for i in range(chans):
W[i, i] = k / k.sum()
return W


class CRF(nn.Module):
def __init__(self, sxy_bf=70, sc_bf=12, compat_bf=4, sxy_spatial=6,
compat_spatial=2, num_iter=5, normalize_final_iter=True,
trainable_kstd=False):
def __init__(
self,
n_ref: int,
n_out: int,
sxy_bf: float = 70,
sc_bf: float = 12,
compat_bf: float = 4,
sxy_spatial: float = 6,
compat_spatial: float = 2,
num_iter: int = 5,
normalize_final_iter: bool = True,
trainable_kstd: bool = False,
):
"""Implements fast approximate mean-field inference for a
fully-connected CRF with Gaussian edge potentials within a neural
network layer using fast bilateral filtering.
Args:
n_ref: Number of channels in the reference images.
n_out: Number of labels.
sxy_bf: Spatial standard deviation of the bilateral filter.
sc_bf: Color standard deviation of the bilateral filter.
compat_bf: Label compatibility weight for the bilateral filter.
Assumes a Potts model w/one parameter.
sxy_spatial: Spatial standard deviation of the 2D Gaussian
convolution kernel.
compat_spatial: Label compatibility weight of the 2D Gaussian
convolution kernel.
num_iter: Number of steps to run in the inference loop.
normalize_final_iter: If pre-softmax outputs are desired rather
than label probabilities, set this to False.
trainable_kstd: Allow the parameters of the bilateral filter to be
learned as well. This option may make training less stable.
"""
assert n_ref in {1, 3}, "Reference image must be either RGB or greyscale (3 or 1 channels)."

super().__init__()

self.n_ref = n_ref
self.n_out = n_out
self.sxy_bf = sxy_bf
self.sc_bf = sc_bf
self.compat_bf = compat_bf
Expand All @@ -39,57 +85,40 @@ def __init__(self, sxy_bf=70, sc_bf=12, compat_bf=4, sxy_spatial=6,
self.normalize_final_iter = normalize_final_iter
self.trainable_kstd = trainable_kstd

if isinstance(sc_bf, (int, float)):
sc_bf = 3 * [sc_bf]
kstd = th.FloatTensor([sxy_bf, sxy_bf, sc_bf, sc_bf, sc_bf])
if n_ref == 1:
kstd = kstd[:3]

kstd = th.FloatTensor([sxy_bf, sxy_bf, sc_bf[0], sc_bf[1], sc_bf[2]])
if trainable_kstd:
self.kstd = nn.Parameter(kstd)
else:
self.register_buffer("kstd", kstd)

def forward(self, unary, ref):
N, ref_dim, H, W = ref.shape
Nu, val_dim, Hu, Wu = unary.shape
assert(Nu == N and Hu == H and Wu == W)
self.register_buffer("gk", gkern(sxy_spatial, n_out))

if ref_dim not in [3, 1]:
raise ValueError("Reference image must be either color or greyscale (3 or 1 channels).")
ref_dim += 2

kstd = self.kstd[:3] if ref_dim == 3 else self.kstd
gk = gkern(self.sxy_spatial, val_dim, unary.device)

yx = mgrid(H, W, unary.device)
grid = yx[None].repeat(N, 1, 1, 1)

cap = get_hash_cap(H * W, ref_dim)
stacked = th.cat([grid, ref], dim=1)
gb = make_gfilt_buffers(val_dim, H, W, cap, unary.device)
gb1 = make_gfilt_buffers(1, H, W, cap, unary.device)

def _bilateral(V, R, hb):
o = th.ones(1, H, W, device=unary.device)
norm = th.sqrt(gaussian_filter(R, o, kstd, hb, gb1)) + 1e-8
return gaussian_filter(R, V / norm, kstd, hb, gb) / norm

def _step(prev_q, U, ref, hb, normalize=True):
qbf = _bilateral(prev_q, ref, hb)
qsf = th.nn.functional.conv2d(prev_q[None], gk, padding=gk.shape[-1]//2)[0]
def forward(self, unary, ref):
def _bilateral(V, R):
return gaussian_filter(R, V, self.kstd[None])

def _step(prev_q, U, ref, normalize=True):
qbf = _bilateral(prev_q, ref)
qsf = thf.conv2d(prev_q, self.gk, padding=self.gk.shape[-1] // 2)
q_hat = -self.compat_bf * qbf - self.compat_spatial * qsf
q_hat = U - q_hat
return th.softmax(q_hat, dim=1) if normalize else q_hat

return th.softmax(q_hat, dim=0) if normalize else q_hat

def _inference(unary_i, ref_i):
U = th.log(th.clamp(unary_i, 1e-5, 1))
prev_q = th.softmax(U, dim=0)
hb = make_hashtable(ref_i.data / kstd[:, None, None].data)
def _inference(unary, ref):
U = th.log(th.clamp(unary, 1e-5, 1))
prev_q = th.softmax(U, dim=1)

for i in range(self.num_iter):
normalize = self.normalize_final_iter or i < self.num_iter - 1
prev_q = _step(prev_q, U, ref_i, hb, normalize=normalize)
prev_q = _step(prev_q, U, ref, normalize=normalize)
return prev_q

return th.stack([_inference(unary[i], stacked[i]) for i in range(N)])
N, _, H, W = unary.shape
yx = mgrid(H, W, unary.device)
grid = yx[None].repeat(N, 1, 1, 1)
stacked = th.cat([grid, ref], dim=1)

return _inference(unary, stacked)
21 changes: 13 additions & 8 deletions permutohedral/gfilt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import torch as th
import numpy as np

from .hash import make_hashtable
import permutohedral_ext

th.ops.load_library(permutohedral_ext.__file__)
gfilt_cuda = th.ops.permutohedral_ext.gfilt_cuda


def make_gfilt_buffers(b, val_dim, h, w, cap, dev):
return [th.zeros(b, val_dim, h, w, device=dev),# output
th.empty(cap, val_dim+1, device=dev), # tmp_vals_1
th.empty(cap, val_dim+1, device=dev)] # tmp_vals_2
return [
th.zeros(b, val_dim, h, w, device=dev), # output
th.empty(cap, val_dim + 1, device=dev), # tmp_vals_1
th.empty(cap, val_dim + 1, device=dev), # tmp_vals_2
]


class GaussianFilter(th.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -37,7 +41,7 @@ def forward(ctx, ref, val, _hash_buffers=None, _gfilt_buffers=None):
gfilt_cuda(val, *gfilt_buffers, *hash_buffers, ref_dim, False)
else:
raise NotImplementedError("Gfilt currently requires CUDA support.")
#gfilt_cpu(val, *gfilt_buffers, *hash_buffers, cap, ref_dim, False)
# gfilt_cpu(val, *gfilt_buffers, *hash_buffers, cap, ref_dim, False)

out = gfilt_buffers[0]

Expand Down Expand Up @@ -76,16 +80,17 @@ def filt(v):

grads[0] = th.stack(
[
(grad_output * (filt(val * r_i) - r_i * out)) +
(val * (filt(grad_output * r_i) - r_i * filt_og))
(grad_output * (filt(val * r_i) - r_i * out))
+ (val * (filt(grad_output * r_i) - r_i * filt_og))
for r_i in ref.split(1, dim=1)
],
dim=1
dim=1,
).sum(dim=2)

if ctx.needs_input_grad[1]:
grads[1] = filt_og

return grads[0], grads[1], grads[2], grads[3]


gfilt = GaussianFilter.apply
22 changes: 14 additions & 8 deletions permutohedral/hash.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
import torch as th

import permutohedral_ext

th.ops.load_library(permutohedral_ext.__file__)
build_hash_cuda = th.ops.permutohedral_ext.build_hash_cuda


def make_hash_buffers(b, dim, h, w, cap, dev):
return [-th.ones(b, cap, dtype=th.int32, device=dev), # hash_entries
th.zeros(b, cap, dim, dtype=th.int16, device=dev), # hash_keys
th.zeros(b, dim+1, h, w, dtype=th.int32, device=dev), # neib_ents
th.zeros(b, dim+1, h, w, device=dev), # barycentric
th.zeros(b, cap, dtype=th.int32, device=dev), # valid_entries
th.zeros(b, 1).int().to(device=dev)] # n_valid_entries
return [
-th.ones(b, cap, dtype=th.int32, device=dev), # hash_entries
th.zeros(b, cap, dim, dtype=th.int16, device=dev), # hash_keys
th.zeros(b, dim + 1, h, w, dtype=th.int32, device=dev), # neib_ents
th.zeros(b, dim + 1, h, w, device=dev), # barycentric
th.zeros(b, cap, dtype=th.int32, device=dev), # valid_entries
th.zeros(b, 1).int().to(device=dev), # n_valid_entries
]


def get_hash_cap(N, dim):
return N*(dim+1)
return N * (dim + 1)


def make_hashtable(points):
b, dim, h, w = points.shape
Expand All @@ -25,6 +31,6 @@ def make_hashtable(points):
build_hash_cuda(points.contiguous(), *buffers)
else:
raise NotImplementedError("Hash table currently requires CUDA support.")
#build_hash_cpu(points.contiguous(), *buffers, cap)
# build_hash_cpu(points.contiguous(), *buffers, cap)

return buffers

0 comments on commit dbddf3f

Please sign in to comment.