Skip to content

Commit

Permalink
[NPU] add multinomial op (#42613)
Browse files Browse the repository at this point in the history
* [NPU] add multinomial op

* fix place

* deal with cann version

* fix for old operator

* change another way
  • Loading branch information
Aganlengzi authored May 17, 2022
1 parent 6b58de9 commit fd14069
Show file tree
Hide file tree
Showing 4 changed files with 302 additions and 2 deletions.
5 changes: 3 additions & 2 deletions cmake/external/ascend.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ endif()
if (WITH_ASCEND_CL)
macro(find_ascend_toolkit_version ascend_toolkit_version_info)
file(READ ${ascend_toolkit_version_info} ASCEND_TOOLKIT_VERSION_CONTENTS)
string(REGEX MATCH "version=([0-9]+\.[0-9]+\.(RC)?[0-9]+\.[a-z]*[0-9]*)" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION_CONTENTS}")
string(REGEX REPLACE "version=([0-9]+\.[0-9]+\.(RC)?[0-9]+\.[a-z]*[0-9]*)" "\\1" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION}")
string(REGEX MATCH "version=([0-9]+\.[0-9]+\.(RC)?[0-9][.a-z0-9]*)" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION_CONTENTS}")
string(REGEX REPLACE "version=([0-9]+\.[0-9]+\.(RC)?[0-9][.a-z0-9]*)" "\\1" ASCEND_TOOLKIT_VERSION "${ASCEND_TOOLKIT_VERSION}")
string(REGEX REPLACE "[A-Z]|[a-z|\.]" "" CANN_VERSION ${ASCEND_TOOLKIT_VERSION})
STRING(SUBSTRING "${CANN_VERSION}000" 0 6 CANN_VERSION)
add_definitions("-DCANN_VERSION_CODE=${CANN_VERSION}")
if(NOT ASCEND_TOOLKIT_VERSION)
set(ASCEND_TOOLKIT_VERSION "???")
Expand Down
6 changes: 6 additions & 0 deletions cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ function(op_library TARGET)
elseif (WITH_XPU_KP AND ${xpu_kp_cc_srcs_len} GREATER 0)
xpu_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} ${xpu_cc_srcs} ${xpu_kp_cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps})
else()
# deal with CANN version control while registering NPU operators before build
if (WITH_ASCEND_CL)
if (CANN_VERSION LESS 504000)
list(REMOVE_ITEM npu_cc_srcs "multinomial_op_npu.cc")
endif()
endif()
# Unity Build relies on global option `WITH_UNITY_BUILD` and local option `UNITY`.
if(WITH_UNITY_BUILD AND op_library_UNITY)
# Combine the cc source files.
Expand Down
58 changes: 58 additions & 0 deletions paddle/fluid/operators/multinomial_op_npu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

// TODO(Aganlengzi): delete this macro control and remove REMOVE_ITEM in
// cmake/operators.cmake when Paddle supports
#if (CANN_VERSION_CODE >= 504000)

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename DeviceContext, typename T>
class NPUMultinomialKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto x = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
const int64_t num_samples = ctx.Attr<int>("num_samples");
const bool replacement = ctx.Attr<bool>("replacement");

auto place = ctx.GetPlace();
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
out->mutable_data<int64_t>(place);

const auto& runner = NpuOpRunner(
"MultinomialWithReplacementD", {*x}, {*out},
{{"num_samples", num_samples}, {"replacement", replacement}});
runner.Run(stream);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
multinomial,
ops::NPUMultinomialKernel<paddle::platform::NPUDeviceContext, float>,
ops::NPUMultinomialKernel<paddle::platform::NPUDeviceContext, double>)
#endif
235 changes: 235 additions & 0 deletions python/paddle/fluid/tests/unittests/npu/test_multinomial_op_npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
import sys
sys.path.append("..")
from op_test import OpTest
import numpy as np
import os

paddle.enable_static()


def sample_output_one_dimension(out, dim):
# count numbers of different categories
sample_prob = np.zeros(dim).astype("float32")
sample_index_prob = np.unique(out, return_counts=True)
sample_prob[sample_index_prob[0]] = sample_index_prob[1]
sample_prob /= sample_prob.sum()
return sample_prob


def sample_output_two_dimension(out, shape):
num_dist = shape[0]
out_list = np.split(out, num_dist, axis=0)
sample_prob = np.zeros(shape).astype("float32")
for i in range(num_dist):
sample_index_prob = np.unique(out_list[i], return_counts=True)
sample_prob[i][sample_index_prob[0]] = sample_index_prob[1]
sample_prob /= sample_prob.sum(axis=-1, keepdims=True)
return sample_prob


class TestMultinomialOp(OpTest):
def setUp(self):
self.set_npu()
self.op_type = "multinomial"
self.init_data()
self.inputs = {"X": self.input_np}

def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)

def init_data(self):
# input probability is a vector, and replacement is True
self.input_np = np.random.rand(4)
self.outputs = {"Out": np.zeros(100000).astype("int64")}
self.attrs = {"num_samples": 100000, "replacement": True}

def test_check_output(self):
self.check_output_customized(
self.verify_output, custom_place=self.place)

def sample_output(self, out):
return sample_output_one_dimension(out, 4)

