Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

implement LOG1P and MATRIX_LOG1P in BMG #1638

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/beanmachine/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,8 @@ enum class OperatorType {
LOG_PROB,
MATRIX_SUM,
MATRIX_LOG,
LOG1P,
MATRIX_LOG1P,
};

enum class DistributionType {
Expand Down
16 changes: 16 additions & 0 deletions src/beanmachine/graph/operator/backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,5 +369,21 @@ void MatrixSum::backward() {
}
}

void Log1p::backward() {
assert(in_nodes.size() == 1);
if (in_nodes[0]->needs_gradient()) {
double jacob = 1 / (1 + in_nodes[0]->value._double);
in_nodes[0]->back_grad1 += back_grad1 * jacob;
}
}

void MatrixLog1p::backward() {
assert(in_nodes.size() == 1);
if (in_nodes[0]->needs_gradient()) {
in_nodes[0]->back_grad1 += back_grad1.as_matrix().cwiseQuotient(
(in_nodes[0]->value._matrix.array() + 1).matrix());
}
}

} // namespace oper
} // namespace beanmachine
27 changes: 27 additions & 0 deletions src/beanmachine/graph/operator/gradient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,5 +600,32 @@ void MatrixSum::compute_gradients() {
grad2 = in_nodes[0]->Grad2.sum();
}

void Log1p::compute_gradients() {
assert(in_nodes.size() == 1);
// f(x) = log(g(x) + 1)
// f'(x) = g'(x) / (g(x) + 1)
// f''(x) = ((g(x) + 1) g''(x) - g'(x)^2)/(g(x) + 1)^2
auto g = in_nodes[0]->value._double;
auto gp1 = g + 1;
auto g1 = in_nodes[0]->grad1;
auto g2 = in_nodes[0]->grad2;
grad1 = g1 / gp1;
grad2 = (gp1 * g2 - g1 * g1) / (gp1 * gp1);
}

void MatrixLog1p::compute_gradients() {
assert(in_nodes.size() == 1);
// f(x) = log(g(x) + 1)
// f'(x) = g'(x) / (g(x) + 1)
// f''(x) = ((g(x) + 1) g''(x) - g'(x)^2)/(g(x) + 1)^2
auto g = in_nodes[0]->value._matrix;
auto gp1 = (g.array() + 1).matrix();
auto g1 = in_nodes[0]->Grad1;
auto g2 = in_nodes[0]->Grad2;
Grad1 = g1.cwiseQuotient(gp1);
Grad2 = (gp1.cwiseProduct(g2) - g1.cwiseProduct(g1))
.cwiseQuotient(gp1.cwiseProduct(gp1));
}

} // namespace oper
} // namespace beanmachine
29 changes: 29 additions & 0 deletions src/beanmachine/graph/operator/linalgop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,5 +504,34 @@ void MatrixLog::eval(std::mt19937& /* gen */) {
value._matrix = Eigen::log(in_nodes[0]->value._matrix.array());
}

MatrixLog1p::MatrixLog1p(const std::vector<graph::Node*>& in_nodes)
: Operator(graph::OperatorType::MATRIX_LOG1P) {
if (in_nodes.size() != 1) {
throw std::invalid_argument("MATRIX_LOG1P requires one parent node");
}
auto type = in_nodes[0]->value.type;
if (type.variable_type != graph::VariableType::BROADCAST_MATRIX) {
throw std::invalid_argument(
"the parent of MATRIX_LOG1P must be a BROADCAST_MATRIX");
}
auto atomic_type = type.atomic_type;
graph::AtomicType new_type;
if (atomic_type == graph::AtomicType::POS_REAL) {
new_type = graph::AtomicType::REAL;
} else if (atomic_type == graph::AtomicType::PROBABILITY) {
new_type = graph::AtomicType::NEG_REAL;
} else {
throw std::invalid_argument(
"operator MATRIX_LOG1P requires a probability or pos_real parent");
}
value = graph::NodeValue(graph::ValueType(
graph::VariableType::BROADCAST_MATRIX, new_type, type.rows, type.cols));
}

void MatrixLog1p::eval(std::mt19937& /* gen */) {
assert(in_nodes.size() == 1);
value._matrix = Eigen::log1p(in_nodes[0]->value._matrix.array());
}

} // namespace oper
} // namespace beanmachine
15 changes: 15 additions & 0 deletions src/beanmachine/graph/operator/linalgop.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,20 @@ class MatrixLog : public Operator {
}
};

