From e80f7687fcbfcbb57596f4e30f62f234d8411372 Mon Sep 17 00:00:00 2001 From: Neal Gafter Date: Wed, 31 Aug 2022 08:31:59 -0700 Subject: [PATCH] implement LOG1P and MATRIX_LOG1P in BMG Summary: Implement the scalar operator LOG1P for log(1 + x) And the corresponding elementwise matrix operator MATRIX_LOG1P Reviewed By: ericlippert Differential Revision: D38923743 fbshipit-source-id: 430057749af89f900c9c8bf43ef928f80dba5094 --- src/beanmachine/graph/graph.h | 2 + src/beanmachine/graph/operator/backward.cpp | 16 +++ src/beanmachine/graph/operator/gradient.cpp | 27 +++++ src/beanmachine/graph/operator/linalgop.cpp | 29 +++++ src/beanmachine/graph/operator/linalgop.h | 15 +++ src/beanmachine/graph/operator/register.cpp | 8 ++ .../graph/operator/tests/gradient_test.cpp | 64 ++++++++++- .../graph/operator/tests/operator_test.cpp | 100 ++++++++++++++++++ src/beanmachine/graph/operator/unaryop.cpp | 22 ++++ src/beanmachine/graph/operator/unaryop.h | 15 +++ src/beanmachine/graph/pybindings.cpp | 4 +- src/beanmachine/graph/to_dot.cpp | 4 + 12 files changed, 304 insertions(+), 2 deletions(-) diff --git a/src/beanmachine/graph/graph.h b/src/beanmachine/graph/graph.h index 8677cce900..81cb0f66f9 100644 --- a/src/beanmachine/graph/graph.h +++ b/src/beanmachine/graph/graph.h @@ -344,6 +344,8 @@ enum class OperatorType { LOG_PROB, MATRIX_SUM, MATRIX_LOG, + LOG1P, + MATRIX_LOG1P, }; enum class DistributionType { diff --git a/src/beanmachine/graph/operator/backward.cpp b/src/beanmachine/graph/operator/backward.cpp index 6bec129491..57bf0ea3b1 100644 --- a/src/beanmachine/graph/operator/backward.cpp +++ b/src/beanmachine/graph/operator/backward.cpp @@ -369,5 +369,21 @@ void MatrixSum::backward() { } } +void Log1p::backward() { + assert(in_nodes.size() == 1); + if (in_nodes[0]->needs_gradient()) { + double jacob = 1 / (1 + in_nodes[0]->value._double); + in_nodes[0]->back_grad1 += back_grad1 * jacob; + } +} + +void MatrixLog1p::backward() { + assert(in_nodes.size() == 1); + if (in_nodes[0]->needs_gradient()) { + in_nodes[0]->back_grad1 += back_grad1.as_matrix().cwiseQuotient( + (in_nodes[0]->value._matrix.array() + 1).matrix()); + } +} + } // namespace oper } // namespace beanmachine diff --git a/src/beanmachine/graph/operator/gradient.cpp b/src/beanmachine/graph/operator/gradient.cpp index 138b8745be..e57ba5786b 100644 --- a/src/beanmachine/graph/operator/gradient.cpp +++ b/src/beanmachine/graph/operator/gradient.cpp @@ -600,5 +600,32 @@ void MatrixSum::compute_gradients() { grad2 = in_nodes[0]->Grad2.sum(); } +void Log1p::compute_gradients() { + assert(in_nodes.size() == 1); + // f(x) = log(g(x) + 1) + // f'(x) = g'(x) / (g(x) + 1) + // f''(x) = ((g(x) + 1) g''(x) - g'(x)^2)/(g(x) + 1)^2 + auto g = in_nodes[0]->value._double; + auto gp1 = g + 1; + auto g1 = in_nodes[0]->grad1; + auto g2 = in_nodes[0]->grad2; + grad1 = g1 / gp1; + grad2 = (gp1 * g2 - g1 * g1) / (gp1 * gp1); +} + +void MatrixLog1p::compute_gradients() { + assert(in_nodes.size() == 1); + // f(x) = log(g(x) + 1) + // f'(x) = g'(x) / (g(x) + 1) + // f''(x) = ((g(x) + 1) g''(x) - g'(x)^2)/(g(x) + 1)^2 + auto g = in_nodes[0]->value._matrix; + auto gp1 = (g.array() + 1).matrix(); + auto g1 = in_nodes[0]->Grad1; + auto g2 = in_nodes[0]->Grad2; + Grad1 = g1.cwiseQuotient(gp1); + Grad2 = (gp1.cwiseProduct(g2) - g1.cwiseProduct(g1)) + .cwiseQuotient(gp1.cwiseProduct(gp1)); +} + } // namespace oper } // namespace beanmachine diff --git a/src/beanmachine/graph/operator/linalgop.cpp b/src/beanmachine/graph/operator/linalgop.cpp index 5db0c48fef..48fe91a0a1 100644 --- a/src/beanmachine/graph/operator/linalgop.cpp +++ b/src/beanmachine/graph/operator/linalgop.cpp @@ -504,5 +504,34 @@ void MatrixLog::eval(std::mt19937& /* gen */) { value._matrix = Eigen::log(in_nodes[0]->value._matrix.array()); } +MatrixLog1p::MatrixLog1p(const std::vector& in_nodes) + : Operator(graph::OperatorType::MATRIX_LOG1P) { + if (in_nodes.size() != 1) { + throw std::invalid_argument("MATRIX_LOG1P requires one parent node"); + } + auto type = in_nodes[0]->value.type; + if (type.variable_type != graph::VariableType::BROADCAST_MATRIX) { + throw std::invalid_argument( + "the parent of MATRIX_LOG1P must be a BROADCAST_MATRIX"); + } + auto atomic_type = type.atomic_type; + graph::AtomicType new_type; + if (atomic_type == graph::AtomicType::POS_REAL) { + new_type = graph::AtomicType::REAL; + } else if (atomic_type == graph::AtomicType::PROBABILITY) { + new_type = graph::AtomicType::NEG_REAL; + } else { + throw std::invalid_argument( + "operator MATRIX_LOG1P requires a probability or pos_real parent"); + } + value = graph::NodeValue(graph::ValueType( + graph::VariableType::BROADCAST_MATRIX, new_type, type.rows, type.cols)); +} + +void MatrixLog1p::eval(std::mt19937& /* gen */) { + assert(in_nodes.size() == 1); + value._matrix = Eigen::log1p(in_nodes[0]->value._matrix.array()); +} + } // namespace oper } // namespace beanmachine diff --git a/src/beanmachine/graph/operator/linalgop.h b/src/beanmachine/graph/operator/linalgop.h index 85aef4d4e9..d52f60e809 100644 --- a/src/beanmachine/graph/operator/linalgop.h +++ b/src/beanmachine/graph/operator/linalgop.h @@ -198,5 +198,20 @@ class MatrixLog : public Operator { } }; +class MatrixLog1p : public Operator { + public: + explicit MatrixLog1p(const std::vector& in_nodes); + ~MatrixLog1p() override {} + + void eval(std::mt19937& gen) override; + void backward() override; + void compute_gradients() override; + + static std::unique_ptr new_op( + const std::vector& in_nodes) { + return std::make_unique(in_nodes); + } +}; + } // namespace oper } // namespace beanmachine diff --git a/src/beanmachine/graph/operator/register.cpp b/src/beanmachine/graph/operator/register.cpp index b4802e362f..faaec4fea9 100644 --- a/src/beanmachine/graph/operator/register.cpp +++ b/src/beanmachine/graph/operator/register.cpp @@ -107,6 +107,10 @@ bool ::beanmachine::oper::OperatorFactory::factories_are_registered = OperatorFactory::register_op(graph::OperatorType::LOG, &(Log::new_op)) && + OperatorFactory::register_op( + graph::OperatorType::LOG1P, + &(Log1p::new_op)) && + // linear algebra op OperatorFactory::register_op( graph::OperatorType::TRANSPOSE, @@ -140,6 +144,10 @@ bool ::beanmachine::oper::OperatorFactory::factories_are_registered = graph::OperatorType::MATRIX_LOG, &(MatrixLog::new_op)) && + OperatorFactory::register_op( + graph::OperatorType::MATRIX_LOG1P, + &(MatrixLog1p::new_op)) && + // matrix index OperatorFactory::register_op( graph::OperatorType::INDEX, diff --git a/src/beanmachine/graph/operator/tests/gradient_test.cpp b/src/beanmachine/graph/operator/tests/gradient_test.cpp index acbaff39ad..443ae00847 100644 --- a/src/beanmachine/graph/operator/tests/gradient_test.cpp +++ b/src/beanmachine/graph/operator/tests/gradient_test.cpp @@ -1126,7 +1126,7 @@ TEST(testgradient, matrix_log_grad_backward) { # backwards gradients as PyTorch. import torch -hn = torch.distributions.HalfNormal(0, 1) +hn = torch.distributions.HalfNormal(1) s0 = torch.tensor(1.0, requires_grad=True) s1 = torch.tensor(0.5, requires_grad=True) mlog0 = s0.log() @@ -1564,3 +1564,65 @@ TEST(testgradient, matrix_sum) { EXPECT_NEAR((*grad1[0]), 2.0, 1e-3); EXPECT_NEAR((*grad1[1]), -1.25, 1e-3); } + +TEST(testgradient, matrix_log1p_grad_backward) { + /* +# Test backward differentiation +# +# Build the same model in PyTorch and BMG; we should get the same +# backwards gradients as PyTorch. + +import torch +hn = torch.distributions.HalfNormal(1) +s0 = torch.tensor(1.0, requires_grad=True) +s1 = torch.tensor(0.5, requires_grad=True) +mlog0 = s0.log1p() +mlog1 = s1.log1p() +n0 = torch.distributions.Normal(mlog0, 1.0) +n1 = torch.distributions.Normal(mlog1, 1.0) +sn0 = torch.tensor(2.5, requires_grad=True) +sn1 = torch.tensor(1.5, requires_grad=True) +log_prob = (n0.log_prob(sn0) + n1.log_prob(sn1) + + hn.log_prob(s0) + hn.log_prob(s1)) +print(torch.autograd.grad(log_prob, s0, retain_graph=True)) # -0.0966 +print(torch.autograd.grad(log_prob, s1, retain_graph=True)) # 0.2297 +print(torch.autograd.grad(log_prob, sn0, retain_graph=True)) # -1.8069 +print(torch.autograd.grad(log_prob, sn1, retain_graph=True)) # -1.0945 + */ + + Graph g; + auto one = g.add_constant_pos_real(1.0); + auto hn = g.add_distribution( + DistributionType::HALF_NORMAL, AtomicType::POS_REAL, {one}); + auto two = g.add_constant((natural_t)2); + auto hn_sample = + g.add_operator(OperatorType::IID_SAMPLE, std::vector{hn, two}); + Eigen::MatrixXd hn_observed(2, 1); + hn_observed << 1.0, 0.5; + g.observe(hn_sample, hn_observed); + + auto mlog_pos = g.add_operator(OperatorType::MATRIX_LOG1P, {hn_sample}); + auto mlog = g.add_operator(OperatorType::TO_REAL_MATRIX, {mlog_pos}); + auto index_zero = g.add_constant((natural_t)0); + auto mlog0 = g.add_operator(OperatorType::INDEX, {mlog, index_zero}); + auto index_one = g.add_constant((natural_t)1); + auto mlog1 = g.add_operator(OperatorType::INDEX, {mlog, index_one}); + + auto n0 = g.add_distribution( + DistributionType::NORMAL, AtomicType::REAL, {mlog0, one}); + auto n1 = g.add_distribution( + DistributionType::NORMAL, AtomicType::REAL, {mlog1, one}); + + auto ns0 = g.add_operator(OperatorType::SAMPLE, std::vector{n0}); + g.observe(ns0, 2.5); + auto ns1 = g.add_operator(OperatorType::SAMPLE, std::vector{n1}); + g.observe(ns1, 1.5); + + std::vector grad1; + g.eval_and_grad(grad1); + EXPECT_EQ(grad1.size(), 3); + EXPECT_NEAR((*grad1[0])(0), -0.0966, 1e-3); + EXPECT_NEAR((*grad1[0])(1), 0.2297, 1e-3); + EXPECT_NEAR((*grad1[1]), -1.8069, 1e-3); + EXPECT_NEAR((*grad1[2]), -1.0945, 1e-3); +} diff --git a/src/beanmachine/graph/operator/tests/operator_test.cpp b/src/beanmachine/graph/operator/tests/operator_test.cpp index 30ce4dec3e..e004c599c0 100644 --- a/src/beanmachine/graph/operator/tests/operator_test.cpp +++ b/src/beanmachine/graph/operator/tests/operator_test.cpp @@ -1933,3 +1933,103 @@ TEST(testoperator, matrix_sum) { g.get_node(sum3)->eval(gen); EXPECT_NEAR(g.get_node(sum3)->value._double, m3.sum(), 1e-5); } + +TEST(testoperator, log1p) { + Graph g; + // negative tests: exactly one real or pos_real should be the input. + EXPECT_THROW( + g.add_operator(OperatorType::LOG1P, std::vector{}), + std::invalid_argument); + auto prob1 = g.add_constant_probability(0.5); + EXPECT_THROW( + g.add_operator(OperatorType::LOG1P, std::vector{prob1}), + std::invalid_argument); + auto real1 = g.add_constant(-0.5); + /* ok */ g.add_operator(OperatorType::LOG1P, std::vector{real1}); + auto pos1 = g.add_constant_pos_real(1.0); + /* ok */ g.add_operator(OperatorType::LOG1P, std::vector{pos1}); + EXPECT_THROW( + g.add_operator(OperatorType::LOG1P, std::vector{pos1, pos1}), + std::invalid_argument); + + // y ~ Normal(log1p(x^2), 1) + // If we observe x = 0.5 then the mean should be log1p(0.25) = 0.223. + auto prior = g.add_distribution( + DistributionType::FLAT, AtomicType::POS_REAL, std::vector{}); + auto x = g.add_operator(OperatorType::SAMPLE, std::vector{prior}); + auto x_sq = g.add_operator(OperatorType::MULTIPLY, std::vector{x, x}); + auto log1p_x_sq = + g.add_operator(OperatorType::LOG1P, std::vector{x_sq}); + auto likelihood = g.add_distribution( + DistributionType::NORMAL, + AtomicType::REAL, + std::vector{log1p_x_sq, pos1}); + auto y = g.add_operator(OperatorType::SAMPLE, std::vector{likelihood}); + g.query(y); + g.observe(x, 0.5); + const auto& means = g.infer_mean(10000, InferenceType::NMC); + EXPECT_NEAR(means[0], 0.223, 0.01); + g.observe(y, 0.0); + + // check forward gradient: + // Verified in PyTorch using the following code: + // + // x = tensor(0.5, requires_grad=True) + // fx = Normal((x * x).log1p(), tensor(1.0)).log_prob(tensor(0.0)) + // f1x = grad(fx, x, create_graph=True) + // f2x = grad(f1x, x) + // + // f1x -> -0.1785 and f2x -> -0.8542 + double grad1 = 0; + double grad2 = 0; + g.gradient_log_prob(x, grad1, grad2); + EXPECT_NEAR(grad1, -0.1785, 1e-3); + EXPECT_NEAR(grad2, -0.8542, 1e-3); + + // test the reverse gradient + std::vector grad; + g.eval_and_grad(grad); + EXPECT_EQ(grad.size(), 2); + EXPECT_NEAR((*grad[0]), -0.1785, 1e-3); +} + +TEST(testoperator, matrix_log1p) { + Graph g; + + // negative tests + // MATRIX_LOG1P requires matrix parent + auto real_number = g.add_constant(2.0); + EXPECT_THROW( + g.add_operator(OperatorType::MATRIX_LOG1P, {real_number}), + std::invalid_argument); + // must be pos real or prob + Eigen::MatrixXb bools(2, 1); + bools << false, true; + auto bools_matrix = g.add_constant_bool_matrix(bools); + EXPECT_THROW( + g.add_operator(OperatorType::MATRIX_LOG1P, {bools_matrix}), + std::invalid_argument); + // can only have one parent + Eigen::MatrixXd m1(3, 1); + m1 << 2.0, 1.0, 3.0; + auto m1_matrix = g.add_constant_pos_matrix(m1); + Eigen::MatrixXd m2(1, 2); + m2 << 0.5, 20.0; + auto m2_matrix = g.add_constant_pos_matrix(m2); + EXPECT_THROW( + g.add_operator(OperatorType::MATRIX_LOG1P, {m1_matrix, m2_matrix}), + std::invalid_argument); + + auto mlog1p = g.add_operator(OperatorType::MATRIX_LOG1P, {m1_matrix}); + g.query(mlog1p); + + auto mlog1p_infer = g.infer(2, InferenceType::REJECTION)[0][0]; + Eigen::MatrixXd mlog1p_expected(3, 1); + mlog1p_expected << 2.0, 1.0, 3.0; + mlog1p_expected = Eigen::log1p(mlog1p_expected.array()); + for (uint i = 0; i < mlog1p_infer.type.rows; i++) { + for (uint j = 0; j < mlog1p_infer.type.cols; j++) { + EXPECT_NEAR(mlog1p_expected(i, j), mlog1p_infer._matrix(i, j), 1e-4); + } + } +} diff --git a/src/beanmachine/graph/operator/unaryop.cpp b/src/beanmachine/graph/operator/unaryop.cpp index 37d759faff..1456a96b83 100644 --- a/src/beanmachine/graph/operator/unaryop.cpp +++ b/src/beanmachine/graph/operator/unaryop.cpp @@ -560,5 +560,27 @@ void LogSumExpVector::eval(std::mt19937& /* gen */) { } } +Log1p::Log1p(const std::vector& in_nodes) + : UnaryOperator(graph::OperatorType::LOG1P, in_nodes) { + graph::ValueType type0 = in_nodes[0]->value.type; + if (type0 != graph::AtomicType::POS_REAL && + type0 != graph::AtomicType::REAL) { + throw std::invalid_argument( + "operator LOG1P requires a real or pos_real parent"); + } + value = graph::NodeValue(graph::AtomicType::REAL); + + if (in_nodes.size() != 1) { + throw std::invalid_argument("LOG1P requires one parent node"); + } + value = graph::NodeValue(graph::AtomicType::REAL); +} + +void Log1p::eval(std::mt19937& /* gen */) { + assert(in_nodes.size() == 1); + const graph::NodeValue& parent = in_nodes[0]->value; + value = graph::NodeValue(graph::AtomicType::REAL, std::log1p(parent._double)); +} + } // namespace oper } // namespace beanmachine diff --git a/src/beanmachine/graph/operator/unaryop.h b/src/beanmachine/graph/operator/unaryop.h index 5030618704..4c3150f01a 100644 --- a/src/beanmachine/graph/operator/unaryop.h +++ b/src/beanmachine/graph/operator/unaryop.h @@ -290,5 +290,20 @@ class LogSumExpVector : public UnaryOperator { } }; +class Log1p : public UnaryOperator { + public: + explicit Log1p(const std::vector& in_nodes); + ~Log1p() override {} + + void eval(std::mt19937& gen) override; + void compute_gradients() override; + void backward() override; + + static std::unique_ptr new_op( + const std::vector& in_nodes) { + return std::make_unique(in_nodes); + } +}; + } // namespace oper } // namespace beanmachine diff --git a/src/beanmachine/graph/pybindings.cpp b/src/beanmachine/graph/pybindings.cpp index bb8fb7be27..0fff54118b 100644 --- a/src/beanmachine/graph/pybindings.cpp +++ b/src/beanmachine/graph/pybindings.cpp @@ -72,6 +72,7 @@ PYBIND11_MODULE(graph, module) { .value("LOGSUMEXP", OperatorType::LOGSUMEXP) .value("IF_THEN_ELSE", OperatorType::IF_THEN_ELSE) .value("LOG", OperatorType::LOG) + .value("LOG1P", OperatorType::LOG1P) .value("POW", OperatorType::POW) .value("TRANSPOSE", OperatorType::TRANSPOSE) .value("MATRIX_MULTIPLY", OperatorType::MATRIX_MULTIPLY) @@ -91,7 +92,8 @@ PYBIND11_MODULE(graph, module) { .value("CHOLESKY", OperatorType::CHOLESKY) .value("MATRIX_EXP", OperatorType::MATRIX_EXP) .value("MATRIX_SUM", OperatorType::MATRIX_SUM) - .value("MATRIX_LOG", OperatorType::MATRIX_LOG); + .value("MATRIX_LOG", OperatorType::MATRIX_LOG) + .value("MATRIX_LOG1P", OperatorType::MATRIX_LOG1P); py::enum_(module, "DistributionType") .value("TABULAR", DistributionType::TABULAR) diff --git a/src/beanmachine/graph/to_dot.cpp b/src/beanmachine/graph/to_dot.cpp index 18ddd3f560..4f143f1422 100644 --- a/src/beanmachine/graph/to_dot.cpp +++ b/src/beanmachine/graph/to_dot.cpp @@ -169,6 +169,8 @@ class DOT { return "LogSumExp"; case OperatorType::LOG: return "Log"; + case OperatorType::LOG1P: + return "Log1p"; case OperatorType::POW: return "^"; case OperatorType::LOG1MEXP: @@ -203,6 +205,8 @@ class DOT { return "MatrixSum"; case OperatorType::MATRIX_LOG: return "MatrixLog"; + case OperatorType::MATRIX_LOG1P: + return "MatrixLog1p"; default: throw std::invalid_argument( "internal error: missing case for OperatorType");