From 3d276277df1b1f8b216cae246d5cdc4f6dd02028 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 8 Nov 2017 14:17:38 +0800 Subject: [PATCH 1/7] Add nce op 1. Add nce forward and backward kernel for CPU --- paddle/operators/nce_op.cc | 120 +++++++++++++++++++++ paddle/operators/nce_op.h | 210 +++++++++++++++++++++++++++++++++++++ 2 files changed, 330 insertions(+) create mode 100644 paddle/operators/nce_op.cc create mode 100644 paddle/operators/nce_op.h diff --git a/paddle/operators/nce_op.cc b/paddle/operators/nce_op.cc new file mode 100644 index 0000000000000..afd61b88514b8 --- /dev/null +++ b/paddle/operators/nce_op.cc @@ -0,0 +1,120 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/nce_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class NCEOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasInput("Label")); + PADDLE_ENFORCE(ctx->HasInput("W")); + PADDLE_ENFORCE(ctx->HasOutput("Out")); + PADDLE_ENFORCE(ctx->HasOutput("SampleLogits")); + PADDLE_ENFORCE(ctx->HasOutput("SampleLabels")); + + auto x_dims = ctx->GetInputDim("X"); + auto label_dims = ctx->GetInputDim("Label"); + PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]); + if (ctx->HasInput("B")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("W")[0], ctx->GetInputDim("B")[0]); + } + int num_sampled_classes = ctx->Attrs().Get("num_sampled_classes"); + int num_classes = ctx->Attrs().Get("num_classes"); + PADDLE_ENFORCE_EQ(num_classes, ctx->GetInputDim("W")[0]); + PADDLE_ENFORCE_LT(num_sampled_classes, num_classes); + + // set dims of output(Out) + std::vector out_dims(1); + out_dims.push_back(x_dims[0]); + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + + // set dims of output(SampleOut) + std::vector sample_out_dims(2); + sample_out_dims.push_back(x_dims[0]); + sample_out_dims.push_back(num_sampled_classes + 1); + ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims)); + ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims)); + } +}; + +class NCEOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", ""); + AddInput("Label", ""); + AddInput("W", ""); + AddInput("B", ""); + AddInput("SampleWeight", ""); + AddOutput("Out", ""); + AddOutput("SampleLogits", ""); + AddOutput("SampleLabels", ""); + AddAttr("num_classes", ""); + AddAttr("num_sampled_classes", "").SetDefault(10); + AddComment(R"DOC( +Expand input(X) according to LOD of input(Y). + +)DOC"); + } +}; + +class NCEOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasInput("W")); + PADDLE_ENFORCE(ctx->HasInput("Out")); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "The input(Out@GRAD) should not be null"); + + auto x_dims = ctx->GetInputDim("X"); + auto x_grad_name = framework::GradVarName("X"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + + auto w_dims = ctx->GetInputDim("W"); + auto w_grad_name = framework::GradVarName("W"); + if (ctx->HasOutput(w_grad_name)) { + ctx->SetOutputDim(w_grad_name, w_dims); + } + + auto bias_grad_name = framework::GradVarName("B"); + if (ctx->HasOutput(bias_grad_name)) { + auto bias_dims = ctx->GetInputDim("B"); + ctx->SetOutputDim(bias_grad_name, bias_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(nce, ops::NCEOp, ops::NCEOpMaker, nce_grad, ops::NCEOpGrad); +REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel); +REGISTER_OP_CPU_KERNEL(nce_grad, + ops::NCEGradKernel); diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h new file mode 100644 index 0000000000000..ce1717c9b016a --- /dev/null +++ b/paddle/operators/nce_op.h @@ -0,0 +1,210 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/memory/memcpy.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +void PrepareSamples(const framework::ExecutionContext& context) { + auto label = context.Input("Label"); + const T* label_data = label->data(); + auto label_dims = label->dims(); + int num_classes = context.Attr("num_classes"); + // random machine + std::random_device rd; + std::mt19937 rng(rd()); + std::uniform_int_distribution rand(0, num_classes - 1); + + auto sample_labels = context.Output("SampleLabels"); + auto sample_labels_dims = sample_labels->dims(); + int* sample_labels_data = + sample_labels->mutable_data(context.GetPlace()); + + int num_label = label_dims.size() == 2 ? label_dims[1] : 1; + for (size_t i = 0; i < label_dims[0]; ++i) { + int j = 0; + for (; j < num_label; ++j) { + sample_labels_data[sample_labels_dims[1] * i + j] = + label_data[i * num_label + j]; + } + for (; j < sample_labels_dims[1]; ++j) { + int id = rand(rng); + sample_labels_data[sample_labels_dims[1] * i + j] = id; + } + } +} + +template +class NCEKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PrepareSamples(context); + auto sample_labels = context.Output("SampleLabels"); + const int* sample_labels_data = sample_labels->data(); + auto sample_out = context.Output("SampleLogits"); + T* sample_out_data = sample_out->mutable_data(context.GetPlace()); + auto label = context.Input("Label"); + auto sample_weight = context.Input("SampleWeight"); + const T* sample_weight_data = nullptr; + if (sample_weight != nullptr) { + sample_weight_data = sample_weight->data(); + } + auto out = context.Output("Out"); + T* out_data = out->mutable_data(context.GetPlace()); + int num_smalped_classes = context.Attr("num_sampled_classes"); + int num_classes = context.Attr("num_classes"); + int num_true_class = 1; + if (label != nullptr) { + num_true_class = label->dims()[1]; + } + T b = 1. / num_classes * num_smalped_classes; + + // forward bias + auto bias = context.Input("B"); + if (bias != nullptr) { + const T* bias_data = bias->data(); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + sample_out_data[i] = bias_data[sample_labels_data[i]]; + } + } else { + for (size_t i = 0; i < sample_labels->numel(); ++i) { + sample_out_data[i] = 0; + } + } + + // forward mul + auto input_mat = EigenMatrix::From(*(context.Input("X"))); + auto weight_mat = EigenMatrix::From(*(context.Input("W"))); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + // sample_out_data[i] += (input_mat.chip((int)(i / + // sample_labels->dims()[1]), 0) * weight_mat.chip(sample_labels_data[i], + // 0)).sum(); + Eigen::Tensor result = + (input_mat.chip((int)(i / sample_labels->dims()[1]), 0) * + weight_mat.chip(sample_labels_data[i], 0)) + .sum(); + sample_out_data[i] += result(0); + // activation_->forward + sample_out_data[i] = (1 / 1 + (sample_out_data[i])); + } + + // forward cost + for (size_t i = 0; i < sample_labels->dims()[0]; ++i) { + size_t j = 0; + T w = sample_weight == nullptr ? 1 : sample_weight_data[i]; + // for true classes + for (; j < num_true_class; ++j) { + T o = sample_out_data[i * sample_out->dims()[1] + j]; + T cost = -log(o / (o + b)); + out_data[i] += w * cost; + } + // for sampled neg classes + for (; j < sample_labels->dims()[1]; ++j) { + T o = sample_out_data[i * sample_out->dims()[1] + j]; + T cost = -log(b / (o + b)); + out_data[i] += w * cost; + } + } + } +}; + +template +class NCEGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto label = context.Input("Label"); + auto sample_out = context.Input("SampleLogits"); + const T* sample_out_data = sample_out->data(); + auto sample_labels = context.Input("SampleLabels"); + const int* sample_labels_data = sample_labels->data(); + auto sample_weight = context.Input("SampleWeight"); + const T* sample_weight_data = nullptr; + if (sample_weight != nullptr) { + sample_weight_data = sample_weight->data(); + } + int num_smalped_classes = context.Attr("num_sampled_classes"); + int num_classes = context.Attr("num_classes"); + int num_true_class = 1; + if (label != nullptr) { + num_true_class = label->dims()[1]; + } + T b = 1. / num_classes * num_smalped_classes; + + Tensor sample_grad; // tmp tensor + T* sample_grad_data = + sample_grad.mutable_data(sample_labels->dims(), context.GetPlace()); + + // backward cost + for (size_t i = 0; i < sample_labels->numel(); ++i) { + T o = sample_out_data[i]; + T w = sample_weight == nullptr + ? 1 + : sample_weight_data[i / sample_labels->dims()[1]]; + sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class + ? -w * b / (o * (o + b)) + : w / (o + b); + // sigmoid->backward + sample_grad_data[i] = + (o > 0) ? sample_grad_data[i] : ((o < 0) ? -sample_grad_data[i] : 0); + } + + // get d_bias + auto d_bias = context.Output(framework::GradVarName("B")); + if (d_bias != nullptr) { + T* d_bias_data = d_bias->mutable_data(context.GetPlace()); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; + } + } + // get d_w + auto d_w = context.Output(framework::GradVarName("W")); + if (d_w != nullptr) { + auto d_w_matrix = EigenMatrix::From(*d_w); + auto x_matrix = EigenMatrix::From(*(context.Input("X"))); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + d_w_matrix.chip(sample_labels_data[i], 0) = + x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) * + sample_grad_data[i]; + } + } + + // get d_x + auto d_x = context.Output(framework::GradVarName("X")); + if (d_x != nullptr) { + auto d_x_matrix = EigenMatrix::From(*d_x); + auto w_matrix = EigenMatrix::From(*(context.Input("W"))); + for (size_t i = 0; i < sample_labels->numel(); ++i) { + d_x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) += + w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; + } + } + } +}; + +} // namespace operators +} // namespace paddle From 09d32b068cbdf65f93e98f7b357dbc7e90f11734 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 16 Nov 2017 00:01:55 +0800 Subject: [PATCH 2/7] Add unitest and comments. --- paddle/operators/nce_op.cc | 115 +++++++++++++------ paddle/operators/nce_op.h | 79 +++++++------ python/paddle/v2/framework/tests/test_nce.py | 96 ++++++++++++++++ 3 files changed, 212 insertions(+), 78 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_nce.py diff --git a/paddle/operators/nce_op.cc b/paddle/operators/nce_op.cc index afd61b88514b8..c365d5d922286 100644 --- a/paddle/operators/nce_op.cc +++ b/paddle/operators/nce_op.cc @@ -23,57 +23,87 @@ class NCEOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X")); + PADDLE_ENFORCE(ctx->HasInput("Input")); PADDLE_ENFORCE(ctx->HasInput("Label")); - PADDLE_ENFORCE(ctx->HasInput("W")); - PADDLE_ENFORCE(ctx->HasOutput("Out")); + PADDLE_ENFORCE(ctx->HasInput("Weight")); + PADDLE_ENFORCE(ctx->HasOutput("Cost")); PADDLE_ENFORCE(ctx->HasOutput("SampleLogits")); PADDLE_ENFORCE(ctx->HasOutput("SampleLabels")); - auto x_dims = ctx->GetInputDim("X"); + auto x_dims = ctx->GetInputDim("Input"); auto label_dims = ctx->GetInputDim("Label"); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]); - if (ctx->HasInput("B")) { - PADDLE_ENFORCE_EQ(ctx->GetInputDim("W")[0], ctx->GetInputDim("B")[0]); + int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1; + if (ctx->HasInput("Bias")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Weight")[0], + ctx->GetInputDim("Bias")[0]); } - int num_sampled_classes = ctx->Attrs().Get("num_sampled_classes"); - int num_classes = ctx->Attrs().Get("num_classes"); - PADDLE_ENFORCE_EQ(num_classes, ctx->GetInputDim("W")[0]); + auto num_sampled_classes = ctx->Attrs().Get("num_sampled_classes"); + auto num_classes = ctx->Attrs().Get("num_classes"); + std::vector sampled_labels = + ctx->Attrs().Get>("sampled_labels"); + PADDLE_ENFORCE_EQ(num_classes, ctx->GetInputDim("Weight")[0]); PADDLE_ENFORCE_LT(num_sampled_classes, num_classes); - + if (sampled_labels.size() > 0) { + PADDLE_ENFORCE_EQ(sampled_labels.size(), + static_cast(num_sampled_classes)); + } // set dims of output(Out) - std::vector out_dims(1); + std::vector out_dims; out_dims.push_back(x_dims[0]); - ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + ctx->SetOutputDim("Cost", framework::make_ddim(out_dims)); // set dims of output(SampleOut) - std::vector sample_out_dims(2); + std::vector sample_out_dims; sample_out_dims.push_back(x_dims[0]); - sample_out_dims.push_back(num_sampled_classes + 1); + sample_out_dims.push_back(num_sampled_classes + num_true_classes); ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims)); ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims)); } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + ctx.device_context()); + } }; class NCEOpMaker : public framework::OpProtoAndCheckerMaker { public: NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", ""); - AddInput("Label", ""); - AddInput("W", ""); - AddInput("B", ""); - AddInput("SampleWeight", ""); - AddOutput("Out", ""); - AddOutput("SampleLogits", ""); - AddOutput("SampleLabels", ""); - AddAttr("num_classes", ""); - AddAttr("num_sampled_classes", "").SetDefault(10); + AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim]."); + AddInput("Label", + "(Tensor) A tensor of shape [batch_size, num_true_class]. " + "'num_true_class' is the number of target class in each sample."); + AddInput("Weight", + "(Tensor) A tensor of shape [num_class, dim]. 'num_class' is the " + "total number of class."); + AddInput("Bias", + "(Tensor) A tensor of shape [num_class]. 'num_class' is the total " + "number of class. It is a dispensable input.") + .AsDispensable(); + AddInput("SampleWeight", + "(Tensor) A tensor of shape [batch_size] storing a weight for " + "each sample. And it is a dispensable input. The default value of " + "sample is 1.") + .AsDispensable(); + AddOutput("Cost", + "(Tensor) A tensor of shape [batch_size]. Cost of samples."); + AddOutput("SampleLogits", "An intermediate tensor.").AsIntermediate(); + AddOutput("SampleLabels", "An intermediate tensor.").AsIntermediate(); + AddAttr("num_classes", "Total number of classes."); + AddAttr("num_sampled_classes", "The number of negative classes.") + .SetDefault(10); + AddAttr>("sampled_labels", ""); AddComment(R"DOC( -Expand input(X) according to LOD of input(Y). - +Computes and returns the noise-contrastive estimation training loss. +See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf). +By default this uses a uniform distribution for sampling. +The number of target classes per example should be same. If you have a variable number of target classes, you can pad them out to a constant number by either repeating them or by padding with an otherwise unused class. )DOC"); } }; @@ -82,32 +112,41 @@ class NCEOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X")); - PADDLE_ENFORCE(ctx->HasInput("W")); - PADDLE_ENFORCE(ctx->HasInput("Out")); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + PADDLE_ENFORCE(ctx->HasInput("Input")); + PADDLE_ENFORCE(ctx->HasInput("Weight")); + PADDLE_ENFORCE(ctx->HasInput("Cost")); + PADDLE_ENFORCE(ctx->HasInput("SampleLogits")); + PADDLE_ENFORCE(ctx->HasInput("SampleLabels")); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cost")), "The input(Out@GRAD) should not be null"); - auto x_dims = ctx->GetInputDim("X"); - auto x_grad_name = framework::GradVarName("X"); + auto x_dims = ctx->GetInputDim("Input"); + auto x_grad_name = framework::GradVarName("Input"); if (ctx->HasOutput(x_grad_name)) { ctx->SetOutputDim(x_grad_name, x_dims); } - auto w_dims = ctx->GetInputDim("W"); - auto w_grad_name = framework::GradVarName("W"); + auto w_dims = ctx->GetInputDim("Weight"); + auto w_grad_name = framework::GradVarName("Weight"); if (ctx->HasOutput(w_grad_name)) { ctx->SetOutputDim(w_grad_name, w_dims); } - auto bias_grad_name = framework::GradVarName("B"); + auto bias_grad_name = framework::GradVarName("Bias"); if (ctx->HasOutput(bias_grad_name)) { - auto bias_dims = ctx->GetInputDim("B"); + auto bias_dims = ctx->GetInputDim("Bias"); ctx->SetOutputDim(bias_grad_name, bias_dims); } } + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("Input")->type()), + ctx.device_context()); + } }; } // namespace operators diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h index ce1717c9b016a..3017bccdca4bb 100644 --- a/paddle/operators/nce_op.h +++ b/paddle/operators/nce_op.h @@ -14,12 +14,11 @@ #pragma once +#include #include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/memory/memcpy.h" #include "unsupported/Eigen/CXX11/Tensor" - namespace paddle { namespace operators { @@ -32,9 +31,12 @@ using EigenMatrix = framework::EigenMatrix; template void PrepareSamples(const framework::ExecutionContext& context) { auto label = context.Input("Label"); - const T* label_data = label->data(); + const int64_t* label_data = label->data(); auto label_dims = label->dims(); int num_classes = context.Attr("num_classes"); + // for unitest + std::vector sampled_labels = + context.Attr>("sampled_labels"); // random machine std::random_device rd; std::mt19937 rng(rd()); @@ -42,19 +44,24 @@ void PrepareSamples(const framework::ExecutionContext& context) { auto sample_labels = context.Output("SampleLabels"); auto sample_labels_dims = sample_labels->dims(); - int* sample_labels_data = - sample_labels->mutable_data(context.GetPlace()); + int64_t* sample_labels_data = + sample_labels->mutable_data(context.GetPlace()); int num_label = label_dims.size() == 2 ? label_dims[1] : 1; + int index = 0; for (size_t i = 0; i < label_dims[0]; ++i) { int j = 0; for (; j < num_label; ++j) { - sample_labels_data[sample_labels_dims[1] * i + j] = - label_data[i * num_label + j]; + sample_labels_data[index++] = label_data[i * num_label + j]; } - for (; j < sample_labels_dims[1]; ++j) { - int id = rand(rng); - sample_labels_data[sample_labels_dims[1] * i + j] = id; + if (sampled_labels.size() > 0) { + for (auto label : sampled_labels) { + sample_labels_data[index++] = label; + } + } else { + for (; j < sample_labels_dims[1]; ++j) { + sample_labels_data[index++] = rand(rng); + } } } } @@ -65,7 +72,7 @@ class NCEKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { PrepareSamples(context); auto sample_labels = context.Output("SampleLabels"); - const int* sample_labels_data = sample_labels->data(); + const int64_t* sample_labels_data = sample_labels->data(); auto sample_out = context.Output("SampleLogits"); T* sample_out_data = sample_out->mutable_data(context.GetPlace()); auto label = context.Input("Label"); @@ -74,7 +81,7 @@ class NCEKernel : public framework::OpKernel { if (sample_weight != nullptr) { sample_weight_data = sample_weight->data(); } - auto out = context.Output("Out"); + auto out = context.Output("Cost"); T* out_data = out->mutable_data(context.GetPlace()); int num_smalped_classes = context.Attr("num_sampled_classes"); int num_classes = context.Attr("num_classes"); @@ -83,9 +90,8 @@ class NCEKernel : public framework::OpKernel { num_true_class = label->dims()[1]; } T b = 1. / num_classes * num_smalped_classes; - // forward bias - auto bias = context.Input("B"); + auto bias = context.Input("Bias"); if (bias != nullptr) { const T* bias_data = bias->data(); for (size_t i = 0; i < sample_labels->numel(); ++i) { @@ -96,27 +102,23 @@ class NCEKernel : public framework::OpKernel { sample_out_data[i] = 0; } } - // forward mul - auto input_mat = EigenMatrix::From(*(context.Input("X"))); - auto weight_mat = EigenMatrix::From(*(context.Input("W"))); + auto input_mat = EigenMatrix::From(*(context.Input("Input"))); + auto weight_mat = EigenMatrix::From(*(context.Input("Weight"))); for (size_t i = 0; i < sample_labels->numel(); ++i) { - // sample_out_data[i] += (input_mat.chip((int)(i / - // sample_labels->dims()[1]), 0) * weight_mat.chip(sample_labels_data[i], - // 0)).sum(); Eigen::Tensor result = (input_mat.chip((int)(i / sample_labels->dims()[1]), 0) * weight_mat.chip(sample_labels_data[i], 0)) .sum(); sample_out_data[i] += result(0); // activation_->forward - sample_out_data[i] = (1 / 1 + (sample_out_data[i])); + sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i]))); } - // forward cost for (size_t i = 0; i < sample_labels->dims()[0]; ++i) { size_t j = 0; - T w = sample_weight == nullptr ? 1 : sample_weight_data[i]; + out_data[i] = 0; + T w = sample_weight == nullptr ? 1. : sample_weight_data[i]; // for true classes for (; j < num_true_class; ++j) { T o = sample_out_data[i * sample_out->dims()[1] + j]; @@ -137,11 +139,13 @@ template class NCEGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + auto d_out = context.Input(framework::GradVarName("Cost")); + const T* d_out_data = d_out->data(); auto label = context.Input("Label"); auto sample_out = context.Input("SampleLogits"); const T* sample_out_data = sample_out->data(); auto sample_labels = context.Input("SampleLabels"); - const int* sample_labels_data = sample_labels->data(); + const int64_t* sample_labels_data = sample_labels->data(); auto sample_weight = context.Input("SampleWeight"); const T* sample_weight_data = nullptr; if (sample_weight != nullptr) { @@ -154,11 +158,9 @@ class NCEGradKernel : public framework::OpKernel { num_true_class = label->dims()[1]; } T b = 1. / num_classes * num_smalped_classes; - Tensor sample_grad; // tmp tensor T* sample_grad_data = sample_grad.mutable_data(sample_labels->dims(), context.GetPlace()); - // backward cost for (size_t i = 0; i < sample_labels->numel(); ++i) { T o = sample_out_data[i]; @@ -166,15 +168,12 @@ class NCEGradKernel : public framework::OpKernel { ? 1 : sample_weight_data[i / sample_labels->dims()[1]]; sample_grad_data[i] = (i % sample_labels->dims()[1]) < num_true_class - ? -w * b / (o * (o + b)) - : w / (o + b); - // sigmoid->backward - sample_grad_data[i] = - (o > 0) ? sample_grad_data[i] : ((o < 0) ? -sample_grad_data[i] : 0); + ? w * (b / (o + b)) * (o - 1) + : w * (o * (1 - o) / (o + b)); + sample_grad_data[i] *= d_out_data[i / sample_labels->dims()[1]]; } - // get d_bias - auto d_bias = context.Output(framework::GradVarName("B")); + auto d_bias = context.Output(framework::GradVarName("Bias")); if (d_bias != nullptr) { T* d_bias_data = d_bias->mutable_data(context.GetPlace()); for (size_t i = 0; i < sample_labels->numel(); ++i) { @@ -182,22 +181,23 @@ class NCEGradKernel : public framework::OpKernel { } } // get d_w - auto d_w = context.Output(framework::GradVarName("W")); + auto d_w = context.Output(framework::GradVarName("Weight")); if (d_w != nullptr) { + d_w->mutable_data(context.GetPlace()); auto d_w_matrix = EigenMatrix::From(*d_w); - auto x_matrix = EigenMatrix::From(*(context.Input("X"))); + auto x_matrix = EigenMatrix::From(*(context.Input("Input"))); for (size_t i = 0; i < sample_labels->numel(); ++i) { - d_w_matrix.chip(sample_labels_data[i], 0) = + d_w_matrix.chip(sample_labels_data[i], 0) += x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) * sample_grad_data[i]; } } - // get d_x - auto d_x = context.Output(framework::GradVarName("X")); + auto d_x = context.Output(framework::GradVarName("Input")); if (d_x != nullptr) { + d_x->mutable_data(context.GetPlace()); auto d_x_matrix = EigenMatrix::From(*d_x); - auto w_matrix = EigenMatrix::From(*(context.Input("W"))); + auto w_matrix = EigenMatrix::From(*(context.Input("Weight"))); for (size_t i = 0; i < sample_labels->numel(); ++i) { d_x_matrix.chip((int)(i / sample_labels->dims()[1]), 0) += w_matrix.chip(sample_labels_data[i], 0) * sample_grad_data[i]; @@ -205,6 +205,5 @@ class NCEGradKernel : public framework::OpKernel { } } }; - } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_nce.py b/python/paddle/v2/framework/tests/test_nce.py new file mode 100644 index 0000000000000..8b1e7a6bb535f --- /dev/null +++ b/python/paddle/v2/framework/tests/test_nce.py @@ -0,0 +1,96 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def nce(input, weight, bias, sample_weight, labels, num_classes, + num_sample_class): + samples = [] + sample_labels = [] + batch_size = input.shape[0] + num_true_class = labels.shape[1] + for i in range(batch_size): + w = 1 if sample_weight is None else sample_weight[i] + for label in labels[i]: + samples.append((i, label, True, w)) + sample_labels.append(label) + for num in range(num_sample_class): + samples.append((i, num, False, w)) + sample_labels.append(num) + # forward bias + sampleOut = np.zeros(len(samples)).astype(np.float32) + if bias is not None: + for i in range(len(samples)): + sampleOut[i] = bias[samples[i][1]] + # forward weight + for i in range(len(samples)): + sampleOut[i] += np.dot(input[samples[i][0]], weight[samples[i][1]]) + + # forward activation + sampleOut = 1.0 / (1.0 + np.exp(-sampleOut)) + # forward cost + out = np.zeros(batch_size).astype(np.float32) + b = 1.0 / num_classes * num_sample_class + for i in range(len(samples)): + o = sampleOut[i] + cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b)) + out[samples[i][0]] += cost * samples[i][3] + return (out, np.array(sampleOut).reshape(batch_size, + num_sample_class + num_true_class), + np.array(sample_labels).reshape(batch_size, + num_sample_class + num_true_class)) + + +class TestNCE(OpTest): + def generate_data(self, dim, batch_size, num_classes, num_true_class, + num_sampled_classes): + input = np.random.randn(batch_size, dim).astype(np.float32) + weight = np.random.randn(num_classes, dim).astype(np.float32) + bias = np.random.randn(num_classes).astype(np.float32) + sample_weight = np.random.randn(batch_size).astype(np.float32) + labels = np.random.randint(0, num_classes, (batch_size, num_true_class)) + self.attrs = { + 'num_classes': num_classes, + 'num_sampled_classes': num_sampled_classes, + 'sampled_labels': range(num_sampled_classes) + } + self.inputs = { + 'X': input, + 'Label': labels, + 'W': weight, + 'B': bias, + 'SampleWeight': sample_weight + } + + def set_data(self): + self.generate_data(5, 5, 4, 1, 2) + + def compute(self): + out = nce(self.inputs['X'], self.inputs['W'], self.inputs['B'], + self.inputs['SampleWeight'], self.inputs['Label'], + self.attrs['num_classes'], self.attrs['num_sampled_classes']) + self.outputs = { + 'Out': out[0], + 'SampleLogits': out[1], + 'SampleLabels': out[2] + } + + def setUp(self): + self.op_type = 'nce' + self.set_data() + self.compute() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X", "W", "B"], "Out", max_relative_error=0.02) + + +class TestNCECase1(TestNCE): + def set_data(self): + self.generate_data(10, 20, 10, 2, 5) + + +if __name__ == '__main__': + unittest.main() From e60eb1eacdac476b52cbd029660249fe709b7196 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Thu, 16 Nov 2017 00:45:36 +0800 Subject: [PATCH 3/7] fix unitest --- .../v2/{framework => fluid}/tests/test_nce.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) rename python/paddle/v2/{framework => fluid}/tests/test_nce.py (86%) diff --git a/python/paddle/v2/framework/tests/test_nce.py b/python/paddle/v2/fluid/tests/test_nce.py similarity index 86% rename from python/paddle/v2/framework/tests/test_nce.py rename to python/paddle/v2/fluid/tests/test_nce.py index 8b1e7a6bb535f..82978f2d230a3 100644 --- a/python/paddle/v2/framework/tests/test_nce.py +++ b/python/paddle/v2/fluid/tests/test_nce.py @@ -55,10 +55,10 @@ def generate_data(self, dim, batch_size, num_classes, num_true_class, 'sampled_labels': range(num_sampled_classes) } self.inputs = { - 'X': input, + 'Input': input, 'Label': labels, - 'W': weight, - 'B': bias, + 'Weight': weight, + 'Bias': bias, 'SampleWeight': sample_weight } @@ -66,11 +66,12 @@ def set_data(self): self.generate_data(5, 5, 4, 1, 2) def compute(self): - out = nce(self.inputs['X'], self.inputs['W'], self.inputs['B'], - self.inputs['SampleWeight'], self.inputs['Label'], - self.attrs['num_classes'], self.attrs['num_sampled_classes']) + out = nce(self.inputs['Input'], self.inputs['Weight'], + self.inputs['Bias'], self.inputs['SampleWeight'], + self.inputs['Label'], self.attrs['num_classes'], + self.attrs['num_sampled_classes']) self.outputs = { - 'Out': out[0], + 'Cost': out[0], 'SampleLogits': out[1], 'SampleLabels': out[2] } @@ -84,7 +85,8 @@ def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(["X", "W", "B"], "Out", max_relative_error=0.02) + self.check_grad( + ["Input", "Weight", "Bias"], "Cost", max_relative_error=0.02) class TestNCECase1(TestNCE): From ea7359c60bdf6062b1296f471f50cbeaf8da243e Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 28 Nov 2017 12:47:17 +0800 Subject: [PATCH 4/7] Refine code and comments 1. Remove checking for num_neg_samples. 2. Fix dims of Output(Cost) and Input(Bias). 3. Renamed num_sampled_classes to num_neg_samples. 4. Add TODO for add more distribution sampler. 5. Init grad_data of bias by zero. 6. Refine comments. 7. Register a kernel for type double. --- paddle/operators/nce_op.cc | 95 +++++++++++++++--------- paddle/operators/nce_op.h | 15 ++-- python/paddle/v2/fluid/tests/test_nce.py | 14 ++-- 3 files changed, 77 insertions(+), 47 deletions(-) diff --git a/paddle/operators/nce_op.cc b/paddle/operators/nce_op.cc index c365d5d922286..bb9346b134c88 100644 --- a/paddle/operators/nce_op.cc +++ b/paddle/operators/nce_op.cc @@ -1,16 +1,16 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/operators/nce_op.h" @@ -39,25 +39,25 @@ class NCEOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->GetInputDim("Weight")[0], ctx->GetInputDim("Bias")[0]); } - auto num_sampled_classes = ctx->Attrs().Get("num_sampled_classes"); - auto num_classes = ctx->Attrs().Get("num_classes"); + auto num_neg_samples = ctx->Attrs().Get("num_neg_samples"); + auto num_total_classes = ctx->Attrs().Get("num_total_classes"); std::vector sampled_labels = ctx->Attrs().Get>("sampled_labels"); - PADDLE_ENFORCE_EQ(num_classes, ctx->GetInputDim("Weight")[0]); - PADDLE_ENFORCE_LT(num_sampled_classes, num_classes); + PADDLE_ENFORCE_EQ(num_total_classes, ctx->GetInputDim("Weight")[0]); if (sampled_labels.size() > 0) { PADDLE_ENFORCE_EQ(sampled_labels.size(), - static_cast(num_sampled_classes)); + static_cast(num_neg_samples)); } // set dims of output(Out) std::vector out_dims; out_dims.push_back(x_dims[0]); + out_dims.push_back(1); ctx->SetOutputDim("Cost", framework::make_ddim(out_dims)); // set dims of output(SampleOut) std::vector sample_out_dims; sample_out_dims.push_back(x_dims[0]); - sample_out_dims.push_back(num_sampled_classes + num_true_classes); + sample_out_dims.push_back(num_neg_samples + num_true_classes); ctx->SetOutputDim("SampleLogits", framework::make_ddim(sample_out_dims)); ctx->SetOutputDim("SampleLabels", framework::make_ddim(sample_out_dims)); } @@ -76,34 +76,59 @@ class NCEOpMaker : public framework::OpProtoAndCheckerMaker { NCEOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("Input", "(Tensor) A tensor of shape [batch_size, dim]."); - AddInput("Label", - "(Tensor) A tensor of shape [batch_size, num_true_class]. " - "'num_true_class' is the number of target class in each sample."); + AddInput( + "Label", + "(Tensor) A tensor of shape [batch_size, num_true_class]. " + "'num_true_class' is the number of target classes in each sample." + "The number of target classes per sample should be same. " + "If you have a variable number of target classes, " + "you can pad them out to a constant number by either repeating them" + " or by padding with an otherwise unused class.)"); AddInput("Weight", "(Tensor) A tensor of shape [num_class, dim]. 'num_class' is the " "total number of class."); - AddInput("Bias", - "(Tensor) A tensor of shape [num_class]. 'num_class' is the total " - "number of class. It is a dispensable input.") + AddInput( + "Bias", + "(Tensor) A tensor of shape [num_class, 1]. 'num_class' is the total " + "number of class. It is a dispensable input.") .AsDispensable(); AddInput("SampleWeight", - "(Tensor) A tensor of shape [batch_size] storing a weight for " + "(Tensor) A tensor of shape [batch_size, 1] storing a weight for " "each sample. And it is a dispensable input. The default value of " "sample is 1.") .AsDispensable(); AddOutput("Cost", - "(Tensor) A tensor of shape [batch_size]. Cost of samples."); - AddOutput("SampleLogits", "An intermediate tensor.").AsIntermediate(); - AddOutput("SampleLabels", "An intermediate tensor.").AsIntermediate(); - AddAttr("num_classes", "Total number of classes."); - AddAttr("num_sampled_classes", "The number of negative classes.") + "(Tensor) A tensor of shape [batch_size, 1]. Cost of samples."); + AddOutput("SampleLogits", + "An intermediate tensor of shape[batch_size, num_neg_samples + " + "num_pos_samples]." + "This tensor is output of forward kernel and used in backward " + "kernel to compute grads." + "Given X is the dot product of input tensor and sampled labels' " + "weights." + "Then 'SampleLogits' is sigmoid(X).") + .AsIntermediate(); + AddOutput("SampleLabels", + "An intermediate tensor of shape[batch_size, num_neg_samples + " + "num_pos_samples]." + "This tensor is output of forward kernel and used in backward " + "kernel to compute grads." + "") + .AsIntermediate(); + AddAttr("num_total_classes", + "Total number of classes in all samples."); + AddAttr("num_neg_samples", + "The number of negative classes. The default value is 10.") .SetDefault(10); - AddAttr>("sampled_labels", ""); + AddAttr>("custom_neg_classes", + "This attribute only be used in unitest. Classes " + "in this list wiil be used as negative classes " + "for every samples. Under normal conditions, " + "user should avoid setting this attribute."); AddComment(R"DOC( -Computes and returns the noise-contrastive estimation training loss. +Compute and return the noise-contrastive estimation training loss. See [Noise-contrastive estimation: A new estimation principle for unnormalized statistical models](http://www.jmlr.org/proceedings/papers/v9/gutmann10a/gutmann10a.pdf). -By default this uses a uniform distribution for sampling. -The number of target classes per example should be same. If you have a variable number of target classes, you can pad them out to a constant number by either repeating them or by padding with an otherwise unused class. +By default this operator uses a uniform distribution for sampling. )DOC"); } }; @@ -119,7 +144,7 @@ class NCEOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput("SampleLogits")); PADDLE_ENFORCE(ctx->HasInput("SampleLabels")); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cost")), - "The input(Out@GRAD) should not be null"); + "The input(Out@GRAD) should not be null."); auto x_dims = ctx->GetInputDim("Input"); auto x_grad_name = framework::GradVarName("Input"); @@ -154,6 +179,8 @@ class NCEOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(nce, ops::NCEOp, ops::NCEOpMaker, nce_grad, ops::NCEOpGrad); -REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel); +REGISTER_OP_CPU_KERNEL(nce, ops::NCEKernel, + ops::NCEKernel); REGISTER_OP_CPU_KERNEL(nce_grad, - ops::NCEGradKernel); + ops::NCEGradKernel, + ops::NCEGradKernel); diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h index 3017bccdca4bb..c41393d26023f 100644 --- a/paddle/operators/nce_op.h +++ b/paddle/operators/nce_op.h @@ -22,7 +22,7 @@ namespace paddle { namespace operators { -using Tensor = framework::Tensor; +using framework::Tensor; template @@ -35,8 +35,8 @@ void PrepareSamples(const framework::ExecutionContext& context) { auto label_dims = label->dims(); int num_classes = context.Attr("num_classes"); // for unitest - std::vector sampled_labels = - context.Attr>("sampled_labels"); + std::vector custom_neg_classes = + context.Attr>("custom_neg_classes"); // random machine std::random_device rd; std::mt19937 rng(rd()); @@ -54,12 +54,13 @@ void PrepareSamples(const framework::ExecutionContext& context) { for (; j < num_label; ++j) { sample_labels_data[index++] = label_data[i * num_label + j]; } - if (sampled_labels.size() > 0) { - for (auto label : sampled_labels) { + if (custom_neg_classes.size() > 0) { + for (auto label : custom_neg_classes) { sample_labels_data[index++] = label; } } else { for (; j < sample_labels_dims[1]; ++j) { + // TODO: support more distribution sampling sample_labels_data[index++] = rand(rng); } } @@ -176,6 +177,7 @@ class NCEGradKernel : public framework::OpKernel { auto d_bias = context.Output(framework::GradVarName("Bias")); if (d_bias != nullptr) { T* d_bias_data = d_bias->mutable_data(context.GetPlace()); + std::fill(d_bias_data, d_bias_data + d_bias->numel(), 0.0); for (size_t i = 0; i < sample_labels->numel(); ++i) { d_bias_data[sample_labels_data[i]] += sample_grad_data[i]; } @@ -183,7 +185,8 @@ class NCEGradKernel : public framework::OpKernel { // get d_w auto d_w = context.Output(framework::GradVarName("Weight")); if (d_w != nullptr) { - d_w->mutable_data(context.GetPlace()); + auto d_w_data = d_w->mutable_data(context.GetPlace()); + std::fill(d_w_data, d_w_data + d_w->numel(), 0.0); auto d_w_matrix = EigenMatrix::From(*d_w); auto x_matrix = EigenMatrix::From(*(context.Input("Input"))); for (size_t i = 0; i < sample_labels->numel(); ++i) { diff --git a/python/paddle/v2/fluid/tests/test_nce.py b/python/paddle/v2/fluid/tests/test_nce.py index 82978f2d230a3..6cbf468e0a983 100644 --- a/python/paddle/v2/fluid/tests/test_nce.py +++ b/python/paddle/v2/fluid/tests/test_nce.py @@ -18,25 +18,25 @@ def nce(input, weight, bias, sample_weight, labels, num_classes, samples.append((i, num, False, w)) sample_labels.append(num) # forward bias - sampleOut = np.zeros(len(samples)).astype(np.float32) + sample_out = np.zeros(len(samples)).astype(np.float32) if bias is not None: for i in range(len(samples)): - sampleOut[i] = bias[samples[i][1]] + sample_out[i] = bias[samples[i][1]] # forward weight for i in range(len(samples)): - sampleOut[i] += np.dot(input[samples[i][0]], weight[samples[i][1]]) + sample_out[i] += np.dot(input[samples[i][0]], weight[samples[i][1]]) # forward activation - sampleOut = 1.0 / (1.0 + np.exp(-sampleOut)) + sample_out = 1.0 / (1.0 + np.exp(-sample_out)) # forward cost out = np.zeros(batch_size).astype(np.float32) b = 1.0 / num_classes * num_sample_class for i in range(len(samples)): - o = sampleOut[i] + o = sample_out[i] cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b)) out[samples[i][0]] += cost * samples[i][3] - return (out, np.array(sampleOut).reshape(batch_size, - num_sample_class + num_true_class), + return (out, np.array(sample_out).reshape( + batch_size, num_sample_class + num_true_class), np.array(sample_labels).reshape(batch_size, num_sample_class + num_true_class)) From ab9d59c5396002a1c0695075164da5109c530150 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 28 Nov 2017 14:45:11 +0800 Subject: [PATCH 5/7] Fix double type error while using eigen api --- paddle/operators/nce_op.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h index c41393d26023f..7a910703293a6 100644 --- a/paddle/operators/nce_op.h +++ b/paddle/operators/nce_op.h @@ -22,7 +22,7 @@ namespace paddle { namespace operators { -using framework::Tensor; +using Tensor = framework::Tensor; template @@ -107,12 +107,11 @@ class NCEKernel : public framework::OpKernel { auto input_mat = EigenMatrix::From(*(context.Input("Input"))); auto weight_mat = EigenMatrix::From(*(context.Input("Weight"))); for (size_t i = 0; i < sample_labels->numel(); ++i) { - Eigen::Tensor result = + Eigen::Tensor result = (input_mat.chip((int)(i / sample_labels->dims()[1]), 0) * weight_mat.chip(sample_labels_data[i], 0)) .sum(); sample_out_data[i] += result(0); - // activation_->forward sample_out_data[i] = (1. / (1. + exp(-sample_out_data[i]))); } // forward cost From 76a65a83a015a38bd8f6654b4dc27d6040bcd5d8 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Tue, 28 Nov 2017 15:54:54 +0800 Subject: [PATCH 6/7] Fix comments style --- paddle/operators/nce_op.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h index 7a910703293a6..8df20f432dade 100644 --- a/paddle/operators/nce_op.h +++ b/paddle/operators/nce_op.h @@ -60,7 +60,7 @@ void PrepareSamples(const framework::ExecutionContext& context) { } } else { for (; j < sample_labels_dims[1]; ++j) { - // TODO: support more distribution sampling + // TODO(wanghaoshuang): support more distribution sampling sample_labels_data[index++] = rand(rng); } } From 29262ab24d8675d5b50fe21dda59f4102db1bb7b Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 29 Nov 2017 11:56:29 +0800 Subject: [PATCH 7/7] Fix unitest. --- paddle/operators/nce_op.cc | 8 ++++---- paddle/operators/nce_op.h | 16 ++++++++-------- python/paddle/v2/fluid/tests/test_nce.py | 14 +++++++------- 3 files changed, 19 insertions(+), 19 deletions(-) diff --git a/paddle/operators/nce_op.cc b/paddle/operators/nce_op.cc index bb9346b134c88..952da10434df0 100644 --- a/paddle/operators/nce_op.cc +++ b/paddle/operators/nce_op.cc @@ -41,11 +41,11 @@ class NCEOp : public framework::OperatorWithKernel { } auto num_neg_samples = ctx->Attrs().Get("num_neg_samples"); auto num_total_classes = ctx->Attrs().Get("num_total_classes"); - std::vector sampled_labels = - ctx->Attrs().Get>("sampled_labels"); + std::vector custom_neg_classes = + ctx->Attrs().Get>("custom_neg_classes"); PADDLE_ENFORCE_EQ(num_total_classes, ctx->GetInputDim("Weight")[0]); - if (sampled_labels.size() > 0) { - PADDLE_ENFORCE_EQ(sampled_labels.size(), + if (custom_neg_classes.size() > 0) { + PADDLE_ENFORCE_EQ(custom_neg_classes.size(), static_cast(num_neg_samples)); } // set dims of output(Out) diff --git a/paddle/operators/nce_op.h b/paddle/operators/nce_op.h index 8df20f432dade..ea92a797fe18e 100644 --- a/paddle/operators/nce_op.h +++ b/paddle/operators/nce_op.h @@ -33,14 +33,14 @@ void PrepareSamples(const framework::ExecutionContext& context) { auto label = context.Input("Label"); const int64_t* label_data = label->data(); auto label_dims = label->dims(); - int num_classes = context.Attr("num_classes"); + int num_total_classes = context.Attr("num_total_classes"); // for unitest std::vector custom_neg_classes = context.Attr>("custom_neg_classes"); // random machine std::random_device rd; std::mt19937 rng(rd()); - std::uniform_int_distribution rand(0, num_classes - 1); + std::uniform_int_distribution rand(0, num_total_classes - 1); auto sample_labels = context.Output("SampleLabels"); auto sample_labels_dims = sample_labels->dims(); @@ -84,13 +84,13 @@ class NCEKernel : public framework::OpKernel { } auto out = context.Output("Cost"); T* out_data = out->mutable_data(context.GetPlace()); - int num_smalped_classes = context.Attr("num_sampled_classes"); - int num_classes = context.Attr("num_classes"); + int num_neg_samples = context.Attr("num_neg_samples"); + int num_total_classes = context.Attr("num_total_classes"); int num_true_class = 1; if (label != nullptr) { num_true_class = label->dims()[1]; } - T b = 1. / num_classes * num_smalped_classes; + T b = 1. / num_total_classes * num_neg_samples; // forward bias auto bias = context.Input("Bias"); if (bias != nullptr) { @@ -151,13 +151,13 @@ class NCEGradKernel : public framework::OpKernel { if (sample_weight != nullptr) { sample_weight_data = sample_weight->data(); } - int num_smalped_classes = context.Attr("num_sampled_classes"); - int num_classes = context.Attr("num_classes"); + int num_neg_samples = context.Attr("num_neg_samples"); + int num_total_classes = context.Attr("num_total_classes"); int num_true_class = 1; if (label != nullptr) { num_true_class = label->dims()[1]; } - T b = 1. / num_classes * num_smalped_classes; + T b = 1. / num_total_classes * num_neg_samples; Tensor sample_grad; // tmp tensor T* sample_grad_data = sample_grad.mutable_data(sample_labels->dims(), context.GetPlace()); diff --git a/python/paddle/v2/fluid/tests/test_nce.py b/python/paddle/v2/fluid/tests/test_nce.py index 6cbf468e0a983..8aeba69769525 100644 --- a/python/paddle/v2/fluid/tests/test_nce.py +++ b/python/paddle/v2/fluid/tests/test_nce.py @@ -35,7 +35,7 @@ def nce(input, weight, bias, sample_weight, labels, num_classes, o = sample_out[i] cost = -np.log(o / (o + b)) if samples[i][2] else -np.log(b / (o + b)) out[samples[i][0]] += cost * samples[i][3] - return (out, np.array(sample_out).reshape( + return (out[:, np.newaxis], np.array(sample_out).reshape( batch_size, num_sample_class + num_true_class), np.array(sample_labels).reshape(batch_size, num_sample_class + num_true_class)) @@ -43,16 +43,16 @@ def nce(input, weight, bias, sample_weight, labels, num_classes, class TestNCE(OpTest): def generate_data(self, dim, batch_size, num_classes, num_true_class, - num_sampled_classes): + num_neg_samples): input = np.random.randn(batch_size, dim).astype(np.float32) weight = np.random.randn(num_classes, dim).astype(np.float32) bias = np.random.randn(num_classes).astype(np.float32) sample_weight = np.random.randn(batch_size).astype(np.float32) labels = np.random.randint(0, num_classes, (batch_size, num_true_class)) self.attrs = { - 'num_classes': num_classes, - 'num_sampled_classes': num_sampled_classes, - 'sampled_labels': range(num_sampled_classes) + 'num_total_classes': num_classes, + 'num_neg_samples': num_neg_samples, + 'custom_neg_classes': range(num_neg_samples) } self.inputs = { 'Input': input, @@ -68,8 +68,8 @@ def set_data(self): def compute(self): out = nce(self.inputs['Input'], self.inputs['Weight'], self.inputs['Bias'], self.inputs['SampleWeight'], - self.inputs['Label'], self.attrs['num_classes'], - self.attrs['num_sampled_classes']) + self.inputs['Label'], self.attrs['num_total_classes'], + self.attrs['num_neg_samples']) self.outputs = { 'Cost': out[0], 'SampleLogits': out[1],