class MatrixLog1p : public Operator {
public:
explicit MatrixLog1p(const std::vector<graph::Node*>& in_nodes);
~MatrixLog1p() override {}

void eval(std::mt19937& gen) override;
void backward() override;
void compute_gradients() override;

static std::unique_ptr<Operator> new_op(
const std::vector<graph::Node*>& in_nodes) {
return std::make_unique<MatrixLog1p>(in_nodes);
}
};

} // namespace oper
} // namespace beanmachine
8 changes: 8 additions & 0 deletions src/beanmachine/graph/operator/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ bool ::beanmachine::oper::OperatorFactory::factories_are_registered =

OperatorFactory::register_op(graph::OperatorType::LOG, &(Log::new_op)) &&

OperatorFactory::register_op(
graph::OperatorType::LOG1P,
&(Log1p::new_op)) &&

// linear algebra op
OperatorFactory::register_op(
graph::OperatorType::TRANSPOSE,
Expand Down Expand Up @@ -140,6 +144,10 @@ bool ::beanmachine::oper::OperatorFactory::factories_are_registered =
graph::OperatorType::MATRIX_LOG,
&(MatrixLog::new_op)) &&

OperatorFactory::register_op(
graph::OperatorType::MATRIX_LOG1P,
&(MatrixLog1p::new_op)) &&

