Skip to content

Commit

Permalink
refactoring + fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eshoguli committed Oct 20, 2023
1 parent 6f48cb4 commit d21b9bc
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 102 deletions.
226 changes: 124 additions & 102 deletions src/plugins/intel_cpu/src/nodes/eltwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1577,21 +1577,12 @@ inline Result* bitwise_ptr_cast(Current* value) {

/* enabled only for float at float16_t at the moment
* can be extended in the future */
template<typename T,
typename std::enable_if<
std::is_same<T, float>::value ||
std::is_same<T, int8_t>::value ||
std::is_same<T, uint8_t>::value ||
std::is_same<T, int16_t>::value ||
std::is_same<T, uint16_t>::value ||
std::is_same<T, int32_t>::value ||
std::is_same<T, dnnl::impl::float16_t>::value>
::type* = nullptr>
class EltwiseRefExecutor : public Eltwise::IEltwiseExecutor {
template<typename T>
class EltwiseRefBaseExecutor : public Eltwise::IEltwiseExecutor {
public:
EltwiseRefExecutor(Eltwise::EltwiseData opData,
const VectorDims& outBlkDims,
std::vector<VectorDims> inpDims)
EltwiseRefBaseExecutor(Eltwise::EltwiseData opData,
const VectorDims& outBlkDims,
std::vector<VectorDims> inpDims)
: _opData(std::move(opData)), _inpDims(inpDims) {
if (inpDims.empty()) {
IE_THROW() << "Can not make Eltwise executor from empty input dims array";
Expand Down Expand Up @@ -1633,6 +1624,73 @@ class EltwiseRefExecutor : public Eltwise::IEltwiseExecutor {
}
}

const VectorDims& getOutDims() const override {
return _dims;
}

size_t getBatchDimIdx() const override {
return _batchDimIdx;
}

protected:
void init_ptr(const jit_eltwise_call_args_ptrs& args_ptrs,
const VectorDims& dims_out,
std::vector<size_t>& counters,
const size_t iwork,
std::vector<T>& src_f,
T*& dst_ptr_f) {
size_t tmp = iwork;
for (ptrdiff_t j = dims_out.size() - 1; j >= 0; j--) {
counters[j] = tmp % dims_out[j];
tmp /= dims_out[j];
}

size_t index_in[MAX_ELTWISE_INPUTS] = { 0 };
for (size_t i = 0; i < _inputNum; i++) {
index_in[i] = 0;
for (size_t j = 0; j < counters.size(); j++) {
index_in[i] += counters[j] * _src_offsets[i][j];
}
index_in[i] /= sizeof(T);
}

size_t index_out = 0;
for (size_t j = 0; j < counters.size(); j++) {
index_out += counters[j] * _dst_offsets[j];
}
index_out /= sizeof(T);

//std::vector<T> src_f(_inputNum);
for (size_t i = 0; i < _inputNum; i++) {
src_f[i] = (reinterpret_cast<const T*>(args_ptrs.src_ptr[i]) + index_in[i])[0];
}
dst_ptr_f = reinterpret_cast<T*>(args_ptrs.dst_ptr) + index_out;
}

const Eltwise::EltwiseData _opData;
VectorDims _dims;
VectorDims _src_offsets[MAX_ELTWISE_INPUTS];
VectorDims _dst_offsets;
size_t _fullWorkAmount = 0;
size_t _inputNum = 0;
size_t _batchDimIdx = 0;
std::vector<VectorDims> _inpDims;
};

/* enabled only for float at float16_t at the moment
* can be extended in the future */
template<typename T,
typename std::enable_if<
std::is_same<T, float>::value ||
std::is_same<T, dnnl::impl::float16_t>::value>
::type * = nullptr>
class EltwiseRefExecutor : public EltwiseRefBaseExecutor<T> {
public:
EltwiseRefExecutor(Eltwise::EltwiseData opData,
const VectorDims& outBlkDims,
std::vector<VectorDims> inpDims) : EltwiseRefBaseExecutor(opData, outBlkDims, inpDims) {
}

void exec(const jit_eltwise_call_args_ptrs &args_ptrs, const VectorDims &dims_out) override {
if (_opData.algo == Algorithm::EltwiseLog) {
const T* src_ptr_f = reinterpret_cast<const T*>(args_ptrs.src_ptr[0]);
Expand Down Expand Up @@ -1694,32 +1752,9 @@ class EltwiseRefExecutor : public Eltwise::IEltwiseExecutor {
std::vector<size_t> counters(dims_out.size(), 0);

for (size_t iwork = start; iwork < end; ++iwork) {
size_t tmp = iwork;
for (ptrdiff_t j = dims_out.size() - 1; j >= 0; j--) {
counters[j] = tmp % dims_out[j];
tmp /= dims_out[j];
}

size_t index_in[MAX_ELTWISE_INPUTS] = {0};
for (size_t i = 0; i < _inputNum; i++) {
index_in[i] = 0;
for (size_t j = 0; j < counters.size(); j++) {
index_in[i] += counters[j] * _src_offsets[i][j];
}
index_in[i] /= sizeof(T);
}

size_t index_out = 0;
for (size_t j = 0; j < counters.size(); j++) {
index_out += counters[j] * _dst_offsets[j];
}
index_out /= sizeof(T);

std::vector<T> src_f(_inputNum);
for (size_t i = 0; i < _inputNum; i++) {
src_f[i] = (reinterpret_cast<const T*>(args_ptrs.src_ptr[i]) + index_in[i])[0];
}
T* dst_ptr_f = reinterpret_cast<T*>(args_ptrs.dst_ptr) + index_out;
T* dst_ptr_f;
init_ptr(args_ptrs, dims_out, counters, iwork, src_f, dst_ptr_f);

switch (_opData.algo) {
case Algorithm::EltwiseRelu:
Expand Down Expand Up @@ -1780,81 +1815,68 @@ class EltwiseRefExecutor : public Eltwise::IEltwiseExecutor {
break;
}
case Algorithm::EltwiseSelect: *dst_ptr_f = src_f[0] ? src_f[1] : src_f[2]; break;
default: IE_THROW() << "Unsupported operation type for Eltwise executor";
}
}
});
}
};

template<typename T,
typename std::enable_if<
std::is_same<T, int8_t>::value ||
std::is_same<T, uint8_t>::value ||
std::is_same<T, int16_t>::value ||
std::is_same<T, uint16_t>::value ||
std::is_same<T, int32_t>::value>
::type * = nullptr>
class BitwiseRefExecutor : public EltwiseRefBaseExecutor<T> {
public:
BitwiseRefExecutor(Eltwise::EltwiseData opData,
const VectorDims& outBlkDims,
std::vector<VectorDims> inpDims) : EltwiseRefBaseExecutor(opData, outBlkDims, inpDims) {
}

void exec(const jit_eltwise_call_args_ptrs &args_ptrs, const VectorDims &dims_out) override {
std::shared_ptr<ref_eltwise_scalar_fwd_t> ref_eltwise_injector = nullptr;
if (_opData.onednnAlgorithm != dnnl::algorithm::undef) {
ref_eltwise_injector = std::make_shared<ref_eltwise_scalar_fwd_t>(
static_cast<dnnl_alg_kind_t>(_opData.onednnAlgorithm), _opData.alpha, _opData.beta, 1.f);
}

parallel_nt(0, [&](const int ithr, const int nthr) {
size_t start = 0, end = 0;
splitter(_fullWorkAmount, nthr, ithr, start, end);

std::vector<size_t> counters(dims_out.size(), 0);

for (size_t iwork = start; iwork < end; ++iwork) {
std::vector<T> src_f(_inputNum);
T* dst_ptr_f;
init_ptr(args_ptrs, dims_out, counters, iwork, src_f, dst_ptr_f);

switch (_opData.algo) {
case Algorithm::EltwiseBitwiseAnd: {
const auto size = sizeof(T);
if (size == 1) {
*bitwise_ptr_cast<uint8_t>(dst_ptr_f) = bitwise_cast<uint8_t>(&src_f[0]) & bitwise_cast<uint8_t>(&src_f[1]);
} else if (size == 2) {
*bitwise_ptr_cast<uint16_t>(dst_ptr_f) = bitwise_cast<uint16_t>(&src_f[0]) & bitwise_cast<uint16_t>(&src_f[1]);
} else if (size == 4) {
*bitwise_ptr_cast<uint32_t>(dst_ptr_f) = bitwise_cast<uint32_t>(&src_f[0]) & bitwise_cast<uint32_t>(&src_f[1]);
} else {
IE_THROW() << "Unsupported operation type for EltwiseBitwiseAnd";
}
*dst_ptr_f = src_f[0] & src_f[1];
break;
}
case Algorithm::EltwiseBitwiseNot: {
const auto size = sizeof(T);
if (size == 1) {
*bitwise_ptr_cast<uint8_t>(dst_ptr_f) = ~bitwise_cast<uint8_t>(&src_f[0]);
} else if (size == 2) {
*bitwise_ptr_cast<uint16_t>(dst_ptr_f) = ~bitwise_cast<uint16_t>(&src_f[0]);
} else if (size == 4) {
*bitwise_ptr_cast<uint32_t>(dst_ptr_f) = ~bitwise_cast<uint32_t>(&src_f[0]);
} else {
IE_THROW() << "Unsupported operation type for EltwiseBitwiseNot";
}
*dst_ptr_f = ~src_f[0];
break;
}
case Algorithm::EltwiseBitwiseOr: {
const auto size = sizeof(T);
if (size == 1) {
*bitwise_ptr_cast<uint8_t>(dst_ptr_f) = bitwise_cast<uint8_t>(&src_f[0]) | bitwise_cast<uint8_t>(&src_f[1]);
} else if (size == 2) {
*bitwise_ptr_cast<uint16_t>(dst_ptr_f) = bitwise_cast<uint16_t>(&src_f[0]) | bitwise_cast<uint16_t>(&src_f[1]);
} else if (size == 4) {
*bitwise_ptr_cast<uint32_t>(dst_ptr_f) = bitwise_cast<uint32_t>(&src_f[0]) | bitwise_cast<uint32_t>(&src_f[1]);
} else {
IE_THROW() << "Unsupported operation type for EltwiseBitwiseOr";
}
*dst_ptr_f = src_f[0] | src_f[1];
break;
}
case Algorithm::EltwiseBitwiseXor: {
const auto size = sizeof(T);
if (size == 1) {
*bitwise_ptr_cast<uint8_t>(dst_ptr_f) = bitwise_cast<uint8_t>(&src_f[0]) ^ bitwise_cast<uint8_t>(&src_f[1]);
} else if (size == 2) {
*bitwise_ptr_cast<uint16_t>(dst_ptr_f) = bitwise_cast<uint16_t>(&src_f[0]) ^ bitwise_cast<uint16_t>(&src_f[1]);
} else if (size == 4) {
*bitwise_ptr_cast<uint32_t>(dst_ptr_f) = bitwise_cast<uint32_t>(&src_f[0]) ^ bitwise_cast<uint32_t>(&src_f[1]);
} else {
IE_THROW() << "Unsupported operation type for EltwiseBitwiseXor";
}
*dst_ptr_f = src_f[0] ^ src_f[1];
break;
}
default: IE_THROW() << "Unsupported operation type for Eltwise executor";
}
}
});
}

const VectorDims& getOutDims() const override {
return _dims;
}

size_t getBatchDimIdx() const override {
return _batchDimIdx;
}

private:
const Eltwise::EltwiseData _opData;
VectorDims _dims;
VectorDims _src_offsets[MAX_ELTWISE_INPUTS];
VectorDims _dst_offsets;
size_t _fullWorkAmount = 0;
size_t _inputNum = 0;
size_t _batchDimIdx = 0;
std::vector<VectorDims> _inpDims;
};

} // namespace
Expand All @@ -1874,31 +1896,31 @@ static Eltwise::executorPtr buildRefExecutor(const EltwiseKey& key) {
key.outBlkDims,
key.inpDims);
case Precision::I8:
return std::make_shared<EltwiseRefExecutor<PrecisionTrait<Precision::I8>::value_type>>(
return std::make_shared<BitwiseRefExecutor<PrecisionTrait<Precision::I8>::value_type>>(
key.eltwise_data.front(),
key.outBlkDims,
key.inpDims);

case Precision::U8:
return std::make_shared<EltwiseRefExecutor<PrecisionTrait<Precision::U8>::value_type>>(
return std::make_shared<BitwiseRefExecutor<PrecisionTrait<Precision::U8>::value_type>>(
key.eltwise_data.front(),
key.outBlkDims,
key.inpDims);

case Precision::I16:
return std::make_shared<EltwiseRefExecutor<PrecisionTrait<Precision::I16>::value_type>>(
return std::make_shared<BitwiseRefExecutor<PrecisionTrait<Precision::I16>::value_type>>(
key.eltwise_data.front(),
key.outBlkDims,
key.inpDims);

case Precision::U16:
return std::make_shared<EltwiseRefExecutor<PrecisionTrait<Precision::U16>::value_type>>(
return std::make_shared<BitwiseRefExecutor<PrecisionTrait<Precision::U16>::value_type>>(
key.eltwise_data.front(),
key.outBlkDims,
key.inpDims);
#
case Precision::I32:
return std::make_shared<EltwiseRefExecutor<PrecisionTrait<Precision::I32>::value_type>>(
return std::make_shared<BitwiseRefExecutor<PrecisionTrait<Precision::I32>::value_type>>(
key.eltwise_data.front(),
key.outBlkDims,
key.inpDims);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ void EltwiseLayerCPUTest::SetUp() {
auto data_ptr = reinterpret_cast<uint32_t*>(data_tensor.data());
std::vector<uint32_t> data(data_ptr, data_ptr + ngraph::shape_size(shape));
secondaryInput = ngraph::builder::makeConstant(netType, shape, data);
} else if (netType == ElementType::f16) {
auto data_ptr = reinterpret_cast<ov::float16*>(data_tensor.data());
std::vector<ov::float16> data(data_ptr, data_ptr + ngraph::shape_size(shape));
secondaryInput = ngraph::builder::makeConstant(netType, shape, data);
} else {
auto data_ptr = reinterpret_cast<float*>(data_tensor.data());
std::vector<float> data(data_ptr, data_ptr + ngraph::shape_size(shape));
Expand Down

0 comments on commit d21b9bc

Please sign in to comment.