Skip to content

Commit

Permalink
Move bernoulli further into ATen (pytorch#7578)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpuhrsch authored May 16, 2018
1 parent 330a725 commit 221e615
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 14 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/Declarations.cwrap
Original file line number Diff line number Diff line change
Expand Up @@ -3610,7 +3610,7 @@
- double p
]]
[[
name: bernoulli
name: _th_bernoulli
types:
- Float
- Double
Expand Down
53 changes: 49 additions & 4 deletions aten/src/ATen/native/Distributions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,59 @@ int64_t sample_poisson(double lambda, THGenerator* generator) {

namespace at {
namespace native {
Tensor& bernoulli_(Tensor& self, const Tensor& p, Generator* generator) {
self.copy_(at::bernoulli(std::get<0>(expand_inplace(self, p)), generator));

Tensor bernoulli(const Tensor& self, const Tensor& p, Generator* gen) {
Tensor result = self.type().tensor();
result.resize_(self.sizes());
return native::bernoulli_(result, p, gen);
}

Tensor bernoulli(const Tensor& self, double p, Generator* gen) {
Tensor result = self.type().tensor();
result.resize_(self.sizes());
return native::bernoulli_(result, p, gen);
}

Tensor bernoulli(const Tensor& self, Generator* gen) {
Tensor result = self.type().tensor();
result.resize_(self.sizes());
return native::bernoulli(result, self, gen);
}

Tensor& bernoulli_(Tensor& self, const Tensor& p_, Generator* gen) {
if (!self.is_cuda() && !p_.is_cuda()) {
Tensor p = p_.toType(kDouble);
AT_DISPATCH_ALL_TYPES(self.type(), "bernoulli_", [&] {
THGenerator* generator = get_generator(gen);
std::lock_guard<std::mutex> lock(generator->mutex);
CPU_tensor_apply2<scalar_t, double>(
self, p, [generator](scalar_t& ret_val, double& p_val) {
ret_val = (scalar_t)THRandom_bernoulli(generator, p_val);
});
});
return self;
}
self.copy_(at::_th_bernoulli(std::get<0>(expand_inplace(self, p_)), gen));
return self;
}

Tensor& bernoulli_(Tensor& self, double p, Generator* generator) {
Tensor& bernoulli_(Tensor& self, double p, Generator* gen) {
if (!self.is_cuda()) {
AT_DISPATCH_ALL_TYPES(self.type(), "bernoulli_", [&] {
THGenerator* generator = get_generator(gen);
std::lock_guard<std::mutex> lock(generator->mutex);
CPU_tensor_apply1<scalar_t>(self, [generator, p](scalar_t& ret_val) {
ret_val = (scalar_t)THRandom_bernoulli(generator, p);
});
});
return self;
}
Tensor probs = self.type().toScalarType(kDouble).tensor({}).fill_(p);
return native::bernoulli_(self, probs, generator);
return native::bernoulli_(self, probs, gen);
}

Tensor& bernoulli_(Tensor& self, Generator* gen) {
return native::bernoulli_(self, 0.5, gen);
}

Tensor _standard_gamma_grad_cpu(const Tensor& self, const Tensor& output) {
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,18 @@
- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, double momentum, double eps, bool cudnn_enabled) -> Tensor
variants: function

- func: bernoulli(Tensor self, Tensor p, Generator* generator=nullptr) -> Tensor

- func: bernoulli(Tensor self, double p=0.5, Generator* generator=nullptr) -> Tensor

- func: bernoulli(Tensor self, Generator* generator=nullptr) -> Tensor

- func: bernoulli_(Tensor self, Tensor p, Generator* generator=nullptr) -> Tensor

- func: bernoulli_(Tensor self, double p=0.5, Generator* generator=nullptr) -> Tensor

- func: bernoulli_(Tensor self, Generator* generator=nullptr) -> Tensor

- func: bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias) -> Tensor
variants: function

Expand Down
23 changes: 14 additions & 9 deletions aten/src/ATen/templates/TensorMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,20 @@ inline Tensor Tensor::toBackend(Backend b) const {
// all static inline to allow for inlining of the non-dynamic part of dispatch
${tensor_method_definitions}

#define DEFINE_CAST(T,name,_) \
template<> \
inline T* Tensor::data() const { \
AT_CHECK(type().scalarType() == ScalarType::name, \
"expected scalar type % s but found %s", #name, \
at::toString(type().scalarType())); \
return static_cast<T*>(this->data_ptr()); \
} \
inline T* Tensor::to##name##Data() const { return data<T>(); }
#define DEFINE_CAST(T, name, _) \
template <> \
inline T* Tensor::data() const { \
AT_CHECK( \
type().scalarType() == ScalarType::name, \
"expected scalar type ", \
#name, \
" but found ", \
at::toString(type().scalarType())); \
return static_cast<T*>(this->data_ptr()); \
} \
inline T* Tensor::to##name##Data() const { \
return data<T>(); \
}

AT_FORALL_SCALAR_TYPES(DEFINE_CAST)
#undef DEFINE_CAST
Expand Down

0 comments on commit 221e615

Please sign in to comment.