// matrix index
OperatorFactory::register_op(
graph::OperatorType::INDEX,
Expand Down
64 changes: 63 additions & 1 deletion src/beanmachine/graph/operator/tests/gradient_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ TEST(testgradient, matrix_log_grad_backward) {
# backwards gradients as PyTorch.

import torch
hn = torch.distributions.HalfNormal(0, 1)
hn = torch.distributions.HalfNormal(1)
s0 = torch.tensor(1.0, requires_grad=True)
s1 = torch.tensor(0.5, requires_grad=True)
mlog0 = s0.log()
Expand Down Expand Up @@ -1564,3 +1564,65 @@ TEST(testgradient, matrix_sum) {
EXPECT_NEAR((*grad1[0]), 2.0, 1e-3);
EXPECT_NEAR((*grad1[1]), -1.25, 1e-3);
}

TEST(testgradient, matrix_log1p_grad_backward) {
/*
# Test backward differentiation
#
# Build the same model in PyTorch and BMG; we should get the same
# backwards gradients as PyTorch.

import torch
hn = torch.distributions.HalfNormal(1)
s0 = torch.tensor(1.0, requires_grad=True)
s1 = torch.tensor(0.5, requires_grad=True)
mlog0 = s0.log1p()
mlog1 = s1.log1p()
n0 = torch.distributions.Normal(mlog0, 1.0)
n1 = torch.distributions.Normal(mlog1, 1.0)
sn0 = torch.tensor(2.5, requires_grad=True)
sn1 = torch.tensor(1.5, requires_grad=True)
log_prob = (n0.log_prob(sn0) + n1.log_prob(sn1) +
hn.log_prob(s0) + hn.log_prob(s1))
print(torch.autograd.grad(log_prob, s0, retain_graph=True)) # -0.0966
print(torch.autograd.grad(log_prob, s1, retain_graph=True)) # 0.2297
print(torch.autograd.grad(log_prob, sn0, retain_graph=True)) # -1.8069
print(torch.autograd.grad(log_prob, sn1, retain_graph=True)) # -1.0945
*/

Graph g;
auto one = g.add_constant_pos_real(1.0);
auto hn = g.add_distribution(
DistributionType::HALF_NORMAL, AtomicType::POS_REAL, {one});
auto two = g.add_constant((natural_t)2);
auto hn_sample =
g.add_operator(OperatorType::IID_SAMPLE, std::vector<uint>{hn, two});
Eigen::MatrixXd hn_observed(2, 1);
hn_observed << 1.0, 0.5;
g.observe(hn_sample, hn_observed);

auto mlog_pos = g.add_operator(OperatorType::MATRIX_LOG1P, {hn_sample});
auto mlog = g.add_operator(OperatorType::TO_REAL_MATRIX, {mlog_pos});
auto index_zero = g.add_constant((natural_t)0);
auto mlog0 = g.add_operator(OperatorType::INDEX, {mlog, index_zero});
auto index_one = g.add_constant((natural_t)1);
auto mlog1 = g.add_operator(OperatorType::INDEX, {mlog, index_one});

auto n0 = g.add_distribution(
DistributionType::NORMAL, AtomicType::REAL, {mlog0, one});
auto n1 = g.add_distribution(
DistributionType::NORMAL, AtomicType::REAL, {mlog1, one});

auto ns0 = g.add_operator(OperatorType::SAMPLE, std::vector<uint>{n0});
g.observe(ns0, 2.5);
auto ns1 = g.add_operator(OperatorType::SAMPLE, std::vector<uint>{n1});
g.observe(ns1, 1.5);

std::vector<DoubleMatrix*> grad1;
g.eval_and_grad(grad1);
EXPECT_EQ(grad1.size(), 3);
EXPECT_NEAR((*grad1[0])(0), -0.0966, 1e-3);
EXPECT_NEAR((*grad1[0])(1), 0.2297, 1e-3);
EXPECT_NEAR((*grad1[1]), -1.8069, 1e-3);
EXPECT_NEAR((*grad1[2]), -1.0945, 1e-3);
}
100 changes: 100 additions & 0 deletions src/beanmachine/graph/operator/tests/operator_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1933,3 +1933,103 @@ TEST(testoperator, matrix_sum) {
g.get_node(sum3)->eval(gen);
EXPECT_NEAR(g.get_node(sum3)->value._double, m3.sum(), 1e-5);
}

TEST(testoperator, log1p) {
Graph g;
// negative tests: exactly one real or pos_real should be the input.
EXPECT_THROW(
g.add_operator(OperatorType::LOG1P, std::vector<uint>{}),
std::invalid_argument);
auto prob1 = g.add_constant_probability(0.5);
EXPECT_THROW(
g.add_operator(OperatorType::LOG1P, std::vector<uint>{prob1}),
std::invalid_argument);
auto real1 = g.add_constant(-0.5);
/* ok */ g.add_operator(OperatorType::LOG1P, std::vector<uint>{real1});
auto pos1 = g.add_constant_pos_real(1.0);
/* ok */ g.add_operator(OperatorType::LOG1P, std::vector<uint>{pos1});
EXPECT_THROW(
g.add_operator(OperatorType::LOG1P, std::vector<uint>{pos1, pos1}),
std::invalid_argument);

// y ~ Normal(log1p(x^2), 1)
// If we observe x = 0.5 then the mean should be log1p(0.25) = 0.223.
auto prior = g.add_distribution(
DistributionType::FLAT, AtomicType::POS_REAL, std::vector<uint>{});
auto x = g.add_operator(OperatorType::SAMPLE, std::vector<uint>{prior});
auto x_sq = g.add_operator(OperatorType::MULTIPLY, std::vector<uint>{x, x});
auto log1p_x_sq =
g.add_operator(OperatorType::LOG1P, std::vector<uint>{x_sq});
auto likelihood = g.add_distribution(
DistributionType::NORMAL,
AtomicType::REAL,
std::vector<uint>{log1p_x_sq, pos1});
auto y = g.add_operator(OperatorType::SAMPLE, std::vector<uint>{likelihood});
g.query(y);
g.observe(x, 0.5);
const auto& means = g.infer_mean(10000, InferenceType::NMC);
EXPECT_NEAR(means[0], 0.223, 0.01);
g.observe(y, 0.0);

// check forward gradient:
// Verified in PyTorch using the following code:
//
// x = tensor(0.5, requires_grad=True)
// fx = Normal((x * x).log1p(), tensor(1.0)).log_prob(tensor(0.0))
// f1x = grad(fx, x, create_graph=True)
// f2x = grad(f1x, x)
//
// f1x -> -0.1785 and f2x -> -0.8542
double grad1 = 0;
double grad2 = 0;
g.gradient_log_prob(x, grad1, grad2);
EXPECT_NEAR(grad1, -0.1785, 1e-3);
EXPECT_NEAR(grad2, -0.8542, 1e-3);

// test the reverse gradient
std::vector<DoubleMatrix*> grad;
g.eval_and_grad(grad);
EXPECT_EQ(grad.size(), 2);
EXPECT_NEAR((*grad[0]), -0.1785, 1e-3);
}

TEST(testoperator, matrix_log1p) {
Graph g;

// negative tests
// MATRIX_LOG1P requires matrix parent
auto real_number = g.add_constant(2.0);
EXPECT_THROW(
g.add_operator(OperatorType::MATRIX_LOG1P, {real_number}),
std::invalid_argument);
// must be pos real or prob
Eigen::MatrixXb bools(2, 1);
bools << false, true;
auto bools_matrix = g.add_constant_bool_matrix(bools);
EXPECT_THROW(
g.add_operator(OperatorType::MATRIX_LOG1P, {bools_matrix}),
std::invalid_argument);
// can only have one parent
Eigen::MatrixXd m1(3, 1);
m1 << 2.0, 1.0, 3.0;
auto m1_matrix = g.add_constant_pos_matrix(m1);
Eigen::MatrixXd m2(1, 2);
m2 << 0.5, 20.0;
auto m2_matrix = g.add_constant_pos_matrix(m2);
EXPECT_THROW(
g.add_operator(OperatorType::MATRIX_LOG1P, {m1_matrix, m2_matrix}),
std::invalid_argument);

auto mlog1p = g.add_operator(OperatorType::MATRIX_LOG1P, {m1_matrix});
g.query(mlog1p);

auto mlog1p_infer = g.infer(2, InferenceType::REJECTION)[0][0];
Eigen::MatrixXd mlog1p_expected(3, 1);
mlog1p_expected << 2.0, 1.0, 3.0;
mlog1p_expected = Eigen::log1p(mlog1p_expected.array());
for (uint i = 0; i < mlog1p_infer.type.rows; i++) {
for (uint j = 0; j < mlog1p_infer.type.cols; j++) {
EXPECT_NEAR(mlog1p_expected(i, j), mlog1p_infer._matrix(i, j), 1e-4);
}
}
}
22 changes: 22 additions & 0 deletions src/beanmachine/graph/operator/unaryop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -560,5 +560,27 @@ void LogSumExpVector::eval(std::mt19937& /* gen */) {
}
}

Log1p::Log1p(const std::vector<graph::Node*>& in_nodes)
: UnaryOperator(graph::OperatorType::LOG1P, in_nodes) {
graph::ValueType type0 = in_nodes[0]->value.type;
if (type0 != graph::AtomicType::POS_REAL &&
type0 != graph::AtomicType::REAL) {
throw std::invalid_argument(
"operator LOG1P requires a real or pos_real parent");
}
value = graph::NodeValue(graph::AtomicType::REAL);

if (in_nodes.size() != 1) {
throw std::invalid_argument("LOG1P requires one parent node");
}
value = graph::NodeValue(graph::AtomicType::REAL);
}

void Log1p::eval(std::mt19937& /* gen */) {
assert(in_nodes.size() == 1);
const graph::NodeValue& parent = in_nodes[0]->value;
value = graph::NodeValue(graph::AtomicType::REAL, std::log1p(parent._double));
}

} // namespace oper
} // namespace beanmachine
15 changes: 15 additions & 0 deletions src/beanmachine/graph/operator/unaryop.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,5 +290,20 @@ class LogSumExpVector : public UnaryOperator {
}
};

