Skip to content

Commit

Permalink
Improvements to apex.mlp (#804)
Browse files Browse the repository at this point in the history
* update fused bias relu backward kernel

* adding support for not require first layer dgrad

* fix bug: wrong layer in requires grad

* add infrastructure for optional bias and activation, currently only support no bias and no relu

* make bias and relu optional separately

* add sigmoid activation option
  • Loading branch information
FDecaYed authored Apr 30, 2020
1 parent aad9300 commit 31aceea
Show file tree
Hide file tree
Showing 4 changed files with 772 additions and 135 deletions.
47 changes: 28 additions & 19 deletions apex/mlp/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@

class MlpFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, *args):
output = mlp_cuda.forward(args)
def forward(ctx, bias, activation, *args):
output = mlp_cuda.forward(bias, activation, args)
ctx.save_for_backward(*args)
ctx.outputs = output
ctx.bias = bias
ctx.activation = activation
return output[0]

@staticmethod
def backward(ctx, grad_o):
grads = mlp_cuda.backward(grad_o, ctx.outputs, ctx.saved_tensors)
grads = mlp_cuda.backward(ctx.bias, ctx.activation, grad_o, ctx.outputs, ctx.saved_tensors)
del ctx.outputs
return tuple(grads)
return (None, None, *grads)

mlp_function = amp.half_function(MlpFunction.apply)

Expand All @@ -29,27 +31,33 @@ class MLP(torch.nn.Module):
bias (bool): Default True:
relu (bool): Default True
"""
def __init__(self, mlp_sizes, bias=True, relu=True):
if not (bias and relu):
raise TypeError("bias and relu must be both true.")
def __init__(self, mlp_sizes, bias=True, activation='relu'):
super(MLP, self).__init__()
self.num_layers = len(mlp_sizes) - 1
self.mlp_sizes = copy(mlp_sizes)
self.bias = bias
self.relu= relu
self.bias = 1 if bias else 0

if activation is 'none':
self.activation = 0
elif activation is 'relu':
self.activation = 1
elif activation is 'sigmoid':
self.activation = 2
else:
raise TypeError("activation must be relu or none.")

# ignoring bias = False now
self.weights = []
self.biases = []
for i in range(self.num_layers):
w = torch.nn.Parameter(torch.empty(mlp_sizes[i+1], mlp_sizes[i]))
self.weights.append(w)
name = 'weight_{}'.format(i)
setattr(self, name, w)
b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1]))
self.biases.append(b)
name = 'bias_{}'.format(i)
setattr(self, name, b)
if self.bias:
b = torch.nn.Parameter(torch.empty(mlp_sizes[i+1]))
self.biases.append(b)
name = 'bias_{}'.format(i)
setattr(self, name, b)

self.reset_parameters()

Expand All @@ -58,13 +66,14 @@ def reset_parameters(self):
dimsum = weight.size(0) + weight.size(1)
std = math.sqrt(2. / float(dimsum))
nn.init.normal_(weight, 0., std)
for bias in self.biases:
std = math.sqrt(1. / float(bias.size(0)))
nn.init.normal_(bias, 0., std)
if self.bias:
for bias in self.biases:
std = math.sqrt(1. / float(bias.size(0)))
nn.init.normal_(bias, 0., std)

def forward(self, input):
return mlp_function(input, *self.weights, *self.biases)
return mlp_function(self.bias, self.activation, input, *self.weights, *self.biases)

def extra_repr(self):
s = F"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, ReLU={self.relu}"
s = F"MLP sizes: {self.mlp_sizes}, Bias={self.bias}, activation={self.activation}"
return s
57 changes: 41 additions & 16 deletions csrc/mlp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ int mlp_fp(
int* output_features,
T** BPtr,
T* Y,
T* reserved_space);
T* reserved_space,
int use_bias,
int activation);

template <typename T>
int mlp_bp(
Expand All @@ -35,11 +37,18 @@ int mlp_bp(
T* work_space,
T* dX,
T** dwPtr,
T** dbPtr);
T** dbPtr,
bool requires_grad,
int use_bias,
int activation);

std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at::Tensor> inputs) {

std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) {
// inputs contains (input, weights, biases)
auto num_layers = (inputs.size() - 1) / 2;
auto num_layers = inputs.size() - 1;
if (use_bias) {
// inputs contains (input, weights, biases)
num_layers /= 2;
}
auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1);

Expand All @@ -60,7 +69,9 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) {
std::vector<scalar_t*> b_ptr;
for (int i = 0; i < num_layers; i++) {
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
if (use_bias) {
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
}
}
auto result = mlp_fp<scalar_t>(
inputs[0].data_ptr<scalar_t>(),
Expand All @@ -71,37 +82,48 @@ std::vector<at::Tensor> mlp_forward(std::vector<at::Tensor> inputs) {
output_features.data(),
b_ptr.data(),
out.data_ptr<scalar_t>(),
reserved_space.data_ptr<scalar_t>());
reserved_space.data_ptr<scalar_t>(),
use_bias,
activation);
});

return {out, reserved_space};
}

std::vector<at::Tensor> mlp_backward(
at::Tensor grad_o,
std::vector<at::Tensor> fprop_outputs,
std::vector<at::Tensor> inputs) {
// same code to get sizes and W pointers
auto num_layers = (inputs.size() - 1) / 2;
int use_bias,
int activation,
at::Tensor grad_o,
std::vector<at::Tensor> fprop_outputs,
std::vector<at::Tensor> inputs) {

auto num_layers = inputs.size() - 1;
if (use_bias) {
// inputs contains (input, weights, biases)
num_layers /= 2;
}

auto batch_size = inputs[0].size(0);
auto input_features = inputs[0].size(1);

// TODO: not creating empty tensor for it?
bool requires_grad = inputs[0].requires_grad();

std::vector<int> output_features;
for (int i = 0; i < num_layers; i++) {
output_features.push_back(inputs[i + 1].size(0));
}
// create outputs, length of inputs
// TODO: not create bias if not needed
std::vector<at::Tensor> outputs;
for (int i = 0; i < inputs.size(); i++) {
outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now
}

AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_backward", [&] {
std::vector<scalar_t*> w_ptr;
std::vector<scalar_t*> b_ptr;
for (int i = 0; i < num_layers; i++) {
w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
}
std::vector<scalar_t*> outputs_ptr;
for (int i = 0; i < inputs.size(); i++) {
Expand All @@ -127,7 +149,10 @@ std::vector<at::Tensor> mlp_backward(
work_space.data_ptr<scalar_t>(),
outputs_ptr[0],
outputs_ptr.data() + 1,
outputs_ptr.data() + 1 + num_layers);
outputs_ptr.data() + 1 + num_layers,
requires_grad,
use_bias,
activation);
});

return outputs;
Expand Down
Loading

0 comments on commit 31aceea

Please sign in to comment.