Skip to content

Commit

Permalink
Adding common loss Modules
Browse files Browse the repository at this point in the history
- MeanSquaredError
- MeanAbsoluteError
- BinaryCrossEntropyLoss

Added typedefs for some alternative names;
  • Loading branch information
pavanky committed Jul 24, 2017
1 parent ecc2bad commit 7a64564
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 9 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ target_sources(afml
src/nn/Modules/Activations.cpp
src/nn/Modules/Container.cpp
src/nn/Modules/Linear.cpp
src/nn/Modules/Loss.cpp
src/nn/Modules/Module.cpp
src/nn/Init.cpp
)
Expand Down
1 change: 1 addition & 0 deletions examples/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
********************************************************/

#include <af/autograd.h>
#include <af/nn.h>

#define VERIFY(VAL) do { \
auto res = af::allTrue<bool>(af::abs(VAL) < 1E-5); \
Expand Down
12 changes: 6 additions & 6 deletions examples/perceptron.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ int main()
perceptron.add(nn::Linear(inputSize, outputSize));
perceptron.add(nn::Sigmoid());

Variable result;
auto loss = nn::MeanSquaredError();

Variable result, l;
for (int i = 0; i < 1000; i++) {
for (int j = 0; j < numSamples; j++) {
perceptron.train();
Expand All @@ -52,17 +54,15 @@ int main()
result = perceptron.forward(nn::input(in_j));

// Calculate loss
// TODO: Use loss function
af::array diff = out_j - result.array();
l = loss.forward(result, nn::noGrad(out_j));

// Backward propagation
auto d_result = Variable(diff, false);
result.backward(d_result);
l.backward();

// Update parameters
// TODO: Should use optimizer
for (auto &param : perceptron.parameters()) {
param.array() += lr * param.grad().array();
param.array() -= lr * param.grad().array();
param.array().eval();
}
}
Expand Down
1 change: 1 addition & 0 deletions include/af/autograd/Variable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ namespace af {

void backward(const Variable &grad, bool retain_grad_graph = false);

void backward(bool retain_grad_graph = false);

private:
void evalGrad(bool retain_grad_graph = false);
Expand Down
2 changes: 2 additions & 0 deletions include/af/nn/Init.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ namespace af {

autograd::Variable input(const af::array &arr);

autograd::Variable noGrad(const af::array &arr);

autograd::Variable parameter(const af::array &arr);

autograd::Variable uniform(int input_size, int output_size,
Expand Down
1 change: 1 addition & 0 deletions include/af/nn/Modules.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
#include <af/nn/Modules/Linear.hpp>
#include <af/nn/Modules/Container.hpp>
#include <af/nn/Modules/Activations.hpp>
#include <af/nn/Modules/Loss.hpp>
4 changes: 1 addition & 3 deletions include/af/nn/Modules/Container.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ namespace af
ModulePtr get(int id);

std::vector<ModulePtr> modules();

virtual autograd::Variable forward(const autograd::Variable &input) = 0;
};

class Sequential : public Container
Expand All @@ -51,7 +49,7 @@ namespace af

Sequential();

virtual autograd::Variable forward(const autograd::Variable &input);
autograd::Variable forward(const autograd::Variable &input);
};
}
}
64 changes: 64 additions & 0 deletions include/af/nn/Modules/Loss.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*******************************************************
* Copyright (c) 2017, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#pragma once

#include <af/nn/Modules/Module.hpp>

namespace af
{
namespace nn
{
class Loss : public Module
{
public:
Loss() {}

virtual autograd::Variable forward(const autograd::Variable &inputs,
const autograd::Variable &targets) = 0;

autograd::Variable forward(const autograd::Variable &inputs);
};

class MeanSquaredError : public Loss
{
public:
MeanSquaredError() {}

autograd::Variable forward(const autograd::Variable &inputs,
const autograd::Variable &targets);
};

class MeanAbsoluteError : public Loss
{
public:
MeanAbsoluteError() {}

autograd::Variable forward(const autograd::Variable &inputs,
const autograd::Variable &targets);
};

class BinaryCrossEntropyLoss : public Loss
{
public:
BinaryCrossEntropyLoss() {}

autograd::Variable forward(const autograd::Variable &inputs,
const autograd::Variable &targets);

autograd::Variable forward(const autograd::Variable &inputs,
const autograd::Variable &targets,
const autograd::Variable &weights);
};

typedef MeanSquaredError MSE;
typedef MeanAbsoluteError MAE;
typedef MeanAbsoluteError L1Loss;
typedef BinaryCrossEntropyLoss BCELoss;
}
}
6 changes: 6 additions & 0 deletions src/autograd/Variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ namespace af {
}
}

void Variable::backward(bool retain_grad_graph)
{
auto ones = Variable(af::constant(1, this->dims()), false);
this->backward(ones, retain_grad_graph);
}

Variable::DAG_t Variable::build(const Variable &var)
{
Cache_t cache;
Expand Down
5 changes: 5 additions & 0 deletions src/nn/Init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ namespace af {
return Variable(arr, false);
}

Variable noGrad(const af::array &arr)
{
return Variable(arr, false);
}

Variable parameter(const af::array &arr)
{
return Variable(arr, true);
Expand Down
59 changes: 59 additions & 0 deletions src/nn/Modules/Loss.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*******************************************************
* Copyright (c) 2017, ArrayFire
* All rights reserved.
*
* This file is distributed under 3-clause BSD license.
* The complete license agreement can be obtained at:
* http://arrayfire.com/licenses/BSD-3-Clause
********************************************************/
#include <af/autograd/Functions.hpp>
#include <af/nn/Modules/Loss.hpp>


namespace af
{
namespace nn
{
using namespace autograd;

autograd::Variable Loss::forward(const autograd::Variable &inputs)
{
throw af::exception("Loss module requires both inputs and targets");
}

autograd::Variable MeanSquaredError::forward(const autograd::Variable &inputs,
const autograd::Variable &targets)
{
auto df = inputs - targets;
auto res = mean(flat(df * df), {0});
return res;
}

autograd::Variable MeanAbsoluteError::forward(const autograd::Variable &inputs,
const autograd::Variable &targets)
{
auto df = inputs - targets;
auto res = mean(flat(abs(df)), {0});
}

static autograd::Variable
binaryCrossEntropy(const autograd::Variable &inputs,
const autograd::Variable &targets)
{
targets * inputs + (1 - targets) * (1 - inputs);
}

autograd::Variable BinaryCrossEntropyLoss::forward(const autograd::Variable &inputs,
const autograd::Variable &targets)
{
return mean(flat(binaryCrossEntropy(inputs, targets)), {0});
}

autograd::Variable BinaryCrossEntropyLoss::forward(const autograd::Variable &inputs,
const autograd::Variable &targets,
const autograd::Variable &weights)
{
return mean(flat(weights * binaryCrossEntropy(inputs, targets)), {0});
}
}
}

0 comments on commit 7a64564

Please sign in to comment.