class Log1p : public UnaryOperator {
public:
explicit Log1p(const std::vector<graph::Node*>& in_nodes);
~Log1p() override {}

void eval(std::mt19937& gen) override;
void compute_gradients() override;
void backward() override;

static std::unique_ptr<Operator> new_op(
const std::vector<graph::Node*>& in_nodes) {
return std::make_unique<Log1p>(in_nodes);
}
};

} // namespace oper
} // namespace beanmachine
4 changes: 3 additions & 1 deletion src/beanmachine/graph/pybindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ PYBIND11_MODULE(graph, module) {
.value("LOGSUMEXP", OperatorType::LOGSUMEXP)
.value("IF_THEN_ELSE", OperatorType::IF_THEN_ELSE)
.value("LOG", OperatorType::LOG)
.value("LOG1P", OperatorType::LOG1P)
.value("POW", OperatorType::POW)
.value("TRANSPOSE", OperatorType::TRANSPOSE)
.value("MATRIX_MULTIPLY", OperatorType::MATRIX_MULTIPLY)
Expand All @@ -91,7 +92,8 @@ PYBIND11_MODULE(graph, module) {
.value("CHOLESKY", OperatorType::CHOLESKY)
.value("MATRIX_EXP", OperatorType::MATRIX_EXP)
.value("MATRIX_SUM", OperatorType::MATRIX_SUM)
.value("MATRIX_LOG", OperatorType::MATRIX_LOG);
.value("MATRIX_LOG", OperatorType::MATRIX_LOG)
.value("MATRIX_LOG1P", OperatorType::MATRIX_LOG1P);

py::enum_<DistributionType>(module, "DistributionType")
.value("TABULAR", DistributionType::TABULAR)
Expand Down
Loading