Skip to content

Commit

Permalink
Adding necessary functions for Loss modules.
Browse files Browse the repository at this point in the history
- log, flat, moddims
  • Loading branch information
pavanky authored and umar456 committed Jul 27, 2017
1 parent 1288a3e commit dd40725
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
5 changes: 5 additions & 0 deletions include/af/autograd/Functions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
********************************************************/
#pragma once

#include <arrayfire.h>
#include <vector>

namespace af {
Expand Down Expand Up @@ -48,6 +49,7 @@ namespace af {
Variable reciprocal(const Variable &input);

Variable exp(const Variable &input);
Variable log(const Variable &input);
Variable sin(const Variable &input);
Variable cos(const Variable &input);
Variable tanh(const Variable &input);
Expand All @@ -73,6 +75,9 @@ namespace af {
Variable matmulTN(const Variable &lhs, const Variable &rhs);
Variable matmulNT(const Variable &lhs, const Variable &rhs);

Variable abs(const Variable &input);

Variable flat(const Variable &input);
Variable moddims(const Variable &input, const dim4 &dims);
}
}
39 changes: 39 additions & 0 deletions src/autograd/Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,15 @@ namespace af {
return Variable(result, {input}, grad_func);
}

Variable log(const Variable &input)
{
auto result = log(input.array());
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
inputs[0].addGrad(grad_output / inputs[0]);
};
return Variable(result, {input}, grad_func);
}

Variable sin(const Variable &input)
{
auto result = sin(input.array());
Expand Down Expand Up @@ -375,5 +384,35 @@ namespace af {
};
return Variable(result, {lhs, rhs}, grad_func);
}

Variable abs(const Variable &input)
{
auto result = af::abs(input.array());
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
// af::sign returns signbit
// Convert it into -1, 1
auto sign = Variable(1 - 2 * af::sign(inputs[0].array()), false);
inputs[0].addGrad(sign * grad_output);
};
return Variable(result, {input}, grad_func);
}

Variable flat(const Variable &input)
{
auto result = af::flat(input.array());
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
inputs[0].addGrad(moddims(grad_output, inputs[0].dims()));
};
return Variable(result, {input}, grad_func);
}

Variable moddims(const Variable &input, const dim4 &dims)
{
auto result = af::moddims(input.array(), dims);
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
inputs[0].addGrad(moddims(grad_output, inputs[0].dims()));
};
return Variable(result, {input}, grad_func);
}
}
}

0 comments on commit dd40725

Please sign in to comment.