def verify_output(self, outs):
# normalize the input to get the probability
prob = self.input_np / self.input_np.sum(axis=-1, keepdims=True)
sample_prob = self.sample_output(np.array(outs[0]))
self.assertTrue(
np.allclose(
sample_prob, prob, rtol=0, atol=0.01),
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))


class TestMultinomialOp2(TestMultinomialOp):
def init_data(self):
# input probability is a matrix
self.input_np = np.random.rand(3, 4)
self.outputs = {"Out": np.zeros((3, 100000)).astype("int64")}
self.attrs = {"num_samples": 100000, "replacement": True}

def sample_output(self, out):
return sample_output_two_dimension(out, [3, 4])


class TestMultinomialOp3(TestMultinomialOp):
def init_data(self):
# replacement is False. number of samples must be less than number of categories.
self.input_np = np.random.rand(1000)
self.outputs = {"Out": np.zeros(100).astype("int64")}
self.attrs = {"num_samples": 100, "replacement": False}

def verify_output(self, outs):
out = np.array(outs[0])
unique_out = np.unique(out)
self.assertEqual(
len(unique_out), 100,
"replacement is False. categories can't be sampled repeatedly")


class TestMultinomialApi(unittest.TestCase):
def test_dygraph(self):
# input probability is a vector, and replacement is True
paddle.set_device('npu:0')
paddle.disable_static()
x_numpy = np.random.rand(4)
x = paddle.to_tensor(x_numpy)
out = paddle.multinomial(x, num_samples=100000, replacement=True)

sample_prob = sample_output_one_dimension(out.numpy(), 4)
prob = x_numpy / x_numpy.sum(axis=-1, keepdims=True)
self.assertTrue(
np.allclose(
sample_prob, prob, rtol=0, atol=0.01),
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
paddle.enable_static()

def test_dygraph2(self):
# input probability is a matrix, and replacement is True
paddle.set_device('npu:0')
paddle.disable_static()
x_numpy = np.random.rand(3, 4)
x = paddle.to_tensor(x_numpy)
out = paddle.multinomial(x, num_samples=100000, replacement=True)

sample_prob = sample_output_two_dimension(out.numpy(), [3, 4])
prob = x_numpy / x_numpy.sum(axis=-1, keepdims=True)
self.assertTrue(
np.allclose(
sample_prob, prob, rtol=0, atol=0.01),
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))
paddle.enable_static()

def test_dygraph3(self):
# replacement is False. number of samples must be less than number of categories.
paddle.set_device('npu:0')
paddle.disable_static()
x_numpy = np.random.rand(1000)
x = paddle.to_tensor(x_numpy)
out = paddle.multinomial(x, num_samples=100, replacement=False)

unique_out = np.unique(out.numpy())
self.assertEqual(
len(unique_out), 100,
"replacement is False. categories can't be sampled repeatedly")
paddle.enable_static()

def test_dygraph4(self):
paddle.set_device('npu:0')
paddle.disable_static()
logits = -1 * paddle.ones([2800])
# Categorical.sample API will call multinomial op with replacement=True
cat = paddle.distribution.Categorical(logits.exp())
cat.sample([1])
paddle.enable_static()

def test_static(self):
paddle.set_device('npu:0')
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
x = fluid.data('x', shape=[4], dtype='float32')
out = paddle.multinomial(x, num_samples=100000, replacement=True)

place = fluid.NPUPlace(0)
exe = fluid.Executor(place)

exe.run(startup_program)
x_np = np.random.rand(4).astype('float32')
out = exe.run(train_program, feed={'x': x_np}, fetch_list=[out])

sample_prob = sample_output_one_dimension(out, 4)
prob = x_np / x_np.sum(axis=-1, keepdims=True)
self.assertTrue(
np.allclose(
sample_prob, prob, rtol=0, atol=0.01),
"sample_prob: " + str(sample_prob) + "\nprob: " + str(prob))


class TestMultinomialAlias(unittest.TestCase):
def test_alias(self):
paddle.set_device('npu:0')
x = paddle.rand([4])
out1 = paddle.multinomial(x, num_samples=10, replacement=True)
out2 = paddle.tensor.multinomial(x, num_samples=10, replacement=True)
out3 = paddle.tensor.random.multinomial(
x, num_samples=10, replacement=True)


class TestMultinomialError(unittest.TestCase):
def setUp(self):
paddle.set_device('npu:0')
paddle.disable_static()

def tearDown(self):
paddle.enable_static()

def test_num_sample(self):
def test_num_sample_less_than_0():
x = paddle.rand([4])
out = paddle.multinomial(x, num_samples=-2)

self.assertRaises(ValueError, test_num_sample_less_than_0)

def test_input_probs_dim(self):
def test_dim_larger_than_2():
x = paddle.rand([2, 3, 3])
out = paddle.multinomial(x)

self.assertRaises(ValueError, test_dim_larger_than_2)

def test_dim_less_than_1():
x_np = np.random.random([])
x = paddle.to_tensor(x_np)
out = paddle.multinomial(x)

self.assertRaises(ValueError, test_dim_less_than_1)

with self.assertRaises(ValueError):
prob = paddle.rand([20, 1000])
prob[1:0] = 0
out = paddle.multinomial(prob)


if __name__ == "__main__":
unittest.main()

0 comments on commit fd14069

Please sign in to comment.