-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Add unet and dependencies #11
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright 2014-2017 The DIPP contributors | ||
# | ||
# This file is part of DIPP. | ||
# | ||
# This Source Code Form is subject to the terms of the Mozilla Public License, | ||
# v. 2.0. If a copy of the MPL was not distributed with this file, You can | ||
# obtain one at https://mozilla.org/MPL/2.0/. | ||
|
||
import tensorflow as tf | ||
|
||
|
||
__all__ = ('leaky_relu', 'prelu') | ||
|
||
|
||
def leaky_relu(_x, alpha=0.2, name='leaky_relu'): | ||
return prelu(_x, init=alpha, name=name, trainable=False) | ||
|
||
|
||
def prelu(_x, init=0.0, name='prelu', trainable=True): | ||
with tf.variable_scope(name): | ||
alphas = tf.get_variable('alphas', | ||
shape=[int(_x.get_shape()[-1])], | ||
initializer=tf.constant_initializer(init), | ||
dtype=tf.float32, | ||
trainable=True) | ||
pos = tf.nn.relu(_x) | ||
neg = -alphas * tf.nn.relu(-_x) | ||
|
||
return pos + neg |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# Copyright 2014-2017 The DIPP contributors | ||
# | ||
# This file is part of DIPP. | ||
# | ||
# This Source Code Form is subject to the terms of the Mozilla Public License, | ||
# v. 2.0. If a copy of the MPL was not distributed with this file, You can | ||
# obtain one at https://mozilla.org/MPL/2.0/. | ||
|
||
import tensorflow as tf | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Commenting here. In general, at least a minimal docstring that says what the function/class does would be great. |
||
|
||
|
||
__all__ = ('conv1d', 'conv1dtransp', 'conv2d', 'conv2dtransp', | ||
'maxpool1d', 'maxpool2d') | ||
|
||
|
||
def conv1d(x, W, stride=1, padding='SAME'): | ||
with tf.name_scope('conv1d'): | ||
return tf.nn.conv1d(x, W, | ||
stride=stride, | ||
padding=padding) | ||
|
||
|
||
def conv1dtransp(x, W, stride=1, out_shape=None, padding='SAME'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this convolving with the flipped kernel? It's a bit hard to imagine what "transposed convolution" is supposed to mean in 1d. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Transpose of the convolution operator, e.g. adjoint. This is standard nomenclature in the field There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm baffled that tensorflow has no implementation for that. 🤔 |
||
with tf.name_scope('conv1dtransp'): | ||
x_shape = tf.shape(x) | ||
W_shape = tf.shape(W) | ||
if out_shape is None: | ||
out_shape = tf.stack([x_shape[0], | ||
1, | ||
stride * x_shape[1], | ||
W_shape[1]]) | ||
else: | ||
out_shape = tf.stack([out_shape[0], | ||
1, | ||
out_shape[1], | ||
out_shape[2]]) | ||
|
||
x_reshaped = tf.expand_dims(x, 1) | ||
W_reshaped = tf.expand_dims(W, 0) | ||
strides = [1, 1, stride, 1] | ||
|
||
result = tf.nn.conv2d_transpose(x_reshaped, W_reshaped, | ||
output_shape=out_shape, | ||
strides=strides, | ||
padding=padding) | ||
|
||
return tf.squeeze(result, axis=1) | ||
|
||
|
||
def conv2d(x, W, stride=(1, 1), padding='SAME'): | ||
with tf.name_scope('conv2d'): | ||
strides = [1, stride[0], stride[1], 1] | ||
if padding in ('SAME', 'VALID'): | ||
return tf.nn.conv2d(x, W, | ||
strides=strides, padding=padding) | ||
else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Kind of strange logic to go into this case for whatever input except |
||
paddings = [[0, 0], | ||
[1, 1], | ||
[1, 1], | ||
[0, 0]] | ||
x = tf.pad(x, paddings=paddings, mode=padding) | ||
|
||
return tf.nn.conv2d(x, W, | ||
strides=strides, padding='VALID') | ||
|
||
|
||
def conv2dtransp(x, W, stride=(1, 1), out_shape=None, padding='SAME'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reason to deviate from the TF naming scheme? |
||
with tf.name_scope('conv2dtransp'): | ||
x_shape = tf.shape(x) | ||
W_shape = tf.shape(W) | ||
if out_shape is None: | ||
out_shape = tf.stack([x_shape[0], | ||
stride[0] * x_shape[1], | ||
stride[1] * x_shape[2], | ||
W_shape[2]]) | ||
|
||
return tf.nn.conv2d_transpose(x, W, | ||
output_shape=out_shape, | ||
strides=[1, stride[0], stride[1], 1], | ||
padding=padding) | ||
|
||
|
||
def maxpool1d(x, stride=2, padding='SAME'): | ||
with tf.name_scope('maxpool1d'): | ||
ksize = [1, 1, stride, 1] | ||
strides = [1, 1, stride, 1] | ||
|
||
x_pad = tf.expand_dims(x, 1) | ||
result = tf.nn.max_pool(x_pad, ksize, strides, padding) | ||
return tf.squeeze(result, axis=1) | ||
|
||
|
||
def maxpool2d(x, stride=(2, 2), padding='SAME'): | ||
with tf.name_scope('maxpool2d'): | ||
ksize = [1, stride[0], stride[1], 1] | ||
strides = [1, stride[0], stride[1], 1] | ||
return tf.nn.max_pool(x, ksize, strides, padding) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
# Copyright 2014-2017 The DIPP contributors | ||
# | ||
# This file is part of DIPP. | ||
# | ||
# This Source Code Form is subject to the terms of the Mozilla Public License, | ||
# v. 2.0. If a copy of the MPL was not distributed with this file, You can | ||
# obtain one at https://mozilla.org/MPL/2.0/. | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from dipp.tensorflow.activation import prelu, leaky_relu | ||
from dipp.tensorflow.layers import (conv1d, conv2d, | ||
conv1dtransp, conv2dtransp, | ||
maxpool1d, maxpool2d) | ||
|
||
|
||
__all__ = ('unet',) | ||
|
||
|
||
def unet(x, nout, | ||
features=64, | ||
keep_prob=1.0, | ||
use_batch_norm=True, | ||
activation='relu', | ||
is_training=True, | ||
init='he', | ||
depth=4, | ||
name='unet'): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Weird stacking of arguments, why not use the full line length? |
||
"""Reference implementation of the original U-net. | ||
|
||
All defaults are according to the reference article: | ||
|
||
https://arxiv.org/abs/1505.04597 | ||
|
||
Parameters | ||
---------- | ||
x : `tf.Tensor` with shape ``(B, L, C)`` or ``(B, H, W, C)`` | ||
The input vector/image | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. . |
||
nout : int | ||
Number of output channels. | ||
features : int, optional | ||
Number of features at the finest resultion. | ||
keep_prob : float in [0, 1], optional | ||
Used for dropout. | ||
use_batch_norm : bool, optional | ||
Wether batch normalization should be used. | ||
activation : {'relu', 'elu', 'leaky_relu', 'prelu'}, optional | ||
Activation function to use. | ||
is_training : bool or `tf.Tensor` with dtype bool, optional | ||
Flag indicating if training is currently done. | ||
Needed for batch normalization. | ||
init : {'he', 'xavier'}, optional | ||
Initialization scheme for the weights. Biases are initialized to zero. | ||
depth : positive int, optional | ||
Number of downsamplings that should be done. | ||
name : str, optional | ||
Name of the created layer. | ||
|
||
Returns | ||
------- | ||
unet : `tf.Tensor` with shape ``(B, L, nout)`` or ``(B, H, W, nout)`` | ||
|
||
Examples | ||
-------- | ||
Create 2d unet | ||
|
||
>>> data = np.array([[1, 2, 3], | ||
... [4, 5, 6], | ||
... [7, 8, 9]]) | ||
>>> x = tf.constant(data[None, ..., None]) # add empty batch and channel | ||
>>> y = unet(x, 1) | ||
>>> y.shape | ||
TensorShape([Dimension(1), Dimension(3), Dimension(3), Dimension(1)]) | ||
""" | ||
x = tf.cast(x, 'float32') | ||
ndim = len(x.shape) - 2 | ||
|
||
assert depth >= 1 | ||
|
||
def get_weight_bias(nin, nout, transpose, size): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why the |
||
if transpose: | ||
shape = [size] * ndim + [nout, nin] | ||
else: | ||
shape = [size] * ndim + [nin, nout] | ||
|
||
b_shape = [1] * (1 + ndim) + [nout] | ||
|
||
if init == 'xavier': | ||
stddev = np.sqrt(2.6 / (size ** ndim * (nin + nout))) | ||
elif init == 'he': | ||
stddev = np.sqrt(2.6 / (size ** ndim * nin)) | ||
|
||
w = tf.Variable(tf.truncated_normal(shape, stddev=stddev)) | ||
b = tf.Variable(tf.constant(0.0, shape=b_shape)) | ||
|
||
return w, b | ||
|
||
def apply_activation(x): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove |
||
if activation == 'relu': | ||
return tf.nn.relu(x) | ||
elif activation == 'elu': | ||
return tf.nn.elu(x) | ||
elif activation == 'leaky_relu': | ||
return leaky_relu(x) | ||
elif activation == 'prelu': | ||
return prelu(x) | ||
else: | ||
raise RuntimeError('unknown activation') | ||
|
||
def apply_conv(x, nout, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
stride=False, | ||
size=3, | ||
disable_batch_norm=False, | ||
disable_dropout=False, | ||
disable_activation=False): | ||
|
||
if stride: | ||
if ndim == 1: | ||
stride = 2 | ||
elif ndim == 2: | ||
stride = (2, 2) | ||
else: | ||
if ndim == 1: | ||
stride = 1 | ||
elif ndim == 2: | ||
stride = (1, 1) | ||
|
||
with tf.name_scope('apply_conv'): | ||
nin = int(x.get_shape()[-1]) | ||
|
||
w, b = get_weight_bias(nin, nout, transpose=False, size=size) | ||
|
||
if ndim == 1: | ||
out = conv1d(x, w, stride=stride) + b | ||
elif ndim == 2: | ||
out = conv2d(x, w, stride=stride) + b | ||
|
||
if use_batch_norm and not disable_batch_norm: | ||
out = tf.contrib.layers.batch_norm(out, | ||
is_training=is_training) | ||
if keep_prob != 1.0 and not disable_dropout: | ||
out = tf.contrib.layers.dropout(out, keep_prob=keep_prob, | ||
is_training=is_training) | ||
|
||
if not disable_activation: | ||
out = apply_activation(out) | ||
|
||
return out | ||
|
||
def apply_convtransp(x, nout, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
stride=True, out_shape=None, | ||
size=2, | ||
disable_batch_norm=False, | ||
disable_dropout=False, | ||
disable_activation=False): | ||
|
||
if stride: | ||
if ndim == 1: | ||
stride = 2 | ||
elif ndim == 2: | ||
stride = (2, 2) | ||
else: | ||
if ndim == 1: | ||
stride = 1 | ||
elif ndim == 2: | ||
stride = (1, 1) | ||
|
||
with tf.name_scope('apply_convtransp'): | ||
nin = int(x.get_shape()[-1]) | ||
|
||
w, b = get_weight_bias(nin, nout, transpose=True, size=size) | ||
|
||
if ndim == 1: | ||
out = conv1dtransp(x, w, stride=stride, out_shape=out_shape) + b | ||
elif ndim == 2: | ||
out = conv2dtransp(x, w, stride=stride, out_shape=out_shape) + b | ||
|
||
if use_batch_norm and not disable_batch_norm: | ||
out = tf.contrib.layers.batch_norm(out, | ||
is_training=is_training) | ||
if keep_prob != 1.0 and not disable_dropout: | ||
out = tf.contrib.layers.dropout(out, keep_prob=keep_prob, | ||
is_training=is_training) | ||
|
||
if not disable_activation: | ||
out = apply_activation(out) | ||
|
||
return out | ||
|
||
def apply_maxpool(x): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if ndim == 1: | ||
return maxpool1d(x) | ||
else: | ||
return maxpool2d(x) | ||
|
||
finals = [] | ||
|
||
with tf.name_scope('{}_call'.format(name)): | ||
with tf.name_scope('in'): | ||
current = apply_conv(x, features) | ||
current = apply_conv(current, features) | ||
finals.append(current) | ||
|
||
for layer in range(depth - 1): | ||
with tf.name_scope('down_{}'.format(layer + 1)): | ||
features_layer = 2 ** (layer + 1) | ||
current = apply_maxpool(current) | ||
current = apply_conv(current, features_layer) | ||
current = apply_conv(current, features_layer) | ||
finals.append(current) | ||
|
||
with tf.name_scope('coarse'): | ||
current = apply_maxpool(current) | ||
current = apply_conv(current, features * 2 ** depth) | ||
current = apply_conv(current, features * 2 ** depth) | ||
|
||
for layer in reversed(range(depth - 1)): | ||
with tf.name_scope('up_{}'.format(layer + 1)): | ||
features_layer = 2 ** (layer + 1) | ||
skip = finals.pop() | ||
current = apply_convtransp(current, features_layer, | ||
out_shape=tf.shape(skip), | ||
disable_activation=True) | ||
current = tf.concat([current, skip], axis=-1) | ||
|
||
current = apply_conv(current, features_layer) | ||
current = apply_conv(current, features_layer) | ||
|
||
with tf.name_scope('out'): | ||
skip = finals.pop() | ||
current = apply_convtransp(current, features, | ||
out_shape=tf.shape(skip), | ||
disable_activation=True) | ||
current = tf.concat([current, skip], axis=-1) | ||
|
||
current = apply_conv(current, features) | ||
current = apply_conv(current, features) | ||
|
||
current = apply_conv(current, nout, | ||
size=1, | ||
disable_activation=True, | ||
disable_batch_norm=True, | ||
disable_dropout=True) | ||
|
||
return current | ||
|
||
|
||
if __name__ == '__main__': | ||
from dipp.util.testutils import run_doctests | ||
with tf.Session(): | ||
run_doctests() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You may want to add |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Established name?
Reason for the underscore(s) in
_x
?Doc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name is established, fixing x