Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Oct 21, 2023
1 parent d21b9bc commit ce0d9d8
Showing 1 changed file with 31 additions and 51 deletions.
82 changes: 31 additions & 51 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1555,26 +1555,6 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor {
static const int optimalTensorRank = 6;
};

namespace {

template<typename Result, typename Current>
inline Result bitwise_cast(Current* value) {
assert(sizeof(Current) == sizeof(Result));
assert((typeid(Current) != typeid(float)) && (typeid(Current) != typeid(dnnl::impl::float16_t)));
assert((typeid(Result) != typeid(float)) && (typeid(Result) != typeid(dnnl::impl::float16_t)));
return *reinterpret_cast<Result*>(value);
}

template<typename Result, typename Current>
inline Result* bitwise_ptr_cast(Current* value) {
assert(sizeof(Current) == sizeof(Result));
assert((typeid(Current) != typeid(float)) && (typeid(Current) != typeid(dnnl::impl::float16_t)));
assert((typeid(Result) != typeid(float)) && (typeid(Result) != typeid(dnnl::impl::float16_t)));
return reinterpret_cast<Result*>(value);
}

} // namespace

/* enabled only for float at float16_t at the moment
* can be extended in the future */
template<typename T>
Expand Down Expand Up @@ -1686,52 +1666,52 @@ template<typename T,
::type * = nullptr>
class EltwiseRefExecutor : public EltwiseRefBaseExecutor<T> {
public:
EltwiseRefExecutor(Eltwise::EltwiseData opData,
EltwiseRefExecutor<T>(Eltwise::EltwiseData opData,
const VectorDims& outBlkDims,
std::vector<VectorDims> inpDims) : EltwiseRefBaseExecutor(opData, outBlkDims, inpDims) {
std::vector<VectorDims> inpDims) : EltwiseRefBaseExecutor<T>(opData, outBlkDims, inpDims) {
}

void exec(const jit_eltwise_call_args_ptrs &args_ptrs, const VectorDims &dims_out) override {
if (_opData.algo == Algorithm::EltwiseLog) {
if (this->_opData.algo == Algorithm::EltwiseLog) {
const T* src_ptr_f = reinterpret_cast<const T*>(args_ptrs.src_ptr[0]);
T* dst_ptr_f = reinterpret_cast<T*>(args_ptrs.dst_ptr);
parallel_for(_fullWorkAmount, [&](size_t i) {
parallel_for(this->_fullWorkAmount, [&](size_t i) {
dst_ptr_f[i] = logf(src_ptr_f[i]);
});
return;
}
if (_opData.algo == Algorithm::EltwisePowerStatic) {
if (this->_opData.algo == Algorithm::EltwisePowerStatic) {
const T* src_ptr_f = reinterpret_cast<const T*>(args_ptrs.src_ptr[0]);
T* dst_ptr_f = reinterpret_cast<T*>(args_ptrs.dst_ptr);
if (_opData.alpha == 2) {
parallel_for(_fullWorkAmount, [&](size_t i) {
dst_ptr_f[i] = (_opData.beta * src_ptr_f[i] + _opData.gamma) *
(_opData.beta * src_ptr_f[i] + _opData.gamma);
if (this->_opData.alpha == 2) {
parallel_for(this->_fullWorkAmount, [&](size_t i) {
dst_ptr_f[i] = (this->_opData.beta * src_ptr_f[i] + this->_opData.gamma) *
(this->_opData.beta * src_ptr_f[i] + this->_opData.gamma);
});
} else {
parallel_for(_fullWorkAmount, [&](size_t i) {
dst_ptr_f[i] = powf(_opData.beta * src_ptr_f[i] + _opData.gamma, _opData.alpha);
parallel_for(this->_fullWorkAmount, [&](size_t i) {
dst_ptr_f[i] = powf(this->_opData.beta * src_ptr_f[i] + this->_opData.gamma, this->_opData.alpha);
});
}
return;
}
if (_opData.algo == Algorithm::EltwisePowerDynamic) {
if (this->_opData.algo == Algorithm::EltwisePowerDynamic) {
const T* src_ptr_f = reinterpret_cast<const T*>(args_ptrs.src_ptr[0]);
const T* src_ptr_f_pow = reinterpret_cast<const T*>(args_ptrs.src_ptr[1]);
T* dst_ptr_f = reinterpret_cast<T*>(args_ptrs.dst_ptr);

uint32_t count_of_power_values = 1;
for (unsigned long i : _inpDims[1]) {
for (unsigned long i : this->_inpDims[1]) {
count_of_power_values *= i;
}

if (count_of_power_values == 1) {
if (src_ptr_f_pow[0] != 2) {
parallel_for(_fullWorkAmount, [&](size_t i) {
parallel_for(this->_fullWorkAmount, [&](size_t i) {
dst_ptr_f[i] = powf(src_ptr_f[i], src_ptr_f_pow[0]);
});
} else {
parallel_for(_fullWorkAmount, [&](size_t i) {
parallel_for(this->_fullWorkAmount, [&](size_t i) {
dst_ptr_f[i] = src_ptr_f[i] * src_ptr_f[i];
});
}
Expand All @@ -1740,23 +1720,23 @@ class EltwiseRefExecutor : public EltwiseRefBaseExecutor<T> {
}

std::shared_ptr<ref_eltwise_scalar_fwd_t> ref_eltwise_injector = nullptr;
if (_opData.onednnAlgorithm != dnnl::algorithm::undef) {
if (this->_opData.onednnAlgorithm != dnnl::algorithm::undef) {
ref_eltwise_injector = std::make_shared<ref_eltwise_scalar_fwd_t>(
static_cast<dnnl_alg_kind_t>(_opData.onednnAlgorithm), _opData.alpha, _opData.beta, 1.f);
static_cast<dnnl_alg_kind_t>(this->_opData.onednnAlgorithm), this->_opData.alpha, this->_opData.beta, 1.f);
}

parallel_nt(0, [&](const int ithr, const int nthr) {
size_t start = 0, end = 0;
splitter(_fullWorkAmount, nthr, ithr, start, end);
splitter(this->_fullWorkAmount, nthr, ithr, start, end);

std::vector<size_t> counters(dims_out.size(), 0);

for (size_t iwork = start; iwork < end; ++iwork) {
std::vector<T> src_f(_inputNum);
std::vector<T> src_f(this->_inputNum);
T* dst_ptr_f;
init_ptr(args_ptrs, dims_out, counters, iwork, src_f, dst_ptr_f);
this->init_ptr(args_ptrs, dims_out, counters, iwork, src_f, dst_ptr_f);

switch (_opData.algo) {
switch (this->_opData.algo) {
case Algorithm::EltwiseRelu:
case Algorithm::EltwiseGeluErf:
case Algorithm::EltwiseGeluTanh:
Expand Down Expand Up @@ -1803,8 +1783,8 @@ class EltwiseRefExecutor : public EltwiseRefBaseExecutor<T> {
// @todo implement proper isinfinite for non-float precisions
case Algorithm::EltwiseIsFinite: *dst_ptr_f = std::isfinite(static_cast<float>(src_f[0])); break;
case Algorithm::EltwiseIsInf:
*dst_ptr_f = (_opData.alpha && (src_f[0] == -std::numeric_limits<T>::infinity())) ||
(_opData.beta && (src_f[0] == std::numeric_limits<T>::infinity()));
*dst_ptr_f = (this->_opData.alpha && (src_f[0] == -std::numeric_limits<T>::infinity())) ||
(this->_opData.beta && (src_f[0] == std::numeric_limits<T>::infinity()));
break;
case Algorithm::EltwiseIsNaN: {
if (sizeof(T) == 4) {
Expand Down Expand Up @@ -1832,30 +1812,30 @@ template<typename T,
::type * = nullptr>
class BitwiseRefExecutor : public EltwiseRefBaseExecutor<T> {
public:
BitwiseRefExecutor(Eltwise::EltwiseData opData,
BitwiseRefExecutor<T>(Eltwise::EltwiseData opData,
const VectorDims& outBlkDims,
std::vector<VectorDims> inpDims) : EltwiseRefBaseExecutor(opData, outBlkDims, inpDims) {
std::vector<VectorDims> inpDims) : EltwiseRefBaseExecutor<T>(opData, outBlkDims, inpDims) {
}

void exec(const jit_eltwise_call_args_ptrs &args_ptrs, const VectorDims &dims_out) override {
std::shared_ptr<ref_eltwise_scalar_fwd_t> ref_eltwise_injector = nullptr;
if (_opData.onednnAlgorithm != dnnl::algorithm::undef) {
if (this->_opData.onednnAlgorithm != dnnl::algorithm::undef) {
ref_eltwise_injector = std::make_shared<ref_eltwise_scalar_fwd_t>(
static_cast<dnnl_alg_kind_t>(_opData.onednnAlgorithm), _opData.alpha, _opData.beta, 1.f);
static_cast<dnnl_alg_kind_t>(this->_opData.onednnAlgorithm), this->_opData.alpha, this->_opData.beta, 1.f);
}

parallel_nt(0, [&](const int ithr, const int nthr) {
size_t start = 0, end = 0;
splitter(_fullWorkAmount, nthr, ithr, start, end);
splitter(this->_fullWorkAmount, nthr, ithr, start, end);

std::vector<size_t> counters(dims_out.size(), 0);

for (size_t iwork = start; iwork < end; ++iwork) {
std::vector<T> src_f(_inputNum);
std::vector<T> src_f(this->_inputNum);
T* dst_ptr_f;
init_ptr(args_ptrs, dims_out, counters, iwork, src_f, dst_ptr_f);
this->init_ptr(args_ptrs, dims_out, counters, iwork, src_f, dst_ptr_f);

switch (_opData.algo) {
switch (this->_opData.algo) {
case Algorithm::EltwiseBitwiseAnd: {
*dst_ptr_f = src_f[0] & src_f[1];
break;
Expand Down

0 comments on commit ce0d9d8

Please sign in to comment.