diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index 04909c7d8f5a5a..eb56a3fb39503e 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -242,7 +242,8 @@ ov_add_plugin(NAME ${TARGET_NAME} DEVICE_NAME "CPU" AS_EXTENSION VERSION_DEFINES_FOR src/plugin.cpp - SOURCES ${SOURCES} ${HEADERS}) + SOURCES ${SOURCES} ${HEADERS} + ADD_CLANG_FORMAT) # give a different file name depending on target platform architecture if(ARM OR AARCH64) diff --git a/src/plugins/intel_cpu/src/cache/cache_entry.h b/src/plugins/intel_cpu/src/cache/cache_entry.h index 135a1090a60045..6e71e207b0a71c 100644 --- a/src/plugins/intel_cpu/src/cache/cache_entry.h +++ b/src/plugins/intel_cpu/src/cache/cache_entry.h @@ -4,8 +4,9 @@ #pragma once -#include #include +#include + #include "lru_cache.h" namespace ov { @@ -13,27 +14,24 @@ namespace intel_cpu { class CacheEntryBase { public: - enum class LookUpStatus : int8_t { - Hit, - Miss - }; + enum class LookUpStatus : int8_t { Hit, Miss }; + public: virtual ~CacheEntryBase() = default; }; /** * @brief Class represents a templated record in multi cache - * @tparam KeyType is a key type that must define hash() const method with return type convertible to size_t and define comparison operator. + * @tparam KeyType is a key type that must define hash() const method with return type convertible to size_t and define + * comparison operator. * @tparam ValType is a type that must meet all the requirements to the std::unordered_map mapped type - * @tparam ImplType is a type for the internal storage. It must provide put(KeyType, ValueType) and ValueType get(const KeyType&) - * interface and must have constructor of type ImplType(size_t). + * @tparam ImplType is a type for the internal storage. It must provide put(KeyType, ValueType) and ValueType get(const + * KeyType&) interface and must have constructor of type ImplType(size_t). * * @note In this implementation default constructed value objects are treated as empty objects. */ -template> +template > class CacheEntry : public CacheEntryBase { public: using ResultType = std::pair; @@ -42,11 +40,12 @@ class CacheEntry : public CacheEntryBase { explicit CacheEntry(size_t capacity) : _impl(capacity) {} /** - * @brief Searches the key in the underlying storage and returns value if it exists, or creates a value using the builder functor and adds it to - * the underlying storage. + * @brief Searches the key in the underlying storage and returns value if it exists, or creates a value using the + * builder functor and adds it to the underlying storage. * @param key is the search key * @param builder is a callable object that creates the ValType object from the KeyType lval reference - * @return result of the operation which is a pair of the requested object of ValType and the status of whether the cache hit or miss occurred + * @return result of the operation which is a pair of the requested object of ValType and the status of whether the + * cache hit or miss occurred */ ResultType getOrCreate(const KeyType& key, std::function builder) { @@ -70,5 +69,5 @@ class CacheEntry : public CacheEntryBase { ImplType _impl; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/cache/lru_cache.h b/src/plugins/intel_cpu/src/cache/lru_cache.h index 792451da16c484..c3a4d47aa9de9f 100644 --- a/src/plugins/intel_cpu/src/cache/lru_cache.h +++ b/src/plugins/intel_cpu/src/cache/lru_cache.h @@ -10,7 +10,8 @@ /** * @brief This is yet another implementation of a preemptive cache with LRU eviction policy. - * @tparam Key is a key type that must define hash() const method with return type convertible to size_t and define comparison operator. + * @tparam Key is a key type that must define hash() const method with return type convertible to size_t and define + * comparison operator. * @tparam Value is a type that must meet all the requirements to the std::unordered_map mapped type * * @attention This cache implementation IS NOT THREAD SAFE! @@ -19,7 +20,7 @@ namespace ov { namespace intel_cpu { -template +template class LruCache { public: using value_type = std::pair; @@ -33,7 +34,7 @@ class LruCache { * @param value */ - void put(const Key &key, const Value &val) { + void put(const Key& key, const Value& val) { if (0 == _capacity) { return; } @@ -56,7 +57,7 @@ class LruCache { * @return Value associated with the key or default constructed instance of the Value type. */ - Value get(const Key &key) { + Value get(const Key& key) { auto itr = _cacheMapper.find(key); if (itr == _cacheMapper.end()) { return Value(); @@ -82,13 +83,13 @@ class LruCache { * @brief Returns the current capacity value * @return the current capacity value */ - size_t getCapacity() const noexcept { - return _capacity; - } + size_t getCapacity() const noexcept { + return _capacity; + } private: struct key_hasher { - std::size_t operator()(const Key &k) const { + std::size_t operator()(const Key& k) const { return k.hash(); } }; @@ -105,5 +106,5 @@ class LruCache { size_t _capacity; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/cache/multi_cache.cpp b/src/plugins/intel_cpu/src/cache/multi_cache.cpp index 29dad18a41c770..325dfb517831b5 100644 --- a/src/plugins/intel_cpu/src/cache/multi_cache.cpp +++ b/src/plugins/intel_cpu/src/cache/multi_cache.cpp @@ -9,5 +9,5 @@ namespace intel_cpu { std::atomic_size_t MultiCache::_typeIdCounter{0}; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/cache/multi_cache.h b/src/plugins/intel_cpu/src/cache/multi_cache.h index d9b6e5f8bfe19a..e216efe6fea801 100644 --- a/src/plugins/intel_cpu/src/cache/multi_cache.h +++ b/src/plugins/intel_cpu/src/cache/multi_cache.h @@ -4,9 +4,10 @@ #pragma once +#include #include #include -#include + #include "cache_entry.h" namespace ov { @@ -20,27 +21,28 @@ namespace intel_cpu { class MultiCache { public: - template + template using EntryTypeT = CacheEntry; using EntryBasePtr = std::shared_ptr; - template + template using EntryPtr = std::shared_ptr>; public: /** - * @param capacity here means maximum records limit FOR EACH entry specified by a pair of Key/Value types. - * @note zero capacity means empty cache so no records are stored and no entries are created - */ + * @param capacity here means maximum records limit FOR EACH entry specified by a pair of Key/Value types. + * @note zero capacity means empty cache so no records are stored and no entries are created + */ explicit MultiCache(size_t capacity) : _capacity(capacity) {} /** - * @brief Searches a value of ValueType in the cache using the provided key or creates a new ValueType instance (if nothing was found) - * using the key and the builder functor and adds the new record to the cache - * @param key is the search key - * @param builder is a callable object that creates the ValType object from the KeyType lval reference. - * Also the builder type is used for the ValueType deduction - * @return result of the operation which is a pair of the requested object of ValType and the status of whether the cache hit or miss occurred - */ + * @brief Searches a value of ValueType in the cache using the provided key or creates a new ValueType instance (if + * nothing was found) using the key and the builder functor and adds the new record to the cache + * @param key is the search key + * @param builder is a callable object that creates the ValType object from the KeyType lval reference. + * Also the builder type is used for the ValueType deduction + * @return result of the operation which is a pair of the requested object of ValType and the status of whether the + * cache hit or miss occurred + */ template 201703L)) || (defined(__cplusplus) && (__cplusplus > 201703L)) @@ -54,9 +56,9 @@ class MultiCache { } private: - template + template size_t getTypeId(); - template + template EntryPtr getEntry(); private: @@ -65,13 +67,13 @@ class MultiCache { std::unordered_map _storage; }; -template +template size_t MultiCache::getTypeId() { static size_t id = _typeIdCounter.fetch_add(1); return id; } -template +template MultiCache::EntryPtr MultiCache::getEntry() { using EntryType = EntryTypeT; size_t id = getTypeId(); @@ -88,5 +90,5 @@ using MultiCacheWeakCPtr = std::weak_ptr; using MultiCachePtr = std::shared_ptr; using MultiCacheCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/compiled_model.cpp b/src/plugins/intel_cpu/src/compiled_model.cpp index 275fd0dbfff755..b237d507526d49 100644 --- a/src/plugins/intel_cpu/src/compiled_model.cpp +++ b/src/plugins/intel_cpu/src/compiled_model.cpp @@ -3,29 +3,30 @@ // #include "compiled_model.h" + +#include +#include + #include "async_infer_request.h" +#include "cpu/x64/cpu_isa_traits.hpp" #include "infer_request.h" #include "itt.h" #include "low_precision/low_precision.hpp" #include "memory_state.h" #include "openvino/core/type/element_type.hpp" #include "openvino/runtime/intel_cpu/properties.hpp" -#include "openvino/runtime/threading/executor_manager.hpp" -#include "transformations/transformation_pipeline.h" #include "openvino/runtime/properties.hpp" -#include "openvino/util/common_util.hpp" +#include "openvino/runtime/threading/cpu_message.hpp" #include "openvino/runtime/threading/cpu_streams_executor.hpp" -#include "transformations/utils/utils.hpp" #include "openvino/runtime/threading/cpu_streams_info.hpp" -#include "openvino/runtime/threading/cpu_message.hpp" +#include "openvino/runtime/threading/executor_manager.hpp" +#include "openvino/util/common_util.hpp" +#include "transformations/transformation_pipeline.h" +#include "transformations/utils/utils.hpp" #include "utils/serialize.hpp" -#include "cpu/x64/cpu_isa_traits.hpp" -#include -#include - #if defined(OV_CPU_WITH_ACL) -#include "nodes/executors/acl/acl_ie_scheduler.hpp" +# include "nodes/executors/acl/acl_ie_scheduler.hpp" #endif using namespace ov::threading; @@ -333,8 +334,7 @@ ov::Any CompiledModel::get_property(const std::string& name) const { return decltype(ov::intel_cpu::sparse_weights_decompression_rate)::value_type( config.fcSparseWeiDecompressionRate); } else if (name == ov::hint::dynamic_quantization_group_size) { - return decltype(ov::hint::dynamic_quantization_group_size)::value_type( - config.fcDynamicQuantizationGroupSize); + return decltype(ov::hint::dynamic_quantization_group_size)::value_type(config.fcDynamicQuantizationGroupSize); } else if (name == ov::hint::kv_cache_precision) { return decltype(ov::hint::kv_cache_precision)::value_type(config.kvCachePrecision); } else if (name == ov::hint::key_cache_precision) { diff --git a/src/plugins/intel_cpu/src/compiled_model.h b/src/plugins/intel_cpu/src/compiled_model.h index faedf1ae5a744c..dc3735b4f3b63e 100644 --- a/src/plugins/intel_cpu/src/compiled_model.h +++ b/src/plugins/intel_cpu/src/compiled_model.h @@ -94,5 +94,5 @@ class CompiledModel : public ov::ICompiledModel { bool m_has_sub_compiled_models = false; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/config.cpp b/src/plugins/intel_cpu/src/config.cpp index 6262027e344032..86f8dea14db4e1 100644 --- a/src/plugins/intel_cpu/src/config.cpp +++ b/src/plugins/intel_cpu/src/config.cpp @@ -4,19 +4,19 @@ #include "config.h" +#include +#include +#include + #include "cpu/x64/cpu_isa_traits.hpp" #include "openvino/core/parallel.hpp" #include "openvino/core/type/element_type_traits.hpp" #include "openvino/runtime/intel_cpu/properties.hpp" #include "openvino/runtime/internal_properties.hpp" #include "openvino/runtime/properties.hpp" +#include "utils/cpu_utils.hpp" #include "utils/debug_capabilities.h" #include "utils/precision_support.h" -#include "utils/cpu_utils.hpp" - -#include -#include -#include namespace ov { namespace intel_cpu { @@ -61,9 +61,7 @@ Config::Config() { */ void Config::applyDebugCapsProperties() { // always enable perf counters for verbose, performance summary and average counters - if (!debugCaps.verbose.empty() || - !debugCaps.summaryPerf.empty() || - !debugCaps.averageCountersPath.empty()) { + if (!debugCaps.verbose.empty() || !debugCaps.summaryPerf.empty() || !debugCaps.averageCountersPath.empty()) { collectPerfCounters = true; } } @@ -151,10 +149,10 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { logLevel = val.as(); } catch (const ov::Exception&) { OPENVINO_THROW("Wrong value ", - val.as(), - " for property key ", - key, - ". Expected only ov::log::Level::NO/ERR/WARNING/INFO/DEBUG/TRACE."); + val.as(), + " for property key ", + key, + ". Expected only ov::log::Level::NO/ERR/WARNING/INFO/DEBUG/TRACE."); } } else if (key == ov::hint::num_requests.name()) { try { @@ -243,8 +241,8 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { fcDynamicQuantizationGroupSize = val.as(); } catch (const ov::Exception&) { OPENVINO_THROW("Wrong value for property key ", - ov::hint::dynamic_quantization_group_size.name(), - ". Expected only unsinged integer numbers"); + ov::hint::dynamic_quantization_group_size.name(), + ". Expected only unsinged integer numbers"); } } else if (key == ov::enable_profiling.name()) { try { @@ -366,7 +364,7 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) { if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) { kvCachePrecision = prec; } else { - OPENVINO_THROW("invalid value"); + OPENVINO_THROW("invalid value"); } } catch (ov::Exception&) { OPENVINO_THROW("Wrong value ", @@ -511,10 +509,13 @@ void Config::updateProperties() { void Config::applyRtInfo(const std::shared_ptr& model) { // if user sets explicitly, it will be higher priority than rt_info - if (!kvCachePrecisionSetExplicitly && model->has_rt_info({"runtime_options", ov::hint::kv_cache_precision.name()})) { - this->kvCachePrecision = model->get_rt_info({"runtime_options", ov::hint::kv_cache_precision.name()}); + if (!kvCachePrecisionSetExplicitly && + model->has_rt_info({"runtime_options", ov::hint::kv_cache_precision.name()})) { + this->kvCachePrecision = + model->get_rt_info({"runtime_options", ov::hint::kv_cache_precision.name()}); } - if (!fcDynamicQuantizationGroupSizeSetExplicitly && model->has_rt_info({"runtime_options", ov::hint::dynamic_quantization_group_size.name()})) { + if (!fcDynamicQuantizationGroupSizeSetExplicitly && + model->has_rt_info({"runtime_options", ov::hint::dynamic_quantization_group_size.name()})) { this->fcDynamicQuantizationGroupSize = model->get_rt_info({"runtime_options", ov::hint::dynamic_quantization_group_size.name()}); } diff --git a/src/plugins/intel_cpu/src/config.h b/src/plugins/intel_cpu/src/config.h index bcde841814d09c..94d4b6e90c531d 100644 --- a/src/plugins/intel_cpu/src/config.h +++ b/src/plugins/intel_cpu/src/config.h @@ -4,18 +4,17 @@ #pragma once +#include +#include +#include + +#include "internal_properties.hpp" #include "openvino/core/type/element_type.hpp" #include "openvino/runtime/properties.hpp" #include "openvino/runtime/threading/istreams_executor.hpp" #include "openvino/util/common_util.hpp" - -#include "internal_properties.hpp" #include "utils/debug_caps_config.h" -#include -#include -#include - namespace ov { namespace intel_cpu { struct Config { @@ -38,11 +37,7 @@ struct Config { Disable, }; - enum class ModelType { - CNN, - LLM, - Unknown - }; + enum class ModelType { CNN, LLM, Unknown }; bool collectPerfCounters = false; bool exclusiveAsyncRequests = false; @@ -75,7 +70,8 @@ struct Config { bool streamsChanged = false; int threads = 0; int threadsPerStream = 0; - ov::threading::IStreamsExecutor::ThreadBindingType threadBindingType = ov::threading::IStreamsExecutor::ThreadBindingType::NONE; + ov::threading::IStreamsExecutor::ThreadBindingType threadBindingType = + ov::threading::IStreamsExecutor::ThreadBindingType::NONE; ov::hint::PerformanceMode hintPerfMode = ov::hint::PerformanceMode::LATENCY; std::vector> streamsRankTable; bool changedHintPerfMode = false; @@ -128,4 +124,4 @@ struct Config { }; } // namespace intel_cpu -} // namespace ov +} // namespace ov diff --git a/src/plugins/intel_cpu/src/cpu_memory.cpp b/src/plugins/intel_cpu/src/cpu_memory.cpp index 8e5fe8d72fd1f2..7cb4abc2161f14 100644 --- a/src/plugins/intel_cpu/src/cpu_memory.cpp +++ b/src/plugins/intel_cpu/src/cpu_memory.cpp @@ -3,14 +3,17 @@ // #include "cpu_memory.h" -#include "memory_desc/cpu_memory_desc_utils.h" + #include -#include "nodes/reorder.h" + +#include "memory_desc/cpu_memory_desc_utils.h" #include "nodes/common/cpu_memcpy.h" +#include "nodes/reorder.h" #include "utils/debug_capabilities.h" #if defined(__linux__) # include /* Definition of SYS_* constants */ # include + # include /* strerror(errno) */ #endif @@ -27,69 +30,72 @@ BlockedMemoryDescPtr IMemory::getDescWithType() const { } namespace { - inline void setSubnormalsToZero(float *data, size_t size) { - uint32_t *u32data = reinterpret_cast(data); - for (size_t i = 0; i < size; ++i) { - if ((u32data[i] & (0xFF << 23)) == 0) { - u32data[i] = 0; - } +inline void setSubnormalsToZero(float* data, size_t size) { + uint32_t* u32data = reinterpret_cast(data); + for (size_t i = 0; i < size; ++i) { + if ((u32data[i] & (0xFF << 23)) == 0) { + u32data[i] = 0; } } +} - void transferData(const IMemory& src, const IMemory& dst, bool ftz) { - node::Reorder::reorderData(src, dst); +void transferData(const IMemory& src, const IMemory& dst, bool ftz) { + node::Reorder::reorderData(src, dst); - if (!ftz) { - return; - } - if (src.getDesc().getPrecision() != ov::element::f32 || dst.getDesc().getPrecision() == ov::element::bf16) { + if (!ftz) { + return; + } + if (src.getDesc().getPrecision() != ov::element::f32 || dst.getDesc().getPrecision() == ov::element::bf16) { + return; + } + size_t offset = 0; + if (dst.getDesc().getType() & MemoryDescType::Dnnl) { + // here we can safely cast to DnnlMemoryDesc + auto dnnl_desc = dst.getDescWithType(); + auto desc = dnnl_desc->getDnnlDesc(); + dnnl::impl::memory_desc_wrapper wrapper(desc.get()); + offset = wrapper.offset0(); + if (wrapper.is_wino_desc() || wrapper.is_rnn_packed_desc()) { return; } - size_t offset = 0; - if (dst.getDesc().getType() & MemoryDescType::Dnnl) { - // here we can safely cast to DnnlMemoryDesc - auto dnnl_desc = dst.getDescWithType(); - auto desc = dnnl_desc->getDnnlDesc(); - dnnl::impl::memory_desc_wrapper wrapper(desc.get()); - offset = wrapper.offset0(); - if (wrapper.is_wino_desc() || wrapper.is_rnn_packed_desc()) { - return; - } - } - // actual FTZ - auto* memData = static_cast(dst.getData()); - memData += offset; - setSubnormalsToZero(memData, dst.getSize() / sizeof(float)); } + // actual FTZ + auto* memData = static_cast(dst.getData()); + memData += offset; + setSubnormalsToZero(memData, dst.getSize() / sizeof(float)); +} -} // namespace +} // namespace -Memory::Memory(const dnnl::engine& eng, MemoryDescPtr desc, const void* data, bool pads_zeroing) : - m_eng(eng), - m_pMemDesc(desc), - m_blockHandle(std::make_shared(make_unique()), this), - dnnlMemHandle(this) { - if (desc->getPrecision() == element::string) { - OPENVINO_THROW("[CPU] Memory object cannot be created for string data."); - } - create(m_pMemDesc, data, pads_zeroing); +Memory::Memory(const dnnl::engine& eng, MemoryDescPtr desc, const void* data, bool pads_zeroing) + : m_eng(eng), + m_pMemDesc(desc), + m_blockHandle(std::make_shared(make_unique()), this), + dnnlMemHandle(this) { + if (desc->getPrecision() == element::string) { + OPENVINO_THROW("[CPU] Memory object cannot be created for string data."); } + create(m_pMemDesc, data, pads_zeroing); +} -Memory::Memory(const dnnl::engine& eng, const MemoryDesc& desc, const void* data, bool pads_zeroing) : - Memory::Memory(eng, desc.clone(), data, pads_zeroing) {} - -Memory::Memory(const dnnl::engine& eng, MemoryDescPtr desc, MemoryBlockPtr block) : - m_eng(eng), m_pMemDesc(desc), m_blockHandle(block, this), dnnlMemHandle(this) { - if (desc->getPrecision() == element::string) { - OPENVINO_THROW("[CPU] Memory object can't be created for string data."); - } - bool memAllocated = m_blockHandle->getRawPtr(); +Memory::Memory(const dnnl::engine& eng, const MemoryDesc& desc, const void* data, bool pads_zeroing) + : Memory::Memory(eng, desc.clone(), data, pads_zeroing) {} - create(desc, nullptr, !memAllocated); +Memory::Memory(const dnnl::engine& eng, MemoryDescPtr desc, MemoryBlockPtr block) + : m_eng(eng), + m_pMemDesc(desc), + m_blockHandle(block, this), + dnnlMemHandle(this) { + if (desc->getPrecision() == element::string) { + OPENVINO_THROW("[CPU] Memory object can't be created for string data."); } + bool memAllocated = m_blockHandle->getRawPtr(); -Memory::Memory(const dnnl::engine& eng, const MemoryDesc& desc, MemoryBlockPtr block) : - Memory::Memory(eng, desc.clone(), block) {} + create(desc, nullptr, !memAllocated); +} + +Memory::Memory(const dnnl::engine& eng, const MemoryDesc& desc, MemoryBlockPtr block) + : Memory::Memory(eng, desc.clone(), block) {} size_t Memory::getSize() const { auto size = getDesc().getCurrentMemSize(); @@ -99,7 +105,7 @@ size_t Memory::getSize() const { return size; } -void Memory::create(const MemoryDesc &desc, const void *data, bool pads_zeroing) { +void Memory::create(const MemoryDesc& desc, const void* data, bool pads_zeroing) { create(desc.clone(), data, pads_zeroing); } @@ -187,9 +193,7 @@ dnnl::memory Memory::DnnlMemPrimHandle::getPrim() const { void* Memory::getData() const { void* data = getDataNoThrow(); - if (data == nullptr && - m_pMemDesc->getShape().isStatic() && - m_pMemDesc->getShape().getElementsCount() != 0) + if (data == nullptr && m_pMemDesc->getShape().isStatic() && m_pMemDesc->getShape().getElementsCount() != 0) OPENVINO_THROW("Memory has not been allocated"); return data; } @@ -198,7 +202,7 @@ void* MemoryBlockWithReuse::getRawPtr() const noexcept { return m_data.get(); } -void MemoryBlockWithReuse::setExtBuff(void *ptr, size_t size) { +void MemoryBlockWithReuse::setExtBuff(void* ptr, size_t size) { m_useExternalStorage = true; m_memUpperBound = size; m_data = decltype(m_data)(ptr, release); @@ -208,7 +212,7 @@ bool MemoryBlockWithReuse::resize(size_t size) { constexpr int cacheLineSize = 64; bool sizeChanged = false; if (size > m_memUpperBound) { - void *ptr = dnnl::impl::malloc(size, cacheLineSize); + void* ptr = dnnl::impl::malloc(size, cacheLineSize); if (!ptr) { OPENVINO_THROW("Failed to allocate ", size, " bytes of memory"); } @@ -236,15 +240,17 @@ void MemoryBlockWithReuse::free() { m_useExternalStorage = false; } -void MemoryBlockWithReuse::release(void *ptr) {} +void MemoryBlockWithReuse::release(void* ptr) {} -void MemoryBlockWithReuse::destroy(void *ptr) { +void MemoryBlockWithReuse::destroy(void* ptr) { dnnl::impl::free(ptr); } /////////////// StringMemory /////////////// -StringMemory::StringMemory(const dnnl::engine& engine, const MemoryDescPtr& desc, const void* data) : m_engine(engine), m_mem_desc(desc) { +StringMemory::StringMemory(const dnnl::engine& engine, const MemoryDescPtr& desc, const void* data) + : m_engine(engine), + m_mem_desc(desc) { if (m_mem_desc->getPrecision() != element::string) { OPENVINO_THROW("[CPU] StringMemory supports String type only."); } @@ -258,8 +264,8 @@ StringMemory::StringMemory(const dnnl::engine& engine, const MemoryDescPtr& desc const auto string_size = m_mem_desc->getShape().getElementsCount(); if (data != nullptr) { - auto not_const_data = const_cast(data); - m_memoryBlock->setExtBuff(reinterpret_cast(not_const_data), string_size); + auto not_const_data = const_cast(data); + m_memoryBlock->setExtBuff(reinterpret_cast(not_const_data), string_size); } else { m_memoryBlock->resize(string_size); } @@ -273,7 +279,7 @@ void StringMemory::load(const IMemory& src, bool ftz) const { transferData(src, *this, false); } -void* StringMemory::getData() const { +void* StringMemory::getData() const { return m_memoryBlock->getRawPtr(); } @@ -297,7 +303,7 @@ void StringMemory::nullify() { } } -size_t StringMemory::getSize() const { // In bytes +size_t StringMemory::getSize() const { // In bytes auto size = getDesc().getCurrentMemSize(); if (size == MemoryDesc::UNDEFINED_SIZE) { OPENVINO_THROW("Can't get memory size for undefined shape."); @@ -329,7 +335,7 @@ bool StringMemory::StringMemoryBlock::resize(size_t size) { if (size > PTRDIFF_MAX) { OPENVINO_THROW("Requested allocation size { ", size, " } exceeds PTRDIFF_MAX."); } - auto ptr_size = static_cast(size); // WA for warning alloc-size-larger-than + auto ptr_size = static_cast(size); // WA for warning alloc-size-larger-than auto ptr = new OvString[ptr_size]; if (!ptr) { OPENVINO_THROW("Failed to allocate ", size, " bytes of memory"); @@ -355,7 +361,7 @@ void StringMemory::StringMemoryBlock::destroy(OvString* ptr) { } void* StringMemory::StringMemoryBlock::getRawPtr() const noexcept { - return reinterpret_cast(m_data.get()); + return reinterpret_cast(m_data.get()); } /////////////// DnnlMemoryBlock /////////////// @@ -364,7 +370,7 @@ void* DnnlMemoryBlock::getRawPtr() const noexcept { return m_pMemBlock->getRawPtr(); } -void DnnlMemoryBlock::setExtBuff(void *ptr, size_t size) { +void DnnlMemoryBlock::setExtBuff(void* ptr, size_t size) { m_pMemBlock->setExtBuff(ptr, size); notifyUpdate(); } @@ -401,8 +407,9 @@ void DnnlMemoryBlock::notifyUpdate() { } } -StaticMemory::StaticMemory(const dnnl::engine& eng, MemoryDescPtr desc, const void* data, bool pads_zeroing) : - m_eng(eng), m_pMemDesc(desc) { +StaticMemory::StaticMemory(const dnnl::engine& eng, MemoryDescPtr desc, const void* data, bool pads_zeroing) + : m_eng(eng), + m_pMemDesc(desc) { if (desc->getPrecision() == element::string) { OPENVINO_THROW("[CPU] StaticMemory object cannot be created for string data."); } @@ -427,14 +434,13 @@ StaticMemory::StaticMemory(const dnnl::engine& eng, MemoryDescPtr desc, const vo // // ======================== m_prim.set_data_handle(m_pMemBlock->getRawPtr()); - } - catch (const std::exception& exc) { + } catch (const std::exception& exc) { dnnlErrorCtx = exc.what(); } } -StaticMemory::StaticMemory(const dnnl::engine& eng, const MemoryDesc& desc, const void* data, bool pads_zeroing) : - StaticMemory::StaticMemory(eng, desc.clone(), data, pads_zeroing) {} +StaticMemory::StaticMemory(const dnnl::engine& eng, const MemoryDesc& desc, const void* data, bool pads_zeroing) + : StaticMemory::StaticMemory(eng, desc.clone(), data, pads_zeroing) {} const MemoryDesc& StaticMemory::getDesc() const { return *m_pMemDesc; @@ -475,7 +481,7 @@ MemoryBlockPtr StaticMemory::getMemoryBlock() const { return m_pMemBlock; } -//oneDNN specifics for backward compatibility +// oneDNN specifics for backward compatibility dnnl::memory StaticMemory::getPrimitive() const { if (!m_prim) { OPENVINO_THROW("Couldn't create dnnl::memory object: ", dnnlErrorCtx); @@ -517,11 +523,11 @@ bool StaticMemory::StaticMemoryBlock::hasExtBuffer() const noexcept { } void StaticMemory::StaticMemoryBlock::registerMemory(Memory* memPtr) { - //do nothing + // do nothing } void StaticMemory::StaticMemoryBlock::unregisterMemory(Memory* memPtr) { - //do nothing + // do nothing } #if defined(__linux__) @@ -529,9 +535,9 @@ void StaticMemory::StaticMemoryBlock::unregisterMemory(Memory* memPtr) { # define MPOL_BIND 2 # define MPOL_MF_STRICT (1 << 0) # define MPOL_MF_MOVE (1 << 1) -#if !defined(__NR_mbind) && defined(__x86_64__) -# define __NR_mbind 237 -#endif +# if !defined(__NR_mbind) && defined(__x86_64__) +# define __NR_mbind 237 +# endif static long mbind(void* start, unsigned long len, int mode, @@ -585,7 +591,12 @@ bool mbind_move(const dnnl::memory mem, int numaNodeID) { return mbind_move(data, size, numaNodeID); } -MemoryPtr split_horizontal(const dnnl::engine& eng, const MemoryPtr src, int dim, int w_rank, int w_size, bool need_fill) { +MemoryPtr split_horizontal(const dnnl::engine& eng, + const MemoryPtr src, + int dim, + int w_rank, + int w_size, + bool need_fill) { auto desc = src->getDescPtr(); auto shape = src->getShape(); auto dims = shape.getDims(); @@ -620,7 +631,9 @@ MemoryPtr split_horizontal(const dnnl::engine& eng, const MemoryPtr src, int dim // reference stride VectorDims stride_dims = dims; stride_dims[dim] = splited_dim_vec[0]; - size_t stride = std::accumulate(stride_dims.begin(), stride_dims.end(), static_cast(1), std::multiplies()) * prec.size(); + size_t stride = + std::accumulate(stride_dims.begin(), stride_dims.end(), static_cast(1), std::multiplies()) * + prec.size(); // create new shape for target memory VectorDims new_dims = dims; @@ -641,7 +654,12 @@ MemoryPtr split_horizontal(const dnnl::engine& eng, const MemoryPtr src, int dim return ptr; } -MemoryPtr split_vertical(const dnnl::engine& eng, const MemoryPtr src, int dim, int w_rank, int w_size, bool need_fill) { +MemoryPtr split_vertical(const dnnl::engine& eng, + const MemoryPtr src, + int dim, + int w_rank, + int w_size, + bool need_fill) { auto desc = src->getDescPtr(); auto shape = src->getShape(); auto dims = shape.getDims(); @@ -697,7 +715,7 @@ MemoryPtr split_vertical(const dnnl::engine& eng, const MemoryPtr src, int dim, strideSize /= 2; copySize /= 2; } - parallel_for(step, [&](int i){ + parallel_for(step, [&](int i) { int dst_offset = i * copySize; int src_offset = i * splited_size + w_rank * strideSize; cpu_parallel_memcpy(dstPtr + dst_offset, srcPtr + src_offset, copySize); @@ -705,5 +723,5 @@ MemoryPtr split_vertical(const dnnl::engine& eng, const MemoryPtr src, int dim, return ptr; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/cpu_memory.h b/src/plugins/intel_cpu/src/cpu_memory.h index 70e6713e36b886..f6837064babfa6 100644 --- a/src/plugins/intel_cpu/src/cpu_memory.h +++ b/src/plugins/intel_cpu/src/cpu_memory.h @@ -4,18 +4,18 @@ #pragma once -#include "memory_desc/cpu_memory_desc.h" -#include "dnnl_extension_utils.h" -#include #include - -#include "openvino/core/type/element_type.hpp" -#include "openvino/core/type/element_type_traits.hpp" +#include #include #include #include +#include "dnnl_extension_utils.h" +#include "memory_desc/cpu_memory_desc.h" +#include "openvino/core/type/element_type.hpp" +#include "openvino/core/type/element_type_traits.hpp" + /** * @file contains a concept classes to work with memory/tensor/blob abstractions on plugin level. * @@ -47,7 +47,8 @@ class IMemoryBlock { virtual void* getRawPtr() const noexcept = 0; /** - * @brief Allows to set externally allocated memory buffer. In that case, the object has no control over the provided memory. + * @brief Allows to set externally allocated memory buffer. In that case, the object has no control over the + * provided memory. * @param ptr - pointer to the memory * @param size - size of the memory buffer */ @@ -82,11 +83,11 @@ class MemoryBlockWithReuse : public IMemoryBlock { private: bool m_useExternalStorage = false; size_t m_memUpperBound = 0ul; - std::unique_ptr m_data; + std::unique_ptr m_data; int numa_node; - static void release(void *ptr); - static void destroy(void *ptr); + static void release(void* ptr); + static void destroy(void* ptr); }; class IMemoryBlockObserver : public IMemoryBlock { @@ -128,13 +129,13 @@ class DnnlMemBlockHandle { } DnnlMemBlockHandle(const DnnlMemBlockHandle&) = delete; - DnnlMemBlockHandle& operator= (const DnnlMemBlockHandle&) = delete; + DnnlMemBlockHandle& operator=(const DnnlMemBlockHandle&) = delete; DnnlMemBlockHandle(DnnlMemBlockHandle&& source) { std::swap(m_pMemBlock, source.m_pMemBlock); std::swap(m_pMem, source.m_pMem); } - DnnlMemBlockHandle& operator= (DnnlMemBlockHandle&& rhs) { + DnnlMemBlockHandle& operator=(DnnlMemBlockHandle&& rhs) { std::swap(m_pMemBlock, rhs.m_pMemBlock); std::swap(m_pMem, rhs.m_pMem); return *this; @@ -166,7 +167,7 @@ class IMemory { virtual const MemoryDesc& getDesc() const = 0; virtual MemoryDescPtr getDescPtr() const = 0; - virtual void* getData() const = 0; // pointer to the actual memory + virtual void* getData() const = 0; // pointer to the actual memory template ::type> T* getDataAs() const { @@ -177,7 +178,7 @@ class IMemory { return static_cast(getData()); } - virtual size_t getSize() const = 0; // in bytes + virtual size_t getSize() const = 0; // in bytes virtual const Shape& getShape() const = 0; virtual const VectorDims& getStaticDims() const = 0; @@ -199,7 +200,7 @@ class IMemory { return false; } - //oneDNN specifics for backward compatibility + // oneDNN specifics for backward compatibility virtual dnnl::memory getPrimitive() const = 0; ov::element::Type getPrecision() const { @@ -211,8 +212,8 @@ class IMemory { } template ::value && !std::is_reference::value, int>::type = 0, - typename std::enable_if::value, int>::type = 0> + typename std::enable_if::value && !std::is_reference::value, int>::type = 0, + typename std::enable_if::value, int>::type = 0> std::shared_ptr getDescWithType() const; }; @@ -241,17 +242,17 @@ class StaticMemory final : public IMemory { StaticMemory(const dnnl::engine& eng, const MemoryDesc& desc, const void* data = nullptr, bool pads_zeroing = true); StaticMemory(const StaticMemory&) = delete; - StaticMemory& operator= (const StaticMemory&) = delete; + StaticMemory& operator=(const StaticMemory&) = delete; StaticMemory(Memory&&) = delete; - StaticMemory& operator= (StaticMemory&&) = delete; + StaticMemory& operator=(StaticMemory&&) = delete; const MemoryDesc& getDesc() const override; MemoryDescPtr getDescPtr() const override; - void* getData() const override; // pointer to the actual memory + void* getData() const override; // pointer to the actual memory - size_t getSize() const override; // in bytes + size_t getSize() const override; // in bytes const Shape& getShape() const override; const VectorDims& getStaticDims() const override; @@ -262,7 +263,7 @@ class StaticMemory final : public IMemory { MemoryBlockPtr getMemoryBlock() const override; - //oneDNN specifics for backward compatibility + // oneDNN specifics for backward compatibility dnnl::memory getPrimitive() const override; void nullify() override; @@ -284,10 +285,10 @@ class Memory : public IMemory { Memory(const dnnl::engine& eng, const MemoryDesc& desc, MemoryBlockPtr block); Memory(const Memory&) = delete; - Memory& operator= (const Memory&) = delete; + Memory& operator=(const Memory&) = delete; Memory(Memory&&) = delete; - Memory& operator= (Memory&&) = delete; + Memory& operator=(Memory&&) = delete; dnnl::memory getPrimitive() const override; @@ -341,7 +342,7 @@ class Memory : public IMemory { bool m_padsZeroing = true; class DnnlMemPrimHandle { public: - explicit DnnlMemPrimHandle(const Memory* memObjPtr): m_memObjPtr(memObjPtr) {} + explicit DnnlMemPrimHandle(const Memory* memObjPtr) : m_memObjPtr(memObjPtr) {} bool isInit() const; dnnl::memory getPrim() const; void resetDnnlPrim(); @@ -376,7 +377,7 @@ class StringMemory : public IMemory { private: bool m_use_external_storage = false; size_t m_str_upper_bound = 0lu; - std::unique_ptr m_data; + std::unique_ptr m_data; static void release(OvString* ptr) {} static void destroy(OvString* ptr); @@ -390,7 +391,9 @@ class StringMemory : public IMemory { : StringMemory(engine, desc.clone(), data) {} StringMemory(const dnnl::engine& engine, const MemoryDescPtr& desc, const StringMemoryBlockPtr& block) - : m_engine(engine), m_mem_desc(desc), m_memoryBlock(block) {} + : m_engine(engine), + m_mem_desc(desc), + m_memoryBlock(block) {} StringMemory(const dnnl::engine& engine, const MemoryDesc& desc, const StringMemoryBlockPtr& block) : StringMemory(engine, desc.clone(), block) {} @@ -405,7 +408,7 @@ class StringMemory : public IMemory { void* getData() const override; - size_t getSize() const override; // In bytes + size_t getSize() const override; // In bytes const Shape& getShape() const override { return m_mem_desc->getShape(); @@ -443,8 +446,18 @@ bool mbind_move(void* data, size_t size, int numaNodeID); bool mbind_move(const MemoryCPtr mem, int numaNodeID); bool mbind_move(const dnnl::memory mem, int numaNodeID); -MemoryPtr split_horizontal(const dnnl::engine& eng, const MemoryPtr src, int dim, int w_rank, int w_size, bool need_fill = true); -MemoryPtr split_vertical(const dnnl::engine& eng, const MemoryPtr src, int dim, int w_rank, int w_size, bool need_fill = true); - -} // namespace intel_cpu -} // namespace ov +MemoryPtr split_horizontal(const dnnl::engine& eng, + const MemoryPtr src, + int dim, + int w_rank, + int w_size, + bool need_fill = true); +MemoryPtr split_vertical(const dnnl::engine& eng, + const MemoryPtr src, + int dim, + int w_rank, + int w_size, + bool need_fill = true); + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/cpu_shape.cpp b/src/plugins/intel_cpu/src/cpu_shape.cpp index 4c6b5793d9f2ef..2b7011af1a1f5e 100644 --- a/src/plugins/intel_cpu/src/cpu_shape.cpp +++ b/src/plugins/intel_cpu/src/cpu_shape.cpp @@ -3,12 +3,13 @@ // #include "cpu_shape.h" + #include "utils/general_utils.h" namespace ov { namespace intel_cpu { -bool Shape::isCompatible(const VectorDims &vecDims) const { +bool Shape::isCompatible(const VectorDims& vecDims) const { if (getRank() != vecDims.size()) { return false; } @@ -21,17 +22,21 @@ bool Shape::isCompatible(const VectorDims &vecDims) const { return false; } - if (!std::equal(getMaxDims().begin(), getMaxDims().end(), vecDims.begin(), [](Dim lhs, Dim rhs) { return lhs >= rhs; })) { + if (!std::equal(getMaxDims().begin(), getMaxDims().end(), vecDims.begin(), [](Dim lhs, Dim rhs) { + return lhs >= rhs; + })) { return false; } - if (!std::equal(getMinDims().begin(), getMinDims().end(), vecDims.begin(), [](Dim lhs, Dim rhs) { return lhs <= rhs; })) { + if (!std::equal(getMinDims().begin(), getMinDims().end(), vecDims.begin(), [](Dim lhs, Dim rhs) { + return lhs <= rhs; + })) { return false; } return true; } -std::string Shape::toString() const { +std::string Shape::toString() const { std::stringstream output; output << "{"; @@ -50,10 +55,10 @@ std::string Shape::toString() const { Shape mergeShapes(const Shape& lhs, const Shape& rhs) { OPENVINO_ASSERT(lhs.getRank() == rhs.getRank(), - "Couldn't merge shapes of different ranks: shape 1:", - lhs.toString(), - " shape 2: ", - rhs.toString()); + "Couldn't merge shapes of different ranks: shape 1:", + lhs.toString(), + " shape 2: ", + rhs.toString()); const auto& lhsMinDims = lhs.getMinDims(); const auto& lhsMaxDims = lhs.getMaxDims(); @@ -66,10 +71,11 @@ Shape mergeShapes(const Shape& lhs, const Shape& rhs) { for (size_t i = 0; i < resultMinDims.size(); ++i) { resultMinDims[i] = std::max(lhsMinDims[i], rhsMinDims[i]); resultMaxDims[i] = std::min(lhsMaxDims[i], rhsMaxDims[i]); - OPENVINO_ASSERT(resultMinDims[i] <= resultMaxDims[i], "Couldn't merge shapes as the dims intervals are not overlapping."); + OPENVINO_ASSERT(resultMinDims[i] <= resultMaxDims[i], + "Couldn't merge shapes as the dims intervals are not overlapping."); } return Shape{resultMinDims, resultMaxDims}; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/cpu_shape.h b/src/plugins/intel_cpu/src/cpu_shape.h index a04b043689e520..f2895287e2f8fe 100644 --- a/src/plugins/intel_cpu/src/cpu_shape.h +++ b/src/plugins/intel_cpu/src/cpu_shape.h @@ -31,13 +31,17 @@ class Shape { type = shape.is_static() ? ShapeType::Static : ShapeType::Dynamic; initDims(); - hasZeroDimensions = std::any_of(dims.begin(), dims.end(), [](size_t dim) { return dim == 0; } ); + hasZeroDimensions = std::any_of(dims.begin(), dims.end(), [](size_t dim) { + return dim == 0; + }); } explicit Shape(const VectorDims& shape) { dims = minDims = maxDims = shape; type = ShapeType::Static; - hasZeroDimensions = std::any_of(dims.begin(), dims.end(), [](size_t dim) { return dim == 0; } ); + hasZeroDimensions = std::any_of(dims.begin(), dims.end(), [](size_t dim) { + return dim == 0; + }); } Shape(const VectorDims& minDims, const VectorDims& maxDims) { @@ -49,13 +53,17 @@ class Shape { initDims(); - if (std::any_of(dims.begin(), dims.end(), [](size_t dim) { return dim == Shape::UNDEFINED_DIM; } )) { + if (std::any_of(dims.begin(), dims.end(), [](size_t dim) { + return dim == Shape::UNDEFINED_DIM; + })) { type = ShapeType::Dynamic; } else { type = ShapeType::Static; } - hasZeroDimensions = std::any_of(dims.begin(), dims.end(), [](size_t dim) { return dim == 0; } ); + hasZeroDimensions = std::any_of(dims.begin(), dims.end(), [](size_t dim) { + return dim == 0; + }); } Shape(const std::initializer_list& shape) { @@ -69,7 +77,9 @@ class Shape { initDims(); - hasZeroDimensions = std::any_of(dims.begin(), dims.end(), [](size_t dim) { return dim == 0; } ); + hasZeroDimensions = std::any_of(dims.begin(), dims.end(), [](size_t dim) { + return dim == 0; + }); } /** @@ -181,21 +191,21 @@ class Shape { std::string toString() const; - bool operator == (const Shape& rhs) const { + bool operator==(const Shape& rhs) const { return minDims == rhs.minDims && maxDims == rhs.maxDims; } - bool operator != (const Shape& rhs) const { + bool operator!=(const Shape& rhs) const { return !(*this == rhs); } bool hasDefinedUpperBounds() const { - return std::all_of(maxDims.begin(), maxDims.end(), [](Dim dim){ return dim != UNDEFINED_DIM; }); + return std::all_of(maxDims.begin(), maxDims.end(), [](Dim dim) { + return dim != UNDEFINED_DIM; + }); } - enum : Dim { - UNDEFINED_DIM = std::numeric_limits::max() - }; + enum : Dim { UNDEFINED_DIM = std::numeric_limits::max() }; private: void initDims() { @@ -205,10 +215,7 @@ class Shape { } } - enum class ShapeType { - Static, - Dynamic - } type {ShapeType::Static}; + enum class ShapeType { Static, Dynamic } type{ShapeType::Static}; bool hasZeroDimensions = false; @@ -229,5 +236,5 @@ class Shape { Shape mergeShapes(const Shape& lhs, const Shape& rhs); -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/cpu_streams_calculation.cpp b/src/plugins/intel_cpu/src/cpu_streams_calculation.cpp index 244adb7c40c23c..be6f5c4035d1ee 100644 --- a/src/plugins/intel_cpu/src/cpu_streams_calculation.cpp +++ b/src/plugins/intel_cpu/src/cpu_streams_calculation.cpp @@ -4,6 +4,11 @@ #include "cpu_streams_calculation.hpp" +#include +#include +#include +#include + #include "cpu_map_scheduling.hpp" #include "graph.h" #include "openvino/op/fake_quantize.hpp" @@ -13,29 +18,25 @@ #include "transformations/utils.hpp" #include "transformations/utils/utils.hpp" -#include -#include -#include -#include - using namespace ov; using namespace ov::threading; -#define INIT_VAL -100 +#define INIT_VAL -100 #define TP_CPU_LIMIT 32 namespace ov { namespace intel_cpu { -std::vector> get_streams_info_table(const int input_streams, - const bool input_streams_changed, - const int input_threads, - const int input_infer_requests, - const int model_prefer_threads, - const int input_current_socket_id, - const std::string input_perf_hint, - const std::set hint_model_distribution_policy, - const std::vector>& proc_type_table) { +std::vector> get_streams_info_table( + const int input_streams, + const bool input_streams_changed, + const int input_threads, + const int input_infer_requests, + const int model_prefer_threads, + const int input_current_socket_id, + const std::string input_perf_hint, + const std::set hint_model_distribution_policy, + const std::vector>& proc_type_table) { std::vector stream_info(CPU_STREAMS_TABLE_SIZE, INIT_VAL); std::vector> streams_info_table; std::vector> proc_socket_table; @@ -339,8 +340,7 @@ std::vector> get_streams_info_table(const int input_streams, n_threads_per_stream = static_cast(n_threads / n_streams); check_threads_per_stream(); } else { - n_threads_per_stream = - model_threads > 0 ? model_threads : static_cast(n_threads / n_streams); + n_threads_per_stream = model_threads > 0 ? model_threads : static_cast(n_threads / n_streams); } } } @@ -590,7 +590,7 @@ int get_model_prefer_threads(const int num_streams, (networkToleranceForLowCache.ratio_mem_limited_gemms > ov::MemBandwidthPressure::LIMITED))) { config.modelPreferThreads = 8; } -#elif((defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)) && defined(__APPLE__)) +#elif ((defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64)) && defined(__APPLE__)) config.modelPreferThreads = 1; if (networkToleranceForLowCache.max_mem_tolerance == ov::MemBandwidthPressure::UNKNOWN) { if ((networkToleranceForLowCache.ratio_compute_convs == ov::MemBandwidthPressure::ALL) || diff --git a/src/plugins/intel_cpu/src/cpu_streams_calculation.hpp b/src/plugins/intel_cpu/src/cpu_streams_calculation.hpp index e362c0373d8d1d..0a0b4a1449b7cb 100644 --- a/src/plugins/intel_cpu/src/cpu_streams_calculation.hpp +++ b/src/plugins/intel_cpu/src/cpu_streams_calculation.hpp @@ -44,15 +44,16 @@ namespace intel_cpu { * in previous function. * @return streams information table which will be used by StreamsExecutor. */ -std::vector> get_streams_info_table(const int input_streams, - const bool input_streams_changed, - const int input_threads, - const int input_infer_requests, - const int model_prefer_threads, - const int input_current_socket_id, - const std::string input_perf_hint, - const std::set hint_llm_distribution_policy, - const std::vector>& proc_type_table); +std::vector> get_streams_info_table( + const int input_streams, + const bool input_streams_changed, + const int input_threads, + const int input_infer_requests, + const int model_prefer_threads, + const int input_current_socket_id, + const std::string input_perf_hint, + const std::set hint_llm_distribution_policy, + const std::vector>& proc_type_table); /** * @brief Generate streams rank table for tensor parallel according to streams info table. @@ -106,9 +107,7 @@ std::vector> generate_stream_info(const int streams, * @param[in] model graph handle * @param[in] config intel cpu configuration */ -void get_num_streams(const int streams, - const std::shared_ptr& model, - Config& config); +void get_num_streams(const int streams, const std::shared_ptr& model, Config& config); } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/cpu_tensor.cpp b/src/plugins/intel_cpu/src/cpu_tensor.cpp index 1a045ca117a538..0f82a8a9a4dfec 100644 --- a/src/plugins/intel_cpu/src/cpu_tensor.cpp +++ b/src/plugins/intel_cpu/src/cpu_tensor.cpp @@ -16,7 +16,8 @@ Tensor::Tensor(MemoryPtr memptr) : m_memptr{memptr} { // only support plain data format ncsp. auto memdesc = m_memptr->getDescPtr(); - OPENVINO_ASSERT(memdesc->hasLayoutType(LayoutType::ncsp), "intel_cpu::Tensor only supports memory with ncsp layout."); + OPENVINO_ASSERT(memdesc->hasLayoutType(LayoutType::ncsp), + "intel_cpu::Tensor only supports memory with ncsp layout."); m_element_type = memdesc->getPrecision(); } @@ -24,8 +25,14 @@ Tensor::Tensor(MemoryPtr memptr) : m_memptr{memptr} { void Tensor::set_shape(ov::Shape new_shape) { const auto& shape = m_memptr->getDescPtr()->getShape(); if (shape.isStatic()) { - DEBUG_LOG("tensor's memory object ", m_memptr.get(), ", ", vec2str(shape.getStaticDims()), " -> ", new_shape.to_string()); - if (shape.getStaticDims() == new_shape) return; + DEBUG_LOG("tensor's memory object ", + m_memptr.get(), + ", ", + vec2str(shape.getStaticDims()), + " -> ", + new_shape.to_string()); + if (shape.getStaticDims() == new_shape) + return; } auto desc = m_memptr->getDescPtr(); @@ -69,7 +76,7 @@ void Tensor::update_strides() const { OPENVINO_ASSERT(blocked_desc, "not a valid blocked memory descriptor."); auto& strides = blocked_desc->getStrides(); m_strides.resize(strides.size()); - std::transform(strides.cbegin(), strides.cend(), m_strides.begin(), [this] (const size_t stride) { + std::transform(strides.cbegin(), strides.cend(), m_strides.begin(), [this](const size_t stride) { return stride * m_element_type.size(); }); } @@ -96,5 +103,5 @@ std::shared_ptr make_tensor(MemoryPtr mem) { return std::make_shared(mem); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/cpu_tensor.h b/src/plugins/intel_cpu/src/cpu_tensor.h index 0f073e0d298faf..86648ce969b168 100644 --- a/src/plugins/intel_cpu/src/cpu_tensor.h +++ b/src/plugins/intel_cpu/src/cpu_tensor.h @@ -4,8 +4,8 @@ #pragma once -#include "openvino/runtime/itensor.hpp" #include "cpu_memory.h" +#include "openvino/runtime/itensor.hpp" namespace ov { namespace intel_cpu { @@ -29,7 +29,9 @@ class Tensor : public ITensor { void* data(const element::Type& type = {}) const override; - MemoryPtr get_memory() {return m_memptr;} + MemoryPtr get_memory() { + return m_memptr; + } private: void update_strides() const; @@ -44,5 +46,5 @@ class Tensor : public ITensor { std::shared_ptr make_tensor(MemoryPtr mem); -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/cpu_types.cpp b/src/plugins/intel_cpu/src/cpu_types.cpp index 30884bbe649962..67c538bd78341a 100644 --- a/src/plugins/intel_cpu/src/cpu_types.cpp +++ b/src/plugins/intel_cpu/src/cpu_types.cpp @@ -2,10 +2,11 @@ // SPDX-License-Identifier: Apache-2.0 // #include "cpu_types.h" -#include "cpu_shape.h" -#include #include +#include + +#include "cpu_shape.h" namespace ov { namespace intel_cpu { @@ -260,8 +261,7 @@ static const TypeToNameMap& get_type_to_name_tbl() { {"QKVProjection", Type::QKVProjection}, {"RMS", Type::RMS}, {"SearchSorted", Type::SearchSorted}, - {"LoraSubgraph", Type::LoRA} - }; + {"LoraSubgraph", Type::LoRA}}; return type_to_name_tbl; } diff --git a/src/plugins/intel_cpu/src/dnnl_extension_utils.cpp b/src/plugins/intel_cpu/src/dnnl_extension_utils.cpp index 3d9b2f69bd8f66..457f8368f734dd 100644 --- a/src/plugins/intel_cpu/src/dnnl_extension_utils.cpp +++ b/src/plugins/intel_cpu/src/dnnl_extension_utils.cpp @@ -47,79 +47,79 @@ uint8_t DnnlExtensionUtils::sizeOfDataType(dnnl::memory::data_type dataType) { dnnl::memory::data_type DnnlExtensionUtils::ElementTypeToDataType(const ov::element::Type& elementType) { switch (elementType) { - case ov::element::f32: - return memory::data_type::f32; - case ov::element::i32: - return memory::data_type::s32; - case ov::element::bf16: - return memory::data_type::bf16; - case ov::element::i8: - return memory::data_type::s8; - case ov::element::u8: - case ov::element::boolean: - return memory::data_type::u8; - case ov::element::u1: - return memory::data_type::bin; - case ov::element::f16: - return memory::data_type::f16; - case ov::element::nf4: - return memory::data_type::nf4; - case ov::element::i4: - return memory::data_type::s4; - case ov::element::u4: - return memory::data_type::u4; - case ov::element::f8e8m0: - return memory::data_type::f8_e8m0; - case ov::element::f4e2m1: - return memory::data_type::f4_e2m1; - case ov::element::undefined: - return memory::data_type::undef; - default: { - OPENVINO_THROW("CPU plugin does not support ", elementType.to_string(), " for use with oneDNN."); - } + case ov::element::f32: + return memory::data_type::f32; + case ov::element::i32: + return memory::data_type::s32; + case ov::element::bf16: + return memory::data_type::bf16; + case ov::element::i8: + return memory::data_type::s8; + case ov::element::u8: + case ov::element::boolean: + return memory::data_type::u8; + case ov::element::u1: + return memory::data_type::bin; + case ov::element::f16: + return memory::data_type::f16; + case ov::element::nf4: + return memory::data_type::nf4; + case ov::element::i4: + return memory::data_type::s4; + case ov::element::u4: + return memory::data_type::u4; + case ov::element::f8e8m0: + return memory::data_type::f8_e8m0; + case ov::element::f4e2m1: + return memory::data_type::f4_e2m1; + case ov::element::undefined: + return memory::data_type::undef; + default: { + OPENVINO_THROW("CPU plugin does not support ", elementType.to_string(), " for use with oneDNN."); + } } } ov::element::Type DnnlExtensionUtils::DataTypeToElementType(const dnnl::memory::data_type& dataType) { switch (dataType) { - case memory::data_type::f32: - return ov::element::f32; - case memory::data_type::s32: - return ov::element::i32; - case memory::data_type::bf16: - return ov::element::bf16; - case memory::data_type::s8: - return ov::element::i8; - case memory::data_type::u8: - return ov::element::u8; - case memory::data_type::bin: - return ov::element::u1; - case memory::data_type::f16: - return ov::element::f16; - case memory::data_type::f64: - return ov::element::f64; - case memory::data_type::nf4: - return ov::element::nf4; - case memory::data_type::s4: - return ov::element::i4; - case memory::data_type::u4: - return ov::element::u4; - case memory::data_type::f8_e8m0: - return ov::element::f8e8m0; - case memory::data_type::f4_e2m1: - return ov::element::f4e2m1; - case memory::data_type::undef: - return ov::element::undefined; - default: { - OPENVINO_THROW("Unsupported data type."); - } + case memory::data_type::f32: + return ov::element::f32; + case memory::data_type::s32: + return ov::element::i32; + case memory::data_type::bf16: + return ov::element::bf16; + case memory::data_type::s8: + return ov::element::i8; + case memory::data_type::u8: + return ov::element::u8; + case memory::data_type::bin: + return ov::element::u1; + case memory::data_type::f16: + return ov::element::f16; + case memory::data_type::f64: + return ov::element::f64; + case memory::data_type::nf4: + return ov::element::nf4; + case memory::data_type::s4: + return ov::element::i4; + case memory::data_type::u4: + return ov::element::u4; + case memory::data_type::f8_e8m0: + return ov::element::f8e8m0; + case memory::data_type::f4_e2m1: + return ov::element::f4e2m1; + case memory::data_type::undef: + return ov::element::undefined; + default: { + OPENVINO_THROW("Unsupported data type."); + } } } -Dim DnnlExtensionUtils::convertToDim(const dnnl::memory::dim &dim) { - return dim == DNNL_RUNTIME_DIM_VAL ? Shape::UNDEFINED_DIM : static_cast(dim); +Dim DnnlExtensionUtils::convertToDim(const dnnl::memory::dim& dim) { + return dim == DNNL_RUNTIME_DIM_VAL ? Shape::UNDEFINED_DIM : static_cast(dim); } -dnnl::memory::dim DnnlExtensionUtils::convertToDnnlDim(const Dim &dim) { +dnnl::memory::dim DnnlExtensionUtils::convertToDnnlDim(const Dim& dim) { return dim == Shape::UNDEFINED_DIM ? DNNL_RUNTIME_DIM_VAL : static_cast(dim); } @@ -141,25 +141,25 @@ memory::dims DnnlExtensionUtils::convertToDnnlDims(const VectorDims& dims) { memory::format_tag DnnlExtensionUtils::GetPlainFormatByRank(size_t rank) { switch (rank) { - case 0: - case 1: - return memory::format_tag::a; - case 2: - return memory::format_tag::ab; - case 3: - return memory::format_tag::abc; - case 4: - return memory::format_tag::abcd; - case 5: - return memory::format_tag::abcde; - case 6: - return memory::format_tag::abcdef; - default: - return memory::format_tag::undef; + case 0: + case 1: + return memory::format_tag::a; + case 2: + return memory::format_tag::ab; + case 3: + return memory::format_tag::abc; + case 4: + return memory::format_tag::abcd; + case 5: + return memory::format_tag::abcde; + case 6: + return memory::format_tag::abcdef; + default: + return memory::format_tag::undef; } } -DnnlMemoryDescPtr DnnlExtensionUtils::makeDescriptor(const dnnl::memory::desc &desc) { +DnnlMemoryDescPtr DnnlExtensionUtils::makeDescriptor(const dnnl::memory::desc& desc) { return makeDescriptor(desc.get()); } @@ -182,7 +182,8 @@ size_t DnnlExtensionUtils::getMemSizeForDnnlDesc(const dnnl::memory::desc& desc) return size; } -std::shared_ptr DnnlExtensionUtils::makeUndefinedDesc(const memory::desc &desc, const Shape &shape) { +std::shared_ptr DnnlExtensionUtils::makeUndefinedDesc(const memory::desc& desc, + const Shape& shape) { if (desc.get_format_kind() == memory::format_kind::blocked) { return std::shared_ptr(new DnnlBlockedMemoryDesc(desc, shape)); } else { @@ -190,7 +191,9 @@ std::shared_ptr DnnlExtensionUtils::makeUndefinedDesc(con } } -DnnlMemoryDescPtr DnnlExtensionUtils::query_md(const const_dnnl_primitive_desc_t& pd, const dnnl::query& what, int idx) { +DnnlMemoryDescPtr DnnlExtensionUtils::query_md(const const_dnnl_primitive_desc_t& pd, + const dnnl::query& what, + int idx) { auto query = dnnl::convert_to_c(what); const auto* cdesc = dnnl_primitive_desc_query_md(pd, query, idx); @@ -201,7 +204,7 @@ DnnlMemoryDescPtr DnnlExtensionUtils::query_md(const const_dnnl_primitive_desc_t } std::string DnnlExtensionUtils::query_impl_info_str(const const_dnnl_primitive_desc_t& pd) { - const char *res; + const char* res; dnnl_status_t status = dnnl_primitive_desc_query(pd, dnnl_query_impl_info_str, 0, &res); if (status != dnnl_success) OPENVINO_THROW("query_impl_info_str failed."); @@ -209,10 +212,9 @@ std::string DnnlExtensionUtils::query_impl_info_str(const const_dnnl_primitive_d } bool DnnlExtensionUtils::find_implementation(dnnl::primitive_desc& desc, impl_desc_type impl_type) { - return DnnlExtensionUtils::find_implementation(desc, - [impl_type](impl_desc_type cur_impl_type){ - return cur_impl_type == impl_type; - }); + return DnnlExtensionUtils::find_implementation(desc, [impl_type](impl_desc_type cur_impl_type) { + return cur_impl_type == impl_type; + }); } dnnl_memory_desc_t DnnlExtensionUtils::clone_desc(const_dnnl_memory_desc_t cdesc) { @@ -233,31 +235,33 @@ const char* DnnlExtensionUtils::query_pd_info(const_dnnl_primitive_desc_t pd) { bool DnnlExtensionUtils::isUnarySupportedAsPostOp(Algorithm alg) { #if defined(OV_CPU_WITH_ACL) - return one_of(alg, Algorithm::EltwiseRelu, - Algorithm::EltwiseTanh, - Algorithm::EltwiseElu, - Algorithm::EltwiseAbs, - Algorithm::EltwiseSqrt, - Algorithm::EltwiseSoftRelu, - Algorithm::EltwiseSigmoid, - Algorithm::EltwiseClamp); + return one_of(alg, + Algorithm::EltwiseRelu, + Algorithm::EltwiseTanh, + Algorithm::EltwiseElu, + Algorithm::EltwiseAbs, + Algorithm::EltwiseSqrt, + Algorithm::EltwiseSoftRelu, + Algorithm::EltwiseSigmoid, + Algorithm::EltwiseClamp); #elif defined(OPENVINO_ARCH_X86_64) - return one_of(alg, Algorithm::EltwiseRelu, - Algorithm::EltwiseGeluErf, - Algorithm::EltwiseGeluTanh, - Algorithm::EltwiseElu, - Algorithm::EltwiseSigmoid, - Algorithm::EltwiseClamp, - Algorithm::EltwiseTanh, - Algorithm::EltwiseSwish, - Algorithm::EltwiseHswish, - Algorithm::EltwiseMish, - Algorithm::EltwiseHsigmoid, - Algorithm::EltwiseRoundHalfToEven, - Algorithm::EltwiseRoundHalfAwayFromZero, - Algorithm::EltwiseAbs, - Algorithm::EltwiseSqrt, - Algorithm::EltwiseSoftRelu); + return one_of(alg, + Algorithm::EltwiseRelu, + Algorithm::EltwiseGeluErf, + Algorithm::EltwiseGeluTanh, + Algorithm::EltwiseElu, + Algorithm::EltwiseSigmoid, + Algorithm::EltwiseClamp, + Algorithm::EltwiseTanh, + Algorithm::EltwiseSwish, + Algorithm::EltwiseHswish, + Algorithm::EltwiseMish, + Algorithm::EltwiseHsigmoid, + Algorithm::EltwiseRoundHalfToEven, + Algorithm::EltwiseRoundHalfAwayFromZero, + Algorithm::EltwiseAbs, + Algorithm::EltwiseSqrt, + Algorithm::EltwiseSoftRelu); #else return false; #endif @@ -269,5 +273,5 @@ std::string DnnlExtensionUtils::computeWeightsStringHash(const std::shared_ptr(memory->getData())); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/dnnl_extension_utils.h b/src/plugins/intel_cpu/src/dnnl_extension_utils.h index 7a968ea3c71c3d..ecf223b48497cd 100644 --- a/src/plugins/intel_cpu/src/dnnl_extension_utils.h +++ b/src/plugins/intel_cpu/src/dnnl_extension_utils.h @@ -10,11 +10,11 @@ #include +#include "common/c_types_map.hpp" #include "cpu_types.h" #include "onednn/dnnl.h" #include "onednn/iml_type_mapper.h" #include "openvino/core/type/element_type.hpp" -#include "common/c_types_map.hpp" namespace ov { namespace intel_cpu { @@ -29,8 +29,8 @@ class DnnlExtensionUtils { static uint8_t sizeOfDataType(dnnl::memory::data_type dataType); static dnnl::memory::data_type ElementTypeToDataType(const ov::element::Type& elementType); static ov::element::Type DataTypeToElementType(const dnnl::memory::data_type& dataType); - static Dim convertToDim(const dnnl::memory::dim &dim); - static dnnl::memory::dim convertToDnnlDim(const Dim &dim); + static Dim convertToDim(const dnnl::memory::dim& dim); + static dnnl::memory::dim convertToDnnlDim(const Dim& dim); static VectorDims convertToVectorDims(const dnnl::memory::dims& dims); static VectorDims convertToVectorDims(const dnnl::impl::dims_t dims, const int ndims); static std::vector convertToDnnlDims(const VectorDims& dims); @@ -41,25 +41,28 @@ class DnnlExtensionUtils { * @param desc dnnl::memory::desc from which one of the descriptors will be created * @return pointer to DnnlBlockedMemoryDesc or DnnlMemoryDesc */ - static std::shared_ptr makeDescriptor(const dnnl::memory::desc &desc); + static std::shared_ptr makeDescriptor(const dnnl::memory::desc& desc); static std::shared_ptr makeDescriptor(const_dnnl_memory_desc_t desc); /** * @brief Helper function that creates DnnlBlockedMemoryDesc from defined dnnl::memory::desc and undefined shape. - * It uses desc as an basis for the new undefined one. Specifically, type, layout, precision, blocks, extra data will be preserved. + * It uses desc as an basis for the new undefined one. Specifically, type, layout, precision, blocks, extra data + * will be preserved. * @param desc dnnl::memory::desc dnnl desc which will be used as a basis of the new descriptor * @param shape a new undefined shape * @return pointer to the created DnnlBlockedMemoryDesc * @note Obly blocked descriptors are allowed at the moment */ - static std::shared_ptr makeUndefinedDesc(const dnnl::memory::desc &desc, const Shape& shape); + static std::shared_ptr makeUndefinedDesc(const dnnl::memory::desc& desc, const Shape& shape); static size_t getMemSizeForDnnlDesc(const dnnl::memory::desc& desc); - static std::shared_ptr query_md(const const_dnnl_primitive_desc_t& pd, const dnnl::query& what, int idx = 0); + static std::shared_ptr query_md(const const_dnnl_primitive_desc_t& pd, + const dnnl::query& what, + int idx = 0); static std::string query_impl_info_str(const const_dnnl_primitive_desc_t& pd); - template + template static bool find_implementation(dnnl::primitive_desc& desc, T&& comparator) { dnnl::primitive_desc_iterator& itpd = desc; @@ -77,7 +80,7 @@ class DnnlExtensionUtils { return false; } - template + template static void for_each_implementation(dnnl::primitive_desc& desc, bool first_match, T&& comparator, L&& func) { dnnl::primitive_desc_iterator& itpd = desc; @@ -113,5 +116,5 @@ class DnnlExtensionUtils { const std::shared_ptr& dstDesc); }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/dnnl_postops_composer.cpp b/src/plugins/intel_cpu/src/dnnl_postops_composer.cpp index 70d28f1f4ac739..7d62e5cb6b673d 100644 --- a/src/plugins/intel_cpu/src/dnnl_postops_composer.cpp +++ b/src/plugins/intel_cpu/src/dnnl_postops_composer.cpp @@ -656,7 +656,7 @@ static MemoryPtr prepackDecompressionParams(const MemoryCPtr& paramsPtr, // weights without batch: (OC, G) // weights with batch: (B, OC, G) const size_t OC = shape[shape.size() - 2]; - const size_t G = shape[shape.size() - 1]; + const size_t G = shape[shape.size() - 1]; Shape dstShape = Shape({OC, G}); @@ -698,8 +698,7 @@ void DnnlPostOpsComposer::appendDecompressionZeroPoints(const MemoryCPtr& zero_p if (zero_points_ptr == nullptr) return; - auto zeroPointsMem = - prepackDecompressionParams(zero_points_ptr, needTranspose, dstPrecision, engine); + auto zeroPointsMem = prepackDecompressionParams(zero_points_ptr, needTranspose, dstPrecision, engine); attr.set_zero_points_dims(DNNL_ARG_WEIGHTS, DnnlExtensionUtils::convertToDnnlDims(zeroPointsMem->getStaticDims()), DnnlExtensionUtils::ElementTypeToDataType(dstPrecision)); diff --git a/src/plugins/intel_cpu/src/dnnl_postops_composer.h b/src/plugins/intel_cpu/src/dnnl_postops_composer.h index 8c2718aaaed4d5..7ae634658b005f 100644 --- a/src/plugins/intel_cpu/src/dnnl_postops_composer.h +++ b/src/plugins/intel_cpu/src/dnnl_postops_composer.h @@ -12,8 +12,8 @@ #include "cpu_memory.h" #include "nodes/executors/dnnl/dnnl_aliases.hpp" -#include "post_ops.hpp" #include "nodes/executors/dnnl/dnnl_post_op_data.hpp" +#include "post_ops.hpp" namespace ov { namespace intel_cpu { @@ -31,7 +31,9 @@ class DnnlPostOpsComposer { const dnnl::memory::data_type outDataType); DnnlPrimitiveAttrs compose(); void appendDecompressionScales(const MemoryCPtr& scales_ptr, bool needTranspose, ov::element::Type dstPrecision); - void appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, bool needTranspose, ov::element::Type dstPrecision); + void appendDecompressionZeroPoints(const MemoryCPtr& zero_points_ptr, + bool needTranspose, + ov::element::Type dstPrecision); void setDynamicQuantizationParams(uint64_t groupSize); private: diff --git a/src/plugins/intel_cpu/src/dnnl_postops_composer_legacy.cpp b/src/plugins/intel_cpu/src/dnnl_postops_composer_legacy.cpp index cb59492463f410..3e40ead65d6cc3 100644 --- a/src/plugins/intel_cpu/src/dnnl_postops_composer_legacy.cpp +++ b/src/plugins/intel_cpu/src/dnnl_postops_composer_legacy.cpp @@ -3,9 +3,11 @@ // #include "dnnl_postops_composer_legacy.h" + #include #include + #include "utils/debug_capabilities.h" namespace ov { @@ -39,10 +41,10 @@ DnnlPostOpsComposerLegacy::DnnlPostOpsComposerLegacy(const dnnl::engine& engine, wei_scale_mask = wei_scale_values.size() > 1 ? weiScaleMaskPerChannel : 0; dst_scale_val = 1.0; - //set the DQscale into attr weight scale before appending any post-ops. + // set the DQscale into attr weight scale before appending any post-ops. updateWeiScales(); - //If having the bias, attr weight scale can't be updated for further ops-ops optimization. - //ONEDNN 3.x quantization for scheme: QuantizedInput * QuantizedWeight * DQScale + Bias. + // If having the bias, attr weight scale can't be updated for further ops-ops optimization. + // ONEDNN 3.x quantization for scheme: QuantizedInput * QuantizedWeight * DQScale + Bias. weightScaleAvailable = !hasBias; } else if (!DQScales.empty()) { // DQ scale is fused but swiching back to non-INT8 for execution in some cases. @@ -115,22 +117,22 @@ bool DnnlPostOpsComposerLegacy::appendScale(const std::vector& scale, boo return true; } if (weightScaleAvailable) { - //oneDNN v3.* weight scale can also be used in the further optimization patterns. - // there are so many possible optimizations can be done, for example: + // oneDNN v3.* weight scale can also be used in the further optimization patterns. + // there are so many possible optimizations can be done, for example: // - // we can switch the existing postOps's order to take - // advantage of output scale if it's available: - // relu(x)*scale = relu(x*scale) - // or we can fuse it into previous one as long as they are - // compatible in shape - // x*A*s = x*(A*s) - // or even with add: - // (x*A + B)*s = x*(A*s) + (B*s) - // or we can combine these two tricks: - // relu(x*A)*s = relu(x*(A*s)) + // we can switch the existing postOps's order to take + // advantage of output scale if it's available: + // relu(x)*scale = relu(x*scale) + // or we can fuse it into previous one as long as they are + // compatible in shape + // x*A*s = x*(A*s) + // or even with add: + // (x*A + B)*s = x*(A*s) + (B*s) + // or we can combine these two tricks: + // relu(x*A)*s = relu(x*(A*s)) // - // we cannot implement all of them, so we just add the one - // that we observed in real models. + // we cannot implement all of them, so we just add the one + // that we observed in real models. if ((ops.len() == 0)) fuseIntoWeiScale = true; @@ -201,9 +203,9 @@ bool DnnlPostOpsComposerLegacy::appendShift(const std::vector& shift, boo } bool DnnlPostOpsComposerLegacy::appendLinear(const std::vector& scale, - const std::vector& shift, - bool isLastPostOp, - bool allowBinary) { + const std::vector& shift, + bool isLastPostOp, + bool allowBinary) { if (scale.size() == 1 && shift.size() == 1) { if (shift[0] == 0.0f) return appendScale(scale, isLastPostOp, allowBinary); diff --git a/src/plugins/intel_cpu/src/dnnl_postops_composer_legacy.h b/src/plugins/intel_cpu/src/dnnl_postops_composer_legacy.h index 82fdda94012f15..485fa31fb5d956 100644 --- a/src/plugins/intel_cpu/src/dnnl_postops_composer_legacy.h +++ b/src/plugins/intel_cpu/src/dnnl_postops_composer_legacy.h @@ -8,11 +8,10 @@ */ #pragma once -#include "dnnl_types.h" - #include #include "cpu_memory.h" +#include "dnnl_types.h" #include "memory_desc/cpu_memory_desc.h" #include "memory_desc/dnnl_blocked_memory_desc.h" #include "onednn/dnnl.h" @@ -39,7 +38,10 @@ class DnnlPostOpsComposerLegacy { void appendRoundHTE(); bool appendScale(const std::vector& scale, bool isLastPostOp, bool allowBinary = true); bool appendShift(const std::vector& shift, bool allowBinary = true); - bool appendLinear(const std::vector& scale, const std::vector& shift, bool isLastPostOp, bool allowBinary = true); + bool appendLinear(const std::vector& scale, + const std::vector& shift, + bool isLastPostOp, + bool allowBinary = true); void appendClip(const std::vector& low, const std::vector& high); const VectorDims& getOutputDims() { diff --git a/src/plugins/intel_cpu/src/edge.cpp b/src/plugins/intel_cpu/src/edge.cpp index c49b924477f694..1eabc6275bf4b0 100644 --- a/src/plugins/intel_cpu/src/edge.cpp +++ b/src/plugins/intel_cpu/src/edge.cpp @@ -3,8 +3,9 @@ // #include "edge.h" -#include "node.h" + #include "dnnl_extension_utils.h" +#include "node.h" #include "openvino/core/type/element_type.hpp" #include "openvino/util/pp.hpp" @@ -12,8 +13,11 @@ using namespace dnnl; namespace ov { namespace intel_cpu { -Edge::Edge(const NodePtr &parent, const NodePtr &child, int pr_port, int ch_port) : - parent(parent), child(child), parent_port(pr_port), child_port(ch_port) {} +Edge::Edge(const NodePtr& parent, const NodePtr& child, int pr_port, int ch_port) + : parent(parent), + child(child), + parent_port(pr_port), + child_port(ch_port) {} const NodePtr Edge::getParent() const { auto parentPtr = parent.lock(); @@ -39,14 +43,14 @@ bool Edge::isDropped() const { auto parent_ptr = parent.lock(); if (parent_ptr) { - for (auto &edge : parent_ptr->childEdges) + for (auto& edge : parent_ptr->childEdges) if (edge.lock().get() == this) not_in_parent = false; } auto child_ptr = child.lock(); if (child_ptr) { - for (auto &edge : child_ptr->parentEdges) + for (auto& edge : child_ptr->parentEdges) if (edge.lock().get() == this) not_in_child = false; } @@ -131,8 +135,8 @@ bool Edge::enforceReorder() { } static inline bool isPhycicalMemCompatible(const MemoryDesc& lhsMemDesc, const MemoryDesc& rhsMemDesc) { - if (!lhsMemDesc.isDefined() || !rhsMemDesc.isDefined() || - !(lhsMemDesc.getType() & MemoryDescType::Blocked) || !(rhsMemDesc.getType() & MemoryDescType::Blocked) || + if (!lhsMemDesc.isDefined() || !rhsMemDesc.isDefined() || !(lhsMemDesc.getType() & MemoryDescType::Blocked) || + !(rhsMemDesc.getType() & MemoryDescType::Blocked) || (lhsMemDesc.getType() == DnnlBlocked && !lhsMemDesc.as()->hasEmptyExtraData()) || (rhsMemDesc.getType() == DnnlBlocked && !rhsMemDesc.as()->hasEmptyExtraData())) return false; @@ -140,13 +144,21 @@ static inline bool isPhycicalMemCompatible(const MemoryDesc& lhsMemDesc, const M const auto lhsBlockMemDesc = lhsMemDesc.as(); const auto rhsBlockMemDesc = rhsMemDesc.as(); - if (lhsBlockMemDesc->getShape() != rhsBlockMemDesc->getShape() || lhsBlockMemDesc->getPrecision() != rhsBlockMemDesc->getPrecision()) + if (lhsBlockMemDesc->getShape() != rhsBlockMemDesc->getShape() || + lhsBlockMemDesc->getPrecision() != rhsBlockMemDesc->getPrecision()) return false; // dims padding check - bool isZeroDimsPaddings = - std::all_of(lhsBlockMemDesc->getOffsetPaddingToData().begin(), lhsBlockMemDesc->getOffsetPaddingToData().end(), [](size_t x){ return x == 0; }) && - std::all_of(rhsBlockMemDesc->getOffsetPaddingToData().begin(), rhsBlockMemDesc->getOffsetPaddingToData().end(), [](size_t x){ return x == 0; }); + bool isZeroDimsPaddings = std::all_of(lhsBlockMemDesc->getOffsetPaddingToData().begin(), + lhsBlockMemDesc->getOffsetPaddingToData().end(), + [](size_t x) { + return x == 0; + }) && + std::all_of(rhsBlockMemDesc->getOffsetPaddingToData().begin(), + rhsBlockMemDesc->getOffsetPaddingToData().end(), + [](size_t x) { + return x == 0; + }); bool isSameElementsCount = lhsBlockMemDesc->getPaddedElementsCount() == rhsBlockMemDesc->getPaddedElementsCount(); if (!isZeroDimsPaddings || !isSameElementsCount) return false; @@ -161,7 +173,8 @@ static inline bool isPhycicalMemCompatible(const MemoryDesc& lhsMemDesc, const M std::vector lhsStridesDefault(lhsBlockDims.size()); lhsStridesDefault[lhsBlockDims.size() - 1] = 1; for (size_t i = 2; i <= lhsBlockDims.size(); i++) { - lhsStridesDefault[lhsBlockDims.size() - i] = lhsStridesDefault[lhsBlockDims.size() - (i - 1)] * lhsBlockDims[lhsBlockDims.size() - (i - 1)]; + lhsStridesDefault[lhsBlockDims.size() - i] = + lhsStridesDefault[lhsBlockDims.size() - (i - 1)] * lhsBlockDims[lhsBlockDims.size() - (i - 1)]; } auto rhsBlockDims = rhsBlockMemDesc->getBlockDims(); @@ -169,11 +182,11 @@ static inline bool isPhycicalMemCompatible(const MemoryDesc& lhsMemDesc, const M rhsStridesDefault[rhsBlockDims.size() - 1] = 1; for (size_t i = 2; i <= rhsBlockDims.size(); i++) { rhsStridesDefault[rhsBlockDims.size() - i] = - rhsStridesDefault[rhsBlockDims.size() - (i - 1)] * rhsBlockDims[rhsBlockDims.size() - (i - 1)]; + rhsStridesDefault[rhsBlockDims.size() - (i - 1)] * rhsBlockDims[rhsBlockDims.size() - (i - 1)]; } - // this check needed to avoid inserting unnecessary reorders if the memory is used in place and the batch size is equal to 1 - // in nodes like concate and split + // this check needed to avoid inserting unnecessary reorders if the memory is used in place and the batch size is + // equal to 1 in nodes like concate and split size_t lhsSkipAxis = lhsBlockDims.size() > 0 && lhsBlockDims[0] == 1 ? 0 : Shape::UNDEFINED_DIM; size_t rhsSkipAxis = rhsBlockDims.size() > 0 && rhsBlockDims[0] == 1 ? 0 : Shape::UNDEFINED_DIM; @@ -219,8 +232,10 @@ Edge::ReorderStatus Edge::needReorder() { // Check whether the child node may accept the parent produced tensor if (!outPortDesc->isCompatible(*inputPortDesc)) { - // Performance optimization which exploit the fact that some tensors do not need actual data reordering to be read using different descriptors - if (isPhycicalMemCompatible(*inputPortDesc->getMemDesc(), *outPortDesc->getMemDesc()) && !getParent()->isConstant()) { + // Performance optimization which exploit the fact that some tensors do not need actual data reordering to be + // read using different descriptors + if (isPhycicalMemCompatible(*inputPortDesc->getMemDesc(), *outPortDesc->getMemDesc()) && + !getParent()->isConstant()) { optimized = true; } else { return ReorderStatus::Regular; @@ -297,8 +312,8 @@ std::string Edge::hash() const { std::stringstream result; - return parentPtr->getName() + "_" + std::to_string(parent_port) + "_" + - childPtr->getName() + "_" + std::to_string(child_port); + return parentPtr->getName() + "_" + std::to_string(parent_port) + "_" + childPtr->getName() + "_" + + std::to_string(child_port); } void Edge::externalAllocate(WeightsSharing::Ptr weightsCache) { @@ -306,10 +321,13 @@ void Edge::externalAllocate(WeightsSharing::Ptr weightsCache) { return; if (weightsCache) { - auto alloc = [this] () { + auto alloc = [this]() { auto allocateFunc = [this](const MemoryDesc& inputDesc) -> MemoryPtr { auto parentPtr = getParent(); - return std::make_shared(parentPtr->getEngine(), inputDesc, nullptr, false); // no pads zeroing + return std::make_shared(parentPtr->getEngine(), + inputDesc, + nullptr, + false); // no pads zeroing }; allocateCommon(allocateFunc); @@ -424,7 +442,7 @@ const MemoryDesc& Edge::getDesc() const { return getInputDesc(); } -const IMemory &Edge::getMemory() { +const IMemory& Edge::getMemory() { auto memPtr = getMemoryPtr(); OPENVINO_ASSERT(memPtr != nullptr, " Dereferencing NULL memory in edge: ", *this); return *memPtr; @@ -434,7 +452,7 @@ MemoryPtr Edge::getMemoryPtr() const { return memoryPtr; } -void Edge::sharedMemFrom(const EdgePtr &edge) { +void Edge::sharedMemFrom(const EdgePtr& edge) { memoryFromEdge = edge; DEBUG_LOG(*this, " sharedMemFrom ", *edge); status = Status::NotAllocated; @@ -474,10 +492,8 @@ void Edge::init() { DEBUG_LOG(*this, " getBaseEdge() return itself"); changeStatus(Status::NeedAllocation); } else { - if (Type::Input == edgePtr->getParent()->getType() && - Type::MemoryInput != getParent()->getType() && - edgePtr->getParent()->isConstant() && - !edgePtr->getChild()->isConstant()) { + if (Type::Input == edgePtr->getParent()->getType() && Type::MemoryInput != getParent()->getType() && + edgePtr->getParent()->isConstant() && !edgePtr->getChild()->isConstant()) { changeStatus(Status::NeedAllocation); DEBUG_LOG(*this, " edge inplace from ", *edgePtr, " is broken!"); return; @@ -505,11 +521,11 @@ EdgePtr Edge::getBaseEdge(int look) { if ((childInPlacePort >= 0) && (look & LOOK_DOWN)) { auto ch_edges = getChild()->getChildEdgesAtPort(childInPlacePort); - auto &next_ch_edge = ch_edges[0]; + auto& next_ch_edge = ch_edges[0]; // Multiple connection to some out port // Will try to find inplace consumer - for (auto &ch_edge : ch_edges) { + for (auto& ch_edge : ch_edges) { if (ch_edge->getChild()->inPlaceInputPort(ch_edge->getOutputNum()) >= 0) { next_ch_edge = ch_edge; // To align with upstream-inplace, we stop searching once found the first inplace consumer @@ -525,14 +541,16 @@ EdgePtr Edge::getBaseEdge(int look) { for (auto edge : edgesForSamePort) { if (edge.get() != this) { // Return once found the first inplace consumer - if (edge->inPlace()) return edge; + if (edge->inPlace()) + return edge; } } // Return the first output edge as the base if there is no inPlace consumers // thus benefits zero-copy of outputs. for (auto edge : edgesForSamePort) { - if (Type::Output == edge->getChild()->getType()) return edge; + if (Type::Output == edge->getChild()->getType()) + return edge; } return edgesForSamePort[0]; @@ -579,7 +597,7 @@ NodePtr Edge::modifiedInPlace() const { for (size_t i = 0; i < outConfs.size(); ++i) { const auto& conf = outConfs[i]; if (childPort < 0 || conf.inPlace() != childPort || - Type::MemoryInput == childNode->getType()) { //exception type, it doesn't modify memory + Type::MemoryInput == childNode->getType()) { // exception type, it doesn't modify memory continue; } if (childNode->isExecutable()) { @@ -599,12 +617,14 @@ NodePtr Edge::modifiedInPlace() const { return nullptr; } -std::ostream& operator<<(std::ostream &os, const Edge& edge) { - return os << "(" << edge.getParent()->getName() << ")" << "[" << edge.getInputNum() << "] " +std::ostream& operator<<(std::ostream& os, const Edge& edge) { + return os << "(" << edge.getParent()->getName() << ")" + << "[" << edge.getInputNum() << "] " << "<->" - << "(" << edge.getChild()->getName() << ")" << "[" << edge.getOutputNum() << "]" + << "(" << edge.getChild()->getName() << ")" + << "[" << edge.getOutputNum() << "]" << ":" << Edge::statusToString(edge.getStatus()); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/edge.h b/src/plugins/intel_cpu/src/edge.h index 5c418b2665924d..38f49ff00db075 100644 --- a/src/plugins/intel_cpu/src/edge.h +++ b/src/plugins/intel_cpu/src/edge.h @@ -4,15 +4,15 @@ #pragma once +#include +#include + #include "cpu_shape.h" #include "internal_properties.hpp" #include "memory_desc/cpu_memory_desc.h" #include "nodes/node_config.h" #include "weights_cache.hpp" -#include -#include - namespace ov { namespace intel_cpu { @@ -24,23 +24,11 @@ using EdgeWeakPtr = std::weak_ptr; class Edge { public: - Edge(const std::shared_ptr& parent, - const std::shared_ptr& child, - int pr_port = 0, int ch_port = 0); - - enum class Status { - Uninitialized, - NeedAllocation, - NotAllocated, - Allocated, - Validated - }; - - enum class ReorderStatus { - Regular = 0, - Optimized = 1, - No = 2 - }; + Edge(const std::shared_ptr& parent, const std::shared_ptr& child, int pr_port = 0, int ch_port = 0); + + enum class Status { Uninitialized, NeedAllocation, NotAllocated, Allocated, Validated }; + + enum class ReorderStatus { Regular = 0, Optimized = 1, No = 2 }; enum LOOK { LOOK_UP = 1, LOOK_DOWN = 2, LOOK_BOTH = LOOK_UP | LOOK_DOWN }; @@ -52,15 +40,15 @@ class Edge { #define CASE(_status) \ case Status::_status: \ return #_status; - switch (status) { - CASE(Uninitialized); - CASE(NeedAllocation); - CASE(NotAllocated); - CASE(Allocated); - CASE(Validated); - } + switch (status) { + CASE(Uninitialized); + CASE(NeedAllocation); + CASE(NotAllocated); + CASE(Allocated); + CASE(Validated); + } #undef CASE - return "Unexpected"; + return "Unexpected"; } void changeStatus(Status state); @@ -87,7 +75,9 @@ class Edge { int getInputNum() const; int getOutputNum() const; - void setChildPort(const size_t port) { child_port = port; } + void setChildPort(const size_t port) { + child_port = port; + } void sharedMemFrom(const EdgePtr& edge); EdgePtr getSharedEdge() const; @@ -126,8 +116,7 @@ class Edge { friend class Graph; }; -std::ostream& operator<<(std::ostream &os, const Edge& edge); - -} // namespace intel_cpu -} // namespace ov +std::ostream& operator<<(std::ostream& os, const Edge& edge); +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.cpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.cpp index 01af9dbde7fe01..c2c6ddf6f271fc 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.cpp @@ -4,9 +4,10 @@ #ifdef CPU_DEBUG_CAPS -#include "debug_capabilities.hpp" -#include -#include +# include "debug_capabilities.hpp" + +# include +# include namespace ov { namespace intel_cpu { @@ -14,25 +15,26 @@ namespace intel_cpu { using namespace Xbyak; using namespace dnnl::impl::cpu::x64; -template void RegPrinter::print(jit_generator &h, Xmm reg, const char *name); -template void RegPrinter::print(jit_generator &h, Xmm reg, const char *name); -template void RegPrinter::print(jit_generator &h, Ymm reg, const char *name); -template void RegPrinter::print(jit_generator &h, Ymm reg, const char *name); -template void RegPrinter::print(jit_generator &h, Zmm reg, const char *name); -template void RegPrinter::print(jit_generator &h, Zmm reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg64 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg64 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg32 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg32 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg16 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg16 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg8 reg, const char *name); -template void RegPrinter::print(jit_generator &h, Reg8 reg, const char *name); +template void RegPrinter::print(jit_generator& h, Xmm reg, const char* name); +template void RegPrinter::print(jit_generator& h, Xmm reg, const char* name); +template void RegPrinter::print(jit_generator& h, Ymm reg, const char* name); +template void RegPrinter::print(jit_generator& h, Ymm reg, const char* name); +template void RegPrinter::print(jit_generator& h, Zmm reg, const char* name); +template void RegPrinter::print(jit_generator& h, Zmm reg, const char* name); +template void RegPrinter::print(jit_generator& h, Reg64 reg, const char* name); +template void RegPrinter::print(jit_generator& h, Reg64 reg, const char* name); +template void RegPrinter::print(jit_generator& h, Reg32 reg, const char* name); +template void RegPrinter::print(jit_generator& h, Reg32 reg, const char* name); +template void RegPrinter::print(jit_generator& h, Reg16 reg, const char* name); +template void RegPrinter::print(jit_generator& h, Reg16 reg, const char* name); +template void RegPrinter::print(jit_generator& h, Reg8 reg, const char* name); +template void RegPrinter::print(jit_generator& h, Reg8 reg, const char* name); template -void RegPrinter::print_reg_prc(const char *name, const char *ori_name, T *ptr) { +void RegPrinter::print_reg_prc(const char* name, const char* ori_name, T* ptr) { std::stringstream ss; - if (name) ss << name << " | "; + if (name) + ss << name << " | "; ss << ori_name << ": "; if (std::is_floating_point::value) { ss << *ptr; @@ -48,9 +50,10 @@ void RegPrinter::print_reg_prc(const char *name, const char *ori_name, T *ptr) { } template -void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr) { +void RegPrinter::print_vmm_prc(const char* name, const char* ori_name, PRC_T* ptr) { std::stringstream ss; - if (name) ss << name << " | "; + if (name) + ss << name << " | "; ss << ori_name << ": {" << ptr[0]; for (size_t i = 1; i < vlen / sizeof(float); i++) { ss << ", " << ptr[i]; @@ -58,15 +61,15 @@ void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, PRC_T *pt ss << "}" << std::endl; std::cout << ss.str(); } -template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, float *ptr); -template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, float *ptr); -template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, float *ptr); -template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, int *ptr); -template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, int *ptr); -template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, int *ptr); +template void RegPrinter::print_vmm_prc(const char* name, const char* ori_name, float* ptr); +template void RegPrinter::print_vmm_prc(const char* name, const char* ori_name, float* ptr); +template void RegPrinter::print_vmm_prc(const char* name, const char* ori_name, float* ptr); +template void RegPrinter::print_vmm_prc(const char* name, const char* ori_name, int* ptr); +template void RegPrinter::print_vmm_prc(const char* name, const char* ori_name, int* ptr); +template void RegPrinter::print_vmm_prc(const char* name, const char* ori_name, int* ptr); template -struct vmm_traits{}; +struct vmm_traits {}; template <> struct vmm_traits { @@ -87,7 +90,7 @@ struct vmm_traits { }; template -void RegPrinter::save_vmm(jit_generator &h) { +void RegPrinter::save_vmm(jit_generator& h) { h.sub(h.rsp, vmm_traits::vmm_len * vmm_traits::vmm_cnt); for (size_t i = 0; i < vmm_traits::vmm_cnt; i++) { h.uni_vmovups(h.ptr[h.rsp + i * vmm_traits::vmm_len], T(i)); @@ -95,52 +98,52 @@ void RegPrinter::save_vmm(jit_generator &h) { } template -void RegPrinter::restore_vmm(jit_generator &h) { +void RegPrinter::restore_vmm(jit_generator& h) { for (size_t i = 0; i < vmm_traits::vmm_cnt; i++) { h.uni_vmovups(T(i), h.ptr[h.rsp + i * vmm_traits::vmm_len]); } h.add(h.rsp, vmm_traits::vmm_len * vmm_traits::vmm_cnt); } -void RegPrinter::save_reg(jit_generator &h) { +void RegPrinter::save_reg(jit_generator& h) { h.sub(h.rsp, reg_len * reg_cnt); for (size_t i = 0; i < reg_cnt; i++) { h.mov(h.ptr[h.rsp + i * reg_len], Reg64(i)); } } -void RegPrinter::restore_reg(jit_generator &h) { +void RegPrinter::restore_reg(jit_generator& h) { for (size_t i = 0; i < reg_cnt; i++) { h.mov(Reg64(i), h.ptr[h.rsp + i * reg_len]); } h.add(h.rsp, reg_len * reg_cnt); } -void RegPrinter::preamble(jit_generator &h) { +void RegPrinter::preamble(jit_generator& h) { save_reg(h); - mayiuse(cpu_isa_t::avx512_core) ? save_vmm(h) : (mayiuse(cpu_isa_t::avx2) ? - save_vmm(h) : save_vmm(h)); + mayiuse(cpu_isa_t::avx512_core) ? save_vmm(h) + : (mayiuse(cpu_isa_t::avx2) ? save_vmm(h) : save_vmm(h)); } -void RegPrinter::postamble(jit_generator &h) { - mayiuse(cpu_isa_t::avx512_core) ? restore_vmm(h) : (mayiuse(cpu_isa_t::avx2) ? - restore_vmm(h) : restore_vmm(h)); +void RegPrinter::postamble(jit_generator& h) { + mayiuse(cpu_isa_t::avx512_core) ? restore_vmm(h) + : (mayiuse(cpu_isa_t::avx2) ? restore_vmm(h) : restore_vmm(h)); restore_reg(h); } // ABI requires 16-bype stack alignment before a call -void RegPrinter::align_rsp(jit_generator &h) { +void RegPrinter::align_rsp(jit_generator& h) { constexpr int alignment = 16; h.mov(h.r15, h.rsp); h.and_(h.rsp, ~(alignment - 1)); } -void RegPrinter::restore_rsp(jit_generator &h) { +void RegPrinter::restore_rsp(jit_generator& h) { h.mov(h.rsp, h.r15); } template -void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) { +void RegPrinter::print_vmm(jit_generator& h, REG_T vmm, const char* name) { preamble(h); h.push(h.rax); @@ -181,7 +184,7 @@ void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) { } template -void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) { +void RegPrinter::print_reg(jit_generator& h, REG_T reg, const char* name) { preamble(h); h.push(h.rax); @@ -213,8 +216,7 @@ void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) { postamble(h); } -} // namespace intel_cpu -} // namespace ov - +} // namespace intel_cpu +} // namespace ov -#endif // CPU_DEBUG_CAPS +#endif // CPU_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.hpp index fd7135b17bf5b9..dcac847dfd1e0f 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/debug_capabilities.hpp @@ -6,7 +6,7 @@ #ifdef CPU_DEBUG_CAPS -#include "cpu/x64/jit_generator.hpp" +# include "cpu/x64/jit_generator.hpp" namespace ov { namespace intel_cpu { @@ -56,42 +56,44 @@ namespace intel_cpu { class RegPrinter { public: using jit_generator = dnnl::impl::cpu::x64::jit_generator; - template ::value, int>::type = 0> - static void print(jit_generator &h, REG_T reg, const char *name = nullptr) { + template ::value, int>::type = 0> + static void print(jit_generator& h, REG_T reg, const char* name = nullptr) { print_vmm(h, reg, name); } - template ::value, int>::type = 0> - static void print(jit_generator &h, REG_T reg, const char *name = nullptr) { + template ::value, int>::type = 0> + static void print(jit_generator& h, REG_T reg, const char* name = nullptr) { print_reg(h, reg, name); } private: RegPrinter() {} template - static void print_vmm(jit_generator &h, REG_T vmm, const char *name); + static void print_vmm(jit_generator& h, REG_T vmm, const char* name); template - static void print_reg(jit_generator &h, REG_T reg, const char *name); + static void print_reg(jit_generator& h, REG_T reg, const char* name); template - static void print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr); + static void print_vmm_prc(const char* name, const char* ori_name, PRC_T* ptr); template - static void print_reg_prc(const char *name, const char *ori_name, T *val); - static void preamble(jit_generator &h); - static void postamble(jit_generator &h); + static void print_reg_prc(const char* name, const char* ori_name, T* val); + static void preamble(jit_generator& h); + static void postamble(jit_generator& h); template - static void save_vmm(jit_generator &h); + static void save_vmm(jit_generator& h); template - static void restore_vmm(jit_generator &h); - static void save_reg(jit_generator &h); - static void restore_reg(jit_generator &h); - static void align_rsp(jit_generator &h); - static void restore_rsp(jit_generator &h); + static void restore_vmm(jit_generator& h); + static void save_reg(jit_generator& h); + static void restore_reg(jit_generator& h); + static void align_rsp(jit_generator& h); + static void restore_rsp(jit_generator& h); static constexpr size_t reg_len = 8; static constexpr size_t reg_cnt = 16; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov -#endif // CPU_DEBUG_CAPS +#endif // CPU_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp index 43a2c2eb6b045f..2bfbaa68880aa8 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp @@ -11,14 +11,18 @@ namespace intel_cpu { class jit_uni_vcvtneps2bf16 : public jit_emitter { public: - jit_uni_vcvtneps2bf16(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type exec_prc = ov::element::bf16) : jit_emitter(host, host_isa, exec_prc) { + jit_uni_vcvtneps2bf16(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type exec_prc = ov::element::bf16) + : jit_emitter(host, host_isa, exec_prc) { if (!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16) && !dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni_2)) prepare_table(); } - size_t get_inputs_num() const override { return 1; } + size_t get_inputs_num() const override { + return 1; + } private: void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override { @@ -36,7 +40,8 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter { template void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { using namespace Xbyak; - using Vmm = typename dnnl::impl::utils::conditional3::type; + using Vmm = typename dnnl::impl::utils:: + conditional3::type; Vmm in = Vmm(in_vec_idxs[0]); @@ -79,7 +84,7 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter { h->uni_vpackusdw(aux, aux, aux); if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx2) { - h->vpermq(Ymm(aux.getIdx()), Ymm(aux.getIdx()), 0xD8); //11 01 10 00 + h->vpermq(Ymm(aux.getIdx()), Ymm(aux.getIdx()), 0xD8); // 11 01 10 00 h->vextracti128(out, Ymm(aux.getIdx()), 0); } else { h->uni_vmovups(out, aux); @@ -123,5 +128,5 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter { } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_conversion_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_conversion_emitters.cpp index 544960008c9158..2e90af39fb9cf1 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_conversion_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_conversion_emitters.cpp @@ -6,7 +6,6 @@ #include "utils/bfloat16.hpp" - using namespace dnnl::impl::utils; using namespace dnnl::impl; using namespace dnnl::impl::cpu::x64; @@ -15,19 +14,23 @@ using namespace Xbyak; namespace ov { namespace intel_cpu { -jit_convert_emitter::jit_convert_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) { +jit_convert_emitter::jit_convert_emitter(jit_generator* host, + cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { input_type = node->get_input_element_type(0); output_type = node->get_output_element_type(0); if (output_type == ov::element::bf16) - uni_vcvtneps2bf16.reset(new jit_uni_vcvtneps2bf16(host, host_isa)); + uni_vcvtneps2bf16.reset(new jit_uni_vcvtneps2bf16(host, host_isa)); } void jit_convert_emitter::validate_types() const { auto is_supported_type = [this](const ov::element::Type& type) { - return any_of(supported_types.begin(), supported_types.end(), - [&type](const ov::element::Type& supported_type) { return supported_type == type; } ); + return any_of(supported_types.begin(), supported_types.end(), [&type](const ov::element::Type& supported_type) { + return supported_type == type; + }); }; if (!is_supported_type(input_type)) @@ -36,7 +39,9 @@ void jit_convert_emitter::validate_types() const { OV_CPU_JIT_EMITTER_THROW("Unsupported output type: ", output_type.get_type_name()); } -size_t jit_convert_emitter::get_inputs_num() const { return 1; } +size_t jit_convert_emitter::get_inputs_num() const { + return 1; +} void jit_convert_emitter::emit_data() const { jit_emitter::emit_data(); @@ -45,19 +50,22 @@ void jit_convert_emitter::emit_data() const { } template -void jit_convert_emitter::float2bfloat(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_convert_emitter::float2bfloat(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src = Vmm(in_vec_idxs[0]); - Vmm vmm_dst = Vmm(out_vec_idxs[0]); + Vmm vmm_dst = Vmm(out_vec_idxs[0]); if (!uni_vcvtneps2bf16) OV_CPU_JIT_EMITTER_THROW("Converter from float to bf16 isn't initialized!"); uni_vcvtneps2bf16->emit_code({static_cast(vmm_src.getIdx())}, {static_cast(vmm_dst.getIdx())}); } -jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *host, cpu_isa_t host_isa, - const std::shared_ptr& node, ov::element::Type exec_prc) - : jit_convert_emitter(host, host_isa, node, exec_prc) { +jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator* host, + cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_convert_emitter(host, host_isa, node, exec_prc) { prepare_table(); } @@ -66,7 +74,8 @@ bool jit_convert_truncation_emitter::is_i8_and_u8_case() const { one_of(output_type, ov::element::i8, ov::element::u8); } -void jit_convert_truncation_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_convert_truncation_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { validate_types(); if (host_isa_ == cpu::x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); @@ -80,10 +89,11 @@ void jit_convert_truncation_emitter::emit_impl(const std::vector &in_vec } template -void jit_convert_truncation_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_convert_truncation_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src = Vmm(in_vec_idxs[0]); - Vmm vmm_dst = Vmm(out_vec_idxs[0]); + Vmm vmm_dst = Vmm(out_vec_idxs[0]); Xmm xmm_dst = Xmm(out_vec_idxs[0]); Ymm ymm_dst = Ymm(out_vec_idxs[0]); @@ -97,95 +107,95 @@ void jit_convert_truncation_emitter::emit_isa(const std::vector &in_vec_ } switch (input_type) { - case ov::element::f32: - if (one_of(output_type, ov::element::i32, ov::element::i8, ov::element::u8)) - h->uni_vcvttps2dq(vmm_dst, vmm_src); - break; - case ov::element::i32: - if (one_of(output_type, ov::element::f32, ov::element::bf16, ov::element::f16)) - h->uni_vcvtdq2ps(vmm_dst, vmm_src); - break; - case ov::element::bf16: - h->vpmovzxwd(vmm_dst, vmm_src); - h->uni_vpslld(vmm_dst, vmm_dst, 16); - if (one_of(output_type, ov::element::i32, ov::element::i8, ov::element::u8)) - h->uni_vcvttps2dq(vmm_dst, vmm_dst); - break; - case ov::element::f16: - if (isa == dnnl::impl::cpu::x64::avx512_core) - h->vcvtph2ps(vmm_dst, Ymm(vmm_src.getIdx())); - else - h->vcvtph2ps(vmm_dst, - Xmm(vmm_src.getIdx())); // for avx2_vnni_2? - if (one_of(output_type, ov::element::i32, ov::element::i8, ov::element::u8)) - h->uni_vcvttps2dq(vmm_dst, vmm_dst); - break; - case ov::element::i8: - h->uni_vpmovsxbd(vmm_dst, vmm_src); - break; - case ov::element::u8: - h->uni_vpmovzxbd(vmm_dst, vmm_src); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported input data type"); + case ov::element::f32: + if (one_of(output_type, ov::element::i32, ov::element::i8, ov::element::u8)) + h->uni_vcvttps2dq(vmm_dst, vmm_src); + break; + case ov::element::i32: + if (one_of(output_type, ov::element::f32, ov::element::bf16, ov::element::f16)) + h->uni_vcvtdq2ps(vmm_dst, vmm_src); + break; + case ov::element::bf16: + h->vpmovzxwd(vmm_dst, vmm_src); + h->uni_vpslld(vmm_dst, vmm_dst, 16); + if (one_of(output_type, ov::element::i32, ov::element::i8, ov::element::u8)) + h->uni_vcvttps2dq(vmm_dst, vmm_dst); + break; + case ov::element::f16: + if (isa == dnnl::impl::cpu::x64::avx512_core) + h->vcvtph2ps(vmm_dst, Ymm(vmm_src.getIdx())); + else + h->vcvtph2ps(vmm_dst, + Xmm(vmm_src.getIdx())); // for avx2_vnni_2? + if (one_of(output_type, ov::element::i32, ov::element::i8, ov::element::u8)) + h->uni_vcvttps2dq(vmm_dst, vmm_dst); + break; + case ov::element::i8: + h->uni_vpmovsxbd(vmm_dst, vmm_src); + break; + case ov::element::u8: + h->uni_vpmovzxbd(vmm_dst, vmm_src); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input data type"); } switch (output_type) { - case ov::element::f32: - if (!one_of(input_type, ov::element::i32, ov::element::bf16, ov::element::f16)) { + case ov::element::f32: + if (!one_of(input_type, ov::element::i32, ov::element::bf16, ov::element::f16)) { + h->uni_vcvtdq2ps(vmm_dst, vmm_dst); + } + break; + case ov::element::i32: + break; + case ov::element::bf16: + if (input_type == ov::element::f32) { + float2bfloat({static_cast(vmm_src.getIdx())}, {static_cast(vmm_dst.getIdx())}); + } else { + if (one_of(input_type, ov::element::i8, ov::element::u8)) { h->uni_vcvtdq2ps(vmm_dst, vmm_dst); } - break; - case ov::element::i32: - break; - case ov::element::bf16: - if (input_type == ov::element::f32) { - float2bfloat({static_cast(vmm_src.getIdx())}, {static_cast(vmm_dst.getIdx())}); - } else { - if (one_of(input_type, ov::element::i8, ov::element::u8)) { - h->uni_vcvtdq2ps(vmm_dst, vmm_dst); - } - float2bfloat({static_cast(vmm_dst.getIdx())}, {static_cast(vmm_dst.getIdx())}); - } - break; - case ov::element::f16: - if (input_type == ov::element::f32) { - if (isa == dnnl::impl::cpu::x64::avx512_core) - h->vcvtps2ph(ymm_dst, vmm_src, 0x4); - else - h->vcvtps2ph(xmm_dst, vmm_src, 0x4); - } else { - if (one_of(input_type, ov::element::i8, ov::element::u8)) { - h->uni_vcvtdq2ps(vmm_dst, vmm_dst); - } - if (isa == dnnl::impl::cpu::x64::avx512_core) - h->vcvtps2ph(ymm_dst, vmm_dst, 0x4); - else - h->vcvtps2ph(xmm_dst, vmm_dst, 0x4); - } - break; - case ov::element::i8: - case ov::element::u8: - if (input_type == ov::element::i32) { - dword2int8({static_cast(vmm_src.getIdx())}, {static_cast(vmm_dst.getIdx())}); - } else { - dword2int8({static_cast(vmm_dst.getIdx())}, {static_cast(vmm_dst.getIdx())}); + float2bfloat({static_cast(vmm_dst.getIdx())}, {static_cast(vmm_dst.getIdx())}); + } + break; + case ov::element::f16: + if (input_type == ov::element::f32) { + if (isa == dnnl::impl::cpu::x64::avx512_core) + h->vcvtps2ph(ymm_dst, vmm_src, 0x4); + else + h->vcvtps2ph(xmm_dst, vmm_src, 0x4); + } else { + if (one_of(input_type, ov::element::i8, ov::element::u8)) { + h->uni_vcvtdq2ps(vmm_dst, vmm_dst); } - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported output data type"); + if (isa == dnnl::impl::cpu::x64::avx512_core) + h->vcvtps2ph(ymm_dst, vmm_dst, 0x4); + else + h->vcvtps2ph(xmm_dst, vmm_dst, 0x4); + } + break; + case ov::element::i8: + case ov::element::u8: + if (input_type == ov::element::i32) { + dword2int8({static_cast(vmm_src.getIdx())}, {static_cast(vmm_dst.getIdx())}); + } else { + dword2int8({static_cast(vmm_dst.getIdx())}, {static_cast(vmm_dst.getIdx())}); + } + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported output data type"); } } void jit_convert_truncation_emitter::register_table_entries() { - if (host_isa_ == dnnl::impl::cpu::x64::avx2 && - one_of(output_type, ov::element::i8, ov::element::u8) && + if (host_isa_ == dnnl::impl::cpu::x64::avx2 && one_of(output_type, ov::element::i8, ov::element::u8) && !is_i8_and_u8_case()) push_arg_entry_of("mask_byte", 0x000000ff, true); } template -void jit_convert_truncation_emitter::dword2int8(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_convert_truncation_emitter::dword2int8(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src = Vmm(in_vec_idxs[0]); @@ -204,12 +214,14 @@ void jit_convert_truncation_emitter::dword2int8(const std::vector &in_ve } } -jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator *host, cpu_isa_t host_isa, - const std::shared_ptr& node, ov::element::Type exec_prc) - : jit_convert_emitter(host, host_isa, node, exec_prc) { -} +jit_convert_saturation_emitter::jit_convert_saturation_emitter(jit_generator* host, + cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_convert_emitter(host, host_isa, node, exec_prc) {} -void jit_convert_saturation_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_convert_saturation_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { validate_types(); if (host_isa_ == cpu::x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); @@ -223,10 +235,11 @@ void jit_convert_saturation_emitter::emit_impl(const std::vector &in_vec } template -void jit_convert_saturation_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_convert_saturation_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src = Vmm(in_vec_idxs[0]); - Vmm vmm_dst = Vmm(out_vec_idxs[0]); + Vmm vmm_dst = Vmm(out_vec_idxs[0]); Xmm xmm_dst = Xmm(out_vec_idxs[0]); Ymm ymm_dst = Ymm(out_vec_idxs[0]); @@ -237,88 +250,94 @@ void jit_convert_saturation_emitter::emit_isa(const std::vector &in_vec_ } switch (input_type) { - case ov::element::f32: - if (one_of(output_type, ov::element::i32, ov::element::i8, ov::element::u8)) - h->uni_vcvtps2dq(vmm_dst, vmm_src); - break; - case ov::element::i32: - if (one_of(output_type, ov::element::f32, ov::element::bf16, ov::element::f16)) - h->uni_vcvtdq2ps(vmm_dst, vmm_src); - break; - case ov::element::bf16: - h->vpmovzxwd(vmm_dst, vmm_src); - h->uni_vpslld(vmm_dst, vmm_dst, 16); - if (one_of(output_type, ov::element::i32, ov::element::i8, ov::element::u8)) - h->uni_vcvttps2dq(vmm_dst, vmm_dst); - break; - case ov::element::f16: - if (isa == dnnl::impl::cpu::x64::avx512_core) - h->vcvtph2ps(vmm_dst, Ymm(vmm_src.getIdx())); - else - h->vcvtph2ps(vmm_dst, - Xmm(vmm_src.getIdx())); // for avx2_vnni_2? - if (one_of(output_type, ov::element::i32, ov::element::i8, ov::element::u8)) - h->uni_vcvttps2dq(vmm_dst, vmm_dst); - break; - case ov::element::i8: - h->uni_vpmovsxbd(vmm_dst, vmm_src); - break; - case ov::element::u8: - h->uni_vpmovzxbd(vmm_dst, vmm_src); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported input data type"); + case ov::element::f32: + if (one_of(output_type, ov::element::i32, ov::element::i8, ov::element::u8)) + h->uni_vcvtps2dq(vmm_dst, vmm_src); + break; + case ov::element::i32: + if (one_of(output_type, ov::element::f32, ov::element::bf16, ov::element::f16)) + h->uni_vcvtdq2ps(vmm_dst, vmm_src); + break; + case ov::element::bf16: + h->vpmovzxwd(vmm_dst, vmm_src); + h->uni_vpslld(vmm_dst, vmm_dst, 16); + if (one_of(output_type, ov::element::i32, ov::element::i8, ov::element::u8)) + h->uni_vcvttps2dq(vmm_dst, vmm_dst); + break; + case ov::element::f16: + if (isa == dnnl::impl::cpu::x64::avx512_core) + h->vcvtph2ps(vmm_dst, Ymm(vmm_src.getIdx())); + else + h->vcvtph2ps(vmm_dst, + Xmm(vmm_src.getIdx())); // for avx2_vnni_2? + if (one_of(output_type, ov::element::i32, ov::element::i8, ov::element::u8)) + h->uni_vcvttps2dq(vmm_dst, vmm_dst); + break; + case ov::element::i8: + h->uni_vpmovsxbd(vmm_dst, vmm_src); + break; + case ov::element::u8: + h->uni_vpmovzxbd(vmm_dst, vmm_src); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported input data type"); } switch (output_type) { - case ov::element::f32: - if (!one_of(input_type, ov::element::i32, ov::element::bf16, ov::element::f16)) { + case ov::element::f32: + if (!one_of(input_type, ov::element::i32, ov::element::bf16, ov::element::f16)) { + h->uni_vcvtdq2ps(vmm_dst, vmm_dst); + } + break; + case ov::element::i32: + break; + case ov::element::bf16: + if (input_type == ov::element::f32) { + float2bfloat({static_cast(vmm_src.getIdx())}, {static_cast(vmm_dst.getIdx())}); + } else { + if (one_of(input_type, ov::element::i8, ov::element::u8)) { h->uni_vcvtdq2ps(vmm_dst, vmm_dst); } - break; - case ov::element::i32: - break; - case ov::element::bf16: - if (input_type == ov::element::f32) { - float2bfloat({static_cast(vmm_src.getIdx())}, {static_cast(vmm_dst.getIdx())}); - } else { - if (one_of(input_type, ov::element::i8, ov::element::u8)) { - h->uni_vcvtdq2ps(vmm_dst, vmm_dst); - } - float2bfloat({static_cast(vmm_dst.getIdx())}, {static_cast(vmm_dst.getIdx())}); - } - break; - case ov::element::f16: - if (input_type == ov::element::f32) { - if (isa == dnnl::impl::cpu::x64::avx512_core) - h->vcvtps2ph(ymm_dst, vmm_src, 0x4); - else - h->vcvtps2ph(xmm_dst, vmm_src, 0x4); - } else { - if (one_of(input_type, ov::element::i8, ov::element::u8)) { - h->uni_vcvtdq2ps(vmm_dst, vmm_dst); - } - if (isa == dnnl::impl::cpu::x64::avx512_core) - h->vcvtps2ph(ymm_dst, vmm_dst, 0x4); - else - h->vcvtps2ph(xmm_dst, vmm_dst, 0x4); - } - break; - case ov::element::i8: - case ov::element::u8: - if (input_type == ov::element::i32) { - dword2int8({static_cast(vmm_src.getIdx())}, {static_cast(vmm_dst.getIdx())}, output_type.is_signed()); - } else { - dword2int8({static_cast(vmm_dst.getIdx())}, {static_cast(vmm_dst.getIdx())}, output_type.is_signed()); + float2bfloat({static_cast(vmm_dst.getIdx())}, {static_cast(vmm_dst.getIdx())}); + } + break; + case ov::element::f16: + if (input_type == ov::element::f32) { + if (isa == dnnl::impl::cpu::x64::avx512_core) + h->vcvtps2ph(ymm_dst, vmm_src, 0x4); + else + h->vcvtps2ph(xmm_dst, vmm_src, 0x4); + } else { + if (one_of(input_type, ov::element::i8, ov::element::u8)) { + h->uni_vcvtdq2ps(vmm_dst, vmm_dst); } - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported output data type"); + if (isa == dnnl::impl::cpu::x64::avx512_core) + h->vcvtps2ph(ymm_dst, vmm_dst, 0x4); + else + h->vcvtps2ph(xmm_dst, vmm_dst, 0x4); + } + break; + case ov::element::i8: + case ov::element::u8: + if (input_type == ov::element::i32) { + dword2int8({static_cast(vmm_src.getIdx())}, + {static_cast(vmm_dst.getIdx())}, + output_type.is_signed()); + } else { + dword2int8({static_cast(vmm_dst.getIdx())}, + {static_cast(vmm_dst.getIdx())}, + output_type.is_signed()); + } + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported output data type"); } } template -void jit_convert_saturation_emitter::dword2int8(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs, bool is_signed) const { +void jit_convert_saturation_emitter::dword2int8(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs, + bool is_signed) const { using Vmm = typename conditional3::type; Vmm vmm_src = Vmm(in_vec_idxs[0]); @@ -330,7 +349,7 @@ void jit_convert_saturation_emitter::dword2int8(const std::vector &in_ve if (is_signed) { h->vpmovsdb(xmm_dst, vmm_src); } else { - Vmm vmm_zero = Vmm(aux_vec_idxs[0]); + Vmm vmm_zero = Vmm(aux_vec_idxs[0]); h->vpxord(vmm_zero, vmm_zero, vmm_zero); h->vpmaxsd(vmm_dst, vmm_src, vmm_zero); h->vpmovusdb(xmm_dst, vmm_dst); @@ -353,8 +372,8 @@ void jit_convert_saturation_emitter::dword2int8(const std::vector &in_ve size_t jit_convert_saturation_emitter::aux_vecs_count() const { // 1 register is for dword2int8 unsigned - return output_type == ov::element::u8 && host_isa_ == dnnl::impl::cpu::x64::avx512_core? 1 : 0; + return output_type == ov::element::u8 && host_isa_ == dnnl::impl::cpu::x64::avx512_core ? 1 : 0; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_conversion_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_conversion_emitters.hpp index ee451ed358dd1a..29b85079573bee 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_conversion_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_conversion_emitters.hpp @@ -4,16 +4,18 @@ #pragma once -#include "jit_emitter.hpp" #include "jit_bf16_emitters.hpp" +#include "jit_emitter.hpp" namespace ov { namespace intel_cpu { class jit_convert_emitter : public jit_emitter { public: - jit_convert_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); + jit_convert_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; @@ -22,19 +24,13 @@ class jit_convert_emitter : public jit_emitter { void validate_types() const; template - void float2bfloat(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void float2bfloat(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; ov::element::Type input_type; ov::element::Type output_type; - const ov::element::TypeVector supported_types = { - ov::element::f32, - ov::element::i32, - ov::element::bf16, - ov::element::f16, - ov::element::i8, - ov::element::u8 - }; + const ov::element::TypeVector supported_types = + {ov::element::f32, ov::element::i32, ov::element::bf16, ov::element::f16, ov::element::i8, ov::element::u8}; std::shared_ptr uni_vcvtneps2bf16 = nullptr; }; @@ -45,16 +41,18 @@ class jit_convert_emitter : public jit_emitter { // 129 -> -127 class jit_convert_truncation_emitter : public jit_convert_emitter { public: - jit_convert_truncation_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); + jit_convert_truncation_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); private: void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; template - void dword2int8(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void dword2int8(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; bool is_i8_and_u8_case() const; void register_table_entries() override; @@ -66,19 +64,23 @@ class jit_convert_truncation_emitter : public jit_convert_emitter { // 129 -> 127 class jit_convert_saturation_emitter : public jit_convert_emitter { public: - jit_convert_saturation_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); + jit_convert_saturation_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); private: void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; template - void dword2int8(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs, bool is_signed) const; + void dword2int8(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs, + bool is_signed) const; size_t aux_vecs_count() const override; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_emitters.cpp index 0b315cdd309715..51e801208b927c 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_emitters.cpp @@ -3,6 +3,7 @@ // #include "jit_dnnl_emitters.hpp" + #include using namespace dnnl::impl::utils; @@ -17,9 +18,11 @@ std::set> jit_dnnl_emitter::get_supported_precisions( return {{element::f32}}; } -jit_dnnl_emitter::jit_dnnl_emitter(jit_generator *host, cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) +jit_dnnl_emitter::jit_dnnl_emitter(jit_generator* host, + cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { - kind = dnnl_eltwise_tanh; alpha = 0.f; beta = 0.f; @@ -27,33 +30,42 @@ jit_dnnl_emitter::jit_dnnl_emitter(jit_generator *host, cpu_isa_t host_isa, cons set_injector(); } -jit_dnnl_emitter::jit_dnnl_emitter(jit_generator *host, cpu_isa_t host_isa, - dnnl_alg_kind_t algKind, float alpha, float beta, +jit_dnnl_emitter::jit_dnnl_emitter(jit_generator* host, + cpu_isa_t host_isa, + dnnl_alg_kind_t algKind, + float alpha, + float beta, ov::element::Type exec_prc) - : jit_emitter(host, host_isa, exec_prc), kind(algKind), alpha(alpha), beta(beta) { - + : jit_emitter(host, host_isa, exec_prc), + kind(algKind), + alpha(alpha), + beta(beta) { set_injector(); } void jit_dnnl_emitter::set_injector() { if (host_isa_ == cpu::x64::sse41) { - eltwise_injector_sse42 = std::make_shared>( - h, kind, alpha, beta, 1.f); + eltwise_injector_sse42 = + std::make_shared>(h, kind, alpha, beta, 1.f); } else if (host_isa_ == cpu::x64::avx2) { - eltwise_injector_avx2 = std::make_shared>( - h, kind, alpha, beta, 1.f); + eltwise_injector_avx2 = + std::make_shared>(h, kind, alpha, beta, 1.f); } else if (host_isa_ == cpu::x64::avx512_core) { - eltwise_injector_avx512_core = std::make_shared>( - h, kind, alpha, beta, 1.f); + eltwise_injector_avx512_core = + std::make_shared>(h, kind, alpha, beta, 1.f); } else { OV_CPU_JIT_EMITTER_THROW("Unsupported ISA ", host_isa_); } } -size_t jit_dnnl_emitter::get_inputs_num() const { return 1; } +size_t jit_dnnl_emitter::get_inputs_num() const { + return 1; +} -void jit_dnnl_emitter::emit_code(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const { +void jit_dnnl_emitter::emit_code(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { if (host_isa_ == cpu::x64::sse41) { if (out_vec_idxs[0] != in_vec_idxs[0]) h->uni_vmovups(Xmm(out_vec_idxs[0]), Xmm(in_vec_idxs[0])); @@ -83,11 +95,13 @@ void jit_dnnl_emitter::emit_data() const { } } -jit_dnnl_aux_emitter::jit_dnnl_aux_emitter(jit_generator *host, cpu_isa_t host_isa, - dnnl_alg_kind_t algKind, float inpAlpha, float inpBeta, +jit_dnnl_aux_emitter::jit_dnnl_aux_emitter(jit_generator* host, + cpu_isa_t host_isa, + dnnl_alg_kind_t algKind, + float inpAlpha, + float inpBeta, ov::element::Type exec_prc) - : jit_dnnl_emitter(host, host_isa, algKind, inpAlpha, inpBeta, exec_prc) { -} + : jit_dnnl_emitter(host, host_isa, algKind, inpAlpha, inpBeta, exec_prc) {} -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_emitters.hpp index bdf04108370ed5..22e003ad261555 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_emitters.hpp @@ -4,8 +4,8 @@ #pragma once -#include "cpu/x64/jit_generator.hpp" #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" +#include "cpu/x64/jit_generator.hpp" #include "jit_emitter.hpp" namespace ov { @@ -13,30 +13,41 @@ namespace intel_cpu { class jit_dnnl_emitter : public jit_emitter { public: - void emit_code(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const override; + void emit_code(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const override; void emit_data() const override; - void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override {}; + void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override{}; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); protected: - jit_dnnl_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - dnnl_alg_kind_t algKind, float inpAlpha, float inpBeta, - ov::element::Type exec_prc = ov::element::f32); - jit_dnnl_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32); + jit_dnnl_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + dnnl_alg_kind_t algKind, + float inpAlpha, + float inpBeta, + ov::element::Type exec_prc = ov::element::f32); + jit_dnnl_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); void set_injector(); - dnnl_alg_kind_t kind {dnnl_alg_kind_undef}; - float alpha {0.f}; - float beta {0.f}; + dnnl_alg_kind_t kind{dnnl_alg_kind_undef}; + float alpha{0.f}; + float beta{0.f}; - std::shared_ptr> eltwise_injector_sse42; - std::shared_ptr> eltwise_injector_avx2; - std::shared_ptr> eltwise_injector_avx512_core; + std::shared_ptr> + eltwise_injector_sse42; + std::shared_ptr> + eltwise_injector_avx2; + std::shared_ptr> + eltwise_injector_avx512_core; private: size_t get_inputs_num() const override; @@ -44,12 +55,15 @@ class jit_dnnl_emitter : public jit_emitter { class jit_dnnl_aux_emitter : public jit_dnnl_emitter { public: - jit_dnnl_aux_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - dnnl_alg_kind_t algKind, float inpAlpha, float inpBeta, - ov::element::Type exec_prc = ov::element::f32); + jit_dnnl_aux_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + dnnl_alg_kind_t algKind, + float inpAlpha, + float inpBeta, + ov::element::Type exec_prc = ov::element::f32); private: }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_ext_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_ext_emitters.hpp index 7a4d1e31277e3b..0b7396b6fcd830 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_ext_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_ext_emitters.hpp @@ -4,9 +4,9 @@ #pragma once +#include "jit_dnnl_emitters.hpp" #include "openvino/opsets/opset5.hpp" #include "transformations/cpu_opset/common/op/swish_cpu.hpp" -#include "jit_dnnl_emitters.hpp" #include "utils/ngraph_utils.hpp" namespace ov { @@ -14,88 +14,102 @@ namespace intel_cpu { class jit_relu_emitter : public jit_dnnl_emitter { public: - jit_relu_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32) + jit_relu_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32) : jit_dnnl_emitter(host, host_isa, n, exec_prc) { - kind = dnnl_eltwise_relu; - alpha = 0.f; - beta = 0.f; + kind = dnnl_eltwise_relu; + alpha = 0.f; + beta = 0.f; - set_injector(); - } + set_injector(); + } }; class jit_sigmoid_emitter : public jit_dnnl_emitter { public: - jit_sigmoid_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32) + jit_sigmoid_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32) : jit_dnnl_emitter(host, host_isa, n, exec_prc) { - kind = dnnl_eltwise_logistic; - alpha = 0.f; - beta = 0.f; + kind = dnnl_eltwise_logistic; + alpha = 0.f; + beta = 0.f; - set_injector(); - } + set_injector(); + } }; class jit_tanh_emitter : public jit_dnnl_emitter { public: - jit_tanh_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32) + jit_tanh_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32) : jit_dnnl_emitter(host, host_isa, n, exec_prc) { - kind = dnnl_eltwise_tanh; - alpha = 0.f; - beta = 0.f; + kind = dnnl_eltwise_tanh; + alpha = 0.f; + beta = 0.f; - set_injector(); - } + set_injector(); + } }; class jit_elu_emitter : public jit_dnnl_emitter { public: - jit_elu_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32) + jit_elu_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32) : jit_dnnl_emitter(host, host_isa, n, exec_prc) { - kind = dnnl_eltwise_elu; - alpha = ov::as_type_ptr(n)->get_alpha(); - beta = 0.f; + kind = dnnl_eltwise_elu; + alpha = ov::as_type_ptr(n)->get_alpha(); + beta = 0.f; - set_injector(); - } + set_injector(); + } }; class jit_abs_emitter : public jit_dnnl_emitter { public: - jit_abs_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32) + jit_abs_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32) : jit_dnnl_emitter(host, host_isa, n, exec_prc) { - kind = dnnl_eltwise_abs; - alpha = 0.f; - beta = 0.f; + kind = dnnl_eltwise_abs; + alpha = 0.f; + beta = 0.f; - set_injector(); - } + set_injector(); + } }; class jit_clamp_emitter : public jit_dnnl_emitter { public: - jit_clamp_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32) + jit_clamp_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32) : jit_dnnl_emitter(host, host_isa, n, exec_prc) { - kind = dnnl_eltwise_clip; - auto op = ov::as_type_ptr(n); - alpha = op->get_min(); - beta = op->get_max(); + kind = dnnl_eltwise_clip; + auto op = ov::as_type_ptr(n); + alpha = op->get_min(); + beta = op->get_max(); - set_injector(); - } + set_injector(); + } }; class jit_swish_emitter : public jit_dnnl_emitter { public: - jit_swish_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32) - : jit_dnnl_emitter(host, host_isa, n, exec_prc) { + jit_swish_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32) + : jit_dnnl_emitter(host, host_isa, n, exec_prc) { kind = dnnl_eltwise_swish; auto op = ov::as_type_ptr(n); alpha = op->get_alpha(); @@ -107,9 +121,11 @@ class jit_swish_emitter : public jit_dnnl_emitter { class jit_hswish_emitter : public jit_dnnl_emitter { public: - jit_hswish_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32) - : jit_dnnl_emitter(host, host_isa, n, exec_prc) { + jit_hswish_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32) + : jit_dnnl_emitter(host, host_isa, n, exec_prc) { // since v3.0 oneDNN has flexible version of hardswish, ov still uses the one with hardcoded alpha and beta kind = dnnl_eltwise_hardswish; alpha = 1.f / 6.f; @@ -121,9 +137,11 @@ class jit_hswish_emitter : public jit_dnnl_emitter { class jit_gelu_v0_emitter : public jit_dnnl_emitter { public: - jit_gelu_v0_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_gelu_v0_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32) - : jit_dnnl_emitter(host, host_isa, n, exec_prc) { + : jit_dnnl_emitter(host, host_isa, n, exec_prc) { kind = dnnl_eltwise_gelu_erf; set_injector(); @@ -132,9 +150,11 @@ class jit_gelu_v0_emitter : public jit_dnnl_emitter { class jit_gelu_v7_emitter : public jit_dnnl_emitter { public: - jit_gelu_v7_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_gelu_v7_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32) - : jit_dnnl_emitter(host, host_isa, n, exec_prc) { + : jit_dnnl_emitter(host, host_isa, n, exec_prc) { auto gelu = getNgraphOpAs(n); ov::op::GeluApproximationMode approximationMode = gelu->get_approximation_mode(); if (approximationMode == ov::op::GeluApproximationMode::ERF) @@ -152,11 +172,11 @@ class jit_gelu_v7_emitter : public jit_dnnl_emitter { class jit_round_emitter : public jit_dnnl_emitter { public: - jit_round_emitter( - dnnl::impl::cpu::x64::jit_generator *host, - dnnl::impl::cpu::x64::cpu_isa_t host_isa, - const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32) : jit_dnnl_emitter(host, host_isa, n, exec_prc) { + jit_round_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32) + : jit_dnnl_emitter(host, host_isa, n, exec_prc) { const auto round = getNgraphOpAs(n); const auto mode = round->get_mode(); if ((mode != ov::opset5::Round::RoundMode::HALF_AWAY_FROM_ZERO) && @@ -165,12 +185,11 @@ class jit_round_emitter : public jit_dnnl_emitter { static_cast(mode)); } - kind = mode == ov::opset5::Round::RoundMode::HALF_AWAY_FROM_ZERO ? - dnnl_eltwise_round_half_away_from_zero : - dnnl_eltwise_round_half_to_even; + kind = mode == ov::opset5::Round::RoundMode::HALF_AWAY_FROM_ZERO ? dnnl_eltwise_round_half_away_from_zero + : dnnl_eltwise_round_half_to_even; set_injector(); } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.cpp index 0331a3ee4908b9..7a091fc946c2d8 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.cpp @@ -8,8 +8,8 @@ using namespace dnnl::impl::utils; using namespace dnnl::impl::cpu; using namespace Xbyak; -#define CONST_1_F 0x3f800000 // 1.f -#define INF_MASK 0x7F800000 +#define CONST_1_F 0x3f800000 // 1.f +#define INF_MASK 0x7F800000 #define INF_NEG_MASK 0xFF800000 namespace ov { @@ -22,23 +22,30 @@ ov::element::Type get_arithmetic_binary_exec_precision(const std::shared_ptr& node) -: jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} -jit_add_emitter::jit_add_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} +jit_add_emitter::jit_add_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} +jit_add_emitter::jit_add_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_add_emitter::get_inputs_num() const { return 2; } +size_t jit_add_emitter::get_inputs_num() const { + return 2; +} -void jit_add_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_add_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -51,7 +58,7 @@ void jit_add_emitter::emit_impl(const std::vector &in_vec_idxs, const st } template -void jit_add_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_add_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -59,9 +66,14 @@ void jit_add_emitter::emit_isa(const std::vector &in_vec_idxs, const std auto uni_vadd = [this](Vmm vmm_dst, Vmm vmm_src0, Vmm vmm_src1) { switch (exec_prc_) { - case ov::element::f32: h->uni_vaddps(vmm_dst, vmm_src0, vmm_src1); break; - case ov::element::i32: h->uni_vpaddd(vmm_dst, vmm_src0, vmm_src1); break; - default: OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); + case ov::element::f32: + h->uni_vaddps(vmm_dst, vmm_src0, vmm_src1); + break; + case ov::element::i32: + h->uni_vpaddd(vmm_dst, vmm_src0, vmm_src1); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); } }; @@ -78,14 +90,19 @@ std::set> jit_add_emitter::get_supported_precisions(c } /// MUL_ADD /// -jit_mul_add_emitter::jit_mul_add_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node) -: jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} -jit_mul_add_emitter::jit_mul_add_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} +jit_mul_add_emitter::jit_mul_add_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} +jit_mul_add_emitter::jit_mul_add_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_mul_add_emitter::get_inputs_num() const { return 3; } +size_t jit_mul_add_emitter::get_inputs_num() const { + return 3; +} -void jit_mul_add_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_mul_add_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -98,7 +115,8 @@ void jit_mul_add_emitter::emit_impl(const std::vector &in_vec_idxs, cons } template -void jit_mul_add_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_mul_add_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -109,47 +127,49 @@ void jit_mul_add_emitter::emit_isa(const std::vector &in_vec_idxs, const auto uni_vfmadd231_xmm = [this](Xmm vmm_dst, Xmm vmm_src0, Xmm vmm_src1, Xmm vmm_src2) { h->uni_vmovups(vmm_dst, vmm_src0); switch (exec_prc_) { - case ov::element::f32: { - h->uni_vmulps(vmm_dst, vmm_dst, vmm_src1); - h->uni_vaddps(vmm_dst, vmm_dst, vmm_src2); - } break; - case ov::element::i32: { - h->uni_vpmulld(vmm_dst, vmm_dst, vmm_src1); - h->uni_vpaddd(vmm_dst, vmm_dst, vmm_src2); - } break; - default: OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); + case ov::element::f32: { + h->uni_vmulps(vmm_dst, vmm_dst, vmm_src1); + h->uni_vaddps(vmm_dst, vmm_dst, vmm_src2); + } break; + case ov::element::i32: { + h->uni_vpmulld(vmm_dst, vmm_dst, vmm_src1); + h->uni_vpaddd(vmm_dst, vmm_dst, vmm_src2); + } break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); } }; auto uni_vfmadd231_vmm = [this, vmm_aux0](Vmm vmm_dst, Vmm vmm_src0, Vmm vmm_src1, Vmm vmm_src2) { switch (exec_prc_) { - case ov::element::f32: { - Vmm vmm_mul0; - if (vmm_dst.getIdx() == vmm_src0.getIdx()) { - h->uni_vmovups(vmm_aux0, vmm_src0); - vmm_mul0 = vmm_aux0; - } else { - vmm_mul0 = vmm_src0; - } - - Vmm vmm_mul1; - if (vmm_dst.getIdx() == vmm_src1.getIdx()) { - h->uni_vmovups(vmm_aux0, vmm_src1); - vmm_mul1 = vmm_aux0; - } else { - vmm_mul1 = vmm_src1; - } - - if (vmm_dst.getIdx() != vmm_src2.getIdx()) - h->uni_vmovups(vmm_dst, vmm_src2); - - h->uni_vfmadd231ps(vmm_dst, vmm_mul0, vmm_mul1); - } break; - case ov::element::i32: { - h->uni_vpmulld(vmm_dst, vmm_src0, vmm_src1); - h->uni_vpaddd(vmm_dst, vmm_dst, vmm_src2); - } break; - default: OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); + case ov::element::f32: { + Vmm vmm_mul0; + if (vmm_dst.getIdx() == vmm_src0.getIdx()) { + h->uni_vmovups(vmm_aux0, vmm_src0); + vmm_mul0 = vmm_aux0; + } else { + vmm_mul0 = vmm_src0; + } + + Vmm vmm_mul1; + if (vmm_dst.getIdx() == vmm_src1.getIdx()) { + h->uni_vmovups(vmm_aux0, vmm_src1); + vmm_mul1 = vmm_aux0; + } else { + vmm_mul1 = vmm_src1; + } + + if (vmm_dst.getIdx() != vmm_src2.getIdx()) + h->uni_vmovups(vmm_dst, vmm_src2); + + h->uni_vfmadd231ps(vmm_dst, vmm_mul0, vmm_mul1); + } break; + case ov::element::i32: { + h->uni_vpmulld(vmm_dst, vmm_src0, vmm_src1); + h->uni_vpaddd(vmm_dst, vmm_dst, vmm_src2); + } break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); } }; @@ -164,19 +184,27 @@ size_t jit_mul_add_emitter::aux_vecs_count() const { return 1; } -std::set> jit_mul_add_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_mul_add_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32, element::f32}, {element::i32, element::i32, element::i32}}; } /// SUB /// -jit_subtract_emitter::jit_subtract_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node) -: jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} -jit_subtract_emitter::jit_subtract_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} +jit_subtract_emitter::jit_subtract_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} +jit_subtract_emitter::jit_subtract_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_subtract_emitter::get_inputs_num() const { return 2; } +size_t jit_subtract_emitter::get_inputs_num() const { + return 2; +} -void jit_subtract_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_subtract_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -189,7 +217,8 @@ void jit_subtract_emitter::emit_impl(const std::vector &in_vec_idxs, con } template -void jit_subtract_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_subtract_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -197,9 +226,14 @@ void jit_subtract_emitter::emit_isa(const std::vector &in_vec_idxs, cons auto uni_vsub = [this](Vmm vmm_dst, Vmm vmm_src0, Vmm vmm_src1) { switch (exec_prc_) { - case ov::element::f32: h->uni_vsubps(vmm_dst, vmm_src0, vmm_src1); break; - case ov::element::i32: h->uni_vpsubd(vmm_dst, vmm_src0, vmm_src1); break; - default: OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); + case ov::element::f32: + h->uni_vsubps(vmm_dst, vmm_src0, vmm_src1); + break; + case ov::element::i32: + h->uni_vpsubd(vmm_dst, vmm_src0, vmm_src1); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); } }; @@ -211,19 +245,27 @@ void jit_subtract_emitter::emit_isa(const std::vector &in_vec_idxs, cons } } -std::set> jit_subtract_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_subtract_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}, {element::i32, element::i32}}; } /// MULTIPLY /// -jit_multiply_emitter::jit_multiply_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node) -: jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} -jit_multiply_emitter::jit_multiply_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} +jit_multiply_emitter::jit_multiply_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} +jit_multiply_emitter::jit_multiply_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_multiply_emitter::get_inputs_num() const { return 2; } +size_t jit_multiply_emitter::get_inputs_num() const { + return 2; +} -void jit_multiply_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_multiply_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -236,7 +278,8 @@ void jit_multiply_emitter::emit_impl(const std::vector &in_vec_idxs, con } template -void jit_multiply_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_multiply_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -244,9 +287,14 @@ void jit_multiply_emitter::emit_isa(const std::vector &in_vec_idxs, cons auto uni_vmul = [this](Vmm vmm_dst, Vmm vmm_src0, Vmm vmm_src1) { switch (exec_prc_) { - case ov::element::f32: h->uni_vmulps(vmm_dst, vmm_src0, vmm_src1); break; - case ov::element::i32: h->uni_vpmulld(vmm_dst, vmm_src0, vmm_src1); break; - default: OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); + case ov::element::f32: + h->uni_vmulps(vmm_dst, vmm_src0, vmm_src1); + break; + case ov::element::i32: + h->uni_vpmulld(vmm_dst, vmm_src0, vmm_src1); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); } }; @@ -258,19 +306,26 @@ void jit_multiply_emitter::emit_isa(const std::vector &in_vec_idxs, cons } } -std::set> jit_multiply_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_multiply_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}, {element::i32, element::i32}}; } /// DIVIDE /// -jit_divide_emitter::jit_divide_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} -jit_divide_emitter::jit_divide_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} +jit_divide_emitter::jit_divide_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} +jit_divide_emitter::jit_divide_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_divide_emitter::get_inputs_num() const { return 2; } +size_t jit_divide_emitter::get_inputs_num() const { + return 2; +} -void jit_divide_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_divide_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -283,7 +338,8 @@ void jit_divide_emitter::emit_impl(const std::vector &in_vec_idxs, const } template -void jit_divide_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_divide_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -291,23 +347,24 @@ void jit_divide_emitter::emit_isa(const std::vector &in_vec_idxs, const auto uni_vdiv = [this](Vmm vmm_dst, Vmm vmm_src0, Vmm vmm_src1) { switch (exec_prc_) { - case ov::element::f32: { - h->uni_vdivps(vmm_dst, vmm_src0, vmm_src1); - break; - } - case ov::element::i32: { - Vmm vmm_aux0 = Vmm(aux_vec_idxs[0]); - - // The opset doesn't contain vector instruction for integer divide operation - // As WA we emulate its behavior via fp divide followed by rounding to zero - h->uni_vcvtdq2ps(vmm_dst, vmm_src0); - h->uni_vcvtdq2ps(vmm_aux0, vmm_src1); - h->uni_vdivps(vmm_dst, vmm_dst, vmm_aux0); - h->uni_vroundps(vmm_dst, vmm_dst, 3); // rounding to zero - h->uni_vcvtps2dq(vmm_dst, vmm_dst); - break; - } - default: OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); + case ov::element::f32: { + h->uni_vdivps(vmm_dst, vmm_src0, vmm_src1); + break; + } + case ov::element::i32: { + Vmm vmm_aux0 = Vmm(aux_vec_idxs[0]); + + // The opset doesn't contain vector instruction for integer divide operation + // As WA we emulate its behavior via fp divide followed by rounding to zero + h->uni_vcvtdq2ps(vmm_dst, vmm_src0); + h->uni_vcvtdq2ps(vmm_aux0, vmm_src1); + h->uni_vdivps(vmm_dst, vmm_dst, vmm_aux0); + h->uni_vroundps(vmm_dst, vmm_dst, 3); // rounding to zero + h->uni_vcvtps2dq(vmm_dst, vmm_dst); + break; + } + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); } }; @@ -319,7 +376,8 @@ void jit_divide_emitter::emit_isa(const std::vector &in_vec_idxs, const } } -std::set> jit_divide_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_divide_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}, {element::i32, element::i32}}; } @@ -328,18 +386,25 @@ size_t jit_divide_emitter::aux_vecs_count() const { } /// FLOOR /// -jit_floor_emitter::jit_floor_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} -jit_floor_emitter::jit_floor_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} +jit_floor_emitter::jit_floor_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} +jit_floor_emitter::jit_floor_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_floor_emitter::get_inputs_num() const { return 1; } +size_t jit_floor_emitter::get_inputs_num() const { + return 1; +} -std::set> jit_floor_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_floor_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } -void jit_floor_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_floor_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -352,7 +417,8 @@ void jit_floor_emitter::emit_impl(const std::vector& in_vec_idxs, const } template -void jit_floor_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_floor_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src = Vmm(in_vec_idxs[0]); Vmm vmm_dst = Vmm(out_vec_idxs[0]); @@ -360,14 +426,20 @@ void jit_floor_emitter::emit_isa(const std::vector &in_vec_idxs, const s } /// CEILING /// -jit_ceiling_emitter::jit_ceiling_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} -jit_ceiling_emitter::jit_ceiling_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) +jit_ceiling_emitter::jit_ceiling_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} +jit_ceiling_emitter::jit_ceiling_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_ceiling_emitter::get_inputs_num() const { return 1; } +size_t jit_ceiling_emitter::get_inputs_num() const { + return 1; +} -std::set> jit_ceiling_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_ceiling_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } @@ -385,7 +457,8 @@ void jit_ceiling_emitter::emit_impl(const std::vector& in_vec_idxs, } template -void jit_ceiling_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_ceiling_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src = Vmm(in_vec_idxs[0]); Vmm vmm_dst = Vmm(out_vec_idxs[0]); @@ -403,13 +476,17 @@ jit_floor_mod_emitter::jit_floor_mod_emitter(x64::jit_generator* host, ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_floor_mod_emitter::get_inputs_num() const { return 2; } +size_t jit_floor_mod_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_floor_mod_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_floor_mod_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_floor_mod_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_floor_mod_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -422,7 +499,8 @@ void jit_floor_mod_emitter::emit_impl(const std::vector& in_vec_idxs, co } template -void jit_floor_mod_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_floor_mod_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -434,14 +512,14 @@ void jit_floor_mod_emitter::emit_isa(const std::vector &in_vec_idxs, con h->uni_vmovups(vmm_dst, vmm_src0); h->uni_vmovups(vmm_aux0, vmm_src0); h->uni_vdivps(vmm_aux0, vmm_aux0, vmm_src1); - h->uni_vroundps(vmm_aux0, vmm_aux0, 1); // rounding down + h->uni_vroundps(vmm_aux0, vmm_aux0, 1); // rounding down h->uni_vmulps(vmm_aux0, vmm_aux0, vmm_src1); h->uni_vsubps(vmm_dst, vmm_dst, vmm_aux0); } else { if (vmm_dst.getIdx() != vmm_src0.getIdx()) h->uni_vmovups(vmm_dst, vmm_src0); h->uni_vdivps(vmm_aux0, vmm_src0, vmm_src1); - h->uni_vroundps(vmm_aux0, vmm_aux0, 1); // rounding down + h->uni_vroundps(vmm_aux0, vmm_aux0, 1); // rounding down h->uni_vmulps(vmm_aux0, vmm_aux0, vmm_src1); h->uni_vsubps(vmm_dst, vmm_dst, vmm_aux0); } @@ -452,12 +530,17 @@ size_t jit_floor_mod_emitter::aux_vecs_count() const { } /// MOD /// -jit_mod_emitter::jit_mod_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} -jit_mod_emitter::jit_mod_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} +jit_mod_emitter::jit_mod_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} +jit_mod_emitter::jit_mod_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_mod_emitter::get_inputs_num() const { return 2; } +size_t jit_mod_emitter::get_inputs_num() const { + return 2; +} std::set> jit_mod_emitter::get_supported_precisions(const std::shared_ptr& node) { return {{element::f32, element::f32}}; @@ -476,7 +559,7 @@ void jit_mod_emitter::emit_impl(const std::vector& in_vec_idxs, const st } template -void jit_mod_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_mod_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -488,14 +571,14 @@ void jit_mod_emitter::emit_isa(const std::vector &in_vec_idxs, const std h->uni_vmovups(vmm_dst, vmm_src0); h->uni_vmovups(vmm_aux0, vmm_src0); h->uni_vdivps(vmm_aux0, vmm_aux0, vmm_src1); - h->uni_vroundps(vmm_aux0, vmm_aux0, 3); // truncate + h->uni_vroundps(vmm_aux0, vmm_aux0, 3); // truncate h->uni_vmulps(vmm_aux0, vmm_aux0, vmm_src1); h->uni_vsubps(vmm_dst, vmm_dst, vmm_aux0); } else { if (vmm_dst.getIdx() != vmm_src0.getIdx()) h->uni_vmovups(vmm_dst, vmm_src0); h->uni_vdivps(vmm_aux0, vmm_src0, vmm_src1); - h->uni_vroundps(vmm_aux0, vmm_aux0, 3); // truncate + h->uni_vroundps(vmm_aux0, vmm_aux0, 3); // truncate h->uni_vmulps(vmm_aux0, vmm_aux0, vmm_src1); h->uni_vsubps(vmm_dst, vmm_dst, vmm_aux0); } @@ -506,14 +589,19 @@ size_t jit_mod_emitter::aux_vecs_count() const { } /// MAXIMUM /// -jit_maximum_emitter::jit_maximum_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node) -: jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} -jit_maximum_emitter::jit_maximum_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} +jit_maximum_emitter::jit_maximum_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} +jit_maximum_emitter::jit_maximum_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_maximum_emitter::get_inputs_num() const { return 2; } +size_t jit_maximum_emitter::get_inputs_num() const { + return 2; +} -void jit_maximum_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_maximum_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -526,7 +614,8 @@ void jit_maximum_emitter::emit_impl(const std::vector &in_vec_idxs, cons } template -void jit_maximum_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_maximum_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -534,9 +623,14 @@ void jit_maximum_emitter::emit_isa(const std::vector &in_vec_idxs, const auto uni_vmax = [this](Vmm vmm_dst, Vmm vmm_src0, Vmm vmm_src1) { switch (exec_prc_) { - case ov::element::f32: h->uni_vmaxps(vmm_dst, vmm_src0, vmm_src1); break; - case ov::element::i32: h->uni_vpmaxsd(vmm_dst, vmm_src0, vmm_src1); break; - default: OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); + case ov::element::f32: + h->uni_vmaxps(vmm_dst, vmm_src0, vmm_src1); + break; + case ov::element::i32: + h->uni_vpmaxsd(vmm_dst, vmm_src0, vmm_src1); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); } }; @@ -549,19 +643,25 @@ void jit_maximum_emitter::emit_isa(const std::vector &in_vec_idxs, const } } -std::set> jit_maximum_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_maximum_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}, {element::i32, element::i32}}; } /// MINIMUM /// -jit_minimum_emitter::jit_minimum_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node) -: jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} -jit_minimum_emitter::jit_minimum_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} +jit_minimum_emitter::jit_minimum_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node) + : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {} +jit_minimum_emitter::jit_minimum_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_minimum_emitter::get_inputs_num() const { return 2; } +size_t jit_minimum_emitter::get_inputs_num() const { + return 2; +} -void jit_minimum_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_minimum_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -574,7 +674,8 @@ void jit_minimum_emitter::emit_impl(const std::vector &in_vec_idxs, cons } template -void jit_minimum_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_minimum_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -582,9 +683,14 @@ void jit_minimum_emitter::emit_isa(const std::vector &in_vec_idxs, const auto uni_vmin = [this](Vmm vmm_dst, Vmm vmm_src0, Vmm vmm_src1) { switch (exec_prc_) { - case ov::element::f32: h->uni_vminps(vmm_dst, vmm_src0, vmm_src1); break; - case ov::element::i32: h->uni_vpminsd(vmm_dst, vmm_src0, vmm_src1); break; - default: OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); + case ov::element::f32: + h->uni_vminps(vmm_dst, vmm_src0, vmm_src1); + break; + case ov::element::i32: + h->uni_vpminsd(vmm_dst, vmm_src0, vmm_src1); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); } }; @@ -597,20 +703,28 @@ void jit_minimum_emitter::emit_isa(const std::vector &in_vec_idxs, const } } -std::set> jit_minimum_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_minimum_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}, {element::i32, element::i32}}; } /// SQUARED_DIFFERENCE /// -jit_squared_difference_emitter::jit_squared_difference_emitter( - x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} -jit_squared_difference_emitter::jit_squared_difference_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) {} +jit_squared_difference_emitter::jit_squared_difference_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} +jit_squared_difference_emitter::jit_squared_difference_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_squared_difference_emitter::get_inputs_num() const { return 2; } +size_t jit_squared_difference_emitter::get_inputs_num() const { + return 2; +} -void jit_squared_difference_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_squared_difference_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -623,7 +737,8 @@ void jit_squared_difference_emitter::emit_impl(const std::vector &in_vec } template -void jit_squared_difference_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_squared_difference_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -631,15 +746,16 @@ void jit_squared_difference_emitter::emit_isa(const std::vector &in_vec_ auto uni_vsqdiff = [this](Vmm vmm_dst, Vmm vmm_src0, Vmm vmm_src1) { switch (exec_prc_) { - case ov::element::f32: { - h->uni_vsubps(vmm_dst, vmm_src0, vmm_src1); - h->uni_vmulps(vmm_dst, vmm_dst, vmm_dst); - } break; - case ov::element::i32: { - h->uni_vpsubd(vmm_dst, vmm_src0, vmm_src1); - h->uni_vpmulld(vmm_dst, vmm_dst, vmm_dst); - } break; - default: OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); + case ov::element::f32: { + h->uni_vsubps(vmm_dst, vmm_src0, vmm_src1); + h->uni_vmulps(vmm_dst, vmm_dst, vmm_dst); + } break; + case ov::element::i32: { + h->uni_vpsubd(vmm_dst, vmm_src0, vmm_src1); + h->uni_vpmulld(vmm_dst, vmm_dst, vmm_dst); + } break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported precision"); } }; @@ -652,24 +768,33 @@ void jit_squared_difference_emitter::emit_isa(const std::vector &in_vec_ } } -std::set> jit_squared_difference_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_squared_difference_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}, {element::i32, element::i32}}; } /// POWER_DYNAMIC /// -jit_power_dynamic_emitter::jit_power_dynamic_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, +jit_power_dynamic_emitter::jit_power_dynamic_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node, ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -jit_power_dynamic_emitter::jit_power_dynamic_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) +jit_power_dynamic_emitter::jit_power_dynamic_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_power_dynamic_emitter::get_inputs_num() const { return 2; } +size_t jit_power_dynamic_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_power_dynamic_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_power_dynamic_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_power_dynamic_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_power_dynamic_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -682,7 +807,8 @@ void jit_power_dynamic_emitter::emit_impl(const std::vector& in_vec_idxs } template -void jit_power_dynamic_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_power_dynamic_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -692,8 +818,8 @@ void jit_power_dynamic_emitter::emit_isa(const std::vector &in_vec_idxs, // caller obligation to save gprs as callee may use them size_t gpr_size = 8; - Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax, - h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; + Xbyak::Operand gprs_to_save[] = + {h->r8, h->r9, h->r10, h->r11, h->rax, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); h->sub(h->rsp, n_gprs_to_save * gpr_size); @@ -722,8 +848,8 @@ void jit_power_dynamic_emitter::emit_isa(const std::vector &in_vec_idxs, h->sub(h->rsp, (get_max_vecs_count() + 2) * get_vec_length()); for (size_t i = 2; i < get_max_vecs_count() + 2; ++i) h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Vmm(i - 2)); - h->uni_vmovups(h->ptr[h->rsp + 0 * get_vec_length()], vmm_src0); // src - h->uni_vmovups(h->ptr[h->rsp + 1 * get_vec_length()], vmm_src1); // beta + h->uni_vmovups(h->ptr[h->rsp + 0 * get_vec_length()], vmm_src0); // src + h->uni_vmovups(h->ptr[h->rsp + 1 * get_vec_length()], vmm_src1); // beta // save function address in gpr to pass in in call instruction h->mov(h->rbp, reinterpret_cast(powf)); @@ -735,7 +861,7 @@ void jit_power_dynamic_emitter::emit_isa(const std::vector &in_vec_idxs, // Take src, apply powf on it and replace value on a stack with dst. for (size_t i = 0; i < get_vec_length() / sizeof(float); ++i) { - const Address &source = h->ptr[h->rsp + h->rbx + i * sizeof(float)]; + const Address& source = h->ptr[h->rsp + h->rbx + i * sizeof(float)]; h->uni_vmovss(xmm0, source); h->uni_vmovss(xmm1, h->ptr[h->rsp + h->rbx + get_vec_length() + i * sizeof(float)]); h->call(h->rbp); @@ -767,24 +893,30 @@ void jit_power_dynamic_emitter::emit_isa(const std::vector &in_vec_idxs, h->add(h->rsp, n_gprs_to_save * gpr_size); } - /// EQUAL /// -jit_equal_emitter::jit_equal_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) { +jit_equal_emitter::jit_equal_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -jit_equal_emitter::jit_equal_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) { +jit_equal_emitter::jit_equal_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_equal_emitter::get_inputs_num() const { return 2; } +size_t jit_equal_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_equal_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_equal_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_equal_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_equal_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -797,7 +929,8 @@ void jit_equal_emitter::emit_impl(const std::vector& in_vec_idxs, const } template -void jit_equal_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_equal_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -846,13 +979,17 @@ jit_not_equal_emitter::jit_not_equal_emitter(x64::jit_generator* host, prepare_table(); } -size_t jit_not_equal_emitter::get_inputs_num() const { return 2; } +size_t jit_not_equal_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_not_equal_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_not_equal_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_not_equal_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_not_equal_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -865,7 +1002,8 @@ void jit_not_equal_emitter::emit_impl(const std::vector& in_vec_idxs, co } template -void jit_not_equal_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_not_equal_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -900,22 +1038,29 @@ size_t jit_not_equal_emitter::aux_vecs_count() const { } /// GREATER /// -jit_greater_emitter::jit_greater_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) { +jit_greater_emitter::jit_greater_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -jit_greater_emitter::jit_greater_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) { +jit_greater_emitter::jit_greater_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_greater_emitter::get_inputs_num() const { return 2; } +size_t jit_greater_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_greater_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_greater_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_greater_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_greater_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -928,7 +1073,8 @@ void jit_greater_emitter::emit_impl(const std::vector& in_vec_idxs, cons } template -void jit_greater_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_greater_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -963,23 +1109,31 @@ size_t jit_greater_emitter::aux_vecs_count() const { } /// GREATER_EQUAL /// -jit_greater_equal_emitter::jit_greater_equal_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, +jit_greater_equal_emitter::jit_greater_equal_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) { + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -jit_greater_equal_emitter::jit_greater_equal_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) { +jit_greater_equal_emitter::jit_greater_equal_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_greater_equal_emitter::get_inputs_num() const { return 2; } +size_t jit_greater_equal_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_greater_equal_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_greater_equal_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_greater_equal_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_greater_equal_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -992,7 +1146,8 @@ void jit_greater_equal_emitter::emit_impl(const std::vector& in_vec_idxs } template -void jit_greater_equal_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_greater_equal_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -1027,22 +1182,28 @@ size_t jit_greater_equal_emitter::aux_vecs_count() const { } /// LESS /// -jit_less_emitter::jit_less_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) { +jit_less_emitter::jit_less_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -jit_less_emitter::jit_less_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc) { +jit_less_emitter::jit_less_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) + : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_less_emitter::get_inputs_num() const { return 2; } +size_t jit_less_emitter::get_inputs_num() const { + return 2; +} std::set> jit_less_emitter::get_supported_precisions(const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_less_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_less_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1055,7 +1216,7 @@ void jit_less_emitter::emit_impl(const std::vector& in_vec_idxs, const s } template -void jit_less_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_less_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -1104,13 +1265,17 @@ jit_less_equal_emitter::jit_less_equal_emitter(x64::jit_generator* host, prepare_table(); } -size_t jit_less_equal_emitter::get_inputs_num() const { return 2; } +size_t jit_less_equal_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_less_equal_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_less_equal_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_less_equal_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_less_equal_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1123,7 +1288,8 @@ void jit_less_equal_emitter::emit_impl(const std::vector& in_vec_idxs, c } template -void jit_less_equal_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_less_equal_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -1173,13 +1339,17 @@ jit_logical_and_emitter::jit_logical_and_emitter(x64::jit_generator* host, prepare_table(); } -size_t jit_logical_and_emitter::get_inputs_num() const { return 2; } +size_t jit_logical_and_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_logical_and_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_logical_and_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_logical_and_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_logical_and_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1192,7 +1362,8 @@ void jit_logical_and_emitter::emit_impl(const std::vector& in_vec_idxs, } template -void jit_logical_and_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_logical_and_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -1261,13 +1432,17 @@ jit_logical_or_emitter::jit_logical_or_emitter(x64::jit_generator* host, prepare_table(); } -size_t jit_logical_or_emitter::get_inputs_num() const { return 2; } +size_t jit_logical_or_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_logical_or_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_logical_or_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_logical_or_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_logical_or_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1280,7 +1455,8 @@ void jit_logical_or_emitter::emit_impl(const std::vector& in_vec_idxs, c } template -void jit_logical_or_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_logical_or_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -1349,13 +1525,17 @@ jit_logical_xor_emitter::jit_logical_xor_emitter(x64::jit_generator* host, prepare_table(); } -size_t jit_logical_xor_emitter::get_inputs_num() const { return 2; } +size_t jit_logical_xor_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_logical_xor_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_logical_xor_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_logical_xor_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_logical_xor_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1368,7 +1548,8 @@ void jit_logical_xor_emitter::emit_impl(const std::vector& in_vec_idxs, } template -void jit_logical_xor_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_logical_xor_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -1437,13 +1618,17 @@ jit_logical_not_emitter::jit_logical_not_emitter(x64::jit_generator* host, prepare_table(); } -size_t jit_logical_not_emitter::get_inputs_num() const { return 1; } +size_t jit_logical_not_emitter::get_inputs_num() const { + return 1; +} -std::set> jit_logical_not_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_logical_not_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } -void jit_logical_not_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_logical_not_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1456,7 +1641,8 @@ void jit_logical_not_emitter::emit_impl(const std::vector& in_vec_idxs, } template -void jit_logical_not_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_logical_not_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_dst = Vmm(out_vec_idxs[0]); @@ -1507,20 +1693,30 @@ jit_power_static_emitter::jit_power_static_emitter(x64::jit_generator* host, prepare_table(); } -jit_power_static_emitter::jit_power_static_emitter(x64::jit_generator *host, x64::cpu_isa_t host_isa, - float inpPower, float inpScale, float inpShift, +jit_power_static_emitter::jit_power_static_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + float inpPower, + float inpScale, + float inpShift, ov::element::Type exec_prc) -: jit_emitter(host, host_isa, exec_prc), power(inpPower), scale(inpScale), shift(inpShift) { + : jit_emitter(host, host_isa, exec_prc), + power(inpPower), + scale(inpScale), + shift(inpShift) { prepare_table(); } -size_t jit_power_static_emitter::get_inputs_num() const { return 1; } +size_t jit_power_static_emitter::get_inputs_num() const { + return 1; +} -std::set> jit_power_static_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_power_static_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } -void jit_power_static_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_power_static_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1533,7 +1729,8 @@ void jit_power_static_emitter::emit_impl(const std::vector& in_vec_idxs, } template -void jit_power_static_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_power_static_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_dst = Vmm(out_vec_idxs[0]); @@ -1600,8 +1797,8 @@ void jit_power_static_emitter::emit_isa(const std::vector &in_vec_idxs, // caller obligation to save gprs as callee may use them size_t gpr_size = 8; - Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax, - h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; + Xbyak::Operand gprs_to_save[] = + {h->r8, h->r9, h->r10, h->r11, h->rax, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); h->sub(h->rsp, n_gprs_to_save * gpr_size); @@ -1630,8 +1827,8 @@ void jit_power_static_emitter::emit_isa(const std::vector &in_vec_idxs, h->sub(h->rsp, (get_max_vecs_count() + 2) * get_vec_length()); for (size_t i = 2; i < get_max_vecs_count() + 2; ++i) h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Vmm(i - 2)); - h->uni_vmovups(h->ptr[h->rsp + 0 * get_vec_length()], vmm_dst); // src - h->uni_vmovups(h->ptr[h->rsp + 1 * get_vec_length()], vmm_aux0); // beta + h->uni_vmovups(h->ptr[h->rsp + 0 * get_vec_length()], vmm_dst); // src + h->uni_vmovups(h->ptr[h->rsp + 1 * get_vec_length()], vmm_aux0); // beta // save function address in gpr to pass in in call instruction h->mov(h->rbp, reinterpret_cast(powf)); @@ -1643,7 +1840,7 @@ void jit_power_static_emitter::emit_isa(const std::vector &in_vec_idxs, // Take src, apply powf on it and replace value on a stack with dst. for (size_t i = 0; i < get_vec_length() / sizeof(float); ++i) { - const Address &source = h->ptr[h->rsp + h->rbx + i * sizeof(float)]; + const Address& source = h->ptr[h->rsp + h->rbx + i * sizeof(float)]; h->uni_vmovss(xmm0, source); h->uni_vmovss(xmm1, h->ptr[h->rsp + h->rbx + get_vec_length() + i * sizeof(float)]); h->call(h->rbp); @@ -1680,7 +1877,7 @@ void jit_power_static_emitter::register_table_entries() { push_arg_entry_of("power", x64::float2int(power), true); push_arg_entry_of("scale", x64::float2int(scale), true); push_arg_entry_of("shift", x64::float2int(shift), true); - push_arg_entry_of("one", x64::float2int(1.f), true); + push_arg_entry_of("one", x64::float2int(1.f), true); } size_t jit_power_static_emitter::aux_vecs_count() const { @@ -1699,13 +1896,17 @@ jit_prelu_emitter::jit_prelu_emitter(x64::jit_generator* host, x64::cpu_isa_t ho : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_prelu_emitter::get_inputs_num() const { return 2; } +size_t jit_prelu_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_prelu_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_prelu_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32}}; } -void jit_prelu_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_prelu_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1718,7 +1919,8 @@ void jit_prelu_emitter::emit_impl(const std::vector& in_vec_idxs, const } template -void jit_prelu_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_prelu_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -1761,13 +1963,16 @@ jit_sqrt_emitter::jit_sqrt_emitter(x64::jit_generator* host, jit_sqrt_emitter::jit_sqrt_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_sqrt_emitter::get_inputs_num() const { return 1; } +size_t jit_sqrt_emitter::get_inputs_num() const { + return 1; +} std::set> jit_sqrt_emitter::get_supported_precisions(const std::shared_ptr& node) { return {{element::f32}}; } -void jit_sqrt_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_sqrt_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1780,12 +1985,12 @@ void jit_sqrt_emitter::emit_impl(const std::vector& in_vec_idxs, const s } template -void jit_sqrt_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_sqrt_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_dst = Vmm(out_vec_idxs[0]); - h->uni_vsqrtps(vmm_dst, vmm_src0); + h->uni_vsqrtps(vmm_dst, vmm_src0); } /// Negate /// @@ -1795,13 +2000,17 @@ jit_negative_emitter::jit_negative_emitter(x64::jit_generator* host, ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_negative_emitter::get_inputs_num() const { return 1; } +size_t jit_negative_emitter::get_inputs_num() const { + return 1; +} -std::set> jit_negative_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_negative_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } -void jit_negative_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_negative_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1814,33 +2023,38 @@ void jit_negative_emitter::emit_impl(const std::vector& in_vec_idxs, con } template -void jit_negative_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_negative_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src = Vmm(in_vec_idxs[0]); - Vmm vmm_dst = Vmm(out_vec_idxs[0]); + Vmm vmm_dst = Vmm(out_vec_idxs[0]); h->uni_vpxor(vmm_dst, vmm_dst, vmm_dst); h->uni_vsubps(vmm_dst, vmm_dst, vmm_src); } - /// EXP /// jit_exp_emitter::jit_exp_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -jit_exp_emitter::jit_exp_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) +jit_exp_emitter::jit_exp_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_exp_emitter::get_inputs_num() const { return 1; } +size_t jit_exp_emitter::get_inputs_num() const { + return 1; +} std::set> jit_exp_emitter::get_supported_precisions(const std::shared_ptr& node) { return {{element::f32}}; } -void jit_exp_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_exp_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1853,7 +2067,7 @@ void jit_exp_emitter::emit_impl(const std::vector &in_vec_idxs, const st } template -void jit_exp_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_exp_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src = Vmm(in_vec_idxs[0]); Vmm vmm_dst = Vmm(out_vec_idxs[0]); @@ -1862,7 +2076,7 @@ void jit_exp_emitter::emit_isa(const std::vector &in_vec_idxs, const std Vmm vmm_aux0 = Vmm(aux_vec_idxs[0 + static_cast(need_vmm_mask())]); Vmm vmm_aux1 = Vmm(aux_vec_idxs[1 + static_cast(need_vmm_mask())]); - auto compute_cmp_mask = [&](const Vmm &vmm_src, const Xbyak::Operand &compare_operand, int cmp_predicate) { + auto compute_cmp_mask = [&](const Vmm& vmm_src, const Xbyak::Operand& compare_operand, int cmp_predicate) { if (host_isa_ == x64::avx512_core) { h->vcmpps(k_mask, vmm_src, compare_operand, cmp_predicate); } else { @@ -1870,7 +2084,7 @@ void jit_exp_emitter::emit_isa(const std::vector &in_vec_idxs, const std } }; - auto blend_with_mask = [&](const Vmm &vmm_dst, const Xbyak::Operand &src) { + auto blend_with_mask = [&](const Vmm& vmm_dst, const Xbyak::Operand& src) { if (host_isa_ == x64::avx512_core) { h->vblendmps(vmm_dst | k_mask, vmm_dst, src); } else { @@ -1924,11 +2138,11 @@ void jit_exp_emitter::emit_isa(const std::vector &in_vec_idxs, const std } void jit_exp_emitter::register_table_entries() { - push_arg_entry_of("pol1", 0x3f7ffffb, true); // p1 = 0.999999701f - push_arg_entry_of("pol2", 0x3efffee3, true); // p2 = 0.499991506f - push_arg_entry_of("pol3", 0x3e2aad40, true); // p3 = 0.166676521f - push_arg_entry_of("pol4", 0x3d2b9d0d, true); // p4 = 0.0418978221f - push_arg_entry_of("pol5", 0x3c07cfce, true); // p5 = 0.00828929059f + push_arg_entry_of("pol1", 0x3f7ffffb, true); // p1 = 0.999999701f + push_arg_entry_of("pol2", 0x3efffee3, true); // p2 = 0.499991506f + push_arg_entry_of("pol3", 0x3e2aad40, true); // p3 = 0.166676521f + push_arg_entry_of("pol4", 0x3d2b9d0d, true); // p4 = 0.0418978221f + push_arg_entry_of("pol5", 0x3c07cfce, true); // p5 = 0.00828929059f push_arg_entry_of("one", CONST_1_F, true); push_arg_entry_of("half", 0x3f000000, true); @@ -1950,16 +2164,21 @@ jit_erf_emitter::jit_erf_emitter(x64::jit_generator* host, x64::cpu_isa_t host_i prepare_table(); } -jit_erf_emitter::jit_erf_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, const std::shared_ptr& node, ov::element::Type exec_prc) +jit_erf_emitter::jit_erf_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + const std::shared_ptr& node, + ov::element::Type exec_prc) : jit_erf_emitter(host, host_isa, exec_prc) {} -size_t jit_erf_emitter::get_inputs_num() const { return 1; } +size_t jit_erf_emitter::get_inputs_num() const { + return 1; +} std::set> jit_erf_emitter::get_supported_precisions(const std::shared_ptr& node) { return {{element::f32}}; } -void jit_erf_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_erf_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -1972,7 +2191,7 @@ void jit_erf_emitter::emit_impl(const std::vector &in_vec_idxs, const st } template -void jit_erf_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_erf_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src = Vmm(in_vec_idxs[0]); Vmm vmm_dst = Vmm(out_vec_idxs[0]); @@ -1991,8 +2210,11 @@ void jit_erf_emitter::emit_isa(const std::vector &in_vec_idxs, const std // pass the current `aux_vec_idxs` to `exp_emitter` excepting `vmm_aux3` auto exp_aux_vec_idxs = aux_vec_idxs; - exp_aux_vec_idxs.erase(std::find(exp_aux_vec_idxs.begin(), exp_aux_vec_idxs.end(), static_cast(vmm_aux3.getIdx()))); - m_exp_emitter->emit_code({static_cast(vmm_dst.getIdx())}, {static_cast(vmm_dst.getIdx())}, exp_aux_vec_idxs); + exp_aux_vec_idxs.erase( + std::find(exp_aux_vec_idxs.begin(), exp_aux_vec_idxs.end(), static_cast(vmm_aux3.getIdx()))); + m_exp_emitter->emit_code({static_cast(vmm_dst.getIdx())}, + {static_cast(vmm_dst.getIdx())}, + exp_aux_vec_idxs); h->uni_vxorps(vmm_dst, vmm_dst, table_val("sign_mask")); @@ -2027,16 +2249,16 @@ void jit_erf_emitter::emit_isa(const std::vector &in_vec_idxs, const std } void jit_erf_emitter::register_table_entries() { - push_arg_entry_of("approx_const", 0x3ea7ba05, true); // 0.3275911 + push_arg_entry_of("approx_const", 0x3ea7ba05, true); // 0.3275911 push_arg_entry_of("one", CONST_1_F, true); push_arg_entry_of("sign_mask", 0x80000000, true); push_arg_entry_of("positive_mask", 0x7fffffff, true); - push_arg_entry_of("pol1", 0x3e827906, true); // p1 = 0.254829592f - push_arg_entry_of("pol2", 0xbe91a98e, true); // p2 = -0.284496736f - push_arg_entry_of("pol3", 0x3fb5f0e3, true); // p3 = 1.421413741f - push_arg_entry_of("pol4", 0xbfba00e3, true); // p4 = -1.453152027f - push_arg_entry_of("pol5", 0x3f87dc22, true); // p5 = 1.061405429f + push_arg_entry_of("pol1", 0x3e827906, true); // p1 = 0.254829592f + push_arg_entry_of("pol2", 0xbe91a98e, true); // p2 = -0.284496736f + push_arg_entry_of("pol3", 0x3fb5f0e3, true); // p3 = 1.421413741f + push_arg_entry_of("pol4", 0xbfba00e3, true); // p4 = -1.453152027f + push_arg_entry_of("pol5", 0x3f87dc22, true); // p5 = 1.061405429f } size_t jit_erf_emitter::aux_vecs_count() const { @@ -2063,13 +2285,17 @@ jit_soft_sign_emitter::jit_soft_sign_emitter(x64::jit_generator* host, prepare_table(); } -size_t jit_soft_sign_emitter::get_inputs_num() const { return 1; } +size_t jit_soft_sign_emitter::get_inputs_num() const { + return 1; +} -std::set> jit_soft_sign_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_soft_sign_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32}}; } -void jit_soft_sign_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_soft_sign_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -2082,7 +2308,8 @@ void jit_soft_sign_emitter::emit_impl(const std::vector& in_vec_idxs, co } template -void jit_soft_sign_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_soft_sign_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src = Vmm(in_vec_idxs[0]); Vmm vmm_dst = Vmm(out_vec_idxs[0]); @@ -2100,10 +2327,11 @@ void jit_soft_sign_emitter::register_table_entries() { /// IS_FINITE /// template <> -void jit_is_finite_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_is_finite_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { auto vmm_src = Zmm(in_vec_idxs[0]); auto vmm_dst = Zmm(out_vec_idxs[0]); - auto &ones_mask = h->k1; + auto& ones_mask = h->k1; auto reg32_one = Reg32(aux_gpr_idxs[0]); h->mov(reg32_one, CONST_1_F); @@ -2113,13 +2341,14 @@ void jit_is_finite_emitter::emit_isa(const std::vector } template -void jit_is_finite_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_is_finite_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional::type; auto vmm_src = Vmm(in_vec_idxs[0]); auto vmm_dst = Vmm(out_vec_idxs[0]); h->uni_vandps(vmm_src, vmm_src, table_val("inf")); - h->uni_vcmpps(vmm_src, vmm_src, table_val("inf"), 0B00000100); // NEq + h->uni_vcmpps(vmm_src, vmm_src, table_val("inf"), 0B00000100); // NEq if (isa == x64::avx2) { h->uni_vandps(vmm_dst, vmm_src, table_val("one")); @@ -2131,7 +2360,8 @@ void jit_is_finite_emitter::emit_isa(const std::vector &in_vec_idxs, con } } -void jit_is_finite_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_is_finite_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::avx512_core) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -2152,12 +2382,13 @@ void jit_is_finite_emitter::register_table_entries() { /// IS_INF /// template <> -void jit_is_inf_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_is_inf_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { Zmm vmm_src = Zmm(in_vec_idxs[0]); Zmm vmm_dst = Zmm(out_vec_idxs[0]); if (detect_negative || detect_positive) { - auto &ones_mask = h->k1; + auto& ones_mask = h->k1; auto reg32_one = Reg32(aux_gpr_idxs[0]); uint8_t imm = detect_negative ? 0B00010000 : 0B00000000; if (detect_positive) { @@ -2173,7 +2404,8 @@ void jit_is_inf_emitter::emit_isa(const std::vector &i } template -void jit_is_inf_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_is_inf_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional::type; if (detect_negative || detect_positive) { @@ -2204,7 +2436,8 @@ void jit_is_inf_emitter::emit_isa(const std::vector &in_vec_idxs, const } } -void jit_is_inf_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_is_inf_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::avx512_core) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -2226,10 +2459,11 @@ void jit_is_inf_emitter::register_table_entries() { /// IS_NAN /// template <> -void jit_is_nan_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_is_nan_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { auto vmm_src = Zmm(in_vec_idxs[0]); auto vmm_dst = Zmm(out_vec_idxs[0]); - auto &ones_mask = h->k1; + auto& ones_mask = h->k1; auto reg32_one = Reg32(aux_gpr_idxs[0]); h->mov(reg32_one, CONST_1_F); @@ -2238,7 +2472,8 @@ void jit_is_nan_emitter::emit_isa(const std::vector &i } template -void jit_is_nan_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_is_nan_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional::type; auto vmm_src = Vmm(in_vec_idxs[0]); auto vmm_dst = Vmm(out_vec_idxs[0]); @@ -2254,7 +2489,8 @@ void jit_is_nan_emitter::emit_isa(const std::vector &in_vec_idxs, const } } -void jit_is_nan_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_is_nan_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::avx512_core) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -2281,9 +2517,12 @@ jit_select_emitter::jit_select_emitter(x64::jit_generator* host, jit_select_emitter::jit_select_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_select_emitter::get_inputs_num() const { return 3; } +size_t jit_select_emitter::get_inputs_num() const { + return 3; +} -std::set> jit_select_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_select_emitter::get_supported_precisions( + const std::shared_ptr& node) { return {{element::f32, element::f32, element::f32}}; } @@ -2292,11 +2531,12 @@ size_t jit_select_emitter::aux_vecs_count() const { return 0; else if (host_isa_ == x64::avx2) // tmp vec for mask return 1; - else // mask should be xmm0 on sse41 + tmp vec for mask + else // mask should be xmm0 on sse41 + tmp vec for mask return 2; } -void jit_select_emitter::emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_select_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -2309,7 +2549,8 @@ void jit_select_emitter::emit_impl(const std::vector &in_vec_idxs, const } template -void jit_select_emitter::emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const { +void jit_select_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_cond = Vmm(in_vec_idxs[0]); Vmm vmm_src0 = Vmm(in_vec_idxs[1]); @@ -2346,20 +2587,22 @@ jit_bitwise_and_emitter::jit_bitwise_and_emitter(x64::jit_generator* host, ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -jit_bitwise_and_emitter::jit_bitwise_and_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) +jit_bitwise_and_emitter::jit_bitwise_and_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_bitwise_and_emitter::get_inputs_num() const { return 2; } +size_t jit_bitwise_and_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_bitwise_and_emitter::get_supported_precisions(const std::shared_ptr& node) { - return { - {element::i8, element::i8}, - {element::u8, element::u8}, - {element::i32, element::i32} - }; +std::set> jit_bitwise_and_emitter::get_supported_precisions( + const std::shared_ptr& node) { + return {{element::i8, element::i8}, {element::u8, element::u8}, {element::i32, element::i32}}; } -void jit_bitwise_and_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_bitwise_and_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -2372,7 +2615,8 @@ void jit_bitwise_and_emitter::emit_impl(const std::vector& in_vec_idxs, } template -void jit_bitwise_and_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_bitwise_and_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -2399,24 +2643,28 @@ jit_bitwise_not_emitter::jit_bitwise_not_emitter(x64::jit_generator* host, prepare_table(); } -jit_bitwise_not_emitter::jit_bitwise_not_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) +jit_bitwise_not_emitter::jit_bitwise_not_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) { prepare_table(); } -size_t jit_bitwise_not_emitter::get_inputs_num() const { return 1; } +size_t jit_bitwise_not_emitter::get_inputs_num() const { + return 1; +} -std::set> jit_bitwise_not_emitter::get_supported_precisions(const std::shared_ptr& node) { - return { - {element::i8}, - {element::u8}, - {element::i32} - }; +std::set> jit_bitwise_not_emitter::get_supported_precisions( + const std::shared_ptr& node) { + return {{element::i8}, {element::u8}, {element::i32}}; } -size_t jit_bitwise_not_emitter::aux_vecs_count() const { return 1; } +size_t jit_bitwise_not_emitter::aux_vecs_count() const { + return 1; +} -void jit_bitwise_not_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_bitwise_not_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -2429,7 +2677,8 @@ void jit_bitwise_not_emitter::emit_impl(const std::vector& in_vec_idxs, } template -void jit_bitwise_not_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_bitwise_not_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src = Vmm(in_vec_idxs[0]); Vmm vmm_dst = Vmm(out_vec_idxs[0]); @@ -2457,20 +2706,22 @@ jit_bitwise_or_emitter::jit_bitwise_or_emitter(x64::jit_generator* host, ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -jit_bitwise_or_emitter::jit_bitwise_or_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) +jit_bitwise_or_emitter::jit_bitwise_or_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_bitwise_or_emitter::get_inputs_num() const { return 2; } +size_t jit_bitwise_or_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_bitwise_or_emitter::get_supported_precisions(const std::shared_ptr& node) { - return { - {element::i8, element::i8}, - {element::u8, element::u8}, - {element::i32, element::i32} - }; +std::set> jit_bitwise_or_emitter::get_supported_precisions( + const std::shared_ptr& node) { + return {{element::i8, element::i8}, {element::u8, element::u8}, {element::i32, element::i32}}; } -void jit_bitwise_or_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_bitwise_or_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -2483,7 +2734,8 @@ void jit_bitwise_or_emitter::emit_impl(const std::vector& in_vec_idxs, c } template -void jit_bitwise_or_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_bitwise_or_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -2508,20 +2760,22 @@ jit_bitwise_xor_emitter::jit_bitwise_xor_emitter(x64::jit_generator* host, ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -jit_bitwise_xor_emitter::jit_bitwise_xor_emitter(x64::jit_generator* host, x64::cpu_isa_t host_isa, ov::element::Type exec_prc) +jit_bitwise_xor_emitter::jit_bitwise_xor_emitter(x64::jit_generator* host, + x64::cpu_isa_t host_isa, + ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {} -size_t jit_bitwise_xor_emitter::get_inputs_num() const { return 2; } +size_t jit_bitwise_xor_emitter::get_inputs_num() const { + return 2; +} -std::set> jit_bitwise_xor_emitter::get_supported_precisions(const std::shared_ptr& node) { - return { - {element::i8, element::i8}, - {element::u8, element::u8}, - {element::i32, element::i32} - }; +std::set> jit_bitwise_xor_emitter::get_supported_precisions( + const std::shared_ptr& node) { + return {{element::i8, element::i8}, {element::u8, element::u8}, {element::i32, element::i32}}; } -void jit_bitwise_xor_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_bitwise_xor_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { if (host_isa_ == x64::sse41) { emit_isa(in_vec_idxs, out_vec_idxs); } else if (host_isa_ == x64::avx2) { @@ -2534,7 +2788,8 @@ void jit_bitwise_xor_emitter::emit_impl(const std::vector& in_vec_idxs, } template -void jit_bitwise_xor_emitter::emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_bitwise_xor_emitter::emit_isa(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { using Vmm = typename conditional3::type; Vmm vmm_src0 = Vmm(in_vec_idxs[0]); Vmm vmm_src1 = Vmm(in_vec_idxs[1]); @@ -2543,5 +2798,5 @@ void jit_bitwise_xor_emitter::emit_isa(const std::vector& in_vec_idxs, c h->uni_vxorps(vmm_dst, vmm_src0, vmm_src1); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.hpp index c8c4b06d6f3347..84c65d44a12280 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.hpp @@ -11,418 +11,488 @@ namespace intel_cpu { class jit_add_emitter : public jit_emitter { public: - jit_add_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_add_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_add_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); + jit_add_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_mul_add_emitter : public jit_emitter { public: - jit_mul_add_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_mul_add_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_mul_add_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); + jit_mul_add_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; size_t aux_vecs_count() const override; }; - class jit_subtract_emitter : public jit_emitter { public: - jit_subtract_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_subtract_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_subtract_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); + jit_subtract_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; - class jit_multiply_emitter : public jit_emitter { public: - jit_multiply_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_multiply_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_multiply_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); + jit_multiply_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; - class jit_divide_emitter : public jit_emitter { public: - jit_divide_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_divide_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_divide_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_divide_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; size_t aux_vecs_count() const override; }; class jit_floor_emitter : public jit_emitter { public: - jit_floor_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_floor_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_floor_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_floor_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_ceiling_emitter : public jit_emitter { public: - jit_ceiling_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type exec_prc = ov::element::f32); - jit_ceiling_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32); + jit_ceiling_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type exec_prc = ov::element::f32); + jit_ceiling_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_floor_mod_emitter : public jit_emitter { public: - jit_floor_mod_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_floor_mod_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_floor_mod_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_floor_mod_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; size_t aux_vecs_count() const override; }; - class jit_mod_emitter : public jit_emitter { public: - jit_mod_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_mod_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_mod_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_mod_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; size_t aux_vecs_count() const override; }; - class jit_maximum_emitter : public jit_emitter { public: - jit_maximum_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_maximum_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_maximum_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); + jit_maximum_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; - class jit_minimum_emitter : public jit_emitter { public: - jit_minimum_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_minimum_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_minimum_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); + jit_minimum_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; - class jit_squared_difference_emitter : public jit_emitter { public: - jit_squared_difference_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_squared_difference_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_squared_difference_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_squared_difference_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; - class jit_power_dynamic_emitter : public jit_emitter { public: - jit_power_dynamic_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_power_dynamic_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_power_dynamic_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_power_dynamic_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; - class jit_equal_emitter : public jit_emitter { public: - jit_equal_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_equal_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_equal_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_equal_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; size_t aux_vecs_count() const override; }; - class jit_not_equal_emitter : public jit_emitter { public: - jit_not_equal_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_not_equal_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_not_equal_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_not_equal_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; size_t aux_vecs_count() const override; }; - class jit_greater_emitter : public jit_emitter { public: - jit_greater_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_greater_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_greater_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_greater_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; size_t aux_vecs_count() const override; }; - class jit_greater_equal_emitter : public jit_emitter { public: - jit_greater_equal_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_greater_equal_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_greater_equal_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_greater_equal_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; size_t aux_vecs_count() const override; }; - class jit_less_emitter : public jit_emitter { public: - jit_less_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_less_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_less_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_less_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; size_t aux_vecs_count() const override; }; - class jit_less_equal_emitter : public jit_emitter { public: - jit_less_equal_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_less_equal_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_less_equal_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_less_equal_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; size_t aux_vecs_count() const override; }; - class jit_logical_and_emitter : public jit_emitter { public: - jit_logical_and_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_logical_and_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_logical_and_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_logical_and_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; size_t aux_vecs_count() const override; }; - class jit_logical_or_emitter : public jit_emitter { public: - jit_logical_or_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_logical_or_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_logical_or_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_logical_or_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; size_t aux_vecs_count() const override; }; - class jit_logical_xor_emitter : public jit_emitter { public: - jit_logical_xor_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_logical_xor_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_logical_xor_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_logical_xor_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; size_t aux_vecs_count() const override; @@ -430,19 +500,23 @@ class jit_logical_xor_emitter : public jit_emitter { class jit_logical_not_emitter : public jit_emitter { public: - jit_logical_not_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_logical_not_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_logical_not_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_logical_not_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; size_t aux_vecs_count() const override; @@ -450,21 +524,26 @@ class jit_logical_not_emitter : public jit_emitter { class jit_power_static_emitter : public jit_emitter { public: - jit_power_static_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - float inpPower, float inpScale, float inpShift, + jit_power_static_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + float inpPower, + float inpScale, + float inpShift, ov::element::Type exec_prc = ov::element::f32); - jit_power_static_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_power_static_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); - + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; size_t aux_vecs_count() const override; @@ -476,73 +555,90 @@ class jit_power_static_emitter : public jit_emitter { class jit_prelu_emitter : public jit_emitter { public: - jit_prelu_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_prelu_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_prelu_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_prelu_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; size_t aux_vecs_count() const override; }; class jit_sqrt_emitter : public jit_emitter { public: - jit_sqrt_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type exec_prc = ov::element::f32); - jit_sqrt_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32); + jit_sqrt_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type exec_prc = ov::element::f32); + jit_sqrt_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_negative_emitter : public jit_emitter { public: - jit_negative_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32); + jit_negative_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_exp_emitter : public jit_emitter { public: - jit_exp_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_exp_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_exp_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_exp_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; - bool need_vmm_mask() const { return host_isa_ != dnnl::impl::cpu::x64::avx512_core; } + bool need_vmm_mask() const { + return host_isa_ != dnnl::impl::cpu::x64::avx512_core; + } void register_table_entries() override; size_t aux_vecs_count() const override; @@ -550,103 +646,132 @@ class jit_exp_emitter : public jit_emitter { class jit_erf_emitter : public jit_emitter { public: - jit_erf_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type exec_prc = ov::element::f32); + jit_erf_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type exec_prc = ov::element::f32); - jit_erf_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_erf_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); void emit_data() const override; size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl( - const std::vector &in_vec_idxs, - const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; size_t aux_vecs_count() const override; - std::unique_ptr m_exp_emitter {nullptr}; + std::unique_ptr m_exp_emitter{nullptr}; }; class jit_soft_sign_emitter : public jit_emitter { public: - jit_soft_sign_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_soft_sign_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_soft_sign_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_soft_sign_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; void register_table_entries() override; }; class jit_is_finite_emitter : public jit_emitter { public: - jit_is_finite_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t hostIsa, - ov::element::Type execPrc = ov::element::f32) : jit_emitter(host, hostIsa, execPrc) { + jit_is_finite_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t hostIsa, + ov::element::Type execPrc = ov::element::f32) + : jit_emitter(host, hostIsa, execPrc) { prepare_table(); } - jit_is_finite_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t hostIsa, const std::shared_ptr& node, - ov::element::Type execPrc = ov::element::f32) : jit_emitter(host, hostIsa, execPrc) { + jit_is_finite_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t hostIsa, + const std::shared_ptr& node, + ov::element::Type execPrc = ov::element::f32) + : jit_emitter(host, hostIsa, execPrc) { prepare_table(); } - size_t get_inputs_num() const override { return 1; }; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + size_t get_inputs_num() const override { + return 1; + }; + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr) { return {{element::f32}}; } protected: - size_t aux_gprs_count() const override { return (entry_map_.empty() ? 0 : 1) + 1; } + size_t aux_gprs_count() const override { + return (entry_map_.empty() ? 0 : 1) + 1; + } void register_table_entries() override; private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_is_inf_emitter : public jit_emitter { public: - jit_is_inf_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t hostIsa, - ov::element::Type execPrc = ov::element::f32, bool detect_negative = true, bool detect_positive = true) - : jit_emitter(host, hostIsa, execPrc), detect_negative(detect_negative), detect_positive(detect_positive) { + jit_is_inf_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t hostIsa, + ov::element::Type execPrc = ov::element::f32, + bool detect_negative = true, + bool detect_positive = true) + : jit_emitter(host, hostIsa, execPrc), + detect_negative(detect_negative), + detect_positive(detect_positive) { prepare_table(); } - jit_is_inf_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t hostIsa, const std::shared_ptr& node, - ov::element::Type execPrc = ov::element::f32): jit_emitter(host, hostIsa, execPrc) { + jit_is_inf_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t hostIsa, + const std::shared_ptr& node, + ov::element::Type execPrc = ov::element::f32) + : jit_emitter(host, hostIsa, execPrc) { prepare_table(); } - size_t get_inputs_num() const override { return 1; }; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + size_t get_inputs_num() const override { + return 1; + }; + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr) { return {{element::f32}}; } protected: - size_t aux_gprs_count() const override { return (entry_map_.empty() ? 0 : 1) + 1; } + size_t aux_gprs_count() const override { + return (entry_map_.empty() ? 0 : 1) + 1; + } void register_table_entries() override; private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; bool detect_negative; bool detect_positive; @@ -654,58 +779,76 @@ class jit_is_inf_emitter : public jit_emitter { class jit_is_nan_emitter : public jit_emitter { public: - jit_is_nan_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t hostIsa, - ov::element::Type execPrc = ov::element::f32) : jit_emitter(host, hostIsa, execPrc) { + jit_is_nan_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t hostIsa, + ov::element::Type execPrc = ov::element::f32) + : jit_emitter(host, hostIsa, execPrc) { prepare_table(); } - jit_is_nan_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t hostIsa, const std::shared_ptr& node, - ov::element::Type execPrc = ov::element::f32) : jit_emitter(host, hostIsa, execPrc) { + jit_is_nan_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t hostIsa, + const std::shared_ptr& node, + ov::element::Type execPrc = ov::element::f32) + : jit_emitter(host, hostIsa, execPrc) { prepare_table(); } - size_t get_inputs_num() const override { return 1; } - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + size_t get_inputs_num() const override { + return 1; + } + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr) { return {{element::f32}}; } protected: - size_t aux_gprs_count() const override { return (entry_map_.empty() ? 0 : 1) + 1; } + size_t aux_gprs_count() const override { + return (entry_map_.empty() ? 0 : 1) + 1; + } void register_table_entries() override; private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_select_emitter : public jit_emitter { public: - jit_select_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_select_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::element::Type exec_prc = ov::element::f32); - jit_select_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, + jit_select_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); size_t aux_vecs_count() const override; private: - void emit_impl(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const override; + void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; template - void emit_isa(const std::vector &in_vec_idxs, const std::vector &out_vec_idxs) const; + void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; class jit_bitwise_and_emitter : public jit_emitter { public: - jit_bitwise_and_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type exec_prc = ov::element::f32); - jit_bitwise_and_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32); + jit_bitwise_and_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type exec_prc = ov::element::f32); + jit_bitwise_and_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; @@ -716,13 +859,17 @@ class jit_bitwise_and_emitter : public jit_emitter { class jit_bitwise_not_emitter : public jit_emitter { public: - jit_bitwise_not_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type exec_prc = ov::element::f32); - jit_bitwise_not_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32); + jit_bitwise_not_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type exec_prc = ov::element::f32); + jit_bitwise_not_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); size_t aux_vecs_count() const override; private: @@ -735,13 +882,17 @@ class jit_bitwise_not_emitter : public jit_emitter { class jit_bitwise_or_emitter : public jit_emitter { public: - jit_bitwise_or_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type exec_prc = ov::element::f32); - jit_bitwise_or_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32); + jit_bitwise_or_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type exec_prc = ov::element::f32); + jit_bitwise_or_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; @@ -752,13 +903,17 @@ class jit_bitwise_or_emitter : public jit_emitter { class jit_bitwise_xor_emitter : public jit_emitter { public: - jit_bitwise_xor_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type exec_prc = ov::element::f32); - jit_bitwise_xor_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n, - ov::element::Type exec_prc = ov::element::f32); + jit_bitwise_xor_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type exec_prc = ov::element::f32); + jit_bitwise_xor_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n, + ov::element::Type exec_prc = ov::element::f32); size_t get_inputs_num() const override; - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; @@ -767,5 +922,5 @@ class jit_bitwise_xor_emitter : public jit_emitter { void emit_isa(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp index acbb04ea01af80..7ee4d5184b311a 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.cpp @@ -3,9 +3,11 @@ // #include "jit_emitter.hpp" + #include -#include "utils/general_utils.h" + #include "utils.hpp" +#include "utils/general_utils.h" using namespace dnnl::impl::cpu; using namespace dnnl::impl; @@ -19,11 +21,12 @@ size_t jit_emitter::get_max_vecs_count() const { } size_t jit_emitter::get_vec_length() const { - return one_of(host_isa_, cpu::x64::avx512_core, cpu::x64::avx512_core) ? 64 : - one_of(host_isa_, cpu::x64::avx2) ? 32 : 16; + return one_of(host_isa_, cpu::x64::avx512_core, cpu::x64::avx512_core) ? 64 + : one_of(host_isa_, cpu::x64::avx2) ? 32 + : 16; } -void jit_emitter::push_vec(const Xbyak::Address &addr, size_t vec_idx) const { +void jit_emitter::push_vec(const Xbyak::Address& addr, size_t vec_idx) const { if (host_isa_ == cpu::x64::sse41) { h->uni_vmovups(addr, Xmm(vec_idx)); } else if (host_isa_ == cpu::x64::avx2) { @@ -33,7 +36,7 @@ void jit_emitter::push_vec(const Xbyak::Address &addr, size_t vec_idx) const { } } -void jit_emitter::pop_vec(size_t vec_idx, const Xbyak::Address &addr) const { +void jit_emitter::pop_vec(size_t vec_idx, const Xbyak::Address& addr) const { if (host_isa_ == cpu::x64::sse41) { h->uni_vmovups(Xmm(vec_idx), addr); } else if (host_isa_ == cpu::x64::avx2) { @@ -60,11 +63,15 @@ std::set> jit_emitter::get_supported_precisions(const return {}; } -void jit_emitter::emitter_preamble(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const { +void jit_emitter::emitter_preamble(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { using namespace Xbyak::util; - bool is_vec_input = (in_out_type_ == emitter_in_out_map::vec_to_vec) || (in_out_type_ == emitter_in_out_map::vec_to_gpr); - bool is_vec_output = (in_out_type_ == emitter_in_out_map::vec_to_vec) || (in_out_type_ == emitter_in_out_map::gpr_to_vec); + bool is_vec_input = + (in_out_type_ == emitter_in_out_map::vec_to_vec) || (in_out_type_ == emitter_in_out_map::vec_to_gpr); + bool is_vec_output = + (in_out_type_ == emitter_in_out_map::vec_to_vec) || (in_out_type_ == emitter_in_out_map::gpr_to_vec); for (auto idx : pool_vec_idxs) aux_vec_idxs.push_back(idx); @@ -73,9 +80,11 @@ void jit_emitter::emitter_preamble(const std::vector &in_idxs, const std if (host_isa_ == cpu::x64::sse41 && aux_vecs_count() > 0) { size_t idx = 0; if (is_vec_input) - OV_CPU_JIT_EMITTER_ASSERT(std::find(in_idxs.begin(), in_idxs.end(), idx) == in_idxs.end(), "Xmm(0) cannot be input register in SSE41"); + OV_CPU_JIT_EMITTER_ASSERT(std::find(in_idxs.begin(), in_idxs.end(), idx) == in_idxs.end(), + "Xmm(0) cannot be input register in SSE41"); if (is_vec_output) - OV_CPU_JIT_EMITTER_ASSERT(std::find(out_idxs.begin(), out_idxs.end(), idx) == out_idxs.end(), "Xmm(0) cannot be output register in SSE41"); + OV_CPU_JIT_EMITTER_ASSERT(std::find(out_idxs.begin(), out_idxs.end(), idx) == out_idxs.end(), + "Xmm(0) cannot be output register in SSE41"); if (std::find(aux_vec_idxs.begin(), aux_vec_idxs.end(), idx) == aux_vec_idxs.end()) { aux_vec_idxs.push_back(idx); preserved_vec_idxs.push_back(idx); @@ -93,16 +102,21 @@ void jit_emitter::emitter_preamble(const std::vector &in_idxs, const std } for (size_t idx = 0; idx < get_max_vecs_count(); idx++) { - if (aux_vec_idxs.size() >= aux_vecs_count()) break; + if (aux_vec_idxs.size() >= aux_vecs_count()) + break; if (is_vec_input) { - if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) continue; + if (std::find(in_idxs.begin(), in_idxs.end(), idx) != in_idxs.end()) + continue; } if (is_vec_output) { - if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) continue; + if (std::find(out_idxs.begin(), out_idxs.end(), idx) != out_idxs.end()) + continue; } - if (std::find(aux_vec_idxs.begin(), aux_vec_idxs.end(), idx) != aux_vec_idxs.end()) continue; - if (std::find(preserved_vec_idxs.begin(), preserved_vec_idxs.end(), idx) != preserved_vec_idxs.end()) continue; + if (std::find(aux_vec_idxs.begin(), aux_vec_idxs.end(), idx) != aux_vec_idxs.end()) + continue; + if (std::find(preserved_vec_idxs.begin(), preserved_vec_idxs.end(), idx) != preserved_vec_idxs.end()) + continue; aux_vec_idxs.push_back(idx); preserved_vec_idxs.push_back(idx); @@ -115,18 +129,24 @@ void jit_emitter::emitter_preamble(const std::vector &in_idxs, const std aux_gpr_idxs.push_back(idx); for (size_t gpr_idx = 0; gpr_idx <= Operand::R15; ++gpr_idx) { - size_t _idx = Operand::R15 - gpr_idx; // we allocate from the end + size_t _idx = Operand::R15 - gpr_idx; // we allocate from the end - if (aux_gpr_idxs.size() >= aux_gprs_count()) break; - if (_idx == Operand::RSP) continue; + if (aux_gpr_idxs.size() >= aux_gprs_count()) + break; + if (_idx == Operand::RSP) + continue; if (!is_vec_input) { - if (std::find(in_idxs.begin(), in_idxs.end(), _idx) != in_idxs.end()) continue; + if (std::find(in_idxs.begin(), in_idxs.end(), _idx) != in_idxs.end()) + continue; } if (!is_vec_output) { - if (std::find(out_idxs.begin(), out_idxs.end(), _idx) != out_idxs.end()) continue; + if (std::find(out_idxs.begin(), out_idxs.end(), _idx) != out_idxs.end()) + continue; } - if (std::find(aux_gpr_idxs.begin(), aux_gpr_idxs.end(), _idx) != aux_gpr_idxs.end()) continue; - if (std::find(preserved_gpr_idxs.begin(), preserved_gpr_idxs.end(), _idx) != preserved_gpr_idxs.end()) continue; + if (std::find(aux_gpr_idxs.begin(), aux_gpr_idxs.end(), _idx) != aux_gpr_idxs.end()) + continue; + if (std::find(preserved_gpr_idxs.begin(), preserved_gpr_idxs.end(), _idx) != preserved_gpr_idxs.end()) + continue; aux_gpr_idxs.push_back(_idx); preserved_gpr_idxs.push_back(_idx); @@ -154,7 +174,6 @@ void jit_emitter::emitter_preamble(const std::vector &in_idxs, const std load_table_addr(); } - void jit_emitter::emitter_postamble() const { using namespace Xbyak::util; @@ -183,7 +202,7 @@ void jit_emitter::emit_data() const { // Run through the map and insert values stored there for (auto it = entry_map_.begin(); it != entry_map_.end(); it++) { - const auto &te = (*it).second; // get map entry for a given key + const auto& te = (*it).second; // get map entry for a given key const auto len = te.bcast ? get_vec_length() : sizeof(table_entry_val_t); for (size_t d = 0; d < len; d += sizeof(table_entry_val_t)) h->dd(te.val); @@ -199,14 +218,16 @@ void jit_emitter::prepare_table() { // prepare_table. size_t off = 0; for (auto it = entry_map_.begin(); it != entry_map_.end(); it++) { - auto &te = (*it).second; + auto& te = (*it).second; te.off = off; off += te.bcast ? get_vec_length() : sizeof(table_entry_val_t); } } -void jit_emitter::emit_code(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const { +void jit_emitter::emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs); emit_impl(in_idxs, out_idxs); @@ -214,5 +235,5 @@ void jit_emitter::emit_code(const std::vector &in_idxs, const std::vecto emitter_postamble(); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp index c5729613f1bfe5..04ac2e6ea0684d 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp @@ -4,17 +4,17 @@ #pragma once -#include "cpu/x64/jit_generator.hpp" - -#include "snippets/snippets_isa.hpp" -#include "snippets/generator.hpp" -#include "emitters/utils.hpp" #include #include +#include "cpu/x64/jit_generator.hpp" +#include "emitters/utils.hpp" +#include "snippets/generator.hpp" +#include "snippets/snippets_isa.hpp" + #ifdef SNIPPETS_DEBUG_CAPS -#include "emitters/snippets/x64/verbose.hpp" +# include "emitters/snippets/x64/verbose.hpp" #endif namespace ov { @@ -34,14 +34,23 @@ struct emitter_params { class jit_emitter : public ov::snippets::Emitter { public: - jit_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type exec_prc = ov::element::f32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec) - : Emitter(), h(host), host_isa_(host_isa), exec_prc_(exec_prc), l_table (new Xbyak::Label()), in_out_type_(in_out_type) { - k_mask = Xbyak::Opmask(1); // FIXME: in general case we need preserve k_mask state as well + jit_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type exec_prc = ov::element::f32, + emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_vec) + : Emitter(), + h(host), + host_isa_(host_isa), + exec_prc_(exec_prc), + l_table(new Xbyak::Label()), + in_out_type_(in_out_type) { + k_mask = Xbyak::Opmask(1); // FIXME: in general case we need preserve k_mask state as well } - void emit_code(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs = {}, const std::vector &pool_gpr_idxs = {}) const override; + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}) const override; void emit_data() const override; virtual size_t get_inputs_num() const = 0; @@ -53,10 +62,11 @@ class jit_emitter : public ov::snippets::Emitter { * Precisions are ordered, the first bigger bitness precision with the same type will be selected. * Empty collection means the emitter supports any input precisions. */ - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); #ifdef SNIPPETS_DEBUG_CAPS - const char *info() const { + const char* info() const { if (!info_.is_initialized()) info_.init(this); return info_.c_str(); @@ -77,12 +87,14 @@ class jit_emitter : public ov::snippets::Emitter { virtual void prepare_table(); virtual void register_table_entries() {} - void load_table_addr() const { h->mov(p_table, *l_table.get()); } + void load_table_addr() const { + h->mov(p_table, *l_table.get()); + } // we accept only 32bit hexadecimal table values to avoid any rounding using table_entry_val_t = uint32_t; - using table_entry_offset_t = size_t; // offsets are in bytes wrt p_table - using table_entry_bcast_t = bool; // true => bcast value + using table_entry_offset_t = size_t; // offsets are in bytes wrt p_table + using table_entry_bcast_t = bool; // true => bcast value struct table_entry_t { table_entry_val_t val; @@ -106,10 +118,12 @@ class jit_emitter : public ov::snippets::Emitter { _cmp_gt_os = dnnl::impl::cpu::x64::jit_generator::_cmp_nle_us, }; - virtual void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const = 0; + virtual void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const = 0; - virtual void emitter_preamble(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const; + virtual void emitter_preamble(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const; virtual void emitter_postamble() const; emitter_in_out_map in_out_type_; @@ -132,14 +146,14 @@ class jit_emitter : public ov::snippets::Emitter { mapped_table_t entry_map_; void push_arg_entry_of(const std::string key, const table_entry_val_t val, const bool broadcast) { - mapped_table_entry_t te {0, val, broadcast}; + mapped_table_entry_t te{0, val, broadcast}; entry_map_.insert(std::make_pair(key, te)); } - void push_entries_of(const table_t &t) { + void push_entries_of(const table_t& t) { for (auto it = t.begin(); it != t.end(); it++) { auto key = (*it).first; - auto te = (*it).second; // copy values from table + auto te = (*it).second; // copy values from table push_arg_entry_of(key, te.val, te.bcast); } } @@ -155,20 +169,20 @@ class jit_emitter : public ov::snippets::Emitter { mutable std::vector preserved_vec_idxs; mutable std::vector preserved_gpr_idxs; - void push_vec(const Xbyak::Address &addr, size_t vec_idx) const; - void pop_vec(size_t vec_idx, const Xbyak::Address &addr) const; + void push_vec(const Xbyak::Address& addr, size_t vec_idx) const; + void pop_vec(size_t vec_idx, const Xbyak::Address& addr) const; size_t table_off(std::string& key, size_t key_off_val_shift = 0) const { // assumption: all table entries sharing the same key also // share their broadcast property // TODO: enforce through data structure - const auto it = entry_map_.find(key); // search an entry for a key + const auto it = entry_map_.find(key); // search an entry for a key OV_CPU_JIT_EMITTER_ASSERT(it != entry_map_.end(), "Value has not been found in the table"); - const auto &te = (*it).second; + const auto& te = (*it).second; const auto scale = te.bcast ? get_vec_length() : sizeof(table_entry_val_t); return te.off + key_off_val_shift * scale; } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_load_store_emitters.cpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_load_store_emitters.cpp index 893c18768a9511..513c1f70d22932 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_load_store_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_load_store_emitters.cpp @@ -3,6 +3,7 @@ // #include "jit_load_store_emitters.hpp" + #include "utils/bfloat16.hpp" using namespace dnnl::impl; @@ -16,19 +17,20 @@ using namespace Xbyak::util; // An auxiliary vector reg(data_reg_new) is used as destination vector for source pollution instructions, // After updated, processed with new vector and no more need to update as source is preserved. // e.g. with STORE_KEEP_SOURCE(vextractf128, xmm, Xmm(aux_src_idx), ymm, 1); -// if ymm is already updated, h->vextractf128(xmm, ymm, 1) is used, which change ymm values as xmm and ymm have the same index. -// if ymm is not updated, h->vextractf128(Xmm(aux_src_idx), ymm, 1) is used, which keep ymm values unchanged as destination is another vector reg. +// if ymm is already updated, h->vextractf128(xmm, ymm, 1) is used, which change ymm values as xmm and ymm have the +// same index. if ymm is not updated, h->vextractf128(Xmm(aux_src_idx), ymm, 1) is used, which keep ymm values +// unchanged as destination is another vector reg. #define STORE_KEEP_SOURCE(instruction, data_reg, data_reg_new, ...) \ - if (data_reg_updated) { \ - h->instruction(data_reg, __VA_ARGS__); \ - } else { \ - h->instruction(data_reg_new, __VA_ARGS__); \ - data_idx = aux_src_idx; \ - xmm = Xbyak::Xmm(data_idx); \ - ymm = Xbyak::Ymm(data_idx); \ - zmm = Xbyak::Zmm(data_idx); \ - vmm = Vmm(data_idx); \ - data_reg_updated = true; \ + if (data_reg_updated) { \ + h->instruction(data_reg, __VA_ARGS__); \ + } else { \ + h->instruction(data_reg_new, __VA_ARGS__); \ + data_idx = aux_src_idx; \ + xmm = Xbyak::Xmm(data_idx); \ + ymm = Xbyak::Ymm(data_idx); \ + zmm = Xbyak::Zmm(data_idx); \ + vmm = Vmm(data_idx); \ + data_reg_updated = true; \ } namespace ov { @@ -39,7 +41,7 @@ namespace { constexpr int threshold_for_mask_emu_load = 14; // heuristic threshold number by byte between mask store and emulation with several simple partial store constexpr int threshold_for_mask_emu_store = 6; -} // namespace +} // namespace size_t load_emitter_params::hash() const { size_t seed = 0; @@ -61,46 +63,69 @@ size_t store_emitter_params::hash() const { return seed; } -static int get_aux_regs_as_temp(const int elem_count, const int data_size, bool is_pure_move, bool is_store_as_real16, - const int avx512_threshold_for_mask = 0, const bool is_fill = false) { +static int get_aux_regs_as_temp(const int elem_count, + const int data_size, + bool is_pure_move, + bool is_store_as_real16, + const int avx512_threshold_for_mask = 0, + const bool is_fill = false) { if (mayiuse(cpu::x64::avx512_core) && is_fill) return 1; // for pure move, there are direct no-mask instructions to move on full xmm/ymm/zmm, so aux_gpr is not needed. // for move+convert: - // there are direct no-mask instructions to load i8/u8/i16/u16/bf16/fp16 to full xmm/ymm/zmm as f32/i32, so aux_gpr is not needed. - // there are direct no-mask instructions to store i32 on full xmm/ymm/zmm to i8/u8/i16/u16, so aux_gpr is not needed. - // store f32 on full xmm/ymm/zmm to bf16/fp16, need convert to bf16/fp16 on vmm, then store vmm to memory, use store_dword_to_word/byte_base condition. - // store_num == 16, vector: 16 * f32 -> 16 * bf16 -> ymm(256bit) -> store - // store_num == 8, vector: 8 * f32 -> 8 * bf16 -> xmm(128bit) -> store - // store_num == 4, vector: 4 * f32 -> 4 * bf16 -> 64bit -> masked instruction with aux_gpr needed - // f32<->i32 is on full vmm, so aux_gpr is not needed. + // there are direct no-mask instructions to load i8/u8/i16/u16/bf16/fp16 to full xmm/ymm/zmm as f32/i32, so aux_gpr + // is not needed. there are direct no-mask instructions to store i32 on full xmm/ymm/zmm to i8/u8/i16/u16, so + // aux_gpr is not needed. store f32 on full xmm/ymm/zmm to bf16/fp16, need convert to bf16/fp16 on vmm, then store + // vmm to memory, use store_dword_to_word/byte_base condition. store_num == 16, vector: 16 * f32 -> 16 * bf16 -> + // ymm(256bit) -> store store_num == 8, vector: 8 * f32 -> 8 * bf16 -> xmm(128bit) -> store store_num == 4, + // vector: 4 * f32 -> 4 * bf16 -> 64bit -> masked instruction with aux_gpr needed f32<->i32 is on full vmm, + // so aux_gpr is not needed. const int byte_size = elem_count * data_size; - if ((is_pure_move && one_of(byte_size, 16, 32, 64)) || (!is_pure_move && one_of(elem_count, 4, 8, 16) && !is_store_as_real16)) + if ((is_pure_move && one_of(byte_size, 16, 32, 64)) || + (!is_pure_move && one_of(elem_count, 4, 8, 16) && !is_store_as_real16)) return 0; - if ((mayiuse(cpu::x64::avx512_core) && (byte_size > avx512_threshold_for_mask)) || (one_of(byte_size % 16, 1, 2, 3))) + if ((mayiuse(cpu::x64::avx512_core) && (byte_size > avx512_threshold_for_mask)) || + (one_of(byte_size % 16, 1, 2, 3))) return 1; return 0; } /// LOAD /// -jit_load_emitter::jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type src_prc, ov::element::Type dst_prc, int load_num, ov::element::Type exec_prc, - bool is_fill, std::string fill_value, emitter_in_out_map in_out_type) -: jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), load_num_(load_num), src_prc_(src_prc), - dst_prc_(dst_prc), is_fill_(is_fill), fill_value_(fill_value) { +jit_load_emitter::jit_load_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type src_prc, + ov::element::Type dst_prc, + int load_num, + ov::element::Type exec_prc, + bool is_fill, + std::string fill_value, + emitter_in_out_map in_out_type) + : jit_emitter(host, host_isa, exec_prc, in_out_type), + name_("unknown"), + load_num_(load_num), + src_prc_(src_prc), + dst_prc_(dst_prc), + is_fill_(is_fill), + fill_value_(fill_value) { prepare_table(); load_size_ = load_num * src_prc.size(); v_len_elt_ = get_vec_length() / exec_prc.size(); } -size_t jit_load_emitter::get_inputs_num() const { return 1; } +size_t jit_load_emitter::get_inputs_num() const { + return 1; +} size_t jit_load_emitter::aux_gprs_count() const { // 0 for temp reg for mask load in avx512 if needed - const auto is_pure_load = (src_prc_ == dst_prc_) || - (one_of(src_prc_, ov::element::f32, ov::element::i32) && - one_of(dst_prc_, ov::element::f32, ov::element::i32)); - int count = get_aux_regs_as_temp(load_num_, static_cast(src_prc_.size()), is_pure_load, false, threshold_for_mask_emu_load, is_fill_); + const auto is_pure_load = (src_prc_ == dst_prc_) || (one_of(src_prc_, ov::element::f32, ov::element::i32) && + one_of(dst_prc_, ov::element::f32, ov::element::i32)); + int count = get_aux_regs_as_temp(load_num_, + static_cast(src_prc_.size()), + is_pure_load, + false, + threshold_for_mask_emu_load, + is_fill_); // 1 for table address if (is_fill_) @@ -109,7 +134,7 @@ size_t jit_load_emitter::aux_gprs_count() const { return count; } -void jit_load_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_load_emitter::emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const { // offset in load emitter is the offset of src gpr register, should be parsed from in_idxs. const int offset = in_idxs.size() == 2 ? in_idxs[1] : 0; if (host_isa_ == cpu::x64::sse41) { @@ -124,7 +149,7 @@ void jit_load_emitter::emit_impl(const std::vector &in_idxs, const std:: } template -void jit_load_emitter::emit_isa(const Xbyak::Reg64 ®_src, const int out_vec_idx, const int offset) const { +void jit_load_emitter::emit_isa(const Xbyak::Reg64& reg_src, const int out_vec_idx, const int offset) const { bool matched_prc = (dst_prc_ == src_prc_) || (dst_prc_ == ov::element::f32) || (dst_prc_ == ov::element::i32); if (!matched_prc) { OV_CPU_JIT_EMITTER_THROW("only support output precision of FP32 or I32 or the same precision as input."); @@ -139,43 +164,43 @@ void jit_load_emitter::emit_isa(const Xbyak::Reg64 ®_src, const int out_vec_i if (src_prc_ == dst_prc_) { load_bytes(Vmm(out_vec_idx), reg_src, offset, load_size_); } else { - // "pure load" + convert. dst_prc must be FP32 or I32. + // "pure load" + convert. dst_prc must be FP32 or I32. switch (src_prc_) { - case ov::element::f32: - case ov::element::i32: - load_bytes(Vmm(out_vec_idx), reg_src, offset, load_size_); - break; - case ov::element::i8: - load_bytes_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, true, load_size_); - break; - case ov::element::u8: - load_bytes_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, false, load_size_); - break; - case ov::element::i16: - case ov::element::u16: - case ov::element::bf16: - case ov::element::f16: - load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, src_prc_, load_size_); - break; - default: - OV_CPU_JIT_EMITTER_THROW("has unsupported src precision to load."); + case ov::element::f32: + case ov::element::i32: + load_bytes(Vmm(out_vec_idx), reg_src, offset, load_size_); + break; + case ov::element::i8: + load_bytes_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, true, load_size_); + break; + case ov::element::u8: + load_bytes_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, false, load_size_); + break; + case ov::element::i16: + case ov::element::u16: + case ov::element::bf16: + case ov::element::f16: + load_words_to_dword_extension(Vmm(out_vec_idx), reg_src, offset, src_prc_, load_size_); + break; + default: + OV_CPU_JIT_EMITTER_THROW("has unsupported src precision to load."); } } // post convert between I32 and FP32 if (src_prc_ != dst_prc_) { switch (dst_prc_) { - case ov::element::f32: - if (!src_prc_.is_real()) - h->uni_vcvtdq2ps(Vmm(out_vec_idx), Vmm(out_vec_idx)); - break; - case ov::element::i32: - if (src_prc_.is_real()) { - h->uni_vcvtps2dq(Vmm(out_vec_idx), Vmm(out_vec_idx)); - } - break; - default: - break; + case ov::element::f32: + if (!src_prc_.is_real()) + h->uni_vcvtdq2ps(Vmm(out_vec_idx), Vmm(out_vec_idx)); + break; + case ov::element::i32: + if (src_prc_.is_real()) { + h->uni_vcvtps2dq(Vmm(out_vec_idx), Vmm(out_vec_idx)); + } + break; + default: + break; } } @@ -186,19 +211,19 @@ void jit_load_emitter::emit_isa(const Xbyak::Reg64 ®_src, const int out_vec_i } /** -* load_bytes is the utility function to facilitate loading of -* load_size (0 <= load_size <= 64) many contiguous bytes into the Xmm/Ymm/Zmm -* register from the memory referenced by ptr[reg + offset] address. -* -* Functionally, invocation of load_bytes is equivalent to -* the following loop: -* -* for (int idx = 0; idx < load_size; ++idx) -* vpinsrb(vmm, vmm, ptr[reg + offset + idx], idx); -* -*/ + * load_bytes is the utility function to facilitate loading of + * load_size (0 <= load_size <= 64) many contiguous bytes into the Xmm/Ymm/Zmm + * register from the memory referenced by ptr[reg + offset] address. + * + * Functionally, invocation of load_bytes is equivalent to + * the following loop: + * + * for (int idx = 0; idx < load_size; ++idx) + * vpinsrb(vmm, vmm, ptr[reg + offset + idx], idx); + * + */ template -void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int load_size) const { +void jit_load_emitter::load_bytes(const Vmm& vmm, const Xbyak::Reg64& reg, int offset, int load_size) const { constexpr bool is_xmm = std::is_same::value; constexpr bool is_ymm = std::is_same::value; constexpr bool is_zmm = std::is_same::value; @@ -249,14 +274,17 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int o } // Cornerstone of partial load is combinaion of vpinsrb/w/d. - // As vpinsrb/w/d will not only write(insert) values into vmm, but also read values in vmm to copy from to positions that not in imm mask, - // this could introduce RAW false dependency(we actually do not care about values not in imm mask). - // To eliminate this false dependency, + // As vpinsrb/w/d will not only write(insert) values into vmm, but also read values in vmm to copy from to + // positions that not in imm mask, this could introduce RAW false dependency(we actually do not care about + // values not in imm mask). To eliminate this false dependency, // 1. For 1/2/3/4 bytes tails, replace vpinsrb/w/d with mov,shl etc instructions that don't read vmm. - // Besides eliminate RAW, these instructions have smaller latency, which also bring better perf, especially for small loop iteration case. + // Besides eliminate RAW, these instructions have smaller latency, which also bring better perf, especially + // for small loop iteration case. // 2. For 8/16 bytes, use vmovq/vmovdqu instructions to load, which also don't read src vmm. - // 3. For other size, insert vpxor before vpinsrb/w/d. vpxor and read vmm instructions in previous loop have WAR(write after read) relationship. - // CPU can identify this scenario and assign another physical vector register(register renameing) in next loop to eliminate RAW. + // 3. For other size, insert vpxor before vpinsrb/w/d. vpxor and read vmm instructions in previous loop have + // WAR(write after read) relationship. + // CPU can identify this scenario and assign another physical vector register(register renameing) in next + // loop to eliminate RAW. if (!one_of(bytes_to_load, 0, 1, 2, 3, 4, 8, 16)) { h->uni_vpxor(vmm, vmm, vmm); } @@ -266,121 +294,136 @@ void jit_load_emitter::load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int o h->uni_vmovdqu(xmm, addr(start_bytes)); switch (bytes_to_load) { - case 0: break; - case 1: - h->movzx(Reg32(aux_gpr_idxs[0]), addr(start_bytes)); - h->uni_vmovq(xmm, Reg64(aux_gpr_idxs[0])); - break; - case 2: - h->movzx(Reg32(aux_gpr_idxs[0]), word_addr(start_bytes)); - h->uni_vmovq(xmm, Reg64(aux_gpr_idxs[0])); - break; - case 3: - h->movzx(Reg32(aux_gpr_idxs[0]), addr(start_bytes + 2)); - h->shl(Reg32(aux_gpr_idxs[0]), 16); - h->mov(Reg16(aux_gpr_idxs[0]), word_addr(start_bytes)); - h->uni_vmovq(xmm, Reg64(aux_gpr_idxs[0])); - break; - case 4: h->uni_vmovss(xmm, addr(start_bytes)); break; - case 5: - h->uni_vmovss(xmm, addr(start_bytes)); - h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 4), 4); - break; - case 6: - h->uni_vmovss(xmm, addr(start_bytes)); - h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 4), 2); - break; - case 7: - h->uni_vmovss(xmm, addr(start_bytes)); - h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 4), 2); - h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 6), 6); - break; - case 8: break; - case 9: h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 8), 8); break; - case 10: h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 8), 4); break; - case 11: - h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 8), 4); - h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 10), 10); - break; - case 12: h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); break; - case 13: - h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); - h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 12), 12); - break; - case 14: - h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); - h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 12), 6); - break; - case 15: - h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); - h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 12), 6); - h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 14), 14); - break; - case 16: break; - default: - OV_CPU_JIT_EMITTER_THROW("has unexpected number of values to load in load_byte."); + case 0: + break; + case 1: + h->movzx(Reg32(aux_gpr_idxs[0]), addr(start_bytes)); + h->uni_vmovq(xmm, Reg64(aux_gpr_idxs[0])); + break; + case 2: + h->movzx(Reg32(aux_gpr_idxs[0]), word_addr(start_bytes)); + h->uni_vmovq(xmm, Reg64(aux_gpr_idxs[0])); + break; + case 3: + h->movzx(Reg32(aux_gpr_idxs[0]), addr(start_bytes + 2)); + h->shl(Reg32(aux_gpr_idxs[0]), 16); + h->mov(Reg16(aux_gpr_idxs[0]), word_addr(start_bytes)); + h->uni_vmovq(xmm, Reg64(aux_gpr_idxs[0])); + break; + case 4: + h->uni_vmovss(xmm, addr(start_bytes)); + break; + case 5: + h->uni_vmovss(xmm, addr(start_bytes)); + h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 4), 4); + break; + case 6: + h->uni_vmovss(xmm, addr(start_bytes)); + h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 4), 2); + break; + case 7: + h->uni_vmovss(xmm, addr(start_bytes)); + h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 4), 2); + h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 6), 6); + break; + case 8: + break; + case 9: + h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 8), 8); + break; + case 10: + h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 8), 4); + break; + case 11: + h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 8), 4); + h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 10), 10); + break; + case 12: + h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); + break; + case 13: + h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); + h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 12), 12); + break; + case 14: + h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); + h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 12), 6); + break; + case 15: + h->uni_vpinsrd(xmm, xmm, addr(start_bytes + 8), 2); + h->uni_vpinsrw(xmm, xmm, addr(start_bytes + 12), 6); + h->uni_vpinsrb(xmm, xmm, addr(start_bytes + 14), 14); + break; + case 16: + break; + default: + OV_CPU_JIT_EMITTER_THROW("has unexpected number of values to load in load_byte."); } if (has_xmm_block) { - h->vinsertf128(ymm, ymm, xmm, 1); // insert to upper bits of ymm + h->vinsertf128(ymm, ymm, xmm, 1); // insert to upper bits of ymm if (has_ymm_block) - h->vinsertf128(ymm, ymm, addr(32), 0); // insert to lower bits of ymm + h->vinsertf128(ymm, ymm, addr(32), 0); // insert to lower bits of ymm else - h->vinsertf128(ymm, ymm, addr(0), 0); // insert to lower bits of ymm + h->vinsertf128(ymm, ymm, addr(0), 0); // insert to lower bits of ymm } if (has_ymm_block) { - h->vinsertf64x4(zmm, zmm, ymm, 1); // insert to upper bits of zmm - h->vinsertf64x4(zmm, zmm, addr(0), 0); // insert to lower bits of zmm + h->vinsertf64x4(zmm, zmm, ymm, 1); // insert to upper bits of zmm + h->vinsertf64x4(zmm, zmm, addr(0), 0); // insert to lower bits of zmm } }; switch (load_size) { - case 64: - h->uni_vmovdqu(zmm, addr(0)); - break; - case 32: - h->uni_vmovdqu(ymm, addr(0)); - break; - case 16: - h->uni_vmovdqu(xmm, addr(0)); - break; - default: { - if (mayiuse(cpu::x64::avx512_core) && load_size > threshold_for_mask_emu_load) { - uint64_t mask = 1; - mask = (mask << load_size) - mask; - h->mov(Reg64(aux_gpr_idxs[0]), mask); - h->kmovq(k_mask, Reg64(aux_gpr_idxs[0])); - h->vmovdqu8(zmm | k_mask | T_z, addr(0)); - } else { - load_byte_base(); - } - break; + case 64: + h->uni_vmovdqu(zmm, addr(0)); + break; + case 32: + h->uni_vmovdqu(ymm, addr(0)); + break; + case 16: + h->uni_vmovdqu(xmm, addr(0)); + break; + default: { + if (mayiuse(cpu::x64::avx512_core) && load_size > threshold_for_mask_emu_load) { + uint64_t mask = 1; + mask = (mask << load_size) - mask; + h->mov(Reg64(aux_gpr_idxs[0]), mask); + h->kmovq(k_mask, Reg64(aux_gpr_idxs[0])); + h->vmovdqu8(zmm | k_mask | T_z, addr(0)); + } else { + load_byte_base(); } + break; + } } } /** -* load_bytes_to_dword_extension is the utility function to facilitate -* loading of load_size (0 <= load_size <= 16) many contiguous bytes in -* the xmm register from the memory referenced by ptr[reg + offset] -* address and then do signed/zero extension of those to double words. -* -* Functionally, invocation of load_bytes_to_dword_extension is equivalent -* to the following: -* -* for (int idx = 0; idx < load_size; ++idx) -* vpinsrb(vmm, vmm, ptr[reg + offset + idx], idx); -* if (is_signed) vpmovsxbd(vmm, vmm); else vpmovzxbd(vmm, vmm); -* -* Valid values for the load_size variable are: -* [0..4] for XMM version of the function, i.e. 4 bytes -> 4 * 32 bit == 128 bit -* [0..8] for YMM version of the function. i.e. 8 bytes -> 8 * 32 bit == 256 bit -* [0..16] for ZMM version of the function. i.e. 16 bytes -> 16 * 32 bit == 512 bit -*/ + * load_bytes_to_dword_extension is the utility function to facilitate + * loading of load_size (0 <= load_size <= 16) many contiguous bytes in + * the xmm register from the memory referenced by ptr[reg + offset] + * address and then do signed/zero extension of those to double words. + * + * Functionally, invocation of load_bytes_to_dword_extension is equivalent + * to the following: + * + * for (int idx = 0; idx < load_size; ++idx) + * vpinsrb(vmm, vmm, ptr[reg + offset + idx], idx); + * if (is_signed) vpmovsxbd(vmm, vmm); else vpmovzxbd(vmm, vmm); + * + * Valid values for the load_size variable are: + * [0..4] for XMM version of the function, i.e. 4 bytes -> 4 * 32 bit == 128 bit + * [0..8] for YMM version of the function. i.e. 8 bytes -> 8 * 32 bit == 256 bit + * [0..16] for ZMM version of the function. i.e. 16 bytes -> 16 * 32 bit == 512 bit + */ template -void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int load_size) const { +void jit_load_emitter::load_bytes_to_dword_extension(const Vmm& vmm, + const Xbyak::Reg64& reg, + int offset, + bool is_signed, + int load_size) const { constexpr bool is_xmm = std::is_same::value; constexpr bool is_ymm = std::is_same::value; constexpr bool is_zmm = std::is_same::value; @@ -401,76 +444,80 @@ void jit_load_emitter::load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak // For load_size == 4/8/16, do load/extension in one go switch (load_size) { - case 16: { - // full size of zmm - const auto zmm = Xbyak::Zmm(vmm.getIdx()); - if (is_signed) - h->uni_vpmovsxbd(zmm, ptr[reg + offset]); - else - h->uni_vpmovzxbd(zmm, ptr[reg + offset]); - break; - } - case 8: { - // full size of ymm or ymm_block of zmm - const auto ymm = Xbyak::Ymm(vmm.getIdx()); + case 16: { + // full size of zmm + const auto zmm = Xbyak::Zmm(vmm.getIdx()); + if (is_signed) + h->uni_vpmovsxbd(zmm, ptr[reg + offset]); + else + h->uni_vpmovzxbd(zmm, ptr[reg + offset]); + break; + } + case 8: { + // full size of ymm or ymm_block of zmm + const auto ymm = Xbyak::Ymm(vmm.getIdx()); + if (is_signed) + h->uni_vpmovsxbd(ymm, ptr[reg + offset]); + else + h->uni_vpmovzxbd(ymm, ptr[reg + offset]); + break; + } + case 4: { + // full size of xmm or xmm_block of ymm/zmm + const auto xmm = Xbyak::Xmm(vmm.getIdx()); + if (is_signed) + h->uni_vpmovsxbd(xmm, ptr[reg + offset]); + else + h->uni_vpmovzxbd(xmm, ptr[reg + offset]); + break; + } + default: { + if (is_zmm && load_size > threshold_for_mask_emu_load) { + unsigned int mask = 1; + mask = (mask << load_size) - mask; + h->mov(Reg32(aux_gpr_idxs[0]), mask); + h->kmovw(k_mask, Reg32(aux_gpr_idxs[0])); if (is_signed) - h->uni_vpmovsxbd(ymm, ptr[reg + offset]); + h->uni_vpmovsxbd(vmm | k_mask | T_z, ptr[reg + offset]); else - h->uni_vpmovzxbd(ymm, ptr[reg + offset]); - break; - } - case 4: { - // full size of xmm or xmm_block of ymm/zmm + h->uni_vpmovzxbd(vmm | k_mask | T_z, ptr[reg + offset]); + } else { const auto xmm = Xbyak::Xmm(vmm.getIdx()); + load_bytes(xmm, reg, offset, load_size); if (is_signed) - h->uni_vpmovsxbd(xmm, ptr[reg + offset]); + h->uni_vpmovsxbd(vmm, xmm); else - h->uni_vpmovzxbd(xmm, ptr[reg + offset]); - break; - } - default: { - if (is_zmm && load_size > threshold_for_mask_emu_load) { - unsigned int mask = 1; - mask = (mask << load_size) - mask; - h->mov(Reg32(aux_gpr_idxs[0]), mask); - h->kmovw(k_mask, Reg32(aux_gpr_idxs[0])); - if (is_signed) - h->uni_vpmovsxbd(vmm | k_mask | T_z, ptr[reg + offset]); - else - h->uni_vpmovzxbd(vmm | k_mask | T_z, ptr[reg + offset]); - } else { - const auto xmm = Xbyak::Xmm(vmm.getIdx()); - load_bytes(xmm, reg, offset, load_size); - if (is_signed) - h->uni_vpmovsxbd(vmm, xmm); - else - h->uni_vpmovzxbd(vmm, xmm); - } - break; + h->uni_vpmovzxbd(vmm, xmm); } + break; + } } } /** -* load_words_to_dword_extension is the utility function to facilitate -* loading of load_size (0 <= load_size <= 32) byte many contiguous words(num == load_size / 2) -* in the Vmm register from the memory referenced by ptr[reg + offset] -* address and then do signed/zero extension of those to double words. -* -* Functionally, invocation of load_words_to_dword_extension is equivalent -* to the following extended pseudo code: -* -* for (int idx = 0; idx < load_size / 2; ++idx) -* vpinsrw(vmm, vmm, ptr[reg + offset + 2 * idx], idx); -* if (is_signed) vpmovsxwd(vmm, vmm); else vpmovzxwd(vmm, vmm); -* -* Valid values for the load_size variable are: -* [0..8] for XMM version of the function. i.e. 4 words -> 4 * 32 bit == 128 bit -* [0..16] for YMM version of the function. i.e. 8 words -> 8 * 32 bit == 256 bit -* [0.. 32] for ZMM version of the function. i.e. 16 words -> 16 * 32 bit == 512 bit -*/ + * load_words_to_dword_extension is the utility function to facilitate + * loading of load_size (0 <= load_size <= 32) byte many contiguous words(num == load_size / 2) + * in the Vmm register from the memory referenced by ptr[reg + offset] + * address and then do signed/zero extension of those to double words. + * + * Functionally, invocation of load_words_to_dword_extension is equivalent + * to the following extended pseudo code: + * + * for (int idx = 0; idx < load_size / 2; ++idx) + * vpinsrw(vmm, vmm, ptr[reg + offset + 2 * idx], idx); + * if (is_signed) vpmovsxwd(vmm, vmm); else vpmovzxwd(vmm, vmm); + * + * Valid values for the load_size variable are: + * [0..8] for XMM version of the function. i.e. 4 words -> 4 * 32 bit == 128 bit + * [0..16] for YMM version of the function. i.e. 8 words -> 8 * 32 bit == 256 bit + * [0.. 32] for ZMM version of the function. i.e. 16 words -> 16 * 32 bit == 512 bit + */ template -void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, ov::element::Type prc, int load_size) const { +void jit_load_emitter::load_words_to_dword_extension(const Vmm& vmm, + const Xbyak::Reg64& reg, + int offset, + ov::element::Type prc, + int load_size) const { constexpr bool is_xmm = std::is_same::value; constexpr bool is_ymm = std::is_same::value; constexpr bool is_zmm = std::is_same::value; @@ -503,87 +550,87 @@ void jit_load_emitter::load_words_to_dword_extension(const Vmm &vmm, const Xbyak // For load_size == 32/16/8, do load/extension in one go // including xmm/ymm tail block for ymm/zmm, so explicite xmm/ymm/zmm switch (load_size) { - case 32: { - if (is_bf16) { + case 32: { + if (is_bf16) { + h->uni_vpmovzxwd(zmm, ptr[reg + offset]); + h->uni_vpslld(zmm, zmm, 16); + } else if (is_f16) { + h->vcvtph2ps(zmm, ptr[reg + offset]); + } else { + if (is_signed) + h->uni_vpmovsxwd(zmm, ptr[reg + offset]); + else h->uni_vpmovzxwd(zmm, ptr[reg + offset]); - h->uni_vpslld(zmm, zmm, 16); - } else if (is_f16) { - h->vcvtph2ps(zmm, ptr[reg + offset]); - } else { - if (is_signed) - h->uni_vpmovsxwd(zmm, ptr[reg + offset]); - else - h->uni_vpmovzxwd(zmm, ptr[reg + offset]); - } - break; } - case 16: { - if (is_bf16) { + break; + } + case 16: { + if (is_bf16) { + h->uni_vpmovzxwd(ymm, ptr[reg + offset]); + h->uni_vpslld(ymm, ymm, 16); + } else if (is_f16) { + h->vcvtph2ps(ymm, ptr[reg + offset]); + } else { + if (is_signed) + h->uni_vpmovsxwd(ymm, ptr[reg + offset]); + else h->uni_vpmovzxwd(ymm, ptr[reg + offset]); - h->uni_vpslld(ymm, ymm, 16); + } + break; + } + case 8: { + if (is_bf16) { + h->uni_vpmovzxwd(xmm, ptr[reg + offset]); + h->uni_vpslld(xmm, xmm, 16); + } else if (is_f16) { + h->vcvtph2ps(xmm, ptr[reg + offset]); + } else { + if (is_signed) + h->uni_vpmovsxwd(xmm, ptr[reg + offset]); + else + h->uni_vpmovzxwd(xmm, ptr[reg + offset]); + } + break; + } + default: { + if (is_zmm && load_size > threshold_for_mask_emu_load) { + unsigned int mask = 1; + mask = (mask << (load_size / 2)) - mask; + h->mov(Reg32(aux_gpr_idxs[0]), mask); + h->kmovw(k_mask, Reg32(aux_gpr_idxs[0])); + if (is_bf16) { + h->uni_vpmovzxwd(vmm | k_mask | T_z, ptr[reg + offset]); + h->uni_vpslld(vmm, vmm, 16); } else if (is_f16) { - h->vcvtph2ps(ymm, ptr[reg + offset]); + h->vcvtph2ps(vmm | k_mask | T_z, ptr[reg + offset]); } else { if (is_signed) - h->uni_vpmovsxwd(ymm, ptr[reg + offset]); + h->uni_vpmovsxwd(vmm | k_mask | T_z, ptr[reg + offset]); else - h->uni_vpmovzxwd(ymm, ptr[reg + offset]); + h->uni_vpmovzxwd(vmm | k_mask | T_z, ptr[reg + offset]); } - break; - } - case 8: { + } else { + // xmm or ymm version + load_bytes(xmm, reg, offset, load_size); if (is_bf16) { - h->uni_vpmovzxwd(xmm, ptr[reg + offset]); - h->uni_vpslld(xmm, xmm, 16); + h->uni_vpmovzxwd(vmm, xmm); + h->uni_vpslld(vmm, vmm, 16); } else if (is_f16) { - h->vcvtph2ps(xmm, ptr[reg + offset]); + h->vcvtph2ps(ymm, xmm); } else { if (is_signed) - h->uni_vpmovsxwd(xmm, ptr[reg + offset]); + h->uni_vpmovsxwd(vmm, xmm); else - h->uni_vpmovzxwd(xmm, ptr[reg + offset]); - } - break; - } - default: { - if (is_zmm && load_size > threshold_for_mask_emu_load) { - unsigned int mask = 1; - mask = (mask << (load_size / 2)) - mask; - h->mov(Reg32(aux_gpr_idxs[0]), mask); - h->kmovw(k_mask, Reg32(aux_gpr_idxs[0])); - if (is_bf16) { - h->uni_vpmovzxwd(vmm | k_mask | T_z, ptr[reg + offset]); - h->uni_vpslld(vmm, vmm, 16); - } else if (is_f16) { - h->vcvtph2ps(vmm | k_mask | T_z, ptr[reg + offset]); - } else { - if (is_signed) - h->uni_vpmovsxwd(vmm | k_mask | T_z, ptr[reg + offset]); - else - h->uni_vpmovzxwd(vmm | k_mask | T_z, ptr[reg + offset]); - } - } else { - // xmm or ymm version - load_bytes(xmm, reg, offset, load_size); - if (is_bf16) { h->uni_vpmovzxwd(vmm, xmm); - h->uni_vpslld(vmm, vmm, 16); - } else if (is_f16) { - h->vcvtph2ps(ymm, xmm); - } else { - if (is_signed) - h->uni_vpmovsxwd(vmm, xmm); - else - h->uni_vpmovzxwd(vmm, xmm); - } } - break; } + break; + } } } template -void jit_load_emitter::fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const { +void jit_load_emitter::fill_with_default(const Vmm& vmm, std::string fill_value, const int& load_num) const { constexpr bool is_xmm = std::is_same::value; constexpr bool is_ymm = std::is_same::value; constexpr bool is_zmm = std::is_same::value; @@ -614,10 +661,20 @@ void jit_load_emitter::register_table_entries() { } /// STORE /// -jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, arithmetic_mode mode, ov::element::Type exec_prc, +jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type src_prc, + ov::element::Type dst_prc, + int store_num, + arithmetic_mode mode, + ov::element::Type exec_prc, emitter_in_out_map in_out_type) - : jit_emitter(host, host_isa, exec_prc, in_out_type), name_("unknown"), store_num_(store_num), src_prc_(src_prc), dst_prc_(dst_prc), mode_(mode) { + : jit_emitter(host, host_isa, exec_prc, in_out_type), + name_("unknown"), + store_num_(store_num), + src_prc_(src_prc), + dst_prc_(dst_prc), + mode_(mode) { prepare_table(); v_len_elt_ = get_vec_length() / exec_prc.size(); store_size_ = store_num * dst_prc.size(); @@ -630,17 +687,20 @@ inline bool jit_store_emitter::is_saturation() const { // case for SSE and AVX2 when we should use AND to truncate values inline bool jit_store_emitter::is_truncation_emulation() const { - return !mayiuse(cpu::x64::avx512_core) && !is_saturation() && - src_prc_ != dst_prc_ && one_of(dst_prc_, ov::element::u16, ov::element::i16, ov::element::u8, ov::element::i8); + return !mayiuse(cpu::x64::avx512_core) && !is_saturation() && src_prc_ != dst_prc_ && + one_of(dst_prc_, ov::element::u16, ov::element::i16, ov::element::u8, ov::element::i8); } size_t jit_store_emitter::aux_gprs_count() const { // for temp reg for store(mask version or special number cases) - const auto is_pure_store = (src_prc_ == dst_prc_) || - (one_of(src_prc_, ov::element::f32, ov::element::i32) && - one_of(dst_prc_, ov::element::f32, ov::element::i32)); + const auto is_pure_store = (src_prc_ == dst_prc_) || (one_of(src_prc_, ov::element::f32, ov::element::i32) && + one_of(dst_prc_, ov::element::f32, ov::element::i32)); const auto is_store_as_real16 = one_of(dst_prc_, ov::element::bf16, ov::element::f16); - int count = get_aux_regs_as_temp(store_num_, static_cast(dst_prc_.size()), is_pure_store, is_store_as_real16, threshold_for_mask_emu_store); + int count = get_aux_regs_as_temp(store_num_, + static_cast(dst_prc_.size()), + is_pure_store, + is_store_as_real16, + threshold_for_mask_emu_store); // for table value in truncation arithmetic mode if (is_truncation_emulation()) @@ -661,14 +721,17 @@ size_t jit_store_emitter::aux_vecs_count() const { if ((host_isa_ == cpu::x64::sse41) && (src_prc_ == ov::element::f32 && dst_prc_ == ov::element::bf16)) count++; - // zero value, zeroed and passed from caller from performance standpoint(zeroed one time and not need preserve and restore status) + // zero value, zeroed and passed from caller from performance standpoint(zeroed one time and not need preserve and + // restore status) if (mayiuse(cpu::x64::avx512_core) && one_of(dst_prc_, ov::element::u8, ov::element::u16)) count++; return count; } -size_t jit_store_emitter::get_inputs_num() const { return 1; } +size_t jit_store_emitter::get_inputs_num() const { + return 1; +} void jit_store_emitter::emit_data() const { jit_emitter::emit_data(); @@ -676,7 +739,7 @@ void jit_store_emitter::emit_data() const { uni_vcvtneps2bf16_->emit_data(); } -void jit_store_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_store_emitter::emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const { // offset in store emitter is the offset of dst gpr register, should be parsed from out_idxs. const int offset = out_idxs.size() == 2 ? out_idxs[1] : 0; if (host_isa_ == cpu::x64::sse41) { @@ -691,7 +754,7 @@ void jit_store_emitter::emit_impl(const std::vector &in_idxs, const std: } template -void jit_store_emitter::emit_isa(const int in_vec_idx, const Xbyak::Reg64 ®_dst, const int offset) const { +void jit_store_emitter::emit_isa(const int in_vec_idx, const Xbyak::Reg64& reg_dst, const int offset) const { bool matched_prc = (src_prc_ == dst_prc_) || (src_prc_ == ov::element::f32) || (src_prc_ == ov::element::i32); if (!matched_prc) { OV_CPU_JIT_EMITTER_THROW("only support input precision of FP32 or I32 or the same precision as output."); @@ -707,29 +770,29 @@ void jit_store_emitter::emit_isa(const int in_vec_idx, const Xbyak::Reg64 ®_d data_idx = in_vec_idx; data_reg_updated = false; if (!aux_vec_idxs.empty()) - aux_src_idx = aux_vec_idxs.back(); // to avoid src pollution + aux_src_idx = aux_vec_idxs.back(); // to avoid src pollution if (src_prc_ != dst_prc_) { switch (src_prc_) { - case ov::element::f32: - if (!dst_prc_.is_real()) { - if (is_saturation()) { - h->uni_vcvtps2dq(Vmm(aux_src_idx), Vmm(data_idx)); - } else { - h->uni_vcvttps2dq(Vmm(aux_src_idx), Vmm(data_idx)); - } - data_idx = aux_src_idx; - data_reg_updated = true; - } - break; - case ov::element::i32: - if (dst_prc_.is_real()) { - h->uni_vcvtdq2ps(Vmm(aux_src_idx), Vmm(data_idx)); - data_idx = aux_src_idx; - data_reg_updated = true; + case ov::element::f32: + if (!dst_prc_.is_real()) { + if (is_saturation()) { + h->uni_vcvtps2dq(Vmm(aux_src_idx), Vmm(data_idx)); + } else { + h->uni_vcvttps2dq(Vmm(aux_src_idx), Vmm(data_idx)); } - break; - default: - break; + data_idx = aux_src_idx; + data_reg_updated = true; + } + break; + case ov::element::i32: + if (dst_prc_.is_real()) { + h->uni_vcvtdq2ps(Vmm(aux_src_idx), Vmm(data_idx)); + data_idx = aux_src_idx; + data_reg_updated = true; + } + break; + default: + break; } } @@ -737,44 +800,44 @@ void jit_store_emitter::emit_isa(const int in_vec_idx, const Xbyak::Reg64 ®_d store_bytes(reg_dst, offset, store_size_); } else { switch (dst_prc_) { - case ov::element::f32: - case ov::element::i32: - store_bytes(reg_dst, offset, store_size_); - break; - case ov::element::i8: - store_dword_to_byte_extension(reg_dst, offset, true, store_num_); - break; - case ov::element::u8: - store_dword_to_byte_extension(reg_dst, offset, false, store_num_); - break; - case ov::element::i16: - case ov::element::u16: - case ov::element::bf16: - case ov::element::f16: - store_dword_to_word_extension(reg_dst, offset, dst_prc_, store_num_); - break; - default: - OV_CPU_JIT_EMITTER_THROW("has unsupported dst precision to store."); + case ov::element::f32: + case ov::element::i32: + store_bytes(reg_dst, offset, store_size_); + break; + case ov::element::i8: + store_dword_to_byte_extension(reg_dst, offset, true, store_num_); + break; + case ov::element::u8: + store_dword_to_byte_extension(reg_dst, offset, false, store_num_); + break; + case ov::element::i16: + case ov::element::u16: + case ov::element::bf16: + case ov::element::f16: + store_dword_to_word_extension(reg_dst, offset, dst_prc_, store_num_); + break; + default: + OV_CPU_JIT_EMITTER_THROW("has unsupported dst precision to store."); } } } /** -* store_bytes is the utility function to facilitate storing of -* store_size (0 <= store_size <= 64) many contiguous bytes from the Xmm/Ymm/Zmm -* register into the memory referenced by ptr[reg + offset] address. -* -* Additionally, when store_size > 16, the input Ymm register will not be -* preserved due to the usage of vextracti128 instruction. -* -* Functionally, invocation of store_bytes is equivalent -* to the following loop: -* -* for (int idx = 0; idx < store_size; ++idx) -* vpextrb(ptr[reg + offset + idx], vmm, idx); -* -*/ + * store_bytes is the utility function to facilitate storing of + * store_size (0 <= store_size <= 64) many contiguous bytes from the Xmm/Ymm/Zmm + * register into the memory referenced by ptr[reg + offset] address. + * + * Additionally, when store_size > 16, the input Ymm register will not be + * preserved due to the usage of vextracti128 instruction. + * + * Functionally, invocation of store_bytes is equivalent + * to the following loop: + * + * for (int idx = 0; idx < store_size; ++idx) + * vpextrb(ptr[reg + offset + idx], vmm, idx); + * + */ template -void jit_store_emitter::store_bytes(const Xbyak::Reg64 ®, int offset, int store_size) const { +void jit_store_emitter::store_bytes(const Xbyak::Reg64& reg, int offset, int store_size) const { constexpr bool is_xmm = std::is_same::value; constexpr bool is_ymm = std::is_same::value; constexpr bool is_zmm = std::is_same::value; @@ -805,7 +868,7 @@ void jit_store_emitter::store_bytes(const Xbyak::Reg64 ®, int offset, int sto int bytes_to_store = store_size; if (store_size > 32) { - h->uni_vmovdqu(addr(0), ymm); // store lower bits from zmm + h->uni_vmovdqu(addr(0), ymm); // store lower bits from zmm start_bytes += 32; bytes_to_store -= 32; // load upper bits from zmm into ymm @@ -813,7 +876,7 @@ void jit_store_emitter::store_bytes(const Xbyak::Reg64 ®, int offset, int sto } if (bytes_to_store > 16) { - h->uni_vmovdqu(addr(start_bytes), xmm); // store lower bits from ymm + h->uni_vmovdqu(addr(start_bytes), xmm); // store lower bits from ymm start_bytes += 16; bytes_to_store -= 16; // load upper bits from ymm into xmm @@ -834,93 +897,108 @@ void jit_store_emitter::store_bytes(const Xbyak::Reg64 ®, int offset, int sto h->mov(addr(start_bytes + bytes_offset), Reg8(gpr_idx, ext8bit)); }; switch (bytes_to_store) { - case 0: break; - case 1: - h->uni_vmovq(Reg64(aux_gpr_idxs[0]), xmm); - store_one_byte(0, aux_gpr_idxs[0]); - break; - case 2: - h->uni_vmovq(Reg64(aux_gpr_idxs[0]), xmm); - h->mov(addr(start_bytes), Reg16(aux_gpr_idxs[0])); - break; - case 3: - h->uni_vmovq(Reg64(aux_gpr_idxs[0]), xmm); - h->mov(addr(start_bytes), Reg16(aux_gpr_idxs[0])); - h->shr(Reg64(aux_gpr_idxs[0]), 16); - store_one_byte(2, aux_gpr_idxs[0]); - break; - case 4: h->uni_vmovss(addr(start_bytes), xmm); break; - case 5: - h->uni_vmovss(addr(start_bytes), xmm); - h->uni_vpextrb(addr(start_bytes + 4), xmm, 4); - break; - case 6: - h->uni_vmovss(addr(start_bytes), xmm); - h->uni_vpextrw(addr(start_bytes + 4), xmm, 2); - break; - case 7: - h->uni_vmovss(addr(start_bytes), xmm); - h->uni_vpextrw(addr(start_bytes + 4), xmm, 2); - h->uni_vpextrb(addr(start_bytes + 6), xmm, 6); - break; - case 8: break; - case 9: h->uni_vpextrb(addr(start_bytes + 8), xmm, 8); break; - case 10: h->uni_vpextrw(addr(start_bytes + 8), xmm, 4); break; - case 11: - h->uni_vpextrw(addr(start_bytes + 8), xmm, 4); - h->uni_vpextrb(addr(start_bytes + 10), xmm, 10); - break; - case 12: h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); break; - case 13: - h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); - h->uni_vpextrb(addr(start_bytes + 12), xmm, 12); - break; - case 14: - h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); - h->uni_vpextrw(addr(start_bytes + 12), xmm, 6); - break; - case 15: - h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); - h->uni_vpextrw(addr(start_bytes + 12), xmm, 6); - h->uni_vpextrb(addr(start_bytes + 14), xmm, 14); - break; - case 16: break; - default: - OV_CPU_JIT_EMITTER_THROW("has unexpected number of values to store in store_bytes."); - } - }; - - switch (store_size) { - case 64: - h->uni_vmovdqu(addr(0), zmm); + case 0: + break; + case 1: + h->uni_vmovq(Reg64(aux_gpr_idxs[0]), xmm); + store_one_byte(0, aux_gpr_idxs[0]); + break; + case 2: + h->uni_vmovq(Reg64(aux_gpr_idxs[0]), xmm); + h->mov(addr(start_bytes), Reg16(aux_gpr_idxs[0])); + break; + case 3: + h->uni_vmovq(Reg64(aux_gpr_idxs[0]), xmm); + h->mov(addr(start_bytes), Reg16(aux_gpr_idxs[0])); + h->shr(Reg64(aux_gpr_idxs[0]), 16); + store_one_byte(2, aux_gpr_idxs[0]); + break; + case 4: + h->uni_vmovss(addr(start_bytes), xmm); + break; + case 5: + h->uni_vmovss(addr(start_bytes), xmm); + h->uni_vpextrb(addr(start_bytes + 4), xmm, 4); + break; + case 6: + h->uni_vmovss(addr(start_bytes), xmm); + h->uni_vpextrw(addr(start_bytes + 4), xmm, 2); + break; + case 7: + h->uni_vmovss(addr(start_bytes), xmm); + h->uni_vpextrw(addr(start_bytes + 4), xmm, 2); + h->uni_vpextrb(addr(start_bytes + 6), xmm, 6); + break; + case 8: break; - case 32: - h->uni_vmovdqu(addr(0), ymm); + case 9: + h->uni_vpextrb(addr(start_bytes + 8), xmm, 8); + break; + case 10: + h->uni_vpextrw(addr(start_bytes + 8), xmm, 4); + break; + case 11: + h->uni_vpextrw(addr(start_bytes + 8), xmm, 4); + h->uni_vpextrb(addr(start_bytes + 10), xmm, 10); + break; + case 12: + h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); + break; + case 13: + h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); + h->uni_vpextrb(addr(start_bytes + 12), xmm, 12); + break; + case 14: + h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); + h->uni_vpextrw(addr(start_bytes + 12), xmm, 6); + break; + case 15: + h->uni_vpextrd(addr(start_bytes + 8), xmm, 2); + h->uni_vpextrw(addr(start_bytes + 12), xmm, 6); + h->uni_vpextrb(addr(start_bytes + 14), xmm, 14); break; case 16: - h->uni_vmovdqu(addr(0), xmm); break; default: - if (mayiuse(cpu::x64::avx512_core) && store_size > threshold_for_mask_emu_store) { - uint64_t mask = 1; - mask = (mask << store_size) - mask; - h->mov(Reg64(aux_gpr_idxs[0]), mask); - h->kmovq(k_mask, Reg64(aux_gpr_idxs[0])); - h->vmovdqu8(addr(0), zmm | k_mask); - } else { - store_byte_base(); - } - break; + OV_CPU_JIT_EMITTER_THROW("has unexpected number of values to store in store_bytes."); + } + }; + + switch (store_size) { + case 64: + h->uni_vmovdqu(addr(0), zmm); + break; + case 32: + h->uni_vmovdqu(addr(0), ymm); + break; + case 16: + h->uni_vmovdqu(addr(0), xmm); + break; + default: + if (mayiuse(cpu::x64::avx512_core) && store_size > threshold_for_mask_emu_store) { + uint64_t mask = 1; + mask = (mask << store_size) - mask; + h->mov(Reg64(aux_gpr_idxs[0]), mask); + h->kmovq(k_mask, Reg64(aux_gpr_idxs[0])); + h->vmovdqu8(addr(0), zmm | k_mask); + } else { + store_byte_base(); + } + break; } } /** -* store_dword_to_byte_extension is the utility function to -* 1. convert store_num (0 <= store_num <= 16) dwords in the Xmm/Ymm/Zmm to store_num bytes, singed or unsinged, truncated or saturated. -* 2. store the packed byte into the memory referenced by ptr[reg + offset] address. -*/ + * store_dword_to_byte_extension is the utility function to + * 1. convert store_num (0 <= store_num <= 16) dwords in the Xmm/Ymm/Zmm to store_num bytes, singed or unsinged, + * truncated or saturated. + * 2. store the packed byte into the memory referenced by ptr[reg + offset] address. + */ template -void jit_store_emitter::store_dword_to_byte_extension(const Xbyak::Reg64 ®, int offset, bool is_signed, int store_num) const { +void jit_store_emitter::store_dword_to_byte_extension(const Xbyak::Reg64& reg, + int offset, + bool is_signed, + int store_num) const { constexpr bool is_xmm = std::is_same::value; constexpr bool is_ymm = std::is_same::value; constexpr bool is_zmm = std::is_same::value; @@ -1032,7 +1110,7 @@ void jit_store_emitter::store_dword_to_byte_extension(const Xbyak::Reg64 ®, i break; case 4: if (mayiuse(cpu::x64::avx512_core)) { - if (is_saturation()) { // xmm block on avx512F + VL + if (is_saturation()) { // xmm block on avx512F + VL if (is_signed) { h->vpmovsdb(addr(0), xmm); } else { @@ -1074,13 +1152,16 @@ void jit_store_emitter::store_dword_to_byte_extension(const Xbyak::Reg64 ®, i } /** -* store_dword_to_word_extension is the utility function to -* 1. convert store_num (0 <= store_num <= 16) dwords in the Xmm/Ymm/Zmm to store_num words with singed or unsinged saturation. -* 2. store the packed words into the memory referenced by ptr[reg + offset] address. -*/ + * store_dword_to_word_extension is the utility function to + * 1. convert store_num (0 <= store_num <= 16) dwords in the Xmm/Ymm/Zmm to store_num words with singed or unsinged + * saturation. + * 2. store the packed words into the memory referenced by ptr[reg + offset] address. + */ template -void jit_store_emitter::store_dword_to_word_extension(const Xbyak::Reg64 ®, - int offset, ov::element::Type precision, int store_num) const { +void jit_store_emitter::store_dword_to_word_extension(const Xbyak::Reg64& reg, + int offset, + ov::element::Type precision, + int store_num) const { const bool is_bf16 = (precision == ov::element::bf16); const bool is_f16 = (precision == ov::element::f16); const bool is_signed = precision.is_signed(); @@ -1151,7 +1232,8 @@ void jit_store_emitter::store_dword_to_word_extension(const Xbyak::Reg64 ®, if (is_bf16) { if (mayiuse(cpu::x64::avx512_core)) { - // to avoid src vmm pollution, this check means no precision convert happens, so data_idx is still original_data_idx. + // to avoid src vmm pollution, this check means no precision convert happens, so data_idx is still + // original_data_idx. if (src_prc_ == ov::element::f32) { ymm = Ymm(aux_vec_idxs[0]); } @@ -1171,7 +1253,8 @@ void jit_store_emitter::store_dword_to_word_extension(const Xbyak::Reg64 ®, if (host_isa_ == cpu::x64::sse41 && src_prc_ == ov::element::f32) { auto xmm_aux1 = Xmm(aux_vec_idxs[1]); h->uni_vmovups(xmm_aux1, vmm); - uni_vcvtneps2bf16_->emit_code({static_cast(vmm.getIdx())}, {static_cast(vmm.getIdx())}, + uni_vcvtneps2bf16_->emit_code({static_cast(vmm.getIdx())}, + {static_cast(vmm.getIdx())}, {static_cast(xmm.getIdx())}); h->uni_vmovups(xmm, vmm); h->uni_vmovups(vmm, xmm_aux1); // return original data to src vmm @@ -1222,7 +1305,7 @@ void jit_store_emitter::store_dword_to_word_extension(const Xbyak::Reg64 ®, Vmm zero(aux_vec_idxs[0]); h->uni_vpxor(zero, zero, zero); STORE_KEEP_SOURCE(uni_vpmaxsd, vmm, Vmm(aux_src_idx), vmm, zero); - h->vpmovusdw(ptr[reg + offset], vmm); // unsinged int32 saturate to unsigned int16. + h->vpmovusdw(ptr[reg + offset], vmm); // unsinged int32 saturate to unsigned int16. } } else { h->vpmovdw(ptr[reg + offset], vmm); @@ -1261,7 +1344,7 @@ void jit_store_emitter::store_dword_to_word_extension(const Xbyak::Reg64 ®, h->vpmovdw(ptr[reg + offset], xmm); } } else { - store_dword_to_word_base(); + store_dword_to_word_base(); } break; default: @@ -1297,5 +1380,5 @@ void jit_store_emitter::register_table_entries() { } } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_load_store_emitters.hpp index 9570a836aa64ee..2c4e15ccaeb28b 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_load_store_emitters.hpp @@ -4,16 +4,23 @@ #pragma once -#include "jit_emitter.hpp" #include "jit_bf16_emitters.hpp" +#include "jit_emitter.hpp" namespace ov { namespace intel_cpu { struct load_emitter_params : public emitter_params { - load_emitter_params(ov::element::Type src_prc, ov::element::Type dst_prc, - int load_num, bool is_fill = false, std::string fill_value = "zero"): - src_prc_(src_prc), dst_prc_(dst_prc), load_num_(load_num), is_fill_(is_fill), fill_value_(fill_value) {} + load_emitter_params(ov::element::Type src_prc, + ov::element::Type dst_prc, + int load_num, + bool is_fill = false, + std::string fill_value = "zero") + : src_prc_(src_prc), + dst_prc_(dst_prc), + load_num_(load_num), + is_fill_(is_fill), + fill_value_(fill_value) {} size_t hash() const override; @@ -25,8 +32,10 @@ struct load_emitter_params : public emitter_params { }; struct store_emitter_params : public emitter_params { - store_emitter_params(ov::element::Type src_prc, ov::element::Type dst_prc, int store_num): - src_prc_(src_prc), dst_prc_(dst_prc), store_num_(store_num) {} + store_emitter_params(ov::element::Type src_prc, ov::element::Type dst_prc, int store_num) + : src_prc_(src_prc), + dst_prc_(dst_prc), + store_num_(store_num) {} size_t hash() const override; @@ -36,57 +45,61 @@ struct store_emitter_params : public emitter_params { }; // Arithmetic modes for data type conversion in store_emitter -enum arithmetic_mode { - saturation, - truncation -}; +enum arithmetic_mode { saturation, truncation }; class jit_load_emitter : public jit_emitter { public: - jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type src_prc, ov::element::Type dst_prc, int load_num, + jit_load_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type src_prc, + ov::element::Type dst_prc, + int load_num, ov::element::Type exec_prc = ov::element::f32, - bool is_fill = false, std::string fill_value = "zero", + bool is_fill = false, + std::string fill_value = "zero", emitter_in_out_map in_out_type = emitter_in_out_map::gpr_to_vec); /** - * load_num values with src_prc precision are loaded from ptr[Reg64(in_idxs[0]) + offset_byte] address to Vmm[out_idxs[0]] as dst_prc, where offset_byte is in_idxs[1] - * is_fill: when load_num can not fully fit in vector register, whether fill_value should be filled as default values. - * fill_value: when load_num can not fully fit in vector register, what values should be filled as default values. - * currently support "zero", "int_one", "float_one", "int32_min", "float_min", "int32_max" and "float_max". - * supported src_prc and dst_prc pairs are as below(x indicate for support): - * FP32 I32 I16 U16 I8 U8 BF16 --> src_prc - * FP32 x x x x x x x - * I32 x x x x x x x - * I16 x - * U16 x - * I8 x - * U8 x - * BF16 x - * | - * \|/ - * dst_prc - */ + * load_num values with src_prc precision are loaded from ptr[Reg64(in_idxs[0]) + offset_byte] address to + * Vmm[out_idxs[0]] as dst_prc, where offset_byte is in_idxs[1] is_fill: when load_num can not fully fit in vector + * register, whether fill_value should be filled as default values. fill_value: when load_num can not fully fit in + * vector register, what values should be filled as default values. currently support "zero", "int_one", + * "float_one", "int32_min", "float_min", "int32_max" and "float_max". supported src_prc and dst_prc pairs are as + * below(x indicate for support): FP32 I32 I16 U16 I8 U8 BF16 --> src_prc FP32 x x x x + * x x x I32 x x x x x x x I16 x U16 x I8 x U8 + * x BF16 x + * | + * \|/ + * dst_prc + */ // offset in load emitter is the offset of src gpr register, should be parsed from in_idxs. - void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override; + void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; size_t get_inputs_num() const override; private: template - void emit_isa(const Xbyak::Reg64 ®_src, const int out_vec_idx, const int offset) const; + void emit_isa(const Xbyak::Reg64& reg_src, const int out_vec_idx, const int offset) const; template - void load_bytes(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, int load_size) const; + void load_bytes(const Vmm& vmm, const Xbyak::Reg64& reg, int offset, int load_size) const; template - void load_bytes_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, bool is_signed, int load_size) const; + void load_bytes_to_dword_extension(const Vmm& vmm, + const Xbyak::Reg64& reg, + int offset, + bool is_signed, + int load_size) const; template - void load_words_to_dword_extension(const Vmm &vmm, const Xbyak::Reg64 ®, int offset, ov::element::Type prc, int load_size) const; + void load_words_to_dword_extension(const Vmm& vmm, + const Xbyak::Reg64& reg, + int offset, + ov::element::Type prc, + int load_size) const; template - void fill_with_default(const Vmm &vmm, std::string fill_value, const int &load_num) const; + void fill_with_default(const Vmm& vmm, std::string fill_value, const int& load_num) const; void register_table_entries() override; @@ -104,30 +117,27 @@ class jit_load_emitter : public jit_emitter { class jit_store_emitter : public jit_emitter { public: - jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type src_prc, ov::element::Type dst_prc, int store_num, + jit_store_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + ov::element::Type src_prc, + ov::element::Type dst_prc, + int store_num, arithmetic_mode mode = arithmetic_mode::saturation, ov::element::Type exec_prc = ov::element::f32, emitter_in_out_map in_out_type = emitter_in_out_map::vec_to_gpr); /** - * store_num values with src_prc in Vmm[in_vec_idx] is stored to ptr[reg_dst + offset_byte] address as dst_prc data, where offset_byte is in_idxs[1] - * supported src_prc and dst_prc pairs are as below(x indicate for support): - * FP32 I32 I16 U16 I8 U8 BF16 --> src_prc - * FP32 x x - * I32 x x - * I16 x x x - * U16 x x x - * I8 x x x - * U8 x x x - * BF16 x* x* x - * \|/ - * dst_prc - * note: FP32/I32-->BF16(x*) is supported only on at least avx512-core plateform - */ + * store_num values with src_prc in Vmm[in_vec_idx] is stored to ptr[reg_dst + offset_byte] address as dst_prc data, + * where offset_byte is in_idxs[1] supported src_prc and dst_prc pairs are as below(x indicate for support): FP32 + * I32 I16 U16 I8 U8 BF16 --> src_prc FP32 x x I32 x x I16 x x x U16 x x + * x I8 x x x U8 x x x BF16 x* x* x + * \|/ + * dst_prc + * note: FP32/I32-->BF16(x*) is supported only on at least avx512-core plateform + */ // offset in store emitter is the offset of dst gpr register, should be parsed from out_idxs. - void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override; + void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; size_t get_inputs_num() const override; @@ -139,16 +149,19 @@ class jit_store_emitter : public jit_emitter { private: template - void emit_isa(const int in_vec_idx, const Xbyak::Reg64 ®_dst, const int offset) const; + void emit_isa(const int in_vec_idx, const Xbyak::Reg64& reg_dst, const int offset) const; template - void store_bytes(const Xbyak::Reg64 ®, int offset, int store_size) const; + void store_bytes(const Xbyak::Reg64& reg, int offset, int store_size) const; template - void store_dword_to_byte_extension(const Xbyak::Reg64 ®, int offset, bool is_signed, int store_size) const; + void store_dword_to_byte_extension(const Xbyak::Reg64& reg, int offset, bool is_signed, int store_size) const; template - void store_dword_to_word_extension(const Xbyak::Reg64 ®, int offset, ov::element::Type precision, int store_size) const; + void store_dword_to_word_extension(const Xbyak::Reg64& reg, + int offset, + ov::element::Type precision, + int store_size) const; void register_table_entries() override; @@ -176,5 +189,5 @@ class jit_store_emitter : public jit_emitter { mutable int aux_src_idx = 0; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp index ea16122f2f9793..420e9691ebc73c 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp @@ -21,8 +21,21 @@ EmitABIRegSpills::~EmitABIRegSpills() { void EmitABIRegSpills::preamble() { // gprs - Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15, - h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp}; + Xbyak::Operand gprs_to_save[] = {h->r8, + h->r9, + h->r10, + h->r11, + h->r12, + h->r13, + h->r14, + h->r15, + h->rax, + h->rbx, + h->rcx, + h->rdx, + h->rdi, + h->rsi, + h->rbp}; size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); h->sub(h->rsp, n_gprs_to_save * gpr_size); @@ -75,8 +88,21 @@ void EmitABIRegSpills::postamble() { } // restore gpr registers - Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15, - h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp}; + Xbyak::Operand gprs_to_save[] = {h->r8, + h->r9, + h->r10, + h->r11, + h->r12, + h->r13, + h->r14, + h->r15, + h->rax, + h->rbx, + h->rcx, + h->rdx, + h->rdi, + h->rsi, + h->rbp}; size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); for (int i = n_gprs_to_save - 1; i >= 0; --i) h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]); @@ -113,13 +139,17 @@ void EmitABIRegSpills::rsp_restore() { cpu_isa_t EmitABIRegSpills::get_isa() { // need preserve based on cpu capability, instead of host isa. // in case there are possibilty that different isa emitters exist in one kernel from perf standpoint in the future. - // e.g. other emitters isa is avx512, while this emitter isa is avx2, and internal call is used. Internal call may use avx512 and spoil k-reg, ZMM. - // do not care about platform w/ avx512_common but w/o avx512_core(knight landing), which is obsoleted. - if (mayiuse(avx512_core)) return avx512_core; - if (mayiuse(avx2)) return avx2; - if (mayiuse(sse41)) return sse41; + // e.g. other emitters isa is avx512, while this emitter isa is avx2, and internal call is used. Internal call may + // use avx512 and spoil k-reg, ZMM. do not care about platform w/ avx512_common but w/o avx512_core(knight landing), + // which is obsoleted. + if (mayiuse(avx512_core)) + return avx512_core; + if (mayiuse(avx2)) + return avx2; + if (mayiuse(sse41)) + return sse41; OV_CPU_JIT_EMITTER_THROW("unsupported isa"); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp index 16a66beba7a536..ba956f3375f054 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp @@ -30,11 +30,15 @@ class EmitABIRegSpills { static dnnl::impl::cpu::x64::cpu_isa_t get_isa(); - inline size_t get_max_vecs_count() const { return dnnl::impl::cpu::x64::isa_num_vregs(isa); } - inline size_t get_vec_length() const { return dnnl::impl::cpu::x64::isa_max_vlen(isa); } + inline size_t get_max_vecs_count() const { + return dnnl::impl::cpu::x64::isa_num_vregs(isa); + } + inline size_t get_vec_length() const { + return dnnl::impl::cpu::x64::isa_max_vlen(isa); + } - dnnl::impl::cpu::x64::jit_generator* h {nullptr}; - const dnnl::impl::cpu::x64::cpu_isa_t isa {dnnl::impl::cpu::x64::cpu_isa_t::isa_undef}; + dnnl::impl::cpu::x64::jit_generator* h{nullptr}; + const dnnl::impl::cpu::x64::cpu_isa_t isa{dnnl::impl::cpu::x64::cpu_isa_t::isa_undef}; static constexpr int k_mask_size = 8; static constexpr int k_mask_num = 8; @@ -44,5 +48,5 @@ class EmitABIRegSpills { bool rsp_status = true; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_kernel_executor_table.hpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_kernel_executor_table.hpp index 79e8dcafb218f6..cfe03d21eac19e 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_kernel_executor_table.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_kernel_executor_table.hpp @@ -4,34 +4,38 @@ #pragma once -#include "snippets/kernel_executor_table.hpp" #include "cache/multi_cache.h" +#include "snippets/kernel_executor_table.hpp" namespace ov { namespace intel_cpu { -template +template class CPUKernelExecutor : public snippets::KernelExecutor { public: - CPUKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, Conf c) : - snippets::KernelExecutor(std::move(c)), m_kernel_cache(std::move(kernel_cache)) {} + CPUKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, Conf c) + : snippets::KernelExecutor(std::move(c)), + m_kernel_cache(std::move(kernel_cache)) {} - void update_kernel(const Conf& config, std::shared_ptr& kernel) const override final { // NOLINT + void update_kernel(const Conf& config, std::shared_ptr& kernel) const override final { // NOLINT const auto& cache = m_kernel_cache.lock(); OPENVINO_ASSERT(cache, "Invalid kernel cache pointer in CPUKernelExecutor::update_kernel()"); - const auto& lookup_result = cache->getOrCreate(Key(config), - [this](const Key& k) { - return compile_kernel(k.config); - }); + const auto& lookup_result = cache->getOrCreate(Key(config), [this](const Key& k) { + return compile_kernel(k.config); + }); kernel = lookup_result.first; } protected: struct Key { explicit Key(Conf c) : config{std::move(c)} {} - const Conf config; - size_t hash() const { return config.hash(); } - bool operator==(const Key& rhs) const { return config == rhs.config; } + const Conf config; + size_t hash() const { + return config.hash(); + } + bool operator==(const Key& rhs) const { + return config == rhs.config; + } }; /** Compile kernel managed by KernelExecutor instance. Will be called only if Kernel is not found in the cache */ virtual std::shared_ptr compile_kernel(const Conf& c) const = 0; @@ -39,5 +43,5 @@ class CPUKernelExecutor : public snippets::KernelExecutor { ov::intel_cpu::MultiCacheWeakPtr m_kernel_cache; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp index b2758735b2d27a..65741d7031d289 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.cpp @@ -8,8 +8,8 @@ #include "snippets/utils/utils.hpp" #ifndef OPENVINO_ARCH_ARM64 -#include "transformations/snippets/x64/pass/lowered/brgemm_copy_b_loop_ports_adjuster.hpp" -#include "transformations/snippets/x64/pass/lowered/external_repacking_adjuster.hpp" +# include "transformations/snippets/x64/pass/lowered/brgemm_copy_b_loop_ports_adjuster.hpp" +# include "transformations/snippets/x64/pass/lowered/external_repacking_adjuster.hpp" #endif namespace ov { namespace intel_cpu { @@ -21,7 +21,8 @@ const size_t CPURuntimeConfigurator::rank6D = 6; std::string CPURuntimeConfig::to_string() const { std::stringstream out; out << RuntimeConfig::to_string(); - out << "Loop Parameters:" << "\n"; + out << "Loop Parameters:" + << "\n"; for (size_t i = 0; i < loop_args.size(); ++i) { const auto& loop = loop_args[i]; out << "\t[" << i << "] WA: " << loop.m_work_amount << "\n"; @@ -38,8 +39,8 @@ std::string CPURuntimeConfig::to_string() const { } #endif -CPURuntimeConfigurator::CPURuntimeConfigurator() : ov::snippets::RuntimeConfigurator(std::make_shared()) { -} +CPURuntimeConfigurator::CPURuntimeConfigurator() + : ov::snippets::RuntimeConfigurator(std::make_shared()) {} void CPURuntimeConfigurator::initialization(const ov::snippets::lowered::LinearIRCPtr& linear_ir) { RuntimeConfigurator::initialization(linear_ir); @@ -78,12 +79,14 @@ void CPURuntimeConfigurator::update_loop_args(const ov::snippets::lowered::Linea const auto& data_sizes = loop_info->get_data_sizes(); auto& loop_arg = cpu_config->loop_args[idx]; - loop_arg = jit_snippets_call_args::loop_args_t(loop_info->get_work_amount(), loop_info->get_ptr_increments(), loop_info->get_finalization_offsets()); + loop_arg = jit_snippets_call_args::loop_args_t(loop_info->get_work_amount(), + loop_info->get_ptr_increments(), + loop_info->get_finalization_offsets()); for (int64_t i = 0; i < loop_arg.m_num_data_ptrs; ++i) { loop_arg.m_ptr_increments[i] *= (increment * data_sizes[i]); loop_arg.m_finalization_offsets[i] *= data_sizes[i]; } } } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp index 42ce35a3c66c2b..1706670ce870d1 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/cpu_runtime_configurator.hpp @@ -34,6 +34,7 @@ class CPURuntimeConfigurator : public ov::snippets::RuntimeConfigurator { * @param linear_ir LinearIR */ void update_loop_args(const ov::snippets::lowered::LinearIRCPtr& linear_ir) const; + protected: void update(const ov::snippets::lowered::LinearIRCPtr& linear_ir) override; void update_tensor_rank(const ov::snippets::VectorDims& master_shape) const override; @@ -43,5 +44,5 @@ class CPURuntimeConfigurator : public ov::snippets::RuntimeConfigurator { static const size_t rank6D; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.cpp index 6f78c43fd54797..ceee57f3c0cd28 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.cpp @@ -3,15 +3,18 @@ // #include "jit_container_emitter.hpp" + #include "emitters/utils.hpp" #include "utils/general_utils.h" namespace ov { namespace intel_cpu { -void jit_container_emitter::map_abstract_registers(mapping_info& gpr_map_pool, mapping_info& vec_map_pool, +void jit_container_emitter::map_abstract_registers(mapping_info& gpr_map_pool, + mapping_info& vec_map_pool, snippets::lowered::LinearIR::container& expressions) const { - OV_CPU_JIT_EMITTER_ASSERT(!expressions.empty(), "Cannot map registers when there is no allocated_emitters provided"); + OV_CPU_JIT_EMITTER_ASSERT(!expressions.empty(), + "Cannot map registers when there is no allocated_emitters provided"); auto map_regs = [&](const std::vector& abstract_regs) { std::vector physical_regs = abstract_regs; @@ -19,13 +22,16 @@ void jit_container_emitter::map_abstract_registers(mapping_info& gpr_map_pool, m const auto& abstract_reg = abstract_regs[i]; const auto& type = abstract_reg.type; const auto& abstract = abstract_reg.idx; - OV_CPU_JIT_EMITTER_ASSERT(one_of(type, snippets::RegType::gpr, snippets::RegType::vec), "Incorrect reg type detected!"); + OV_CPU_JIT_EMITTER_ASSERT(one_of(type, snippets::RegType::gpr, snippets::RegType::vec), + "Incorrect reg type detected!"); auto& mapping = type == snippets::RegType::gpr ? gpr_map_pool : vec_map_pool; auto& abstract_to_physical = mapping.first; auto& regs_pool = mapping.second; auto& physical = physical_regs[i]; if (abstract_to_physical.count(abstract) == 0) { - OV_CPU_JIT_EMITTER_ASSERT(!regs_pool.empty(), "Cannot map registers for jit_container_emitter: not enough regs in the pool"); + OV_CPU_JIT_EMITTER_ASSERT( + !regs_pool.empty(), + "Cannot map registers for jit_container_emitter: not enough regs in the pool"); physical.idx = regs_pool.back(); regs_pool.pop_back(); abstract_to_physical[abstract] = physical.idx; @@ -48,5 +54,5 @@ void jit_container_emitter::map_abstract_registers(mapping_info& gpr_map_pool, m } } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.hpp index 2325c6ef1a2eb3..7737e7e1150926 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/jit_container_emitter.hpp @@ -20,8 +20,10 @@ class jit_container_emitter { protected: // maps gpr and vec abstract registers to physical ones. - void map_abstract_registers(mapping_info& gpr_map_pool, mapping_info& vec_map_pool, snippets::lowered::LinearIR::container& expressions) const; + void map_abstract_registers(mapping_info& gpr_map_pool, + mapping_info& vec_map_pool, + snippets::lowered::LinearIR::container& expressions) const; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/jit_snippets_call_args.cpp b/src/plugins/intel_cpu/src/emitters/snippets/jit_snippets_call_args.cpp index 48f98c2ffb2450..20e19bcba7e4f4 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/jit_snippets_call_args.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/jit_snippets_call_args.cpp @@ -3,10 +3,11 @@ // #include "jit_snippets_call_args.hpp" -#include "emitters/utils.hpp" #include +#include "emitters/utils.hpp" + namespace ov { namespace intel_cpu { @@ -21,16 +22,19 @@ void jit_snippets_call_args::register_loops(const std::vector& loop std::copy(loops.begin(), loops.end(), loop_args); } -jit_snippets_call_args::loop_args_t::loop_args_t(int64_t work_amount, const std::vector& ptr_increments, +jit_snippets_call_args::loop_args_t::loop_args_t(int64_t work_amount, + const std::vector& ptr_increments, const std::vector& finalization_offsets) : m_work_amount(work_amount) { - OV_CPU_JIT_EMITTER_ASSERT(ptr_increments.size() == finalization_offsets.size(), "Inconsistent sizes of ptr_increments and finalization_offsets"); + OV_CPU_JIT_EMITTER_ASSERT(ptr_increments.size() == finalization_offsets.size(), + "Inconsistent sizes of ptr_increments and finalization_offsets"); m_num_data_ptrs = static_cast(ptr_increments.size()); init_pointers_and_copy_data(m_num_data_ptrs, ptr_increments.data(), finalization_offsets.data()); } jit_snippets_call_args::loop_args_t::loop_args_t(const loop_args_t& other) - : m_work_amount(other.m_work_amount), m_num_data_ptrs(other.m_num_data_ptrs) { + : m_work_amount(other.m_work_amount), + m_num_data_ptrs(other.m_num_data_ptrs) { init_pointers_and_copy_data(m_num_data_ptrs, other.m_ptr_increments, other.m_finalization_offsets); } @@ -44,7 +48,8 @@ jit_snippets_call_args::loop_args_t& jit_snippets_call_args::loop_args_t::operat return *this; } -void jit_snippets_call_args::loop_args_t::init_pointers_and_copy_data(const int64_t num_elements, const int64_t* ptr_increments, +void jit_snippets_call_args::loop_args_t::init_pointers_and_copy_data(const int64_t num_elements, + const int64_t* ptr_increments, const int64_t* finalization_offsets) { const size_t chunk_size = num_elements * sizeof(int64_t); m_ptr_increments = new int64_t[num_elements]; @@ -60,5 +65,5 @@ void swap(jit_snippets_call_args::loop_args_t& first, jit_snippets_call_args::lo std::swap(first.m_finalization_offsets, second.m_finalization_offsets); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/jit_snippets_call_args.hpp b/src/plugins/intel_cpu/src/emitters/snippets/jit_snippets_call_args.hpp index 027655d493784d..eb74190dd71676 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/jit_snippets_call_args.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/jit_snippets_call_args.hpp @@ -4,9 +4,9 @@ #pragma once -#include -#include #include +#include +#include #include "dnnl_types.h" #include "openvino/core/visibility.hpp" @@ -15,12 +15,12 @@ namespace ov { namespace intel_cpu { #if defined(OPENVINO_ARCH_ARM64) -#define SNIPPETS_MAX_DATA_PTR_COUNT 23 +# define SNIPPETS_MAX_DATA_PTR_COUNT 23 #else -#define SNIPPETS_MAX_DATA_PTR_COUNT 11 +# define SNIPPETS_MAX_DATA_PTR_COUNT 11 #endif -#define GET_OFF(field) offsetof(jit_snippets_call_args, field) +#define GET_OFF(field) offsetof(jit_snippets_call_args, field) #define GET_OFF_LOOP_ARGS(field) offsetof(jit_snippets_call_args::loop_args_t, field) struct amx_tile_config_t { @@ -37,9 +37,9 @@ struct jit_snippets_call_args { void register_loops(const std::vector& loops); - const void *src_ptrs[SNIPPETS_MAX_DATA_PTR_COUNT] = {}; - void *dst_ptrs[SNIPPETS_MAX_DATA_PTR_COUNT] = {}; - void *buffer_scratchpad_ptr = nullptr; + const void* src_ptrs[SNIPPETS_MAX_DATA_PTR_COUNT] = {}; + void* dst_ptrs[SNIPPETS_MAX_DATA_PTR_COUNT] = {}; + void* buffer_scratchpad_ptr = nullptr; // Note: Ideally loop_args must be private, since we manage this pointer manually. // However, standard-layout class definition (to use offset_of) requires the same access specifier @@ -51,14 +51,18 @@ struct jit_snippets_call_args { struct jit_snippets_call_args::loop_args_t { loop_args_t() = default; - loop_args_t(int64_t work_amount, const std::vector& ptr_increments, const std::vector& finalization_offsets); + loop_args_t(int64_t work_amount, + const std::vector& ptr_increments, + const std::vector& finalization_offsets); loop_args_t(const loop_args_t& other); ~loop_args_t(); loop_args_t& operator=(loop_args_t other); friend void swap(loop_args_t& first, loop_args_t& second); - void init_pointers_and_copy_data(const int64_t num_elements, const int64_t* ptr_increments, const int64_t* finalization_offsets); + void init_pointers_and_copy_data(const int64_t num_elements, + const int64_t* ptr_increments, + const int64_t* finalization_offsets); int64_t m_work_amount = 0; int64_t m_num_data_ptrs = 0; @@ -71,5 +75,5 @@ struct jit_snippets_compile_args { std::vector exec_domain = {}; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/utils/debug_caps_config.cpp b/src/plugins/intel_cpu/src/emitters/snippets/utils/debug_caps_config.cpp index b7c51539861ff8..e4c3c40e1d8120 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/utils/debug_caps_config.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/utils/debug_caps_config.cpp @@ -3,7 +3,7 @@ // #ifdef SNIPPETS_DEBUG_CAPS -#include "debug_caps_config.hpp" +# include "debug_caps_config.hpp" namespace ov { namespace intel_cpu { @@ -20,7 +20,7 @@ void SnippetsDebugCapsConfig::readProperties() { enable_segfault_detector = readEnv("OV_CPU_SNIPPETS_SEGFAULT_DETECTOR") ? true : false; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov -#endif // SNIPPETS_DEBUG_CAPS +#endif // SNIPPETS_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/emitters/snippets/utils/debug_caps_config.hpp b/src/plugins/intel_cpu/src/emitters/snippets/utils/debug_caps_config.hpp index 14dcae0ddf0c69..8f01e85063f5e9 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/utils/debug_caps_config.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/utils/debug_caps_config.hpp @@ -3,10 +3,10 @@ // #ifdef SNIPPETS_DEBUG_CAPS -#pragma once +# pragma once -#include -#include +# include +# include namespace ov { namespace intel_cpu { @@ -23,7 +23,7 @@ class SnippetsDebugCapsConfig { void readProperties(); }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov -#endif // SNIPPETS_DEBUG_CAPS +#endif // SNIPPETS_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp index c210782db8f91c..39e384837856a1 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/cpu_generator.cpp @@ -4,63 +4,61 @@ #include "cpu_generator.hpp" -#include "snippets/snippets_isa.hpp" -#include "emitters/snippets/cpu_runtime_configurator.hpp" +#include +#include "emitters/plugin/x64/jit_conversion_emitters.hpp" +#include "emitters/plugin/x64/jit_dnnl_emitters.hpp" +#include "emitters/plugin/x64/jit_dnnl_ext_emitters.hpp" +#include "emitters/plugin/x64/jit_eltwise_emitters.hpp" +#include "emitters/snippets/cpu_kernel_executor_table.hpp" +#include "emitters/snippets/cpu_runtime_configurator.hpp" #include "emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp" #include "emitters/snippets/x64/jit_brgemm_emitter.hpp" -#include "emitters/snippets/x64/jit_memory_emitters.hpp" +#include "emitters/snippets/x64/jit_fill_emitter.hpp" +#include "emitters/snippets/x64/jit_horizon_emitter.hpp" #include "emitters/snippets/x64/jit_kernel_emitter.hpp" #include "emitters/snippets/x64/jit_loop_emitters.hpp" +#include "emitters/snippets/x64/jit_memory_emitters.hpp" #include "emitters/snippets/x64/jit_snippets_emitters.hpp" -#include "emitters/snippets/x64/jit_fill_emitter.hpp" -#include "emitters/snippets/x64/jit_horizon_emitter.hpp" -#include "emitters/plugin/x64/jit_eltwise_emitters.hpp" -#include "emitters/plugin/x64/jit_dnnl_emitters.hpp" -#include "emitters/plugin/x64/jit_dnnl_ext_emitters.hpp" -#include "emitters/plugin/x64/jit_conversion_emitters.hpp" - -#include "transformations/snippets/x64/op/load_convert.hpp" -#include "transformations/snippets/x64/op/store_convert.hpp" +#include "snippets/snippets_isa.hpp" +#include "transformations/cpu_opset/common/op/swish_cpu.hpp" #include "transformations/snippets/common/op/fused_mul_add.hpp" #include "transformations/snippets/x64/op/brgemm_copy_b.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" +#include "transformations/snippets/x64/op/load_convert.hpp" #include "transformations/snippets/x64/op/perf_count_rdtsc.hpp" -#include "transformations/cpu_opset/common/op/swish_cpu.hpp" +#include "transformations/snippets/x64/op/store_convert.hpp" #include "transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp" -#include -#include "emitters/snippets/cpu_kernel_executor_table.hpp" - #ifdef SNIPPETS_DEBUG_CAPS -#include "emitters/snippets/x64/jit_perf_count_chrono_emitters.hpp" -#include "emitters/snippets/x64/jit_perf_count_rdtsc_emitters.hpp" -#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp" -#include "emitters/snippets/x64/jit_debug_emitter.hpp" -#include "emitters/snippets/x64/jit_segfault_detector_emitter.hpp" -#include "emitters/snippets/x64/verbose.hpp" +# include "emitters/snippets/x64/jit_debug_emitter.hpp" +# include "emitters/snippets/x64/jit_perf_count_chrono_emitters.hpp" +# include "emitters/snippets/x64/jit_perf_count_rdtsc_emitters.hpp" +# include "emitters/snippets/x64/jit_segfault_detector_emitter.hpp" +# include "emitters/snippets/x64/verbose.hpp" +# include "transformations/snippets/x64/op/perf_count_rdtsc.hpp" #endif #ifdef SNIPPETS_LIBXSMM_TPP -#include "transformations/tpp/x64/op/brgemm.hpp" -#include "transformations/tpp/x64/op/eltwise.hpp" -#include "transformations/tpp/x64/op/reduce.hpp" -#include "transformations/tpp/x64/op/modifiers.hpp" -#include "transformations/tpp/x64/op/scalar.hpp" -#include "transformations/tpp/x64/op/equation.hpp" -#include "emitters/tpp/x64/jit_eltwise_emitters.hpp" -#include "emitters/tpp/x64/jit_brgemm_emitter.hpp" -#include "emitters/tpp/x64/jit_scalar_emitter.hpp" -#include "emitters/tpp/x64/jit_equation_emitter.hpp" -#include "emitters/tpp/x64/jit_debug_emitter.hpp" +# include "emitters/tpp/x64/jit_brgemm_emitter.hpp" +# include "emitters/tpp/x64/jit_debug_emitter.hpp" +# include "emitters/tpp/x64/jit_eltwise_emitters.hpp" +# include "emitters/tpp/x64/jit_equation_emitter.hpp" +# include "emitters/tpp/x64/jit_scalar_emitter.hpp" +# include "transformations/tpp/x64/op/brgemm.hpp" +# include "transformations/tpp/x64/op/eltwise.hpp" +# include "transformations/tpp/x64/op/equation.hpp" +# include "transformations/tpp/x64/op/modifiers.hpp" +# include "transformations/tpp/x64/op/reduce.hpp" +# include "transformations/tpp/x64/op/scalar.hpp" // Note: for reference implementations -#include +# include #endif namespace ov { #ifdef SNIPPETS_DEBUG_CAPS -static bool is_load_emitter(const intel_cpu::jit_emitter *emitter) { +static bool is_load_emitter(const intel_cpu::jit_emitter* emitter) { bool ret = false; if (dynamic_cast(emitter) || dynamic_cast(emitter)) { @@ -69,7 +67,7 @@ static bool is_load_emitter(const intel_cpu::jit_emitter *emitter) { return ret; } -static bool is_store_emitter(const intel_cpu::jit_emitter *emitter) { +static bool is_store_emitter(const intel_cpu::jit_emitter* emitter) { bool ret = false; if (dynamic_cast(emitter)) { return true; @@ -77,72 +75,82 @@ static bool is_store_emitter(const intel_cpu::jit_emitter *emitter) { return ret; } -static bool is_segfault_detector_emitter(const intel_cpu::jit_emitter *emitter) { +static bool is_segfault_detector_emitter(const intel_cpu::jit_emitter* emitter) { // default active for typical tensor memory access emitters bool ret = false; - ret = is_load_emitter(emitter) || - is_store_emitter(emitter) || - dynamic_cast(emitter) || - dynamic_cast(emitter) || - dynamic_cast(emitter); + ret = is_load_emitter(emitter) || is_store_emitter(emitter) || + dynamic_cast(emitter) || + dynamic_cast(emitter) || + dynamic_cast(emitter); return ret; // use below code to active all emitters for extend usage // return !dynamic_cast(emitter); } -#define CREATE_SNIPPETS_EMITTER(e_type, ...) { \ - [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ - auto emitter = std::make_shared(h.get(), isa, expr, ##__VA_ARGS__); \ - if (debug_config.enable_segfault_detector && is_segfault_detector_emitter(emitter.get())) { \ - auto segfault_emitter = std::make_shared(h.get(), isa, emitter.get(), \ - is_load_emitter(emitter.get()), is_store_emitter(emitter.get()), expr->get_node()->get_friendly_name()); \ - return std::make_shared(emitter, segfault_emitter, jit_debug_emitter::EmissionLocation::preamble); \ - } else { \ - return emitter; \ - } \ - }, \ - [](const std::shared_ptr& n) -> std::set> { \ - return e_type::get_supported_precisions(n); \ - } \ -} +# define CREATE_SNIPPETS_EMITTER(e_type, ...) \ + { \ + [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ + auto emitter = std::make_shared(h.get(), isa, expr, ##__VA_ARGS__); \ + if (debug_config.enable_segfault_detector && is_segfault_detector_emitter(emitter.get())) { \ + auto segfault_emitter = \ + std::make_shared(h.get(), \ + isa, \ + emitter.get(), \ + is_load_emitter(emitter.get()), \ + is_store_emitter(emitter.get()), \ + expr->get_node()->get_friendly_name()); \ + return std::make_shared(emitter, \ + segfault_emitter, \ + jit_debug_emitter::EmissionLocation::preamble); \ + } else { \ + return emitter; \ + } \ + }, \ + [](const std::shared_ptr& n) -> std::set> { \ + return e_type::get_supported_precisions(n); \ + } \ + } #else -#define CREATE_SNIPPETS_EMITTER(e_type, ...) { \ - [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ - return std::make_shared(h.get(), isa, expr, ##__VA_ARGS__); \ - }, \ - [](const std::shared_ptr& n) -> std::set> { \ - return e_type::get_supported_precisions(n); \ - } \ -} +# define CREATE_SNIPPETS_EMITTER(e_type, ...) \ + { \ + [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ + return std::make_shared(h.get(), isa, expr, ##__VA_ARGS__); \ + }, \ + [](const std::shared_ptr& n) -> std::set> { \ + return e_type::get_supported_precisions(n); \ + } \ + } #endif -#define CREATE_DEBUG_TPP_EMITTER(e_type) { \ - [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ - return std::make_shared(expr, std::make_shared(h.get(), isa, expr)); \ - }, \ - [](const std::shared_ptr& n) -> std::set> { \ - return e_type::get_supported_precisions(n); \ - } \ -} - +#define CREATE_DEBUG_TPP_EMITTER(e_type) \ + { \ + [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ + return std::make_shared(expr, std::make_shared(h.get(), isa, expr)); \ + }, \ + [](const std::shared_ptr& n) -> std::set> { \ + return e_type::get_supported_precisions(n); \ + } \ + } -#define CREATE_CPU_EMITTER(e_type) { \ - [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ - return std::make_shared(h.get(), isa, expr->get_node()); \ - }, \ - [](const std::shared_ptr& n) -> std::set> { \ - return e_type::get_supported_precisions(n); \ - } \ -} +#define CREATE_CPU_EMITTER(e_type) \ + { \ + [this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ + return std::make_shared(h.get(), isa, expr->get_node()); \ + }, \ + [](const std::shared_ptr& n) -> std::set> { \ + return e_type::get_supported_precisions(n); \ + } \ + } -#define CREATE_UNDEFINED_EMITTER(supported_precisions) { \ - [](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ - return nullptr; \ - }, \ - [](const std::shared_ptr& n) -> std::set> { \ - return supported_precisions; \ - } \ -} +#define CREATE_UNDEFINED_EMITTER(supported_precisions) \ + { \ + [](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr { \ + return nullptr; \ + }, \ + [](const std::shared_ptr& n) -> std::set> { \ + return supported_precisions; \ + } \ + } class jit_snippet : public dnnl::impl::cpu::x64::jit_generator { public: @@ -157,30 +165,43 @@ class jit_snippet : public dnnl::impl::cpu::x64::jit_generator { intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::intel_cpu::MultiCacheWeakPtr cache) - : TargetMachine(std::make_shared()), h(new jit_snippet()), isa(host_isa), compiled_kernel_cache(std::move(cache)) { + : TargetMachine(std::make_shared()), + h(new jit_snippet()), + isa(host_isa), + compiled_kernel_cache(std::move(cache)) { // data movement jitters[op::v0::Parameter::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter); jitters[op::v0::Result::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter); jitters[snippets::op::Buffer::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter); jitters[snippets::op::VectorBuffer::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter); - jitters[snippets::op::RankNormalization::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter); + jitters[snippets::op::RankNormalization::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter); jitters[snippets::op::Reshape::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_nop_emitter); jitters[snippets::op::Load::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_load_memory_emitter); - jitters[snippets::op::LoadReshape::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_load_memory_emitter); - jitters[snippets::op::BroadcastLoad::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_load_broadcast_emitter); - jitters[intel_cpu::LoadConvertSaturation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_load_memory_emitter); - jitters[intel_cpu::LoadConvertTruncation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_load_memory_emitter); + jitters[snippets::op::LoadReshape::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_load_memory_emitter); + jitters[snippets::op::BroadcastLoad::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_load_broadcast_emitter); + jitters[intel_cpu::LoadConvertSaturation::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_load_memory_emitter); + jitters[intel_cpu::LoadConvertTruncation::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_load_memory_emitter); jitters[snippets::op::Store::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_store_memory_emitter); - jitters[intel_cpu::StoreConvertSaturation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_store_memory_emitter); - jitters[intel_cpu::StoreConvertTruncation::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_store_memory_emitter); + jitters[intel_cpu::StoreConvertSaturation::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_store_memory_emitter); + jitters[intel_cpu::StoreConvertTruncation::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_store_memory_emitter); jitters[snippets::op::Scalar::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_scalar_emitter); - jitters[snippets::op::BroadcastMove::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_broadcast_move_emitter); + jitters[snippets::op::BroadcastMove::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_broadcast_move_emitter); - jitters[snippets::op::ConvertTruncation::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_convert_truncation_emitter); - jitters[snippets::op::ConvertSaturation::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_convert_saturation_emitter); + jitters[snippets::op::ConvertTruncation::get_type_info_static()] = + CREATE_CPU_EMITTER(intel_cpu::jit_convert_truncation_emitter); + jitters[snippets::op::ConvertSaturation::get_type_info_static()] = + CREATE_CPU_EMITTER(intel_cpu::jit_convert_saturation_emitter); // ternary jitters[op::v1::Select::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_select_emitter); @@ -203,10 +224,12 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho jitters[op::v1::Mod::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_mod_emitter); jitters[op::v1::Multiply::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_multiply_emitter); jitters[op::v1::NotEqual::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_not_equal_emitter); - jitters[snippets::op::PowerStatic::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_power_static_emitter); + jitters[snippets::op::PowerStatic::get_type_info_static()] = + CREATE_CPU_EMITTER(intel_cpu::jit_power_static_emitter); jitters[op::v1::Power::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_power_dynamic_emitter); jitters[op::v0::PRelu::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_prelu_emitter); - jitters[op::v0::SquaredDifference::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_squared_difference_emitter); + jitters[op::v0::SquaredDifference::get_type_info_static()] = + CREATE_CPU_EMITTER(intel_cpu::jit_squared_difference_emitter); jitters[op::v1::Subtract::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_subtract_emitter); jitters[op::v0::Xor::get_type_info_static()] = CREATE_CPU_EMITTER(intel_cpu::jit_logical_xor_emitter); @@ -235,25 +258,35 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho jitters[snippets::op::HorizonMax::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_horizon_emitter); jitters[snippets::op::HorizonSum::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_horizon_emitter); - jitters[snippets::op::KernelStatic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_kernel_static_emitter); - jitters[snippets::op::KernelDynamic::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_kernel_dynamic_emitter); - jitters[snippets::op::LoopBegin::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_loop_begin_emitter); + jitters[snippets::op::KernelStatic::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_kernel_static_emitter); + jitters[snippets::op::KernelDynamic::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_kernel_dynamic_emitter); + jitters[snippets::op::LoopBegin::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_loop_begin_emitter); jitters[snippets::op::LoopEnd::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_loop_end_emitter); - // Note: jit_brgemm_emitter and jit_brgemm_copy_b_emitter support runtime recompilation, so their constructor takes additional arguments - jitters[intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter, - configurator->get_kernel_executor_table(), - compiled_kernel_cache); - jitters[intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_b_emitter, - configurator->get_kernel_executor_table(), - compiled_kernel_cache); + // Note: jit_brgemm_emitter and jit_brgemm_copy_b_emitter support runtime recompilation, so their constructor takes + // additional arguments + jitters[intel_cpu::BrgemmCPU::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter, + configurator->get_kernel_executor_table(), + compiled_kernel_cache); + jitters[intel_cpu::BrgemmCopyB::get_type_info_static()] = + CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_b_emitter, + configurator->get_kernel_executor_table(), + compiled_kernel_cache); jitters[snippets::op::ReduceMax::get_type_info_static()] = CREATE_UNDEFINED_EMITTER({{ov::element::f32}}); jitters[snippets::op::ReduceSum::get_type_info_static()] = CREATE_UNDEFINED_EMITTER({{ov::element::f32}}); #ifdef SNIPPETS_DEBUG_CAPS - jitters[snippets::op::PerfCountBegin::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_perf_count_chrono_start_emitter); - jitters[snippets::op::PerfCountEnd::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_perf_count_chrono_end_emitter); - jitters[ov::intel_cpu::PerfCountRdtscBegin::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_perf_count_rdtsc_start_emitter); - jitters[ov::intel_cpu::PerfCountRdtscEnd::get_type_info_static()] = CREATE_CPU_EMITTER(ov::intel_cpu::jit_perf_count_rdtsc_end_emitter); + jitters[snippets::op::PerfCountBegin::get_type_info_static()] = + CREATE_CPU_EMITTER(ov::intel_cpu::jit_perf_count_chrono_start_emitter); + jitters[snippets::op::PerfCountEnd::get_type_info_static()] = + CREATE_CPU_EMITTER(ov::intel_cpu::jit_perf_count_chrono_end_emitter); + jitters[ov::intel_cpu::PerfCountRdtscBegin::get_type_info_static()] = + CREATE_CPU_EMITTER(ov::intel_cpu::jit_perf_count_rdtsc_start_emitter); + jitters[ov::intel_cpu::PerfCountRdtscEnd::get_type_info_static()] = + CREATE_CPU_EMITTER(ov::intel_cpu::jit_perf_count_rdtsc_end_emitter); #endif #ifdef SNIPPETS_LIBXSMM_TPP @@ -267,8 +300,8 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho // Note: you can register Debug emitter for Unary/Binary operations as shown below: // jitters[intel_cpu::tpp::op::Add::get_type_info_static()] = CREATE_DEBUG_TPP_EMITTER(UnaryEltwiseTppEmitter); // - // Note: you can register Reference emitter for Unary operations using std::function or lambda function as shown below: - // jitters[intel_cpu::tpp::op::Exp::get_type_info_static()] = + // Note: you can register Reference emitter for Unary operations using std::function or lambda function as shown + // below: jitters[intel_cpu::tpp::op::Exp::get_type_info_static()] = // CREATE_SNIPPETS_EMITTER(ReferenceUnaryEltwiseTppEmitter, static_cast(std::exp)); // jitters[intel_cpu::tpp::op::Reciprocal::get_type_info_static()] = // CREATE_SNIPPETS_EMITTER(ReferenceUnaryEltwiseTppEmitter, [](float x){ return 1.f/x; }); @@ -292,10 +325,14 @@ std::shared_ptr intel_cpu::CPUTargetMachine::clone() co size_t intel_cpu::CPUTargetMachine::get_lanes() const { switch (isa) { - case dnnl::impl::cpu::x64::avx2 : return dnnl::impl::cpu::x64::cpu_isa_traits::vlen / sizeof(float); - case dnnl::impl::cpu::x64::sse41 : return dnnl::impl::cpu::x64::cpu_isa_traits::vlen / sizeof(float); - case dnnl::impl::cpu::x64::avx512_core : return dnnl::impl::cpu::x64::cpu_isa_traits::vlen / sizeof(float); - default : OPENVINO_THROW("unknown isa ", isa); + case dnnl::impl::cpu::x64::avx2: + return dnnl::impl::cpu::x64::cpu_isa_traits::vlen / sizeof(float); + case dnnl::impl::cpu::x64::sse41: + return dnnl::impl::cpu::x64::cpu_isa_traits::vlen / sizeof(float); + case dnnl::impl::cpu::x64::avx512_core: + return dnnl::impl::cpu::x64::cpu_isa_traits::vlen / sizeof(float); + default: + OPENVINO_THROW("unknown isa ", isa); } } @@ -315,13 +352,15 @@ snippets::CompiledSnippetPtr intel_cpu::CPUTargetMachine::get_snippet() { if (h->create_kernel() != dnnl::impl::status::success) { OPENVINO_THROW("Failed to create jit_kernel in get_snippet()"); } - const auto& result = std::make_shared(std::unique_ptr(h.release())); + const auto& result = + std::make_shared(std::unique_ptr(h.release())); // Note that we reset all the generated code, since it was copied into CompiledSnippetCPU h.reset(new jit_snippet()); return result; } -intel_cpu::CompiledSnippetCPU::CompiledSnippetCPU(std::unique_ptr h) : h_compiled(std::move(h)) { +intel_cpu::CompiledSnippetCPU::CompiledSnippetCPU(std::unique_ptr h) + : h_compiled(std::move(h)) { OPENVINO_ASSERT(h_compiled && h_compiled->jit_ker(), "Got invalid jit generator or kernel was nopt compiled"); } @@ -337,15 +376,14 @@ bool intel_cpu::CompiledSnippetCPU::empty() const { return get_code_size() == 0; } -intel_cpu::CPUGenerator::CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa_, ov::intel_cpu::MultiCacheWeakPtr cache) : - Generator(std::make_shared(isa_, std::move(cache))) { -} -intel_cpu::CPUGenerator::CPUGenerator(const std::shared_ptr& target) : Generator(target) { -} +intel_cpu::CPUGenerator::CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa_, ov::intel_cpu::MultiCacheWeakPtr cache) + : Generator(std::make_shared(isa_, std::move(cache))) {} +intel_cpu::CPUGenerator::CPUGenerator(const std::shared_ptr& target) : Generator(target) {} std::shared_ptr intel_cpu::CPUGenerator::clone() const { const auto& cpu_target_machine = std::dynamic_pointer_cast(target->clone()); - OPENVINO_ASSERT(cpu_target_machine, "Failed to clone CPUGenerator: the instance contains incompatible TargetMachine type"); + OPENVINO_ASSERT(cpu_target_machine, + "Failed to clone CPUGenerator: the instance contains incompatible TargetMachine type"); return std::make_shared(cpu_target_machine); } @@ -358,12 +396,11 @@ ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(cons #endif std::dynamic_pointer_cast(op)) return ov::snippets::RegType::gpr; - else if ( - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op)) + else if (std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) return ov::snippets::RegType::vec; else - return ov::snippets::RegType::undefined; + return ov::snippets::RegType::undefined; } bool intel_cpu::CPUGenerator::uses_precompiled_kernel(const std::shared_ptr& e) const { @@ -383,4 +420,4 @@ bool intel_cpu::CPUGenerator::uses_precompiled_kernel(const std::shared_ptr h_compiled; + public: const uint8_t* get_code() const override; size_t get_code_size() const override; @@ -31,8 +30,7 @@ class CompiledSnippetCPU : public snippets::CompiledSnippet { class CPUTargetMachine : public snippets::TargetMachine { public: - explicit CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::intel_cpu::MultiCacheWeakPtr); + explicit CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t host_isa, ov::intel_cpu::MultiCacheWeakPtr); std::shared_ptr clone() const override; bool is_supported() const override; snippets::CompiledSnippetPtr get_snippet() override; @@ -60,5 +58,5 @@ class CPUGenerator : public snippets::Generator { bool uses_precompiled_kernel(const std::shared_ptr& emitter) const override; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp index 53d8fea05a8adf..6df658d8d72d0c 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.cpp @@ -4,18 +4,15 @@ #include "jit_brgemm_copy_b_emitter.hpp" +#include +#include + #include "emitters/plugin/x64/utils.hpp" #include "emitters/snippets/x64/utils.hpp" - -#include "snippets/utils/utils.hpp" #include "snippets/lowered/expression.hpp" - +#include "snippets/utils/utils.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" -#include -#include - - using namespace Xbyak; using namespace dnnl::impl; using namespace dnnl::impl::cpu::x64; @@ -34,7 +31,9 @@ bool get_is_transposed(const ov::snippets::lowered::ExpressionPtr& expr) { } } // namespace -jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr, +jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, + cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr, const snippets::KernelExecutorTablePtr& kernel_table, const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache) : jit_emitter(h, isa) { @@ -57,17 +56,20 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t m_with_comp = with_compensations(brgemm_type); BrgemmCopyBKernelConfig kernel_config(src_prc, wei_prc, primitive_isa, m_with_comp, is_transposed, wei_N_blk); - m_kernel_executor = kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); + m_kernel_executor = + kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); m_memory_offsets = {brgemm_repack->get_offset_in(), brgemm_repack->get_offset_out()}; - m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)), utils::get_buffer_cluster_id(expr->get_output_port(0))}; + m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)), + utils::get_buffer_cluster_id(expr->get_output_port(0))}; if (m_with_comp) { m_memory_offsets.push_back(brgemm_repack->get_offset_compensations()); m_buffer_ids.push_back(utils::get_buffer_cluster_id(expr->get_output_port(1))); } } -void jit_brgemm_copy_b_emitter::validate_arguments(const std::vector &in, const std::vector &out) const { +void jit_brgemm_copy_b_emitter::validate_arguments(const std::vector& in, + const std::vector& out) const { OV_CPU_JIT_EMITTER_ASSERT(in.size() == 1, "expects 1 input"); OV_CPU_JIT_EMITTER_ASSERT((m_with_comp && out.size() == 2) || (!m_with_comp && out.size() == 1), "expects 2 outputs if there are compensations"); @@ -87,14 +89,20 @@ void jit_brgemm_copy_b_emitter::emit_impl(const std::vector& in, const s // Reserve memory on the stack h->sub(h->rsp, reserved_stack_size); - const bool is_dynamic_case = std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value); + const bool is_dynamic_case = + std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value); Xbyak::Reg64 aux_reg = is_dynamic_case ? ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs) : Xbyak::Reg64(); - const std::vector args_offsets {GET_OFF_BRGEMM_COPY_B_ARGS(src), GET_OFF_BRGEMM_COPY_B_ARGS(tr_src), GET_OFF_BRGEMM_COPY_B_ARGS(compensation_ptr)}; + const std::vector args_offsets{GET_OFF_BRGEMM_COPY_B_ARGS(src), + GET_OFF_BRGEMM_COPY_B_ARGS(tr_src), + GET_OFF_BRGEMM_COPY_B_ARGS(compensation_ptr)}; const auto& mem_ptrs = ov::intel_cpu::utils::transform_idxs_to_regs(mem_ptrs_idxs); for (size_t i = 0; i < mem_ptrs.size(); i++) { if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) - utils::push_ptr_with_runtime_offset_on_stack(h, args_offsets[i], mem_ptrs[i], aux_reg, + utils::push_ptr_with_runtime_offset_on_stack(h, + args_offsets[i], + mem_ptrs[i], + aux_reg, GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t)); else utils::push_ptr_with_static_offset_on_stack(h, args_offsets[i], mem_ptrs[i], m_memory_offsets[i]); @@ -116,5 +124,5 @@ void jit_brgemm_copy_b_emitter::emit_impl(const std::vector& in, const s spill.postamble(); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp index ef53efe6081217..96a80153bba4b6 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp @@ -5,38 +5,40 @@ #pragma once #include "emitters/plugin/x64/jit_emitter.hpp" - #include "kernel_executors/brgemm_copy_b.hpp" - namespace ov { namespace intel_cpu { class jit_brgemm_copy_b_emitter : public jit_emitter { public: - jit_brgemm_copy_b_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, + jit_brgemm_copy_b_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr, const snippets::KernelExecutorTablePtr& kernel_table, const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache); - size_t get_inputs_num() const override {return 1;} - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + size_t get_inputs_num() const override { + return 1; + } + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr) { return {{element::i8}, {element::bf16}, {element::f32}}; } private: - void validate_arguments(const std::vector &in, const std::vector &out) const override; + void validate_arguments(const std::vector& in, const std::vector& out) const override; void emit_impl(const std::vector& in, const std::vector& out) const override; std::vector m_memory_offsets{}; std::vector m_buffer_ids{}; - std::shared_ptr m_kernel_executor {nullptr}; - bool m_with_comp {false}; + std::shared_ptr m_kernel_executor{nullptr}; + bool m_with_comp{false}; #ifdef SNIPPETS_DEBUG_CAPS - friend std::string init_info_jit_brgemm_copy_b_emitter(const jit_brgemm_copy_b_emitter *emitter); + friend std::string init_info_jit_brgemm_copy_b_emitter(const jit_brgemm_copy_b_emitter* emitter); #endif }; -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp index 6e70cbf2e8fe81..172a1cc0b98284 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp @@ -4,13 +4,12 @@ #include "jit_brgemm_emitter.hpp" -#include "transformations/snippets/x64/op/brgemm_cpu.hpp" -#include "transformations/snippets/x64/op/brgemm_utils.hpp" +#include "emitters/plugin/x64/utils.hpp" #include "emitters/snippets/x64/kernel_executors/brgemm.hpp" #include "emitters/snippets/x64/kernel_executors/brgemm_amx.hpp" -#include "emitters/plugin/x64/utils.hpp" - #include "snippets/utils/utils.hpp" +#include "transformations/snippets/x64/op/brgemm_cpu.hpp" +#include "transformations/snippets/x64/op/brgemm_utils.hpp" #include "utils.hpp" using namespace Xbyak; @@ -20,11 +19,12 @@ using namespace dnnl::impl::cpu::x64; namespace ov { namespace intel_cpu { -jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa, +jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, + cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr, const snippets::KernelExecutorTablePtr& kernel_table, - const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache) : - jit_emitter(h, isa) { + const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache) + : jit_emitter(h, isa) { in_out_type_ = emitter_in_out_map::gpr_to_gpr; const auto& brgemm_node = as_type_ptr(expr->get_node()); const auto& brg0Prc = brgemm_node->get_input_element_type(0); @@ -33,20 +33,26 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa, m_is_with_amx = brgemm_utils::with_amx(brgemm_type); if (m_is_with_amx) { BrgemmAMXKernelConfig kernel_config(brg0Prc, brg1Prc, brgemm_utils::get_primitive_isa(brg0Prc, true)); - m_kernel_executor = kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); + m_kernel_executor = + kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); } else { - BrgemmKernelConfig kernel_config(brg0Prc, brg1Prc, with_compensations(brgemm_type), brgemm_utils::get_primitive_isa(brg0Prc, false)); - m_kernel_executor = kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); + BrgemmKernelConfig kernel_config(brg0Prc, + brg1Prc, + with_compensations(brgemm_type), + brgemm_utils::get_primitive_isa(brg0Prc, false)); + m_kernel_executor = + kernel_table->register_kernel(expr, compiled_kernel_cache, kernel_config); } // Note: even if the Brgemm node is dynamic, the first shapeInfer and RuntimeConfigurator::update() // are performed before the BrgemmKernelExecutor registration. So we have to trigger update() manually // for both static and the 1st dynamic shapes. OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()) && - !snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(1)->get_shape()), + !snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(1)->get_shape()), "Jit emitter is called when the shapes are unknown"); m_memory_offsets = {brgemm_node->get_offset_a(), brgemm_node->get_offset_b(), brgemm_node->get_offset_c()}; - m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)), utils::get_buffer_cluster_id(expr->get_input_port(1)), + m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)), + utils::get_buffer_cluster_id(expr->get_input_port(1)), utils::get_buffer_cluster_id(expr->get_output_port(0))}; if (with_scratchpad(brgemm_type)) { m_memory_offsets.push_back(brgemm_node->get_offset_scratch()); @@ -54,7 +60,8 @@ jit_brgemm_emitter::jit_brgemm_emitter(jit_generator* h, cpu_isa_t isa, } } -std::set> jit_brgemm_emitter::get_supported_precisions(const std::shared_ptr& node) { +std::set> jit_brgemm_emitter::get_supported_precisions( + const std::shared_ptr& node) { const auto brgemm = as_type_ptr(node); OV_CPU_JIT_EMITTER_ASSERT(brgemm, "get_supported_precisions() expects BrgemmCPU node"); using brgemm_utils::BRGEMM_TYPE; @@ -77,7 +84,7 @@ std::set> jit_brgemm_emitter::get_supported_precision OV_CPU_JIT_EMITTER_THROW("got BrgemmCPU node with unsupported type"); } -void jit_brgemm_emitter::validate_arguments(const std::vector &in, const std::vector &out) const { +void jit_brgemm_emitter::validate_arguments(const std::vector& in, const std::vector& out) const { OV_CPU_JIT_EMITTER_ASSERT(m_memory_offsets.size() == in.size() + 1 && (out.size() == 1), "expects 3 inputs if there are compensations/wsp"); } @@ -96,8 +103,7 @@ void jit_brgemm_emitter::emit_impl(const std::vector& in, const std::vec OV_CPU_JIT_EMITTER_THROW("uknown execuor type"); } -template::value, bool>::type> +template ::value, bool>::type> void jit_brgemm_emitter::emit_call(const std::vector& mem_ptrs_idxs) const { EmitABIRegSpills spill(h); spill.preamble(); @@ -107,17 +113,24 @@ void jit_brgemm_emitter::emit_call(const std::vector& mem_ptrs_idxs) con // Reserve memory on the stack h->sub(h->rsp, reserved_stack_size); - const bool is_dynamic_case = std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value); + const bool is_dynamic_case = + std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value); Xbyak::Reg64 aux_reg = is_dynamic_case ? ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs) : Xbyak::Reg64(); #define GET_OFF_CALL_ARGS(field) offsetof(typename T::call_args, field) - const std::vector brgemm_args_offsets = { GET_OFF_CALL_ARGS(A), GET_OFF_CALL_ARGS(B), GET_OFF_CALL_ARGS(C), GET_OFF_CALL_ARGS(scratch) }; + const std::vector brgemm_args_offsets = {GET_OFF_CALL_ARGS(A), + GET_OFF_CALL_ARGS(B), + GET_OFF_CALL_ARGS(C), + GET_OFF_CALL_ARGS(scratch)}; #undef GET_OFF_CALL_ARGS const auto& mem_ptrs = utils::transform_idxs_to_regs(mem_ptrs_idxs); for (size_t i = 0; i < mem_ptrs.size(); i++) { if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i])) - utils::push_ptr_with_runtime_offset_on_stack(h, brgemm_args_offsets[i], mem_ptrs[i], aux_reg, + utils::push_ptr_with_runtime_offset_on_stack(h, + brgemm_args_offsets[i], + mem_ptrs[i], + aux_reg, GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t)); else utils::push_ptr_with_static_offset_on_stack(h, brgemm_args_offsets[i], mem_ptrs[i], m_memory_offsets[i]); @@ -145,5 +158,5 @@ void jit_brgemm_emitter::emit_call(const std::vector& mem_ptrs_idxs) con spill.postamble(); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp index ccec1b68b18b20..9d072065c0fe52 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.hpp @@ -12,34 +12,39 @@ namespace intel_cpu { class jit_brgemm_emitter : public jit_emitter { public: - jit_brgemm_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, + jit_brgemm_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr, const snippets::KernelExecutorTablePtr& kernel_table, const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache); - size_t get_inputs_num() const override { return m_memory_offsets.size() - 1; } - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); + size_t get_inputs_num() const override { + return m_memory_offsets.size() - 1; + } + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr); private: - void validate_arguments(const std::vector &in, const std::vector &out) const override; + void validate_arguments(const std::vector& in, const std::vector& out) const override; void emit_impl(const std::vector& in, const std::vector& out) const override; template ::value, bool>::type = true> void emit_call(const std::vector& mem_ptrs_idxs) const; - // Note: offsets order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if offset is calculated in runtime + // Note: offsets order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if offset is calculated in + // runtime std::vector m_memory_offsets{}; // Note: cluster ids order: A, B, C (+ scratchpad, if needed). Values can be dynamic_value if there is no buffer std::vector m_buffer_ids{}; std::shared_ptr m_kernel_executor = nullptr; #ifdef SNIPPETS_DEBUG_CAPS - friend std::string init_info_jit_brgemm_emitter(const jit_brgemm_emitter *emitter); + friend std::string init_info_jit_brgemm_emitter(const jit_brgemm_emitter* emitter); #endif - bool m_is_with_amx {false}; + bool m_is_with_amx{false}; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_debug_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_debug_emitter.cpp index 05b9d15786157b..45ebfc83899dba 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_debug_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_debug_emitter.cpp @@ -4,9 +4,11 @@ #ifdef SNIPPETS_DEBUG_CAPS -#include "jit_debug_emitter.hpp" -#include -#include "utils/general_utils.h" +# include "jit_debug_emitter.hpp" + +# include + +# include "utils/general_utils.h" using namespace dnnl::impl::cpu; using namespace dnnl::impl; @@ -27,8 +29,10 @@ size_t jit_debug_emitter::aux_gprs_count() const { return m_target_emitter->aux_gprs_count(); } -void jit_debug_emitter::emitter_preamble(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const { +void jit_debug_emitter::emitter_preamble(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { m_target_emitter->emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs); } @@ -52,12 +56,14 @@ void jit_debug_emitter::register_table_entries() { m_target_emitter->register_table_entries(); } -void jit_debug_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_debug_emitter::emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const { m_target_emitter->emit_impl(in_idxs, out_idxs); } -void jit_debug_emitter::emit_code(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const { +void jit_debug_emitter::emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { if (m_decorator_emit_loc == EmissionLocation::preamble || m_decorator_emit_loc == EmissionLocation::both) m_decorator_emitter->emit_code(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs); @@ -67,7 +73,7 @@ void jit_debug_emitter::emit_code(const std::vector &in_idxs, const std: m_decorator_emitter->emit_code(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov #endif \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_debug_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_debug_emitter.hpp index fe7cc527418587..2591af119cc3b5 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_debug_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_debug_emitter.hpp @@ -4,29 +4,33 @@ #ifdef SNIPPETS_DEBUG_CAPS -#pragma once - -#include "emitters/plugin/x64/jit_emitter.hpp" +# pragma once +# include "emitters/plugin/x64/jit_emitter.hpp" namespace ov { namespace intel_cpu { class jit_debug_emitter : public jit_emitter { public: - enum class EmissionLocation { - preamble, - postamble, - both - }; - jit_debug_emitter(const std::shared_ptr& target_emitter, const std::shared_ptr& decorator_emitter, const EmissionLocation& loc) - : jit_emitter(target_emitter->h, target_emitter->host_isa_, target_emitter->exec_prc_, target_emitter->in_out_type_), - m_target_emitter(target_emitter), m_decorator_emitter(decorator_emitter), m_decorator_emit_loc(loc) { - prepare_table(); - } - - void emit_code(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs = {}, const std::vector &pool_gpr_idxs = {}) const override; + enum class EmissionLocation { preamble, postamble, both }; + jit_debug_emitter(const std::shared_ptr& target_emitter, + const std::shared_ptr& decorator_emitter, + const EmissionLocation& loc) + : jit_emitter(target_emitter->h, + target_emitter->host_isa_, + target_emitter->exec_prc_, + target_emitter->in_out_type_), + m_target_emitter(target_emitter), + m_decorator_emitter(decorator_emitter), + m_decorator_emit_loc(loc) { + prepare_table(); + } + + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}) const override; void emit_data() const override; size_t get_inputs_num() const override; @@ -38,10 +42,12 @@ class jit_debug_emitter : public jit_emitter { void prepare_table() override; void register_table_entries() override; - void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override; + void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; - void emitter_preamble(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const override; + void emitter_preamble(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const override; void emitter_postamble() const override; private: @@ -54,7 +60,7 @@ class jit_debug_emitter : public jit_emitter { EmissionLocation m_decorator_emit_loc; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov #endif \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_fill_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_fill_emitter.cpp index 1c05100317ae5f..687917acbabc5a 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_fill_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_fill_emitter.cpp @@ -4,16 +4,15 @@ #include "jit_fill_emitter.hpp" - using namespace Xbyak; using namespace dnnl::impl; using namespace dnnl::impl::cpu::x64; - namespace ov { namespace intel_cpu { -jit_fill_emitter::jit_fill_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, +jit_fill_emitter::jit_fill_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) : jit_emitter(h, isa, ov::element::f32, emitter_in_out_map::vec_to_vec) { const auto fill = ov::as_type_ptr(expr->get_node()); @@ -52,9 +51,9 @@ void jit_fill_emitter::emit_impl(const std::vector& in, const std::vecto } template -void jit_fill_emitter::emit_isa(const std::vector &in, const std::vector &out) const { - using Vmm = typename dnnl::impl::utils::conditional3::type; +void jit_fill_emitter::emit_isa(const std::vector& in, const std::vector& out) const { + using Vmm = typename dnnl::impl::utils:: + conditional3::type; Vmm src_vmm = Vmm(in[0]); Vmm dst_vmm = Vmm(out[0]); @@ -62,7 +61,8 @@ void jit_fill_emitter::emit_isa(const std::vector &in, const std::vector const size_t supported_et_size = 4; const auto register_capacity = (src_vmm.getBit() / 8) / supported_et_size; if (offset == register_capacity) { - // WA: since AssignRegisters doesn't support inplace logic, Fill ops with offset = register_capacity can't be removed from the LIR + // WA: since AssignRegisters doesn't support inplace logic, Fill ops with offset = register_capacity can't be + // removed from the LIR // TODO: when inplace is supported, remove such Fill ops from the LIR and remove this logic. // Ticket: 126270 if (src_vmm.getIdx() != dst_vmm.getIdx()) @@ -105,5 +105,5 @@ void jit_fill_emitter::fill_tail(const Vmm& src_vmm, const Vmm& dst_vmm) const { } } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_fill_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_fill_emitter.hpp index 79e9a0e4027a5d..23b929cc161ca7 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_fill_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_fill_emitter.hpp @@ -6,15 +6,18 @@ #include "emitters/plugin/x64/jit_emitter.hpp" - namespace ov { namespace intel_cpu { class jit_fill_emitter : public jit_emitter { public: - jit_fill_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); + jit_fill_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_num() const override {return 1;} + size_t get_inputs_num() const override { + return 1; + } protected: size_t aux_gprs_count() const override; @@ -23,18 +26,22 @@ class jit_fill_emitter : public jit_emitter { void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in, const std::vector &out) const; + void emit_isa(const std::vector& in, const std::vector& out) const; template void fill_full(const Vmm& vmm_dst) const; template void fill_tail(const Vmm& vmm_src, const Vmm& vmm_dst) const; - bool is_full_reg() const { return offset == 0; } - bool is_optimized() const { return is_full_reg() && fill_value == uint32_t(0x0); } + bool is_full_reg() const { + return offset == 0; + } + bool is_optimized() const { + return is_full_reg() && fill_value == uint32_t(0x0); + } size_t offset = 0; uint32_t fill_value = 0x0; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_horizon_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_horizon_emitter.cpp index a4f5cbe16d7e1f..34e9c2f71fd148 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_horizon_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_horizon_emitter.cpp @@ -4,7 +4,6 @@ #include "jit_horizon_emitter.hpp" - using namespace Xbyak; using namespace dnnl::impl; using namespace dnnl::impl::cpu::x64; @@ -12,7 +11,8 @@ using namespace dnnl::impl::cpu::x64; namespace ov { namespace intel_cpu { -jit_horizon_emitter::jit_horizon_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, +jit_horizon_emitter::jit_horizon_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) : jit_emitter(h, isa, ov::element::f32, emitter_in_out_map::vec_to_vec) { if (ov::is_type(expr->get_node())) { @@ -24,8 +24,7 @@ jit_horizon_emitter::jit_horizon_emitter(dnnl::impl::cpu::x64::jit_generator* h, } } -void jit_horizon_emitter::emit_impl(const std::vector& in, - const std::vector& out) const { +void jit_horizon_emitter::emit_impl(const std::vector& in, const std::vector& out) const { if (host_isa_ == dnnl::impl::cpu::x64::sse41) { emit_isa(in, out); } else if (host_isa_ == dnnl::impl::cpu::x64::avx2) { @@ -38,9 +37,12 @@ void jit_horizon_emitter::emit_impl(const std::vector& in, } template -void jit_horizon_emitter::emit_isa(const std::vector &in, const std::vector &out) const { +void jit_horizon_emitter::emit_isa(const std::vector& in, const std::vector& out) const { using Vmm = typename dnnl::impl::utils::conditional3::type; + Xbyak::Xmm, + isa == dnnl::impl::cpu::x64::avx2, + Xbyak::Ymm, + Xbyak::Zmm>::type; Vmm src_vmm = Vmm(in[0]); Vmm dst_vmm = Vmm(out[0]); @@ -67,19 +69,19 @@ void jit_horizon_emitter::emit_isa(const std::vector &in, const std::vec perform_op(dst_vmm, dst_vmm, aux_vmm); } -template -void jit_horizon_emitter::perform_op(const Vmm &vmm1, const Vmm &vmm2, const Vmm &vmm3) const { +template +void jit_horizon_emitter::perform_op(const Vmm& vmm1, const Vmm& vmm2, const Vmm& vmm3) const { switch (m_op_type) { - case OpType::max: - h->uni_vmaxps(vmm1, vmm2, vmm3); - break; - case OpType::sum: - h->uni_vaddps(vmm1, vmm2, vmm3); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Unsupported horizontal operation."); + case OpType::max: + h->uni_vmaxps(vmm1, vmm2, vmm3); + break; + case OpType::sum: + h->uni_vaddps(vmm1, vmm2, vmm3); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported horizontal operation."); } } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_horizon_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_horizon_emitter.hpp index 1b222cb2a86776..df74b2ad9783a4 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_horizon_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_horizon_emitter.hpp @@ -6,34 +6,40 @@ #include "emitters/plugin/x64/jit_emitter.hpp" - namespace ov { namespace intel_cpu { class jit_horizon_emitter : public jit_emitter { public: - jit_horizon_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); + jit_horizon_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_num() const override {return 1;} - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + size_t get_inputs_num() const override { + return 1; + } + static std::set> get_supported_precisions( + const std::shared_ptr& node = nullptr) { return {{element::f32}}; } protected: - size_t aux_vecs_count() const override {return 1;} + size_t aux_vecs_count() const override { + return 1; + } private: void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in, const std::vector &out) const; + void emit_isa(const std::vector& in, const std::vector& out) const; - template - void perform_op(const Vmm &vmm1, const Vmm &vmm2, const Vmm &vmm3) const; + template + void perform_op(const Vmm& vmm1, const Vmm& vmm2, const Vmm& vmm3) const; enum class OpType { max, sum }; OpType m_op_type = OpType::max; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp index 476123355abe70..bd5a3227e1e125 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_kernel_emitter.cpp @@ -14,8 +14,11 @@ using namespace dnnl::impl::cpu::x64; namespace ov { namespace intel_cpu { -jit_kernel_emitter::jit_kernel_emitter(jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) - : jit_emitter(h, isa), reg_runtime_params_idx(abi_param1.getIdx()) { +jit_kernel_emitter::jit_kernel_emitter(jit_generator* h, + cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr) + : jit_emitter(h, isa), + reg_runtime_params_idx(abi_param1.getIdx()) { const auto kernel = ov::as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(kernel != nullptr, "invoked with invalid op argument"); OV_CPU_JIT_EMITTER_ASSERT(!kernel->region->empty(), "invoked with empty body"); @@ -59,8 +62,12 @@ void jit_kernel_emitter::init_reg_pools(const std::set& gpr_blacklist, c gp_regs_pool[i] = vec_regs_pool[i] = 15 - i; auto remove_regs_from_pool = [](std::vector& pool, const std::set& to_remove) { // It's important to keep the order of other elements - pool.erase(std::remove_if(pool.begin(), pool.end(), - [&](size_t x) {return to_remove.count(x) != 0;}), pool.end()); + pool.erase(std::remove_if(pool.begin(), + pool.end(), + [&](size_t x) { + return to_remove.count(x) != 0; + }), + pool.end()); }; // Reserve stack base and pointer for push(...) and pop(...) operations std::set gprs_blacklist_extended{Xbyak::Operand::RSP, Xbyak::Operand::RBP}; @@ -70,25 +77,31 @@ void jit_kernel_emitter::init_reg_pools(const std::set& gpr_blacklist, c remove_regs_from_pool(vec_regs_pool, vec_blacklist); } -void jit_kernel_emitter::emit_code(const std::vector &in, const std::vector &out, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const { +void jit_kernel_emitter::emit_code(const std::vector& in, + const std::vector& out, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { validate_arguments(in, out); emit_impl(in, out); } -void jit_kernel_emitter::validate_arguments(const std::vector &in, const std::vector &out) const { +void jit_kernel_emitter::validate_arguments(const std::vector& in, const std::vector& out) const { OV_CPU_JIT_EMITTER_ASSERT(in.empty() && out.empty(), ": expects 0 registers on input and output"); const auto num_params = num_inputs + num_outputs + num_unique_buffers; // The number of used gpr may be >= num_params since LoopBegin+LoopEnd could also use gpr to store work_amount OV_CPU_JIT_EMITTER_ASSERT(data_ptr_regs_idx.size() == num_params, - "number of inputs and outputs is inconsistent with the number of allocated registers ", num_params, - " data_ptr_regs_idx.size() = ", data_ptr_regs_idx.size()); + "number of inputs and outputs is inconsistent with the number of allocated registers ", + num_params, + " data_ptr_regs_idx.size() = ", + data_ptr_regs_idx.size()); } void jit_kernel_emitter::init_body_regs(const std::set& kernel_regs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) { + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) { // Initialize pools of gp and vec registers - // Reserve kernel regs (abi_param1 and, if there is, abi_param2), since they'll be used to pass runtime call args to kernel + // Reserve kernel regs (abi_param1 and, if there is, abi_param2), since they'll be used to pass runtime call args to + // kernel init_reg_pools(kernel_regs, {}); mapping_info gpr_map_pool({}, gp_regs_pool); @@ -122,9 +135,11 @@ void jit_kernel_emitter::emit_impl(const std::vector& in, const std::vec h->postamble(); } -jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, +jit_kernel_static_emitter::jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) - : jit_kernel_emitter(h, isa, expr), reg_indexes_idx(abi_param2.getIdx()) { + : jit_kernel_emitter(h, isa, expr), + reg_indexes_idx(abi_param2.getIdx()) { const auto kernel = ov::as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(kernel != nullptr, "expectes KernelStatic expression"); jcp = *reinterpret_cast(kernel->compile_params); @@ -158,12 +173,12 @@ void jit_kernel_static_emitter::init_data_pointers(const std::vector(*spare_corruptable_gpr)); + Reg64 reg_tmp = + last_iter_explicitly ? data_ptr_regs[num_params - 1] : Reg64(static_cast(*spare_corruptable_gpr)); // Vector "data_ptr_regs" is sorted by abstract regs. // It means that the vector contains the physical registers in order [src, .., src, dst, .., dst, buffer] // So we can initialize buffer register firstly as last value of vector "data_ptr_regs" @@ -193,13 +208,15 @@ void jit_kernel_static_emitter::init_data_pointers(const std::vector(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(kernel, "expectes KernelDynamic expression"); - // - Reserve abi_param1, since it wll be used to pass runtime call args to all dynamic emitters that needs runtime args + // - Reserve abi_param1, since it wll be used to pass runtime call args to all dynamic emitters that needs runtime + // args // - We cannot assign this register to the body emitters since runtime params MUST be valid during whole execution // for all dynamic emitters init_body_regs({reg_runtime_params_idx}); @@ -220,5 +237,5 @@ void jit_kernel_dynamic_emitter::init_data_pointers(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs = {}, const std::vector &pool_gpr_idxs = {}) const override; + jit_kernel_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); + + size_t get_inputs_num() const override { + return 0; + } + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}) const override; protected: void validate_arguments(const std::vector& in, const std::vector& out) const override; - void init_body_regs(const std::set& kernel_regs, const std::vector &pool_vec_idxs = {}, const std::vector &pool_gpr_idxs = {}); + void init_body_regs(const std::set& kernel_regs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}); /** - * @brief populates physical registers pools for x86 (both vec and gp). + * @brief populates physical registers pools for x86 (both vec and gp). * Skips stack-related gprs and extra gprs passed as arguments. * @arg gpr_blacklist - set of gp registers that should not be added to register pool * @arg vec_blacklist - set of vec registers should not be added to register pool - */ + */ void init_reg_pools(const std::set& gpr_blacklist, const std::set& vec_blacklist); virtual void init_data_pointers(const std::vector& data_ptr_regs) const = 0; @@ -70,13 +77,15 @@ class jit_kernel_emitter : public jit_emitter, public jit_container_emitter { std::shared_ptr body; #ifdef SNIPPETS_DEBUG_CAPS - friend std::string init_info_jit_kernel_emitter(const jit_kernel_emitter *emitter); + friend std::string init_info_jit_kernel_emitter(const jit_kernel_emitter* emitter); #endif }; class jit_kernel_static_emitter : public jit_kernel_emitter { public: - jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); + jit_kernel_static_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); private: void init_data_pointers(const std::vector& data_ptr_regs) const override; @@ -86,21 +95,23 @@ class jit_kernel_static_emitter : public jit_kernel_emitter { std::vector> data_offsets; #ifdef SNIPPETS_DEBUG_CAPS - friend std::string init_info_jit_kernel_static_emitter(const jit_kernel_static_emitter *emitter); + friend std::string init_info_jit_kernel_static_emitter(const jit_kernel_static_emitter* emitter); #endif }; class jit_kernel_dynamic_emitter : public jit_kernel_emitter { public: - jit_kernel_dynamic_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); + jit_kernel_dynamic_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); private: void init_data_pointers(const std::vector& data_ptr_regs) const override; #ifdef SNIPPETS_DEBUG_CAPS - friend std::string init_info_jit_kernel_dynamic_emitter(const jit_kernel_dynamic_emitter *emitter); + friend std::string init_info_jit_kernel_dynamic_emitter(const jit_kernel_dynamic_emitter* emitter); #endif }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.cpp index f3151d0df4ccb1..86421678a29011 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.cpp @@ -18,8 +18,11 @@ namespace intel_cpu { namespace { class jit_aux_gpr_holder { public: - jit_aux_gpr_holder(dnnl::impl::cpu::x64::jit_generator* host, std::vector& pool_gpr_idxs, const std::vector& used_gpr_idxs) - : m_h(host), m_pool_gpr_idxs(pool_gpr_idxs) { + jit_aux_gpr_holder(dnnl::impl::cpu::x64::jit_generator* host, + std::vector& pool_gpr_idxs, + const std::vector& used_gpr_idxs) + : m_h(host), + m_pool_gpr_idxs(pool_gpr_idxs) { // If the pool is empty, let's manualy allocate the gpr and push original vlaue on stack if (m_pool_gpr_idxs.empty()) { m_aux_gpr_idx = ov::intel_cpu::utils::get_aux_gpr(used_gpr_idxs); @@ -39,21 +42,26 @@ class jit_aux_gpr_holder { } } - const Reg64& get_reg() const { return m_aux_gpr_idx; } + const Reg64& get_reg() const { + return m_aux_gpr_idx; + } private: dnnl::impl::cpu::x64::jit_generator* m_h; std::vector& m_pool_gpr_idxs; - Reg64 m_aux_gpr_idx {}; + Reg64 m_aux_gpr_idx{}; bool m_is_preserved = false; }; } // namespace /* ================== jit_loop_begin_emitter ====================== */ -jit_loop_begin_emitter::jit_loop_begin_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, +jit_loop_begin_emitter::jit_loop_begin_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) - : jit_emitter(h, isa), loop_begin_label{new Xbyak::Label()}, loop_end_label(nullptr) { + : jit_emitter(h, isa), + loop_begin_label{new Xbyak::Label()}, + loop_end_label(nullptr) { const auto loop_begin = ov::as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(loop_begin, "expects LoopBegin expression"); const auto loop_end = loop_begin->get_loop_end(); @@ -65,7 +73,7 @@ jit_loop_begin_emitter::jit_loop_begin_emitter(dnnl::impl::cpu::x64::jit_generat in_out_type_ = emitter_in_out_map::gpr_to_gpr; } -void jit_loop_begin_emitter::validate_arguments(const std::vector &in, const std::vector &out) const { +void jit_loop_begin_emitter::validate_arguments(const std::vector& in, const std::vector& out) const { OV_CPU_JIT_EMITTER_ASSERT(in.empty(), "Invalid inputs size: expected 0 got " + std::to_string(in.size())); // Note: the only expected output is work amount register (communicated to jit_loop_end_emitter) OV_CPU_JIT_EMITTER_ASSERT(out.size() == 1, "Invalid outputs size: expected 1 got " + std::to_string(out.size())); @@ -74,21 +82,24 @@ void jit_loop_begin_emitter::validate_arguments(const std::vector &in, c "loop increment might be dynamic only if loop evaluates once!"); } -void jit_loop_begin_emitter::emit_code(const std::vector &in, const std::vector &out, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const { +void jit_loop_begin_emitter::emit_code(const std::vector& in, + const std::vector& out, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { validate_arguments(in, out); jit_emitter::emit_code(in, out, pool_vec_idxs, pool_gpr_idxs); } void jit_loop_begin_emitter::emit_impl(const std::vector& in, const std::vector& out) const { // If the loop evaulate once, we can skip loop begin code emission - // If work_amount is dynamic, we should get runtime `work_amount` - it might be `zero` and we should skip loop evaluation + // If work_amount is dynamic, we should get runtime `work_amount` - it might be `zero` and we should skip loop + // evaluation if (evaluate_once && !is_work_amount_dynamic) return; Reg64 reg_work_amount = Reg64(static_cast(out.back())); if (is_work_amount_dynamic) { - jit_aux_gpr_holder gpr_holder(h, aux_gpr_idxs, out); // loop_begin has only output registers + jit_aux_gpr_holder gpr_holder(h, aux_gpr_idxs, out); // loop_begin has only output registers Reg64 reg_loop_args_ptr = gpr_holder.get_reg(); const auto id_offset = loop_id * sizeof(jit_snippets_call_args::loop_args_t); h->mov(reg_loop_args_ptr, h->ptr[abi_param1 + GET_OFF(loop_args)]); @@ -113,9 +124,12 @@ void jit_loop_begin_emitter::emit_impl(const std::vector& in, const std: /* ================== jit_loop_end_emitter ====================== */ -jit_loop_end_emitter::jit_loop_end_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, +jit_loop_end_emitter::jit_loop_end_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr) - : jit_emitter(h, isa), loop_begin_label{nullptr}, loop_end_label{new Xbyak::Label()} { + : jit_emitter(h, isa), + loop_begin_label{nullptr}, + loop_end_label{new Xbyak::Label()} { in_out_type_ = emitter_in_out_map::gpr_to_gpr; const auto loop_end = ov::as_type_ptr(expr->get_node()); OV_CPU_JIT_EMITTER_ASSERT(loop_end != nullptr, "expected LoopEnd expr"); @@ -132,8 +146,9 @@ jit_loop_end_emitter::jit_loop_end_emitter(dnnl::impl::cpu::x64::jit_generator* are_ptr_increments_dynamic = std::any_of(ptr_increments.cbegin(), ptr_increments.cend(), ov::snippets::utils::is_dynamic_value); - are_final_offsets_dynamic = - std::any_of(finalization_offsets.cbegin(), finalization_offsets.cend(), ov::snippets::utils::is_dynamic_value); + are_final_offsets_dynamic = std::any_of(finalization_offsets.cbegin(), + finalization_offsets.cend(), + ov::snippets::utils::is_dynamic_value); are_ptr_shifts_dynamic = are_ptr_increments_dynamic || are_final_offsets_dynamic; const auto begin_expr = get_loop_begin_expr(expr); @@ -143,29 +158,51 @@ jit_loop_end_emitter::jit_loop_end_emitter(dnnl::impl::cpu::x64::jit_generator* loop_begin_label = loop_begin_emitter->get_begin_label(); } -ov::snippets::lowered::ExpressionPtr jit_loop_end_emitter::get_loop_begin_expr(const ov::snippets::lowered::ExpressionPtr& expr) { +ov::snippets::lowered::ExpressionPtr jit_loop_end_emitter::get_loop_begin_expr( + const ov::snippets::lowered::ExpressionPtr& expr) { const auto begin_expr = expr->get_input_port_connectors().back()->get_source().get_expr(); OV_CPU_JIT_EMITTER_ASSERT(ov::is_type(begin_expr->get_node()), "LoopEnd expression must have th last port connector to LoopBegin"); return begin_expr; } -void jit_loop_end_emitter::validate_arguments(const std::vector &in, const std::vector &out) const { +void jit_loop_end_emitter::validate_arguments(const std::vector& in, const std::vector& out) const { const auto io_size = num_inputs + num_outputs; OV_CPU_JIT_EMITTER_ASSERT(out.size() == 0, "Invalid number of out arguments: expected ", 0, " got ", out.size()); - OV_CPU_JIT_EMITTER_ASSERT(in.size() == io_size + 1, "Invalid number of in arguments: expected ", io_size + 1, " got ", in.size()); - OV_CPU_JIT_EMITTER_ASSERT(is_incremented.size() == io_size, "Invalid is_incremented size: expected ", io_size, " got ", is_incremented.size()); - OV_CPU_JIT_EMITTER_ASSERT(ptr_increments.size() == io_size, "Invalid ptr_increments size: expected ", io_size, " got ", ptr_increments.size()); + OV_CPU_JIT_EMITTER_ASSERT(in.size() == io_size + 1, + "Invalid number of in arguments: expected ", + io_size + 1, + " got ", + in.size()); + OV_CPU_JIT_EMITTER_ASSERT(is_incremented.size() == io_size, + "Invalid is_incremented size: expected ", + io_size, + " got ", + is_incremented.size()); + OV_CPU_JIT_EMITTER_ASSERT(ptr_increments.size() == io_size, + "Invalid ptr_increments size: expected ", + io_size, + " got ", + ptr_increments.size()); OV_CPU_JIT_EMITTER_ASSERT(finalization_offsets.size() == io_size, - "Invalid finalization_offsets size: expected: ", io_size, " got ", finalization_offsets.size()); - OV_CPU_JIT_EMITTER_ASSERT(data_sizes.size() == io_size, "Invalid data_sizes size: expected: ", io_size, " got ", data_sizes.size()); + "Invalid finalization_offsets size: expected: ", + io_size, + " got ", + finalization_offsets.size()); + OV_CPU_JIT_EMITTER_ASSERT(data_sizes.size() == io_size, + "Invalid data_sizes size: expected: ", + io_size, + " got ", + data_sizes.size()); OV_CPU_JIT_EMITTER_ASSERT(loop_end_label != nullptr && loop_begin_label != nullptr, "has not inited labels!"); OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_value(wa_increment) || evaluate_once, "loop increment might be dynamic only if loop evaluates once!"); } -void jit_loop_end_emitter::emit_code(const std::vector &in, const std::vector &out, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const { +void jit_loop_end_emitter::emit_code(const std::vector& in, + const std::vector& out, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { validate_arguments(in, out); jit_emitter::emit_code(in, out, pool_vec_idxs, pool_gpr_idxs); } @@ -176,34 +213,38 @@ void jit_loop_end_emitter::emit_impl(const std::vector& in, const std::v data_ptr_reg_idxs.reserve(num_inputs + num_outputs); std::copy(in.begin(), in.end() - 1, std::back_inserter(data_ptr_reg_idxs)); - auto apply_increments = [&](bool use_runtime_args, size_t field_offset, const std::vector& increments, size_t scale) { - Reg64 reg_increments; - auto add_increments = [&]() { - for (size_t idx = 0; idx < data_ptr_reg_idxs.size(); idx++) { - const auto& increment = increments[idx]; - if (is_incremented[idx] && increment != 0) { - if (ov::snippets::utils::is_dynamic_value(increment)) { - OV_CPU_JIT_EMITTER_ASSERT(use_runtime_args, "Loop argument structure cannot be pushed to aux GPR"); - h->add(Reg64(static_cast(data_ptr_reg_idxs[idx])), h->ptr[reg_increments + idx * sizeof(int64_t)]); - } else { - h->add(Reg64(static_cast(data_ptr_reg_idxs[idx])), increment * scale * data_sizes[idx]); + auto apply_increments = + [&](bool use_runtime_args, size_t field_offset, const std::vector& increments, size_t scale) { + Reg64 reg_increments; + auto add_increments = [&]() { + for (size_t idx = 0; idx < data_ptr_reg_idxs.size(); idx++) { + const auto& increment = increments[idx]; + if (is_incremented[idx] && increment != 0) { + if (ov::snippets::utils::is_dynamic_value(increment)) { + OV_CPU_JIT_EMITTER_ASSERT(use_runtime_args, + "Loop argument structure cannot be pushed to aux GPR"); + h->add(Reg64(static_cast(data_ptr_reg_idxs[idx])), + h->ptr[reg_increments + idx * sizeof(int64_t)]); + } else { + h->add(Reg64(static_cast(data_ptr_reg_idxs[idx])), + increment * scale * data_sizes[idx]); + } } } + }; + + const auto id_offset = loop_id * sizeof(jit_snippets_call_args::loop_args_t); + if (use_runtime_args) { + jit_aux_gpr_holder gpr_holder(h, aux_gpr_idxs, in); // loop_end has only input registers + reg_increments = gpr_holder.get_reg(); + h->mov(reg_increments, h->ptr[abi_param1 + GET_OFF(loop_args)]); + h->mov(reg_increments, h->ptr[reg_increments + id_offset + field_offset]); + add_increments(); + } else { + add_increments(); } }; - const auto id_offset = loop_id * sizeof(jit_snippets_call_args::loop_args_t); - if (use_runtime_args) { - jit_aux_gpr_holder gpr_holder(h, aux_gpr_idxs, in); // loop_end has only input registers - reg_increments = gpr_holder.get_reg(); - h->mov(reg_increments, h->ptr[abi_param1 + GET_OFF(loop_args)]); - h->mov(reg_increments, h->ptr[reg_increments + id_offset + field_offset]); - add_increments(); - } else { - add_increments(); - } - }; - if (!evaluate_once) { apply_increments(are_ptr_increments_dynamic, GET_OFF_LOOP_ARGS(m_ptr_increments), ptr_increments, wa_increment); @@ -220,5 +261,5 @@ void jit_loop_end_emitter::emit_impl(const std::vector& in, const std::v /* ============================================================== */ -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.hpp index 262bba39b7d74c..c0a2b53b100c62 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_loop_emitters.hpp @@ -5,7 +5,6 @@ #pragma once #include "emitters/plugin/x64/jit_emitter.hpp" - #include "snippets/op/loop.hpp" #include "snippets/utils/utils.hpp" @@ -14,25 +13,36 @@ namespace intel_cpu { /* ================== jit_loop_begin_emitter ====================== */ -class jit_loop_begin_emitter: public jit_emitter { +class jit_loop_begin_emitter : public jit_emitter { public: - jit_loop_begin_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, + jit_loop_begin_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_num() const override { return 0; } + size_t get_inputs_num() const override { + return 0; + } - void emit_code(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs = {}, const std::vector &pool_gpr_idxs = {}) const override; + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}) const override; - void set_loop_end_label(const std::shared_ptr& label) { loop_end_label = label; } - std::shared_ptr get_begin_label() { return loop_begin_label; } + void set_loop_end_label(const std::shared_ptr& label) { + loop_end_label = label; + } + std::shared_ptr get_begin_label() { + return loop_begin_label; + } protected: - void validate_arguments(const std::vector &in, const std::vector &out) const override; + void validate_arguments(const std::vector& in, const std::vector& out) const override; void emit_impl(const std::vector& in, const std::vector& out) const override; // `jit_loop_begin_emitter` handles manually aux_gpr allocation using `jit_aux_gpr_holder` - size_t aux_gprs_count() const override { return 0; } + size_t aux_gprs_count() const override { + return 0; + } std::shared_ptr loop_begin_label = nullptr; std::shared_ptr loop_end_label = nullptr; @@ -43,27 +53,33 @@ class jit_loop_begin_emitter: public jit_emitter { bool is_work_amount_dynamic = false; }; - /* ============================================================== */ /* ================== jit_loop_end_emitter ====================== */ -class jit_loop_end_emitter: public jit_emitter { +class jit_loop_end_emitter : public jit_emitter { public: - jit_loop_end_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr); + jit_loop_end_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_num() const override { return 0; } + size_t get_inputs_num() const override { + return 0; + } - void emit_code(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs = {}, const std::vector &pool_gpr_idxs = {}) const override; + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}) const override; protected: - void validate_arguments(const std::vector &in, const std::vector &out) const override; + void validate_arguments(const std::vector& in, const std::vector& out) const override; void emit_impl(const std::vector& in, const std::vector& out) const override; // `jit_loop_end_emitter` handles manually aux_gpr allocation using `jit_aux_gpr_holder` - size_t aux_gprs_count() const override { return 0; } + size_t aux_gprs_count() const override { + return 0; + } static ov::snippets::lowered::ExpressionPtr get_loop_begin_expr(const ov::snippets::lowered::ExpressionPtr& expr); @@ -86,5 +102,5 @@ class jit_loop_end_emitter: public jit_emitter { /* ============================================================== */ -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.cpp index b7a5fc2e993398..307ef63a8e6a2e 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.cpp @@ -5,10 +5,9 @@ #include "jit_memory_emitters.hpp" #include "emitters/snippets/jit_snippets_call_args.hpp" +#include "snippets/op/buffer.hpp" #include "transformations/snippets/x64/op/load_convert.hpp" #include "transformations/snippets/x64/op/store_convert.hpp" -#include "snippets/op/buffer.hpp" - using namespace Xbyak; using namespace dnnl::impl; @@ -21,7 +20,10 @@ using jit_generator = dnnl::impl::cpu::x64::jit_generator; using cpu_isa_t = dnnl::impl::cpu::x64::cpu_isa_t; using ExpressionPtr = ov::snippets::lowered::ExpressionPtr; -jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr, emitter_in_out_map in_out_type) +jit_memory_emitter::jit_memory_emitter(jit_generator* h, + cpu_isa_t isa, + const ExpressionPtr& expr, + emitter_in_out_map in_out_type) : jit_emitter(h, isa) { in_out_type_ = in_out_type; @@ -36,7 +38,8 @@ jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const Ex compiled_byte_offset = memory_access->get_input_offset(); buffer_cluster_id = get_parent_buffer_cluster_id(expr); } else if (in_out_type_ == emitter_in_out_map::vec_to_gpr) { - OV_CPU_JIT_EMITTER_ASSERT(memory_access->is_memory_access_output_port(0), "must be output port - memory access"); + OV_CPU_JIT_EMITTER_ASSERT(memory_access->is_memory_access_output_port(0), + "must be output port - memory access"); count = memory_access->get_output_count(); compiled_byte_offset = memory_access->get_output_offset(); buffer_cluster_id = get_consumer_buffer_cluster_id(expr); @@ -46,7 +49,8 @@ jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const Ex if (ov::snippets::utils::is_dynamic_value(compiled_byte_offset)) { is_offset_runtime = true; - // Compiled byte offset is zero to manually `add` runtime offset before operation and `sub` after to reset pointer in the register + // Compiled byte offset is zero to manually `add` runtime offset before operation and `sub` after to reset + // pointer in the register compiled_byte_offset = 0; OV_CPU_JIT_EMITTER_ASSERT(buffer_cluster_id != SIZE_MAX, "Incorrect buffer offset in call_args"); } @@ -84,8 +88,10 @@ std::vector jit_memory_emitter::get_available_aux_gprs() const { return available_aux_gprs; } -void jit_memory_emitter::emit_code(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs, const std::vector &pool_gpr_idxs) const { +void jit_memory_emitter::emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs, + const std::vector& pool_gpr_idxs) const { emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs); Reg64 reg_runtime_params = abi_param1; // defined by jit_kernel_emitter @@ -152,19 +158,26 @@ void jit_load_broadcast_emitter::emit_impl(const std::vector& in, const } template -void jit_load_broadcast_emitter::emit_isa(const std::vector &in, const std::vector &out) const { - using Vmm = typename dnnl::impl::utils::conditional3::type; +void jit_load_broadcast_emitter::emit_isa(const std::vector& in, const std::vector& out) const { + using Vmm = typename dnnl::impl::utils:: + conditional3::type; Reg64 in_reg(in[0]); Vmm vmm_dst = Vmm(out[0]); - // It doesn't really matter if we broadcast or `movss` for vector tails so keep only one version for `BroadcastLoad`, - // key point here is not to add post-increment, it might be fixed by some other approach in future + // It doesn't really matter if we broadcast or `movss` for vector tails so keep only one version for + // `BroadcastLoad`, key point here is not to add post-increment, it might be fixed by some other approach in future switch (src_prc.size()) { - case 4: h->uni_vbroadcastss(vmm_dst, h->ptr[in_reg + compiled_byte_offset]); break; - case 2: h->vpbroadcastw(vmm_dst, h->ptr[in_reg + compiled_byte_offset]); break; - case 1: h->vpbroadcastb(vmm_dst, h->ptr[in_reg + compiled_byte_offset]); break; - default: OV_CPU_JIT_EMITTER_THROW("Unsupported data type"); + case 4: + h->uni_vbroadcastss(vmm_dst, h->ptr[in_reg + compiled_byte_offset]); + break; + case 2: + h->vpbroadcastw(vmm_dst, h->ptr[in_reg + compiled_byte_offset]); + break; + case 1: + h->vpbroadcastb(vmm_dst, h->ptr[in_reg + compiled_byte_offset]); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported data type"); } } @@ -190,5 +203,5 @@ void jit_store_memory_emitter::emit_data() const { store_emitter->emit_data(); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.hpp index 55a41c977dd67c..d21e85d53e7193 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_memory_emitters.hpp @@ -7,17 +7,20 @@ #include "emitters/plugin/x64/jit_emitter.hpp" #include "emitters/plugin/x64/jit_load_store_emitters.hpp" - namespace ov { namespace intel_cpu { -class jit_memory_emitter : public jit_emitter { +class jit_memory_emitter : public jit_emitter { public: - jit_memory_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr, emitter_in_out_map in_out_type); + jit_memory_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr, + emitter_in_out_map in_out_type); - void emit_code(const std::vector &in_idxs, const std::vector &out_idxs, - const std::vector &pool_vec_idxs = {}, const std::vector &pool_gpr_idxs = {}) const override; + void emit_code(const std::vector& in_idxs, + const std::vector& out_idxs, + const std::vector& pool_vec_idxs = {}, + const std::vector& pool_gpr_idxs = {}) const override; protected: static size_t get_parent_buffer_cluster_id(const ov::snippets::lowered::ExpressionPtr& expr); @@ -36,16 +39,19 @@ class jit_memory_emitter : public jit_emitter { bool is_offset_runtime = false; #ifdef SNIPPETS_DEBUG_CAPS - friend std::string init_info_jit_memory_emitter(const jit_memory_emitter *emitter); + friend std::string init_info_jit_memory_emitter(const jit_memory_emitter* emitter); #endif }; class jit_load_memory_emitter : public jit_memory_emitter { public: - jit_load_memory_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, + jit_load_memory_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_num() const override {return 0;} + size_t get_inputs_num() const override { + return 0; + } private: void emit_impl(const std::vector& in, const std::vector& out) const override; @@ -58,24 +64,30 @@ class jit_load_memory_emitter : public jit_memory_emitter { class jit_load_broadcast_emitter : public jit_memory_emitter { public: - jit_load_broadcast_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, + jit_load_broadcast_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_num() const override {return 0;} + size_t get_inputs_num() const override { + return 0; + } private: void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in, const std::vector &out) const; + void emit_isa(const std::vector& in, const std::vector& out) const; }; -class jit_store_memory_emitter : public jit_memory_emitter { +class jit_store_memory_emitter : public jit_memory_emitter { public: - jit_store_memory_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, + jit_store_memory_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_num() const override {return 1;} + size_t get_inputs_num() const override { + return 1; + } private: void emit_impl(const std::vector& in, const std::vector& out) const override; @@ -86,5 +98,5 @@ class jit_store_memory_emitter : public jit_memory_emitter { std::unique_ptr store_emitter = nullptr; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_chrono_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_chrono_emitters.cpp index f89e906ce57593..ccb4da742e38d6 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_chrono_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_chrono_emitters.cpp @@ -3,9 +3,9 @@ // #ifdef SNIPPETS_DEBUG_CAPS -#include "jit_perf_count_chrono_emitters.hpp" +# include "jit_perf_count_chrono_emitters.hpp" -#include "emitters/plugin/x64/utils.hpp" +# include "emitters/plugin/x64/utils.hpp" using namespace dnnl::impl; using namespace dnnl::impl::utils; @@ -17,8 +17,10 @@ using namespace Xbyak::util; namespace ov { namespace intel_cpu { -jit_perf_count_chrono_start_emitter::jit_perf_count_chrono_start_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - const std::shared_ptr& n) : jit_emitter(host, host_isa) { +jit_perf_count_chrono_start_emitter::jit_perf_count_chrono_start_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n) + : jit_emitter(host, host_isa) { m_start_node = ov::as_type_ptr(n); } @@ -30,11 +32,12 @@ void jit_perf_count_chrono_start_emitter::set_start_time(snippets::op::PerfCount start_node->set_start_time(); } -void jit_perf_count_chrono_start_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_perf_count_chrono_start_emitter::emit_impl(const std::vector& in_idxs, + const std::vector& out_idxs) const { EmitABIRegSpills spill(h); spill.preamble(); - const auto &set_start_time_overload = static_cast(set_start_time); + const auto& set_start_time_overload = static_cast(set_start_time); h->mov(h->rax, reinterpret_cast(set_start_time_overload)); h->mov(abi_param1, reinterpret_cast(m_start_node.get())); @@ -46,8 +49,10 @@ void jit_perf_count_chrono_start_emitter::emit_impl(const std::vector &i } ///////////////////jit_perf_count_chrono_end_emitter//////////////////////////////////// -jit_perf_count_chrono_end_emitter::jit_perf_count_chrono_end_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - const std::shared_ptr& n) : jit_emitter(host, host_isa) { +jit_perf_count_chrono_end_emitter::jit_perf_count_chrono_end_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n) + : jit_emitter(host, host_isa) { m_end_node = ov::as_type_ptr(n); } @@ -59,11 +64,13 @@ void jit_perf_count_chrono_end_emitter::set_accumulated_time(snippets::op::PerfC end_node->set_accumulated_time(); } -void jit_perf_count_chrono_end_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_perf_count_chrono_end_emitter::emit_impl(const std::vector& in_idxs, + const std::vector& out_idxs) const { EmitABIRegSpills spill(h); spill.preamble(); - const auto &set_accumulated_time_overload = static_cast(set_accumulated_time); + const auto& set_accumulated_time_overload = + static_cast(set_accumulated_time); h->mov(h->rax, reinterpret_cast(set_accumulated_time_overload)); h->mov(abi_param1, reinterpret_cast(m_end_node.get())); @@ -74,6 +81,6 @@ void jit_perf_count_chrono_end_emitter::emit_impl(const std::vector &in_ spill.postamble(); } -} // namespace intel_cpu -} // namespace ov -#endif // SNIPPETS_DEBUG_CAPS +} // namespace intel_cpu +} // namespace ov +#endif // SNIPPETS_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_chrono_emitters.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_chrono_emitters.hpp index e8608afc7f1428..817c0583609778 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_chrono_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_chrono_emitters.hpp @@ -3,24 +3,23 @@ // #ifdef SNIPPETS_DEBUG_CAPS -#pragma once - -#include "emitters/plugin/x64/jit_emitter.hpp" - -#include "snippets/op/perf_count.hpp" +# pragma once +# include "emitters/plugin/x64/jit_emitter.hpp" +# include "snippets/op/perf_count.hpp" namespace ov { namespace intel_cpu { class jit_perf_count_chrono_start_emitter : public jit_emitter { public: - jit_perf_count_chrono_start_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_perf_count_chrono_start_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); size_t get_inputs_num() const override; private: - void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override; + void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; static void set_start_time(snippets::op::PerfCountBegin* start_node); std::shared_ptr m_start_node = nullptr; @@ -28,17 +27,18 @@ class jit_perf_count_chrono_start_emitter : public jit_emitter { class jit_perf_count_chrono_end_emitter : public jit_emitter { public: - jit_perf_count_chrono_end_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_perf_count_chrono_end_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, const std::shared_ptr& n); size_t get_inputs_num() const override; private: - void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override; + void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; static void set_accumulated_time(snippets::op::PerfCountEnd* end_node); std::shared_ptr m_end_node = nullptr; }; -} // namespace intel_cpu -} // namespace ov -#endif // SNIPPETS_DEBUG_CAPS +} // namespace intel_cpu +} // namespace ov +#endif // SNIPPETS_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_rdtsc_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_rdtsc_emitters.cpp index c469c052ce3ef6..e951f8042ad762 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_rdtsc_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_rdtsc_emitters.cpp @@ -3,7 +3,7 @@ // #ifdef SNIPPETS_DEBUG_CAPS -#include "jit_perf_count_rdtsc_emitters.hpp" +# include "jit_perf_count_rdtsc_emitters.hpp" using namespace dnnl::impl; using namespace dnnl::impl::utils; @@ -15,8 +15,10 @@ using namespace Xbyak::util; namespace ov { namespace intel_cpu { -jit_perf_count_rdtsc_start_emitter::jit_perf_count_rdtsc_start_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - const std::shared_ptr& n) : jit_emitter(host, host_isa) { +jit_perf_count_rdtsc_start_emitter::jit_perf_count_rdtsc_start_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n) + : jit_emitter(host, host_isa) { m_start_node = ov::as_type_ptr(n); } @@ -24,16 +26,18 @@ size_t jit_perf_count_rdtsc_start_emitter::get_inputs_num() const { return 0; } -void jit_perf_count_rdtsc_start_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_perf_count_rdtsc_start_emitter::emit_impl(const std::vector& in_idxs, + const std::vector& out_idxs) const { h->push(h->rax); h->push(h->rdx); - // The EDX register is loaded with the high-order 32 bits of the MSR and the EAX register is loaded with the low-order 32 bits. + // The EDX register is loaded with the high-order 32 bits of the MSR and the EAX register is loaded with the + // low-order 32 bits. h->lfence(); h->rdtsc(); h->lfence(); - h->shl(h->rdx, 0x20); // shift to higher half of rdx 0x20(32) - h->or_(h->rdx, h->rax); // rdx has current tsc + h->shl(h->rdx, 0x20); // shift to higher half of rdx 0x20(32) + h->or_(h->rdx, h->rax); // rdx has current tsc h->mov(h->rax, reinterpret_cast(&m_start_node->start_count)); h->mov(qword[h->rax], h->rdx); @@ -43,16 +47,19 @@ void jit_perf_count_rdtsc_start_emitter::emit_impl(const std::vector &in } ///////////////////jit_perf_count_rdtsc_end_emitter//////////////////////////////////// -jit_perf_count_rdtsc_end_emitter::jit_perf_count_rdtsc_end_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - const std::shared_ptr& n) : jit_emitter(host, host_isa) { - m_end_node = ov::as_type_ptr(n); +jit_perf_count_rdtsc_end_emitter::jit_perf_count_rdtsc_end_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n) + : jit_emitter(host, host_isa) { + m_end_node = ov::as_type_ptr(n); } size_t jit_perf_count_rdtsc_end_emitter::get_inputs_num() const { return 0; } -void jit_perf_count_rdtsc_end_emitter::emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const { +void jit_perf_count_rdtsc_end_emitter::emit_impl(const std::vector& in_idxs, + const std::vector& out_idxs) const { h->push(h->rax); h->push(h->rdx); @@ -79,6 +86,6 @@ void jit_perf_count_rdtsc_end_emitter::emit_impl(const std::vector &in_i h->pop(h->rax); } -} // namespace intel_cpu -} // namespace ov -#endif // SNIPPETS_DEBUG_CAPS +} // namespace intel_cpu +} // namespace ov +#endif // SNIPPETS_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_rdtsc_emitters.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_rdtsc_emitters.hpp index c3ae1aac01ab9d..343807bdfcd076 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_rdtsc_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_perf_count_rdtsc_emitters.hpp @@ -3,40 +3,40 @@ // #ifdef SNIPPETS_DEBUG_CAPS -#pragma once - -#include "emitters/plugin/x64/jit_emitter.hpp" - -#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp" +# pragma once +# include "emitters/plugin/x64/jit_emitter.hpp" +# include "transformations/snippets/x64/op/perf_count_rdtsc.hpp" namespace ov { namespace intel_cpu { class jit_perf_count_rdtsc_start_emitter : public jit_emitter { public: - jit_perf_count_rdtsc_start_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - const std::shared_ptr& n); + jit_perf_count_rdtsc_start_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n); size_t get_inputs_num() const override; private: - void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override; + void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; std::shared_ptr m_start_node = nullptr; }; class jit_perf_count_rdtsc_end_emitter : public jit_emitter { public: - jit_perf_count_rdtsc_end_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - const std::shared_ptr& n); + jit_perf_count_rdtsc_end_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + const std::shared_ptr& n); size_t get_inputs_num() const override; private: - void emit_impl(const std::vector &in_idxs, const std::vector &out_idxs) const override; + void emit_impl(const std::vector& in_idxs, const std::vector& out_idxs) const override; std::shared_ptr m_end_node = nullptr; }; -} // namespace intel_cpu -} // namespace ov -#endif // SNIPPETS_DEBUG_CAPS +} // namespace intel_cpu +} // namespace ov +#endif // SNIPPETS_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_segfault_detector_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_segfault_detector_emitter.cpp index f88c345ff055b5..c513e969144d1c 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_segfault_detector_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_segfault_detector_emitter.cpp @@ -4,8 +4,9 @@ #ifdef SNIPPETS_DEBUG_CAPS -#include "jit_segfault_detector_emitter.hpp" -#include "emitters/plugin/x64/utils.hpp" +# include "jit_segfault_detector_emitter.hpp" + +# include "emitters/plugin/x64/utils.hpp" using namespace dnnl::impl::utils; using namespace dnnl::impl; @@ -18,22 +19,28 @@ namespace intel_cpu { std::shared_ptr> g_custom_segfault_handler = std::make_shared>(); -jit_uni_segfault_detector_emitter::jit_uni_segfault_detector_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - jit_emitter* target_emitter, bool is_load, bool is_store, std::string target_node_name) : - jit_emitter(host, host_isa), - m_target_emitter(target_emitter), - is_target_use_load_emitter(is_load), - is_target_use_store_emitter(is_store), - m_target_node_name(target_node_name) { +jit_uni_segfault_detector_emitter::jit_uni_segfault_detector_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_emitter* target_emitter, + bool is_load, + bool is_store, + std::string target_node_name) + : jit_emitter(host, host_isa), + m_target_emitter(target_emitter), + is_target_use_load_emitter(is_load), + is_target_use_store_emitter(is_store), + m_target_node_name(target_node_name) {} + +size_t jit_uni_segfault_detector_emitter::get_inputs_num() const { + return 1; } -size_t jit_uni_segfault_detector_emitter::get_inputs_num() const { return 1; } - const jit_emitter* jit_uni_segfault_detector_emitter::get_target_emitter() const { return m_target_emitter; } -void jit_uni_segfault_detector_emitter::emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const { +void jit_uni_segfault_detector_emitter::emit_impl(const std::vector& in_vec_idxs, + const std::vector& out_vec_idxs) const { save_target_emitter(); if (is_target_use_load_emitter) { memory_track(in_vec_idxs[0]); @@ -47,7 +54,8 @@ void jit_uni_segfault_detector_emitter::save_target_emitter() const { EmitABIRegSpills spill(h); spill.preamble(); - const auto &set_local_handler_overload = static_cast(set_local_handler); + const auto& set_local_handler_overload = + static_cast(set_local_handler); h->mov(h->rax, reinterpret_cast(set_local_handler_overload)); h->mov(abi_param1, reinterpret_cast(this)); @@ -85,7 +93,7 @@ void jit_uni_segfault_detector_emitter::memory_track(size_t gpr_idx_for_mem_addr h->pop(h->r15); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov #endif diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_segfault_detector_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_segfault_detector_emitter.hpp index 21ffaa84cf3db8..86191ae865fe38 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_segfault_detector_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_segfault_detector_emitter.hpp @@ -4,11 +4,12 @@ #ifdef SNIPPETS_DEBUG_CAPS -#pragma once +# pragma once -#include -#include "emitters/plugin/x64/jit_emitter.hpp" -#include "openvino/runtime/threading/thread_local.hpp" +# include + +# include "emitters/plugin/x64/jit_emitter.hpp" +# include "openvino/runtime/threading/thread_local.hpp" using namespace ov::threading; @@ -20,18 +21,22 @@ extern std::shared_ptr> g_custom class jit_uni_segfault_detector_emitter : public jit_emitter { public: - jit_uni_segfault_detector_emitter(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - jit_emitter* target_emitter, bool is_load, bool is_store, std::string target_node_name); + jit_uni_segfault_detector_emitter(dnnl::impl::cpu::x64::jit_generator* host, + dnnl::impl::cpu::x64::cpu_isa_t host_isa, + jit_emitter* target_emitter, + bool is_load, + bool is_store, + std::string target_node_name); size_t get_inputs_num() const override; const jit_emitter* get_target_emitter() const; private: - // emit code is to save "this" pointer(jit_uni_segfault_detector_emitter) to global handler, then print info w/ it's target_emitter. - // and to save tracked memory address, iteration, etc to print + // emit code is to save "this" pointer(jit_uni_segfault_detector_emitter) to global handler, then print info w/ it's + // target_emitter. and to save tracked memory address, iteration, etc to print void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override; - jit_emitter *m_target_emitter = nullptr; + jit_emitter* m_target_emitter = nullptr; bool is_target_use_load_emitter = false; bool is_target_use_store_emitter = false; std::string m_target_node_name = ""; @@ -44,10 +49,10 @@ class jit_uni_segfault_detector_emitter : public jit_emitter { mutable size_t current_address = 0; mutable size_t iteration = 0; - friend std::string init_info_jit_uni_segfault_detector_emitter(const jit_uni_segfault_detector_emitter *emitter); + friend std::string init_info_jit_uni_segfault_detector_emitter(const jit_uni_segfault_detector_emitter* emitter); }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov #endif \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_snippets_emitters.cpp index d8066f9a126543..ba4012de86d83d 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_snippets_emitters.cpp @@ -15,7 +15,10 @@ using jit_generator = dnnl::impl::cpu::x64::jit_generator; using cpu_isa_t = dnnl::impl::cpu::x64::cpu_isa_t; using ExpressionPtr = ov::snippets::lowered::ExpressionPtr; -jit_nop_emitter::jit_nop_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr, emitter_in_out_map emitter_type) +jit_nop_emitter::jit_nop_emitter(jit_generator* h, + cpu_isa_t isa, + const ExpressionPtr& expr, + emitter_in_out_map emitter_type) : jit_emitter(h, isa) { in_out_type_ = emitter_type; } @@ -25,7 +28,8 @@ jit_parameter_emitter::jit_parameter_emitter(jit_generator* h, cpu_isa_t isa, co in_out_type_ = emitter_in_out_map::gpr_to_gpr; } -jit_result_emitter::jit_result_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_nop_emitter(h, isa, expr) { +jit_result_emitter::jit_result_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) + : jit_nop_emitter(h, isa, expr) { in_out_type_ = emitter_in_out_map::gpr_to_gpr; } @@ -34,14 +38,13 @@ jit_broadcast_move_emitter::jit_broadcast_move_emitter(jit_generator* h, cpu_isa const auto n = expr->get_node(); if (n->get_input_element_type(0) != n->get_output_element_type(0)) OV_CPU_JIT_EMITTER_THROW("supports only equal input and output types but gets: ", - n->get_input_element_type(0), - " and ", - n->get_output_element_type(0)); + n->get_input_element_type(0), + " and ", + n->get_output_element_type(0)); byte_size = n->get_input_element_type(0).size(); } -void jit_broadcast_move_emitter::emit_impl(const std::vector& in, - const std::vector& out) const { +void jit_broadcast_move_emitter::emit_impl(const std::vector& in, const std::vector& out) const { if (host_isa_ == dnnl::impl::cpu::x64::sse41) { emit_isa(in, out); } else if (host_isa_ == dnnl::impl::cpu::x64::avx2) { @@ -54,17 +57,24 @@ void jit_broadcast_move_emitter::emit_impl(const std::vector& in, } template -void jit_broadcast_move_emitter::emit_isa(const std::vector &in, const std::vector &out) const { - using Vmm = typename dnnl::impl::utils::conditional3::type; +void jit_broadcast_move_emitter::emit_isa(const std::vector& in, const std::vector& out) const { + using Vmm = typename dnnl::impl::utils:: + conditional3::type; Xmm xmm_src0 = Xmm(in[0]); - Vmm vmm_dst = Vmm(out[0]); + Vmm vmm_dst = Vmm(out[0]); switch (byte_size) { - case 4: h->uni_vbroadcastss(vmm_dst, xmm_src0); break; - case 2: h->vpbroadcastw(vmm_dst, xmm_src0); break; - case 1: h->vpbroadcastb(vmm_dst, xmm_src0); break; - default: OV_CPU_JIT_EMITTER_THROW("unsupported data type"); + case 4: + h->uni_vbroadcastss(vmm_dst, xmm_src0); + break; + case 2: + h->vpbroadcastw(vmm_dst, xmm_src0); + break; + case 1: + h->vpbroadcastb(vmm_dst, xmm_src0); + break; + default: + OV_CPU_JIT_EMITTER_THROW("unsupported data type"); } } @@ -74,14 +84,20 @@ int32_t jit_scalar_emitter::read_value(const ov::snippets::lowered::ExpressionPt const auto& precision = n->get_output_element_type(0); int32_t res = INT_MIN; switch (precision) { - case element::i32: res = n->cast_vector(1)[0]; break; - case element::f32: res = dnnl::impl::cpu::x64::float2int(n->cast_vector(1)[0]); break; - default: OV_CPU_JIT_EMITTER_THROW("doesn't support ", precision); + case element::i32: + res = n->cast_vector(1)[0]; + break; + case element::f32: + res = dnnl::impl::cpu::x64::float2int(n->cast_vector(1)[0]); + break; + default: + OV_CPU_JIT_EMITTER_THROW("doesn't support ", precision); } return res; } -jit_scalar_emitter::jit_scalar_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) : jit_emitter(h, isa) { +jit_scalar_emitter::jit_scalar_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr) + : jit_emitter(h, isa) { push_arg_entry_of("scalar", read_value(expr), true); prepare_table(); } @@ -89,21 +105,27 @@ jit_scalar_emitter::jit_scalar_emitter(jit_generator* h, cpu_isa_t isa, const Ex void jit_scalar_emitter::emit_impl(const std::vector& in, const std::vector& out) const { using isa = cpu_isa_t; switch (host_isa_) { - case isa::sse41: emit_isa(in, out); break; - case isa::avx2: emit_isa(in, out); break; - case isa::avx512_core: emit_isa(in, out); break; - default: OV_CPU_JIT_EMITTER_THROW("Unsupported isa ", host_isa_); + case isa::sse41: + emit_isa(in, out); + break; + case isa::avx2: + emit_isa(in, out); + break; + case isa::avx512_core: + emit_isa(in, out); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Unsupported isa ", host_isa_); } } template -void jit_scalar_emitter::emit_isa(const std::vector &in, const std::vector &out) const { - using Vmm = typename dnnl::impl::utils::conditional3::type; - Vmm vmm_dst = Vmm(out[0]); +void jit_scalar_emitter::emit_isa(const std::vector& in, const std::vector& out) const { + using Vmm = typename dnnl::impl::utils:: + conditional3::type; + Vmm vmm_dst = Vmm(out[0]); h->uni_vbroadcastss(vmm_dst, table_val("scalar")); } - } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_snippets_emitters.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_snippets_emitters.hpp index c75f071c4ec7e0..6a91e3b7c47d3d 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_snippets_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_snippets_emitters.hpp @@ -6,16 +6,19 @@ #include "emitters/plugin/x64/jit_emitter.hpp" - namespace ov { namespace intel_cpu { class jit_nop_emitter : public jit_emitter { public: - jit_nop_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, - const ov::snippets::lowered::ExpressionPtr& expr, emitter_in_out_map emitter_type = gpr_to_gpr); + jit_nop_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, + const ov::snippets::lowered::ExpressionPtr& expr, + emitter_in_out_map emitter_type = gpr_to_gpr); - size_t get_inputs_num() const override {return 0;} + size_t get_inputs_num() const override { + return 0; + } private: void emit_impl(const std::vector& in, const std::vector& out) const override {} @@ -23,31 +26,40 @@ class jit_nop_emitter : public jit_emitter { class jit_parameter_emitter : public jit_nop_emitter { public: - jit_parameter_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, + jit_parameter_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_num() const override { return 0; } + size_t get_inputs_num() const override { + return 0; + } }; class jit_result_emitter : public jit_nop_emitter { public: - jit_result_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, + jit_result_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_num() const override {return 1;} + size_t get_inputs_num() const override { + return 1; + } }; class jit_broadcast_move_emitter : public jit_emitter { public: - jit_broadcast_move_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, + jit_broadcast_move_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_num() const override {return 1;} + size_t get_inputs_num() const override { + return 1; + } private: void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in, const std::vector &out) const; + void emit_isa(const std::vector& in, const std::vector& out) const; private: size_t byte_size = 0lu; @@ -55,18 +67,23 @@ class jit_broadcast_move_emitter : public jit_emitter { class jit_scalar_emitter : public jit_emitter { public: - jit_scalar_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, + jit_scalar_emitter(dnnl::impl::cpu::x64::jit_generator* h, + dnnl::impl::cpu::x64::cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr); - size_t get_inputs_num() const override {return 0;} - size_t aux_gprs_count() const override {return 1;} + size_t get_inputs_num() const override { + return 0; + } + size_t aux_gprs_count() const override { + return 1; + } static int32_t read_value(const ov::snippets::lowered::ExpressionPtr& expr); private: void emit_impl(const std::vector& in, const std::vector& out) const override; template - void emit_isa(const std::vector &in, const std::vector &out) const; + void emit_isa(const std::vector& in, const std::vector& out) const; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp index c57824526d6e20..58a31a1804782a 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.cpp @@ -6,13 +6,10 @@ #include "common/utils.hpp" #include "dnnl_extension_utils.h" - #include "snippets/lowered/pass/insert_specific_iterations.hpp" - #include "transformations/snippets/x64/op/brgemm_cpu.hpp" #include "transformations/snippets/x64/op/brgemm_utils.hpp" - using namespace Xbyak; using namespace dnnl::impl; using namespace dnnl::impl::cpu::x64; @@ -20,15 +17,21 @@ using namespace dnnl::impl::cpu::x64; namespace ov { namespace intel_cpu { -BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, - bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) - : BrgemmBaseKernelConfig(), m_static_params(std::make_shared(in0_dtype, in1_dtype, is_with_comp, primitive_isa)) { +BrgemmKernelConfig::BrgemmKernelConfig(const element::Type& in0_dtype, + const element::Type& in1_dtype, + bool is_with_comp, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) + : BrgemmBaseKernelConfig(), + m_static_params(std::make_shared(in0_dtype, in1_dtype, is_with_comp, primitive_isa)) { m_hash = compute_hash(); } -BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype, - bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) - : StaticBaseParams(in0_dtype, in1_dtype, primitive_isa, compute_hash(is_with_comp)), is_with_comp(is_with_comp) {} +BrgemmKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, + const element::Type& in1_dtype, + bool is_with_comp, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) + : StaticBaseParams(in0_dtype, in1_dtype, primitive_isa, compute_hash(is_with_comp)), + is_with_comp(is_with_comp) {} bool BrgemmKernelConfig::StaticParams::operator==(const StaticParams& rhs) const { return StaticBaseParams::operator==(rhs) && is_with_comp == rhs.is_with_comp; @@ -47,8 +50,8 @@ std::string BrgemmKernelConfig::StaticParams::to_string() const { } #endif -BrgemmKernelExecutor::BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config) : - CPUKernelExecutor(std::move(kernel_cache), std::move(config)) { } +BrgemmKernelExecutor::BrgemmKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config) + : CPUKernelExecutor(std::move(kernel_cache), std::move(config)) {} std::shared_ptr BrgemmKernelExecutor::compile_kernel(const BrgemmKernelConfig& config) const { std::shared_ptr compiled_kernel = std::make_shared(); @@ -57,8 +60,17 @@ std::shared_ptr BrgemmKernelExecutor::compile_kernel(const if (config.is_empty()) return compiled_kernel; - create_brgemm_kernel(compiled_kernel->brgemm_kernel, config.get_dt_in0(), config.get_dt_in1(), config.get_isa(), - config.get_M(), config.get_N(), config.get_K(), config.get_LDA(), config.get_LDB(), config.get_LDC(), config.get_beta()); + create_brgemm_kernel(compiled_kernel->brgemm_kernel, + config.get_dt_in0(), + config.get_dt_in1(), + config.get_isa(), + config.get_M(), + config.get_N(), + config.get_K(), + config.get_LDA(), + config.get_LDB(), + config.get_LDC(), + config.get_beta()); return compiled_kernel; } @@ -81,8 +93,9 @@ void BrgemmKernelExecutor::execute(const BrgemmKernelExecutor* executor, call_ar } #ifdef SNIPPETS_DEBUG_CAPS -BrgemmKernelReferenceExecutor::BrgemmKernelReferenceExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config) : - BrgemmKernelExecutor(std::move(kernel_cache), std::move(config)) {} +BrgemmKernelReferenceExecutor::BrgemmKernelReferenceExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, + BrgemmKernelConfig config) + : BrgemmKernelExecutor(std::move(kernel_cache), std::move(config)) {} std::shared_ptr BrgemmKernelReferenceExecutor::compile_kernel(const BrgemmKernelConfig& c) const { const auto& res = std::make_shared(); @@ -91,11 +104,10 @@ std::shared_ptr BrgemmKernelReferenceExecutor::compile_ker } brgemm_ref_kernel::brgemm_ref_kernel(BrgemmKernelConfig c) : m_config(std::move(c)) { - OV_CPU_JIT_EMITTER_ASSERT(!m_config.is_with_comp(), - "brgemm_ref_kernel doesn't currently support compensations"); - OV_CPU_JIT_EMITTER_ASSERT(m_config.get_dt_in0() == m_config.get_dt_in1() && - m_config.get_dt_in0() == dnnl_data_type_t::dnnl_f32, - "brgemm_ref_kernel currently supports only fp32 inputs"); + OV_CPU_JIT_EMITTER_ASSERT(!m_config.is_with_comp(), "brgemm_ref_kernel doesn't currently support compensations"); + OV_CPU_JIT_EMITTER_ASSERT( + m_config.get_dt_in0() == m_config.get_dt_in1() && m_config.get_dt_in0() == dnnl_data_type_t::dnnl_f32, + "brgemm_ref_kernel currently supports only fp32 inputs"); } void brgemm_ref_kernel::operator()(dnnl::impl::cpu::x64::brgemm_kernel_params_t* args) const { @@ -115,5 +127,5 @@ void brgemm_ref_kernel::operator()(dnnl::impl::cpu::x64::brgemm_kernel_params_t* } #endif -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp index 1c3d1e18872aea..9cc17049c4d3ae 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm.hpp @@ -11,24 +11,33 @@ namespace intel_cpu { struct BrgemmKernelConfig : public BrgemmBaseKernelConfig { public: - BrgemmKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, - bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); + BrgemmKernelConfig(const element::Type& in0_dtype, + const element::Type& in1_dtype, + bool is_with_comp, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); BrgemmKernelConfig() = delete; std::unique_ptr get_clone_ptr() const override { return std::unique_ptr(new BrgemmKernelConfig(*this)); } - bool is_with_comp() const { return m_static_params->is_with_comp; } + bool is_with_comp() const { + return m_static_params->is_with_comp; + } private: struct StaticParams : StaticBaseParams { - StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype, bool is_with_comp, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); + StaticParams(const element::Type& in0_dtype, + const element::Type& in1_dtype, + bool is_with_comp, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); - const bool is_with_comp {false}; + const bool is_with_comp{false}; bool operator==(const StaticParams& rhs) const; - bool operator!=(const StaticParams& rhs) const { return !(*this == rhs); } + bool operator!=(const StaticParams& rhs) const { + return !(*this == rhs); + } #ifdef SNIPPETS_DEBUG_CAPS std::string to_string() const; #endif @@ -36,9 +45,11 @@ struct BrgemmKernelConfig : public BrgemmBaseKernelConfig { static size_t compute_hash(bool is_with_comp); }; - std::shared_ptr get_static_params() const override { return m_static_params; } + std::shared_ptr get_static_params() const override { + return m_static_params; + } - std::shared_ptr m_static_params {nullptr}; + std::shared_ptr m_static_params{nullptr}; }; // The `update_kernel` method verifies that a compiled kernel is not nullptr. @@ -76,21 +87,25 @@ class BrgemmKernelReferenceExecutor : public BrgemmKernelExecutor { public: BrgemmKernelReferenceExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmKernelConfig config); using BrgemmKernelExecutor::execute; + protected: std::shared_ptr compile_kernel(const BrgemmKernelConfig& c) const override; }; struct brgemm_ref_kernel : public dnnl::impl::cpu::x64::brgemm_kernel_t { brgemm_ref_kernel(BrgemmKernelConfig c); - void operator()(dnnl::impl::cpu::x64::brgemm_kernel_params_t *) const override; - dnnl_status_t create_kernel() override { return dnnl_status_t::dnnl_success; } - const dnnl::impl::cpu::x64::jit_generator *get_jit_generator() const override { + void operator()(dnnl::impl::cpu::x64::brgemm_kernel_params_t*) const override; + dnnl_status_t create_kernel() override { + return dnnl_status_t::dnnl_success; + } + const dnnl::impl::cpu::x64::jit_generator* get_jit_generator() const override { OV_CPU_JIT_EMITTER_THROW("get_jit_generator should not be called for reference kernel"); return nullptr; } + private: BrgemmKernelConfig m_config; }; #endif -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp index 62c7236735f70e..12c52d43b2c4b8 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.cpp @@ -4,35 +4,40 @@ #include "brgemm_amx.hpp" -#include "transformations/snippets/x64/op/brgemm_utils.hpp" -#include "transformations/snippets/x64/op/brgemm_cpu.hpp" - #include +#include "transformations/snippets/x64/op/brgemm_cpu.hpp" +#include "transformations/snippets/x64/op/brgemm_utils.hpp" #define INNER_K_BLK(dtype) static_cast((brgemm_utils::repacking::compute_inner_k_block(in0_dtype))) #define VNNI_FACTOR(dtype) static_cast((brgemm_utils::compute_vnni_factor(in0_dtype))) -#define EQ(X) X == rhs.X -#define HASH(X) seed = hash_combine(seed, X) - +#define EQ(X) X == rhs.X +#define HASH(X) seed = hash_combine(seed, X) using namespace Xbyak; using namespace dnnl::impl; using namespace dnnl::impl::cpu::x64; - namespace ov { namespace intel_cpu { -BrgemmAMXKernelConfig::BrgemmAMXKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) - : BrgemmBaseKernelConfig(), m_static_params(std::make_shared(in0_dtype, in1_dtype, primitive_isa)) { +BrgemmAMXKernelConfig::BrgemmAMXKernelConfig(const element::Type& in0_dtype, + const element::Type& in1_dtype, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) + : BrgemmBaseKernelConfig(), + m_static_params(std::make_shared(in0_dtype, in1_dtype, primitive_isa)) { m_hash = compute_hash(); } -BrgemmAMXKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype, +BrgemmAMXKernelConfig::StaticParams::StaticParams(const element::Type& in0_dtype, + const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa) - : StaticBaseParams(in0_dtype, in1_dtype, primitive_isa, compute_hash(INNER_K_BLK(in0_dtype), VNNI_FACTOR(in0_dtype))), - inner_k_blk(INNER_K_BLK(in0_dtype)), vnni_factor(VNNI_FACTOR(in0_dtype)) {} + : StaticBaseParams(in0_dtype, + in1_dtype, + primitive_isa, + compute_hash(INNER_K_BLK(in0_dtype), VNNI_FACTOR(in0_dtype))), + inner_k_blk(INNER_K_BLK(in0_dtype)), + vnni_factor(VNNI_FACTOR(in0_dtype)) {} bool BrgemmAMXKernelConfig::StaticParams::operator==(const StaticParams& rhs) const { return StaticBaseParams::operator==(rhs) && EQ(inner_k_blk) && EQ(vnni_factor); @@ -40,7 +45,8 @@ bool BrgemmAMXKernelConfig::StaticParams::operator==(const StaticParams& rhs) co size_t BrgemmAMXKernelConfig::StaticParams::compute_hash(dnnl_dim_t inner_k_blk, dnnl_dim_t vnni_factor) { size_t seed = 0; - HASH(inner_k_blk); HASH(vnni_factor); + HASH(inner_k_blk); + HASH(vnni_factor); return seed; } @@ -58,30 +64,50 @@ std::string BrgemmAMXKernelConfig::StaticParams::to_string() const { } #endif -BrgemmAMXKernelExecutor::BrgemmAMXKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmAMXKernelConfig config) : - CPUKernelExecutor(std::move(kernel_cache), std::move(config)) {} +BrgemmAMXKernelExecutor::BrgemmAMXKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, + BrgemmAMXKernelConfig config) + : CPUKernelExecutor(std::move(kernel_cache), std::move(config)) {} namespace { struct BrgemmCopyAKey { - BrgemmCopyAKey(cpu_isa_t isa, dnnl_data_type_t dt, dnnl_dim_t K, dnnl_dim_t K_blk, dnnl_dim_t K_tail, dnnl_dim_t src_stride, dnnl_dim_t LDA) - : isa(isa), dt(dt), K{K}, K_blk{K_blk}, K_tail{K_tail}, src_stride{src_stride}, LDA{LDA} {} + BrgemmCopyAKey(cpu_isa_t isa, + dnnl_data_type_t dt, + dnnl_dim_t K, + dnnl_dim_t K_blk, + dnnl_dim_t K_tail, + dnnl_dim_t src_stride, + dnnl_dim_t LDA) + : isa(isa), + dt(dt), + K{K}, + K_blk{K_blk}, + K_tail{K_tail}, + src_stride{src_stride}, + LDA{LDA} {} size_t hash() const { size_t seed = 0; - HASH(isa); HASH(dt); HASH(K); HASH(K_blk); HASH(K_tail); HASH(src_stride); HASH(LDA); + HASH(isa); + HASH(dt); + HASH(K); + HASH(K_blk); + HASH(K_tail); + HASH(src_stride); + HASH(LDA); return seed; } bool operator==(const BrgemmCopyAKey& rhs) const { return EQ(isa) && EQ(dt) && EQ(K) && EQ(K_blk) && EQ(K_tail) && EQ(src_stride) && EQ(LDA); } - cpu_isa_t isa {cpu_isa_t::isa_undef}; - dnnl_data_type_t dt {dnnl_data_type_t::dnnl_data_type_undef}; - dnnl_dim_t K {0}, K_blk {0}, K_tail {0}, src_stride {0}, LDA {0}; + cpu_isa_t isa{cpu_isa_t::isa_undef}; + dnnl_data_type_t dt{dnnl_data_type_t::dnnl_data_type_undef}; + dnnl_dim_t K{0}, K_blk{0}, K_tail{0}, src_stride{0}, LDA{0}; }; -} // namespace +} // namespace -std::shared_ptr BrgemmAMXKernelExecutor::compile_kernel(const BrgemmAMXKernelConfig& config) const { +std::shared_ptr BrgemmAMXKernelExecutor::compile_kernel( + const BrgemmAMXKernelConfig& config) const { std::shared_ptr compiled_kernel = std::make_shared(); // Brgemm is not executable - nothing to compile @@ -98,14 +124,26 @@ std::shared_ptr BrgemmAMXKernelExecutor::compile_kernel }; auto brgemm_builder = [](const BrgemmAMXKernelConfig& k) { - std::shared_ptr ker = std::make_shared(); - create_brgemm_kernel(ker->brgemm_kernel, k.get_dt_in0(), k.get_dt_in1(), k.get_isa(), k.get_M(), k.get_N(), k.get_K(), - k.get_LDA(), k.get_LDB(), k.get_LDC(), k.get_beta(), true, ker->palette); + std::shared_ptr ker = + std::make_shared(); + create_brgemm_kernel(ker->brgemm_kernel, + k.get_dt_in0(), + k.get_dt_in1(), + k.get_isa(), + k.get_M(), + k.get_N(), + k.get_K(), + k.get_LDA(), + k.get_LDB(), + k.get_LDC(), + k.get_beta(), + true, + ker->palette); return ker; }; auto brgemm_copy_a_builder = [](const BrgemmCopyAKey& k) { - std::shared_ptr ker {nullptr}; + std::shared_ptr ker{nullptr}; create_brgemm_copy_a_kernel(ker, k.isa, k.dt, k.K, k.K_blk, k.K_tail, k.src_stride, k.LDA); return ker; }; @@ -130,7 +168,13 @@ std::shared_ptr BrgemmAMXKernelExecutor::compile_kernel K_tail = ov::snippets::utils::rnd_up(K_tail, config.get_vnni_factor()); LDA = K_tail; - const auto key = BrgemmCopyAKey(config.get_isa(), config.get_dt_in0(), config.get_K(), config.get_inner_K_blk(), K_tail, copy_A_src_stride, LDA); + const auto key = BrgemmCopyAKey(config.get_isa(), + config.get_dt_in0(), + config.get_K(), + config.get_inner_K_blk(), + K_tail, + copy_A_src_stride, + LDA); const auto result = cache->getOrCreate(key, brgemm_copy_a_builder); compiled_kernel->brgemm_copy_a_kernel = result.first; } @@ -142,11 +186,17 @@ std::shared_ptr BrgemmAMXKernelExecutor::compile_kernel return compiled_kernel; } -void BrgemmAMXKernelExecutor::create_brgemm_copy_a_kernel(std::shared_ptr& kernel, - dnnl::impl::cpu::x64::cpu_isa_t isa, dnnl_data_type_t dt, - dnnl_dim_t K, dnnl_dim_t K_blk, dnnl_dim_t K_tail, dnnl_dim_t src_stride, dnnl_dim_t LDA) { +void BrgemmAMXKernelExecutor::create_brgemm_copy_a_kernel( + std::shared_ptr& kernel, + dnnl::impl::cpu::x64::cpu_isa_t isa, + dnnl_data_type_t dt, + dnnl_dim_t K, + dnnl_dim_t K_blk, + dnnl_dim_t K_tail, + dnnl_dim_t src_stride, + dnnl_dim_t LDA) { matmul::brgemm_matmul_conf_t conf_; - conf_.src_tag = dnnl_abcd; // unused + conf_.src_tag = dnnl_abcd; // unused conf_.K = K; conf_.K_tail = K_tail; conf_.K_blk = K_blk; @@ -176,18 +226,28 @@ void BrgemmAMXKernelExecutor::update_config(const ov::snippets::lowered::Express return BrgemmBaseKernelExecutor::update_config(expr, linear_ir, config); } -void BrgemmAMXKernelExecutor::configure_tiles_if_needed(amx_tile_config_t* config, const char* palette, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K) { +void BrgemmAMXKernelExecutor::configure_tiles_if_needed(amx_tile_config_t* config, + const char* palette, + dnnl_dim_t M, + dnnl_dim_t N, + dnnl_dim_t K) { auto compatible = [&](amx_tile_config_t* rhs) { return rhs && rhs->M == M && rhs->N == N && rhs->K == K; }; if (config && !compatible(config)) { - config->M = M; config->N = N; config->K = K; + config->M = M; + config->N = N; + config->K = K; cpu::x64::amx_tile_configure(palette); } } -void BrgemmAMXKernelExecutor::execute_brgemm_copy_a_kernel(const std::shared_ptr& kernel, - const void* src, const void* tr_src, dnnl_dim_t M, dnnl_dim_t K) { +void BrgemmAMXKernelExecutor::execute_brgemm_copy_a_kernel( + const std::shared_ptr& kernel, + const void* src, + const void* tr_src, + dnnl_dim_t M, + dnnl_dim_t K) { auto ctx = matmul::jit_brgemm_matmul_copy_a_t::ctx_t(); ctx.current_M_blk = M; @@ -219,7 +279,11 @@ void BrgemmAMXKernelExecutor::execute(const BrgemmAMXKernelExecutor* executor, c if (K_body != 0) { const auto& K_body_kernel = kernel->K_body_kernel; - configure_tiles_if_needed(args->amx_tile_config, K_body_kernel->palette, config.get_M(), config.get_N(), K_body); + configure_tiles_if_needed(args->amx_tile_config, + K_body_kernel->palette, + config.get_M(), + config.get_N(), + K_body); execute_brgemm_kernel(K_body_kernel->brgemm_kernel, src_ptr, wei_ptr, args->C, scratch, false); src_ptr = src_ptr + K_body * dnnl_data_type_size(config.get_dt_in0()); @@ -235,7 +299,11 @@ void BrgemmAMXKernelExecutor::execute(const BrgemmAMXKernelExecutor* executor, c } const auto& K_tail_kernel = kernel->K_tail_kernel; - configure_tiles_if_needed(args->amx_tile_config, K_tail_kernel->palette, config.get_M(), config.get_N(), K_tail); + configure_tiles_if_needed(args->amx_tile_config, + K_tail_kernel->palette, + config.get_M(), + config.get_N(), + K_tail); execute_brgemm_kernel(K_tail_kernel->brgemm_kernel, src_ptr, wei_ptr, args->C, scratch, false); } } @@ -245,5 +313,5 @@ void BrgemmAMXKernelExecutor::execute(const BrgemmAMXKernelExecutor* executor, c #undef EQ #undef HASH -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp index a8544e5343b0ce..733295ec995583 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_amx.hpp @@ -4,42 +4,50 @@ #pragma once -#include "brgemm_base.hpp" - -#include "emitters/plugin/x64/jit_emitter.hpp" -#include "emitters/snippets/jit_snippets_call_args.hpp" -#include "emitters/snippets/cpu_kernel_executor_table.hpp" - #include #include +#include "brgemm_base.hpp" +#include "emitters/plugin/x64/jit_emitter.hpp" +#include "emitters/snippets/cpu_kernel_executor_table.hpp" +#include "emitters/snippets/jit_snippets_call_args.hpp" namespace ov { namespace intel_cpu { struct BrgemmAMXKernelConfig : public BrgemmBaseKernelConfig { public: - BrgemmAMXKernelConfig(const element::Type& in0_dtype, const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); + BrgemmAMXKernelConfig(const element::Type& in0_dtype, + const element::Type& in1_dtype, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); BrgemmAMXKernelConfig() = delete; std::unique_ptr get_clone_ptr() const override { return std::unique_ptr(new BrgemmAMXKernelConfig(*this)); } - dnnl_dim_t get_inner_K_blk() const { return m_static_params->inner_k_blk; } - dnnl_dim_t get_vnni_factor() const { return m_static_params->vnni_factor; } + dnnl_dim_t get_inner_K_blk() const { + return m_static_params->inner_k_blk; + } + dnnl_dim_t get_vnni_factor() const { + return m_static_params->vnni_factor; + } bool need_copy_a(dnnl_dim_t K) const; private: struct StaticParams : StaticBaseParams { - StaticParams(const element::Type& in0_dtype, const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); + StaticParams(const element::Type& in0_dtype, + const element::Type& in1_dtype, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa); - const dnnl_dim_t inner_k_blk {0}; - const dnnl_dim_t vnni_factor {0}; + const dnnl_dim_t inner_k_blk{0}; + const dnnl_dim_t vnni_factor{0}; bool operator==(const StaticParams& rhs) const; - bool operator!=(const StaticParams& rhs) const { return !(*this == rhs); } + bool operator!=(const StaticParams& rhs) const { + return !(*this == rhs); + } #ifdef SNIPPETS_DEBUG_CAPS std::string to_string() const; #endif @@ -47,22 +55,24 @@ struct BrgemmAMXKernelConfig : public BrgemmBaseKernelConfig { static size_t compute_hash(dnnl_dim_t inner_k_blk, dnnl_dim_t vnni_factor); }; - std::shared_ptr get_static_params() const override { return m_static_params; } + std::shared_ptr get_static_params() const override { + return m_static_params; + } - std::shared_ptr m_static_params {nullptr}; + std::shared_ptr m_static_params{nullptr}; }; struct BrgemmAMXCompiledKernel { struct BrgemmKernel { - std::shared_ptr brgemm_kernel {nullptr}; + std::shared_ptr brgemm_kernel{nullptr}; // Note: Palette is treated as a part of a kernel because it is initialized during the kernel compilation stage. // Each kernel need to store the pallet it was compiled with. char palette[64] = {}; }; - std::shared_ptr K_body_kernel {nullptr}; - std::shared_ptr K_tail_kernel {nullptr}; - std::shared_ptr brgemm_copy_a_kernel {nullptr}; + std::shared_ptr K_body_kernel{nullptr}; + std::shared_ptr K_tail_kernel{nullptr}; + std::shared_ptr brgemm_copy_a_kernel{nullptr}; }; class BrgemmAMXKernelExecutor : public BrgemmBaseKernelExecutor, @@ -87,16 +97,30 @@ class BrgemmAMXKernelExecutor : public BrgemmBaseKernelExecutor, const ov::snippets::lowered::LinearIRCPtr& linear_ir, BrgemmAMXKernelConfig& config) const override; - static void configure_tiles_if_needed(amx_tile_config_t* config, const char* palette, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K); - - static void create_brgemm_copy_a_kernel(std::shared_ptr& kernel, - dnnl::impl::cpu::x64::cpu_isa_t isa, dnnl_data_type_t dt, - dnnl_dim_t K, dnnl_dim_t K_blk, dnnl_dim_t K_tail, dnnl_dim_t src_stride, dnnl_dim_t LDA); - - static void execute_brgemm_copy_a_kernel(const std::shared_ptr& kernel, - const void* src, const void* tr_src, dnnl_dim_t M, dnnl_dim_t K); + static void configure_tiles_if_needed(amx_tile_config_t* config, + const char* palette, + dnnl_dim_t M, + dnnl_dim_t N, + dnnl_dim_t K); + + static void create_brgemm_copy_a_kernel( + std::shared_ptr& kernel, + dnnl::impl::cpu::x64::cpu_isa_t isa, + dnnl_data_type_t dt, + dnnl_dim_t K, + dnnl_dim_t K_blk, + dnnl_dim_t K_tail, + dnnl_dim_t src_stride, + dnnl_dim_t LDA); + + static void execute_brgemm_copy_a_kernel( + const std::shared_ptr& kernel, + const void* src, + const void* tr_src, + dnnl_dim_t M, + dnnl_dim_t K); }; #define GET_OFF_BRGEMM_AMX_ARGS(field) offsetof(BrgemmAMXKernelExecutor::call_args, field) -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp index 17b1f0e053b577..008237780de3f6 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_base.cpp @@ -9,11 +9,11 @@ #include "transformations/snippets/x64/op/brgemm_cpu.hpp" #include "transformations/snippets/x64/op/brgemm_utils.hpp" -#define DIM_CAST(X) static_cast(X) +#define DIM_CAST(X) static_cast(X) #define DTYPE_CAST(X) static_cast(DnnlExtensionUtils::ElementTypeToDataType(X)) -#define PRINT(X) ss << #X << " = " << X << "\n" -#define EQ(X) X == rhs.X -#define HASH(X) seed = hash_combine(seed, X) +#define PRINT(X) ss << #X << " = " << X << "\n" +#define EQ(X) X == rhs.X +#define HASH(X) seed = hash_combine(seed, X) using namespace Xbyak; using namespace dnnl::impl; @@ -31,22 +31,34 @@ bool BrgemmBaseKernelConfig::is_empty() const { } bool BrgemmBaseKernelConfig::operator==(const BrgemmBaseKernelConfig& rhs) const { - return EQ(m_hash) && EQ(m_beta) && - EQ(m_M) && EQ(m_N) && EQ(m_K) && - EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC) && + return EQ(m_hash) && EQ(m_beta) && EQ(m_M) && EQ(m_N) && EQ(m_K) && EQ(m_LDA) && EQ(m_LDB) && EQ(m_LDC) && (EQ(get_static_params()) || *get_static_params() == *(rhs.get_static_params())); } -void BrgemmBaseKernelConfig::update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta) { +void BrgemmBaseKernelConfig::update(dnnl_dim_t M, + dnnl_dim_t N, + dnnl_dim_t K, + dnnl_dim_t LDA, + dnnl_dim_t LDB, + dnnl_dim_t LDC, + float beta) { // If M is zero, it means that Brgemm won't be executed (in Loop with work_amount = 0, for example) // To process this case, we have to make this Config as empty (nullify runtime parameters) if (utils::one_of(0, M, N, K)) { - m_M = 0; m_N = 0; m_K = 0; - m_LDA = 0; m_LDB = 0; m_LDC = 0; + m_M = 0; + m_N = 0; + m_K = 0; + m_LDA = 0; + m_LDB = 0; + m_LDC = 0; m_beta = 0; } else { - m_M = M; m_N = N; m_K = K; - m_LDA = LDA; m_LDB = LDB; m_LDC = LDC; + m_M = M; + m_N = N; + m_K = K; + m_LDA = LDA; + m_LDB = LDB; + m_LDC = LDC; m_beta = beta; } m_hash = compute_hash(); @@ -54,30 +66,45 @@ void BrgemmBaseKernelConfig::update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dn size_t BrgemmBaseKernelConfig::compute_hash() const { size_t seed = get_static_params()->hash(); - HASH(m_M); HASH(m_N); HASH(m_K); - HASH(m_LDA); HASH(m_LDB); HASH(m_LDC); + HASH(m_M); + HASH(m_N); + HASH(m_K); + HASH(m_LDA); + HASH(m_LDB); + HASH(m_LDC); HASH(m_beta); return seed; } -BrgemmBaseKernelConfig::StaticBaseParams::StaticBaseParams(const element::Type& in0_dtype, const element::Type& in1_dtype, - cpu_isa_t primitive_isa, size_t hash_seed) - : dt_in0(DTYPE_CAST(in0_dtype)), dt_in1(DTYPE_CAST(in1_dtype)), isa(primitive_isa), m_hash(compute_hash(hash_seed, dt_in0, dt_in1, isa)) {} +BrgemmBaseKernelConfig::StaticBaseParams::StaticBaseParams(const element::Type& in0_dtype, + const element::Type& in1_dtype, + cpu_isa_t primitive_isa, + size_t hash_seed) + : dt_in0(DTYPE_CAST(in0_dtype)), + dt_in1(DTYPE_CAST(in1_dtype)), + isa(primitive_isa), + m_hash(compute_hash(hash_seed, dt_in0, dt_in1, isa)) {} bool BrgemmBaseKernelConfig::StaticBaseParams::operator==(const StaticBaseParams& rhs) const { return EQ(hash()) && EQ(dt_in0) && EQ(dt_in1) && EQ(isa); } -size_t BrgemmBaseKernelConfig::StaticBaseParams::compute_hash(size_t hash_seed, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1, cpu_isa_t isa) { +size_t BrgemmBaseKernelConfig::StaticBaseParams::compute_hash(size_t hash_seed, + dnnl_data_type_t dt_in0, + dnnl_data_type_t dt_in1, + cpu_isa_t isa) { size_t seed = hash_seed; - HASH(dt_in0); HASH(dt_in1); HASH(isa); + HASH(dt_in0); + HASH(dt_in1); + HASH(isa); return seed; } #ifdef SNIPPETS_DEBUG_CAPS std::string BrgemmBaseKernelConfig::StaticBaseParams::to_string() const { std::stringstream ss; - PRINT(dt_in0); PRINT(dt_in1); + PRINT(dt_in0); + PRINT(dt_in1); PRINT(isa); return ss.str(); } @@ -85,26 +112,33 @@ std::string BrgemmBaseKernelConfig::StaticBaseParams::to_string() const { std::string BrgemmBaseKernelConfig::to_string() const { std::stringstream ss; ss << get_static_params()->to_string() << "\n"; - PRINT(m_M); PRINT(m_N); PRINT(m_K); - PRINT(m_LDA); PRINT(m_LDB); PRINT(m_LDC); + PRINT(m_M); + PRINT(m_N); + PRINT(m_K); + PRINT(m_LDA); + PRINT(m_LDB); + PRINT(m_LDC); PRINT(m_beta); return ss.str(); } #endif -float BrgemmBaseKernelExecutor::get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, int loop_id, +float BrgemmBaseKernelExecutor::get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, + int loop_id, const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info) { // Find all Expanded loops with the same Unified loop information -> they were decomposed from this Unified Loop. // Note that LoopInfo are normalized and sorted (due to NormalizedLoopIDs pass). // It means that previous executed Loops have Loop ID less the current Loop ID. - // - If there is executed Loop (work_amount > 0) and evaluated before the current -> the current Brgemm should have `beta = 1`. + // - If there is executed Loop (work_amount > 0) and evaluated before the current -> the current Brgemm should have + // `beta = 1`. // - If there is not this Loop -> the current executed Brgemm should have `beta = 0`. if (loop_id > 0) { const auto& current_unified_loop_info = current_expanded_loop_info->get_unified_loop_info(); // Check the previous Loops --loop_id; while (loop_id >= 0) { - const auto& expanded_loop_info = loop_manager->get_loop_info(loop_id); + const auto& expanded_loop_info = + loop_manager->get_loop_info(loop_id); if (expanded_loop_info->get_unified_loop_info() != current_unified_loop_info) return 0; if (expanded_loop_info->get_work_amount() > 0) { @@ -143,7 +177,7 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres size_t loop_idx = 0; const auto& loop_ids = expr->get_loop_ids(); const auto& loop_manager = linear_ir->get_loop_manager(); - auto get_loop_info = [&](){ + auto get_loop_info = [&]() { OPENVINO_ASSERT(loop_idx < loop_ids.size(), "Loop is missed"); return loop_manager->get_loop_info(loop_ids[loop_idx++]); }; @@ -160,9 +194,11 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres // to avoid extra checks, we validate only first input port // Note: We check `is_incremented` attribute only for not incremented ports because // this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization - auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { return p.dim_idx == 1; }; + auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { + return p.dim_idx == 1; + }; OPENVINO_ASSERT(in_ports.size() > 1 && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) && - out_ports.size() == 1 && check_port(out_ports.back()), + out_ports.size() == 1 && check_port(out_ports.back()), "Incorrect Loop by Brgemm dimension M"); M = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0; input_pds[0]->set_subtensor_dim(1, M); @@ -179,9 +215,12 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres // Quick validation check: Should we check that port is really Brgemm port? // Note: We check `is_incremented` attribute only for not incremented ports because // this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization - auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { return p.dim_idx == 0; }; - OPENVINO_ASSERT(in_ports.size() >= 2 && !in_ports.front().is_incremented && std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) && - out_ports.size() == 1 && check_port(out_ports.back()), + auto check_port = [&](const ov::snippets::lowered::LoopPort& p) { + return p.dim_idx == 0; + }; + OPENVINO_ASSERT(in_ports.size() >= 2 && !in_ports.front().is_incremented && + std::all_of(in_ports.cbegin(), in_ports.cend(), check_port) && out_ports.size() == 1 && + check_port(out_ports.back()), "Incorrect Loop by Brgemm dimension N"); N = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0; input_pds[1]->set_subtensor_dim(0, N); @@ -204,7 +243,7 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres // Note: We check `is_incremented` attribute only for not incremented ports because // this `is_incremented = true` can be changed by `CleanRepeatedDataPointerShifts` optimization OPENVINO_ASSERT(in_ports.size() >= 2 && in_ports.front().dim_idx == 0 && in_ports.back().dim_idx == 1 && - out_ports.size() == 1 && !out_ports.front().is_incremented, + out_ports.size() == 1 && !out_ports.front().is_incremented, "Incorrect Loop by Brgemm dimension K"); K = current_expanded_loop_info->get_work_amount() > 0 ? current_expanded_loop_info->get_increment() : 0; input_pds[0]->set_subtensor_dim(0, K); @@ -226,13 +265,37 @@ void BrgemmBaseKernelExecutor::update_config(const ov::snippets::lowered::Expres config.update(DIM_CAST(M), DIM_CAST(N), DIM_CAST(K), LDA, LDB, LDC, beta); } -void BrgemmBaseKernelExecutor::create_brgemm_kernel(std::shared_ptr& kernel, dnnl_data_type_t dt0, dnnl_data_type_t dt1, - cpu_isa_t isa, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, - dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta, bool with_amx, char* palette) { +void BrgemmBaseKernelExecutor::create_brgemm_kernel(std::shared_ptr& kernel, + dnnl_data_type_t dt0, + dnnl_data_type_t dt1, + cpu_isa_t isa, + dnnl_dim_t M, + dnnl_dim_t N, + dnnl_dim_t K, + dnnl_dim_t LDA, + dnnl_dim_t LDB, + dnnl_dim_t LDC, + float beta, + bool with_amx, + char* palette) { cpu::x64::brgemm_desc_t desc; - OV_CPU_JIT_EMITTER_ASSERT(brgemm_desc_init(&desc, isa, cpu::x64::brgemm_strd, dt0, dt1, - false, false, cpu::x64::brgemm_row_major, 1.f, - beta, LDA, LDB, LDC, M, N, K, nullptr) == dnnl_success, + OV_CPU_JIT_EMITTER_ASSERT(brgemm_desc_init(&desc, + isa, + cpu::x64::brgemm_strd, + dt0, + dt1, + false, + false, + cpu::x64::brgemm_row_major, + 1.f, + beta, + LDA, + LDB, + LDC, + M, + N, + K, + nullptr) == dnnl_success, "Cannot initialize brgemm descriptor due to invalid params"); if (with_amx) { @@ -241,12 +304,18 @@ void BrgemmBaseKernelExecutor::create_brgemm_kernel(std::shared_ptr(kernel_); } -void BrgemmBaseKernelExecutor::execute_brgemm_kernel(const std::shared_ptr& kernel, - const void* src, const void* wei, void* dst, void* scratch, bool with_comp) { +void BrgemmBaseKernelExecutor::execute_brgemm_kernel( + const std::shared_ptr& kernel, + const void* src, + const void* wei, + void* dst, + void* scratch, + bool with_comp) { cpu::x64::brgemm_kernel_params_t brgemm_p; brgemm_p.batch = nullptr; // default value brgemm_p.ptr_A = src; @@ -269,5 +338,5 @@ void BrgemmBaseKernelExecutor::execute_brgemm_kernel(const std::shared_ptr #include "cpu/x64/cpu_isa_traits.hpp" - #include "emitters/plugin/x64/jit_emitter.hpp" -#include "emitters/snippets/jit_snippets_call_args.hpp" #include "emitters/snippets/cpu_kernel_executor_table.hpp" -#include - -#include "snippets/lowered/loop_manager.hpp" +#include "emitters/snippets/jit_snippets_call_args.hpp" +#include "openvino/core/type/element_type.hpp" #include "snippets/lowered/loop_info.hpp" +#include "snippets/lowered/loop_manager.hpp" namespace ov { namespace intel_cpu { @@ -24,27 +22,51 @@ struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConf BrgemmBaseKernelConfig() = default; bool is_completed() const override; - size_t hash() const override { return m_hash; } + size_t hash() const override { + return m_hash; + } bool is_empty() const; void update(dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta); bool operator==(const BrgemmBaseKernelConfig& rhs) const; - bool operator!=(const BrgemmBaseKernelConfig& rhs) const {return !(*this == rhs);} - - dnnl_data_type_t get_dt_in0() const { return get_static_params()->dt_in0; } - dnnl_data_type_t get_dt_in1() const { return get_static_params()->dt_in1; } - - dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { return get_static_params()->isa; } - float get_beta() const { return m_beta; } - - dnnl_dim_t get_M() const { return m_M; } - dnnl_dim_t get_N() const { return m_N; } - dnnl_dim_t get_K() const { return m_K; } - - dnnl_dim_t get_LDA() const { return m_LDA; } - dnnl_dim_t get_LDB() const { return m_LDB; } - dnnl_dim_t get_LDC() const { return m_LDC; } + bool operator!=(const BrgemmBaseKernelConfig& rhs) const { + return !(*this == rhs); + } + + dnnl_data_type_t get_dt_in0() const { + return get_static_params()->dt_in0; + } + dnnl_data_type_t get_dt_in1() const { + return get_static_params()->dt_in1; + } + + dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { + return get_static_params()->isa; + } + float get_beta() const { + return m_beta; + } + + dnnl_dim_t get_M() const { + return m_M; + } + dnnl_dim_t get_N() const { + return m_N; + } + dnnl_dim_t get_K() const { + return m_K; + } + + dnnl_dim_t get_LDA() const { + return m_LDA; + } + dnnl_dim_t get_LDB() const { + return m_LDB; + } + dnnl_dim_t get_LDC() const { + return m_LDC; + } #ifdef SNIPPETS_DEBUG_CAPS std::string to_string() const override; @@ -52,51 +74,77 @@ struct BrgemmBaseKernelConfig : public snippets::KernelExecutorBase::GenericConf protected: struct StaticBaseParams { - StaticBaseParams(const element::Type& in0_dtype, const element::Type& in1_dtype, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa, size_t hash_seed); + StaticBaseParams(const element::Type& in0_dtype, + const element::Type& in1_dtype, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa, + size_t hash_seed); virtual ~StaticBaseParams() = default; - const dnnl_data_type_t dt_in0 {dnnl_f32}, dt_in1 {dnnl_f32}; - const dnnl::impl::cpu::x64::cpu_isa_t isa {dnnl::impl::cpu::x64::isa_undef}; + const dnnl_data_type_t dt_in0{dnnl_f32}, dt_in1{dnnl_f32}; + const dnnl::impl::cpu::x64::cpu_isa_t isa{dnnl::impl::cpu::x64::isa_undef}; - size_t hash() const { return m_hash; } + size_t hash() const { + return m_hash; + } bool operator==(const StaticBaseParams& rhs) const; - bool operator!=(const StaticBaseParams& rhs) const { return !(*this == rhs); } + bool operator!=(const StaticBaseParams& rhs) const { + return !(*this == rhs); + } #ifdef SNIPPETS_DEBUG_CAPS std::string to_string() const; #endif protected: - static size_t compute_hash(size_t hash_seed, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1, dnnl::impl::cpu::x64::cpu_isa_t isa); + static size_t compute_hash(size_t hash_seed, + dnnl_data_type_t dt_in0, + dnnl_data_type_t dt_in1, + dnnl::impl::cpu::x64::cpu_isa_t isa); - const size_t m_hash {0}; + const size_t m_hash{0}; }; virtual std::shared_ptr get_static_params() const = 0; size_t compute_hash() const; - dnnl_dim_t m_M {0}, m_N {0}, m_K {0}, m_LDA {0}, m_LDB {0}, m_LDC {0}; - float m_beta {0}; - size_t m_hash {SIZE_MAX}; + dnnl_dim_t m_M{0}, m_N{0}, m_K{0}, m_LDA{0}, m_LDB{0}, m_LDC{0}; + float m_beta{0}; + size_t m_hash{SIZE_MAX}; }; class BrgemmBaseKernelExecutor { public: virtual ~BrgemmBaseKernelExecutor() = default; + protected: - static float get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, int loop_id, + static float get_beta(const ov::snippets::lowered::LoopManagerPtr& loop_manager, + int loop_id, const ov::snippets::lowered::ExpandedLoopInfoPtr& current_expanded_loop_info); static void update_config(const ov::snippets::lowered::ExpressionPtr& expr, const ov::snippets::lowered::LinearIRCPtr& linear_ir, BrgemmBaseKernelConfig& config); - static void create_brgemm_kernel(std::shared_ptr& kernel, dnnl_data_type_t dt0, dnnl_data_type_t dt1, - dnnl::impl::cpu::x64::cpu_isa_t isa, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, - dnnl_dim_t LDA, dnnl_dim_t LDB, dnnl_dim_t LDC, float beta, bool with_amx = false, char* palette = nullptr); - - static void execute_brgemm_kernel(const std::shared_ptr& kernel, const void* src, const void* wei, - void* dst, void* scratch, bool with_comp); + static void create_brgemm_kernel(std::shared_ptr& kernel, + dnnl_data_type_t dt0, + dnnl_data_type_t dt1, + dnnl::impl::cpu::x64::cpu_isa_t isa, + dnnl_dim_t M, + dnnl_dim_t N, + dnnl_dim_t K, + dnnl_dim_t LDA, + dnnl_dim_t LDB, + dnnl_dim_t LDC, + float beta, + bool with_amx = false, + char* palette = nullptr); + + static void execute_brgemm_kernel(const std::shared_ptr& kernel, + const void* src, + const void* wei, + void* dst, + void* scratch, + bool with_comp); }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.cpp index cc79458c7c4c64..dd216517ace12e 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.cpp @@ -4,8 +4,8 @@ #include "brgemm_copy_b.hpp" -#include "snippets/lowered/loop_manager.hpp" #include "emitters/plugin/x64/utils.hpp" +#include "snippets/lowered/loop_manager.hpp" #include "transformations/snippets/x64/op/brgemm_utils.hpp" #define DTYPE_CAST(X) static_cast(DnnlExtensionUtils::ElementTypeToDataType(X)) @@ -16,8 +16,12 @@ using namespace dnnl::impl::cpu::x64; namespace ov { namespace intel_cpu { -BrgemmCopyBKernelConfig::BrgemmCopyBKernelConfig(const element::Type& src_dt, const element::Type& wei_dt, cpu_isa_t isa, - bool is_with_comp, bool is_transposed_B, dnnl_dim_t wei_N_blk) +BrgemmCopyBKernelConfig::BrgemmCopyBKernelConfig(const element::Type& src_dt, + const element::Type& wei_dt, + cpu_isa_t isa, + bool is_with_comp, + bool is_transposed_B, + dnnl_dim_t wei_N_blk) : m_static_params(std::make_shared(src_dt, wei_dt, isa, is_with_comp, is_transposed_B, wei_N_blk)) { m_hash = compute_hash(); } @@ -37,17 +41,28 @@ bool BrgemmCopyBKernelConfig::operator==(const BrgemmCopyBKernelConfig& rhs) con #undef EQ } -void BrgemmCopyBKernelConfig::update(dnnl_dim_t N, dnnl_dim_t N_blk, dnnl_dim_t K, dnnl_dim_t K_blk, dnnl_dim_t copy_B_wei_stride, dnnl_dim_t LDB) { - // If one of the dims is zero, it means that BrgemmCopyB won't be executed (in Loop with work_amount = 0, for example) - // To process this case, we have to make this Config as empty (nullify runtime parameters) +void BrgemmCopyBKernelConfig::update(dnnl_dim_t N, + dnnl_dim_t N_blk, + dnnl_dim_t K, + dnnl_dim_t K_blk, + dnnl_dim_t copy_B_wei_stride, + dnnl_dim_t LDB) { + // If one of the dims is zero, it means that BrgemmCopyB won't be executed (in Loop with work_amount = 0, for + // example) To process this case, we have to make this Config as empty (nullify runtime parameters) if (utils::one_of(0, N, K)) { - m_N = 0; m_N_blk = 0; - m_K = 0; m_K_blk = 0; - m_copy_B_wei_stride = 0; m_LDB = 0; + m_N = 0; + m_N_blk = 0; + m_K = 0; + m_K_blk = 0; + m_copy_B_wei_stride = 0; + m_LDB = 0; } else { - m_N = N; m_N_blk = N_blk; - m_K = K; m_K_blk = K_blk; - m_copy_B_wei_stride = copy_B_wei_stride; m_LDB = LDB; + m_N = N; + m_N_blk = N_blk; + m_K = K; + m_K_blk = K_blk; + m_copy_B_wei_stride = copy_B_wei_stride; + m_LDB = LDB; } m_hash = compute_hash(); } @@ -55,59 +70,94 @@ void BrgemmCopyBKernelConfig::update(dnnl_dim_t N, dnnl_dim_t N_blk, dnnl_dim_t size_t BrgemmCopyBKernelConfig::compute_hash() const { size_t seed = m_static_params->hash; #define HASH(X) seed = hash_combine(seed, X) - HASH(m_N); HASH(m_N_blk); - HASH(m_K); HASH(m_K_blk); - HASH(m_copy_B_wei_stride); HASH(m_LDB); + HASH(m_N); + HASH(m_N_blk); + HASH(m_K); + HASH(m_K_blk); + HASH(m_copy_B_wei_stride); + HASH(m_LDB); #undef HASH return seed; } -BrgemmCopyBKernelConfig::StaticParams::StaticParams(const element::Type& src_type, const element::Type& wei_type, cpu_isa_t isa, - bool is_with_comp, bool is_transposed_B, dnnl_dim_t wei_n_blk) - : src_dt(DTYPE_CAST(src_type)), wei_dt(DTYPE_CAST(wei_type)), isa(isa), - is_with_comp(is_with_comp), is_transposed_B(is_transposed_B), wei_N_blk(wei_n_blk), +BrgemmCopyBKernelConfig::StaticParams::StaticParams(const element::Type& src_type, + const element::Type& wei_type, + cpu_isa_t isa, + bool is_with_comp, + bool is_transposed_B, + dnnl_dim_t wei_n_blk) + : src_dt(DTYPE_CAST(src_type)), + wei_dt(DTYPE_CAST(wei_type)), + isa(isa), + is_with_comp(is_with_comp), + is_transposed_B(is_transposed_B), + wei_N_blk(wei_n_blk), hash(init_hash(src_dt, wei_dt, isa, is_with_comp, is_transposed_B, wei_N_blk)) {} bool BrgemmCopyBKernelConfig::StaticParams::operator==(const StaticParams& rhs) const { #define EQ(X) X == rhs.X - return EQ(hash) && EQ(src_dt) && EQ(wei_dt)&& EQ(isa) && EQ(is_with_comp) && EQ(is_transposed_B) && EQ(wei_N_blk); + return EQ(hash) && EQ(src_dt) && EQ(wei_dt) && EQ(isa) && EQ(is_with_comp) && EQ(is_transposed_B) && EQ(wei_N_blk); #undef EQ } -size_t BrgemmCopyBKernelConfig::StaticParams::init_hash(const dnnl_data_type_t& src_dt, const dnnl_data_type_t& wei_dt, cpu_isa_t isa, - bool is_with_comp, bool is_transposed_B, dnnl_dim_t wei_N_blk) { +size_t BrgemmCopyBKernelConfig::StaticParams::init_hash(const dnnl_data_type_t& src_dt, + const dnnl_data_type_t& wei_dt, + cpu_isa_t isa, + bool is_with_comp, + bool is_transposed_B, + dnnl_dim_t wei_N_blk) { size_t seed = 0; #define HASH(X) seed = hash_combine(seed, X) - HASH(src_dt); HASH(wei_dt); HASH(isa); - HASH(is_with_comp); HASH(is_transposed_B); HASH(wei_N_blk); + HASH(src_dt); + HASH(wei_dt); + HASH(isa); + HASH(is_with_comp); + HASH(is_transposed_B); + HASH(wei_N_blk); #undef HASH return seed; } #ifdef SNIPPETS_DEBUG_CAPS -#define PRINT(X) ss << #X << " = " << X << "\n" +# define PRINT(X) ss << #X << " = " << X << "\n" std::string BrgemmCopyBKernelConfig::to_string() const { std::stringstream ss; ss << m_static_params->to_string() << "\n"; - PRINT(m_hash); PRINT(m_N); PRINT(m_N_blk); - PRINT(m_K); PRINT(m_K_blk); PRINT(m_LDB); PRINT(m_copy_B_wei_stride); + PRINT(m_hash); + PRINT(m_N); + PRINT(m_N_blk); + PRINT(m_K); + PRINT(m_K_blk); + PRINT(m_LDB); + PRINT(m_copy_B_wei_stride); return ss.str(); } std::string BrgemmCopyBKernelConfig::StaticParams::to_string() const { std::stringstream ss; - PRINT(src_dt); PRINT(wei_dt); PRINT(isa); - PRINT(is_with_comp); PRINT(is_transposed_B); PRINT(wei_N_blk); + PRINT(src_dt); + PRINT(wei_dt); + PRINT(isa); + PRINT(is_with_comp); + PRINT(is_transposed_B); + PRINT(wei_N_blk); return ss.str(); } -#undef PRINT +# undef PRINT #endif BrgemmCopyBKernel::BrgemmCopyBKernel() : jit_generator(jit_name()), ker_(nullptr) {} BrgemmCopyBKernel::BrgemmCopyBKernel(const BrgemmCopyBKernelConfig& conf) - : jit_generator(jit_name()), is_with_comp(conf.is_with_comp()), is_transpose(conf.is_transposed_B()), - wei_data_size(dnnl_data_type_size(conf.get_wei_dt())), vnni_factor(data_type_vnni_granularity(conf.get_wei_dt())), - K(conf.get_K()), N_blk(conf.get_N_blk()), wei_N_blk(conf.get_wei_N_blk()), wei_N_tail(conf.get_wei_N_tail()), ker_(nullptr) { + : jit_generator(jit_name()), + is_with_comp(conf.is_with_comp()), + is_transpose(conf.is_transposed_B()), + wei_data_size(dnnl_data_type_size(conf.get_wei_dt())), + vnni_factor(data_type_vnni_granularity(conf.get_wei_dt())), + K(conf.get_K()), + N_blk(conf.get_N_blk()), + wei_N_blk(conf.get_wei_N_blk()), + wei_N_tail(conf.get_wei_N_tail()), + ker_(nullptr) { init_brgemm_copy_b_kernel(dnnl_brgemm_copy_b_kernel, conf); OV_CPU_JIT_EMITTER_ASSERT(dnnl_brgemm_copy_b_kernel, "Kernel is missed!"); } @@ -124,8 +174,9 @@ void BrgemmCopyBKernel::operator()(const call_args* args) const { ker_(args); } -void BrgemmCopyBKernel::init_brgemm_copy_b_kernel(std::unique_ptr& kernel, - const BrgemmCopyBKernelConfig& conf) const { +void BrgemmCopyBKernel::init_brgemm_copy_b_kernel( + std::unique_ptr& kernel, + const BrgemmCopyBKernelConfig& conf) const { matmul::brgemm_matmul_conf_t brgCopyKernelConf; brgCopyKernelConf.src_dt = conf.get_src_dt(); brgCopyKernelConf.wei_dt = conf.get_wei_dt(); @@ -143,8 +194,10 @@ void BrgemmCopyBKernel::init_brgemm_copy_b_kernel(std::unique_ptr(brgCopyKernelConf.wei_dt)); - brgCopyKernelConf.tr_b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.wei_dt)); + brgCopyKernelConf.b_dt_sz = + DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.wei_dt)); + brgCopyKernelConf.tr_b_dt_sz = + DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.wei_dt)); brgCopyKernelConf.req_wei_vnni_downconvert = false; @@ -191,28 +244,35 @@ void BrgemmCopyBKernel::generate() { postamble(); } -void BrgemmCopyBKernel::emit_brgemm_copy_b_kernel_call(size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) { +void BrgemmCopyBKernel::emit_brgemm_copy_b_kernel_call(size_t N, + size_t K, + size_t offset_in, + size_t offset_out, + size_t offset_comp) { EmitABIRegSpills spill(this); spill.preamble(); const auto add_offset = [&](Xbyak::Reg64 reg, size_t bytes_offset) { - if (bytes_offset) add(reg, bytes_offset); + if (bytes_offset) + add(reg, bytes_offset); }; // save function address in gpr to pass in call instruction - const auto& kernel_overload = static_cast(execute); + const auto& kernel_overload = static_cast< + void (*)(matmul::jit_brgemm_matmul_copy_b_t*, const void*, const void*, const void*, size_t, size_t)>(execute); mov(rbp, reinterpret_cast(kernel_overload)); mov(abi_param1, reinterpret_cast(dnnl_brgemm_copy_b_kernel.get())); - add_offset(src_reg, offset_in); // abi_param2 - add_offset(tr_src_reg, offset_out); // abi_param3 - if (is_with_comp) // abi_param4 + add_offset(src_reg, offset_in); // abi_param2 + add_offset(tr_src_reg, offset_out); // abi_param3 + if (is_with_comp) // abi_param4 add_offset(comp_reg, offset_comp); else mov(comp_reg, reinterpret_cast(nullptr)); #ifdef _WIN32 - // Note: ABI requires that the remaining parameters (except the first for) are pushed to the stack in right-to-left order + // Note: ABI requires that the remaining parameters (except the first for) are pushed to the stack in right-to-left + // order // Shadow space will be allocated inside internal_call_rsp_align() push(K); push(N); @@ -233,7 +293,12 @@ void BrgemmCopyBKernel::emit_brgemm_copy_b_kernel_call(size_t N, size_t K, size_ spill.postamble(); } -void BrgemmCopyBKernel::execute(matmul::jit_brgemm_matmul_copy_b_t* kernel, const void* src, const void* dst, const void* comp, size_t N, size_t K) { +void BrgemmCopyBKernel::execute(matmul::jit_brgemm_matmul_copy_b_t* kernel, + const void* src, + const void* dst, + const void* comp, + size_t N, + size_t K) { auto ctx = matmul::jit_brgemm_matmul_copy_b_t::ctx_t(); ctx.current_N_blk = N; ctx.src = src; @@ -248,10 +313,12 @@ void BrgemmCopyBKernel::execute(matmul::jit_brgemm_matmul_copy_b_t* kernel, cons (*kernel)(&ctx); } -BrgemmCopyBKernelExecutor::BrgemmCopyBKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, BrgemmCopyBKernelConfig config) - : CPUKernelExecutor(std::move(kernel_cache), std::move(config)) { } +BrgemmCopyBKernelExecutor::BrgemmCopyBKernelExecutor(ov::intel_cpu::MultiCacheWeakPtr kernel_cache, + BrgemmCopyBKernelConfig config) + : CPUKernelExecutor(std::move(kernel_cache), std::move(config)) {} -std::shared_ptr BrgemmCopyBKernelExecutor::compile_kernel(const BrgemmCopyBKernelConfig& config) const { +std::shared_ptr BrgemmCopyBKernelExecutor::compile_kernel( + const BrgemmCopyBKernelConfig& config) const { std::shared_ptr compiled_kernel = std::make_shared(); // BrgemmCopyB is not executable - nothing to compile if (!config.is_empty()) { @@ -283,14 +350,16 @@ void BrgemmCopyBKernelExecutor::update_config(const ov::snippets::lowered::Expre const auto& loop_manager = linear_ir->get_loop_manager(); auto init = [&](size_t& dim, size_t& blk, size_t idx) { - OPENVINO_ASSERT(idx < planar_shape.size() && idx < in_subtensor.size(), "Index must be less than shape/subtensor rank!"); + OPENVINO_ASSERT(idx < planar_shape.size() && idx < in_subtensor.size(), + "Index must be less than shape/subtensor rank!"); dim = *(planar_shape.rbegin() + idx); blk = *(in_subtensor.rbegin() + idx); if (ov::snippets::utils::is_full_dim_value(blk)) { blk = dim; } else { OPENVINO_ASSERT(loop_idx < loop_ids.size(), "Loop is missed"); - const auto& current_expanded_loop_info = loop_manager->get_loop_info(loop_ids[loop_idx++]); + const auto& current_expanded_loop_info = + loop_manager->get_loop_info(loop_ids[loop_idx++]); blk = current_expanded_loop_info->get_increment(); input_desc->set_subtensor_dim(idx, blk); output_desc->set_subtensor_dim(idx, blk); @@ -306,7 +375,9 @@ void BrgemmCopyBKernelExecutor::update_config(const ov::snippets::lowered::Expre const auto& brg_weight_etype = expr->get_node()->get_input_element_type(0); const auto LDB = brgemm_utils::repacking::compute_LDB(N_dim, brg_weight_etype); - const auto copy_B_wei_stride = ov::snippets::utils::get_dim_stride(expr->get_input_port(0), config.is_transposed_B() ? 0 : 1) * brg_weight_etype.size(); + const auto copy_B_wei_stride = + ov::snippets::utils::get_dim_stride(expr->get_input_port(0), config.is_transposed_B() ? 0 : 1) * + brg_weight_etype.size(); config.update(N_dim, N_blk, K_dim, K_blk, copy_B_wei_stride, LDB); } @@ -318,5 +389,5 @@ void BrgemmCopyBKernelExecutor::execute(const BrgemmCopyBKernelExecutor* executo (*kernel)(args); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.hpp index c4e3f3622ad88f..b3b107cd676705 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/kernel_executors/brgemm_copy_b.hpp @@ -4,13 +4,12 @@ #pragma once -#include "emitters/plugin/x64/jit_emitter.hpp" -#include "emitters/snippets/jit_snippets_call_args.hpp" -#include "emitters/snippets/cpu_kernel_executor_table.hpp" - #include #include +#include "emitters/plugin/x64/jit_emitter.hpp" +#include "emitters/snippets/cpu_kernel_executor_table.hpp" +#include "emitters/snippets/jit_snippets_call_args.hpp" namespace ov { namespace intel_cpu { @@ -18,11 +17,17 @@ namespace intel_cpu { struct BrgemmCopyBKernelConfig : public snippets::KernelExecutorBase::GenericConfig { public: BrgemmCopyBKernelConfig() = default; - BrgemmCopyBKernelConfig(const element::Type& src_dt, const element::Type& wei_dt, dnnl::impl::cpu::x64::cpu_isa_t isa, - bool is_with_comp, bool is_transposed_B, dnnl_dim_t wei_N_blk); + BrgemmCopyBKernelConfig(const element::Type& src_dt, + const element::Type& wei_dt, + dnnl::impl::cpu::x64::cpu_isa_t isa, + bool is_with_comp, + bool is_transposed_B, + dnnl_dim_t wei_N_blk); bool operator==(const BrgemmCopyBKernelConfig& rhs) const; - bool operator!=(const BrgemmCopyBKernelConfig& rhs) const {return !(*this == rhs);} + bool operator!=(const BrgemmCopyBKernelConfig& rhs) const { + return !(*this == rhs); + } std::unique_ptr get_clone_ptr() const override { return std::unique_ptr(new BrgemmCopyBKernelConfig(*this)); @@ -31,26 +36,61 @@ struct BrgemmCopyBKernelConfig : public snippets::KernelExecutorBase::GenericCon bool is_empty() const; bool is_completed() const override; - void update(dnnl_dim_t N, dnnl_dim_t N_blk, dnnl_dim_t K, dnnl_dim_t K_blk, dnnl_dim_t copy_B_wei_stride, dnnl_dim_t LDB); + void update(dnnl_dim_t N, + dnnl_dim_t N_blk, + dnnl_dim_t K, + dnnl_dim_t K_blk, + dnnl_dim_t copy_B_wei_stride, + dnnl_dim_t LDB); - size_t hash() const override { return m_hash; } + size_t hash() const override { + return m_hash; + } - dnnl_data_type_t get_src_dt() const { return m_static_params->src_dt; } - dnnl_data_type_t get_wei_dt() const { return m_static_params->wei_dt; } + dnnl_data_type_t get_src_dt() const { + return m_static_params->src_dt; + } + dnnl_data_type_t get_wei_dt() const { + return m_static_params->wei_dt; + } - dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { return m_static_params->isa; } - bool is_with_comp() const { return m_static_params->is_with_comp; } - bool is_transposed_B() const { return m_static_params->is_transposed_B; } + dnnl::impl::cpu::x64::cpu_isa_t get_isa() const { + return m_static_params->isa; + } + bool is_with_comp() const { + return m_static_params->is_with_comp; + } + bool is_transposed_B() const { + return m_static_params->is_transposed_B; + } - dnnl_dim_t get_N() const { return m_N; } - dnnl_dim_t get_N_blk() const { return m_N_blk; } - dnnl_dim_t get_N_tail() const { return m_N % m_N_blk; } - dnnl_dim_t get_wei_N_blk() const { return m_static_params->wei_N_blk; } - dnnl_dim_t get_wei_N_tail() const { return m_N_blk % m_static_params->wei_N_blk; } - dnnl_dim_t get_K() const { return m_K; } - dnnl_dim_t get_K_blk() const { return m_K_blk; } - dnnl_dim_t get_copy_B_wei_stride() const { return m_copy_B_wei_stride; } - dnnl_dim_t get_LDB() const { return m_LDB; } + dnnl_dim_t get_N() const { + return m_N; + } + dnnl_dim_t get_N_blk() const { + return m_N_blk; + } + dnnl_dim_t get_N_tail() const { + return m_N % m_N_blk; + } + dnnl_dim_t get_wei_N_blk() const { + return m_static_params->wei_N_blk; + } + dnnl_dim_t get_wei_N_tail() const { + return m_N_blk % m_static_params->wei_N_blk; + } + dnnl_dim_t get_K() const { + return m_K; + } + dnnl_dim_t get_K_blk() const { + return m_K_blk; + } + dnnl_dim_t get_copy_B_wei_stride() const { + return m_copy_B_wei_stride; + } + dnnl_dim_t get_LDB() const { + return m_LDB; + } #ifdef SNIPPETS_DEBUG_CAPS std::string to_string() const override; @@ -58,35 +98,45 @@ struct BrgemmCopyBKernelConfig : public snippets::KernelExecutorBase::GenericCon private: struct StaticParams { - StaticParams(const element::Type& src_dt, const element::Type& wei_dt, dnnl::impl::cpu::x64::cpu_isa_t isa, - bool is_with_comp, bool is_transposed_B, dnnl_dim_t wei_N_blk); - - const dnnl_data_type_t src_dt {dnnl_data_type_undef}, wei_dt {dnnl_data_type_undef}; - const dnnl::impl::cpu::x64::cpu_isa_t isa {dnnl::impl::cpu::x64::isa_undef}; - const bool is_with_comp {false}; - const bool is_transposed_B {false}; - const dnnl_dim_t wei_N_blk {0}; - const size_t hash {0}; + StaticParams(const element::Type& src_dt, + const element::Type& wei_dt, + dnnl::impl::cpu::x64::cpu_isa_t isa, + bool is_with_comp, + bool is_transposed_B, + dnnl_dim_t wei_N_blk); + + const dnnl_data_type_t src_dt{dnnl_data_type_undef}, wei_dt{dnnl_data_type_undef}; + const dnnl::impl::cpu::x64::cpu_isa_t isa{dnnl::impl::cpu::x64::isa_undef}; + const bool is_with_comp{false}; + const bool is_transposed_B{false}; + const dnnl_dim_t wei_N_blk{0}; + const size_t hash{0}; bool operator==(const StaticParams& rhs) const; - bool operator!=(const StaticParams& rhs) const { return !(*this == rhs); } + bool operator!=(const StaticParams& rhs) const { + return !(*this == rhs); + } #ifdef SNIPPETS_DEBUG_CAPS std::string to_string() const; #endif private: - static size_t init_hash(const dnnl_data_type_t& src_dt, const dnnl_data_type_t& wei_dt, dnnl::impl::cpu::x64::cpu_isa_t primitive_isa, - bool is_with_comp, bool is_transposed_B, dnnl_dim_t wei_N_blk); + static size_t init_hash(const dnnl_data_type_t& src_dt, + const dnnl_data_type_t& wei_dt, + dnnl::impl::cpu::x64::cpu_isa_t primitive_isa, + bool is_with_comp, + bool is_transposed_B, + dnnl_dim_t wei_N_blk); }; size_t compute_hash() const; std::shared_ptr m_static_params; - dnnl_dim_t m_N {0}, m_N_blk {0}; - dnnl_dim_t m_K {0}, m_K_blk {0}; - dnnl_dim_t m_copy_B_wei_stride {0}, m_LDB {0}; - size_t m_hash {SIZE_MAX}; + dnnl_dim_t m_N{0}, m_N_blk{0}; + dnnl_dim_t m_K{0}, m_K_blk{0}; + dnnl_dim_t m_copy_B_wei_stride{0}, m_LDB{0}; + size_t m_hash{SIZE_MAX}; }; struct BrgemmCopyBKernel : public dnnl::impl::cpu::x64::jit_generator { @@ -109,8 +159,12 @@ struct BrgemmCopyBKernel : public dnnl::impl::cpu::x64::jit_generator { void emit_brgemm_copy_b_kernel_call(size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp); - static void execute(dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t* kernel, const void* src, const void* dst, const void* comp, - size_t N, size_t K); + static void execute(dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t* kernel, + const void* src, + const void* dst, + const void* comp, + size_t N, + size_t K); void init_brgemm_copy_b_kernel(std::unique_ptr& kernel, const BrgemmCopyBKernelConfig& conf) const; @@ -151,5 +205,5 @@ class BrgemmCopyBKernelExecutor : public CPUKernelExecutor(); size_t id = SIZE_MAX; switch (port.get_type()) { - case ov::snippets::lowered::ExpressionPort::Type::Input: - offset = ma_op->get_input_offset(port.get_index()); - id = get_cluster_id(port.get_port_connector_ptr()->get_source()); - break; - case ov::snippets::lowered::ExpressionPort::Type::Output: - offset = ma_op->get_output_offset(port.get_index()); - for (const auto& child : port.get_connected_ports()) - if (!ov::is_type(child.get_expr()->get_node())) - id = get_cluster_id(child); - break; - default: - OV_CPU_JIT_EMITTER_THROW("Uknown type of expression port!"); + case ov::snippets::lowered::ExpressionPort::Type::Input: + offset = ma_op->get_input_offset(port.get_index()); + id = get_cluster_id(port.get_port_connector_ptr()->get_source()); + break; + case ov::snippets::lowered::ExpressionPort::Type::Output: + offset = ma_op->get_output_offset(port.get_index()); + for (const auto& child : port.get_connected_ports()) + if (!ov::is_type(child.get_expr()->get_node())) + id = get_cluster_id(child); + break; + default: + OV_CPU_JIT_EMITTER_THROW("Uknown type of expression port!"); } OV_CPU_JIT_EMITTER_ASSERT(IMPLICATION(ov::snippets::utils::is_dynamic_value(offset), id != SIZE_MAX), "In dynamic case Buffer Cluster ID must be known!"); @@ -46,31 +45,41 @@ size_t get_buffer_cluster_id(const ov::snippets::lowered::ExpressionPort& port) Xbyak::Reg64 get_aux_gpr(const std::vector& used_gpr_idxs) { // RSP, RBP - stack-related registers, abi_param1 - runtime parameter register in the kernel - static std::unordered_set blacklist_gpr_idxs = { Xbyak::Operand::RSP, Xbyak::Operand::RBP, static_cast(abi_param1.getIdx()) }; + static std::unordered_set blacklist_gpr_idxs = {Xbyak::Operand::RSP, + Xbyak::Operand::RBP, + static_cast(abi_param1.getIdx())}; for (size_t gpr_idx = 0; gpr_idx <= Xbyak::Operand::R15; ++gpr_idx) { - size_t _idx = Xbyak::Operand::R15 - gpr_idx; // we allocate from the end - if (std::find(used_gpr_idxs.cbegin(), used_gpr_idxs.cend(), _idx) != used_gpr_idxs.cend()) continue; - if (blacklist_gpr_idxs.count(_idx) > 0) continue; + size_t _idx = Xbyak::Operand::R15 - gpr_idx; // we allocate from the end + if (std::find(used_gpr_idxs.cbegin(), used_gpr_idxs.cend(), _idx) != used_gpr_idxs.cend()) + continue; + if (blacklist_gpr_idxs.count(_idx) > 0) + continue; return Xbyak::Reg64(_idx); } OV_CPU_JIT_EMITTER_THROW("Failed to allocate aux GPR"); } -void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, size_t stack_offset, - Xbyak::Reg64 ptr_reg, Xbyak::Reg64 aux_reg, size_t runtime_offset) { +void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, + size_t stack_offset, + Xbyak::Reg64 ptr_reg, + Xbyak::Reg64 aux_reg, + size_t runtime_offset) { const auto stack_frame = h->qword[h->rsp + stack_offset]; h->mov(aux_reg, ptr_reg); h->add(aux_reg, h->ptr[abi_param1 + runtime_offset]); h->mov(stack_frame, aux_reg); } -void push_ptr_with_static_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, size_t stack_offset, - Xbyak::Reg64 ptr_reg, size_t ptr_offset) { +void push_ptr_with_static_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, + size_t stack_offset, + Xbyak::Reg64 ptr_reg, + size_t ptr_offset) { const auto stack_frame = h->qword[h->rsp + stack_offset]; h->mov(stack_frame, ptr_reg); - if (ptr_offset != 0) h->add(stack_frame, ptr_offset); + if (ptr_offset != 0) + h->add(stack_frame, ptr_offset); } -} // namespace utils -} // namespace intel_cpu -} // namespace ov +} // namespace utils +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp index 97ea86f404fd67..3d8026ea33c750 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/utils.hpp @@ -13,13 +13,17 @@ namespace utils { inline static std::vector transform_idxs_to_regs(const std::vector& idxs) { std::vector regs(idxs.size()); - std::transform(idxs.begin(), idxs.end(), regs.begin(), [](size_t idx){return Xbyak::Reg64(static_cast(idx));}); + std::transform(idxs.begin(), idxs.end(), regs.begin(), [](size_t idx) { + return Xbyak::Reg64(static_cast(idx)); + }); return regs; } inline static std::vector transform_snippets_regs_to_idxs(const std::vector& regs) { std::vector idxs(regs.size()); - std::transform(regs.cbegin(), regs.cend(), idxs.begin(), [](const snippets::Reg& reg) { return reg.idx; }); + std::transform(regs.cbegin(), regs.cend(), idxs.begin(), [](const snippets::Reg& reg) { + return reg.idx; + }); return idxs; } @@ -46,8 +50,11 @@ Xbyak::Reg64 get_aux_gpr(const std::vector& used_gpr_idxs); * @param aux_reg aux register * @param runtime_offset offset in runtime params `abi_param1` */ -void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, size_t stack_offset, - Xbyak::Reg64 ptr_reg, Xbyak::Reg64 aux_reg, size_t runtime_offset); +void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, + size_t stack_offset, + Xbyak::Reg64 ptr_reg, + Xbyak::Reg64 aux_reg, + size_t runtime_offset); /** * @brief Push data pointer on stack adding static offset `ptr_offset` @@ -56,9 +63,11 @@ void push_ptr_with_runtime_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* * @param ptr_reg register contains data pointer * @param ptr_offset offset which will be added to data pointer */ -void push_ptr_with_static_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, size_t stack_offset, - Xbyak::Reg64 ptr_reg, size_t ptr_offset); +void push_ptr_with_static_offset_on_stack(dnnl::impl::cpu::x64::jit_generator* h, + size_t stack_offset, + Xbyak::Reg64 ptr_reg, + size_t ptr_offset); -} // namespace utils -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace utils +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp index 269212edf1ab9b..9ac7f0d5cd0ffc 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.cpp @@ -4,20 +4,20 @@ #ifdef SNIPPETS_DEBUG_CAPS -#include "verbose.hpp" -#include "jit_segfault_detector_emitter.hpp" -#include "jit_memory_emitters.hpp" -#include "jit_brgemm_emitter.hpp" -#include "jit_brgemm_copy_b_emitter.hpp" -#include "jit_kernel_emitter.hpp" -#include "jit_snippets_emitters.hpp" -#include "kernel_executors/brgemm.hpp" -#include "kernel_executors/brgemm_amx.hpp" +# include "verbose.hpp" +# include "jit_brgemm_copy_b_emitter.hpp" +# include "jit_brgemm_emitter.hpp" +# include "jit_kernel_emitter.hpp" +# include "jit_memory_emitters.hpp" +# include "jit_segfault_detector_emitter.hpp" +# include "jit_snippets_emitters.hpp" +# include "kernel_executors/brgemm.hpp" +# include "kernel_executors/brgemm_amx.hpp" -#ifndef _WIN32 -#include -#endif +# ifndef _WIN32 +# include +# endif namespace ov { namespace intel_cpu { @@ -44,50 +44,44 @@ std::string vector_to_string(const T& v) { std::string get_emitter_type_name(const jit_emitter* emitter) { std::string name = typeid(*emitter).name(); -#ifndef _WIN32 +# ifndef _WIN32 int status; - std::unique_ptr demangled_name( - abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), - std::free); + std::unique_ptr demangled_name(abi::__cxa_demangle(name.c_str(), nullptr, nullptr, &status), + std::free); name = demangled_name.get(); -#endif +# endif return name; } -std::string init_info_jit_memory_emitter(const jit_memory_emitter *emitter) { +std::string init_info_jit_memory_emitter(const jit_memory_emitter* emitter) { std::stringstream ss; - ss << " src_precision:" << emitter->src_prc - << " dst_precision:" << emitter->dst_prc - << " load/store_element_number:" << emitter->count - << " byte_offset:" << emitter->compiled_byte_offset; + ss << " src_precision:" << emitter->src_prc << " dst_precision:" << emitter->dst_prc + << " load/store_element_number:" << emitter->count << " byte_offset:" << emitter->compiled_byte_offset; return ss.str(); } -static std::string init_info_jit_load_memory_emitter(const jit_load_memory_emitter *emitter) { +static std::string init_info_jit_load_memory_emitter(const jit_load_memory_emitter* emitter) { std::stringstream ss; std::string memory_emitter_info = init_info_jit_memory_emitter(emitter); - ss << "Emitter_type_name:jit_load_memory_emitter" - << memory_emitter_info; + ss << "Emitter_type_name:jit_load_memory_emitter" << memory_emitter_info; return ss.str(); } -static std::string init_info_jit_load_broadcast_emitter(const jit_load_broadcast_emitter *emitter) { +static std::string init_info_jit_load_broadcast_emitter(const jit_load_broadcast_emitter* emitter) { std::stringstream ss; std::string memory_emitter_info = init_info_jit_memory_emitter(emitter); - ss << "Emitter_type_name:jit_load_broadcast_emitter" - << memory_emitter_info; + ss << "Emitter_type_name:jit_load_broadcast_emitter" << memory_emitter_info; return ss.str(); } -static std::string init_info_jit_store_memory_emitter(const jit_store_memory_emitter *emitter) { +static std::string init_info_jit_store_memory_emitter(const jit_store_memory_emitter* emitter) { std::stringstream ss; std::string memory_emitter_info = init_info_jit_memory_emitter(emitter); - ss << "Emitter_type_name:jit_store_memory_emitter" - << memory_emitter_info; + ss << "Emitter_type_name:jit_store_memory_emitter" << memory_emitter_info; return ss.str(); } -std::string init_info_jit_brgemm_emitter(const jit_brgemm_emitter *emitter) { +std::string init_info_jit_brgemm_emitter(const jit_brgemm_emitter* emitter) { std::stringstream ss; ss << "Emitter_type_name:jit_brgemm_emitter"; if (const auto& common = std::dynamic_pointer_cast(emitter->m_kernel_executor)) @@ -100,10 +94,9 @@ std::string init_info_jit_brgemm_emitter(const jit_brgemm_emitter *emitter) { return ss.str(); } -std::string init_info_jit_brgemm_copy_b_emitter(const jit_brgemm_copy_b_emitter *emitter) { +std::string init_info_jit_brgemm_copy_b_emitter(const jit_brgemm_copy_b_emitter* emitter) { std::stringstream ss; - ss << "Emitter_type_name:jit_brgemm_copy_b_emitter" - << emitter->m_kernel_executor->to_string() + ss << "Emitter_type_name:jit_brgemm_copy_b_emitter" << emitter->m_kernel_executor->to_string() << " m_memory_offset:" << vector_to_string(emitter->m_memory_offsets) << " m_buffer_ids:" << vector_to_string(emitter->m_buffer_ids); @@ -114,11 +107,9 @@ std::string init_info_jit_kernel_static_emitter(const jit_kernel_static_emitter* std::stringstream ss; ss << "Emitter_type_name:jit_kernel_static_emitter" << " jcp.exec_domain:" << vector_to_string(emitter->jcp.exec_domain) - << " gp_regs_pool:"<< vector_to_string(emitter->gp_regs_pool) - << " master_shape:" << vector_to_string(emitter->master_shape) - << " num_inputs:" << emitter->num_inputs - << " num_outputs:" << emitter->num_outputs - << " num_unique_buffers:" << emitter->num_unique_buffers + << " gp_regs_pool:" << vector_to_string(emitter->gp_regs_pool) + << " master_shape:" << vector_to_string(emitter->master_shape) << " num_inputs:" << emitter->num_inputs + << " num_outputs:" << emitter->num_outputs << " num_unique_buffers:" << emitter->num_unique_buffers << " data_ptr_regs_idx:" << vector_to_string(emitter->data_ptr_regs_idx) << " vec_regs_pool:" << vector_to_string(emitter->vec_regs_pool) << " reg_indexes_idx:" << emitter->reg_indexes_idx @@ -131,24 +122,20 @@ std::string init_info_jit_kernel_static_emitter(const jit_kernel_static_emitter* std::string init_info_jit_kernel_dynamic_emitter(const jit_kernel_dynamic_emitter* emitter) { std::stringstream ss; ss << "Emitter_type_name:jit_kernel_dynamic_emitter" - << " gp_regs_pool:"<< vector_to_string(emitter->gp_regs_pool) - << " num_inputs:" << emitter->num_inputs - << " num_outputs:" << emitter->num_outputs - << " num_unique_buffers:" << emitter->num_unique_buffers + << " gp_regs_pool:" << vector_to_string(emitter->gp_regs_pool) << " num_inputs:" << emitter->num_inputs + << " num_outputs:" << emitter->num_outputs << " num_unique_buffers:" << emitter->num_unique_buffers << " data_ptr_regs_idx:" << vector_to_string(emitter->data_ptr_regs_idx) << " vec_regs_pool:" << vector_to_string(emitter->vec_regs_pool) << " reg_runtime_params_idx:" << emitter->reg_runtime_params_idx; return ss.str(); } -std::string init_info_jit_uni_segfault_detector_emitter(const jit_uni_segfault_detector_emitter *emitter) { +std::string init_info_jit_uni_segfault_detector_emitter(const jit_uni_segfault_detector_emitter* emitter) { std::stringstream ss; - ss << "Node_name:" << emitter->m_target_node_name - << " use_load_emitter:"<< emitter->is_target_use_load_emitter - << " use_store_emitter:"<< emitter->is_target_use_store_emitter; + ss << "Node_name:" << emitter->m_target_node_name << " use_load_emitter:" << emitter->is_target_use_load_emitter + << " use_store_emitter:" << emitter->is_target_use_store_emitter; if (emitter->is_target_use_load_emitter || emitter->is_target_use_store_emitter) { - ss << " start_address:" << emitter->start_address - << " current_address:" << emitter->current_address + ss << " start_address:" << emitter->start_address << " current_address:" << emitter->current_address << " iteration:" << emitter->iteration << " "; } // traget emitter info @@ -158,14 +145,15 @@ std::string init_info_jit_uni_segfault_detector_emitter(const jit_uni_segfault_d return ss.str(); } -static std::string init_info_jit_emitter_general(const jit_emitter *emitter) { +static std::string init_info_jit_emitter_general(const jit_emitter* emitter) { std::stringstream ss; ss << "Emitter_type_name:" << get_emitter_type_name(emitter); return ss.str(); } -void jit_emitter_info_t::init(const jit_emitter *emitter) { - if (is_initialized_) return; +void jit_emitter_info_t::init(const jit_emitter* emitter) { + if (is_initialized_) + return; if (auto e_type = dynamic_cast(emitter)) { str_ = init_info_jit_load_memory_emitter(e_type); } else if (auto e_type = dynamic_cast(emitter)) { @@ -188,7 +176,7 @@ void jit_emitter_info_t::init(const jit_emitter *emitter) { is_initialized_ = true; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov #endif \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.hpp index a81364039b98a7..ffbe210f75d2ff 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/verbose.hpp @@ -4,27 +4,30 @@ #ifdef SNIPPETS_DEBUG_CAPS -#pragma once +# pragma once -#include +# include namespace ov { namespace intel_cpu { class jit_emitter; struct jit_emitter_info_t { jit_emitter_info_t() = default; - jit_emitter_info_t(const jit_emitter_info_t &rhs) - : str_(rhs.str_), is_initialized_(rhs.is_initialized_) {} - jit_emitter_info_t &operator=(const jit_emitter_info_t &rhs) { + jit_emitter_info_t(const jit_emitter_info_t& rhs) : str_(rhs.str_), is_initialized_(rhs.is_initialized_) {} + jit_emitter_info_t& operator=(const jit_emitter_info_t& rhs) { is_initialized_ = rhs.is_initialized_; str_ = rhs.str_; return *this; } - const char *c_str() const { return str_.c_str(); } - bool is_initialized() const { return is_initialized_; } + const char* c_str() const { + return str_.c_str(); + } + bool is_initialized() const { + return is_initialized_; + } - void init(const jit_emitter *emitter); + void init(const jit_emitter* emitter); private: std::string str_; @@ -33,7 +36,7 @@ struct jit_emitter_info_t { std::string get_emitter_type_name(const jit_emitter* emitter); -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov #endif \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/emitters/utils.cpp b/src/plugins/intel_cpu/src/emitters/utils.cpp index b92277ae643218..43172e1b600843 100644 --- a/src/plugins/intel_cpu/src/emitters/utils.cpp +++ b/src/plugins/intel_cpu/src/emitters/utils.cpp @@ -7,28 +7,29 @@ namespace ov { namespace intel_cpu { -std::string jit_emitter_pretty_name(const std::string &pretty_func) { -#define SAFE_SYMBOL_FINDING(idx, find) \ - auto idx = (find); \ +std::string jit_emitter_pretty_name(const std::string& pretty_func) { +#define SAFE_SYMBOL_FINDING(idx, find) \ + auto idx = (find); \ if (idx == std::string::npos || idx == 0) \ return pretty_func; // Example: - // pretty_func := void ov::intel_cpu::jit_load_memory_emitter::emit_impl(const std::vector& in, const std::vector& out) const - // begin := -----------| - // end := ---------------------------------------------------| - // result := ov::intel_cpu::jit_load_memory_emitter + // pretty_func := void ov::intel_cpu::jit_load_memory_emitter::emit_impl(const std::vector& in, const + // std::vector& out) const begin := -----------| end := + // ---------------------------------------------------| result := ov::intel_cpu::jit_load_memory_emitter // Signatures: // GCC: void foo() [with T = {type}] // clang: void foo() [T = {type}] // MSVC: void __cdecl foo<{type}>(void) SAFE_SYMBOL_FINDING(parenthesis, pretty_func.find("(")) - if (pretty_func[parenthesis - 1] == '>') { // To cover template on MSVC + if (pretty_func[parenthesis - 1] == '>') { // To cover template on MSVC parenthesis--; size_t counter = 1; while (counter != 0 && parenthesis > 0) { parenthesis--; - if (pretty_func[parenthesis] == '>') counter++; - if (pretty_func[parenthesis] == '<') counter--; + if (pretty_func[parenthesis] == '>') + counter++; + if (pretty_func[parenthesis] == '<') + counter--; } } SAFE_SYMBOL_FINDING(end, pretty_func.substr(0, parenthesis).rfind("::")) @@ -38,5 +39,5 @@ std::string jit_emitter_pretty_name(const std::string &pretty_func) { return end > begin ? pretty_func.substr(begin, end - begin) : pretty_func; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/emitters/utils.hpp b/src/plugins/intel_cpu/src/emitters/utils.hpp index 4c3210579d7fd2..7c89b720159dde 100644 --- a/src/plugins/intel_cpu/src/emitters/utils.hpp +++ b/src/plugins/intel_cpu/src/emitters/utils.hpp @@ -5,21 +5,22 @@ #pragma once #include + #include "openvino/core/except.hpp" namespace ov { namespace intel_cpu { -std::string jit_emitter_pretty_name(const std::string &pretty_func); +std::string jit_emitter_pretty_name(const std::string& pretty_func); #ifdef __GNUC__ -#define OV_CPU_JIT_EMITTER_NAME jit_emitter_pretty_name(__PRETTY_FUNCTION__) +# define OV_CPU_JIT_EMITTER_NAME jit_emitter_pretty_name(__PRETTY_FUNCTION__) #else /* __GNUC__ */ -#define OV_CPU_JIT_EMITTER_NAME jit_emitter_pretty_name(__FUNCSIG__) +# define OV_CPU_JIT_EMITTER_NAME jit_emitter_pretty_name(__FUNCSIG__) #endif /* __GNUC__ */ -#define OV_CPU_JIT_EMITTER_THROW(...) OPENVINO_THROW(OV_CPU_JIT_EMITTER_NAME, ": ", __VA_ARGS__) +#define OV_CPU_JIT_EMITTER_THROW(...) OPENVINO_THROW(OV_CPU_JIT_EMITTER_NAME, ": ", __VA_ARGS__) #define OV_CPU_JIT_EMITTER_ASSERT(cond, ...) OPENVINO_ASSERT((cond), OV_CPU_JIT_EMITTER_NAME, ": ", __VA_ARGS__) -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/extension.cpp b/src/plugins/intel_cpu/src/extension.cpp index e6dbc04b0ca6a4..bdb5211009a22a 100644 --- a/src/plugins/intel_cpu/src/extension.cpp +++ b/src/plugins/intel_cpu/src/extension.cpp @@ -9,8 +9,8 @@ #include "ov_ops/augru_sequence.hpp" #include "ov_ops/fully_connected.hpp" #include "ov_ops/fully_connected_compressed.hpp" -#include "ov_ops/fully_connected_quantized_legacy.hpp" #include "ov_ops/fully_connected_quantized.hpp" +#include "ov_ops/fully_connected_quantized_legacy.hpp" #include "ov_ops/gather_compressed.hpp" #include "ov_ops/multiclass_nms_ie_internal.hpp" #include "ov_ops/nms_ie_internal.hpp" @@ -26,8 +26,8 @@ #include "transformations/cpu_opset/common/op/sdpa.hpp" #include "transformations/cpu_opset/common/op/swish_cpu.hpp" #include "transformations/cpu_opset/x64/op/interaction.hpp" -#include "transformations/cpu_opset/x64/op/mha.hpp" #include "transformations/cpu_opset/x64/op/llm_mlp.hpp" +#include "transformations/cpu_opset/x64/op/mha.hpp" #include "transformations/cpu_opset/x64/op/qkv_proj.hpp" #include "transformations/snippets/x64/op/brgemm_copy_b.hpp" #include "transformations/snippets/x64/op/brgemm_cpu.hpp" @@ -40,8 +40,7 @@ namespace { template class TypeRelaxedExtension : public ov::OpExtension> { public: - TypeRelaxedExtension() - : m_ext_type(Op::get_type_info_static().name, "type_relaxed_opset") {} + TypeRelaxedExtension() : m_ext_type(Op::get_type_info_static().name, "type_relaxed_opset") {} ~TypeRelaxedExtension() override = default; const ov::DiscreteTypeInfo& get_type_info() const override { @@ -159,31 +158,31 @@ class TypeRelaxedExtension : public ov::OpExtension> { # define SNIPPETS_DEBUG_CAPS_EXTENSIONS #endif -#define SNIPPETS_EXTENSIONS \ - OP_EXTENSION(ov::snippets::op::Brgemm) \ - OP_EXTENSION(ov::snippets::op::BroadcastLoad) \ - OP_EXTENSION(ov::snippets::op::BroadcastMove) \ - OP_EXTENSION(ov::snippets::op::ConvertSaturation) \ - OP_EXTENSION(ov::snippets::op::ConvertTruncation) \ - OP_EXTENSION(ov::snippets::op::Fill) \ - OP_EXTENSION(ov::snippets::op::HorizonMax) \ - OP_EXTENSION(ov::snippets::op::HorizonSum) \ - OP_EXTENSION(ov::snippets::op::KernelStatic) \ - OP_EXTENSION(ov::snippets::op::KernelDynamic) \ - OP_EXTENSION(ov::snippets::op::Load) \ - OP_EXTENSION(ov::snippets::op::LoadReshape) \ - OP_EXTENSION(ov::snippets::op::LoopBegin) \ - OP_EXTENSION(ov::snippets::op::LoopEnd) \ - OP_EXTENSION(ov::snippets::op::Buffer) \ - OP_EXTENSION(ov::snippets::op::Nop) \ - OP_EXTENSION(ov::snippets::op::PowerStatic) \ - OP_EXTENSION(ov::snippets::op::Scalar) \ - OP_EXTENSION(ov::snippets::op::Store) \ - OP_EXTENSION(ov::snippets::op::Subgraph) \ - OP_EXTENSION(ov::snippets::op::VectorBuffer) \ - OP_EXTENSION(ov::snippets::op::RankNormalization) \ - OP_EXTENSION(ov::snippets::op::ReduceMax) \ - OP_EXTENSION(ov::snippets::op::ReduceSum) \ +#define SNIPPETS_EXTENSIONS \ + OP_EXTENSION(ov::snippets::op::Brgemm) \ + OP_EXTENSION(ov::snippets::op::BroadcastLoad) \ + OP_EXTENSION(ov::snippets::op::BroadcastMove) \ + OP_EXTENSION(ov::snippets::op::ConvertSaturation) \ + OP_EXTENSION(ov::snippets::op::ConvertTruncation) \ + OP_EXTENSION(ov::snippets::op::Fill) \ + OP_EXTENSION(ov::snippets::op::HorizonMax) \ + OP_EXTENSION(ov::snippets::op::HorizonSum) \ + OP_EXTENSION(ov::snippets::op::KernelStatic) \ + OP_EXTENSION(ov::snippets::op::KernelDynamic) \ + OP_EXTENSION(ov::snippets::op::Load) \ + OP_EXTENSION(ov::snippets::op::LoadReshape) \ + OP_EXTENSION(ov::snippets::op::LoopBegin) \ + OP_EXTENSION(ov::snippets::op::LoopEnd) \ + OP_EXTENSION(ov::snippets::op::Buffer) \ + OP_EXTENSION(ov::snippets::op::Nop) \ + OP_EXTENSION(ov::snippets::op::PowerStatic) \ + OP_EXTENSION(ov::snippets::op::Scalar) \ + OP_EXTENSION(ov::snippets::op::Store) \ + OP_EXTENSION(ov::snippets::op::Subgraph) \ + OP_EXTENSION(ov::snippets::op::VectorBuffer) \ + OP_EXTENSION(ov::snippets::op::RankNormalization) \ + OP_EXTENSION(ov::snippets::op::ReduceMax) \ + OP_EXTENSION(ov::snippets::op::ReduceSum) \ OP_EXTENSION(ov::snippets::op::Reshape) OPENVINO_CREATE_EXTENSIONS(std::vector( diff --git a/src/plugins/intel_cpu/src/graph.cpp b/src/plugins/intel_cpu/src/graph.cpp index 6aa4644f902bc9..7fb5f512227cf9 100644 --- a/src/plugins/intel_cpu/src/graph.cpp +++ b/src/plugins/intel_cpu/src/graph.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -16,6 +17,7 @@ #include #include +#include "common/primitive_desc_iface.hpp" #include "edge.h" #include "graph_dumper.h" #include "graph_optimizer.h" @@ -28,25 +30,21 @@ #include "nodes/common/cpu_memcpy.h" #include "nodes/convert.h" #include "nodes/input.h" -#include "nodes/reorder.h" #include "nodes/memory.hpp" +#include "nodes/reorder.h" #include "openvino/core/except.hpp" #include "openvino/core/model.hpp" #include "openvino/core/node.hpp" +#include "openvino/core/parallel.hpp" #include "openvino/core/type/element_type.hpp" +#include "openvino/runtime/exception.hpp" +#include "openvino/runtime/threading/cpu_streams_executor.hpp" #include "utils/debug_capabilities.h" #include "utils/general_utils.h" #include "utils/ngraph_utils.hpp" #include "utils/node_dumper.h" -#include "utils/verbose.h" #include "utils/precision_support.h" - -#include -#include "common/primitive_desc_iface.hpp" - -#include "openvino/runtime/exception.hpp" -#include "openvino/runtime/threading/cpu_streams_executor.hpp" -#include "openvino/core/parallel.hpp" +#include "utils/verbose.h" #if (OV_THREAD == OV_THREAD_TBB || OV_THREAD == OV_THREAD_TBB_AUTO) # include @@ -61,8 +59,8 @@ Graph::~Graph() { CPU_DEBUG_CAP_ENABLE(average_counters(*this)); } -template -void Graph::CreateGraph(NET &model, const GraphContext::CPtr context) { +template +void Graph::CreateGraph(NET& model, const GraphContext::CPtr context) { OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "CreateGraph"); Init(model, context); @@ -104,7 +102,7 @@ void Graph::CreateGraph(const std::vector& graphNodes, template void Graph::CreateGraph(const std::shared_ptr&, const GraphContext::CPtr); -void Graph::Replicate(const std::shared_ptr &model, +void Graph::Replicate(const std::shared_ptr& model, const std::vector& inputConfigs, const std::vector& outputConfigs) { OV_ITT_SCOPE_CHAIN(FIRST_INFERENCE, taskChain, itt::domains::intel_cpu_LT, "Graph::Replicate", "ov::Model"); @@ -135,7 +133,9 @@ void Graph::Replicate(const std::shared_ptr &model, if (op->get_type_info() == op::v0::Parameter::get_type_info_static()) { auto input_index = model->get_parameter_index(std::dynamic_pointer_cast(op)); OPENVINO_ASSERT(input_index >= 0, - "CPU plugin cannot find op: ", op->get_friendly_name(), " in model parameter list!"); + "CPU plugin cannot find op: ", + op->get_friendly_name(), + " in model parameter list!"); const auto& config = static_cast(input_index) < inputConfigs.size() ? inputConfigs[input_index] : node::Input::InputConfig{}; @@ -152,7 +152,9 @@ void Graph::Replicate(const std::shared_ptr &model, if (op->get_type_info() == op::v0::Result::get_type_info_static()) { auto output_index = model->get_result_index(std::dynamic_pointer_cast(op)); OPENVINO_ASSERT(output_index >= 0, - "CPU plugin cannot find op: ", op->get_friendly_name(), " in model result list!"); + "CPU plugin cannot find op: ", + op->get_friendly_name(), + " in model result list!"); const auto& config = static_cast(output_index) < outputConfigs.size() ? outputConfigs[output_index] : node::Input::OutputConfig{}; @@ -179,9 +181,9 @@ void Graph::Replicate(const std::shared_ptr &model, } if (!one_of(op->get_type_info(), - op::v0::Result::get_type_info_static(), - op::v3::Assign::get_type_info_static(), - op::v6::Assign::get_type_info_static())) { + op::v0::Result::get_type_info_static(), + op::v3::Assign::get_type_info_static(), + op::v6::Assign::get_type_info_static())) { for (size_t oi = 0; oi < op->get_output_size(); oi++) { if (op->get_output_target_inputs(oi).empty()) { unusedOutputs.push_back(op->output(oi)); @@ -194,10 +196,13 @@ void Graph::Replicate(const std::shared_ptr &model, for (auto unusedOutput : unusedOutputs) { auto parentNode = op2node[unusedOutput.get_node_shared_ptr()]; const auto port = unusedOutput.get_index(); - const auto nodeName = std::string("stub_") + std::to_string(unusedOutput.get_index()) + "_" + parentNode->getName(); + const auto nodeName = + std::string("stub_") + std::to_string(unusedOutput.get_index()) + "_" + parentNode->getName(); const NodePtr outNode = std::make_shared(parentNode->outputShapes[port], parentNode->getOriginalOutputPrecisionAtPort(port), - nodeName, "Result", m_context); + nodeName, + "Result", + m_context); CreateEdge(parentNode, outNode, port, 0); AddNode(outNode); } @@ -216,7 +221,7 @@ void Graph::Replicate(const std::shared_ptr &model, EnforceInferencePrecision(); // update input precisions of consumers to avoid extra reorders - for (auto &input : inputNodesMap) { + for (auto& input : inputNodesMap) { const auto& inputNode = input.second; const auto precToSet = inputNode->getOriginalOutputPrecisionAtPort(0); const auto childEdges = inputNode->getChildEdgesAtPort(0); @@ -233,7 +238,7 @@ void Graph::Replicate(const std::shared_ptr &model, // update output precisions of producers to avoid extra reorders // do this only in case output configration is not provided explicitly if (outputConfigs.empty()) { - for (auto &output : outputNodesMap) { + for (auto& output : outputNodesMap) { const auto& outputNode = output.second; const auto precToSet = outputNode->getOriginalInputPrecisionAtPort(0); const auto parentEdge = outputNode->getParentEdgeAt(0); @@ -254,11 +259,12 @@ static std::vector IdentifySyncPoints(const std::vector& graphN continue; if (node->outputShapeDataDependency() || - // WA: for convolution plus sum(broadcast). Due to the fact that a convolution with sum use the same memory for second sum term and the output - // tensors (inPlace) resizing the output tensor, may lead to reallocation of this second term memory and possible data lost. The reallocation - // may happen when the second term shape is broadcasted to the output tensor shape. To avoid the data loss, we have a special processing for - // such cases inside the convolution node, but it works properly only when dynamic shapes inference, preparation and execution a called - // for this node sequentially. + // WA: for convolution plus sum(broadcast). Due to the fact that a convolution with sum use the same memory + // for second sum term and the output tensors (inPlace) resizing the output tensor, may lead to reallocation + // of this second term memory and possible data lost. The reallocation may happen when the second term shape + // is broadcasted to the output tensor shape. To avoid the data loss, we have a special processing for such + // cases inside the convolution node, but it works properly only when dynamic shapes inference, preparation + // and execution a called for this node sequentially. (node->getType() == Type::Convolution && node->isInPlace()) || // Due to the special handling of the internal states and initialization subgraphs, MemoryInput nodes must // be processed as a internal dynamism node, allowing to hide the aforementioned complexity inside the @@ -271,15 +277,17 @@ static std::vector IdentifySyncPoints(const std::vector& graphN return syncNodesInds; } -static std::tuple, std::vector> ExtractExecutableNodesAndSyncPoints(const std::vector& syncNodesInds, - const std::vector& graphNodes) { +static std::tuple, std::vector> ExtractExecutableNodesAndSyncPoints( + const std::vector& syncNodesInds, + const std::vector& graphNodes) { OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "Graph::ExtractExecutableNodesAndSyncPoints"); std::unordered_map graphIdToExecutableId; std::vector executableGraphNodes; for (size_t i = 0; i < graphNodes.size(); i++) { const auto& graphNode = graphNodes[i]; - if ((!graphNode->isConstant() && graphNode->isExecutable()) || // non-constant executable or - (graphNode->isDynamicNode() && !one_of(graphNode->getType(), Type::Input, Type::Output))) { // dynamic, except inputs / outputs + if ((!graphNode->isConstant() && graphNode->isExecutable()) || // non-constant executable or + (graphNode->isDynamicNode() && + !one_of(graphNode->getType(), Type::Input, Type::Output))) { // dynamic, except inputs / outputs graphIdToExecutableId[i] = executableGraphNodes.size(); executableGraphNodes.emplace_back(graphNode); } @@ -291,17 +299,17 @@ static std::tuple, std::vector> ExtractExecutableNo auto it = graphIdToExecutableId.find(syncNodesInd); if (it != graphIdToExecutableId.end()) { uniqueExecutableSyncNodesInds.insert(it->second); - // since sometimes we need to run the synchronization node alone (for example in the case of internal dynamism) - // let's add another sync index after the sync point node + // since sometimes we need to run the synchronization node alone (for example in the case of internal + // dynamism) let's add another sync index after the sync point node uniqueExecutableSyncNodesInds.insert(it->second + 1); } } uniqueExecutableSyncNodesInds.insert(executableGraphNodes.size()); // convert to a vector to reduce runtime overhead - std::vector executableSyncNodesInds(uniqueExecutableSyncNodesInds.begin(), uniqueExecutableSyncNodesInds.end()); + std::vector executableSyncNodesInds(uniqueExecutableSyncNodesInds.begin(), + uniqueExecutableSyncNodesInds.end()); - return std::make_tuple(std::move(executableGraphNodes), - std::move(executableSyncNodesInds)); + return std::make_tuple(std::move(executableGraphNodes), std::move(executableSyncNodesInds)); } void Graph::Init(const std::shared_ptr& model, @@ -346,7 +354,7 @@ static void UseExternalOutputMemory(const std::map& output } void Graph::Activate(const std::vector& externalInputMemory, - const std::vector& externalOutputMemory) { + const std::vector& externalOutputMemory) { OPENVINO_ASSERT(status == Status::Initialized, "Invalid graph status"); const bool hasDynNodes = ProcessDynNodes(); @@ -360,12 +368,13 @@ void Graph::Activate(const std::vector& externalInputMemory, CreatePrimitivesAndExecConstants(); #ifndef CPU_DEBUG_CAPS - for (auto &graphNode : graphNodes) { + for (auto& graphNode : graphNodes) { graphNode->cleanup(); } #endif - std::tie(m_executableGraphNodes, m_executableSyncNodesInds) = ExtractExecutableNodesAndSyncPoints(syncNodesInds, graphNodes); + std::tie(m_executableGraphNodes, m_executableSyncNodesInds) = + ExtractExecutableNodesAndSyncPoints(syncNodesInds, graphNodes); if (hasDynNodes) { status = Status::ReadyDynamic; @@ -424,7 +433,7 @@ void Graph::Configure(bool optimize) { void Graph::InitNodes() { OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "Graph::InitNodes"); - for (auto &node : graphNodes) { + for (auto& node : graphNodes) { node->init(); } } @@ -432,7 +441,7 @@ void Graph::InitNodes() { void Graph::InitDescriptors() { OV_ITT_SCOPE_CHAIN(FIRST_INFERENCE, taskChain, itt::domains::intel_cpu_LT, "InitDescriptors", "Prepare"); - for (auto &node : graphNodes) { + for (auto& node : graphNodes) { OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, node->profiling.getSupportedDescriptors); DEBUG_LOG("Get supported primitive descriptors for node: ", node->getName()); node->getSupportedDescriptors(); @@ -445,15 +454,15 @@ void Graph::InitDescriptors() { const auto& SPDs = node->getSupportedPrimitiveDescriptors(); for (size_t i = 0; i < SPDs.size(); i++) { DEBUG_LOG("#", - node->getExecIndex(), - " ", - node->getName(), - " Before filter, SupportedPrimitiveDescriptors [", - i, - "/", - SPDs.size(), - "]: \n", - SPDs[i]); + node->getExecIndex(), + " ", + node->getName(), + " Before filter, SupportedPrimitiveDescriptors [", + i, + "/", + SPDs.size(), + "]: \n", + SPDs[i]); } } #endif @@ -478,7 +487,7 @@ void Graph::InitDescriptors() { #endif } - for (auto &node : graphNodes) { + for (auto& node : graphNodes) { OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, node->profiling.selectOptimalPrimitiveDescriptor); DEBUG_LOG("Select optimal primitive descriptors for node: ", node->getName()); node->selectOptimalPrimitiveDescriptor(); @@ -495,12 +504,18 @@ void Graph::ResolveInplaceDirections() { void Graph::InitOptimalPrimitiveDescriptors() { OV_ITT_SCOPED_TASK(itt::domains::intel_cpu, "Graph::InitOptimalPrimitiveDescriptors"); - for (auto &node : graphNodes) { + for (auto& node : graphNodes) { OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, node->profiling.initOptimalPrimitiveDescriptor); DEBUG_LOG("Init optimal primitive descriptors for node: ", node->getName()); node->initOptimalPrimitiveDescriptor(); - DEBUG_LOG("#", node->getExecIndex(), " ", node->getName(), "\n", - *node->getSelectedPrimitiveDescriptor(), "selectedPrimitiveDescriptorIdx = ", node->selectedPrimitiveDescriptorIndex); + DEBUG_LOG("#", + node->getExecIndex(), + " ", + node->getName(), + "\n", + *node->getSelectedPrimitiveDescriptor(), + "selectedPrimitiveDescriptorIdx = ", + node->selectedPrimitiveDescriptorIndex); } } @@ -508,7 +523,7 @@ void Graph::CreatePrimitivesAndExecConstants() const { OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "Graph::CreatePrimitivesAndExecConstants"); using shared_memory_ptr = WeightsSharing::SharedMemory::Ptr; - auto acquireSharedOutputs = [this](const NodePtr & node) { + auto acquireSharedOutputs = [this](const NodePtr& node) { std::vector outputs; bool hasLocalAllocatedEdges = false; bool hasExternalInvalidEdges = false; @@ -530,7 +545,7 @@ void Graph::CreatePrimitivesAndExecConstants() const { return std::make_tuple(hasExternalInvalidEdges, hasLocalAllocatedEdges, outputs); }; - for (const auto &node : graphNodes) { + for (const auto& node : graphNodes) { { OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, node->profiling.createPrimitive); DEBUG_LOG(*node); @@ -547,7 +562,7 @@ void Graph::CreatePrimitivesAndExecConstants() const { if (std::get<0>(sharedOutputs) || std::get<1>(sharedOutputs)) { ExecuteNodeWithCatch(node); - for (auto & output : std::get<2>(sharedOutputs)) + for (auto& output : std::get<2>(sharedOutputs)) output->valid(true); } } else { @@ -556,7 +571,9 @@ void Graph::CreatePrimitivesAndExecConstants() const { } } -static bool isReorderAvailable(const MemoryDescPtr& parentDesc, const MemoryDescPtr& childDesc, const dnnl::engine& eng) { +static bool isReorderAvailable(const MemoryDescPtr& parentDesc, + const MemoryDescPtr& childDesc, + const dnnl::engine& eng) { auto definedParentDesc = parentDesc->isDefined() ? parentDesc : MemoryDescUtils::makeDummyDesc(*parentDesc); memory::desc srcMemDesc = MemoryDescUtils::convertToDnnlMemoryDesc(definedParentDesc)->getDnnlDesc(); @@ -566,14 +583,16 @@ static bool isReorderAvailable(const MemoryDescPtr& parentDesc, const MemoryDesc dnnl::primitive_attr attr; dnnl_primitive_desc_t result = nullptr; - auto status = dnnl_reorder_primitive_desc_create(&result, srcMemDesc.get(), eng.get(), dstMemDesc.get(), eng.get(), + auto status = dnnl_reorder_primitive_desc_create(&result, + srcMemDesc.get(), + eng.get(), + dstMemDesc.get(), + eng.get(), attr.get()); #if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64) // temporary WA for slow FP32->FP16 conversion reorder in oneDNN on ARM // pretend the reorder is not available to use Convert node instead - if (hasHardwareSupport(ov::element::f16) && - result && - parse_impl_name(result->impl()->name()) == ref_any) { + if (hasHardwareSupport(ov::element::f16) && result && parse_impl_name(result->impl()->name()) == ref_any) { dnnl_primitive_desc_destroy(result); return false; } @@ -587,8 +606,8 @@ static bool isReorderAvailable(const MemoryDescPtr& parentDesc, const MemoryDesc void Graph::insertReorder(EdgePtr& edge, bool isOptimized, std::unordered_set& uniqueLayerNames) { std::string basicLayerName = edge->getParent()->getName() + "_" + - node::Reorder::getReorderArgs(edge->getInputDesc(), edge->getOutputDesc()) + "_" + - edge->getChild()->getName(); + node::Reorder::getReorderArgs(edge->getInputDesc(), edge->getOutputDesc()) + "_" + + edge->getChild()->getName(); std::string layerName = basicLayerName; int idx = 0; while (uniqueLayerNames.find(layerName) != uniqueLayerNames.end()) { @@ -605,11 +624,14 @@ void Graph::insertConvert(EdgePtr& edge) { const auto& inDesc = edge->getInputDesc(); const auto& outDesc = edge->getOutputDesc(); - std::string convertName = edge->getParent()->getName() + "_" + - inDesc.getPrecision().get_type_name() + "_" + outDesc.getPrecision().get_type_name(); + std::string convertName = edge->getParent()->getName() + "_" + inDesc.getPrecision().get_type_name() + "_" + + outDesc.getPrecision().get_type_name(); - auto convertNode = std::make_shared(inDesc.getShape(), inDesc.getPrecision(), outDesc.getPrecision(), - convertName, m_context); + auto convertNode = std::make_shared(inDesc.getShape(), + inDesc.getPrecision(), + outDesc.getPrecision(), + convertName, + m_context); convertNode->setDescs(inDesc, outDesc); InsertNode(edge, convertNode, true); } @@ -720,9 +742,9 @@ void Graph::AllocateWithReuse(const std::vector& syncNodesInds) { // Resolve special cases: for (size_t i = 0; i < remaining_edge_clusters_count;) { - auto &cluster = edge_clusters[i]; + auto& cluster = edge_clusters[i]; bool erase = false; - for (auto &edge : cluster) { + for (auto& edge : cluster) { // Remove already allocated edges from the mem reuse algo if (edge->getStatus() == Edge::Status::Allocated) { erase = true; @@ -730,18 +752,23 @@ void Graph::AllocateWithReuse(const std::vector& syncNodesInds) { } // Special allocation for string tensors - if (edge->getDesc().getPrecision() == element::string && edge->getStatus() == Edge::Status::NeedAllocation) { + if (edge->getDesc().getPrecision() == element::string && + edge->getStatus() == Edge::Status::NeedAllocation) { StringMemory::StringMemoryBlockPtr memBlcok; if (edge->getParent()->isConstant()) { if (edge->getParent()->getType() == Type::Input) { - auto constNode = static_cast(edge->getParent().get()); + auto constNode = static_cast(edge->getParent().get()); edge->reuse(std::const_pointer_cast(constNode->getMemoryPtr())); } else { edge->externalAllocate(m_context->getWeightsCache()); } - auto stringMemory = dynamic_cast(edge->getMemoryPtr().get()); - OPENVINO_ASSERT(stringMemory, "[CPU] Edge between nodes '", - edge->getParent()->getName(), "' and '", edge->getChild()->getName(), "' must have StringMemory."); + auto stringMemory = dynamic_cast(edge->getMemoryPtr().get()); + OPENVINO_ASSERT(stringMemory, + "[CPU] Edge between nodes '", + edge->getParent()->getName(), + "' and '", + edge->getChild()->getName(), + "' must have StringMemory."); memBlcok = stringMemory->getStringMemoryBlockPtr(); } else { auto memory = std::make_shared(getEngine(), edge->getDesc()); @@ -752,13 +779,18 @@ void Graph::AllocateWithReuse(const std::vector& syncNodesInds) { if (edge_c == edge) { continue; } - OPENVINO_ASSERT(edge_c->getDesc().getPrecision() == element::string, "All edges in the cluster must be string."); + OPENVINO_ASSERT(edge_c->getDesc().getPrecision() == element::string, + "All edges in the cluster must be string."); if (edge_c->getStatus() == Edge::Status::NotAllocated) { auto memory = std::make_shared(getEngine(), edge_c->getDesc(), memBlcok); edge_c->reuse(memory); } else { - OPENVINO_THROW("[CPU] String tensors allocation in the cluster. Edge between nodes '", edge_c->getParent()->getName(), "' and '", - edge_c->getChild()->getName(), "' has an unexpected status: ", static_cast(edge_c->getStatus())); + OPENVINO_THROW("[CPU] String tensors allocation in the cluster. Edge between nodes '", + edge_c->getParent()->getName(), + "' and '", + edge_c->getChild()->getName(), + "' has an unexpected status: ", + static_cast(edge_c->getStatus())); } } erase = true; @@ -800,14 +832,15 @@ void Graph::AllocateWithReuse(const std::vector& syncNodesInds) { int64_t boxSize = 0; bool isConst = false, isOutput = false, isInput = false; - for (auto &edge : edge_clusters[i]) { + for (auto& edge : edge_clusters[i]) { int e_start = edge->getParent()->getExecIndex(); int e_finish = edge->getChild()->getExecIndex(); auto&& desc = edge->getDesc(); if (boxSize != -1 && desc.isDefined()) { - int64_t e_size = desc.getCurrentMemSize(); // size in bytes (from the beginning of data to the last element) + int64_t e_size = + desc.getCurrentMemSize(); // size in bytes (from the beginning of data to the last element) boxSize = std::max(e_size, boxSize); } else { boxSize = -1; @@ -824,9 +857,9 @@ void Graph::AllocateWithReuse(const std::vector& syncNodesInds) { } reg.alloc_type = allocType; - isConst |= isConstOutput(edge); + isConst |= isConstOutput(edge); isOutput |= edge->getChild()->getType() == Type::Output; - isInput |= edge->getParent()->getType() == Type::Input; + isInput |= edge->getParent()->getType() == Type::Input; } reg.size = boxSize; @@ -878,7 +911,7 @@ void Graph::AllocateWithReuse(const std::vector& syncNodesInds) { memoryRegions.erase(it, memoryRegions.end()); - //Set up the memory control subsystem. + // Set up the memory control subsystem. this->m_pMemoryControl = &(getGraphContext()->getNetworkMemoryControl()->createMemoryControlUnit(syncNodesInds)); auto memoryBlocks = m_pMemoryControl->insert(memoryRegions); @@ -911,9 +944,8 @@ void Graph::AllocateWithReuse(const std::vector& syncNodesInds) { } std::vector edges_to_process; edges_to_process.push_back(edge); - for (auto next_edge = edge->getSharedEdge(std::nothrow); - next_edge; - next_edge = next_edge->getSharedEdge(std::nothrow)) { + for (auto next_edge = edge->getSharedEdge(std::nothrow); next_edge; + next_edge = next_edge->getSharedEdge(std::nothrow)) { edges_to_process.push_back(next_edge); } std::for_each(edges_to_process.rbegin(), edges_to_process.rend(), [](const EdgePtr& edge) { @@ -937,16 +969,15 @@ void Graph::AllocateWithReuse(const std::vector& syncNodesInds) { void Graph::Allocate(const std::vector& syncNodesInds) { OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "Graph::Allocate"); - //resolve inplace dead end nodes + // resolve inplace dead end nodes for (const auto& edge : graphEdges) { if (edge->getStatus() == Edge::Status::Uninitialized) { if (edge->getParent()->getParentEdges().empty() && - one_of(edge->getParent()->getType(), Type::Input, Type::MemoryInput) && - edge->inPlace(Edge::LOOK_UP)) { + one_of(edge->getParent()->getType(), Type::Input, Type::MemoryInput) && edge->inPlace(Edge::LOOK_UP)) { edge->getParent()->resolveInPlaceEdges(Edge::LOOK_UP); } else if (edge->getChild()->getChildEdges().empty() && - one_of(edge->getChild()->getType(), Type::Output, Type::MemoryOutput) && - edge->inPlace(Edge::LOOK_DOWN)) { + one_of(edge->getChild()->getType(), Type::Output, Type::MemoryOutput) && + edge->inPlace(Edge::LOOK_DOWN)) { edge->getChild()->resolveInPlaceEdges(Edge::LOOK_DOWN); } } @@ -955,13 +986,15 @@ void Graph::Allocate(const std::vector& syncNodesInds) { // resolve edges. Define which will be a view on others // NeedAllocation - real blob // NotAllocated - view on other blob, peer or in-place - for (auto& edge : graphEdges) edge->init(); + for (auto& edge : graphEdges) + edge->init(); // Allocate memory space for all edges marked with NeedAllocation AllocateWithReuse(syncNodesInds); // Check all getters. Should work. - for (auto& edge : graphEdges) edge->validate(); + for (auto& edge : graphEdges) + edge->validate(); } bool Graph::ProcessDynNodes() { @@ -975,7 +1008,8 @@ bool Graph::ProcessDynNodes() { } void Graph::PushInputData(const std::size_t& index, const ov::SoPtr& input) { - if (!IsReady()) OPENVINO_THROW("Wrong state. Topology not ready."); + if (!IsReady()) + OPENVINO_THROW("Wrong state. Topology not ready."); auto input_itr = inputNodesMap.find(index); if (input_itr != inputNodesMap.end()) { auto node = input_itr->second; @@ -1010,7 +1044,7 @@ void Graph::PullOutputData(std::unordered_map>& if (!IsReady()) OPENVINO_THROW("Wrong state. Topology not ready."); - for (auto &outputMap : outputNodesMap) { + for (auto& outputMap : outputNodesMap) { auto output_index = outputMap.first; auto node = outputMap.second; auto parentEdge = node->getParentEdgeAt(0); @@ -1040,17 +1074,32 @@ void Graph::PullOutputData(std::unordered_map>& if (ext_blob->get_shape() != outDims && !isScalarOutput) { // WA: because input/output info initially contains non empty dims, order etc. // and setDims (called inside setShape) can't correct modify blocked desc for desc with blocked layout - DEBUG_LOG(output_index, ", tensor data addr ", static_cast(output[output_index]->data()), - " dims ", PartialShape(output[output_index]->get_shape()), " -> ", PartialShape(outDims), - ", intr ptr ", intr_blob.getData(), " , parentedge's memory object ", parentEdge->getMemoryPtr().get()); + DEBUG_LOG(output_index, + ", tensor data addr ", + static_cast(output[output_index]->data()), + " dims ", + PartialShape(output[output_index]->get_shape()), + " -> ", + PartialShape(outDims), + ", intr ptr ", + intr_blob.getData(), + " , parentedge's memory object ", + parentEdge->getMemoryPtr().get()); ext_blob->set_shape(outDims); - DEBUG_LOG(output_index, ", tensor data addr ", static_cast(output[output_index]->data()), - " dims ", PartialShape(output[output_index]->get_shape()), ", intr ptr ", intr_blob.getData()); + DEBUG_LOG(output_index, + ", tensor data addr ", + static_cast(output[output_index]->data()), + " dims ", + PartialShape(output[output_index]->get_shape()), + ", intr ptr ", + intr_blob.getData()); expected_desc_ptr = MemoryDescUtils::generateCpuBlockedMemoryDesc(ext_blob); } // check for empty output blob - if (std::any_of(outDims.begin(), outDims.end(), [](const Dim dim) {return dim == 0;})) { + if (std::any_of(outDims.begin(), outDims.end(), [](const Dim dim) { + return dim == 0; + })) { continue; } @@ -1063,12 +1112,22 @@ void Graph::PullOutputData(std::unordered_map>& intr_blob.getSize(), ")."); - void *ext_blob_ptr = ext_blob->data(); - void *intr_blob_ptr = intr_blob.getData(); - DEBUG_LOG(output_index, " @ ", intr_blob_ptr, " -> ", ext_blob_ptr, " zero-copy: ", intr_blob_ptr == ext_blob_ptr, " graph ", this, "\r\n"); + void* ext_blob_ptr = ext_blob->data(); + void* intr_blob_ptr = intr_blob.getData(); + DEBUG_LOG(output_index, + " @ ", + intr_blob_ptr, + " -> ", + ext_blob_ptr, + " zero-copy: ", + intr_blob_ptr == ext_blob_ptr, + " graph ", + this, + "\r\n"); // That is the same memory. No need to copy - if (ext_blob_ptr == intr_blob_ptr) continue; + if (ext_blob_ptr == intr_blob_ptr) + continue; if (actualDesc->getPrecision() == element::string) { StringMemory outBloMem(getEngine(), expected_desc_ptr, ext_blob_ptr); @@ -1077,7 +1136,10 @@ void Graph::PullOutputData(std::unordered_map>& Memory outBloMem(getEngine(), expected_desc_ptr, ext_blob_ptr, false); outBloMem.load(intr_blob, false); } else { - OPENVINO_ASSERT(srcPrec == dstPrec, "The precision of the CPU output tensor index", output_index, " is different from the external one"); + OPENVINO_ASSERT(srcPrec == dstPrec, + "The precision of the CPU output tensor index", + output_index, + " is different from the external one"); size_t size_to_copy = intr_blob.getSize(); cpu_parallel_memcpy(ext_blob_ptr, intr_blob_ptr, size_to_copy); } @@ -1108,7 +1170,8 @@ namespace { class UpdateNodesSeq { public: - explicit UpdateNodesSeq(std::vector& executableGraphNodes) : m_executableGraphNodes(executableGraphNodes) {} + explicit UpdateNodesSeq(std::vector& executableGraphNodes) + : m_executableGraphNodes(executableGraphNodes) {} void operator()(size_t stopIndx) { for (; prepareCounter < stopIndx; ++prepareCounter) { @@ -1126,7 +1189,7 @@ class UpdateNodesSeq { }; #if (OV_THREAD == OV_THREAD_SEQ) - using UpdateNodes = UpdateNodesSeq; +using UpdateNodes = UpdateNodesSeq; #endif #if (OV_THREAD == OV_THREAD_TBB || OV_THREAD == OV_THREAD_TBB_AUTO || OV_THREAD == OV_THREAD_OMP) @@ -1143,7 +1206,8 @@ class UpdateNodesSeq { class UpdateNodesBase { public: - explicit UpdateNodesBase(std::vector& executableGraphNodes) : m_executableGraphNodes(executableGraphNodes) {} + explicit UpdateNodesBase(std::vector& executableGraphNodes) + : m_executableGraphNodes(executableGraphNodes) {} void updateShapes(size_t node_indx, size_t stop_indx) { try { for (size_t i = node_indx; i < stop_indx; i++) { @@ -1153,8 +1217,7 @@ class UpdateNodesBase { } m_prepareCounter.store(i, ov_memory_order_release); } - } - catch(...) { + } catch (...) { m_completion.store(true, ov_memory_order_relaxed); throw; } @@ -1185,13 +1248,16 @@ class UpdateNodesBase { std::vector& m_executableGraphNodes; }; -#if (OV_THREAD == OV_THREAD_TBB || OV_THREAD == OV_THREAD_TBB_AUTO) -#if (TBB_VERSION_MAJOR > 2020) +# if (OV_THREAD == OV_THREAD_TBB || OV_THREAD == OV_THREAD_TBB_AUTO) +# if (TBB_VERSION_MAJOR > 2020) template class AsyncTask : public tbb::detail::d1::task { public: - AsyncTask(Body& body, tbb::detail::d1::wait_context& wait, size_t node_indx, size_t stop_indx) : - m_body(body), m_wait(wait), m_node_indx(node_indx), m_stop_indx(stop_indx) {} + AsyncTask(Body& body, tbb::detail::d1::wait_context& wait, size_t node_indx, size_t stop_indx) + : m_body(body), + m_wait(wait), + m_node_indx(node_indx), + m_stop_indx(stop_indx) {} task* execute(tbb::detail::d1::execution_data&) override { m_body(m_node_indx, m_stop_indx); m_wait.release(); @@ -1235,11 +1301,14 @@ class UpdateNodes : public UpdateNodesBase { private: tbb::task_group_context ctx; }; -#else +# else template class AsyncTask : public tbb::task { public: - AsyncTask(Body& body, size_t node_indx, size_t stop_indx) : m_body(body), m_node_indx(node_indx), m_stop_indx(stop_indx) {} + AsyncTask(Body& body, size_t node_indx, size_t stop_indx) + : m_body(body), + m_node_indx(node_indx), + m_stop_indx(stop_indx) {} task* execute() override { m_body(m_node_indx, m_stop_indx); return nullptr; @@ -1257,28 +1326,30 @@ class UpdateNodes : public UpdateNodesBase { void operator()(size_t stopIndx) { m_completion.store(false); auto startCounter = m_prepareCounter.load(); - tbb::task& root = *new(tbb::task::allocate_root()) tbb::empty_task; - root.set_ref_count(3); // two for children and one preserved + tbb::task& root = *new (tbb::task::allocate_root()) tbb::empty_task; + root.set_ref_count(3); // two for children and one preserved auto task1 = [this](size_t start, size_t stop) { this->updateShapes(start, stop); }; - AsyncTask& a = *new (root.allocate_child()) AsyncTask(task1, startCounter, stopIndx); + AsyncTask& a = + *new (root.allocate_child()) AsyncTask(task1, startCounter, stopIndx); auto task2 = [this](size_t start, size_t stop) { this->updateDynParams(start, stop); }; - AsyncTask& b = *new (root.allocate_child()) AsyncTask(task2, startCounter, stopIndx); + AsyncTask& b = + *new (root.allocate_child()) AsyncTask(task2, startCounter, stopIndx); - b.set_affinity(2); // slot 1 plus 1 + b.set_affinity(2); // slot 1 plus 1 tbb::task::spawn(b); root.spawn_and_wait_for_all(a); } }; -#endif -#endif +# endif +# endif -#if (OV_THREAD == OV_THREAD_OMP) +# if (OV_THREAD == OV_THREAD_OMP) class UpdateNodes : public UpdateNodesBase { public: using UpdateNodesBase::UpdateNodesBase; @@ -1293,14 +1364,15 @@ class UpdateNodes : public UpdateNodesBase { if (origin_nested_levels < 2) { set_max_nested_levels(2); } - // In OpenMP, an exception that is thrown in a parallel region must be caught and handled in the same region by the same thread. - // Therefore, need to pass the error message and throw a new exception outside the parallel region. + // In OpenMP, an exception that is thrown in a parallel region must be caught and handled in the same region by + // the same thread. Therefore, need to pass the error message and throw a new exception outside the parallel + // region. const char* what = nullptr; - #pragma omp parallel - #pragma omp sections +# pragma omp parallel +# pragma omp sections { - #pragma omp section +# pragma omp section { try { updateDynParams(startCounter, stopIndx); @@ -1310,7 +1382,7 @@ class UpdateNodes : public UpdateNodesBase { what = "[ CPU ] Could not update dynamic parameters."; } } - #pragma omp section +# pragma omp section { try { updateShapes(startCounter, stopIndx); @@ -1329,18 +1401,18 @@ class UpdateNodes : public UpdateNodesBase { OPENVINO_ASSERT(what == nullptr, what); } }; -#endif +# endif #endif -} // namespace +} // namespace /* group all the profiling macros into a single one * to avoid cluttering a core logic */ #define VERBOSE_PERF_DUMP_ITT_DEBUG_LOG(ittScope, node, config) \ - VERBOSE(node, config.debugCaps.verbose); \ - PERF(node, config.collectPerfCounters); \ - DUMP(node, config.debugCaps, infer_count); \ - OV_ITT_SCOPED_TASK(ittScope, node->profiling.execute); \ + VERBOSE(node, config.debugCaps.verbose); \ + PERF(node, config.collectPerfCounters); \ + DUMP(node, config.debugCaps, infer_count); \ + OV_ITT_SCOPED_TASK(ittScope, node->profiling.execute); \ DEBUG_LOG(*node); inline void Graph::ExecuteNode(const NodePtr& node, SyncInferRequest* request, int numaId) const { @@ -1362,7 +1434,7 @@ inline void Graph::ExecuteNodeWithCatch(const NodePtr& node, SyncInferRequest* r } } -template +template void Graph::InferDynamic(SyncInferRequest* request, int numaId, UpdateStrategy&& update) { size_t inferCounter = 0; for (auto stopIndx : m_executableSyncNodesInds) { @@ -1410,17 +1482,20 @@ void Graph::Infer(SyncInferRequest* request) { InferStatic(request, numaId); break; default: - OPENVINO_ASSERT(IsReady(), "Wrong state of the ov::intel_cpu::Graph. Topology is not ready: ", static_cast(status)); + OPENVINO_ASSERT(IsReady(), + "Wrong state of the ov::intel_cpu::Graph. Topology is not ready: ", + static_cast(status)); } - if (infer_count != -1) infer_count++; + if (infer_count != -1) + infer_count++; } void Graph::SortTopologically() { OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "Graph::SortTopologically"); // Set execIndex of all nodes to default invaild value - for (auto &node : graphNodes) { + for (auto& node : graphNodes) { node->execIndex = -1; } @@ -1433,7 +1508,7 @@ void Graph::SortTopologically() { std::function visit; visit = [&execIndexCnt, &sorted, &visit](const NodePtr node) { if (node->execIndex >= 0) - return; // already visited + return; // already visited for (size_t i = 0; i < node->getParentEdges().size(); i++) { visit(node->getParentEdgeAt(i)->getParent()); @@ -1467,7 +1542,7 @@ void Graph::SortTopologically() { // Sort in / out child edges by port index // Make first N (N == port_num) edge indexes match with port index - for (auto &node : graphNodes) { + for (auto& node : graphNodes) { int port_num = node->outputShapes.size(); std::vector res(port_num); @@ -1512,10 +1587,7 @@ void Graph::GetPerfData(std::vector& perfMap) const { } } -void Graph::CreateEdge(const NodePtr& parent, - const NodePtr& child, - int parentPort, - int childPort) { +void Graph::CreateEdge(const NodePtr& parent, const NodePtr& child, int parentPort, int childPort) { assert(parentPort >= 0 && childPort >= 0); auto edge = std::make_shared(parent, child, parentPort, childPort); @@ -1539,24 +1611,28 @@ void Graph::AddNode(NodePtr node) { graphNodes.push_back(node); } -void Graph::DropNode(const NodePtr &node) { +void Graph::DropNode(const NodePtr& node) { auto children = node->childEdges; auto parents = node->parentEdges; for (size_t i = 0; i < parents.size(); i++) { auto p_edge = parents[i].lock(); - if (!p_edge) continue; + if (!p_edge) + continue; auto parent = p_edge->getParent(); - if (!parent) continue; + if (!parent) + continue; const int inNum = p_edge->getInputNum(); RemoveEdge(p_edge); for (size_t j = 0; j < children.size(); j++) { auto c_edge = children[j].lock(); - if (!c_edge) continue; + if (!c_edge) + continue; auto child = c_edge->getChild(); - if (!child) continue; + if (!child) + continue; const int outNum = c_edge->getOutputNum(); RemoveEdge(c_edge); @@ -1565,31 +1641,37 @@ void Graph::DropNode(const NodePtr &node) { } } -void Graph::DropDWConvNode(const NodePtr &node) { +void Graph::DropDWConvNode(const NodePtr& node) { auto children = node->childEdges; auto parents = node->parentEdges; auto parentConvEdge = parents[0].lock(); - if (!parentConvEdge) return; + if (!parentConvEdge) + return; auto parentConv = parentConvEdge->getParent(); - if (!parentConv) return; + if (!parentConv) + return; parentConv->outputShapes[0] = node->outputShapes[0]; for (size_t i = 0; i < 1; i++) { auto p_edge = parents[i].lock(); - if (!p_edge) continue; + if (!p_edge) + continue; auto parent = p_edge->getParent(); - if (!parent) continue; + if (!parent) + continue; const int inNum = p_edge->getInputNum(); RemoveEdge(p_edge); for (size_t j = 0; j < children.size(); j++) { auto c_edge = children[j].lock(); - if (!c_edge) continue; + if (!c_edge) + continue; auto child = c_edge->getChild(); - if (!child) continue; + if (!child) + continue; const int outNum = c_edge->getOutputNum(); RemoveEdge(c_edge); @@ -1599,9 +1681,11 @@ void Graph::DropDWConvNode(const NodePtr &node) { for (size_t i = 1; i < parents.size(); i++) { auto p_edge = parents[i].lock(); - if (!p_edge) continue; + if (!p_edge) + continue; auto parent = p_edge->getParent(); - if (!parent) continue; + if (!parent) + continue; const int inNum = p_edge->getInputNum(); const int portCandidate = p_edge->getOutputNum(); @@ -1615,14 +1699,20 @@ void Graph::DropDWConvNode(const NodePtr &node) { } void Graph::RemoveDroppedNodes() { - graphNodes.erase(std::remove_if(graphNodes.begin(), graphNodes.end(), - [](const NodePtr& node){ return node->isDropped(); }), + graphNodes.erase(std::remove_if(graphNodes.begin(), + graphNodes.end(), + [](const NodePtr& node) { + return node->isDropped(); + }), graphNodes.end()); } void Graph::RemoveDroppedEdges() { - graphEdges.erase(std::remove_if(graphEdges.begin(), graphEdges.end(), - [](const EdgePtr& node){ return node->isDropped(); }), + graphEdges.erase(std::remove_if(graphEdges.begin(), + graphEdges.end(), + [](const EdgePtr& node) { + return node->isDropped(); + }), graphEdges.end()); } @@ -1631,19 +1721,28 @@ NodePtr Graph::InsertReorder(EdgePtr edge, const MemoryDesc& inDesc, const MemoryDesc& outDesc, bool isOptimized, - const std::vector & src_perm) { + const std::vector& src_perm) { auto reorder = std::make_shared(inDesc, outDesc, layerName, m_context); reorder->setOptimized(isOptimized); reorder->setSrcPermutation(src_perm); DEBUG_LOG(reorder->getName(), " edge=", *edge, " isOptimized=", isOptimized); - DEBUG_LOG(" inDesc: ", inDesc.getShape().toString(), inDesc.getPrecision().get_type_name(), " ", inDesc.serializeFormat()); - DEBUG_LOG(" outDesc: ", outDesc.getShape().toString(), outDesc.getPrecision().get_type_name(), " ", outDesc.serializeFormat()); + DEBUG_LOG(" inDesc: ", + inDesc.getShape().toString(), + inDesc.getPrecision().get_type_name(), + " ", + inDesc.serializeFormat()); + DEBUG_LOG(" outDesc: ", + outDesc.getShape().toString(), + outDesc.getPrecision().get_type_name(), + " ", + outDesc.serializeFormat()); InsertNode(edge, reorder, true); // Using the method Edge::getDesc() we can check that input and output tensor descriptors are equal. - // Due to the specificity of GraphOptimizer::MergeTransposeAndReorder() that isOptimized flag uses, we shouldn't do these checks. + // Due to the specificity of GraphOptimizer::MergeTransposeAndReorder() that isOptimized flag uses, we shouldn't do + // these checks. if (!isOptimized) { reorder->getParentEdgeAt(0)->getDesc(); reorder->getChildEdgeAt(0)->getDesc(); @@ -1692,10 +1791,10 @@ void Graph::EnforceInferencePrecision() { const auto inferPrec = getConfig().inferencePrecision; if (one_of(inferPrec, element::f32, element::undefined, ov::element::f16)) - return; // nothing to do, only precision reduction is currently allowed + return; // nothing to do, only precision reduction is currently allowed #if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64) if (inferPrec == ov::element::f16) - return; // precision of configured by ov::pass::ConvertPrecision + return; // precision of configured by ov::pass::ConvertPrecision #endif std::function& skipNodes)> searchForNodesToSkip; searchForNodesToSkip = [&](const NodePtr& node, std::unordered_set& skipNodes) -> void { @@ -1703,35 +1802,35 @@ void Graph::EnforceInferencePrecision() { const auto& parent = node->getParentEdgeAt(i)->getParent(); if (inferPrec == ov::element::bf16) { /* list of node types that must be forced to be executed in BF16 precision - * because of performance gains */ + * because of performance gains */ if (one_of(parent->getType(), - Type::Convolution, // conv nets - Type::FullyConnected, // conv / bert nets - Type::RNNCell, // recurent nets - Type::RNNSeq, // recurent nets - Type::MatMul, // bert nets - Type::ROIPooling, // object detection nets - Type::Interpolate, // super resolution nets - Type::PagedAttention, // page attention - Type::QKVProjection, - Type::LLMMLP)) - continue; // stop at significant nodes + Type::Convolution, // conv nets + Type::FullyConnected, // conv / bert nets + Type::RNNCell, // recurent nets + Type::RNNSeq, // recurent nets + Type::MatMul, // bert nets + Type::ROIPooling, // object detection nets + Type::Interpolate, // super resolution nets + Type::PagedAttention, // page attention + Type::QKVProjection, + Type::LLMMLP)) + continue; // stop at significant nodes } else if (inferPrec == ov::element::f16) { /* list of node types that must be forced to be executed in FP16 precision - * because of performance gains */ + * because of performance gains */ if (one_of(parent->getType(), - Type::Convolution, // conv nets - Type::Deconvolution, // deconv - Type::FullyConnected, // conv / bert nets - Type::MatMul, // bert nets - Type::Pooling, - Type::MVN)) - continue; // stop at significant nodes + Type::Convolution, // conv nets + Type::Deconvolution, // deconv + Type::FullyConnected, // conv / bert nets + Type::MatMul, // bert nets + Type::Pooling, + Type::MVN)) + continue; // stop at significant nodes } const auto res = skipNodes.insert(parent); - if (res.second) // node not visited yet + if (res.second) // node not visited yet searchForNodesToSkip(parent, skipNodes); } }; @@ -1772,10 +1871,10 @@ void Graph::EnforceInferencePrecision() { // kvcache of PagedAttention should be written directly if (node->getType() == Type::PagedAttention && (inPort == 3 || inPort == 4)) return true; - const auto &parent = node->getParentEdgeAt(inPort)->getParent(); + const auto& parent = node->getParentEdgeAt(inPort)->getParent(); /* Skip BF16 enforcement for nodes after Constant Inputs for maintaining precision for fusing. - * Element type conversion to bf16 is done automatically, if convolution follows up after Constant Inputs - * and activation is bf16 */ + * Element type conversion to bf16 is done automatically, if convolution follows up after Constant + * Inputs and activation is bf16 */ if (parent->getType() == Type::Input && parent->isConstant() && // Concatenation node is exception because it doesn't change an accuracy for BF16 activation node->getType() != Type::Concatenation) @@ -1815,7 +1914,7 @@ void Graph::EnforceInferencePrecision() { // exclude Convert before Range since it may cause precision loss when integter type to LP. // TODO: Incorrect subgraph is generated by ONNX FE + ticket 117861. - const auto &child = node->getChildEdgeAt(i)->getChild(); + const auto& child = node->getChildEdgeAt(i)->getChild(); if (child->getType() == Type::Range && node->getType() == Type::Convert) continue; // skip second output of PagedAttention @@ -1845,5 +1944,5 @@ const std::unordered_map& Graph::getInterna return m_context->getMemoryStatesRegister()->getMemoryStates(); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/graph.h b/src/plugins/intel_cpu/src/graph.h index d50ccc152c9186..bdf3205d2baaaf 100644 --- a/src/plugins/intel_cpu/src/graph.h +++ b/src/plugins/intel_cpu/src/graph.h @@ -4,22 +4,20 @@ #pragma once +#include +#include +#include +#include + #include "config.h" #include "cpu_memory.h" -#include "nodes/input.h" -#include "openvino/core/node_vector.hpp" -#include "openvino/runtime/profiling_info.hpp" -#include "node.h" #include "edge.h" #include "graph_context.h" #include "memory_control.hpp" +#include "node.h" +#include "nodes/input.h" +#include "openvino/core/node_vector.hpp" #include "openvino/runtime/profiling_info.hpp" - -#include -#include -#include -#include - #include "openvino/runtime/so_ptr.hpp" #include "proxy_mem_blk.h" @@ -29,7 +27,7 @@ namespace intel_cpu { class SyncInferRequest; namespace node { class MemoryStateNode; -} // namespace node +} // namespace node class Graph { public: @@ -61,15 +59,15 @@ class Graph { return IsStatic() || IsDynamic(); } - const Config & getConfig() const { + const Config& getConfig() const { return m_context->getConfig(); } - template - void CreateGraph(NET &model, const GraphContext::CPtr context); + template + void CreateGraph(NET& model, const GraphContext::CPtr context); - void CreateGraph(const std::vector &graphNodes, - const std::vector &graphEdges, + void CreateGraph(const std::vector& graphNodes, + const std::vector& graphEdges, const GraphContext::CPtr context, std::string name); @@ -97,14 +95,14 @@ class Graph { return outputNodesMap; } - NodePtr getInputNodeByIndex(const std::size_t &index) { + NodePtr getInputNodeByIndex(const std::size_t& index) { auto input = inputNodesMap.find(index); if (input == inputNodesMap.end()) OPENVINO_THROW("CPU execution graph doesn't contain input node with index: ", index); return input->second; } - NodePtr getOutputNodeByIndex(const std::size_t &index) { + NodePtr getOutputNodeByIndex(const std::size_t& index) { auto output = outputNodesMap.find(index); if (output == outputNodesMap.end()) OPENVINO_THROW("CPU execution graph doesn't contain output node with index: ", index); @@ -119,12 +117,9 @@ class Graph { return m_context; } - void GetPerfData(std::vector &perfMap) const; + void GetPerfData(std::vector& perfMap) const; - void CreateEdge(const NodePtr& parent, - const NodePtr& child, - int parentPort = 0, - int childPort = 0); + void CreateEdge(const NodePtr& parent, const NodePtr& child, int parentPort = 0, int childPort = 0); void RemoveEdge(const EdgePtr& edge); void RemoveDroppedNodes(); void RemoveDroppedEdges(); @@ -134,9 +129,9 @@ class Graph { /** * @brief Insert Reorder node at the edge-specified location. - * The Reorder node must be inserted in case when there are inplace conflicts or the input and output tensor descriptors do not match. - * The Reorder node rearranges the elements in memory according to inDesc and outDesc, or reinterprets memory descriptor without - * rearrangement of elements if isOptimized is true. + * The Reorder node must be inserted in case when there are inplace conflicts or the input and output tensor + * descriptors do not match. The Reorder node rearranges the elements in memory according to inDesc and outDesc, or + * reinterprets memory descriptor without rearrangement of elements if isOptimized is true. * @param edge * pointer to the edge in the graph where Reorder node will be inserted * @param layerName @@ -153,14 +148,18 @@ class Graph { * pointer to the blob containing scales * @return pointer to the new Reorder node. */ - NodePtr InsertReorder(EdgePtr edge, std::string layerName, const MemoryDesc& inDesc, - const MemoryDesc& outDesc, bool isOptimized = false, const std::vector & src_perm = {}); + NodePtr InsertReorder(EdgePtr edge, + std::string layerName, + const MemoryDesc& inDesc, + const MemoryDesc& outDesc, + bool isOptimized = false, + const std::vector& src_perm = {}); /** * @brief Insert Node at the edge-specified location. - * This method supports two regimes. First, the node is inserted without initialization (i.e. supported descriptors initialization, - * supported primitive descriptors selection, etc.), which can be useful after the ResolveEdgeConflicts() completes. The second is just inserting the - * node without initialization. + * This method supports two regimes. First, the node is inserted without initialization (i.e. supported descriptors + * initialization, supported primitive descriptors selection, etc.), which can be useful after the + * ResolveEdgeConflicts() completes. The second is just inserting the node without initialization. * @param edge * pointer to the edge in the graph where the node will be inserted * @param node @@ -173,10 +172,10 @@ class Graph { /** * @brief Insert Node between two specified nodes. - * This procedure creates two edges that link the parent and child nodes to the inserted one and adds all created objects to the graph. - * This method supports two regimes. First, the node is inserted without initialization (i.e. supported descriptors initialization, - * supported primitive descriptors selection, etc.), which can be useful after the ResolveEdgeConflicts() completes. The second is just inserting the - * node without initialization. + * This procedure creates two edges that link the parent and child nodes to the inserted one and adds all created + * objects to the graph. This method supports two regimes. First, the node is inserted without initialization (i.e. + * supported descriptors initialization, supported primitive descriptors selection, etc.), which can be useful after + * the ResolveEdgeConflicts() completes. The second is just inserting the node without initialization. * @param parent * pointer to the parent node * @param child @@ -193,7 +192,9 @@ class Graph { std::shared_ptr dump() const; - void ResetInferCount() { infer_count = 0; } + void ResetInferCount() { + infer_count = 0; + } void SortTopologically(); @@ -215,7 +216,7 @@ class Graph { * Activate execution graph using \p externalInputMemory and \p externalOutputMemory */ void Activate(const std::vector& externalInputMemory = {}, - const std::vector& externalOutputMemory = {}); + const std::vector& externalOutputMemory = {}); const std::unordered_map& getOutputNodesMemBlocksMap() const { return outputNodesMemBlocksMap; @@ -231,7 +232,7 @@ class Graph { graphEdges.clear(); m_executableSyncNodesInds.clear(); } - Status status { Status::NotReady }; + Status status{Status::NotReady}; // For dumping purposes. -1 - no counting, all other positive // values mean increment it within each Infer() call @@ -244,7 +245,7 @@ class Graph { bool graphHasDynamicInput = false; - void Replicate(const std::shared_ptr &subgraph, + void Replicate(const std::shared_ptr& subgraph, const std::vector& inputConfigs = {}, const std::vector& outputConfigs = {}); @@ -281,10 +282,10 @@ class Graph { void ExecuteNode(const NodePtr& node, SyncInferRequest* request = nullptr, int numaId = -1) const; void InferStatic(SyncInferRequest* request, int numaId); - template + template void InferDynamic(SyncInferRequest* request, int numaId, UpdateStrategy&& update); - friend std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph &graph); + friend std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph& graph); private: using event_t = void (Graph::*)(void); diff --git a/src/plugins/intel_cpu/src/graph_context.cpp b/src/plugins/intel_cpu/src/graph_context.cpp index 5b967ed58a7918..462cdab2a9b5c0 100644 --- a/src/plugins/intel_cpu/src/graph_context.cpp +++ b/src/plugins/intel_cpu/src/graph_context.cpp @@ -1,10 +1,11 @@ // Copyright (C) 2018-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // -#include "dnnl_types.h" #include "graph_context.h" -#include "nodes/memory.hpp" + +#include "dnnl_types.h" #include "memory_control.hpp" +#include "nodes/memory.hpp" namespace ov { namespace intel_cpu { @@ -42,5 +43,5 @@ const dnnl::engine& GraphContext::getEngine() { return eng; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/graph_context.h b/src/plugins/intel_cpu/src/graph_context.h index ce51af0c81b4bd..d13872129325b4 100644 --- a/src/plugins/intel_cpu/src/graph_context.h +++ b/src/plugins/intel_cpu/src/graph_context.h @@ -4,11 +4,11 @@ #pragma once -#include "openvino/runtime/threading/cpu_streams_executor.hpp" -#include "sub_memory_manager.hpp" #include "cache/multi_cache.h" #include "config.h" #include "dnnl_scratch_pad.h" +#include "openvino/runtime/threading/cpu_streams_executor.hpp" +#include "sub_memory_manager.hpp" #include "weights_cache.hpp" namespace ov { @@ -16,7 +16,7 @@ namespace intel_cpu { namespace node { class MemoryStatesRegister; -} // namespace node +} // namespace node class NetworkMemoryControl; @@ -39,7 +39,6 @@ class GraphContext { return weightsCache; } - MultiCachePtr getParamsCache() const { return rtParamsCache; } @@ -81,7 +80,7 @@ class GraphContext { private: Config config; // network-level config - WeightsSharing::Ptr weightsCache; // per NUMA node caches for sharing weights data + WeightsSharing::Ptr weightsCache; // per NUMA node caches for sharing weights data MultiCachePtr rtParamsCache; // primitive cache DnnlScratchPadPtr rtScratchPad; // scratch pad @@ -90,9 +89,9 @@ class GraphContext { std::vector rtScratchPads; // scratch pad (each sub-stream has its own copy) - ov::threading::IStreamsExecutor::Ptr streamExecutor; // stream executor for current graph + ov::threading::IStreamsExecutor::Ptr streamExecutor; // stream executor for current graph - ov::threading::CPUStreamsExecutor::Ptr cpuStreamExecutor; // cpu stream executor for current graph + ov::threading::CPUStreamsExecutor::Ptr cpuStreamExecutor; // cpu stream executor for current graph std::shared_ptr subMemoryManager; diff --git a/src/plugins/intel_cpu/src/graph_dumper.cpp b/src/plugins/intel_cpu/src/graph_dumper.cpp index 04c15408743c71..5a3a95362267fe 100644 --- a/src/plugins/intel_cpu/src/graph_dumper.cpp +++ b/src/plugins/intel_cpu/src/graph_dumper.cpp @@ -4,28 +4,28 @@ #include "graph_dumper.h" -#include "dnnl_debug.h" -#include "openvino/pass/manager.hpp" -#include "openvino/pass/serialize.hpp" -#include "openvino/runtime/exec_model_info.hpp" -#include "utils/debug_capabilities.h" - #include +#include #include #include #include #include -#include + +#include "dnnl_debug.h" +#include "openvino/pass/manager.hpp" +#include "openvino/pass/serialize.hpp" +#include "openvino/runtime/exec_model_info.hpp" +#include "utils/debug_capabilities.h" namespace ov { namespace intel_cpu { -void serializeToCout(const Graph &graph); -void serializeToXML(const Graph &graph, const std::string& path); +void serializeToCout(const Graph& graph); +void serializeToXML(const Graph& graph, const std::string& path); namespace { -std::map extract_node_metadata(const NodePtr &node) { +std::map extract_node_metadata(const NodePtr& node) { std::map serialization_info; if (node->getType() == Type::Input && node->isConstant()) { @@ -47,7 +47,8 @@ std::map extract_node_metadata(const NodePtr &node) { bool isAllEqual = true; for (size_t i = 1; i < node->getChildEdges().size(); i++) { - if (node->getChildEdgeAt(i - 1)->getMemory().getDesc().getPrecision() != node->getChildEdgeAt(i)->getMemory().getDesc().getPrecision()) { + if (node->getChildEdgeAt(i - 1)->getMemory().getDesc().getPrecision() != + node->getChildEdgeAt(i)->getMemory().getDesc().getPrecision()) { isAllEqual = false; break; } @@ -56,7 +57,8 @@ std::map extract_node_metadata(const NodePtr &node) { // If all output precisions are the same, we store the name only once if (!isAllEqual) { for (size_t i = 1; i < node->getChildEdges().size(); i++) - outputPrecisionsStr += "," + std::string(node->getChildEdgeAt(i)->getMemory().getDesc().getPrecision().get_type_name()); + outputPrecisionsStr += + "," + std::string(node->getChildEdgeAt(i)->getMemory().getDesc().getPrecision().get_type_name()); } } else { // Branch to correctly handle output nodes @@ -107,8 +109,8 @@ std::map extract_node_metadata(const NodePtr &node) { } // namespace -std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph &graph) { - std::map > node2layer; +std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph& graph) { + std::map> node2layer; ov::ResultVector results; ov::ParameterVector params; @@ -117,7 +119,7 @@ std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph &graph) { std::map> paramsMap; std::map> resultsMap; - auto get_inputs = [&] (const NodePtr & node) { + auto get_inputs = [&](const NodePtr& node) { auto pr_edges = node->getParentEdges(); ov::OutputVector inputs(pr_edges.size()); @@ -136,10 +138,10 @@ std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph &graph) { return inputs; }; - auto create_ngraph_node = [&](const NodePtr &node) { + auto create_ngraph_node = [&](const NodePtr& node) { bool is_input = false, is_output = false, should_be_hold = false; size_t input_index = -1, output_index = -1; - for (auto && kvp : graph.inputNodesMap) { + for (auto&& kvp : graph.inputNodesMap) { if (kvp.second == node) { is_input = true; input_index = kvp.first; @@ -147,7 +149,7 @@ std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph &graph) { } } - for (auto && kvp : graph.outputNodesMap) { + for (auto&& kvp : graph.outputNodesMap) { if (kvp.second == node) { is_output = true; output_index = kvp.first; @@ -174,7 +176,8 @@ std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph &graph) { return_node = result; } else { return_node = std::make_shared( - get_inputs(node), node->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size()); + get_inputs(node), + node->getSelectedPrimitiveDescriptor()->getConfig().outConfs.size()); for (size_t port = 0; port < return_node->get_output_size(); ++port) { auto& desc = node->getChildEdgeAt(port)->getMemory().getDesc(); @@ -186,7 +189,7 @@ std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph &graph) { to_hold.push_back(return_node); } - for (auto && kvp : meta_data) + for (auto&& kvp : meta_data) return_node->get_rt_info()[kvp.first] = kvp.second; return_node->set_friendly_name(node->getName()); @@ -195,18 +198,18 @@ std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph &graph) { ov::NodeVector nodes; nodes.reserve(graph.graphNodes.size()); - for (auto &node : graph.graphNodes) { // important: graph.graphNodes are in topological order + for (auto& node : graph.graphNodes) { // important: graph.graphNodes are in topological order nodes.emplace_back(create_ngraph_node(node)); node2layer[node] = nodes.back(); } - for (auto && kvp : paramsMap) + for (auto&& kvp : paramsMap) params.push_back(kvp.second); - for (auto && kvp : resultsMap) + for (auto&& kvp : resultsMap) results.push_back(kvp.second); auto holder = !results.empty() ? results[0] : std::make_shared(); - for (auto &node : to_hold) { + for (auto& node : to_hold) { holder->add_control_dependency(node); } @@ -214,7 +217,7 @@ std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph &graph) { } #ifdef CPU_DEBUG_CAPS -void serialize(const Graph &graph) { +void serialize(const Graph& graph) { const std::string& path = graph.getConfig().debugCaps.execGraphPath; if (path.empty()) @@ -231,19 +234,17 @@ void serialize(const Graph &graph) { } } -void serializeToXML(const Graph &graph, const std::string& path) { +void serializeToXML(const Graph& graph, const std::string& path) { if (path.empty()) return; std::string binPath; ov::pass::Manager manager; - manager.register_pass(path, - binPath, - ov::pass::Serialize::Version::IR_V10); + manager.register_pass(path, binPath, ov::pass::Serialize::Version::IR_V10); manager.run_passes(graph.dump()); } -void serializeToCout(const Graph &graph) { +void serializeToCout(const Graph& graph) { for (const auto& node : graph.GetNodes()) { std::cout << "name: " << node->getName() << " [ "; auto nodeDesc = node->getSelectedPrimitiveDescriptor(); @@ -251,8 +252,7 @@ void serializeToCout(const Graph &graph) { auto& inConfs = nodeDesc->getConfig().inConfs; if (!inConfs.empty()) { std::cout << "in: " << inConfs.front().getMemDesc()->getPrecision().get_type_name() - << "/l=" << inConfs.front().getMemDesc()->serializeFormat() - << "; "; + << "/l=" << inConfs.front().getMemDesc()->serializeFormat() << "; "; } auto& outConfs = nodeDesc->getConfig().outConfs; if (!outConfs.empty()) { @@ -260,11 +260,11 @@ void serializeToCout(const Graph &graph) { << "/l=" << outConfs.front().getMemDesc()->serializeFormat(); } } - std::cout << " ]" << std::endl; + std::cout << " ]" << std::endl; } } -void summary_perf(const Graph &graph) { +void summary_perf(const Graph& graph) { if (!graph.getGraphContext()) { return; } @@ -277,7 +277,7 @@ void summary_perf(const Graph &graph) { std::map perf_by_node; double total_avg = 0; uint64_t total = 0; - for (auto &node : graph.GetNodes()) { // important: graph.graphNodes are in topological order + for (auto& node : graph.GetNodes()) { // important: graph.graphNodes are in topological order double avg = node->PerfCounter().avg(); auto type = node->getTypeStr() + "_" + node->getPrimitiveDescriptorType(); auto name = node->getName(); @@ -296,59 +296,60 @@ void summary_perf(const Graph &graph) { perf_by_node[node] = avg; } - if (total_avg < 1) return; + if (total_avg < 1) + return; std::cout << "======= ENABLE_DEBUG_CAPS:OV_CPU_SUMMARY_PERF ======" << std::endl; - std::cout << "Summary of " << graph.GetName() << " @" << std::hash{}(reinterpret_cast(&graph)) << std::endl; + std::cout << "Summary of " << graph.GetName() << " @" << std::hash{}(reinterpret_cast(&graph)) + << std::endl; std::cout << " Total(us): " << (uint64_t)(total) << std::endl; std::cout << " Total_avg(us): " << (uint64_t)(total_avg) << std::endl; { std::cout << " perf_by_type:" << std::endl; - std::vector > A; + std::vector> A; for (auto& it : perf_by_type) A.push_back(it); - sort(A.begin(), A.end(), - [](std::pair& a, - std::pair& b){ - return a.second > b.second; - }); + sort(A.begin(), A.end(), [](std::pair& a, std::pair& b) { + return a.second > b.second; + }); for (auto& it : A) { std::stringstream ss; - int percentage = static_cast(it.second*100/total_avg); - if (percentage == 0) break; - ss << std::setw(10) << std::right << percentage << " % : " << std::setw(8) << std::right << it.second << "(us) " << it.first << std::endl; + int percentage = static_cast(it.second * 100 / total_avg); + if (percentage == 0) + break; + ss << std::setw(10) << std::right << percentage << " % : " << std::setw(8) << std::right << it.second + << "(us) " << it.first << std::endl; std::cout << ss.str(); } } { std::cout << " perf_by_node:" << std::endl; - std::vector > A; + std::vector> A; for (auto& it : perf_by_node) A.push_back(it); - sort(A.begin(), A.end(), - [](std::pair& a, - std::pair& b){ + sort(A.begin(), A.end(), [](std::pair& a, std::pair& b) { return a.second > b.second; }); for (auto& it : A) { std::stringstream ss; - auto percentage = it.second*100/total_avg; + auto percentage = it.second * 100 / total_avg; auto node = it.first; - if (node->PerfCounter().count() == 0) continue; - if (node->PerfCounter().avg() < 1) continue; + if (node->PerfCounter().count() == 0) + continue; + if (node->PerfCounter().avg() < 1) + continue; ss << std::setw(10) << std::right << std::fixed << std::setprecision(2) << percentage << " % " - << std::setw(8) << std::right << node->PerfCounter().avg() << "(us)x" << node->PerfCounter().count() - << " #" << node->getExecIndex() - << " " << node->getName() - << " " << node->getTypeStr() + "_" + node->getPrimitiveDescriptorType() << std::endl; + << std::setw(8) << std::right << node->PerfCounter().avg() << "(us)x" << node->PerfCounter().count() + << " #" << node->getExecIndex() << " " << node->getName() << " " + << node->getTypeStr() + "_" + node->getPrimitiveDescriptorType() << std::endl; std::cout << ss.str(); } } } -void average_counters(const Graph &graph) { +void average_counters(const Graph& graph) { /** * @todo improve logic for a graph with inner graphs: * - collect counters only for the outer graph if full path is specified @@ -359,7 +360,8 @@ void average_counters(const Graph &graph) { static int graphIndex = 0; std::ofstream file; - std::string fileName = graph.getConfig().debugCaps.averageCountersPath + "_" + std::to_string(graphIndex++) + ".csv"; + std::string fileName = + graph.getConfig().debugCaps.averageCountersPath + "_" + std::to_string(graphIndex++) + ".csv"; file.open(fileName); @@ -379,18 +381,14 @@ void average_counters(const Graph &graph) { const auto cpuTime = toMs(avg); const auto realTime = cpuTime; - file << node->getName() << ";" - << status << ";" - << node->getTypeStr() << ";" - << node->getPrimitiveDescriptorType() << ";" - << realTime << ";" - << cpuTime << ";" - << "\n"; + file << node->getName() << ";" << status << ";" << node->getTypeStr() << ";" + << node->getPrimitiveDescriptorType() << ";" << realTime << ";" << cpuTime << ";" + << "\n"; return avg; }; - for (auto &node : graph.GetNodes()) { + for (auto& node : graph.GetNodes()) { if (node->isConstant()) continue; @@ -399,11 +397,12 @@ void average_counters(const Graph &graph) { const auto totalMs = toMs(total); - file << "Total;;;;" << totalMs << ";" << totalMs << ";" << "\n"; + file << "Total;;;;" << totalMs << ";" << totalMs << ";" + << "\n"; file.close(); } #endif -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/graph_dumper.h b/src/plugins/intel_cpu/src/graph_dumper.h index 417db7e4c3cdc5..40af2fd44c61e6 100644 --- a/src/plugins/intel_cpu/src/graph_dumper.h +++ b/src/plugins/intel_cpu/src/graph_dumper.h @@ -4,19 +4,19 @@ #pragma once -#include "graph.h" - #include +#include "graph.h" + namespace ov { namespace intel_cpu { -std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph &graph); +std::shared_ptr dump_graph_as_ie_ngraph_net(const Graph& graph); #ifdef CPU_DEBUG_CAPS -void serialize(const Graph &graph); -void summary_perf(const Graph &graph); -void average_counters(const Graph &graph); -#endif // CPU_DEBUG_CAPS +void serialize(const Graph& graph); +void summary_perf(const Graph& graph); +void average_counters(const Graph& graph); +#endif // CPU_DEBUG_CAPS -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/graph_optimizer.cpp b/src/plugins/intel_cpu/src/graph_optimizer.cpp index 94f54fc4c59b55..fe0df309dc32f1 100644 --- a/src/plugins/intel_cpu/src/graph_optimizer.cpp +++ b/src/plugins/intel_cpu/src/graph_optimizer.cpp @@ -4,6 +4,7 @@ #include "graph_optimizer.h" +#include "cpu_types.h" #include "dnnl_extension_utils.h" #include "nodes/bin_conv.h" #include "nodes/common/cpu_convert.h" @@ -22,28 +23,26 @@ #include "nodes/transpose.h" #include "onednn/dnnl.h" #include "openvino/opsets/opset1.hpp" -#include "cpu_types.h" #include "utils/cpu_utils.hpp" #include "utils/debug_capabilities.h" #include "utils/general_utils.h" // WA for xbyak.h #ifdef _WIN32 -# ifndef _WINSOCKAPI_ -# define _WINSOCKAPI_ -# endif -# ifndef _WINSOCK2API_ -# define _WINSOCK2API_ -#endif +# ifndef _WINSOCKAPI_ +# define _WINSOCKAPI_ +# endif +# ifndef _WINSOCK2API_ +# define _WINSOCK2API_ +# endif #endif -#include "cpu/x64/cpu_isa_traits.hpp" - -#include +#include #include #include #include -#include +#include +#include "cpu/x64/cpu_isa_traits.hpp" #include "itt.h" #include "memory_desc/cpu_memory_desc_utils.h" @@ -55,11 +54,15 @@ namespace intel_cpu { GraphOptimizer::GraphOptimizer() {} -void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) { +void GraphOptimizer::ApplyCommonGraphOptimizations(Graph& graph) { // For conv with input zp, canBeExecutedInInt8() check has dependency on input zero point check. - // Also zero point node is the input of computing-intensive nodes. Most others fusing are the output of computing-intensive nodes. - // So Locate the FuseConvolutionAndZeroPoints() as the first optimization. - OV_ITT_SCOPE_CHAIN(FIRST_INFERENCE, taskChain, itt::domains::intel_cpu_LT, "ApplyCommonGraphOptimizations", "FuseConvolutionAndZeroPoints"); + // Also zero point node is the input of computing-intensive nodes. Most others fusing are the output of + // computing-intensive nodes. So Locate the FuseConvolutionAndZeroPoints() as the first optimization. + OV_ITT_SCOPE_CHAIN(FIRST_INFERENCE, + taskChain, + itt::domains::intel_cpu_LT, + "ApplyCommonGraphOptimizations", + "FuseConvolutionAndZeroPoints"); FuseConvolutionAndZeroPoints(graph); graph.RemoveDroppedNodes(); @@ -187,7 +190,7 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) { graph.RemoveDroppedEdges(); } -void GraphOptimizer::ApplyImplSpecificGraphOptimizations(Graph &graph) { +void GraphOptimizer::ApplyImplSpecificGraphOptimizations(Graph& graph) { OV_ITT_SCOPE(FIRST_INFERENCE, itt::domains::intel_cpu_LT, "GraphOptimizer::ApplyImplSpecificGraphOptimizations"); DropDoubleReorders(graph); @@ -202,7 +205,7 @@ void GraphOptimizer::ApplyImplSpecificGraphOptimizations(Graph &graph) { graph.RemoveDroppedEdges(); } -void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) { +void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isDQScaleGraphPattern = [](NodePtr node) { @@ -211,13 +214,12 @@ void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) { } auto parentNode = node->getParentEdgeAt(0)->getParent(); auto scaleNode = node->getParentEdgeAt(1)->getParent(); - if (!(parentNode->getType() == Type::Convolution - || parentNode->getType() == Type::MatMul - || parentNode->getType() == Type::Deconvolution)) + if (!(parentNode->getType() == Type::Convolution || parentNode->getType() == Type::MatMul || + parentNode->getType() == Type::Deconvolution)) return false; if (!scaleNode->isConstant()) return false; - //Only Fusing scales for INT8 precision. + // Only Fusing scales for INT8 precision. if (!parentNode->canBeExecutedInInt8()) return false; return (parentNode->getParentEdges().size() == 2); @@ -233,8 +235,7 @@ void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) { if (!node->getFusedWith().empty() || !scales->getFusedWith().empty()) return false; - const auto scalesDims = getNormalizedDimsBySize(scales->getOutputShapeAtPort(0).getDims(), - nodeOutDims.size()); + const auto scalesDims = getNormalizedDimsBySize(scales->getOutputShapeAtPort(0).getDims(), nodeOutDims.size()); if (nodeOutDims.size() != scalesDims.size() || scalesDims.size() < 2) return false; @@ -261,7 +262,7 @@ void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) { if (scalesData == nullptr) OPENVINO_THROW("scalesBlob has not allocated buffer"); auto scalesDims = getNormalizedDimsBySize(scales->getOutputShapeAtPort(0).getDims(), - node->getOutputShapeAtPort(0).getDims().size()); + node->getOutputShapeAtPort(0).getDims().size()); auto scaleSize = std::accumulate(scalesDims.begin(), scalesDims.end(), 1, std::multiplies()); node->fuseDQScales(scalesData, scaleSize); return true; @@ -269,16 +270,21 @@ void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) { for (size_t i = 0; i < graphNodes.size(); i++) { auto mul = graphNodes[i]; - if (!isDQScaleGraphPattern(mul)) continue; + if (!isDQScaleGraphPattern(mul)) + continue; CPU_GRAPH_OPTIMIZER_SCOPE(FuseConvMatmulFCDeconvAndDQScales); auto node = mul->getParentEdgeAt(0)->getParent(); auto scales = mul->getParentEdgeAt(1)->getParent(); - if (!scaleDimsCheck(node, scales)) continue; + if (!scaleDimsCheck(node, scales)) + continue; if (initializeDeQuantizedScales(node, scales)) { - DEBUG_LOG("GraphOptimizer##FusingDQ: Node ##", mul->getName(), " optimized as DQ scales of Node ##", node->getName()); + DEBUG_LOG("GraphOptimizer##FusingDQ: Node ##", + mul->getName(), + " optimized as DQ scales of Node ##", + node->getName()); node->addOriginalLayer(mul->getOriginalLayers()); auto p_edge = mul->getParentEdgeAt(1); graph.RemoveEdge(p_edge); @@ -287,7 +293,7 @@ void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) { } } -void GraphOptimizer::FuseConvolutionMatMulDeconvAndBias(Graph &graph) { +void GraphOptimizer::FuseConvolutionMatMulDeconvAndBias(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableParentNode = [](const NodePtr& node) { @@ -300,16 +306,14 @@ void GraphOptimizer::FuseConvolutionMatMulDeconvAndBias(Graph &graph) { return false; if (!deconv) - return (one_of(node->getType(), Type::Convolution, Type::MatMul) && - node->getParentEdges().size() == 2); + return (one_of(node->getType(), Type::Convolution, Type::MatMul) && node->getParentEdges().size() == 2); else return deconv->canFuseBias(); }; auto isSuitableChildNode = [&](const NodePtr& parentNode, const NodePtr& childNode) { - if (childNode->getAlgorithm() != Algorithm::EltwiseAdd - || !childNode->getFusedWith().empty() - || childNode->getParentEdges().size() != 2) + if (childNode->getAlgorithm() != Algorithm::EltwiseAdd || !childNode->getFusedWith().empty() || + childNode->getParentEdges().size() != 2) return false; auto biasPort = childNode->getParentEdgeAt(0)->getParent() == parentNode ? 1 : 0; @@ -318,10 +322,11 @@ void GraphOptimizer::FuseConvolutionMatMulDeconvAndBias(Graph &graph) { return false; const auto parentOutDims = parentNode->getOutputShapeAtPort(0).getDims(); - const auto biasDims = getNormalizedDimsBySize(biasNode->getOutputShapeAtPort(0).getDims(), - parentOutDims.size()); - // TODO [NM]: Legacy ConvBias fusion transformation supports both per-tensor (via explicit broadcasing) and per-channel cases. - // Most of the real models contain per-channel bias, so we need to reavaluate the need to support per-tensor variant. + const auto biasDims = + getNormalizedDimsBySize(biasNode->getOutputShapeAtPort(0).getDims(), parentOutDims.size()); + // TODO [NM]: Legacy ConvBias fusion transformation supports both per-tensor (via explicit broadcasing) and + // per-channel cases. Most of the real models contain per-channel bias, so we need to reavaluate the need to + // support per-tensor variant. if (parentOutDims.size() != biasDims.size() || biasDims.size() < 2) return false; @@ -357,9 +362,11 @@ void GraphOptimizer::FuseConvolutionMatMulDeconvAndBias(Graph &graph) { for (size_t i = 0; i < parents.size(); i++) { auto p_edge = parents[i].lock(); - if (!p_edge) continue; + if (!p_edge) + continue; auto parent = p_edge->getParent(); - if (!parent) continue; + if (!parent) + continue; if (parent == parentNode) { for (size_t j = 0; j < childs.size(); j++) { @@ -369,7 +376,7 @@ void GraphOptimizer::FuseConvolutionMatMulDeconvAndBias(Graph &graph) { if (!child) continue; - EdgePtr &remEdge = p_edge; + EdgePtr& remEdge = p_edge; int inNum = 0; if (remEdge) { inNum = remEdge->getInputNum(); @@ -384,7 +391,7 @@ void GraphOptimizer::FuseConvolutionMatMulDeconvAndBias(Graph &graph) { graph.CreateEdge(parent, child, inNum, outNum); } } else { - EdgePtr &remEdge = p_edge; + EdgePtr& remEdge = p_edge; int inNum = 0; if (remEdge) { inNum = remEdge->getInputNum(); @@ -398,48 +405,57 @@ void GraphOptimizer::FuseConvolutionMatMulDeconvAndBias(Graph &graph) { // ONEDNN Conv, Deconv, FC would need the bias to be flatten into 1D tensor. // Usually the bias output shape would be normalized to align rank with Conv/Deconv/FC output. // To avoid duplicate reshape WA code in nodes, here we flatten the shape. - // Most bias nodes are const Input and bias memory primitive has been initialized as const memory when constructing CPU Input node. - // Const memory is not allowed to be modified after initialized. It means we can't redefine const bias memory primitive. - // So let's insert a reshape node to flatten the bias shape into 1D and const folding node will be executed during the compiling stage. - const bool needReshape = (targetNode->getType() != Type::MatMul && - biasOutputShape.getRank() != 1); + // Most bias nodes are const Input and bias memory primitive has been initialized as const memory when + // constructing CPU Input node. Const memory is not allowed to be modified after initialized. It means + // we can't redefine const bias memory primitive. So let's insert a reshape node to flatten the bias + // shape into 1D and const folding node will be executed during the compiling stage. + const bool needReshape = (targetNode->getType() != Type::MatMul && biasOutputShape.getRank() != 1); if (needReshape) { // Bias -> Reshape -> Conv/Deconv/FC const VectorDims flattenShape = {biasOutputShape.getElementsCount()}; // Construct Ngraph Reshape node and CPU Reshape node. - auto reshapeConstInput = std::make_shared(ov::element::i32, ov::Shape{1}, flattenShape); - auto reshapeDummyInput = std::make_shared( - biasNode->getOriginalOutputPrecisionAtPort(0), - biasOutputShape.toPartialShape()); - const auto reshape = std::make_shared(reshapeDummyInput, reshapeConstInput, false); + auto reshapeConstInput = + std::make_shared(ov::element::i32, ov::Shape{1}, flattenShape); + auto reshapeDummyInput = + std::make_shared(biasNode->getOriginalOutputPrecisionAtPort(0), + biasOutputShape.toPartialShape()); + const auto reshape = + std::make_shared(reshapeDummyInput, reshapeConstInput, false); reshape->set_friendly_name(biasNode->getName() + "_flatten_reshape"); - const auto cpuReshapeNode = std::make_shared(reshape, graph.getGraphContext()); + const auto cpuReshapeNode = + std::make_shared(reshape, graph.getGraphContext()); // Insert Reshape between bias node and Conv/Deconv/FC graph.InsertNode(biasNode, targetNode, cpuReshapeNode, inNum, outNum, false); // Insert the Reshape const input node and edge into CPU graph. - const auto cpuReshapeConstInput = std::make_shared(reshapeConstInput, graph.getGraphContext()); + const auto cpuReshapeConstInput = + std::make_shared(reshapeConstInput, graph.getGraphContext()); graph.AddNode(cpuReshapeConstInput); graph.CreateEdge(cpuReshapeConstInput, cpuReshapeNode, 0, 1); - DEBUG_LOG("GraphOptimizer##FusingBias:Flatten Bias node from shape ", PartialShape{biasOutputShape.getDims()}, - " to ", PartialShape{flattenShape}); + DEBUG_LOG("GraphOptimizer##FusingBias:Flatten Bias node from shape ", + PartialShape{biasOutputShape.getDims()}, + " to ", + PartialShape{flattenShape}); // Update bias output shape to be flatten shape. biasOutputShape = Shape{flattenShape}; } else { // Bias is connected as input edge. graph.CreateEdge(biasNode, targetNode, inNum, outNum); } - //Add the Bias inputshape into conv/FC/Deconv/Matmul. + // Add the Bias inputshape into conv/FC/Deconv/Matmul. targetNode->inputShapes.push_back(biasOutputShape); } } - DEBUG_LOG("GraphOptimizer##FusingBias:Node ##: ", childNode->getName(), " initialize as Bias of Node ##", parentNode->getName()); + DEBUG_LOG("GraphOptimizer##FusingBias:Node ##: ", + childNode->getName(), + " initialize as Bias of Node ##", + parentNode->getName()); parentNode->addOriginalLayer(childNode->getOriginalLayers()); parentNode->addOriginalInputPrecision(childNode->getOriginalInputPrecisionAtPort(biasPort)); graph.DropNode(childNode); } } -void GraphOptimizer::FuseDeconvolutionAndSimpleOperation(Graph &graph) { +void GraphOptimizer::FuseDeconvolutionAndSimpleOperation(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableParentNode = [](NodePtr node) { @@ -485,7 +501,7 @@ void GraphOptimizer::FuseDeconvolutionAndSimpleOperation(Graph &graph) { childNode->fuseInto(parentNode); auto parentEdges = childNode->parentEdges; - for (auto &parentEdge : parentEdges) { + for (auto& parentEdge : parentEdges) { auto p_edge = parentEdge.lock(); if (p_edge->getParent()->getType() == Type::Deconvolution) continue; @@ -497,7 +513,7 @@ void GraphOptimizer::FuseDeconvolutionAndSimpleOperation(Graph &graph) { } } -void GraphOptimizer::FuseMultiplyAndAdd(Graph &graph) { +void GraphOptimizer::FuseMultiplyAndAdd(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableSecondInput = [](const NodePtr& node, VectorDims dataDims) { @@ -509,9 +525,9 @@ void GraphOptimizer::FuseMultiplyAndAdd(Graph &graph) { auto getChannelAxis = [](const VectorDims& dims) { auto channelAxis = -1; - for (size_t i = 0; i < dims.size(); i ++) { + for (size_t i = 0; i < dims.size(); i++) { if (dims[i] != 1) { - if (channelAxis != -1) // more than one axis is != 1 + if (channelAxis != -1) // more than one axis is != 1 return -1; else channelAxis = i; @@ -539,11 +555,13 @@ void GraphOptimizer::FuseMultiplyAndAdd(Graph &graph) { }; auto isSuitableChildNode = [&](const NodePtr& parentNode, const NodePtr& childNode) { - if (childNode->getAlgorithm() != Algorithm::EltwiseAdd || !childNode->getFusedWith().empty() || childNode->getParentEdges().size() != 2) + if (childNode->getAlgorithm() != Algorithm::EltwiseAdd || !childNode->getFusedWith().empty() || + childNode->getParentEdges().size() != 2) return false; - return isSuitableSecondInput(childNode->getParentEdgeAt(1)->getParent(), childNode->getInputShapeAtPort(0).getDims()) && - parentNode->canFuse(childNode); + return isSuitableSecondInput(childNode->getParentEdgeAt(1)->getParent(), + childNode->getInputShapeAtPort(0).getDims()) && + parentNode->canFuse(childNode); }; auto parent = graphNodes.begin(); @@ -569,9 +587,11 @@ void GraphOptimizer::FuseMultiplyAndAdd(Graph &graph) { for (size_t i = 0; i < parents.size(); i++) { auto p_edge = parents[i].lock(); - if (!p_edge) continue; + if (!p_edge) + continue; auto parent = p_edge->getParent(); - if (!parent) continue; + if (!parent) + continue; if (parent == parentNode) { for (size_t j = 0; j < childs.size(); j++) { @@ -581,7 +601,7 @@ void GraphOptimizer::FuseMultiplyAndAdd(Graph &graph) { if (!child) continue; - EdgePtr &remEdge = p_edge; + EdgePtr& remEdge = p_edge; int inNum = 0; if (remEdge) { inNum = remEdge->getInputNum(); @@ -596,7 +616,7 @@ void GraphOptimizer::FuseMultiplyAndAdd(Graph &graph) { graph.CreateEdge(parent, child, inNum, outNum); } } else { - EdgePtr &remEdge = p_edge; + EdgePtr& remEdge = p_edge; int inNum = 0; if (remEdge) { inNum = remEdge->getInputNum(); @@ -652,9 +672,11 @@ void GraphOptimizer::MergeConvertAndScaleShift(Graph& graph) { const auto parents = parentNode->parentEdges; for (size_t i = 0; i < parents.size(); i++) { auto p_edge = parents[i].lock(); - if (!p_edge) continue; + if (!p_edge) + continue; auto parent = p_edge->getParent(); - if (!parent) continue; + if (!parent) + continue; if (!parentNode->childEdges[0].lock()) continue; @@ -688,8 +710,8 @@ void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) { return; #endif - // This optimization fuses Convert (fp16 -> bf16/fp32) on weights directly to FC input to allow precision conversion handling based on internal logic - // (e.g. fuse conversion with weights reordering) + // This optimization fuses Convert (fp16 -> bf16/fp32) on weights directly to FC input to allow precision conversion + // handling based on internal logic (e.g. fuse conversion with weights reordering) auto& graphNodes = graph.GetNodes(); for (const auto& fullyConnected : graphNodes) { if (fullyConnected->getType() != Type::FullyConnected) { @@ -722,14 +744,13 @@ void GraphOptimizer::FuseFCAndTransposeOnWeights(Graph& graph) { return; #endif - // This optimization allows us to avoid transposing the weights in Transpose node and do it directly along with reordering in FC node + // This optimization allows us to avoid transposing the weights in Transpose node and do it directly along with + // reordering in FC node auto& graphNodes = graph.GetNodes(); auto isSuitablePattern = [](NodePtr parent) { - bool res = true && parent->getType() == Type::Transpose - && parent->getChildEdges().size() == 1 - && parent->getChildEdgeAt(0)->getChild()->getType() == Type::FullyConnected - && parent->isConstant(); + bool res = true && parent->getType() == Type::Transpose && parent->getChildEdges().size() == 1 && + parent->getChildEdgeAt(0)->getChild()->getType() == Type::FullyConnected && parent->isConstant(); return res; }; @@ -744,7 +765,7 @@ void GraphOptimizer::FuseFCAndTransposeOnWeights(Graph& graph) { } } -void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) { +void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableConvNode = [](NodePtr node) { @@ -777,9 +798,10 @@ void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) { return false; // The plug-in doesn't support FP32 convolution with input/weights zero points. - // In case weights are in FP32 (or we have zero points on weights which are not supported by INT8 convolution) we cannot use - // INT8 implementation so we have to disable input zero points fusing as well. - if (parent1->getType() != Type::Input || !parent1->isConstant() || parent1->getOriginalOutputPrecisionAtPort(0) != ov::element::i8) { + // In case weights are in FP32 (or we have zero points on weights which are not supported by INT8 convolution) + // we cannot use INT8 implementation so we have to disable input zero points fusing as well. + if (parent1->getType() != Type::Input || !parent1->isConstant() || + parent1->getOriginalOutputPrecisionAtPort(0) != ov::element::i8) { return false; } @@ -827,7 +849,7 @@ void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) { if (zeroPointsData == nullptr) OPENVINO_THROW("zeroPointsBlob has not allocated buffer"); - auto zeroPointDataSize = parent0->getInputShapeAtPort(1).getDims()[1]; + auto zeroPointDataSize = parent0->getInputShapeAtPort(1).getDims()[1]; if (Shape::UNDEFINED_DIM == zeroPointDataSize) { return false; } @@ -863,8 +885,10 @@ void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) { auto OC = weightsConstantDims[0 + groupOffset]; auto IC = weightsConstantDims[1 + groupOffset]; - auto KD = weightsConstantDims.size() == (5 + groupOffset) ? weightsConstantDims[weightsConstantDims.size() - 3] : 1; - auto KH = weightsConstantDims.size() == (3 + groupOffset) ? 1 : weightsConstantDims[weightsConstantDims.size() - 2]; + auto KD = + weightsConstantDims.size() == (5 + groupOffset) ? weightsConstantDims[weightsConstantDims.size() - 3] : 1; + auto KH = + weightsConstantDims.size() == (3 + groupOffset) ? 1 : weightsConstantDims[weightsConstantDims.size() - 2]; auto KW = weightsConstantDims[weightsConstantDims.size() - 1]; for (size_t g = 0; g < G; g++) { @@ -874,20 +898,19 @@ void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) { for (size_t kd = 0; kd < KD; kd++) { for (size_t kh = 0; kh < KH; kh++) { for (size_t kw = 0; kw < KW; kw++) { - size_t widx = g * OC * IC * KD * KH * KW + - oc * IC * KD * KH * KW + - ic * KD * KH * KW + - kd * KH * KW + - kh * KW + - kw; + size_t widx = g * OC * IC * KD * KH * KW + oc * IC * KD * KH * KW + ic * KD * KH * KW + + kd * KH * KW + kh * KW + kw; auto w = static_cast(weightsPtr[widx]); - auto izp = !convNode->legacyInputZeroPoints.empty() ? static_cast(convNode->legacyInputZeroPoints[g * IC + ic]) : 0; + auto izp = !convNode->legacyInputZeroPoints.empty() + ? static_cast(convNode->legacyInputZeroPoints[g * IC + ic]) + : 0; a += w * izp; - auto wzp = !convNode->legacyWeightsZeroPoints.empty() ? - static_cast(convNode->legacyWeightsZeroPoints[g * OC + oc]) : 0; + auto wzp = !convNode->legacyWeightsZeroPoints.empty() + ? static_cast(convNode->legacyWeightsZeroPoints[g * OC + oc]) + : 0; a -= wzp * izp; } } @@ -900,7 +923,8 @@ void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) { for (size_t i = 0; i < graphNodes.size(); i++) { auto conv = graphNodes[i]; - if (!isSuitableConvNode(conv)) continue; + if (!isSuitableConvNode(conv)) + continue; CPU_GRAPH_OPTIMIZER_SCOPE(FuseConvolutionAndZeroPoints_ConvNode); @@ -908,8 +932,10 @@ void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) { auto weightsEltwise = conv->getParentEdgeAt(1)->getParent(); if (initializeInputZeroPoints(conv, dataEltwise, weightsEltwise)) { auto p_edge = dataEltwise->getParentEdgeAt(1); - DEBUG_LOG("[GraphOptimizer##FusingZeorPoint]:Eltwise Subtract Node ##", dataEltwise->getName(), - " is optimized as zeropoint of Conv ##", conv->getName()); + DEBUG_LOG("[GraphOptimizer##FusingZeorPoint]:Eltwise Subtract Node ##", + dataEltwise->getName(), + " is optimized as zeropoint of Conv ##", + conv->getName()); graph.RemoveEdge(p_edge); graph.DropNode(dataEltwise); initializeOutputCompensation(conv); @@ -917,7 +943,7 @@ void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) { } } -void GraphOptimizer::FuseFullyConnectedAndSimpleOperation(Graph &graph) { +void GraphOptimizer::FuseFullyConnectedAndSimpleOperation(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableParentNode = [](NodePtr node) { @@ -944,7 +970,7 @@ void GraphOptimizer::FuseFullyConnectedAndSimpleOperation(Graph &graph) { if (childNode->getType() == Type::FakeQuantize || childNode->getType() == Type::Eltwise) { auto parentEdges = childNode->parentEdges; - for (auto &parentEdge : parentEdges) { + for (auto& parentEdge : parentEdges) { auto p_edge = parentEdge.lock(); if (p_edge->getParent()->getType() == Type::FullyConnected) continue; @@ -957,7 +983,7 @@ void GraphOptimizer::FuseFullyConnectedAndSimpleOperation(Graph &graph) { } } -void GraphOptimizer::FuseMatMulAndSimpleOperation(Graph &graph) { +void GraphOptimizer::FuseMatMulAndSimpleOperation(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSutableParentNode = [](const NodePtr& node) { @@ -984,7 +1010,7 @@ void GraphOptimizer::FuseMatMulAndSimpleOperation(Graph &graph) { if (childNode->getType() == Type::FakeQuantize || childNode->getType() == Type::Eltwise) { auto parentEdges = childNode->parentEdges; - for (auto &parentEdge : parentEdges) { + for (auto& parentEdge : parentEdges) { auto p_edge = parentEdge.lock(); if (p_edge->getParent()->getType() == Type::MatMul) continue; @@ -997,14 +1023,14 @@ void GraphOptimizer::FuseMatMulAndSimpleOperation(Graph &graph) { } } -void GraphOptimizer::FuseConvolutionAndDWConvolution(Graph &graph) { +void GraphOptimizer::FuseConvolutionAndDWConvolution(Graph& graph) { auto& graphNodes = graph.GetNodes(); - auto isConvolutionNode = [](const NodePtr &node) { + auto isConvolutionNode = [](const NodePtr& node) { return node->getType() == Type::Convolution; }; - auto is1x1Convolution = [](const std::shared_ptr &conv) { + auto is1x1Convolution = [](const std::shared_ptr& conv) { const auto weightRank = conv->getWeightDims().size(); return conv->getWeightDims()[weightRank - 1] == 1 && conv->getWeightDims()[weightRank - 2] == 1; }; @@ -1023,10 +1049,10 @@ void GraphOptimizer::FuseConvolutionAndDWConvolution(Graph &graph) { if (!conv->legacyWeightsZeroPoints.empty()) return false; - const auto &strides = conv->getStride(); - const auto &paddings = conv->getPaddingL(); - const auto &inDims = node->getInputShapeAtPort(0).getDims(); - const auto &outDims = node->getOutputShapeAtPort(0).getDims(); + const auto& strides = conv->getStride(); + const auto& paddings = conv->getPaddingL(); + const auto& inDims = node->getInputShapeAtPort(0).getDims(); + const auto& outDims = node->getOutputShapeAtPort(0).getDims(); bool isSupportedParams = conv->getGroupNum() == 1 && inDims.size() == 4 && dimsEqualStrong(inDims[inDims.size() - 1], outDims[outDims.size() - 1]) && @@ -1039,12 +1065,13 @@ void GraphOptimizer::FuseConvolutionAndDWConvolution(Graph &graph) { static_cast(paddings[paddings.size() - 1]), static_cast(paddings[paddings.size() - 2])) && !conv->canBeExecutedInInt8(); - if (!isSupportedParams) return false; + if (!isSupportedParams) + return false; return node->getChildEdges().size() == 1 && isConvolutionNode(node->getChildEdgeAt(0)->getChild()); }; - auto isSuitableChildConvolution = [&](const NodePtr &parentNode, const NodePtr &childNode) { + auto isSuitableChildConvolution = [&](const NodePtr& parentNode, const NodePtr& childNode) { if (parentNode->isDropped() || childNode->isDropped()) return false; @@ -1059,15 +1086,19 @@ void GraphOptimizer::FuseConvolutionAndDWConvolution(Graph &graph) { if (convParent == nullptr) OPENVINO_THROW("Cannot cast to convolution node ", parentNode->getName()); - if (!everyone_is(ov::element::f32, convParent->getOriginalOutputPrecisionAtPort(0), convChild->getOriginalInputPrecisionAtPort(0), - convChild->getOriginalOutputPrecisionAtPort(0))) + if (!everyone_is(ov::element::f32, + convParent->getOriginalOutputPrecisionAtPort(0), + convChild->getOriginalInputPrecisionAtPort(0), + convChild->getOriginalOutputPrecisionAtPort(0))) return false; - auto parentOutputPrecision = !parentNode->fusedWith.empty() + auto parentOutputPrecision = + !parentNode->fusedWith.empty() ? parentNode->fusedWith[parentNode->fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0) : parentNode->getOriginalOutputPrecisionAtPort(0); - auto childOutputPrecision = !childNode->fusedWith.empty() + auto childOutputPrecision = + !childNode->fusedWith.empty() ? childNode->fusedWith[childNode->fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0) : childNode->getOriginalOutputPrecisionAtPort(0); @@ -1103,7 +1134,7 @@ void GraphOptimizer::FuseConvolutionAndDWConvolution(Graph &graph) { return isSupportedParams; }; - auto isFusingWorthwhile = [&](const NodePtr &parentNode, const NodePtr &childNode) { + auto isFusingWorthwhile = [&](const NodePtr& parentNode, const NodePtr& childNode) { if (!childNode->inputShapes[0].isStatic() || !childNode->outputShapes[0].isStatic()) { return false; } @@ -1114,7 +1145,7 @@ void GraphOptimizer::FuseConvolutionAndDWConvolution(Graph &graph) { int L3_cache_size = dnnl::utils::get_cache_size(3, false); int dw_conv_input_size = inDims[0] * inDims[1] * inDims[2] * inDims[3] * elemSize; - int dw_conv_output_size = outDims[0] * outDims[1]* outDims[2] * outDims[3] * elemSize; + int dw_conv_output_size = outDims[0] * outDims[1] * outDims[2] * outDims[3] * elemSize; auto parentConvolutionNode = std::dynamic_pointer_cast(parentNode); if (parentConvolutionNode == nullptr) @@ -1127,19 +1158,23 @@ void GraphOptimizer::FuseConvolutionAndDWConvolution(Graph &graph) { }; for (size_t i = 0; i < graphNodes.size(); i++) { - if (!isConvolutionNode(graphNodes[i])) continue; + if (!isConvolutionNode(graphNodes[i])) + continue; auto parentConvNode = graphNodes[i]; - if (!isSuitableParentConvolution(parentConvNode)) continue; + if (!isSuitableParentConvolution(parentConvNode)) + continue; CPU_GRAPH_OPTIMIZER_SCOPE(FuseConvolutionAndDWConvolution_ParentConv); auto childConvNode = parentConvNode->getChildEdgeAt(0)->getChild(); - if (!isSuitableChildConvolution(parentConvNode, childConvNode)) continue; + if (!isSuitableChildConvolution(parentConvNode, childConvNode)) + continue; CPU_GRAPH_OPTIMIZER_SCOPE(FuseConvolutionAndDWConvolution_ChildConv); - if (!isFusingWorthwhile(parentConvNode, childConvNode)) continue; + if (!isFusingWorthwhile(parentConvNode, childConvNode)) + continue; parentConvNode->addFusedNode(childConvNode); @@ -1153,12 +1188,12 @@ void GraphOptimizer::FuseConvolutionAndDWConvolution(Graph &graph) { } // TODO [NM]: unite with FuseConvolutionAndSimpleOperation -void GraphOptimizer::FuseConvolutionAndSimpleOperationThroughMaxPool(Graph &graph) { +void GraphOptimizer::FuseConvolutionAndSimpleOperationThroughMaxPool(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableParentNode = [](NodePtr node) { - return (node->getType() == Type::Convolution || node->getType() == Type::BinaryConvolution) && node->getChildEdges().size() == 1 && - node->getOriginalOutputPrecisionAtPort(0) == ov::element::f32; + return (node->getType() == Type::Convolution || node->getType() == Type::BinaryConvolution) && + node->getChildEdges().size() == 1 && node->getOriginalOutputPrecisionAtPort(0) == ov::element::f32; }; auto parent = graphNodes.begin(); @@ -1197,7 +1232,7 @@ void GraphOptimizer::FuseConvolutionAndSimpleOperationThroughMaxPool(Graph &grap parentNode->addFusedNode(fuseCandidate); parentNode->addOriginalLayer(fuseCandidate->getOriginalLayers()); auto parentEdges = fuseCandidate->parentEdges; - for (auto &parentEdge : parentEdges) { + for (auto& parentEdge : parentEdges) { auto p_edge = parentEdge.lock(); if (p_edge->getParent() == childNode) continue; @@ -1208,11 +1243,12 @@ void GraphOptimizer::FuseConvolutionAndSimpleOperationThroughMaxPool(Graph &grap } } -void GraphOptimizer::FuseConvolutionAndSimpleOperation(Graph &graph) { +void GraphOptimizer::FuseConvolutionAndSimpleOperation(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableParentNode = [](NodePtr node) { - return (node->getType() == Type::Convolution || node->getType() == Type::BinaryConvolution) && node->getChildEdges().size() == 1; + return (node->getType() == Type::Convolution || node->getType() == Type::BinaryConvolution) && + node->getChildEdges().size() == 1; }; auto parent = graphNodes.begin(); @@ -1237,7 +1273,7 @@ void GraphOptimizer::FuseConvolutionAndSimpleOperation(Graph &graph) { if (childNode->getType() == Type::FakeQuantize || childNode->getType() == Type::Eltwise) { auto parentEdges = childNode->parentEdges; - for (auto &parentEdge : parentEdges) { + for (auto& parentEdge : parentEdges) { auto p_edge = parentEdge.lock(); if (p_edge->getParent()->getType() == parentNodeType) continue; @@ -1250,7 +1286,7 @@ void GraphOptimizer::FuseConvolutionAndSimpleOperation(Graph &graph) { } } -void GraphOptimizer::FusePoolingAndFakeQuantize(Graph &graph) { +void GraphOptimizer::FusePoolingAndFakeQuantize(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableParentNode = [](NodePtr node) { @@ -1268,12 +1304,14 @@ void GraphOptimizer::FusePoolingAndFakeQuantize(Graph &graph) { for (size_t i = 0; i < graphNodes.size(); i++) { auto parent = graphNodes[i]; - if (!isSuitableParentNode(parent)) continue; + if (!isSuitableParentNode(parent)) + continue; CPU_GRAPH_OPTIMIZER_SCOPE(FusePoolingAndFakeQuantize_ParentNode); auto child = parent->getChildEdgeAt(0)->getChild(); - if (!isSuitableChildNode(child)) continue; + if (!isSuitableChildNode(child)) + continue; CPU_GRAPH_OPTIMIZER_SCOPE(FusePoolingAndFakeQuantize_ChildNode); @@ -1300,14 +1338,14 @@ void GraphOptimizer::FusePoolingAndFakeQuantize(Graph &graph) { * @param child node we try to find * @return True if child is one of data supplier */ -static bool is_data_dependency(const std::shared_ptr &parent, - const std::shared_ptr &child) { +static bool is_data_dependency(const std::shared_ptr& parent, const std::shared_ptr& child) { std::set visited; - std::list nextLayers {parent.get()}; + std::list nextLayers{parent.get()}; for (; !nextLayers.empty();) { auto layer = *nextLayers.begin(); - if (layer == child.get()) return true; + if (layer == child.get()) + return true; for (auto& oe : layer->getChildEdges()) { auto nn = oe.lock()->getChild(); if (visited.find(nn.get()) == visited.end()) { @@ -1358,19 +1396,18 @@ static bool is_data_dependency(const std::shared_ptr &parent, * *** */ -void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph &graph) { +void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph& graph) { #if !defined(OPENVINO_ARCH_X86) && !defined(OPENVINO_ARCH_X86_64) return; #endif - auto &graphNodes = graph.GetNodes(); + auto& graphNodes = graph.GetNodes(); auto isFusingSupported = [&](NodePtr conv, NodePtr child) { - return child->getType() == Type::Eltwise && - DnnlExtensionUtils::isUnarySupportedAsPostOp(child->getAlgorithm()); + return child->getType() == Type::Eltwise && DnnlExtensionUtils::isUnarySupportedAsPostOp(child->getAlgorithm()); }; - for (auto &graphNode : graphNodes) { + for (auto& graphNode : graphNodes) { const auto eltwiseNode = std::dynamic_pointer_cast(graphNode); if (graphNode->getType() != Type::Eltwise || graphNode->getAlgorithm() != Algorithm::EltwiseAdd || !eltwiseNode || eltwiseNode->isWithBroadcast()) @@ -1384,12 +1421,12 @@ void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph &graph) auto parent1 = graphNode->getParentEdgeAt(0)->getParent(); auto parent2 = graphNode->getParentEdgeAt(1)->getParent(); - bool isSuitableParent1 = parent1->getType() == Type::Convolution - || parent1->getType() == Type::BinaryConvolution; - bool isSuitableParent2 = parent2->getType() == Type::Convolution - || parent2->getType() == Type::BinaryConvolution; + bool isSuitableParent1 = + parent1->getType() == Type::Convolution || parent1->getType() == Type::BinaryConvolution; + bool isSuitableParent2 = + parent2->getType() == Type::Convolution || parent2->getType() == Type::BinaryConvolution; - auto canFuseSum = [](node::BinaryConvolution *binConv, NodePtr fuseCandidate) { + auto canFuseSum = [](node::BinaryConvolution* binConv, NodePtr fuseCandidate) { if (binConv->getImplType() == impl_desc_type::ref) return false; @@ -1408,12 +1445,12 @@ void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph &graph) return false; }; - auto* binConvNode1 = dynamic_cast(parent1.get()); + auto* binConvNode1 = dynamic_cast(parent1.get()); if (binConvNode1) { isSuitableParent1 = isSuitableParent1 && canFuseSum(binConvNode1, graphNode); } - auto* binConvNode2 = dynamic_cast(parent2.get()); + auto* binConvNode2 = dynamic_cast(parent2.get()); if (binConvNode2) { isSuitableParent2 = isSuitableParent2 && canFuseSum(binConvNode2, graphNode); } @@ -1427,7 +1464,7 @@ void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph &graph) return false; }; - auto* convNode1 = dynamic_cast(parent1.get()); + auto* convNode1 = dynamic_cast(parent1.get()); if (convNode1) { if (!convNode1->canBeExecutedInInt8()) { isSuitableParent1 = isSuitableParent1 && convNode1->getFusedWith().empty(); @@ -1436,7 +1473,7 @@ void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph &graph) } } - auto* convNode2 = dynamic_cast(parent2.get()); + auto* convNode2 = dynamic_cast(parent2.get()); if (convNode2) { if (!convNode2->canBeExecutedInInt8()) { isSuitableParent2 = isSuitableParent2 && convNode2->getFusedWith().empty(); @@ -1455,9 +1492,9 @@ void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph &graph) // not merged operation (peerNode) has to be in low precision const auto isBranchQuantized = [](const NodePtr& branchParent) { const auto& fused = branchParent->getFusedWith(); - const auto branchPrecision = fused.empty() ? - branchParent->getOriginalOutputPrecisionAtPort(0) : - fused[fused.size() - 1]->getOriginalOutputPrecisionAtPort(0); + const auto branchPrecision = fused.empty() + ? branchParent->getOriginalOutputPrecisionAtPort(0) + : fused[fused.size() - 1]->getOriginalOutputPrecisionAtPort(0); return (branchPrecision == ov::element::i8) || (branchPrecision == ov::element::u8); }; @@ -1527,15 +1564,16 @@ void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph &graph) // be overwritten. Should verify that all other consumer already read it and // we can spoil input data. // TODO: rewrite once we add "Inplace" reporting mechanism - for (auto & edge : peerNode->getChildEdges()) { + for (auto& edge : peerNode->getChildEdges()) { if (!fuse_allowed) break; fuse_allowed &= is_data_dependency(edge.lock()->getChild(), sum); } - if (!fuse_allowed) continue; + if (!fuse_allowed) + continue; if (graphNode->getChildEdges().size() == 1 && - isFusingSupported(graphNode, graphNode->getChildEdgeAt(0)->getChild())) { + isFusingSupported(graphNode, graphNode->getChildEdgeAt(0)->getChild())) { auto relu_shared = graphNode->getChildEdgeAt(0)->getChild(); lastNode = relu_shared; if (mergedConv->isConstant() && !lastNode->isConstant()) @@ -1545,8 +1583,8 @@ void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph &graph) lastNode->fuseInto(mergedConv); - if (mergedConv->fusedWith.size() > 0 && - (mergedConv->fusedWith[0]->getType() == Type::Convolution || mergedConv->fusedWith[0]->getType() == Type::BinaryConvolution)) { + if (mergedConv->fusedWith.size() > 0 && (mergedConv->fusedWith[0]->getType() == Type::Convolution || + mergedConv->fusedWith[0]->getType() == Type::BinaryConvolution)) { // Merged with DW_conv. Shape may change mergedConv->inputShapes.push_back(mergedConv->fusedWith[0]->getOutputShapeAtPort(0)); } else { @@ -1577,7 +1615,7 @@ void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph &graph) graph.CreateEdge(peerNode, mergedConv, peer_port, childPort); std::vector edges_to_reconnect = lastNode->getChildEdges(); - for (auto &edge_w : edges_to_reconnect) { + for (auto& edge_w : edges_to_reconnect) { auto edge = edge_w.lock(); auto child = edge->getChild(); int idxParent = edge->getInputNum(); @@ -1597,7 +1635,7 @@ void GraphOptimizer::FuseConvolutionSumAndConvolutionSumActivation(Graph &graph) } } -void GraphOptimizer::FuseMVNAndSimpleOperation(Graph &graph) { +void GraphOptimizer::FuseMVNAndSimpleOperation(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableParentNode = [](NodePtr node) { @@ -1624,7 +1662,7 @@ void GraphOptimizer::FuseMVNAndSimpleOperation(Graph &graph) { if (childNode->getType() == Type::FakeQuantize || childNode->getType() == Type::Eltwise) { auto parentEdges = childNode->parentEdges; - for (auto &parentEdge : parentEdges) { + for (auto& parentEdge : parentEdges) { auto p_edge = parentEdge.lock(); if (p_edge->getParent()->getType() == Type::MVN) continue; @@ -1637,7 +1675,7 @@ void GraphOptimizer::FuseMVNAndSimpleOperation(Graph &graph) { } } -void GraphOptimizer::FuseInterpolateAndSimpleOperation(Graph &graph) { +void GraphOptimizer::FuseInterpolateAndSimpleOperation(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableParentNode = [](NodePtr node) { @@ -1646,8 +1684,8 @@ void GraphOptimizer::FuseInterpolateAndSimpleOperation(Graph &graph) { auto isSuitableChildNode = [&](NodePtr parentNode, NodePtr childNode) { // Avoid cycle dependencies - for (auto &childParentEdge : childNode->getParentEdges()) { - for (auto &parentParentEdge : parentNode->getParentEdges()) { + for (auto& childParentEdge : childNode->getParentEdges()) { + for (auto& parentParentEdge : parentNode->getParentEdges()) { if (childParentEdge.lock()->getParent() == parentParentEdge.lock()->getParent()) return false; } @@ -1683,7 +1721,7 @@ void GraphOptimizer::FuseInterpolateAndSimpleOperation(Graph &graph) { if (childNode->getType() == Type::FakeQuantize || childNode->getType() == Type::Eltwise) { auto parentEdges = childNode->parentEdges; - for (auto &parentEdge : parentEdges) { + for (auto& parentEdge : parentEdges) { auto p_edge = parentEdge.lock(); if (p_edge->getParent()->getType() == Type::Interpolate) continue; @@ -1696,7 +1734,7 @@ void GraphOptimizer::FuseInterpolateAndSimpleOperation(Graph &graph) { } } -void GraphOptimizer::FuseNormalizeL2AndSimpleOperation(Graph &graph) { +void GraphOptimizer::FuseNormalizeL2AndSimpleOperation(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableParentNode = [](NodePtr node) { @@ -1723,7 +1761,7 @@ void GraphOptimizer::FuseNormalizeL2AndSimpleOperation(Graph &graph) { if (childNode->getType() == Type::FakeQuantize || childNode->getType() == Type::Eltwise) { auto parentEdges = childNode->parentEdges; - for (auto &parentEdge : parentEdges) { + for (auto& parentEdge : parentEdges) { auto p_edge = parentEdge.lock(); if (p_edge->getParent()->getType() == Type::NormalizeL2) continue; @@ -1736,7 +1774,7 @@ void GraphOptimizer::FuseNormalizeL2AndSimpleOperation(Graph &graph) { } } -void GraphOptimizer::FuseReduceAndSimpleOperation(Graph &graph) { +void GraphOptimizer::FuseReduceAndSimpleOperation(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableParentNode = [](NodePtr node) { @@ -1763,7 +1801,7 @@ void GraphOptimizer::FuseReduceAndSimpleOperation(Graph &graph) { if (childNode->getType() == Type::FakeQuantize || childNode->getType() == Type::Eltwise) { auto parentEdges = childNode->parentEdges; - for (auto &parentEdge : parentEdges) { + for (auto& parentEdge : parentEdges) { auto p_edge = parentEdge.lock(); if (p_edge == nullptr) OPENVINO_THROW("Cannot get parent edge ", childNode->getName()); @@ -1778,7 +1816,7 @@ void GraphOptimizer::FuseReduceAndSimpleOperation(Graph &graph) { } } -void GraphOptimizer::FuseEltwiseAndSimple(Graph &graph) { +void GraphOptimizer::FuseEltwiseAndSimple(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableParentNode = [](NodePtr node) { @@ -1788,14 +1826,14 @@ void GraphOptimizer::FuseEltwiseAndSimple(Graph &graph) { auto isSuitableChildNode = [&](NodePtr parentNode, NodePtr childNode) { if (parentNode->isConstant() && !childNode->isConstant()) return false; - for (auto &childParentEdge : childNode->getParentEdges()) { + for (auto& childParentEdge : childNode->getParentEdges()) { // WA to prevent unsupported reorder exception issue in some cases if (childParentEdge.lock()->getParent()->getType() == Type::Split) { return false; } // Avoid cycle dependencies - for (auto &parentParentEdge : parentNode->getParentEdges()) { + for (auto& parentParentEdge : parentNode->getParentEdges()) { if (childParentEdge.lock()->getParent() == parentParentEdge.lock()->getParent()) return false; } @@ -1819,7 +1857,8 @@ void GraphOptimizer::FuseEltwiseAndSimple(Graph &graph) { auto childNode = parentNode->getChildEdgeAt(0)->getChild(); - if ((parentNode->isDynamicNode() && !childNode->isDynamicNode()) || (!parentNode->isDynamicNode() && childNode->isDynamicNode())) { + if ((parentNode->isDynamicNode() && !childNode->isDynamicNode()) || + (!parentNode->isDynamicNode() && childNode->isDynamicNode())) { parent++; continue; } @@ -1835,7 +1874,7 @@ void GraphOptimizer::FuseEltwiseAndSimple(Graph &graph) { if (childNode->getType() == Type::FakeQuantize) { auto parentEdges = childNode->parentEdges; - for (auto &parentEdge : parentEdges) { + for (auto& parentEdge : parentEdges) { auto p_edge = parentEdge.lock(); if (p_edge->getParent()->getType() == Type::Eltwise) continue; @@ -1851,9 +1890,11 @@ void GraphOptimizer::FuseEltwiseAndSimple(Graph &graph) { for (size_t i = 0; i < parents.size(); i++) { auto p_edge = parents[i].lock(); - if (!p_edge) continue; + if (!p_edge) + continue; auto parent = p_edge->getParent(); - if (!parent) continue; + if (!parent) + continue; if (parent == parentNode) { for (size_t j = 0; j < children.size(); j++) { @@ -1863,7 +1904,7 @@ void GraphOptimizer::FuseEltwiseAndSimple(Graph &graph) { if (!child) continue; - EdgePtr &remEdge = p_edge; + EdgePtr& remEdge = p_edge; int inNum = 0; if (remEdge) { inNum = remEdge->getInputNum(); @@ -1879,7 +1920,7 @@ void GraphOptimizer::FuseEltwiseAndSimple(Graph &graph) { graph.CreateEdge(parent, child, inNum, outNum); } } else { - EdgePtr &remEdge = p_edge; + EdgePtr& remEdge = p_edge; int inNum = 0; int outNum = parentNode->getParentEdges().size(); if (remEdge) { @@ -1970,15 +2011,14 @@ void GraphOptimizer::ShareReorders(Graph& graph) { } } -void GraphOptimizer::DropDoubleReorders(Graph &graph) { +void GraphOptimizer::DropDoubleReorders(Graph& graph) { std::set processed; auto& nodes = graph.GetNodes(); for (size_t i = 0; i < nodes.size(); i++) { auto node = nodes[i]; - if (processed.find(node) == processed.end() && node->getType() == Type::Reorder - && node->getChildEdges().size() == 1 - && node->getChildEdgeAt(0)->getChild()->getType() == Type::Reorder ) { + if (processed.find(node) == processed.end() && node->getType() == Type::Reorder && + node->getChildEdges().size() == 1 && node->getChildEdgeAt(0)->getChild()->getType() == Type::Reorder) { auto nextNode = node->getChildEdgeAt(0)->getChild(); Reorder* n = dynamic_cast(node.get()); if (n == nullptr) @@ -2003,7 +2043,8 @@ void GraphOptimizer::DropDoubleReorders(Graph &graph) { if (cur->getChild() == c) edge = cur; } - if (!edge) OPENVINO_THROW("Inappropriate graph processing"); + if (!edge) + OPENVINO_THROW("Inappropriate graph processing"); std::string layerName = edge->getParent()->getName() + "_ScaleReorder_" + edge->getChild()->getName(); graph.InsertReorder(edge, layerName, n->getInput(), nn->getOutput(), false); @@ -2012,11 +2053,12 @@ void GraphOptimizer::DropDoubleReorders(Graph &graph) { } } -void GraphOptimizer::FuseClampAndFakeQuantize(Graph &graph) { +void GraphOptimizer::FuseClampAndFakeQuantize(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableClampNode = [](NodePtr node) { - return node->getType() == Type::Eltwise && node->getChildEdges().size() == 1 && node->getAlgorithm() == Algorithm::EltwiseClamp; + return node->getType() == Type::Eltwise && node->getChildEdges().size() == 1 && + node->getAlgorithm() == Algorithm::EltwiseClamp; }; auto isSuitableFakeQuantizeNode = [](NodePtr node) { @@ -2024,7 +2066,7 @@ void GraphOptimizer::FuseClampAndFakeQuantize(Graph &graph) { }; auto fuseClampAndFakeQuantizeNodes = [](NodePtr parent, NodePtr child) { - auto* eltwiseNode = dynamic_cast(parent.get()); + auto* eltwiseNode = dynamic_cast(parent.get()); if (eltwiseNode == nullptr) OPENVINO_THROW("Cannot cast ", parent->getName(), " to Eltwise node"); @@ -2050,12 +2092,14 @@ void GraphOptimizer::FuseClampAndFakeQuantize(Graph &graph) { for (size_t i = 0; i < graphNodes.size(); i++) { auto parent = graphNodes[i]; - if (!isSuitableClampNode(parent)) continue; + if (!isSuitableClampNode(parent)) + continue; CPU_GRAPH_OPTIMIZER_SCOPE(FuseClampAndFakeQuantize_ClalmpNode); auto child = parent->getChildEdgeAt(0)->getChild(); - if (!isSuitableFakeQuantizeNode(child)) continue; + if (!isSuitableFakeQuantizeNode(child)) + continue; CPU_GRAPH_OPTIMIZER_SCOPE(FuseClampAndFakeQuantize_QuantizeNode); @@ -2065,7 +2109,7 @@ void GraphOptimizer::FuseClampAndFakeQuantize(Graph &graph) { } } -void GraphOptimizer::FusePerformedAsScaleShiftAndFakeQuantize(Graph &graph) { +void GraphOptimizer::FusePerformedAsScaleShiftAndFakeQuantize(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto getNonConstPort = [](const NodePtr& node) { @@ -2083,11 +2127,12 @@ void GraphOptimizer::FusePerformedAsScaleShiftAndFakeQuantize(Graph &graph) { }; auto isSuitableScaleShiftNode = [getNonConstPort](const NodePtr& node) { - if (!one_of(node->getAlgorithm(), Algorithm::EltwiseAdd, - Algorithm::EltwiseSubtract, - Algorithm::EltwiseMultiply, - Algorithm::EltwiseDivide, - Algorithm::EltwiseMulAdd)) + if (!one_of(node->getAlgorithm(), + Algorithm::EltwiseAdd, + Algorithm::EltwiseSubtract, + Algorithm::EltwiseMultiply, + Algorithm::EltwiseDivide, + Algorithm::EltwiseMulAdd)) return false; const auto nonConstPort = getNonConstPort(node); @@ -2117,7 +2162,7 @@ void GraphOptimizer::FusePerformedAsScaleShiftAndFakeQuantize(Graph &graph) { const NodePtr eltwiseInput = parentEltwise->getParentEdgeAt(getNonConstPort(parent))->getParent(); std::tie(scalesBuffer, shiftsBuffer) = parentEltwise->getScalesAndShifts(eltwiseInput.get()); - const auto &outputShape = child->getOutputShapeAtPort(0); + const auto& outputShape = child->getOutputShapeAtPort(0); VectorDims outputDims = outputShape.getDims(); // We need to compute explicitly port with unfolded parent, @@ -2178,7 +2223,7 @@ void GraphOptimizer::FusePerformedAsScaleShiftAndFakeQuantize(Graph &graph) { std::vector zeroShift(newInputScale.size(), 0.f); const auto isSubnormal = [](const float value) { - const uint32_t *u32data = reinterpret_cast(&value); + const uint32_t* u32data = reinterpret_cast(&value); return (*u32data) && (((*u32data) & (0xFF << 23)) == 0); }; @@ -2220,18 +2265,20 @@ void GraphOptimizer::FusePerformedAsScaleShiftAndFakeQuantize(Graph &graph) { for (size_t i = 0; i < graphNodes.size(); i++) { auto parent = graphNodes[i]; - if (!isSuitableScaleShiftNode(parent)) continue; + if (!isSuitableScaleShiftNode(parent)) + continue; CPU_GRAPH_OPTIMIZER_SCOPE(FusePerformedAsScaleShiftAndFakeQuantize_ShiftNode); auto child = parent->getChildEdgeAt(0)->getChild(); - if (!isSuitableFakeQuantizeNode(child)) continue; + if (!isSuitableFakeQuantizeNode(child)) + continue; CPU_GRAPH_OPTIMIZER_SCOPE(FusePerformedAsScaleShiftAndFakeQuantize_QuantizeNode); if (fuseScaleShiftAndFakeQuantizeNodes(parent, child)) { auto parentEdges = parent->parentEdges; - for (auto &parentEdge : parentEdges) { + for (auto& parentEdge : parentEdges) { auto p_edge = parentEdge.lock(); if (!p_edge->getParent()->isConstant()) continue; @@ -2355,7 +2402,12 @@ void GraphOptimizer::mergeTransposeReshapeReorder(Graph& graph, transposeNode->getName(), " is not a transpose node"); - const auto& inOrder = transposeNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].getMemDesc()->as()->getOrder(); + const auto& inOrder = transposeNode->getSelectedPrimitiveDescriptor() + ->getConfig() + .inConfs[0] + .getMemDesc() + ->as() + ->getOrder(); const auto& outOrder = reorderOutDesc->as()->getOrder(); // Permutation must be set and reorder mustn't be optimized in 2 cases: // 1. Transpose has blocked input & non-blocked output @@ -2371,11 +2423,13 @@ void GraphOptimizer::mergeTransposeReshapeReorder(Graph& graph, } } - std::string reorderName = nodeBeforeSequence->getName() + "_" + Reorder::getReorderArgs(*reorderInDesc, *reorderOutDesc); + std::string reorderName = + nodeBeforeSequence->getName() + "_" + Reorder::getReorderArgs(*reorderInDesc, *reorderOutDesc); if (isOptimized) - reorderName += "_fake"; + reorderName += "_fake"; DEBUG_LOG("mergeTransposeAndReorder ", parentNode->getName(), " and ", childNode->getName(), " -> ", reorderName); - auto reorder_layout = std::make_shared(*reorderInDesc, *reorderOutDesc, reorderName, graph.getGraphContext()); + auto reorder_layout = + std::make_shared(*reorderInDesc, *reorderOutDesc, reorderName, graph.getGraphContext()); reorder_layout->setOptimized(isOptimized); reorder_layout->setSrcPermutation(srcPerm); @@ -2388,10 +2442,8 @@ void GraphOptimizer::mergeTransposeReshapeReorder(Graph& graph, Reorder::getReorderArgs(*reorderOutDesc, *finalDesc) + "_" + nodeAfterSequence->getName(); - reorder_last = std::make_shared(*reorderOutDesc, - *finalDesc, - reorderLayerName2, - graph.getGraphContext()); + reorder_last = + std::make_shared(*reorderOutDesc, *finalDesc, reorderLayerName2, graph.getGraphContext()); reorder_last->setOptimized(false); reorder_last->setSrcPermutation(srcPerm); graph.CreateEdge(reorder_layout, reorder_last, 0, 0); @@ -2445,10 +2497,10 @@ void GraphOptimizer::MergeTransposeAndReorder(Graph& graph) { return false; }; - return node->getType() == Type::Transpose - && node->getChildEdges().size() == 1 - && !node->isDynamicNode() // TODO [DS]: enable for dynamic shapes when inPlace in the dynamic case is available (CVS-74863) - && !prevNodeIsConvSum(node); + return node->getType() == Type::Transpose && node->getChildEdges().size() == 1 && + !node->isDynamicNode() // TODO [DS]: enable for dynamic shapes when inPlace in the dynamic case is + // available (CVS-74863) + && !prevNodeIsConvSum(node); }; auto isSuitableReshape = [](NodePtr node) { @@ -2473,8 +2525,9 @@ void GraphOptimizer::MergeTransposeAndReorder(Graph& graph) { }; auto isSuitableReorder = [](NodePtr node) { - return node->getType() == Type::Reorder - && !node->isDynamicNode(); // TODO [DS]: enable for dynamic shapes when inPlace in the dynamic case is available (CVS-74863) + return node->getType() == Type::Reorder && + !node->isDynamicNode(); // TODO [DS]: enable for dynamic shapes when inPlace in the dynamic case is + // available (CVS-74863) }; auto updateOrder = [](const VectorDims& originalOrder, NodePtr reshape) { @@ -2542,17 +2595,28 @@ void GraphOptimizer::MergeTransposeAndReorder(Graph& graph) { const auto transposeNode = std::dynamic_pointer_cast(parentNode); const auto reorderNode = std::dynamic_pointer_cast(childNode); - std::shared_ptr reshapeNode = intermNode != nullptr ? std::dynamic_pointer_cast(intermNode) : nullptr; + std::shared_ptr reshapeNode = + intermNode != nullptr ? std::dynamic_pointer_cast(intermNode) : nullptr; if (!transposeNode || !reorderNode || (intermNode && !reshapeNode)) { continue; } auto transposeOrder = updateOrder(transposeNode->getOrder(), reshapeNode); - auto descBeforeReorder = reorderNode->getParentEdgeAt(0)->getParent()->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].getMemDesc(); + auto descBeforeReorder = reorderNode->getParentEdgeAt(0) + ->getParent() + ->getSelectedPrimitiveDescriptor() + ->getConfig() + .outConfs[0] + .getMemDesc(); auto layoutOrder = descBeforeReorder->as()->getOrder(); - auto inBlockedDesc = reorderNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].getMemDesc()->as(); - auto outBlockedDesc = reorderNode->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].getMemDesc()->as(); + auto inBlockedDesc = + reorderNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].getMemDesc()->as(); + auto outBlockedDesc = reorderNode->getSelectedPrimitiveDescriptor() + ->getConfig() + .outConfs[0] + .getMemDesc() + ->as(); auto& inOrder = inBlockedDesc->getOrder(); auto& outOrder = outBlockedDesc->getOrder(); @@ -2563,13 +2627,11 @@ void GraphOptimizer::MergeTransposeAndReorder(Graph& graph) { } } -void GraphOptimizer::MergeReorderAndTranspose(Graph &graph) { +void GraphOptimizer::MergeReorderAndTranspose(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableTranspose = [](NodePtr node) { - return node->getType() == Type::Transpose - && node->getChildEdges().size() == 1 - && !node->isDynamicNode(); + return node->getType() == Type::Transpose && node->getChildEdges().size() == 1 && !node->isDynamicNode(); }; auto isSuitableReshape = [](NodePtr node) { @@ -2659,7 +2721,8 @@ void GraphOptimizer::MergeReorderAndTranspose(Graph &graph) { auto transposeNode = std::dynamic_pointer_cast(childNode); auto reorderNode = std::dynamic_pointer_cast(parentNode); - std::shared_ptr reshapeNode = intermNode != nullptr ? std::dynamic_pointer_cast(intermNode) : nullptr; + std::shared_ptr reshapeNode = + intermNode != nullptr ? std::dynamic_pointer_cast(intermNode) : nullptr; if (!transposeNode || !reorderNode || (intermNode && !reshapeNode)) { continue; } @@ -2668,15 +2731,20 @@ void GraphOptimizer::MergeReorderAndTranspose(Graph &graph) { auto descAfterTranspose = transposeNode->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].getMemDesc(); auto layoutOrder = updateOrder(descAfterTranspose->as()->getOrder(), reshapeNode); - auto inBlockedDesc = reorderNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].getMemDesc()->as(); - auto outBlockedDesc = reorderNode->getSelectedPrimitiveDescriptor()->getConfig().outConfs[0].getMemDesc()->as(); + auto inBlockedDesc = + reorderNode->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].getMemDesc()->as(); + auto outBlockedDesc = reorderNode->getSelectedPrimitiveDescriptor() + ->getConfig() + .outConfs[0] + .getMemDesc() + ->as(); auto& inOrder = inBlockedDesc->getOrder(); auto& outOrder = outBlockedDesc->getOrder(); if (checkAscendingFinalOrder(transposeOrder, layoutOrder, inOrder, outOrder)) { - // Reorder node doesn't support (with rare exceptions) reordering in case of different ranks on input and output. - // So the merge can be performed only in the case when the fused reorder will be optimized. + // Reorder node doesn't support (with rare exceptions) reordering in case of different ranks on input and + // output. So the merge can be performed only in the case when the fused reorder will be optimized. if (parentNode->getInputShapeAtPort(0).getRank() != childNode->getOutputShapeAtPort(0).getRank() && !canBeInplaced(parentNode, childNode)) { continue; @@ -2686,14 +2754,15 @@ void GraphOptimizer::MergeReorderAndTranspose(Graph &graph) { } } -void GraphOptimizer::reshapeRnnSeq(Graph &graph) { +void GraphOptimizer::reshapeRnnSeq(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableParentNode = [](NodePtr node) { if (node->type != Type::RNNSeq) return false; auto rnnNode = std::dynamic_pointer_cast(node); - return rnnNode && !rnnNode->hasNativeOrder() && node->outputShapes[0].getRank() == 4 && node->outputShapes[0].getDims()[1] == 1; + return rnnNode && !rnnNode->hasNativeOrder() && node->outputShapes[0].getRank() == 4 && + node->outputShapes[0].getDims()[1] == 1; }; for (size_t i = 0; i < graphNodes.size(); i++) { @@ -2715,10 +2784,12 @@ void GraphOptimizer::reshapeRnnSeq(Graph &graph) { auto edge = childrenEdges[j]; auto childNode = edge->getChild(); - const auto secondInput = std::make_shared(ov::element::i32, ov::Shape{1}, std::vector{1}); + const auto secondInput = + std::make_shared(ov::element::i32, ov::Shape{1}, std::vector{1}); const auto unsqueeze = std::make_shared( std::make_shared(parentNode->getOriginalOutputPrecisionAtPort(0), - parentNode->getOutputShapeAtPort(0).toPartialShape()), secondInput); + parentNode->getOutputShapeAtPort(0).toPartialShape()), + secondInput); unsqueeze->set_friendly_name(parentNode->getName() + "_abc_a1bc_" + std::to_string(j)); const auto cpuUnsqueeze = std::make_shared(unsqueeze, graph.getGraphContext()); @@ -2758,7 +2829,7 @@ void GraphOptimizer::RemoveSameConvert(Graph& graph) { } } -void GraphOptimizer::RemoveMemoryInputConvert(Graph &graph) { +void GraphOptimizer::RemoveMemoryInputConvert(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableNode = [](const NodePtr& node) { @@ -2784,7 +2855,7 @@ void GraphOptimizer::RemoveMemoryInputConvert(Graph &graph) { } } -void GraphOptimizer::RemoveConvertMemoryOutput(Graph &graph) { +void GraphOptimizer::RemoveConvertMemoryOutput(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableNode = [](const NodePtr& node) { @@ -2812,7 +2883,7 @@ void GraphOptimizer::RemoveConvertMemoryOutput(Graph &graph) { } } -void GraphOptimizer::MatchSdpaKvCache(Graph &graph) { +void GraphOptimizer::MatchSdpaKvCache(Graph& graph) { auto& graphNodes = graph.GetNodes(); auto isSuitableMemInput = [](const NodePtr& node) -> bool { @@ -2829,7 +2900,7 @@ void GraphOptimizer::MatchSdpaKvCache(Graph &graph) { if (Type::ScaledDotProductAttention == childNode->getType()) { if (childSdpa && childSdpa != childNode) { - //only one child SDPA supported + // only one child SDPA supported return false; } childSdpa = childNode; @@ -2872,7 +2943,7 @@ void GraphOptimizer::MatchSdpaKvCache(Graph &graph) { input_prc = ov::optional(node->getOriginalInputPrecisionAtPort(0)); } - //search for SDPA + // search for SDPA std::shared_ptr sdpa; for (auto&& edge : node->getChildEdgesAtPort(0)) { auto child = edge->getChild(); @@ -2886,19 +2957,18 @@ void GraphOptimizer::MatchSdpaKvCache(Graph &graph) { } } - //capture reference to the original mem output before graph transformations + // capture reference to the original mem output before graph transformations auto& memOutput = memInputNode->getOutputNode(); - auto memInputSdpa = std::make_shared( - memInputNode->getId(), - memInputNode->getName(), - memInputNode->getTypeStr(), - memInputNode->getOutputShapeAtPort(0), - memInputNode->getOriginalOutputPrecisionAtPort(0), - graph.getGraphContext(), - input_shape, - input_prc, - sdpa); + auto memInputSdpa = std::make_shared(memInputNode->getId(), + memInputNode->getName(), + memInputNode->getTypeStr(), + memInputNode->getOutputShapeAtPort(0), + memInputNode->getOriginalOutputPrecisionAtPort(0), + graph.getGraphContext(), + input_shape, + input_prc, + sdpa); if (!memInputNode->getParentEdges().empty()) { auto parentEdge = memInputNode->getParentEdgeAt(0); @@ -2915,14 +2985,13 @@ void GraphOptimizer::MatchSdpaKvCache(Graph &graph) { graph.CreateEdge(memInputSdpa, child, 0, outputNum); } - //create a stub memory output - auto memOutputStub = std::make_shared( - memOutput.getId(), - memOutput.getName(), - memOutput.getTypeStr(), - memOutput.getInputShapeAtPort(0), - memOutput.getOriginalInputPrecisionAtPort(0), - graph.getGraphContext()); + // create a stub memory output + auto memOutputStub = std::make_shared(memOutput.getId(), + memOutput.getName(), + memOutput.getTypeStr(), + memOutput.getInputShapeAtPort(0), + memOutput.getOriginalInputPrecisionAtPort(0), + graph.getGraphContext()); auto memOutputEdge = memOutput.getParentEdgeAt(0); const auto inputNum = memOutputEdge->getInputNum(); @@ -2934,7 +3003,7 @@ void GraphOptimizer::MatchSdpaKvCache(Graph &graph) { } } -void GraphOptimizer::DropRedundantMemoryOutput(Graph &graph) { +void GraphOptimizer::DropRedundantMemoryOutput(Graph& graph) { // When we have a MemoryInput->MemoryOutput pair, that means that the state is immediately populated with the init // subgraph values when the init subgraph exists. In all the other cases the state is simply a read only object. // We can optimize such a case removing the MemoryOutput node and transferring the state values update @@ -2975,7 +3044,7 @@ void GraphOptimizer::DropRedundantMemoryOutput(Graph &graph) { } if (MemoryOutput && MemoryOutput != childNode) { - //only one child MemoryOutput is expected + // only one child MemoryOutput is expected return false; } MemoryOutput = childNode; @@ -3003,7 +3072,7 @@ void GraphOptimizer::DropRedundantMemoryOutput(Graph &graph) { inputPrc = ov::optional(node->getOriginalInputPrecisionAtPort(0)); } - //search for the MemoryOutputNode + // search for the MemoryOutputNode NodePtr memoryOutputNode; for (auto&& edge : node->getChildEdgesAtPort(0)) { auto child = edge->getChild(); @@ -3046,5 +3115,5 @@ void GraphOptimizer::DropRedundantMemoryOutput(Graph &graph) { } } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/graph_optimizer.h b/src/plugins/intel_cpu/src/graph_optimizer.h index 536ef468a09816..90cf9c41c0907e 100644 --- a/src/plugins/intel_cpu/src/graph_optimizer.h +++ b/src/plugins/intel_cpu/src/graph_optimizer.h @@ -16,42 +16,42 @@ class GraphOptimizer { public: void ApplyCommonGraphOptimizations(Graph& graph); void ApplyImplSpecificGraphOptimizations(Graph& graph); - void ShareReorders(Graph &graph); + void ShareReorders(Graph& graph); private: - void FuseConvMatmulFCDeconvAndDQScales(Graph &graph); - void FuseConvolutionMatMulDeconvAndBias(Graph &graph); - void FuseDeconvolutionAndSimpleOperation(Graph &graph); - void FuseMultiplyAndAdd(Graph &graph); + void FuseConvMatmulFCDeconvAndDQScales(Graph& graph); + void FuseConvolutionMatMulDeconvAndBias(Graph& graph); + void FuseDeconvolutionAndSimpleOperation(Graph& graph); + void FuseMultiplyAndAdd(Graph& graph); void MergeConvertAndScaleShift(Graph& graph); void FuseFCAndConvertOnWeights(Graph& graph); void FuseFCAndTransposeOnWeights(Graph& graph); - void FuseFullyConnectedAndSimpleOperation(Graph &graph); - void FuseMatMulAndSimpleOperation(Graph &graph); - void FuseConvolutionAndSimpleOperationThroughMaxPool(Graph &graph); - void FuseConvolutionAndSimpleOperation(Graph &graph); - void FuseConvolutionAndDWConvolution(Graph &graph); - void FusePoolingAndFakeQuantize(Graph &graph); - void FuseConvolutionSumAndConvolutionSumActivation(Graph &graph); - void FuseMVNAndSimpleOperation(Graph &graph); - void FuseInterpolateAndSimpleOperation(Graph &graph); - void FuseNormalizeL2AndSimpleOperation(Graph &graph); - void FuseReduceAndSimpleOperation(Graph &graph); + void FuseFullyConnectedAndSimpleOperation(Graph& graph); + void FuseMatMulAndSimpleOperation(Graph& graph); + void FuseConvolutionAndSimpleOperationThroughMaxPool(Graph& graph); + void FuseConvolutionAndSimpleOperation(Graph& graph); + void FuseConvolutionAndDWConvolution(Graph& graph); + void FusePoolingAndFakeQuantize(Graph& graph); + void FuseConvolutionSumAndConvolutionSumActivation(Graph& graph); + void FuseMVNAndSimpleOperation(Graph& graph); + void FuseInterpolateAndSimpleOperation(Graph& graph); + void FuseNormalizeL2AndSimpleOperation(Graph& graph); + void FuseReduceAndSimpleOperation(Graph& graph); void DropDoubleReorders(Graph& graph); - void FuseConvolutionAndZeroPoints(Graph &graph); - void FuseBroadcastAndEltwise(Graph &graph); - void FuseEltwiseAndSimple(Graph &graph); - void FusePerformedAsScaleShiftAndFakeQuantize(Graph &graph); - void FuseClampAndFakeQuantize(Graph &graph); - void MergeTransposeAndReorder(Graph &graph); - void MergeReorderAndTranspose(Graph &graph); - void reshapeRnnSeq(Graph &graph); - void RemoveSameConvert(Graph &graph); - void RemoveMemoryInputConvert(Graph &graph); - void RemoveConvertMemoryOutput(Graph &graph); - void MatchSdpaKvCache(Graph &graph); - void DropRedundantMemoryOutput(Graph &graph); + void FuseConvolutionAndZeroPoints(Graph& graph); + void FuseBroadcastAndEltwise(Graph& graph); + void FuseEltwiseAndSimple(Graph& graph); + void FusePerformedAsScaleShiftAndFakeQuantize(Graph& graph); + void FuseClampAndFakeQuantize(Graph& graph); + void MergeTransposeAndReorder(Graph& graph); + void MergeReorderAndTranspose(Graph& graph); + void reshapeRnnSeq(Graph& graph); + void RemoveSameConvert(Graph& graph); + void RemoveMemoryInputConvert(Graph& graph); + void RemoveConvertMemoryOutput(Graph& graph); + void MatchSdpaKvCache(Graph& graph); + void DropRedundantMemoryOutput(Graph& graph); bool canBeInplaced(const NodePtr& parentNode, const NodePtr& childNode); // Method checks that after the sequential execution of Transpose and Reorder nodes, @@ -68,19 +68,22 @@ class GraphOptimizer { // Examples: // 1. Direct order, no Reshape node. // Before: [N,C,H,W]abcd==>Transpose(0312)==>[N,W,C,H]abcd==>Reorder(abcd->acdb)==>[N,W,C,H]acdb - // [N,C,H,W]abcd is equivalent to the [N,W,C,H]acdb, so the Transpose and Reorder can be fused into single optimized Reorder: - // After: [N,C,H,W]abcd==>Reorder(abcd->acdb, isOptimized=true)==>[N,W,C,H]acdb + // [N,C,H,W]abcd is equivalent to the [N,W,C,H]acdb, so the Transpose and Reorder can be fused into single + // optimized Reorder: After: [N,C,H,W]abcd==>Reorder(abcd->acdb, isOptimized=true)==>[N,W,C,H]acdb // 2. Reverse order, no Reshape node. // Before: [N,W,C,H]acdb==>Reorder(acdb->abcd)==>[N,W,C,H]abcd==>Transpose(0231)==>[N,C,H,W]abcd - // [N,W,C,H]acdb is equivalent to the [N,C,H,W]abcd, so the Transpose and Reorder can be fused into single optimized Reorder: - // After: [N,W,C,H]acdb==>Reorder(acdb->abcd, isOptimized=true)==>[N,C,H,W]abcd + // [N,W,C,H]acdb is equivalent to the [N,C,H,W]abcd, so the Transpose and Reorder can be fused into single + // optimized Reorder: After: [N,W,C,H]acdb==>Reorder(acdb->abcd, isOptimized=true)==>[N,C,H,W]abcd // 3. Direct order with Reshape node (L = H x w). - // Before: [N,L,C]abc==>Transpose(021)==>[N,C,L]abc==>Reshape==>[N,C,H,W]abcd==>Reoder(abcd->acdb)==>[N,C,H,W]acdb - // After: [N,L,C]abc==>Reorder(abc->acdb, isOptimized=true)==>[N,C,H,W]acdb + // Before: + // [N,L,C]abc==>Transpose(021)==>[N,C,L]abc==>Reshape==>[N,C,H,W]abcd==>Reoder(abcd->acdb)==>[N,C,H,W]acdb After: + // [N,L,C]abc==>Reorder(abc->acdb, isOptimized=true)==>[N,C,H,W]acdb // 4. Reverse order with Reshape node (L = H x W). - // Before: [N,C,H,W]acdb==>Reorder(acdb->abcd)==>[N,C,H,W]abcd==>Reshape==>[N,C,L]abc==>Transpose(021)==>[N,L,C]abc + // Before: + // [N,C,H,W]acdb==>Reorder(acdb->abcd)==>[N,C,H,W]abcd==>Reshape==>[N,C,L]abc==>Transpose(021)==>[N,L,C]abc // After: [N,C,H,W]acdb==>Reorder(acdb->abc, isOptimized=true)==>[N,L,C]abc - // Note: in some cases (inplace conflicts or transpose with blocked input and non-blocked output) the merged Reorder can not be optimized. + // Note: in some cases (inplace conflicts or transpose with blocked input and non-blocked output) the merged Reorder + // can not be optimized. void mergeTransposeReshapeReorder(Graph& graph, const NodePtr& transposeNode, const NodePtr& reshapeNode, @@ -88,5 +91,5 @@ class GraphOptimizer { const bool reverseOrder); }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/infer_request.cpp b/src/plugins/intel_cpu/src/infer_request.cpp index 26cdaf0860168a..3cfc34589623d2 100644 --- a/src/plugins/intel_cpu/src/infer_request.cpp +++ b/src/plugins/intel_cpu/src/infer_request.cpp @@ -8,17 +8,17 @@ #include "compiled_model.h" #include "dnnl_extension_utils.h" #include "itt.h" +#include "memory_desc/cpu_memory_desc_utils.h" #include "memory_state.h" #include "nodes/common/cpu_convert.h" -#include "memory_desc/cpu_memory_desc_utils.h" #include "nodes/memory_state_base.h" #include "openvino/core/shape.hpp" #include "openvino/runtime/make_tensor.hpp" #include "openvino/runtime/tensor.hpp" +#include "openvino/runtime/threading/cpu_message.hpp" #include "proxy_mem_blk.h" #include "utils/general_utils.h" #include "utils/ngraph_utils.hpp" -#include "openvino/runtime/threading/cpu_message.hpp" using OvString = ov::element_type_traits::value_type; @@ -56,7 +56,7 @@ void SyncInferRequest::create_infer_request() { init_tensor(it.first, ov::ISyncInferRequest::FoundPort::Type::OUTPUT); } - //create states according to the list of the MemoryStateNodes + // create states according to the list of the MemoryStateNodes for (auto&& node : m_graph->getInternalStateNodes()) { m_memory_states.emplace_back(node.second->makeState()); } @@ -162,7 +162,7 @@ static inline void change_edge_ptr(const EdgePtr& edge, ov::SoPtr& OPENVINO_ASSERT(mem != nullptr, "Edge with name '", *edge, "' doesn't have allocated memory object."); if (tensor->get_element_type() == element::string) { - auto memBlock = dynamic_cast(mem.get())->getStringMemoryBlockPtr(); + auto memBlock = dynamic_cast(mem.get())->getStringMemoryBlockPtr(); OPENVINO_ASSERT(memBlock); memBlock->setExtBuff(tensor->data(), tensor->get_size()); } else { @@ -177,14 +177,14 @@ void SyncInferRequest::change_default_ptr() { const auto& outputNodesMap = m_graph->GetOutputNodesMap(); std::unordered_set inputPtrs; - std::function& tensor)> changeInpPtr; + std::function& tensor)> changeInpPtr; if (m_graph->IsDynamic()) { - changeInpPtr = [&inputPtrs](const EdgePtr &edge, ov::SoPtr& tensor) { + changeInpPtr = [&inputPtrs](const EdgePtr& edge, ov::SoPtr& tensor) { change_edge_ptr(edge, tensor); inputPtrs.insert(tensor->data()); }; } else { - changeInpPtr = [](const EdgePtr &edge, ov::SoPtr& tensor) { + changeInpPtr = [](const EdgePtr& edge, ov::SoPtr& tensor) { change_edge_ptr(edge, tensor); }; } @@ -279,7 +279,7 @@ void SyncInferRequest::change_default_ptr() { } if (m_graph->IsDynamic()) { - const auto &outMemBlocksMap = m_graph->getOutputNodesMemBlocksMap(); + const auto& outMemBlocksMap = m_graph->getOutputNodesMemBlocksMap(); for (auto&& item : outMemBlocksMap) { const auto& name = item.first; @@ -291,20 +291,32 @@ void SyncInferRequest::change_default_ptr() { if (controlBlockItr != m_outputControlBlocks.end()) { auto output = outputNodesMap.find(name); - OPENVINO_ASSERT(outputNodesMap.end() != output, "Node with name: ", name, " is absent in the outputNodesMap"); + OPENVINO_ASSERT(outputNodesMap.end() != output, + "Node with name: ", + name, + " is absent in the outputNodesMap"); auto parentEdge = output->second->getParentEdgeAt(0); - //avoid cyclic memory use + // avoid cyclic memory use auto&& controlBlock = controlBlockItr->second; - std::shared_ptr memBlock = inputPtrs.count(controlBlock.rawPtr()) ? // same memory is used on the input and output - controlBlock.nextMemBlock() : // then swap internal buffer to avoid data corruption - controlBlock.currentMemBlock(); // else reuse the existing buffer + std::shared_ptr memBlock = + inputPtrs.count(controlBlock.rawPtr()) ? // same memory is used on the input and output + controlBlock.nextMemBlock() + : // then swap internal buffer to avoid data corruption + controlBlock.currentMemBlock(); // else reuse the existing buffer outputMemBlock->setMemBlockResize(memBlock); - DEBUG_LOG("reset proxy ", outputMemBlock, ", actual ", controlBlock.currentMemBlock(), " graph ", m_graph, " inferrequest ", this); + DEBUG_LOG("reset proxy ", + outputMemBlock, + ", actual ", + controlBlock.currentMemBlock(), + " graph ", + m_graph, + " inferrequest ", + this); DEBUG_LOG(name, ", tensor ", controlBlock.tensor()); } else { - outputMemBlock->reset(); // switch to the internal memory since memory sharing is no longer possible + outputMemBlock->reset(); // switch to the internal memory since memory sharing is no longer possible } } } @@ -456,12 +468,13 @@ void SyncInferRequest::set_tensor(const ov::Output& in_port, con } m_outputs[output_index] = tensor; - m_outputControlBlocks.erase(output_index); // now the memory is under user's control + m_outputControlBlocks.erase(output_index); // now the memory is under user's control } ov::ISyncInferRequest::set_tensor(port, tensor); } -void SyncInferRequest::set_tensors_impl(const ov::Output port, const std::vector>& tensors) { +void SyncInferRequest::set_tensors_impl(const ov::Output port, + const std::vector>& tensors) { if (find_port(port).is_input()) { m_batched_tensors[port.get_tensor_ptr()] = tensors; return; @@ -541,7 +554,8 @@ void SyncInferRequest::init_tensor(const std::size_t& port_index, const ov::ISyn } dnnl::engine eng(dnnl::engine::kind::cpu, 0); - CpuBlockedMemoryDescPtr desc = std::make_shared(model_prec, Shape{memDims}); + CpuBlockedMemoryDescPtr desc = + std::make_shared(model_prec, Shape{memDims}); auto memory = std::make_shared(eng, desc); tensor = std::make_shared(memory); @@ -551,12 +565,12 @@ void SyncInferRequest::init_tensor(const std::size_t& port_index, const ov::ISyn OutputControlBlock control_block{model_prec, Shape{shape}}; DEBUG_LOG(port_index, - ", tensor ", - control_block.tensor(), - ", memBlock ", - control_block.tensor()->get_memory()->getMemoryBlock(), - "memory object ", - control_block.tensor()->get_memory().get()); + ", tensor ", + control_block.tensor(), + ", memBlock ", + control_block.tensor()->get_memory()->getMemoryBlock(), + "memory object ", + control_block.tensor()->get_memory().get()); tensor = control_block.tensor(); if (model_prec == graph_prec) @@ -602,7 +616,7 @@ SyncInferRequest::OutputControlBlock::OutputControlBlock(const ov::element::Type m_proxyMemBlock = std::make_shared(m_buffers[m_buffIndx]); VectorDims memDims; - if (shape.isDynamic()) { // this is a WA since the ITensor doesn't allow dyn shapes + if (shape.isDynamic()) { // this is a WA since the ITensor doesn't allow dyn shapes for (auto&& item : shape.getDims()) { memDims.push_back(item != Shape::UNDEFINED_DIM ? item : 0); } @@ -610,8 +624,7 @@ SyncInferRequest::OutputControlBlock::OutputControlBlock(const ov::element::Type memDims = shape.getStaticDims(); } - CpuBlockedMemoryDescPtr desc = - std::make_shared(precision, Shape{memDims}); + CpuBlockedMemoryDescPtr desc = std::make_shared(precision, Shape{memDims}); auto memory = std::make_shared(eng, desc, m_proxyMemBlock); m_tensor = std::make_shared(memory); @@ -649,6 +662,5 @@ void SyncInferRequest::sub_streams_infer() { } } -} // namespace intel_cpu -} // namespace ov - +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/infer_request.h b/src/plugins/intel_cpu/src/infer_request.h index a9def63d359744..b66387ecc4d4d5 100644 --- a/src/plugins/intel_cpu/src/infer_request.h +++ b/src/plugins/intel_cpu/src/infer_request.h @@ -4,11 +4,11 @@ #pragma once -#include "graph.h" #include "cpu_tensor.h" +#include "graph.h" +#include "memory_state.h" #include "openvino/runtime/iinfer_request.hpp" #include "openvino/runtime/isync_infer_request.hpp" -#include "memory_state.h" namespace ov { namespace intel_cpu { @@ -29,7 +29,8 @@ class SyncInferRequest : public ov::ISyncInferRequest { void set_tensor(const ov::Output& port, const ov::SoPtr& tensor) override; - void set_tensors_impl(const ov::Output port, const std::vector>& tensors) override; + void set_tensors_impl(const ov::Output port, + const std::vector>& tensors) override; ov::SoPtr get_tensor(const ov::Output& port) const override; std::vector> get_tensors(const ov::Output& _port) const override; diff --git a/src/plugins/intel_cpu/src/memory_control.cpp b/src/plugins/intel_cpu/src/memory_control.cpp index 0f202c296891c1..26cd8459458b9d 100644 --- a/src/plugins/intel_cpu/src/memory_control.cpp +++ b/src/plugins/intel_cpu/src/memory_control.cpp @@ -16,8 +16,7 @@ namespace { class StaticPartitionMemoryBlock : public IMemoryBlockObserver { public: - StaticPartitionMemoryBlock(MemoryBlockPtr pBlock, ptrdiff_t offset) - : m_pBlock(pBlock), m_offset(offset) { + StaticPartitionMemoryBlock(MemoryBlockPtr pBlock, ptrdiff_t offset) : m_pBlock(pBlock), m_offset(offset) { OPENVINO_ASSERT(m_pBlock, "Memory block is uninitialized"); } @@ -92,7 +91,7 @@ class IMemoryManager { using MemoryManagerPtr = std::shared_ptr; -template +template std::shared_ptr makeDnnlMemoryBlock(Args&&... args) { return std::make_shared(make_unique(std::forward(args)...)); } @@ -152,10 +151,12 @@ class MemoryManagerStatic : public IMemoryManager { } void allocate() override { - if (m_workspace) m_workspace->resize(m_totalSize); + if (m_workspace) + m_workspace->resize(m_totalSize); } void release() override { - if (m_workspace) m_workspace->free(); + if (m_workspace) + m_workspace->free(); } private: @@ -171,14 +172,13 @@ class MemoryManageNonOverlapingSets : public IMemoryManager { void insert(const MemoryRegion& reg) override { MemorySolver::Box box = {reg.start, reg.finish, reg.size, reg.id}; if (-1 != reg.finish) { - //We have to extend the lifespan of tensors that are crossing a sync point border in order to save - //the intermediate computation results from possible loss due to the tensor resize - auto itr_upper = - std::upper_bound(m_syncInds.begin(), m_syncInds.end(), box.finish, [](int y, int x) { - return y <= x; - }); + // We have to extend the lifespan of tensors that are crossing a sync point border in order to save + // the intermediate computation results from possible loss due to the tensor resize + auto itr_upper = std::upper_bound(m_syncInds.begin(), m_syncInds.end(), box.finish, [](int y, int x) { + return y <= x; + }); auto itr_lower = std::lower_bound(m_syncInds.begin(), m_syncInds.end(), box.start); - if (itr_lower != itr_upper) { // across sections + if (itr_lower != itr_upper) { // across sections if (itr_upper == m_syncInds.end()) { box.finish = -1; } else { @@ -201,7 +201,7 @@ class MemoryManageNonOverlapingSets : public IMemoryManager { void solve() { ov::MemorySolver::normalize_boxes(m_boxes); - std::vector> groups; //groups of nonoverlapping boxes + std::vector> groups; // groups of nonoverlapping boxes groups.push_back({m_boxes.front()}); for (size_t i = 1; i < m_boxes.size(); ++i) { const auto& box = m_boxes[i]; @@ -229,7 +229,7 @@ class MemoryManageNonOverlapingSets : public IMemoryManager { } void allocate() override { - //nothing to do + // nothing to do } void release() override { for (auto&& item : m_internalBlocks) { @@ -305,15 +305,17 @@ MemoryControl::MemoryControl(std::vector syncInds) { })); // handler for static tensors - m_handlers.emplace_back(buildHandler([](const MemoryRegion& reg) { - if (reg.size >= 0 || MemoryRegion::RegionType::VARIABLE != reg.type || - MemoryRegion::AllocType::POD != reg.alloc_type) { - return false; - } - return true; - }, std::move(syncInds))); + m_handlers.emplace_back(buildHandler( + [](const MemoryRegion& reg) { + if (reg.size >= 0 || MemoryRegion::RegionType::VARIABLE != reg.type || + MemoryRegion::AllocType::POD != reg.alloc_type) { + return false; + } + return true; + }, + std::move(syncInds))); - //handler for I/O tensors, so far simply individual blocks + // handler for I/O tensors, so far simply individual blocks m_handlers.emplace_back(buildHandler([](const MemoryRegion& reg) { if (MemoryRegion::RegionType::VARIABLE == reg.type || reg.alloc_type != MemoryRegion::AllocType::POD) { return false; diff --git a/src/plugins/intel_cpu/src/memory_desc/blocked_memory_desc.cpp b/src/plugins/intel_cpu/src/memory_desc/blocked_memory_desc.cpp index 4b75d5c5263398..7dff6905df09d9 100644 --- a/src/plugins/intel_cpu/src/memory_desc/blocked_memory_desc.cpp +++ b/src/plugins/intel_cpu/src/memory_desc/blocked_memory_desc.cpp @@ -15,9 +15,9 @@ namespace intel_cpu { constexpr BlockedMemoryDesc::CmpMask BlockedMemoryDesc::FULL_MASK; constexpr BlockedMemoryDesc::CmpMask BlockedMemoryDesc::EMPTY_MASK; constexpr BlockedMemoryDesc::CmpMask BlockedMemoryDesc::SKIP_OFFSET_MASK; -constexpr size_t BlockedMemoryDesc::OFFSET_MASK_POS; +constexpr size_t BlockedMemoryDesc::OFFSET_MASK_POS; -bool BlockedMemoryDesc::isCompatibleInternal(const BlockedMemoryDesc &rhs, CmpMask cmpMask) const { +bool BlockedMemoryDesc::isCompatibleInternal(const BlockedMemoryDesc& rhs, CmpMask cmpMask) const { if (this->getShape() != rhs.getShape() || this->getPrecision() != rhs.getPrecision()) return false; @@ -77,5 +77,5 @@ std::string BlockedMemoryDesc::serializeFormat() const { return result.str(); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_desc/blocked_memory_desc.h b/src/plugins/intel_cpu/src/memory_desc/blocked_memory_desc.h index d938a4ba585602..9ff132965bdc0b 100644 --- a/src/plugins/intel_cpu/src/memory_desc/blocked_memory_desc.h +++ b/src/plugins/intel_cpu/src/memory_desc/blocked_memory_desc.h @@ -21,7 +21,7 @@ class BlockedMemoryDesc : public virtual MemoryDesc { static constexpr CmpMask FULL_MASK{0xffffffff}; static constexpr CmpMask EMPTY_MASK{0x0}; static constexpr CmpMask SKIP_OFFSET_MASK{0x7fffffff}; - static constexpr size_t OFFSET_MASK_POS{31}; + static constexpr size_t OFFSET_MASK_POS{31}; /** * @brief Returns the blocked dimensions @@ -73,7 +73,7 @@ class BlockedMemoryDesc : public virtual MemoryDesc { * * @return the result of the compatibility check */ - virtual bool isCompatible(const BlockedMemoryDesc &rhs, CmpMask cmpMask) const = 0; + virtual bool isCompatible(const BlockedMemoryDesc& rhs, CmpMask cmpMask) const = 0; using MemoryDesc::isCompatible; ~BlockedMemoryDesc() override = default; @@ -88,7 +88,7 @@ class BlockedMemoryDesc : public virtual MemoryDesc { * Doesn't perform descs specific attributes check * @return true if compatible, otherwise false */ - bool isCompatibleInternal(const BlockedMemoryDesc &rhs, CmpMask cmpMask = FULL_MASK) const; + bool isCompatibleInternal(const BlockedMemoryDesc& rhs, CmpMask cmpMask = FULL_MASK) const; mutable VectorDims blockedDims; mutable VectorDims strides; @@ -99,5 +99,5 @@ class BlockedMemoryDesc : public virtual MemoryDesc { using BlockedMemoryDescPtr = std::shared_ptr; using BlockedMemoryDescCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_desc/cpu_blocked_memory_desc.cpp b/src/plugins/intel_cpu/src/memory_desc/cpu_blocked_memory_desc.cpp index d1c50d0048c57d..c95463207a9c46 100644 --- a/src/plugins/intel_cpu/src/memory_desc/cpu_blocked_memory_desc.cpp +++ b/src/plugins/intel_cpu/src/memory_desc/cpu_blocked_memory_desc.cpp @@ -3,6 +3,7 @@ // #include "cpu_blocked_memory_desc.h" + #include "dnnl_blocked_memory_desc.h" #include "utils/general_utils.h" @@ -15,17 +16,27 @@ static VectorDims makeRange(size_t size) { return retVec; } -CpuBlockedMemoryDesc::CpuBlockedMemoryDesc(ov::element::Type prc, const Shape& shape) : - CpuBlockedMemoryDesc(prc, shape, shape.getDims(), makeRange(shape.getDims().size())) {} - -CpuBlockedMemoryDesc::CpuBlockedMemoryDesc(ov::element::Type prc, const Shape& shape, const VectorDims& blockedDims, - const VectorDims& order, size_t offsetPadding, const VectorDims& offsetPaddingToData, - const VectorDims& strides) : MemoryDesc(shape, Blocked), precision(prc) { - if (std::any_of(order.begin(), order.end(), [](size_t val) { return val == Shape::UNDEFINED_DIM; })) { +CpuBlockedMemoryDesc::CpuBlockedMemoryDesc(ov::element::Type prc, const Shape& shape) + : CpuBlockedMemoryDesc(prc, shape, shape.getDims(), makeRange(shape.getDims().size())) {} + +CpuBlockedMemoryDesc::CpuBlockedMemoryDesc(ov::element::Type prc, + const Shape& shape, + const VectorDims& blockedDims, + const VectorDims& order, + size_t offsetPadding, + const VectorDims& offsetPaddingToData, + const VectorDims& strides) + : MemoryDesc(shape, Blocked), + precision(prc) { + if (std::any_of(order.begin(), order.end(), [](size_t val) { + return val == Shape::UNDEFINED_DIM; + })) { OPENVINO_THROW("CpuBlockedMemoryDesc do not support undefined order."); } - if (std::any_of(blockedDims.begin() + shape.getRank(), blockedDims.end(), [](size_t val) { return val == Shape::UNDEFINED_DIM; })) { + if (std::any_of(blockedDims.begin() + shape.getRank(), blockedDims.end(), [](size_t val) { + return val == Shape::UNDEFINED_DIM; + })) { OPENVINO_THROW("CpuBlockedMemoryDesc doesn't support undefined blockedDims."); } @@ -51,29 +62,43 @@ CpuBlockedMemoryDesc::CpuBlockedMemoryDesc(ov::element::Type prc, const Shape& s if (strides.empty() && !order.empty()) { if (shape.hasZeroDims()) { this->strides.resize(order.size(), 0); - } else if (std::any_of(this->blockedDims.begin(), this->blockedDims.end(), [](size_t val) { return val == Shape::UNDEFINED_DIM; })) { + } else if (std::any_of(this->blockedDims.begin(), this->blockedDims.end(), [](size_t val) { + return val == Shape::UNDEFINED_DIM; + })) { this->strides.resize(order.size(), Shape::UNDEFINED_DIM); } else { this->strides.resize(order.size(), 1); for (size_t i = 2; i <= order.size(); i++) { - this->strides[order.size() - i] = this->strides[order.size() - (i - 1)] * this->blockedDims[blockedDims.size() - (i - 1)]; + this->strides[order.size() - i] = + this->strides[order.size() - (i - 1)] * this->blockedDims[blockedDims.size() - (i - 1)]; } } } else { this->strides = strides; } - if (!everyone_is(this->order.size(), this->blockedDims.size(), this->offsetPaddingToData.size(), this->strides.size())) { + if (!everyone_is(this->order.size(), + this->blockedDims.size(), + this->offsetPaddingToData.size(), + this->strides.size())) { OPENVINO_THROW("Order, blocked dims, offset padding to data and strides must have equals size"); } } bool CpuBlockedMemoryDesc::isDefinedImp() const { bool defined = true; - defined = defined && std::none_of(blockedDims.cbegin(), blockedDims.cend(), [](size_t val) { return val == Shape::UNDEFINED_DIM; }); - defined = defined && std::none_of(strides.cbegin(), strides.cend(), [](size_t val) { return val == Shape::UNDEFINED_DIM; }); - defined = defined && std::none_of(order.cbegin(), order.cend(), [](size_t val) { return val == Shape::UNDEFINED_DIM; }); - defined = defined && std::none_of(offsetPaddingToData.cbegin(), offsetPaddingToData.cend(), [](size_t val) { return val == Shape::UNDEFINED_DIM; }); + defined = defined && std::none_of(blockedDims.cbegin(), blockedDims.cend(), [](size_t val) { + return val == Shape::UNDEFINED_DIM; + }); + defined = defined && std::none_of(strides.cbegin(), strides.cend(), [](size_t val) { + return val == Shape::UNDEFINED_DIM; + }); + defined = defined && std::none_of(order.cbegin(), order.cend(), [](size_t val) { + return val == Shape::UNDEFINED_DIM; + }); + defined = defined && std::none_of(offsetPaddingToData.cbegin(), offsetPaddingToData.cend(), [](size_t val) { + return val == Shape::UNDEFINED_DIM; + }); defined = defined && offsetPadding != Shape::UNDEFINED_DIM; return defined; @@ -90,15 +115,15 @@ bool CpuBlockedMemoryDesc::isCompatible(const MemoryDesc& rhs) const { } } -bool CpuBlockedMemoryDesc::isCompatible(const CpuBlockedMemoryDesc &rhs, CmpMask cmpMask) const { +bool CpuBlockedMemoryDesc::isCompatible(const CpuBlockedMemoryDesc& rhs, CmpMask cmpMask) const { return BlockedMemoryDesc::isCompatibleInternal(rhs, cmpMask); } -bool CpuBlockedMemoryDesc::isCompatible(const DnnlBlockedMemoryDesc &rhs, CmpMask cmpMask) const { +bool CpuBlockedMemoryDesc::isCompatible(const DnnlBlockedMemoryDesc& rhs, CmpMask cmpMask) const { return rhs.isCompatible(*this, cmpMask); } -bool CpuBlockedMemoryDesc::isCompatible(const BlockedMemoryDesc &rhs, CmpMask cmpMask) const { +bool CpuBlockedMemoryDesc::isCompatible(const BlockedMemoryDesc& rhs, CmpMask cmpMask) const { const BlockedMemoryDesc* pRhs = &rhs; if (auto cpuBlkDesc = dynamic_cast(pRhs)) { return isCompatible(*cpuBlkDesc, cmpMask); @@ -149,7 +174,9 @@ size_t CpuBlockedMemoryDesc::getMaxMemSize() const { } const auto& maxDims = shape.getMaxDims(); - if (std::any_of(maxDims.begin(), maxDims.end(), [](size_t x){ return Shape::UNDEFINED_DIM == x; })) { + if (std::any_of(maxDims.begin(), maxDims.end(), [](size_t x) { + return Shape::UNDEFINED_DIM == x; + })) { return UNDEFINED_SIZE; } @@ -193,16 +220,16 @@ size_t CpuBlockedMemoryDesc::getElementOffset(size_t elemNumber) const { bool CpuBlockedMemoryDesc::hasLayoutType(LayoutType layoutType) const { switch (layoutType) { - case LayoutType::ncsp: - return isPlainFormat(); - case LayoutType::nspc: - return isTailCFormat(); - case LayoutType::nCsp8c: - return isBlockedCFormat(8); - case LayoutType::nCsp16c: - return isBlockedCFormat(16); - default: - return false; + case LayoutType::ncsp: + return isPlainFormat(); + case LayoutType::nspc: + return isTailCFormat(); + case LayoutType::nCsp8c: + return isBlockedCFormat(8); + case LayoutType::nCsp16c: + return isBlockedCFormat(16); + default: + return false; } } @@ -252,13 +279,15 @@ bool CpuBlockedMemoryDesc::isTailCFormat() const { return true; } -MemoryDescPtr CpuBlockedMemoryDesc::cloneWithNewDimsImp(const VectorDims &dims) const { - if (std::any_of(dims.begin(), dims.end(), [](size_t x){ return Shape::UNDEFINED_DIM == x; })) { +MemoryDescPtr CpuBlockedMemoryDesc::cloneWithNewDimsImp(const VectorDims& dims) const { + if (std::any_of(dims.begin(), dims.end(), [](size_t x) { + return Shape::UNDEFINED_DIM == x; + })) { OPENVINO_THROW("Can't clone desc if new dims are undefined"); } // TODO [DS]: add stride recalculation for strided blobs - for (int i = strides.size() - 2; i >= 0 ; i--) { + for (int i = strides.size() - 2; i >= 0; i--) { if (strides[i] == Shape::UNDEFINED_DIM) break; @@ -280,11 +309,18 @@ MemoryDescPtr CpuBlockedMemoryDesc::cloneWithNewDimsImp(const VectorDims &dims) } VectorDims newOffsetPaddingToData; - if (std::none_of(offsetPaddingToData.begin(), offsetPaddingToData.end(), [](size_t x){ return x == Shape::UNDEFINED_DIM;})) { + if (std::none_of(offsetPaddingToData.begin(), offsetPaddingToData.end(), [](size_t x) { + return x == Shape::UNDEFINED_DIM; + })) { newOffsetPaddingToData = offsetPaddingToData; } - return std::make_shared(precision, Shape(dims), newBlockedDims, order, offsetPadding, newOffsetPaddingToData); + return std::make_shared(precision, + Shape(dims), + newBlockedDims, + order, + offsetPadding, + newOffsetPaddingToData); } bool CpuBlockedMemoryDesc::blocksExtended() const { @@ -311,7 +347,9 @@ size_t CpuBlockedMemoryDesc::getPaddedElementsCount() const { if (getShape().hasZeroDims()) { return 0; } - if (std::any_of(blockedDims.begin(), blockedDims.end(), [](Dim dim) { return dim == Shape::UNDEFINED_DIM; })) { + if (std::any_of(blockedDims.begin(), blockedDims.end(), [](Dim dim) { + return dim == Shape::UNDEFINED_DIM; + })) { OPENVINO_THROW("Can't compute padded elements count for non undefined blocked dims"); } return std::accumulate(blockedDims.begin(), blockedDims.end(), size_t{1}, std::multiplies()); @@ -323,5 +361,5 @@ MemoryDescPtr CpuBlockedMemoryDesc::cloneWithNewPrecision(const ov::element::Typ return newDesc; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_desc/cpu_blocked_memory_desc.h b/src/plugins/intel_cpu/src/memory_desc/cpu_blocked_memory_desc.h index 28badb4dac15fb..fdf931a262e854 100644 --- a/src/plugins/intel_cpu/src/memory_desc/cpu_blocked_memory_desc.h +++ b/src/plugins/intel_cpu/src/memory_desc/cpu_blocked_memory_desc.h @@ -16,8 +16,12 @@ class CpuBlockedMemoryDesc : public BlockedMemoryDesc { public: CpuBlockedMemoryDesc(ov::element::Type prc, const Shape& shape); - CpuBlockedMemoryDesc(ov::element::Type prc, const Shape& shape, const VectorDims& blockedDims, - const VectorDims& order, size_t offsetPadding = 0, const VectorDims& offsetPaddingToData = {}, + CpuBlockedMemoryDesc(ov::element::Type prc, + const Shape& shape, + const VectorDims& blockedDims, + const VectorDims& order, + size_t offsetPadding = 0, + const VectorDims& offsetPaddingToData = {}, const VectorDims& strides = {}); MemoryDescPtr clone() const override { @@ -26,8 +30,8 @@ class CpuBlockedMemoryDesc : public BlockedMemoryDesc { bool isCompatible(const MemoryDesc& rhs) const override; bool isCompatible(const BlockedMemoryDesc& rhs, CmpMask cmpMask) const override; - bool isCompatible(const CpuBlockedMemoryDesc &rhs, CmpMask cmpMask = BlockedMemoryDesc::FULL_MASK) const; - bool isCompatible(const DnnlBlockedMemoryDesc &rhs, CmpMask cmpMask = BlockedMemoryDesc::FULL_MASK) const; + bool isCompatible(const CpuBlockedMemoryDesc& rhs, CmpMask cmpMask = BlockedMemoryDesc::FULL_MASK) const; + bool isCompatible(const DnnlBlockedMemoryDesc& rhs, CmpMask cmpMask = BlockedMemoryDesc::FULL_MASK) const; ov::element::Type getPrecision() const override { return precision; @@ -105,5 +109,5 @@ class CpuBlockedMemoryDesc : public BlockedMemoryDesc { using CpuBlockedMemoryDescPtr = std::shared_ptr; using CpuBlockedMemoryDescCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc.h b/src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc.h index c3936528abed7b..e6d260066118ee 100644 --- a/src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc.h +++ b/src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc.h @@ -22,7 +22,7 @@ namespace ov { namespace intel_cpu { namespace node { class Split; -} // namespace node +} // namespace node class MemoryDesc; @@ -39,10 +39,10 @@ enum MemoryDescType { }; enum class LayoutType : unsigned { - nspc, // general per channels format - ncsp, // general planar - nCsp8c, // general channels blocked by 8 - nCsp16c // general channels blocked by 16 + nspc, // general per channels format + ncsp, // general planar + nCsp8c, // general channels blocked by 8 + nCsp16c // general channels blocked by 16 }; class MemoryDesc { @@ -70,8 +70,8 @@ class MemoryDesc { /** * @brief Clone descriptor with new dims. - * Throws an exception if relaxedCheck is false and some of the new dims conflicts with the internal shape (i.e. its defined dims ,rank, upper bounds) - * or if internal shape and dims have different ranks + * Throws an exception if relaxedCheck is false and some of the new dims conflicts with the internal shape (i.e. its + * defined dims ,rank, upper bounds) or if internal shape and dims have different ranks * @param dims new dims * @param relaxedCheck flag which defined must we check dims with internal desc on compatibility * @return MemoryDescPtr with new dims @@ -136,8 +136,8 @@ class MemoryDesc { } template ::value && !std::is_reference::value, int>::type = 0, - typename std::enable_if::value, int>::type = 0> + typename std::enable_if::value && !std::is_reference::value, int>::type = 0, + typename std::enable_if::value, int>::type = 0> T* as() { T* casted = dynamic_cast(this); if (!casted) @@ -146,8 +146,8 @@ class MemoryDesc { } template ::value && !std::is_reference::value, int>::type = 0, - typename std::enable_if::value, int>::type = 0> + typename std::enable_if::value && !std::is_reference::value, int>::type = 0, + typename std::enable_if::value, int>::type = 0> const T* as() const { const T* casted = dynamic_cast(this); if (!casted) @@ -159,17 +159,16 @@ class MemoryDesc { protected: MemoryDesc() : type(MemoryDescType::Undef) {} - MemoryDesc(Shape shape, MemoryDescType type) - : type(type), shape(std::move(shape)) {} + MemoryDesc(Shape shape, MemoryDescType type) : type(type), shape(std::move(shape)) {} - MemoryDesc(const VectorDims& dims, MemoryDescType type) - : type(type), shape(dims) {} + MemoryDesc(const VectorDims& dims, MemoryDescType type) : type(type), shape(dims) {} virtual void setPrecision(ov::element::Type prc) = 0; virtual size_t getCurrentMemSizeImp() const = 0; - // Get offset to the n'th element. Returns physical index of the element by the logical one considering padding, layout, blocking etc. + // Get offset to the n'th element. Returns physical index of the element by the logical one considering padding, + // layout, blocking etc. virtual size_t getElementOffset(size_t elemNumber) const = 0; virtual bool canComputeMemSizeZeroDims() const = 0; @@ -195,5 +194,5 @@ class MemoryDesc { friend class node::Split; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc_utils.cpp b/src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc_utils.cpp index 0ae17d6c00322b..2937b73409b67d 100644 --- a/src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc_utils.cpp +++ b/src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc_utils.cpp @@ -4,29 +4,33 @@ #include "memory_desc/cpu_memory_desc_utils.h" -#include "memory_desc/cpu_blocked_memory_desc.h" -#include "memory_desc/dnnl_blocked_memory_desc.h" -#include "graph_context.h" -#include "cpu_memory_desc.h" -#include "memory_desc/empty_memory_desc.h" -#include -#include #include #include + #include #include +#include "cpu_memory_desc.h" +#include "graph_context.h" +#include "memory_desc/cpu_blocked_memory_desc.h" +#include "memory_desc/dnnl_blocked_memory_desc.h" +#include "memory_desc/empty_memory_desc.h" + using namespace dnnl; namespace ov { namespace intel_cpu { -DnnlMemoryDescPtr MemoryDescUtils::convertToDnnlMemoryDesc(const MemoryDescPtr &desc) { +DnnlMemoryDescPtr MemoryDescUtils::convertToDnnlMemoryDesc(const MemoryDescPtr& desc) { if (MemoryDescType::Blocked == desc->getType()) { const auto cpuDesc = desc->as(); - return std::shared_ptr(new DnnlBlockedMemoryDesc(cpuDesc->getPrecision(), cpuDesc->getShape(), cpuDesc->getBlockDims(), - cpuDesc->getOrder(), cpuDesc->getOffsetPadding(), - cpuDesc->getOffsetPaddingToData(), cpuDesc->getStrides())); + return std::shared_ptr(new DnnlBlockedMemoryDesc(cpuDesc->getPrecision(), + cpuDesc->getShape(), + cpuDesc->getBlockDims(), + cpuDesc->getOrder(), + cpuDesc->getOffsetPadding(), + cpuDesc->getOffsetPaddingToData(), + cpuDesc->getStrides())); } else if (MemoryDescType::Empty == desc->getType()) { return DnnlExtensionUtils::makeDescriptor(dnnl::memory::desc()); } else if (MemoryDescType::Dnnl & desc->getType()) { @@ -41,14 +45,19 @@ DnnlBlockedMemoryDesc MemoryDescUtils::convertToDnnlBlockedMemoryDesc(const Memo return DnnlBlockedMemoryDesc(*desc.as()); } else if (MemoryDescType::Blocked == desc.getType()) { const auto cpuDesc = desc.as(); - return DnnlBlockedMemoryDesc(cpuDesc->getPrecision(), cpuDesc->getShape(), cpuDesc->getBlockDims(), cpuDesc->getOrder(), cpuDesc->getOffsetPadding(), - cpuDesc->getOffsetPaddingToData(), cpuDesc->getStrides()); + return DnnlBlockedMemoryDesc(cpuDesc->getPrecision(), + cpuDesc->getShape(), + cpuDesc->getBlockDims(), + cpuDesc->getOrder(), + cpuDesc->getOffsetPadding(), + cpuDesc->getOffsetPaddingToData(), + cpuDesc->getStrides()); } else { OPENVINO_THROW("Cannot convert MemoryDesc to DnnlBlockedMemoryDesc"); } } -BlockedMemoryDescPtr MemoryDescUtils::convertToBlockedMemoryDesc(const MemoryDescPtr &desc) { +BlockedMemoryDescPtr MemoryDescUtils::convertToBlockedMemoryDesc(const MemoryDescPtr& desc) { if (desc->getType() & MemoryDescType::Blocked) { return std::dynamic_pointer_cast(desc); } else { @@ -57,7 +66,7 @@ BlockedMemoryDescPtr MemoryDescUtils::convertToBlockedMemoryDesc(const MemoryDes } CpuBlockedMemoryDescPtr MemoryDescUtils::generateCpuBlockedMemoryDesc(const ov::SoPtr& tensor) { - const auto& shape = tensor->get_shape().empty() ? ov::Shape{tensor->get_size()} : tensor->get_shape(); + const auto& shape = tensor->get_shape().empty() ? ov::Shape{tensor->get_size()} : tensor->get_shape(); VectorDims blk_order(shape.size()); std::iota(blk_order.begin(), blk_order.end(), 0); @@ -87,17 +96,16 @@ CpuBlockedMemoryDescPtr MemoryDescUtils::generateCpuBlockedMemoryDesc(const ov:: }); } - return std::make_shared( - element_type, - Shape{shape}, - shape, - blk_order, - 0UL, - VectorDims{}, - blk_strides); + return std::make_shared(element_type, + Shape{shape}, + shape, + blk_order, + 0UL, + VectorDims{}, + blk_strides); } -std::shared_ptr MemoryDescUtils::makeDummyDesc(const MemoryDesc &desc, Dim dummyVal) { +std::shared_ptr MemoryDescUtils::makeDummyDesc(const MemoryDesc& desc, Dim dummyVal) { auto dummyShape = makeDummyShape(desc.getShape(), dummyVal); return desc.cloneWithNewDims(dummyShape.getStaticDims()); } @@ -111,7 +119,7 @@ std::shared_ptr MemoryDescUtils::makeEmptyMemory(const GraphContext::CP return std::make_shared(context->getEngine(), makeEmptyDesc(), nullptr); } -Shape MemoryDescUtils::makeDummyShape(const Shape &shape, Dim dummyVal) { +Shape MemoryDescUtils::makeDummyShape(const Shape& shape, Dim dummyVal) { const auto& minDims = shape.getMinDims(); const auto& maxDims = shape.getMaxDims(); const auto& dims = shape.getDims(); @@ -122,7 +130,7 @@ Shape MemoryDescUtils::makeDummyShape(const Shape &shape, Dim dummyVal) { return Shape(dummyDims); } -Shape MemoryDescUtils::makeDummyShape(const Shape &shape, const VectorDims& dummyVals) { +Shape MemoryDescUtils::makeDummyShape(const Shape& shape, const VectorDims& dummyVals) { if (shape.getRank() != dummyVals.size()) { OPENVINO_THROW("makeDummyShape(): dummyVals vector size and shape ranks mismatch"); } @@ -131,9 +139,10 @@ Shape MemoryDescUtils::makeDummyShape(const Shape &shape, const VectorDims& dumm const auto& dims = shape.getDims(); VectorDims dummyDims(dims.size()); for (size_t i = 0; i < dims.size(); ++i) { - dummyDims[i] = dims[i] == Shape::UNDEFINED_DIM ? std::min(maxDims[i], std::max(minDims[i], dummyVals[i])) : dims[i]; + dummyDims[i] = + dims[i] == Shape::UNDEFINED_DIM ? std::min(maxDims[i], std::max(minDims[i], dummyVals[i])) : dims[i]; } return Shape(dummyDims); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc_utils.h b/src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc_utils.h index a4acd3eb2aa778..388c9a21c5df8e 100644 --- a/src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc_utils.h +++ b/src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc_utils.h @@ -5,11 +5,12 @@ #pragma once #include + #include "cpu_shape.h" #include "cpu_types.h" +#include "graph_context.h" #include "openvino/runtime/itensor.hpp" #include "openvino/runtime/so_ptr.hpp" -#include "graph_context.h" namespace ov { namespace intel_cpu { @@ -32,7 +33,7 @@ class MemoryDescUtils { * @param desc MemoryDesc to be converted * @return converted DnnlMemoryDesc */ - static std::shared_ptr convertToDnnlMemoryDesc(const std::shared_ptr &desc); + static std::shared_ptr convertToDnnlMemoryDesc(const std::shared_ptr& desc); /** * @brief Converts MemoryDesc to DnnlBlockedMemoryDesc @@ -46,7 +47,7 @@ class MemoryDescUtils { * @param desc MemoryDesc to be converted * @return converted BlockedMemoryDesc */ - static std::shared_ptr convertToBlockedMemoryDesc(const std::shared_ptr &desc); + static std::shared_ptr convertToBlockedMemoryDesc(const std::shared_ptr& desc); /** * @brief Builds CpuBlockedMemoryDesc for given ov::ITensor @@ -58,7 +59,8 @@ class MemoryDescUtils { static constexpr Dim DEFAULT_DUMMY_VAL = 64; /** - * @brief Makes a dummy descriptor where all undefined values are replaced with the smallest value between the parameter and the upper bound dim + * @brief Makes a dummy descriptor where all undefined values are replaced with the smallest value between the + * parameter and the upper bound dim * @param desc MemoryDesc from which the new descriptor is generated * @param dummyVal Dim value to replace undefined dimensions * @return a new MemoryDesc with dummy values instead of undefined dims @@ -66,27 +68,29 @@ class MemoryDescUtils { static std::shared_ptr makeDummyDesc(const MemoryDesc& desc, Dim dummyVal = DEFAULT_DUMMY_VAL); /** - * @brief Make an empty memory descriptor - * @note Shape{0}, undefined - * @return empty memory descriptor - */ + * @brief Make an empty memory descriptor + * @note Shape{0}, undefined + * @return empty memory descriptor + */ static std::shared_ptr makeEmptyDesc(); static std::shared_ptr makeEmptyMemory(const GraphContext::CPtr context); /** - * @brief Makes a static dummy shape where all undefined values are replaced with the smallest value between the parameter and the upper bound dim - * @param shape a Shape object from which the new static shape is generated - * @param dummyVal Dim value to replace undefined dimensions - * @return a new Shape with dummy values instead of undefined dims - */ + * @brief Makes a static dummy shape where all undefined values are replaced with the smallest value between the + * parameter and the upper bound dim + * @param shape a Shape object from which the new static shape is generated + * @param dummyVal Dim value to replace undefined dimensions + * @return a new Shape with dummy values instead of undefined dims + */ static Shape makeDummyShape(const Shape& shape, Dim dummyVal = DEFAULT_DUMMY_VAL); /** - * @brief Makes a static dummy shape where all undefined values are replaced with the smallest value between the parameter and the upper bound dim - * @param shape a Shape object from which the new static shape is generated - * @param dummyVals vector of values to replace undefined dimensions - * @return a new Shape with dummy values instead of undefined dims - */ + * @brief Makes a static dummy shape where all undefined values are replaced with the smallest value between the + * parameter and the upper bound dim + * @param shape a Shape object from which the new static shape is generated + * @param dummyVals vector of values to replace undefined dimensions + * @return a new Shape with dummy values instead of undefined dims + */ static Shape makeDummyShape(const Shape& shape, const VectorDims& dummyVals); /** @@ -104,5 +108,5 @@ class MemoryDescUtils { static std::string dims2str(const VectorDims& dims); }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_desc/dnnl_blocked_memory_desc.cpp b/src/plugins/intel_cpu/src/memory_desc/dnnl_blocked_memory_desc.cpp index a24b55831c2c7c..38c020674c7168 100644 --- a/src/plugins/intel_cpu/src/memory_desc/dnnl_blocked_memory_desc.cpp +++ b/src/plugins/intel_cpu/src/memory_desc/dnnl_blocked_memory_desc.cpp @@ -4,26 +4,28 @@ #include "memory_desc/dnnl_blocked_memory_desc.h" +#include #include +#include #include + #include "cpu_types.h" #include "dnnl_extension_utils.h" #include "memory_desc/cpu_blocked_memory_desc.h" #include "utils/general_utils.h" -#include -#include - namespace ov { namespace intel_cpu { DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(ov::element::Type prc, const Shape& shape, const VectorDims& strides) : MemoryDesc(shape, DnnlBlocked) { const auto ndims = shape.getRank(); - const auto &dims = shape.getDims(); + const auto& dims = shape.getDims(); - if (!strides.empty()) { // custom strides - if (shape.hasZeroDims() && std::any_of(strides.begin(), strides.end(), [](size_t stride) { return stride != 0; } )) { + if (!strides.empty()) { // custom strides + if (shape.hasZeroDims() && std::any_of(strides.begin(), strides.end(), [](size_t stride) { + return stride != 0; + })) { OPENVINO_THROW("Can't create DnnlBlockedMemoryDesc with zero dim, but with non zero strides"); } desc = {DnnlExtensionUtils::convertToDnnlDims(dims), @@ -33,16 +35,20 @@ DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(ov::element::Type prc, const Shape& dnnl::memory::dims plain_strides; if (shape.hasZeroDims()) { plain_strides.resize(ndims, 0); - } else if (std::any_of(dims.begin(), dims.end(), [](size_t val) { return val == Shape::UNDEFINED_DIM; })) { + } else if (std::any_of(dims.begin(), dims.end(), [](size_t val) { + return val == Shape::UNDEFINED_DIM; + })) { plain_strides.resize(ndims, DNNL_RUNTIME_DIM_VAL); } else { plain_strides.resize(ndims, 1); for (size_t i = 1; i < ndims; i++) { - plain_strides[ndims - i -1] = plain_strides[ndims - i] * dims[ndims - i]; + plain_strides[ndims - i - 1] = plain_strides[ndims - i] * dims[ndims - i]; } } - desc = {DnnlExtensionUtils::convertToDnnlDims(dims), DnnlExtensionUtils::ElementTypeToDataType(prc), plain_strides}; + desc = {DnnlExtensionUtils::convertToDnnlDims(dims), + DnnlExtensionUtils::ElementTypeToDataType(prc), + plain_strides}; } order.resize(ndims); @@ -55,11 +61,12 @@ DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(ov::element::Type prc, const Shape& * Construct from blocked parameters * * OV IOhw_4i16o4i dims(N) = {32, 64, 128, 128} - * blockedDims {4, 2, 128, 128, 4, 16, 4} // total dims(inner, outermost, auto blocked/padded). Generally sorted by strides. - * strides {8388608, 4194304, 32768, 256, 64, 4, 1} // strides for blockedDims, growing sequence - * order {1, 0, 2, 3, 1, 0, 1} // matching to original dims + * blockedDims {4, 2, 128, 128, 4, 16, 4} // total dims(inner, outermost, auto blocked/padded). + * Generally sorted by strides. strides {8388608, 4194304, 32768, 256, 64, 4, 1} // strides for blockedDims, + * growing sequence order {1, 0, 2, 3, 1, 0, 1} // matching to original dims * - * All vectors blockedDims/strides/order have same size equals total num of internal blocked dims(inner_dims + outer_dims) + * All vectors blockedDims/strides/order have same size equals total num of internal blocked dims(inner_dims + + * outer_dims) * * Tensor descriptor filing is not deterministic. It allows any permutation of index which keeps order of * real dims spliting. @@ -70,9 +77,14 @@ DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(ov::element::Type prc, const Shape& * * Limitation of conversion first N elements of order should be permutation of [0,1,2 ... N] */ -DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(ov::element::Type prc, const Shape& shape, const VectorDims& blockedDims, - const VectorDims& order, size_t offsetPadding, const VectorDims& offsetPaddingToData, - const VectorDims& strides) : MemoryDesc(shape, DnnlBlocked) { +DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(ov::element::Type prc, + const Shape& shape, + const VectorDims& blockedDims, + const VectorDims& order, + size_t offsetPadding, + const VectorDims& offsetPaddingToData, + const VectorDims& strides) + : MemoryDesc(shape, DnnlBlocked) { using namespace dnnl; // scalar case if (shape.getRank() == 0) { @@ -128,7 +140,9 @@ DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(ov::element::Type prc, const Shape& const bool emptyDesc = shape.hasZeroDims(); if (!strides.empty()) { - if (emptyDesc && std::any_of(strides.begin(), strides.end(), [](size_t dim) { return dim != 0; } )) { + if (emptyDesc && std::any_of(strides.begin(), strides.end(), [](size_t dim) { + return dim != 0; + })) { OPENVINO_THROW("Can't create DnnlBlockedMemoryDesc with zero dim, but with non zero strides"); } @@ -143,7 +157,9 @@ DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(ov::element::Type prc, const Shape& OPENVINO_THROW("Can not construct DnnlBlockedMemoryDesc from strides: ", vec2str(strides)); } - if (!strides.empty() && !emptyDesc && std::none_of(strides.begin(), strides.end(), [](size_t x) { return Shape::UNDEFINED_DIM == x; })) { + if (!strides.empty() && !emptyDesc && std::none_of(strides.begin(), strides.end(), [](size_t x) { + return Shape::UNDEFINED_DIM == x; + })) { bool inner_block_are_dense = one_of(strides.back(), 0u, 1u); // stride 1 - is dense case, 0 - broad casted for (size_t i = outer_ndims; i < strides.size() - 1; i++) { inner_block_are_dense &= (strides[i] == strides[i + 1] * blockedDims[i + 1]); @@ -164,8 +180,10 @@ DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(ov::element::Type prc, const Shape& std::copy(dims.begin(), dims.end(), desc.get()->dims); if (!offsetPaddingToData.empty()) { - bool inner_pad_offsets_is_zero = std::all_of(offsetPaddingToData.begin() + outer_ndims, offsetPaddingToData.end(), - [](size_t pad) { return pad == 0; }); + bool inner_pad_offsets_is_zero = + std::all_of(offsetPaddingToData.begin() + outer_ndims, offsetPaddingToData.end(), [](size_t pad) { + return pad == 0; + }); if (!inner_pad_offsets_is_zero) OPENVINO_THROW("Can not construct DnnlBlockedMemoryDesc, inner pad offsets is not zero: ", @@ -189,7 +207,7 @@ DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(ov::element::Type prc, const Shape& } // Fill blocking desc - auto &dnn_blk_desc = desc.get()->format_desc.blocking; + auto& dnn_blk_desc = desc.get()->format_desc.blocking; dnn_blk_desc.inner_nblks = inner_ndims; std::copy(dnnlBlkDims.end() - inner_ndims, dnnlBlkDims.end(), dnn_blk_desc.inner_blks); std::copy(order.end() - inner_ndims, order.end(), dnn_blk_desc.inner_idxs); @@ -209,8 +227,10 @@ DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(ov::element::Type prc, const Shape& } } -DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(const Shape& shape, dnnl::memory::data_type dataType, dnnl::memory::format_tag format) : - MemoryDesc(shape, DnnlBlocked) { +DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(const Shape& shape, + dnnl::memory::data_type dataType, + dnnl::memory::format_tag format) + : MemoryDesc(shape, DnnlBlocked) { using namespace dnnl; if (format == memory::format_tag::any || format == memory::format_tag::undef) OPENVINO_THROW("Unexpected: Can't create dnnl::desc with any or undef format"); @@ -249,7 +269,7 @@ bool DnnlBlockedMemoryDesc::isCompatible(const MemoryDesc& rhs) const { } } -bool DnnlBlockedMemoryDesc::isCompatible(const BlockedMemoryDesc &rhs, CmpMask cmpMask) const { +bool DnnlBlockedMemoryDesc::isCompatible(const BlockedMemoryDesc& rhs, CmpMask cmpMask) const { if (auto desc = dynamic_cast(&rhs)) { return isCompatible(*desc, cmpMask); } else if (auto desc = dynamic_cast(&rhs)) { @@ -261,7 +281,8 @@ bool DnnlBlockedMemoryDesc::isCompatible(const BlockedMemoryDesc &rhs, CmpMask c bool DnnlBlockedMemoryDesc::isCompatible(const CpuBlockedMemoryDesc& rhs, CmpMask cmpMask) const { dnnl::impl::memory_desc_wrapper wrapped(desc.get()); - return wrapped.extra().flags == dnnl_memory_extra_flag_none && BlockedMemoryDesc::isCompatibleInternal(rhs, cmpMask); + return wrapped.extra().flags == dnnl_memory_extra_flag_none && + BlockedMemoryDesc::isCompatibleInternal(rhs, cmpMask); } bool DnnlBlockedMemoryDesc::isCompatible(const DnnlBlockedMemoryDesc& rhs, CmpMask cmpMask) const { @@ -288,8 +309,10 @@ bool DnnlBlockedMemoryDesc::isCompatible(const DnnlBlockedMemoryDesc& rhs, CmpMa const auto thisExtra = wrappedThis.extra(); const auto rhsExtra = wrappedRhs.extra(); - return this->getOrder() == rhs.getOrder() && (thisExtra.flags == rhsExtra.flags && thisExtra.compensation_mask == rhsExtra.compensation_mask && - thisExtra.scale_adjust == rhsExtra.scale_adjust) && wrappedThis.similar_to(wrappedRhs, true, true, 0, true, checkOffset, stride_mask); + return this->getOrder() == rhs.getOrder() && + (thisExtra.flags == rhsExtra.flags && thisExtra.compensation_mask == rhsExtra.compensation_mask && + thisExtra.scale_adjust == rhsExtra.scale_adjust) && + wrappedThis.similar_to(wrappedRhs, true, true, 0, true, checkOffset, stride_mask); } static VectorDims extractOrder(const dnnl::memory::desc& desc) { @@ -300,7 +323,7 @@ static VectorDims extractOrder(const dnnl::memory::desc& desc) { OPENVINO_THROW("Unexpected: Cannot calculate order from undefined dims or strides"); } - const auto &blk_desc = descWrapped.blocking_desc(); + const auto& blk_desc = descWrapped.blocking_desc(); const size_t outer_ndims = dims.size(); const size_t inner_ndims = blk_desc.inner_nblks; @@ -319,11 +342,11 @@ static VectorDims extractOrder(const dnnl::memory::desc& desc) { // order of outer dims. In case of IOhw_ will be {1, 0, 2, 3} VectorDims outer_order(outer_ndims); std::iota(outer_order.begin(), outer_order.end(), 0); - std::sort(outer_order.begin(), outer_order.end(), - [&blk_desc, &outer_block_dims](size_t ind_l, size_t ind_r) { - return (blk_desc.strides[ind_l] > blk_desc.strides[ind_r]) || - (blk_desc.strides[ind_l] == blk_desc.strides[ind_r] && outer_block_dims[ind_l] > outer_block_dims[ind_r]); - }); + std::sort(outer_order.begin(), outer_order.end(), [&blk_desc, &outer_block_dims](size_t ind_l, size_t ind_r) { + return (blk_desc.strides[ind_l] > blk_desc.strides[ind_r]) || + (blk_desc.strides[ind_l] == blk_desc.strides[ind_r] && + outer_block_dims[ind_l] > outer_block_dims[ind_r]); + }); // blocked order // [new_outer_order] U [inner_idxs] @@ -333,8 +356,8 @@ static VectorDims extractOrder(const dnnl::memory::desc& desc) { return blk_order; } -DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(const_dnnl_memory_desc_t cdesc) : - MemoryDesc(DnnlExtensionUtils::convertToVectorDims(cdesc->dims, cdesc->ndims), DnnlBlocked) { +DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(const_dnnl_memory_desc_t cdesc) + : MemoryDesc(DnnlExtensionUtils::convertToVectorDims(cdesc->dims, cdesc->ndims), DnnlBlocked) { desc = dnnl::memory::desc(DnnlExtensionUtils::clone_desc(cdesc)); if (desc.get_format_kind() == dnnl::memory::format_kind::any) @@ -356,16 +379,16 @@ DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(const_dnnl_memory_desc_t cdesc) : bool DnnlBlockedMemoryDesc::hasLayoutType(LayoutType layoutType) const { switch (layoutType) { - case LayoutType::ncsp: - return isPlainFormat(); - case LayoutType::nspc: - return isTailCFormat(); - case LayoutType::nCsp8c: - return isBlockedCFormat(8); - case LayoutType::nCsp16c: - return isBlockedCFormat(16); - default: - return false; + case LayoutType::ncsp: + return isPlainFormat(); + case LayoutType::nspc: + return isTailCFormat(); + case LayoutType::nCsp8c: + return isBlockedCFormat(8); + case LayoutType::nCsp16c: + return isBlockedCFormat(16); + default: + return false; } } @@ -382,8 +405,7 @@ bool DnnlBlockedMemoryDesc::isPlainFormat() const { } bool DnnlBlockedMemoryDesc::isBlockedCFormat(size_t blk_size) const { - if (desc.get_format_kind() != dnnl::memory::format_kind::blocked || - desc.get_inner_nblks() != 1 || + if (desc.get_format_kind() != dnnl::memory::format_kind::blocked || desc.get_inner_nblks() != 1 || desc.get_inner_idxs()[0] != 1) return false; @@ -452,13 +474,15 @@ static dnnl::memory::desc cloneDescWithNewDims(const dnnl::memory::desc& desc, return newMklDesc; } -MemoryDescPtr DnnlBlockedMemoryDesc::cloneWithNewDimsImp(const VectorDims &dims) const { - if (std::any_of(dims.begin(), dims.end(), [](size_t x){ return Shape::UNDEFINED_DIM == x; })) { +MemoryDescPtr DnnlBlockedMemoryDesc::cloneWithNewDimsImp(const VectorDims& dims) const { + if (std::any_of(dims.begin(), dims.end(), [](size_t x) { + return Shape::UNDEFINED_DIM == x; + })) { OPENVINO_THROW("Can't clone desc if new dims are undefined"); } // TODO [DS]: add stride recalculation for strided blobs - for (int i = strides.size() - 2; i >= 0 ; i--) { + for (int i = strides.size() - 2; i >= 0; i--) { if (strides[i] == Shape::UNDEFINED_DIM) break; @@ -499,7 +523,7 @@ bool DnnlBlockedMemoryDesc::isSame(dnnl::memory::format_tag fmt) const { { const auto dims = desc.get_dims(); VectorDims total_block_per_dim(dims.size(), 1); - const auto &blk_desc = desc.get()->format_desc.blocking; + const auto& blk_desc = desc.get()->format_desc.blocking; for (int i = 0; i < blk_desc.inner_nblks; i++) { total_block_per_dim[blk_desc.inner_idxs[i]] *= blk_desc.inner_blks[i]; } @@ -509,10 +533,12 @@ bool DnnlBlockedMemoryDesc::isSame(dnnl::memory::format_tag fmt) const { } std::iota(actualOrder.begin(), actualOrder.end(), 0); - std::sort(actualOrder.begin(), actualOrder.end(), - [&actualStrides, &outer_block_dims] (size_t ind_l, size_t ind_r) { + std::sort(actualOrder.begin(), + actualOrder.end(), + [&actualStrides, &outer_block_dims](size_t ind_l, size_t ind_r) { return (actualStrides[ind_l] > actualStrides[ind_r]) || - (actualStrides[ind_l] == actualStrides[ind_r] && outer_block_dims[ind_l] > outer_block_dims[ind_r]); + (actualStrides[ind_l] == actualStrides[ind_r] && + outer_block_dims[ind_l] > outer_block_dims[ind_r]); }); } @@ -520,7 +546,7 @@ bool DnnlBlockedMemoryDesc::isSame(dnnl::memory::format_tag fmt) const { { const auto dims = refDesc.get_dims(); VectorDims total_block_per_dim(dims.size(), 1); - const auto &blk_desc = refDesc.get()->format_desc.blocking; + const auto& blk_desc = refDesc.get()->format_desc.blocking; for (int i = 0; i < blk_desc.inner_nblks; i++) { total_block_per_dim[blk_desc.inner_idxs[i]] *= blk_desc.inner_blks[i]; } @@ -530,11 +556,10 @@ bool DnnlBlockedMemoryDesc::isSame(dnnl::memory::format_tag fmt) const { } std::iota(refOrder.begin(), refOrder.end(), 0); - std::sort(refOrder.begin(), refOrder.end(), - [&refStrides, &outer_block_dims] (size_t ind_l, size_t ind_r) { - return (refStrides[ind_l] > refStrides[ind_r]) || - (refStrides[ind_l] == refStrides[ind_r] && outer_block_dims[ind_l] > outer_block_dims[ind_r]); - }); + std::sort(refOrder.begin(), refOrder.end(), [&refStrides, &outer_block_dims](size_t ind_l, size_t ind_r) { + return (refStrides[ind_l] > refStrides[ind_r]) || + (refStrides[ind_l] == refStrides[ind_r] && outer_block_dims[ind_l] > outer_block_dims[ind_r]); + }); } if (actualOrder != refOrder) { @@ -549,7 +574,9 @@ size_t DnnlBlockedMemoryDesc::getMaxMemSize() const { } const auto& maxDims = shape.getMaxDims(); - if (std::any_of(maxDims.begin(), maxDims.end(), [](size_t x){ return Shape::UNDEFINED_DIM == x; })) { + if (std::any_of(maxDims.begin(), maxDims.end(), [](size_t x) { + return Shape::UNDEFINED_DIM == x; + })) { return UNDEFINED_SIZE; } @@ -563,11 +590,13 @@ size_t DnnlBlockedMemoryDesc::getPaddedElementsCount() const { } auto padded_dims = desc.get_padded_dims(); - if (std::any_of(std::begin(padded_dims), std::begin(padded_dims) + desc.get_ndims(), - [](dnnl_dim_t dim) { return dim == DNNL_RUNTIME_DIM_VAL; })) { + if (std::any_of(std::begin(padded_dims), std::begin(padded_dims) + desc.get_ndims(), [](dnnl_dim_t dim) { + return dim == DNNL_RUNTIME_DIM_VAL; + })) { OPENVINO_THROW("Can't compute padded elements count for non undefined blocked dims"); } - return std::accumulate(std::begin(padded_dims), std::begin(padded_dims) + desc.get_ndims(), + return std::accumulate(std::begin(padded_dims), + std::begin(padded_dims) + desc.get_ndims(), size_t{1}, std::multiplies()); } @@ -586,7 +615,7 @@ void DnnlBlockedMemoryDesc::initBlockDims() { const auto dims = desc.get_dims(); const size_t outer_ndims = dims.size(); - const auto inner_ndims = desc.get_inner_nblks(); + const auto inner_ndims = desc.get_inner_nblks(); const size_t total_ndims = outer_ndims + inner_ndims; // total inner block size. in case of 4i16o4i will be {16, 16, 1, 1} @@ -612,10 +641,10 @@ void DnnlBlockedMemoryDesc::initBlockDims() { std::copy(order.begin(), order.begin() + outer_ndims, outer_order.begin()); blockedDims.resize(total_ndims, 0); - std::copy(inner_blks.begin(), inner_blks.begin() + inner_nblks, - blockedDims.end() - inner_nblks); - std::transform(outer_order.begin(), outer_order.end(), blockedDims.begin(), - [&] (size_t i) { return outer_block_dims[i]; }); + std::copy(inner_blks.begin(), inner_blks.begin() + inner_nblks, blockedDims.end() - inner_nblks); + std::transform(outer_order.begin(), outer_order.end(), blockedDims.begin(), [&](size_t i) { + return outer_block_dims[i]; + }); } void DnnlBlockedMemoryDesc::initStrides() { @@ -623,7 +652,7 @@ void DnnlBlockedMemoryDesc::initStrides() { const size_t outer_ndims = dims.size(); const size_t inner_nblks = desc.get_inner_nblks(); - const auto inner_blks = desc.get_inner_blks(); + const auto inner_blks = desc.get_inner_blks(); const size_t total_ndims = outer_ndims + inner_nblks; // strides of inner dims. In case of 4i16o4i will be {64, 4, 1} @@ -642,8 +671,9 @@ void DnnlBlockedMemoryDesc::initStrides() { std::copy(inner_strides.rbegin(), inner_strides.rend(), strides.rbegin()); const auto desc_strides = desc.get_strides(); - std::transform(outer_order.begin(), outer_order.end(), strides.begin(), - [&](size_t i) { return desc_strides[i] == DNNL_RUNTIME_DIM_VAL ? Shape::UNDEFINED_DIM : desc_strides[i]; }); + std::transform(outer_order.begin(), outer_order.end(), strides.begin(), [&](size_t i) { + return desc_strides[i] == DNNL_RUNTIME_DIM_VAL ? Shape::UNDEFINED_DIM : desc_strides[i]; + }); } void DnnlBlockedMemoryDesc::initOffsetPadding() { @@ -659,15 +689,17 @@ MemoryDescPtr DnnlBlockedMemoryDesc::cloneWithNewPrecision(const ov::element::Ty } void DnnlBlockedMemoryDesc::recomputeDefaultStrides() { - const auto &rank = getShape().getRank(); + const auto& rank = getShape().getRank(); if (order.size() != blockedDims.size()) OPENVINO_THROW("Can't recompute stride: order size != blocked dims size"); - auto &oneDnnStrides = desc.get()->format_desc.blocking.strides; + auto& oneDnnStrides = desc.get()->format_desc.blocking.strides; if (getShape().hasZeroDims()) { std::fill(std::begin(oneDnnStrides), std::begin(oneDnnStrides) + getShape().getRank(), 0); - } else if (std::any_of(blockedDims.begin(), blockedDims.end(), [](Dim val) { return val == Shape::UNDEFINED_DIM; })) { + } else if (std::any_of(blockedDims.begin(), blockedDims.end(), [](Dim val) { + return val == Shape::UNDEFINED_DIM; + })) { std::fill(std::begin(oneDnnStrides), std::begin(oneDnnStrides) + rank, DNNL_RUNTIME_DIM_VAL); initStrides(); } else { @@ -682,8 +714,8 @@ void DnnlBlockedMemoryDesc::recomputeDefaultStrides() { } } -DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(const dnnl::memory::desc& mdesc, const Shape& shape) : - MemoryDesc(shape, DnnlBlocked) { +DnnlBlockedMemoryDesc::DnnlBlockedMemoryDesc(const dnnl::memory::desc& mdesc, const Shape& shape) + : MemoryDesc(shape, DnnlBlocked) { if (mdesc.get_format_kind() == dnnl::memory::format_kind::any) OPENVINO_THROW("Unexpected: Memory format any is prohibited!"); @@ -715,5 +747,5 @@ std::string DnnlBlockedMemoryDesc::serializeFormat() const { return BlockedMemoryDesc::serializeFormat(); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_desc/dnnl_blocked_memory_desc.h b/src/plugins/intel_cpu/src/memory_desc/dnnl_blocked_memory_desc.h index a6c6a3297ba044..91388c12e2abf7 100644 --- a/src/plugins/intel_cpu/src/memory_desc/dnnl_blocked_memory_desc.h +++ b/src/plugins/intel_cpu/src/memory_desc/dnnl_blocked_memory_desc.h @@ -4,19 +4,20 @@ #pragma once +#include + +#include "dnnl_extension_utils.h" #include "dnnl_memory_desc.h" #include "memory_desc/blocked_memory_desc.h" #include "openvino/util/util.hpp" -#include "dnnl_extension_utils.h" -#include namespace ov { namespace intel_cpu { class CpuBlockedMemoryDesc; -OPENVINO_DISABLE_WARNING_MSVC_BEGIN(4250) // Visual Studio warns us about inheritance via dominance but it's done intentionally - // so turn it off +OPENVINO_DISABLE_WARNING_MSVC_BEGIN(4250) // Visual Studio warns us about inheritance via dominance but it's done + // intentionally so turn it off class DnnlBlockedMemoryDesc : public BlockedMemoryDesc, public DnnlMemoryDesc { public: // Creates planar DnnlBlockedMemoryDesc @@ -30,8 +31,8 @@ class DnnlBlockedMemoryDesc : public BlockedMemoryDesc, public DnnlMemoryDesc { bool isCompatible(const MemoryDesc& rhs) const override; bool isCompatible(const BlockedMemoryDesc& rhs, CmpMask cmpMask) const override; - bool isCompatible(const CpuBlockedMemoryDesc &rhs, CmpMask cmpMask = FULL_MASK) const; - bool isCompatible(const DnnlBlockedMemoryDesc &rhs, CmpMask cmpMask = FULL_MASK) const; + bool isCompatible(const CpuBlockedMemoryDesc& rhs, CmpMask cmpMask = FULL_MASK) const; + bool isCompatible(const DnnlBlockedMemoryDesc& rhs, CmpMask cmpMask = FULL_MASK) const; const VectorDims& getBlockDims() const override { return blockedDims; @@ -63,17 +64,22 @@ class DnnlBlockedMemoryDesc : public BlockedMemoryDesc, public DnnlMemoryDesc { MemoryDescPtr cloneWithNewPrecision(const ov::element::Type prec) const override; - using DnnlMemoryDesc::setPrecision; using DnnlMemoryDesc::getPrecision; + using DnnlMemoryDesc::setPrecision; private: - DnnlBlockedMemoryDesc(ov::element::Type prc, const Shape& shape, const VectorDims& blockedDims, - const VectorDims& order, size_t offsetPadding = 0, const VectorDims& offsetPaddingToData = {}, + DnnlBlockedMemoryDesc(ov::element::Type prc, + const Shape& shape, + const VectorDims& blockedDims, + const VectorDims& order, + size_t offsetPadding = 0, + const VectorDims& offsetPaddingToData = {}, const VectorDims& strides = {}); - // Creates DnnlBlockedMemoryDesc using the shape parameter as a true shape but all other params (layout, blocks, etc.) are used from the mdesc, but - // the mdesc own shape is ignored. The main purpose of this constructor is making dynamic descriptor from some dummy mdesc, which stores info about - // layout, blocking, strides, etc., and the provided dynamic shape. + // Creates DnnlBlockedMemoryDesc using the shape parameter as a true shape but all other params (layout, blocks, + // etc.) are used from the mdesc, but the mdesc own shape is ignored. The main purpose of this constructor is making + // dynamic descriptor from some dummy mdesc, which stores info about layout, blocking, strides, etc., and the + // provided dynamic shape. DnnlBlockedMemoryDesc(const dnnl::memory::desc& mdesc, const Shape& shape); explicit DnnlBlockedMemoryDesc(const_dnnl_memory_desc_t cdesc); @@ -84,7 +90,8 @@ class DnnlBlockedMemoryDesc : public BlockedMemoryDesc, public DnnlMemoryDesc { bool isBlockedCFormat(size_t blk_size = UNREACHABLE_DIM) const; bool isTailCFormat() const; - // WA: we need to initialize blocked params into ctor to avoid bugs when we calculate these params in throughput mode + // WA: we need to initialize blocked params into ctor to avoid bugs when we calculate these params in throughput + // mode // TODO [DS]: should be reimplemented to avoid useless calculation void initBlockedParams() { initBlockDims(); @@ -99,7 +106,8 @@ class DnnlBlockedMemoryDesc : public BlockedMemoryDesc, public DnnlMemoryDesc { void recomputeDefaultStrides(); friend DnnlMemoryDescPtr DnnlExtensionUtils::makeDescriptor(const_dnnl_memory_desc_t desc); - friend std::shared_ptr DnnlExtensionUtils::makeUndefinedDesc(const dnnl::memory::desc &desc, const Shape& shape); + friend std::shared_ptr DnnlExtensionUtils::makeUndefinedDesc(const dnnl::memory::desc& desc, + const Shape& shape); friend class MemoryDescUtils; }; OPENVINO_DISABLE_WARNING_MSVC_END(4250) @@ -107,5 +115,5 @@ OPENVINO_DISABLE_WARNING_MSVC_END(4250) using DnnlBlockedMemoryDescPtr = std::shared_ptr; using DnnlBlockedMemoryDescCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_desc/dnnl_memory_desc.cpp b/src/plugins/intel_cpu/src/memory_desc/dnnl_memory_desc.cpp index 3e3af41cfc523a..375b218272ed57 100644 --- a/src/plugins/intel_cpu/src/memory_desc/dnnl_memory_desc.cpp +++ b/src/plugins/intel_cpu/src/memory_desc/dnnl_memory_desc.cpp @@ -3,20 +3,21 @@ // #include "dnnl_memory_desc.h" -#include "dnnl_extension_utils.h" + #include #include + +#include "dnnl_extension_utils.h" #include "onednn/dnnl.h" namespace ov { namespace intel_cpu { -DnnlMemoryDesc::DnnlMemoryDesc(const dnnl::memory::desc& desc) : - DnnlMemoryDesc(desc.get()) {} +DnnlMemoryDesc::DnnlMemoryDesc(const dnnl::memory::desc& desc) : DnnlMemoryDesc(desc.get()) {} -DnnlMemoryDesc::DnnlMemoryDesc(const_dnnl_memory_desc_t cdesc) : - MemoryDesc(Shape(DnnlExtensionUtils::convertToVectorDims(cdesc->dims, cdesc->ndims)), Dnnl), - desc(DnnlExtensionUtils::clone_desc(cdesc)) { +DnnlMemoryDesc::DnnlMemoryDesc(const_dnnl_memory_desc_t cdesc) + : MemoryDesc(Shape(DnnlExtensionUtils::convertToVectorDims(cdesc->dims, cdesc->ndims)), Dnnl), + desc(DnnlExtensionUtils::clone_desc(cdesc)) { if (getFormatKind() == dnnl::memory::format_kind::any) OPENVINO_THROW("Unexpected: Memory format any is prohibited!"); } @@ -35,7 +36,7 @@ MemoryDescPtr DnnlMemoryDesc::cloneWithNewPrecision(const ov::element::Type prec return newDesc; } -bool DnnlMemoryDesc::isCompatible(const MemoryDesc &rhs) const { +bool DnnlMemoryDesc::isCompatible(const MemoryDesc& rhs) const { if (MemoryDescType::Dnnl & rhs.getType()) { auto* dnnMemDesc = rhs.as(); return isCompatible(*dnnMemDesc); @@ -52,17 +53,25 @@ std::string DnnlMemoryDesc::serializeFormat() const { dnnl::impl::memory_desc_wrapper wrapped(desc.get()); if (wrapped.is_wino_desc()) { switch (desc.get()->format_desc.wino_desc.wino_format) { - case dnnl::impl::wino_memory_format_t::wino_wei_aaOio: return "wino_aaOio"; - case dnnl::impl::wino_memory_format_t::wino_wei_aaOBiOo: return "wino_aaOBiOo"; - case dnnl::impl::wino_memory_format_t::wino_wei_OBaaIBOIio: return "wino_OBaaIBOIio"; - default: return "wino_undef"; + case dnnl::impl::wino_memory_format_t::wino_wei_aaOio: + return "wino_aaOio"; + case dnnl::impl::wino_memory_format_t::wino_wei_aaOBiOo: + return "wino_aaOBiOo"; + case dnnl::impl::wino_memory_format_t::wino_wei_OBaaIBOIio: + return "wino_OBaaIBOIio"; + default: + return "wino_undef"; } } else if (wrapped.is_rnn_packed_desc()) { switch (desc.get()->format_desc.rnn_packed_desc.format) { - case dnnl::impl::rnn_packed_format::ldigo_p: return "packed_ldigo"; - case dnnl::impl::rnn_packed_format::ldgoi_p: return "packed_ldgoi"; - case dnnl::impl::rnn_packed_format::ldio_p: return "packed_ldio"; - default: return "packed_undef"; + case dnnl::impl::rnn_packed_format::ldigo_p: + return "packed_ldigo"; + case dnnl::impl::rnn_packed_format::ldgoi_p: + return "packed_ldgoi"; + case dnnl::impl::rnn_packed_format::ldio_p: + return "packed_ldio"; + default: + return "packed_undef"; } } return "undef"; @@ -116,7 +125,7 @@ bool DnnlMemoryDesc::isDefinedImp() const { return wrappedThis.offset0() != DNNL_RUNTIME_DIM_VAL; } -MemoryDescPtr DnnlMemoryDesc::cloneWithNewDimsImp(const VectorDims &dims) const { +MemoryDescPtr DnnlMemoryDesc::cloneWithNewDimsImp(const VectorDims& dims) const { OPENVINO_THROW("Unexpected: Cannot clone non blocked oneDNN desc with new dims"); } @@ -125,6 +134,5 @@ size_t DnnlMemoryDesc::getOffsetPadding() const { return DnnlExtensionUtils::convertToDim(wrap.offset0()); } - -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_desc/dnnl_memory_desc.h b/src/plugins/intel_cpu/src/memory_desc/dnnl_memory_desc.h index f2722a5170f871..6b3692c5663078 100644 --- a/src/plugins/intel_cpu/src/memory_desc/dnnl_memory_desc.h +++ b/src/plugins/intel_cpu/src/memory_desc/dnnl_memory_desc.h @@ -4,11 +4,11 @@ #pragma once -#include "dnnl_extension_utils.h" #include #include -#include "memory_desc/cpu_memory_desc.h" + #include "dnnl_extension_utils.h" +#include "memory_desc/cpu_memory_desc.h" namespace ov { namespace intel_cpu { @@ -29,13 +29,17 @@ class DnnlMemoryDesc : public virtual MemoryDesc { bool isCompatible(const MemoryDesc& rhs) const override; bool isCompatible(const DnnlMemoryDesc& rhs) const; - bool hasLayoutType(LayoutType layoutType) const override { return false; } + bool hasLayoutType(LayoutType layoutType) const override { + return false; + } std::string serializeFormat() const override; size_t getMaxMemSize() const override; - virtual bool isSame(dnnl::memory::format_tag fmt) const { return false; } + virtual bool isSame(dnnl::memory::format_tag fmt) const { + return false; + } const dnnl::memory::desc& getDnnlDesc() const { return desc; @@ -70,10 +74,9 @@ class DnnlMemoryDesc : public virtual MemoryDesc { bool isDefinedImp() const override; MemoryDescPtr cloneWithNewDimsImp(const VectorDims& dims) const override; - friend DnnlMemoryDescPtr DnnlExtensionUtils::makeDescriptor(const dnnl::memory::desc &desc); + friend DnnlMemoryDescPtr DnnlExtensionUtils::makeDescriptor(const dnnl::memory::desc& desc); friend DnnlMemoryDescPtr DnnlExtensionUtils::makeDescriptor(const_dnnl_memory_desc_t desc); }; -} // namespace intel_cpu -} // namespace ov - +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_desc/empty_memory_desc.h b/src/plugins/intel_cpu/src/memory_desc/empty_memory_desc.h index 1575841cb2be9e..c26cc6aa33a251 100644 --- a/src/plugins/intel_cpu/src/memory_desc/empty_memory_desc.h +++ b/src/plugins/intel_cpu/src/memory_desc/empty_memory_desc.h @@ -5,7 +5,6 @@ #pragma once #include "cpu_memory_desc.h" - #include "cpu_shape.h" #include "openvino/core/except.hpp" #include "openvino/core/type/element_type.hpp" @@ -23,8 +22,7 @@ namespace intel_cpu { */ class EmptyMemoryDesc : public MemoryDesc { public: - EmptyMemoryDesc(): - MemoryDesc(Shape{0}, Empty) { + EmptyMemoryDesc() : MemoryDesc(Shape{0}, Empty) { /* status never changes for an empty memory desc * so "define" beforehand to ensure isDefined() is thread safe */ status = MemoryDesc::descStatus::Defined; @@ -60,7 +58,9 @@ class EmptyMemoryDesc : public MemoryDesc { MemoryDescPtr cloneWithNewPrecision(const ov::element::Type prec) const override { OPENVINO_ASSERT(prec == ov::element::undefined, - "Clone an empty memory desc with defined precision: ", prec, " is prohibited"); + "Clone an empty memory desc with defined precision: ", + prec, + " is prohibited"); return clone(); } @@ -92,5 +92,5 @@ class EmptyMemoryDesc : public MemoryDesc { using EmptyMemoryDescPtr = std::shared_ptr; using EmptyMemoryDescCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_state.cpp b/src/plugins/intel_cpu/src/memory_state.cpp index aa06f4ebd82957..c0dc85c4103ce4 100644 --- a/src/plugins/intel_cpu/src/memory_state.cpp +++ b/src/plugins/intel_cpu/src/memory_state.cpp @@ -5,30 +5,33 @@ #include "memory_state.h" #include + #include "cpu_memory.h" +#include "cpu_tensor.h" +#include "dnnl_extension_utils.h" #include "memory_desc/cpu_blocked_memory_desc.h" #include "memory_desc/cpu_memory_desc_utils.h" -#include "dnnl_extension_utils.h" -#include "cpu_tensor.h" -#include "utils/plain_tensor.hpp" -#include "openvino/core/parallel.hpp" #include "nodes/common/cpu_convert.h" #include "nodes/kernels/scaled_attn/attn_quant.hpp" +#include "openvino/core/parallel.hpp" +#include "utils/plain_tensor.hpp" using namespace ov::Extensions::Cpu::XARCH; namespace ov { namespace intel_cpu { -VariableStateBase::VariableStateBase(const std::string& name, const MemoryDescPtr& external_desc) : - IVariableState{name} , m_external_desc{external_desc} {} +VariableStateBase::VariableStateBase(const std::string& name, const MemoryDescPtr& external_desc) + : IVariableState{name}, + m_external_desc{external_desc} {} MemoryDescPtr VariableStateBase::to_static(const MemoryDescPtr& desc) { if (!desc->isDefined()) { auto&& current_dims = desc->getShape().getDims(); VectorDims new_dims(current_dims.size()); std::transform(current_dims.begin(), current_dims.end(), new_dims.begin(), [](Dim x) { - return x == Shape::UNDEFINED_DIM ? 0 : x; }); + return x == Shape::UNDEFINED_DIM ? 0 : x; + }); return desc->cloneWithNewDims(new_dims, true); } @@ -71,21 +74,26 @@ ov::SoPtr VariableStateBase::get_state() const { return std::make_shared(internal_state_mem()); } - //test precision + // test precision { auto internal_prc = current_internal_desc->getPrecision(); auto tmp_desc = current_ext_desc->cloneWithNewPrecision(internal_prc); if (tmp_desc->isCompatible(*current_internal_desc)) { auto mem = std::make_shared(get_engine(), current_ext_desc); - size_t elements_to_convert = internal_state_mem()->getDescWithType()->getPaddedElementsCount(); + size_t elements_to_convert = + internal_state_mem()->getDescWithType()->getPaddedElementsCount(); auto external_prc = current_ext_desc->getPrecision(); - cpu_convert(internal_state_mem()->getData(), mem->getData(), internal_prc, external_prc, elements_to_convert); + cpu_convert(internal_state_mem()->getData(), + mem->getData(), + internal_prc, + external_prc, + elements_to_convert); return std::make_shared(mem); } } - //reorder + // reorder auto mem = std::make_shared(get_engine(), current_ext_desc); mem->load(*(internal_state_mem())); return std::make_shared(mem); @@ -108,19 +116,19 @@ void VariableStateBase::commit() { VariableStateDoubleBuffer::VariableStateDoubleBuffer(const std::string& name, const MemoryPtr& first_buffer, const MemoryPtr& second_buffer, - const MemoryDescPtr& external_desc) : - VariableStateBase(name, external_desc) { + const MemoryDescPtr& external_desc) + : VariableStateBase(name, external_desc) { OPENVINO_ASSERT(first_buffer && second_buffer); reset_prime_mem(first_buffer); reset_second_mem(second_buffer); m_internal_desc = prime_mem()->getDescPtr(); auto&& shape = m_internal_desc->getShape(); - //TODO what if by some reason we already have internal static state while the node is dynamic, is it even possible? + // TODO what if by some reason we already have internal static state while the node is dynamic, is it even possible? if (shape.isStatic()) { prime_mem()->nullify(); } else { - //in the case of the original desc has dynamic shape we create an empty tensor + // in the case of the original desc has dynamic shape we create an empty tensor auto new_desc = to_static(m_internal_desc); prime_mem()->redefineDesc(new_desc); } @@ -199,11 +207,11 @@ void VariableStateSingleBuffer::commit_impl() { // nothing to do } -VariableStateKVcache::VariableStateKVcache( - const std::string& name, - const MemoryDescPtr& external_desc, - const BlockedMemoryDescPtr& dense_internal_desc) : - VariableStateBase(name, external_desc), m_dense_internal_desc(dense_internal_desc) { +VariableStateKVcache::VariableStateKVcache(const std::string& name, + const MemoryDescPtr& external_desc, + const BlockedMemoryDescPtr& dense_internal_desc) + : VariableStateBase(name, external_desc), + m_dense_internal_desc(dense_internal_desc) { auto&& shape = external_desc->getShape(); OPENVINO_ASSERT(shape.isDynamic(), "VariableStateKVcache is unexpectedly initalized with a static tensor"); @@ -227,7 +235,7 @@ ov::SoPtr VariableStateKVcache::get_state() const { OPENVINO_ASSERT(actual_external_desc->getShape().getRank() == 4); auto&& actual_internal_order = actual_internal_desc->getOrder(); - //sanity check + // sanity check OPENVINO_ASSERT(actual_internal_order == m_dense_internal_desc->getOrder()); PlainTensor output, pastkv, beam_table; @@ -253,20 +261,12 @@ ov::SoPtr VariableStateKVcache::get_state() const { S, m_scale_zp.ptr(m, b_kv, h)[0], m_scale_zp.ptr(m, b_kv, h)[1]); - cpu_convert(buffers[ithr].ptr(), - output.ptr_v(m, b, h), - element::f32, - output.m_dt, - S); + cpu_convert(buffers[ithr].ptr(), output.ptr_v(m, b, h), element::f32, output.m_dt, S); }); } else { parallel_for3d(L0, B, H, [&](size_t m, size_t b, size_t h) { auto b_kv = static_cast(beam_table.at({b, m})); - cpu_convert(pastkv.ptr_v(m, b_kv, h), - output.ptr_v(m, b, h), - pastkv.m_dt, - output.m_dt, - S); + cpu_convert(pastkv.ptr_v(m, b_kv, h), output.ptr_v(m, b, h), pastkv.m_dt, output.m_dt, S); }); } @@ -274,11 +274,11 @@ ov::SoPtr VariableStateKVcache::get_state() const { } void VariableStateKVcache::set_state_impl(const ov::SoPtr& state) { - //1. reset the memory object - m_state = state; // simply to extend the lifetime + // 1. reset the memory object + m_state = state; // simply to extend the lifetime auto state_desc = MemoryDescUtils::generateCpuBlockedMemoryDesc(m_state); - //May be optimized by reusing the state tensor underlining memory pointer, but corner cases should be considered + // May be optimized by reusing the state tensor underlining memory pointer, but corner cases should be considered auto dense_internal_desc = m_dense_internal_desc->cloneWithNewDims(state_desc->getShape().getStaticDims()); m_internal_mem = std::make_shared(get_engine(), dense_internal_desc); @@ -287,7 +287,10 @@ void VariableStateKVcache::set_state_impl(const ov::SoPtr& state) { if (dense_internal_desc->getPrecision() == element::u8) { PlainTensor external, internal; auto&& actual_internal_order = m_dense_internal_desc->getOrder(); - external.resize(external_mem.getStaticDims(), state_desc->getPrecision().size(), state_desc->getPrecision(), m_state->data()); + external.resize(external_mem.getStaticDims(), + state_desc->getPrecision().size(), + state_desc->getPrecision(), + m_state->data()); internal.reset(m_internal_mem); external = external.permute(actual_internal_order); internal = internal.permute(actual_internal_order); @@ -300,11 +303,7 @@ void VariableStateKVcache::set_state_impl(const ov::SoPtr& state) { m_scale_zp.resize({L0, B, H, 2}); parallel_for3d(B, H, L0, [&](size_t ithr, size_t b, size_t h, size_t m) { buffers[ithr].resize({S}); - cpu_convert(external.ptr_v(m, b, h), - buffers[ithr].ptr(), - external.m_dt, - element::f32, - S); + cpu_convert(external.ptr_v(m, b, h), buffers[ithr].ptr(), external.m_dt, element::f32, S); attn_quant_u8(buffers[ithr].ptr(), internal.ptr(m, b, h), S, @@ -315,14 +314,13 @@ void VariableStateKVcache::set_state_impl(const ov::SoPtr& state) { m_internal_mem->load(external_mem); } - //2. Reset the beam search table + // 2. Reset the beam search table auto&& state_dims = dense_internal_desc->getShape().getStaticDims(); auto&& order = m_dense_internal_desc->getOrder(); const size_t size_B = state_dims[order.at(1)]; const size_t size_L = state_dims[order.at(0)]; - auto mem_desc = - std::make_shared(ov::element::i32, Shape{size_B, size_L}); + auto mem_desc = std::make_shared(ov::element::i32, Shape{size_B, size_L}); m_hidden_state = std::make_shared(get_engine(), mem_desc); auto buff = m_hidden_state->getDataAs(); @@ -336,11 +334,11 @@ void VariableStateKVcache::set_state_impl(const ov::SoPtr& state) { } void VariableStateKVcache::reset_impl() { - //nothing to do + // nothing to do } void VariableStateKVcache::commit_impl() { - //nothing to do + // nothing to do } MemoryPtr VariableStateKVcache::input_mem() { @@ -352,7 +350,7 @@ MemoryPtr VariableStateKVcache::output_mem() { } MemoryDescPtr VariableStateKVcache::internal_desc() const { - return m_dense_internal_desc; //since we don't store initial one + return m_dense_internal_desc; // since we don't store initial one } MemoryPtr VariableStateKVcache::internal_state_mem() const { diff --git a/src/plugins/intel_cpu/src/memory_state.h b/src/plugins/intel_cpu/src/memory_state.h index e7493f327e93fa..f35e78989b02f8 100644 --- a/src/plugins/intel_cpu/src/memory_state.h +++ b/src/plugins/intel_cpu/src/memory_state.h @@ -29,12 +29,12 @@ class VariableStateBase : public IVariableState { public: VariableStateBase(const std::string& name, const MemoryDescPtr& external_desc); - //ov::IVariableState - void set_state(const ov::SoPtr& state) override final; // NOLINT + // ov::IVariableState + void set_state(const ov::SoPtr& state) override final; // NOLINT ov::SoPtr get_state() const override; - void reset() override final; // NOLINT - bool is_reset_state() const override final; // NOLINT - void commit() override final; // NOLINT + void reset() override final; // NOLINT + bool is_reset_state() const override final; // NOLINT + void commit() override final; // NOLINT protected: virtual MemoryPtr internal_state_mem() const = 0; @@ -66,7 +66,7 @@ class VariableStateDoubleBuffer : public VariableStateBase { MemoryDescPtr internal_desc() const override; private: - //ov::intel_cpu::VariableStateBase + // ov::intel_cpu::VariableStateBase void reset_impl() override; void commit_impl() override; @@ -89,7 +89,7 @@ class VariableStateDoubleBuffer : public VariableStateBase { MemoryPtr internal_state_mem() const override; private: - MemoryDescPtr m_internal_desc; //mem desc required by the graph internal tensor + MemoryDescPtr m_internal_desc; // mem desc required by the graph internal tensor std::array m_internal_mem{}; size_t buffer_num = 0; }; @@ -111,7 +111,7 @@ class VariableStateSingleBuffer : public VariableStateBase { MemoryPtr internal_state_mem() const override; private: - MemoryDescPtr m_internal_desc; //mem desc required by the graph internal tensor + MemoryDescPtr m_internal_desc; // mem desc required by the graph internal tensor MemoryPtr m_internal_mem; }; @@ -121,10 +121,10 @@ class VariableStateKVcache : public VariableStateBase { const MemoryDescPtr& external_desc, const BlockedMemoryDescPtr& dense_internal_desc); - //ov::IVariableState + // ov::IVariableState ov::SoPtr get_state() const override; - //ov::intel_cpu::VariableStateBase + // ov::intel_cpu::VariableStateBase MemoryPtr input_mem() override; MemoryPtr output_mem() override; MemoryDescPtr internal_desc() const override; @@ -158,14 +158,14 @@ class VariableStateKVcache : public VariableStateBase { } private: - //ov::intel_cpu::VariableStateBase + // ov::intel_cpu::VariableStateBase void set_state_impl(const ov::SoPtr& state) override; void reset_impl() override; void commit_impl() override; private: - MemoryPtr m_internal_mem; // kv cache - MemoryPtr m_hidden_state; // beam access table + MemoryPtr m_internal_mem; // kv cache + MemoryPtr m_hidden_state; // beam access table size_t m_internal_mem_max_size = 0; size_t m_hidden_state_max_size = 0; @@ -178,5 +178,5 @@ class VariableStateKVcache : public VariableStateBase { using MemStatePtr = std::shared_ptr; using MemStateCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/mlas/thread_pool.hpp b/src/plugins/intel_cpu/src/mlas/thread_pool.hpp index 536b3746be1d69..5af8b0cce915fa 100644 --- a/src/plugins/intel_cpu/src/mlas/thread_pool.hpp +++ b/src/plugins/intel_cpu/src/mlas/thread_pool.hpp @@ -7,6 +7,7 @@ #include #include #include + #include "mlas.h" namespace ov { @@ -17,6 +18,7 @@ class OVMlasThreadPool : public IMlasThreadPool { explicit OVMlasThreadPool(const size_t& threadNum) : threadNum(threadNum) {} size_t DegreeOfParallelism() override; void TrySimpleParallelFor(const std::ptrdiff_t total, const std::function& fn) override; + public: // the actual threads used for sgemm size_t threadNum = 0; diff --git a/src/plugins/intel_cpu/src/node.cpp b/src/plugins/intel_cpu/src/node.cpp index ee0a99c3bba44e..ddf8d068f920a2 100644 --- a/src/plugins/intel_cpu/src/node.cpp +++ b/src/plugins/intel_cpu/src/node.cpp @@ -3,39 +3,38 @@ // #include "node.h" -#include "cpu_types.h" -#include "edge.h" -#include "partitioned_mem_blk.h" -#include "openvino/core/type/element_type.hpp" +#include +#include + +#include +#include +#include #include #include -#include +#include #include -#include #include +#include +#include "cpu_types.h" +#include "dnnl_extension_utils.h" +#include "edge.h" +#include "memory_desc/cpu_memory_desc_utils.h" +#include "memory_desc/dnnl_blocked_memory_desc.h" +#include "nodes/common/cpu_convert.h" #include "nodes/conv.h" #include "nodes/eltwise.h" #include "nodes/input.h" -#include "nodes/reorder.h" #include "nodes/reference.h" -#include "dnnl_extension_utils.h" - +#include "nodes/reorder.h" +#include "openvino/core/type/element_type.hpp" +#include "partitioned_mem_blk.h" +#include "utils/cpu_utils.hpp" #include "utils/debug_capabilities.h" +#include "utils/general_utils.h" #include "utils/ngraph_utils.hpp" #include "utils/rt_info/memory_formats_attribute.hpp" -#include - -#include -#include -#include "utils/general_utils.h" -#include "utils/cpu_utils.hpp" -#include "nodes/common/cpu_convert.h" -#include "memory_desc/cpu_memory_desc_utils.h" -#include "memory_desc/dnnl_blocked_memory_desc.h" -#include -#include using namespace dnnl; using namespace openvino; @@ -44,7 +43,7 @@ using namespace ov::intel_cpu::node; namespace ov { namespace intel_cpu { -Node::NodesFactory & Node::factory() { +Node::NodesFactory& Node::factory() { static NodesFactory factoryInstance; return factoryInstance; } @@ -63,7 +62,7 @@ Node::Node(const std::shared_ptr& op, type(TypeFromName(op->get_type_name())), profiling(op->get_friendly_name()) { for (size_t i = 0; i < op->get_input_size(); i++) { - const auto &shape = op->get_input_partial_shape(i); + const auto& shape = op->get_input_partial_shape(i); if (shape.rank().is_dynamic()) { OPENVINO_THROW("Unexpected: CPU plug-in doesn't support ", getTypeStr(), @@ -83,7 +82,7 @@ Node::Node(const std::shared_ptr& op, OPENVINO_THROW("Node with type '", typeStr, "' and name '", name, "' does not have any outputs."); } for (size_t i = 0; i < op->get_output_size(); i++) { - const auto &shape = op->get_output_partial_shape(i); + const auto& shape = op->get_output_partial_shape(i); if (shape.rank().is_dynamic()) { OPENVINO_THROW("Unexpected: CPU plug-in doesn't support ", getTypeStr(), @@ -99,8 +98,14 @@ Node::Node(const std::shared_ptr& op, childEdges.reserve(outputShapes.size()); } - isDynamic = std::any_of(inputShapes.begin(), inputShapes.end(), [](const Shape& shape){ return shape.isDynamic(); }) || - std::any_of(outputShapes.begin(), outputShapes.end(), [](const Shape& shape){ return shape.isDynamic(); }); + isDynamic = std::any_of(inputShapes.begin(), + inputShapes.end(), + [](const Shape& shape) { + return shape.isDynamic(); + }) || + std::any_of(outputShapes.begin(), outputShapes.end(), [](const Shape& shape) { + return shape.isDynamic(); + }); if (isDynamic) { shapeInference = shapeInferFactory.makeShapeInfer(); @@ -127,12 +132,13 @@ Node::Node(const std::shared_ptr& op, if (str.substr(0, 4) != "cpu:") continue; customImplPriorities.push_back(parse_impl_name(str)); - if (customImplPriorities.back() == impl_desc_type::unknown && - str != "cpu:unknown") + if (customImplPriorities.back() == impl_desc_type::unknown && str != "cpu:unknown") OPENVINO_THROW("Unsupported CPU implementation ", str, " for node ", getName()); } const auto& defaultImplPriorities = getDefaultImplPriority(); - customImplPriorities.insert(customImplPriorities.end(), defaultImplPriorities.begin(), defaultImplPriorities.end()); + customImplPriorities.insert(customImplPriorities.end(), + defaultImplPriorities.begin(), + defaultImplPriorities.end()); } std::string inputMemoryFormats = getInputMemoryFormats(op); @@ -199,10 +205,11 @@ void Node::addEdge(const EdgePtr& edge) { } void Node::remove() { - auto drop = [](std::vector edges){ + auto drop = [](std::vector edges) { for (auto& edge : edges) { auto edgePtr = edge.lock(); - if (!edgePtr) continue; + if (!edgePtr) + continue; edgePtr->getParent()->removeChildEdge(edgePtr); edgePtr->getChild()->removeParentEdge(edgePtr); } @@ -213,7 +220,7 @@ void Node::remove() { } bool Node::isEdgesEmpty(const std::vector& edges) const { - for (auto &edge : edges) { + for (auto& edge : edges) { if (edge.lock()) return false; } @@ -265,7 +272,8 @@ void Node::selectPreferPrimitiveDescriptor(const std::vector& pr auto parentEdge = getParentEdgeAt(j); auto parentPtr = parentEdge->getParent(); - // We don't take into account constant edges since reorders on them will be executed on load network stage + // We don't take into account constant edges since reorders on them will be executed on load network + // stage if (ignoreConstInputs && j > 0 && parentPtr->isConstant()) { equalsLocalFormatCount++; continue; @@ -286,10 +294,20 @@ void Node::selectPreferPrimitiveDescriptor(const std::vector& pr equalsLocalFormatCount++; } - DEBUG_LOG(getName(), " pd[", i, "].inConfs[", j, "]" - " is ", (isCompatible ? "compatible" : "not compatible"), - " with parent ", parentPtr->getName(), - " outConfs[", inNum, "], equalsLocalFormatCount add to ", equalsLocalFormatCount); + DEBUG_LOG(getName(), + " pd[", + i, + "].inConfs[", + j, + "]" + " is ", + (isCompatible ? "compatible" : "not compatible"), + " with parent ", + parentPtr->getName(), + " outConfs[", + inNum, + "], equalsLocalFormatCount add to ", + equalsLocalFormatCount); } if (equalsLocalFormatCount > equalsFormatCount) { @@ -334,7 +352,8 @@ bool Node::isReorderRequired(ov::intel_cpu::MemoryDescPtr desc1, ov::intel_cpu:: return !(isOneDimShape1 && isOneDimShape2 && samePrec); } -void Node::selectPreferPrimitiveDescriptorWithShape(const std::vector& priority, bool ignoreConstInputs) { +void Node::selectPreferPrimitiveDescriptorWithShape(const std::vector& priority, + bool ignoreConstInputs) { // Filter out dynamic shape. if (isDynamic) { return selectPreferPrimitiveDescriptor(priority, ignoreConstInputs); @@ -371,11 +390,22 @@ void Node::selectPreferPrimitiveDescriptorWithShape(const std::vectorgetShape().toPartialShape()) ? "one dim shape" : "not one dim shape"), - " with parent ", parentPtr->getName(), - " outConfs[", inNum, "], estimate add to ", estimate); + DEBUG_LOG(getName(), + " pd[", + i, + "].inConfs[", + j, + "]" + " is ", + (isCompatible ? "compatible" : "not compatible"), + " shape is ", + (isOneDimShape(curDesc->getShape().toPartialShape()) ? "one dim shape" : "not one dim shape"), + " with parent ", + parentPtr->getName(), + " outConfs[", + inNum, + "], estimate add to ", + estimate); } } return estimate; @@ -443,7 +473,7 @@ bool Node::canBeInPlace() const { } if (getParentEdges().size() != 1 || getParentEdgeAt(0)->getParent()->getChildEdges().size() != 1 || - (getParentEdgeAt(0)->getParent()->isConstant() && !getParentEdgeAt(0)->getChild()->isConstant())) + (getParentEdgeAt(0)->getParent()->isConstant() && !getParentEdgeAt(0)->getChild()->isConstant())) return false; // TODO: we need to extend this logic to properly handle all possible inplace conflicts @@ -463,7 +493,7 @@ bool Node::canBeInPlace() const { } void Node::resolveInPlaceEdges(Edge::LOOK look) { - const NodeDesc *selected_pd = getSelectedPrimitiveDescriptor(); + const NodeDesc* selected_pd = getSelectedPrimitiveDescriptor(); if (!selected_pd) OPENVINO_THROW("Cannot find selected primitive descriptor for node: ", getName()); if (look & Edge::LOOK_DOWN) { @@ -478,16 +508,19 @@ void Node::resolveInPlaceEdges(Edge::LOOK look) { " Unexpected inplace resolve call to an allocated edge: ", *parentEdge); - //search for already allocated edge + // search for already allocated edge const auto& childEdges = getChildEdgesAtPort(inplaceOutIndx); - auto itr = std::find_if(childEdges.begin(), childEdges.end(), [](const EdgePtr& edge) { return edge->getStatus() == Edge::Status::Allocated; }); + auto itr = std::find_if(childEdges.begin(), childEdges.end(), [](const EdgePtr& edge) { + return edge->getStatus() == Edge::Status::Allocated; + }); OPENVINO_ASSERT(itr != childEdges.end(), " Could not find an allocated edge to resolve in-place for node: ", getName()); auto baseMemBlock = (*itr)->getMemory().getMemoryBlock(); auto memBlock = std::make_shared(baseMemBlock); - auto newMem = std::make_shared(getEngine(), selected_pd->getConfig().inConfs[i].getMemDesc(), memBlock); + auto newMem = + std::make_shared(getEngine(), selected_pd->getConfig().inConfs[i].getMemDesc(), memBlock); parentEdge->reuse(newMem); } } @@ -506,7 +539,8 @@ void Node::resolveInPlaceEdges(Edge::LOOK look) { OPENVINO_ASSERT(childEdge->getStatus() == Edge::Status::NotAllocated, " Unexpected inplace resolve call to an allocated edge: ", *childEdge); - auto newMem = std::make_shared(getEngine(), selected_pd->getConfig().outConfs[i].getMemDesc(), memBlock); + auto newMem = + std::make_shared(getEngine(), selected_pd->getConfig().outConfs[i].getMemDesc(), memBlock); childEdge->reuse(newMem); } } @@ -566,9 +600,9 @@ std::string Node::getPrimitiveDescriptorType() const { str_type += t; }; -#define SEARCH_TYPE(_type) \ - if ((type & impl_desc_type::_type) == impl_desc_type::_type) \ - add_type(#_type) +#define SEARCH_TYPE(_type) \ + if ((type & impl_desc_type::_type) == impl_desc_type::_type) \ + add_type(#_type) SEARCH_TYPE(undef); SEARCH_TYPE(reorder); @@ -609,13 +643,19 @@ std::string Node::getPrimitiveDescriptorType() const { if (selectedPrimitiveDesc) { if (!selectedPrimitiveDesc->getConfig().inConfs.empty()) { if (selectedPrimitiveDesc->getConfig().inConfs[0].getMemDesc()->getPrecision() != ov::element::u8) { - str_type += "_" + std::string(selectedPrimitiveDesc->getConfig().inConfs[0].getMemDesc()->getPrecision().get_type_name()); + str_type += + "_" + + std::string( + selectedPrimitiveDesc->getConfig().inConfs[0].getMemDesc()->getPrecision().get_type_name()); } else { str_type += "_I8"; } } else { if (selectedPrimitiveDesc->getConfig().outConfs[0].getMemDesc()->getPrecision() != ov::element::u8) { - str_type += "_" + std::string(selectedPrimitiveDesc->getConfig().outConfs[0].getMemDesc()->getPrecision().get_type_name()); + str_type += + "_" + + std::string( + selectedPrimitiveDesc->getConfig().outConfs[0].getMemDesc()->getPrecision().get_type_name()); } else { str_type += "_I8"; } @@ -651,7 +691,7 @@ std::vector Node::getChildEdgesAtPort(int inputNum) const { OPENVINO_THROW("Node ", getName(), " contains less output ports than ", inputNum); std::vector res; - for (auto &edge_w : childEdges) { + for (auto& edge_w : childEdges) { auto edge = edge_w.lock(); if (!edge) OPENVINO_THROW("Node ", getName(), " contains dead weak ptr"); @@ -661,7 +701,7 @@ std::vector Node::getChildEdgesAtPort(int inputNum) const { return res; } -std::vector Node::getAvailableFormatsForDims(const Shape &dims) const { +std::vector Node::getAvailableFormatsForDims(const Shape& dims) const { if (dims.getRank() == 0) return {memory::format_tag::x}; else if (dims.getRank() == 1) @@ -669,8 +709,11 @@ std::vector Node::getAvailableFormatsForDims(const Shape &di else if (dims.getRank() == 2) return {memory::format_tag::nc}; else if (dims.getRank() == 3) - return {memory::format_tag::tnc, memory::format_tag::ntc, - memory::format_tag::ncw, memory::format_tag::nCw8c, memory::format_tag::nCw16c }; + return {memory::format_tag::tnc, + memory::format_tag::ntc, + memory::format_tag::ncw, + memory::format_tag::nCw8c, + memory::format_tag::nCw16c}; else if (dims.getRank() == 4) return {memory::format_tag::nchw, memory::format_tag::nChw8c, memory::format_tag::nChw16c}; else if (dims.getRank() == 5) @@ -695,36 +738,36 @@ void Node::updateShapes() { getTypeStr(), " with name: ", getName()); - try { - if (needShapeInfer()) { - auto result = shapeInfer(); - if (ShapeInferStatus::success == result.status) { - redefineOutputMemory(result.dims); + try { + if (needShapeInfer()) { + auto result = shapeInfer(); + if (ShapeInferStatus::success == result.status) { + redefineOutputMemory(result.dims); + } + } else { + // guard check for internal dynamic nodes to avoid possible overestimation of the required memory size + if (shapeInference && FULL_PORT_MASK == shapeInference->get_port_mask()) + return; + + for (auto&& edge : getChildEdges()) { + auto edge_ptr = edge.lock(); + CPU_NODE_ASSERT(edge_ptr, " has null edge"); + if (edge_ptr->inPlace(Edge::LOOK_UP)) { + continue; } - } else { - //guard check for internal dynamic nodes to avoid possible overestimation of the required memory size - if (shapeInference && FULL_PORT_MASK == shapeInference->get_port_mask()) - return; - - for (auto&& edge : getChildEdges()) { - auto edge_ptr = edge.lock(); - CPU_NODE_ASSERT(edge_ptr, " has null edge"); - if (edge_ptr->inPlace(Edge::LOOK_UP)) { - continue; - } - auto mem = edge_ptr->getMemoryPtr(); - CPU_NODE_ASSERT(mem, " has null output memory"); + auto mem = edge_ptr->getMemoryPtr(); + CPU_NODE_ASSERT(mem, " has null output memory"); - if (mem->getShape().hasZeroDims()) { - continue; - } - fetchRawMemory(mem); + if (mem->getShape().hasZeroDims()) { + continue; } + fetchRawMemory(mem); } - } catch (const std::exception& exp) { - THROW_CPU_NODE_ERR(exp.what()); } + } catch (const std::exception& exp) { + THROW_CPU_NODE_ERR(exp.what()); + } } void Node::updateDynamicParams() { @@ -736,10 +779,17 @@ void Node::updateDynamicParams() { try { if (isExecutable()) { if (needPrepareParams()) { - OPENVINO_ASSERT(inputShapesDefined(), - "Input shapes are not defined."); - DEBUG_LOG(" prepareParams() on #", getExecIndex(), " ", getTypeStr(), " ", algToString(getAlgorithm()), - " ", getName(), " ", getOriginalLayers()); + OPENVINO_ASSERT(inputShapesDefined(), "Input shapes are not defined."); + DEBUG_LOG(" prepareParams() on #", + getExecIndex(), + " ", + getTypeStr(), + " ", + algToString(getAlgorithm()), + " ", + getName(), + " ", + getOriginalLayers()); prepareParams(); } } @@ -782,7 +832,7 @@ bool Node::outputShapeDataDependency() const { return false; } -void Node::redefineOutputMemory(const std::vector &newOutputShapes) { +void Node::redefineOutputMemory(const std::vector& newOutputShapes) { if (newOutputShapes.size() != outputShapes.size()) { OPENVINO_THROW("Number shapes mismatch with real outputs number for node with name: ", getName()); } @@ -841,34 +891,45 @@ void Node::initSupportedPrimitiveDescriptors() { }; /* When custom implementation priorities are NOT defined it is enough to - * just use the first implementation from the priority list. - * When custom implementation priorities are defined, all the implementations should be considered, - * since custom implementations can be not available at all, so a fallback to the default ones must happen - * To achive the fallback, it is necessary to create a supported primitive descriptor for each implementation - * since oneDNN primitive is mutating while iterating */ + * just use the first implementation from the priority list. + * When custom implementation priorities are defined, all the implementations should be considered, + * since custom implementations can be not available at all, so a fallback to the default ones must happen + * To achive the fallback, it is necessary to create a supported primitive descriptor for each implementation + * since oneDNN primitive is mutating while iterating */ #ifdef CPU_DEBUG_CAPS { - if (!customImplPriorities.empty()) { - DEBUG_LOG("#", getName(), " customImplPriorities [", 0 , "/", customImplPriorities.size(), - "]: ", impl_type_to_string(customImplPriorities[0])); - } + if (!customImplPriorities.empty()) { + DEBUG_LOG("#", + getName(), + " customImplPriorities [", + 0, + "/", + customImplPriorities.size(), + "]: ", + impl_type_to_string(customImplPriorities[0])); + } } #endif for (auto& desc : descs) { auto first_desc = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(desc.get())); const bool first_match = customImplPriorities.empty(); - DEBUG_LOG("#", getName(), - ", itpd.impl_info_str(): ", desc.impl_info_str(), - ", parsed imp_type: ", impl_type_to_string(parse_impl_name(desc.impl_info_str())), - ", first_match: ", first_match ? "true" : "false"); - DnnlExtensionUtils::for_each_implementation(desc, - first_match, - [&](impl_desc_type implType) { - return contains(getImplPriority(), implType); - }, - [&](dnnl::primitive_desc& desc) { - addSupportedPrimitiveDescriptor(desc); - }); + DEBUG_LOG("#", + getName(), + ", itpd.impl_info_str(): ", + desc.impl_info_str(), + ", parsed imp_type: ", + impl_type_to_string(parse_impl_name(desc.impl_info_str())), + ", first_match: ", + first_match ? "true" : "false"); + DnnlExtensionUtils::for_each_implementation( + desc, + first_match, + [&](impl_desc_type implType) { + return contains(getImplPriority(), implType); + }, + [&](dnnl::primitive_desc& desc) { + addSupportedPrimitiveDescriptor(desc); + }); // fallback. if none of the primitive types is present in the priority list just add first implementation // @todo this fallback is not necessary if primitive priority list is filled correctly @@ -889,22 +950,29 @@ void Node::filterSupportedPrimitiveDescriptors() { }; auto isNotSuitableDesc = [&](const NodeDesc& desc) { - const auto &config = desc.getConfig(); - if (inputMemoryFormatsFilter.size() > config.inConfs.size() || outputMemoryFormatsFilter.size() > config.outConfs.size()) + const auto& config = desc.getConfig(); + if (inputMemoryFormatsFilter.size() > config.inConfs.size() || + outputMemoryFormatsFilter.size() > config.outConfs.size()) OPENVINO_THROW("Incorrect number of input or output memory formats"); for (size_t i = 0; i < inputMemoryFormatsFilter.size(); i++) { if (!areCompatible(*config.inConfs[i].getMemDesc(), inputMemoryFormatsFilter[i])) { - DEBUG_LOG(getName(), " input memory format filter: ", inputMemoryFormatsFilter[i], - " not matched. Erase desc from supported primitive descriptors: ", desc); + DEBUG_LOG(getName(), + " input memory format filter: ", + inputMemoryFormatsFilter[i], + " not matched. Erase desc from supported primitive descriptors: ", + desc); return true; } } for (size_t i = 0; i < outputMemoryFormatsFilter.size(); i++) { if (!areCompatible(*config.outConfs[i].getMemDesc(), outputMemoryFormatsFilter[i])) { - DEBUG_LOG(getName(), " Output memory format filter: ", outputMemoryFormatsFilter[i], - " not matched. Erase desc from supported primitive descriptors: ", desc); + DEBUG_LOG(getName(), + " Output memory format filter: ", + outputMemoryFormatsFilter[i], + " not matched. Erase desc from supported primitive descriptors: ", + desc); return true; } } @@ -932,7 +1000,8 @@ void Node::initDescriptor(const NodeConfig& config) { if (descs.empty()) { const auto& selectedConfig = selectedPD->getConfig(); - if (selectedConfig.inConfs.size() != config.inConfs.size() || selectedConfig.outConfs.size() != config.outConfs.size()) + if (selectedConfig.inConfs.size() != config.inConfs.size() || + selectedConfig.outConfs.size() != config.outConfs.size()) return; for (size_t i = 0; i < selectedConfig.inConfs.size(); i++) { @@ -949,19 +1018,19 @@ void Node::initDescriptor(const NodeConfig& config) { return; } - auto updateNodeConfig = [&](const NodeConfig& cfg){ + auto updateNodeConfig = [&](const NodeConfig& cfg) { auto updatedConfig = cfg; for (size_t i = 0; i < descInputNumbers(); i++) { PortConfig& dataConfig = updatedConfig.inConfs[i]; - dataConfig.inPlace(canBeInPlace() ? 0 : -1); // update inPlace - dataConfig.setMemDesc(dataConfig.getMemDesc()); // reset desc with default compatibility mask + dataConfig.inPlace(canBeInPlace() ? 0 : -1); // update inPlace + dataConfig.setMemDesc(dataConfig.getMemDesc()); // reset desc with default compatibility mask } for (size_t i = 0; i < descOutputNumbers(); i++) { PortConfig& dataConfig = updatedConfig.outConfs[i]; - dataConfig.inPlace(-1); // update inPlace - dataConfig.setMemDesc(dataConfig.getMemDesc()); // reset desc with default compatibility mask + dataConfig.inPlace(-1); // update inPlace + dataConfig.setMemDesc(dataConfig.getMemDesc()); // reset desc with default compatibility mask } return updatedConfig; @@ -1017,8 +1086,8 @@ void Node::prepareMemory(const DnnlMemoryDescPtr& intDesc, size_t indx) { MemoryPtr ptr; auto weightCache = context->getWeightsCache(); if (weightCache != nullptr && memory::format_kind::blocked == intDesc->getDnnlDesc().get_format_kind()) { - const auto string_hash = - name + "_" + std::to_string(indx) + "_" + DnnlExtensionUtils::computeWeightsStringHash(internalBlob, intDesc); + const auto string_hash = name + "_" + std::to_string(indx) + "_" + + DnnlExtensionUtils::computeWeightsStringHash(internalBlob, intDesc); ptr = *weightCache->findOrCreate(string_hash, create); } else { ptr = create(); @@ -1043,7 +1112,7 @@ void Node::prepareMemory(const std::vector& intDescs) { void Node::prepareMemory(dnnl::primitive_desc_iterator& itpd) { std::vector intDescs; - for (auto &it : internalBlobDesc) + for (auto& it : internalBlobDesc) intDescs.push_back(it(itpd, 0)); Node::prepareMemory(intDescs); @@ -1063,8 +1132,8 @@ MemoryPtr Node::prepareWeightMemory(DnnlMemoryDescPtr dstWeightDesc, DnnlMemoryD srcWeightDesc = DnnlExtensionUtils::makeDescriptor(weightSrcDesc); } - auto create = [&] () { - Memory srcMemory{ getEngine(), srcWeightDesc, edgeMem->getData() }; + auto create = [&]() { + Memory srcMemory{getEngine(), srcWeightDesc, edgeMem->getData()}; MemoryPtr _ptr = std::make_shared(getEngine(), dstWeightDesc); node::Reorder::reorderData(srcMemory, *_ptr, context->getParamsCache()); @@ -1128,13 +1197,13 @@ bool Node::isInPlace() const { inplace = InPlaceType::NoInPlace; auto config = selected_pd->getConfig(); - for (auto &in : config.inConfs) { + for (auto& in : config.inConfs) { if (in.inPlace() >= 0) { inplace = InPlaceType::InPlace; break; } } - for (auto &out : config.outConfs) { + for (auto& out : config.outConfs) { if (out.inPlace() >= 0) { inplace = InPlaceType::InPlace; break; @@ -1165,7 +1234,7 @@ void Node::updateConstantType() { const auto prevConstantType = constant; constant = isConst ? ConstantType::Const : ConstantType::NoConst; if (constant == prevConstantType) - return; // state has not changed, no reason to continue + return; // state has not changed, no reason to continue for (const auto& childEdge : getChildEdges()) { const auto childNode = childEdge.lock()->getChild(); @@ -1174,7 +1243,8 @@ void Node::updateConstantType() { } void Node::addOriginalLayer(const std::string& layerName) { - if (layerName.empty()) return; + if (layerName.empty()) + return; if (originalLayers.empty()) { originalLayers = layerName; } else { @@ -1197,46 +1267,25 @@ void Node::cleanup() { const std::vector& Node::getDefaultImplPriority() { static const std::vector priorities { impl_desc_type::unknown, - // Undef impl type is used to express use-cases there real type is unkown during compilation - // Undef has higher priority than defined types in order to force primitive selection logic to make decision based on other properties - impl_desc_type::undef, - impl_desc_type::brgconv_avx512_amx_1x1, - impl_desc_type::brgconv_avx512_amx, - impl_desc_type::jit_avx512_amx_dw, - impl_desc_type::jit_avx512_amx_1x1, - impl_desc_type::jit_avx512_amx, - // Brgconv kernels disabled in order to prevent perf degradations on non AMX HW - // impl_desc_type::brgconv_avx512_1x1, - // impl_desc_type::brgconv_avx512, - impl_desc_type::jit_uni_dw, - impl_desc_type::jit_uni_1x1, - impl_desc_type::jit_uni, - impl_desc_type::jit_avx512_dw, - impl_desc_type::jit_avx512_1x1, - impl_desc_type::jit_avx512, - impl_desc_type::jit_avx2_dw, - impl_desc_type::jit_avx2_1x1, - impl_desc_type::jit_avx2, - impl_desc_type::jit_avx_dw, - impl_desc_type::jit_avx_1x1, - impl_desc_type::jit_avx, - impl_desc_type::jit_sse42_dw, - impl_desc_type::jit_sse42_1x1, - impl_desc_type::jit_sse42, + // Undef impl type is used to express use-cases there real type is unkown during compilation + // Undef has higher priority than defined types in order to force primitive selection logic to make decision + // based on other properties + impl_desc_type::undef, impl_desc_type::brgconv_avx512_amx_1x1, impl_desc_type::brgconv_avx512_amx, + impl_desc_type::jit_avx512_amx_dw, impl_desc_type::jit_avx512_amx_1x1, impl_desc_type::jit_avx512_amx, + // Brgconv kernels disabled in order to prevent perf degradations on non AMX HW + // impl_desc_type::brgconv_avx512_1x1, + // impl_desc_type::brgconv_avx512, + impl_desc_type::jit_uni_dw, impl_desc_type::jit_uni_1x1, impl_desc_type::jit_uni, + impl_desc_type::jit_avx512_dw, impl_desc_type::jit_avx512_1x1, impl_desc_type::jit_avx512, + impl_desc_type::jit_avx2_dw, impl_desc_type::jit_avx2_1x1, impl_desc_type::jit_avx2, + impl_desc_type::jit_avx_dw, impl_desc_type::jit_avx_1x1, impl_desc_type::jit_avx, + impl_desc_type::jit_sse42_dw, impl_desc_type::jit_sse42_1x1, impl_desc_type::jit_sse42, #if defined(OPENVINO_ARCH_ARM64) - impl_desc_type::jit_asimd, + impl_desc_type::jit_asimd, #endif - impl_desc_type::gemm_any, - impl_desc_type::gemm_blas, - impl_desc_type::gemm_avx512, - impl_desc_type::gemm_avx2, - impl_desc_type::gemm_avx, - impl_desc_type::gemm_sse42, - impl_desc_type::gemm_acl, - impl_desc_type::acl, - impl_desc_type::jit_gemm, - impl_desc_type::ref_any, - impl_desc_type::ref, + impl_desc_type::gemm_any, impl_desc_type::gemm_blas, impl_desc_type::gemm_avx512, impl_desc_type::gemm_avx2, + impl_desc_type::gemm_avx, impl_desc_type::gemm_sse42, impl_desc_type::gemm_acl, impl_desc_type::acl, + impl_desc_type::jit_gemm, impl_desc_type::ref_any, impl_desc_type::ref, }; return priorities; @@ -1246,30 +1295,31 @@ const std::vector& Node::getImplPriority() { if (!customImplPriorities.empty()) return customImplPriorities; - return getDefaultImplPriority(); } -PortDescBasePtr Node::getConsistentInputDesc(const NodeConfig &config, size_t idx) const { +PortDescBasePtr Node::getConsistentInputDesc(const NodeConfig& config, size_t idx) const { const auto& inConf = config.inConfs[idx]; - if (inConf.inPlace() >= 0) { // node have inplace input + if (inConf.inPlace() >= 0) { // node have inplace input auto inplaceIndx = static_cast(inConf.inPlace()); PortDescBasePtr outPortDesc; const auto& outConf = config.outConfs[inplaceIndx]; - if (outConf.inPlace() == static_cast(idx)) { // the input desc port is the same port used for inplace output - outPortDesc = outConf.getPortDesc(); // just use desc from this output port + if (outConf.inPlace() == + static_cast(idx)) { // the input desc port is the same port used for inplace output + outPortDesc = outConf.getPortDesc(); // just use desc from this output port } else { - outPortDesc = getConsistentOutputDesc(config, inplaceIndx); // get consistent desc otherwise + outPortDesc = getConsistentOutputDesc(config, inplaceIndx); // get consistent desc otherwise } - if (inConf.getPortDesc()->isCompatible(*outPortDesc)) { // use the desc if compatible + if (inConf.getPortDesc()->isCompatible(*outPortDesc)) { // use the desc if compatible return outPortDesc; } } - auto *parentSelectedPD = getParentEdgeAt(idx)->getParent()->getSelectedPrimitiveDescriptor(); + auto* parentSelectedPD = getParentEdgeAt(idx)->getParent()->getSelectedPrimitiveDescriptor(); if (!parentSelectedPD) - OPENVINO_THROW("Cannot get selected primitive descriptor for node: ", getParentEdgeAt(idx)->getParent()->getName()); + OPENVINO_THROW("Cannot get selected primitive descriptor for node: ", + getParentEdgeAt(idx)->getParent()->getName()); int num = getParentEdgeAt(idx)->getInputNum(); if (num >= 0) { @@ -1290,26 +1340,28 @@ PortDescBasePtr Node::getConsistentInputDesc(const NodeConfig &config, size_t id return inConf.getPortDesc(); } -PortDescBasePtr Node::getConsistentOutputDesc(const NodeConfig &config, size_t idx) const { +PortDescBasePtr Node::getConsistentOutputDesc(const NodeConfig& config, size_t idx) const { const auto& outConf = config.outConfs[idx]; - if (outConf.inPlace() >= 0) { // node have inplace output + if (outConf.inPlace() >= 0) { // node have inplace output auto inplaceIndx = static_cast(outConf.inPlace()); PortDescBasePtr inpPortDesc; const auto& inpConf = config.inConfs[inplaceIndx]; - if (inpConf.inPlace() == static_cast(idx)) { // the input desc port is the same port used for inplace output - inpPortDesc = inpConf.getPortDesc(); // just use desc from this output port + if (inpConf.inPlace() == + static_cast(idx)) { // the input desc port is the same port used for inplace output + inpPortDesc = inpConf.getPortDesc(); // just use desc from this output port } else { - inpPortDesc = getConsistentInputDesc(config, inplaceIndx); // get consistent desc otherwise + inpPortDesc = getConsistentInputDesc(config, inplaceIndx); // get consistent desc otherwise } - if (outConf.getPortDesc()->isCompatible(*inpPortDesc)) { // use the desc if compatible + if (outConf.getPortDesc()->isCompatible(*inpPortDesc)) { // use the desc if compatible return inpPortDesc; } } - auto *childSelectedPD = getChildEdgeAt(idx)->getChild()->getSelectedPrimitiveDescriptor(); + auto* childSelectedPD = getChildEdgeAt(idx)->getChild()->getSelectedPrimitiveDescriptor(); if (!childSelectedPD) - OPENVINO_THROW("Cannot get selected primitive descriptor for node: ", getChildEdgeAt(idx)->getChild()->getName()); + OPENVINO_THROW("Cannot get selected primitive descriptor for node: ", + getChildEdgeAt(idx)->getChild()->getName()); int num = getChildEdgeAt(idx)->getOutputNum(); if (num >= 0) { @@ -1331,7 +1383,7 @@ PortDescBasePtr Node::getConsistentOutputDesc(const NodeConfig &config, size_t i } void Node::initOptimalPrimitiveDescriptor() { - if (one_of(getType(), Type::RNNCell, Type::RNNSeq)) // can be skipped for RNN node + if (one_of(getType(), Type::RNNCell, Type::RNNSeq)) // can be skipped for RNN node return; auto selected_pd = getSelectedPrimitiveDescriptor(); @@ -1358,7 +1410,8 @@ void Node::initOptimalPrimitiveDescriptor() { // it is assumed that the nodes will define dense tensors on output edges // if it is not the case the implementation must redefine this behaviour if (outMemDesc->getType() & Blocked) { - config.outConfs[i].setMemDesc(std::dynamic_pointer_cast(outMemDesc), BlockedMemoryDesc::FULL_MASK); + config.outConfs[i].setMemDesc(std::dynamic_pointer_cast(outMemDesc), + BlockedMemoryDesc::FULL_MASK); } } } @@ -1366,9 +1419,9 @@ void Node::initOptimalPrimitiveDescriptor() { initDescriptor(config); } -bool Node::isConfigDefined(const NodeConfig &config) const { +bool Node::isConfigDefined(const NodeConfig& config) const { for (const auto& configs : {config.inConfs, config.outConfs}) { - for (const auto &dc : configs) { + for (const auto& dc : configs) { if (!dc.getMemDesc()->isDefined()) return false; } @@ -1376,14 +1429,14 @@ bool Node::isConfigDefined(const NodeConfig &config) const { return true; } -MemoryDescPtr Node::getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const { +MemoryDescPtr Node::getSrcMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const { if (getInputShapeAtPort(idx).isDynamic()) { return DnnlExtensionUtils::makeUndefinedDesc(prim_desc.src_desc(idx), getInputShapeAtPort(idx)); } return DnnlExtensionUtils::makeDescriptor(prim_desc.src_desc(idx)); } -MemoryDescPtr Node::getDstMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const { +MemoryDescPtr Node::getDstMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const { if (getOutputShapeAtPort(idx).isDynamic()) { return DnnlExtensionUtils::makeUndefinedDesc(prim_desc.dst_desc(idx), getOutputShapeAtPort(idx)); } @@ -1393,7 +1446,7 @@ MemoryDescPtr Node::getDstMemDesc(const dnnl::primitive_desc &prim_desc, size_t void Node::appendPostOpArgs(const dnnl::primitive_attr& attr, std::unordered_map& primArgs, const std::unordered_map& postOpsArgs) { - for (auto & entry : postOpsArgs) { + for (auto& entry : postOpsArgs) { primArgs[entry.first] = entry.second->getPrimitive(); } } @@ -1426,11 +1479,17 @@ dnnl::memory::format_tag Node::getWeightsFormatTagByDims(const VectorDims& dims) } } -void Node::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::unordered_map& postOpsMem, const int channelAxis) { +void Node::appendPostOps(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::unordered_map& postOpsMem, + const int channelAxis) { OPENVINO_THROW("Fusing of ", NameFromType(this->getType()), " operation is not implemented"); } -void Node::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector& postOpsMem, const int channelAxis) { +void Node::appendPostOps(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::vector& postOpsMem, + const int channelAxis) { OPENVINO_THROW("Fusing of ", NameFromType(this->getType()), " operation is not implemented"); } @@ -1474,12 +1533,12 @@ ov::element::Type Node::getRuntimePrecision() const { } Node* Node::NodesFactory::create(const std::shared_ptr& op, const GraphContext::CPtr context) { - // getExceptionDescWithoutStatus removes redundant information from the exception message. For instance, the NotImplemented - // exception is generated in the form: full_path_to_src_file:line_number [ NOT_IMPLEMENTED ] reason. + // getExceptionDescWithoutStatus removes redundant information from the exception message. For instance, the + // NotImplemented exception is generated in the form: full_path_to_src_file:line_number [ NOT_IMPLEMENTED ] reason. // An example for gather node: - // /path-to-openVino-root/src/plugins/intel_cpu/nodes/gather.cpp:42 [ NOT_IMPLEMENTED ] Only opset7 Gather operation is supported - // The most important part of the message is the reason, so the lambda trims everything up to "]" - // Note that the op type and its friendly name will also be provided if we fail to create the node. + // /path-to-openVino-root/src/plugins/intel_cpu/nodes/gather.cpp:42 [ NOT_IMPLEMENTED ] Only opset7 Gather operation + // is supported The most important part of the message is the reason, so the lambda trims everything up to "]" Note + // that the op type and its friendly name will also be provided if we fail to create the node. auto getExceptionDescWithoutStatus = [](const ov::Exception& ex) { std::string desc = ex.what(); size_t pos = desc.find(']'); @@ -1492,7 +1551,7 @@ Node* Node::NodesFactory::create(const std::shared_ptr& op, const Grap } return desc; }; - Node *newNode = nullptr; + Node* newNode = nullptr; std::string errorMessage; if (newNode == nullptr) { try { @@ -1539,7 +1598,7 @@ Node* Node::NodesFactory::create(const std::shared_ptr& op, const Grap return newNode; } -bool Node::canBePerformedAsScaleShift(const Node *parentNode) const { +bool Node::canBePerformedAsScaleShift(const Node* parentNode) const { #if defined(OPENVINO_ARCH_X86_64) OPENVINO_ASSERT(parentNode); @@ -1547,7 +1606,7 @@ bool Node::canBePerformedAsScaleShift(const Node *parentNode) const { const auto channelAxis = parentNode->getFusingAxis(); for (size_t i = 0; i < getParentEdges().size(); i++) { - Node *node = getParentEdgeAt(i)->getParent().get(); + Node* node = getParentEdgeAt(i)->getParent().get(); if (node == nullptr) { OPENVINO_THROW("Cannot get parent node for ", getName(), " on ", i, " port"); } @@ -1575,7 +1634,7 @@ bool Node::canBePerformedAsScaleShift(const Node *parentNode) const { const auto isConvertablePowerStatic = [&]() { if (getAlgorithm() == Algorithm::EltwisePowerStatic) { - const auto eltwise = dynamic_cast(this); + const auto eltwise = dynamic_cast(this); if (!eltwise) { OPENVINO_THROW("Cannot cast ", getName(), " to Eltwise"); } @@ -1584,13 +1643,15 @@ bool Node::canBePerformedAsScaleShift(const Node *parentNode) const { return false; }; - return (one_of(getAlgorithm(), Algorithm::EltwiseAdd, - Algorithm::EltwiseMultiply, - Algorithm::EltwiseSubtract, - Algorithm::EltwiseDivide, - Algorithm::EltwisePrelu, - Algorithm::EltwiseMulAdd) && isBroadcastableToDataInput()) - || isConvertablePowerStatic(); + return (one_of(getAlgorithm(), + Algorithm::EltwiseAdd, + Algorithm::EltwiseMultiply, + Algorithm::EltwiseSubtract, + Algorithm::EltwiseDivide, + Algorithm::EltwisePrelu, + Algorithm::EltwiseMulAdd) && + isBroadcastableToDataInput()) || + isConvertablePowerStatic(); #else // TODO: provide correct list of operations for other backends return false; @@ -1600,11 +1661,11 @@ bool Node::canBePerformedAsScaleShift(const Node *parentNode) const { // @todo shifts for Subtract and scales for Divide are replaced with // Add (with opposite sign) and Multiply (with inverse value) for legacy dephwise post ops // This can be avoided after dephwise post ops are gone -std::pair, std::vector> Node::getScalesAndShifts(const Node *parentNode) const { +std::pair, std::vector> Node::getScalesAndShifts(const Node* parentNode) const { std::vector scales, shifts; const auto fillValuesFrom = [&](const NodePtr& constInput, std::vector& buffer) { - auto *constInputNode = dynamic_cast(constInput.get()); + auto* constInputNode = dynamic_cast(constInput.get()); if (!constInputNode) { OPENVINO_THROW("Cannot cast ", constInput->getName(), " to Input"); } @@ -1628,7 +1689,7 @@ std::pair, std::vector> Node::getScalesAndShifts(const fillValuesFrom(getParentEdgeAt(1)->getParent(), scales); fillValuesFrom(getParentEdgeAt(2)->getParent(), shifts); } else if (one_of(getAlgorithm(), Algorithm::EltwisePowerStatic)) { - const auto power = dynamic_cast(this); + const auto power = dynamic_cast(this); if (!power) { OPENVINO_THROW("Cannot cast ", getName(), " to Eltwise"); } @@ -1639,25 +1700,30 @@ std::pair, std::vector> Node::getScalesAndShifts(const } switch (getAlgorithm()) { - case Algorithm::EltwiseAdd: { - scales.resize(shifts.size(), 1.0f); - break; - } - case Algorithm::EltwiseSubtract: { - scales.resize(shifts.size(), 1.0f); - std::transform(shifts.begin(), shifts.end(), shifts.begin(), [](float shift){ return -1.0f * shift; }); - break; - } - case Algorithm::EltwiseMultiply: { - shifts.resize(scales.size(), 0.0f); - break; - } - case Algorithm::EltwiseDivide: { - shifts.resize(scales.size(), 0.0f); - std::transform(scales.begin(), scales.end(), scales.begin(), [](float scale){ return 1.0f / scale; }); - break; - } - default: break; + case Algorithm::EltwiseAdd: { + scales.resize(shifts.size(), 1.0f); + break; + } + case Algorithm::EltwiseSubtract: { + scales.resize(shifts.size(), 1.0f); + std::transform(shifts.begin(), shifts.end(), shifts.begin(), [](float shift) { + return -1.0f * shift; + }); + break; + } + case Algorithm::EltwiseMultiply: { + shifts.resize(scales.size(), 0.0f); + break; + } + case Algorithm::EltwiseDivide: { + shifts.resize(scales.size(), 0.0f); + std::transform(scales.begin(), scales.end(), scales.begin(), [](float scale) { + return 1.0f / scale; + }); + break; + } + default: + break; } return {scales, shifts}; @@ -1824,22 +1890,25 @@ bool Node::canFuseSimpleOperation(const NodePtr& node) const { return ret; } else if (node->getType() == Type::Eltwise) { return DnnlExtensionUtils::isUnarySupportedAsPostOp(node->getAlgorithm()) || - node->canBePerformedAsScaleShift(this); + node->canBePerformedAsScaleShift(this); } return false; } -void Node::addFusedNode(const NodePtr &fusingNode) { +void Node::addFusedNode(const NodePtr& fusingNode) { fusedWith.push_back(fusingNode); } void Node::addSupportedPrimDesc(const std::vector& inPortConfigs, const std::vector& outPortConfigs, impl_desc_type implType) { - auto fill_port = [] (const PortConfigurator& portConfigurator, const Shape& shape, - ov::element::Type prc, std::vector& port) -> bool { - // In order to simplify particular node initialization logic we just don't add config in case target shape is not supported by blockedDescCreator. - // This should be suitable for major of scenarios since almost all nodes add `ncsp` blockedDescCreator which supports any shape rank. + auto fill_port = [](const PortConfigurator& portConfigurator, + const Shape& shape, + ov::element::Type prc, + std::vector& port) -> bool { + // In order to simplify particular node initialization logic we just don't add config in case target shape is + // not supported by blockedDescCreator. This should be suitable for major of scenarios since almost all nodes + // add `ncsp` blockedDescCreator which supports any shape rank. if (shape.getRank() < portConfigurator.blockedDescCreator->getMinimalRank()) return false; @@ -1856,14 +1925,16 @@ void Node::addSupportedPrimDesc(const std::vector& inPortConfi NodeConfig config; for (size_t i = 0; i < inPortConfigs.size(); i++) { auto shape = inPortConfigs[i].shape.getRank() == 0 ? getInputShapeAtPort(i) : inPortConfigs[i].shape; - auto prc = inPortConfigs[i].prc == ov::element::undefined ? getOriginalInputPrecisionAtPort(i) : inPortConfigs[i].prc; + auto prc = + inPortConfigs[i].prc == ov::element::undefined ? getOriginalInputPrecisionAtPort(i) : inPortConfigs[i].prc; if (!fill_port(inPortConfigs[i], shape, prc, config.inConfs)) return; } for (size_t i = 0; i < outPortConfigs.size(); i++) { auto dims = outPortConfigs[i].shape.getRank() == 0 ? getOutputShapeAtPort(i) : outPortConfigs[i].shape; - auto prc = outPortConfigs[i].prc == ov::element::undefined ? getOriginalOutputPrecisionAtPort(i) : outPortConfigs[i].prc; + auto prc = outPortConfigs[i].prc == ov::element::undefined ? getOriginalOutputPrecisionAtPort(i) + : outPortConfigs[i].prc; if (!fill_port(outPortConfigs[i], dims, prc, config.outConfs)) return; } @@ -1884,23 +1955,27 @@ void Node::fuseDQScales(const float* scaleData, const size_t scaleSize) { if (scaleSize > DQScales.size()) DQScales.resize(scaleSize, DQScales[0]); if (1 == scaleSize) { - std::transform(DQScales.begin(), DQScales.end(), DQScales.begin(), [=](float val){ return (scaleData[0] * val); }); - } else { - for (size_t i = 0; i < DQScales.size(); i++) { - DQScales[i] *= scaleData[i]; - } - } - if (std::all_of(DQScales.begin(), DQScales.end(), [OV_CAPTURE_CPY_AND_THIS](float val){ return (val == DQScales[0]);})) + std::transform(DQScales.begin(), DQScales.end(), DQScales.begin(), [=](float val) { + return (scaleData[0] * val); + }); + } else { + for (size_t i = 0; i < DQScales.size(); i++) { + DQScales[i] *= scaleData[i]; + } + } + if (std::all_of(DQScales.begin(), DQScales.end(), [OV_CAPTURE_CPY_AND_THIS](float val) { + return (val == DQScales[0]); + })) DQScales.resize(1); } int Node::inPlaceInputPort(int portIdx) const { if (inputShapes.empty()) { - //special case - a dead end node + // special case - a dead end node return -1; } - const NodeDesc *selected_pd = getSelectedPrimitiveDescriptor(); + const NodeDesc* selected_pd = getSelectedPrimitiveDescriptor(); if (!selected_pd) OPENVINO_THROW("Cannot find selected primitive descriptor for node: ", getName()); @@ -1918,11 +1993,11 @@ int Node::inPlaceInputPort(int portIdx) const { int Node::inPlaceOutPort(int portIdx) const { if (outputShapes.empty()) { - //special case - a dead end node + // special case - a dead end node return -1; } - const NodeDesc *selected_pd = getSelectedPrimitiveDescriptor(); + const NodeDesc* selected_pd = getSelectedPrimitiveDescriptor(); if (!selected_pd) OPENVINO_THROW("Cannot find selected primitive descriptor for node: ", getName()); @@ -1939,8 +2014,8 @@ int Node::inPlaceOutPort(int portIdx) const { } void Node::resolveInPlaceDirection() { - enum InplaceDirectionType {UP, DOWN, CYCLIC, NONE}; - enum PortType {INPUT, OUTPUT}; + enum InplaceDirectionType { UP, DOWN, CYCLIC, NONE }; + enum PortType { INPUT, OUTPUT }; auto inPlaceDirection = [](const Node* node, PortType portType, int portNum) -> InplaceDirectionType { if (PortType::INPUT == portType) { @@ -1990,7 +2065,8 @@ void Node::resolveInPlaceDirection() { if (auto pEdge = wEdge.lock()) { auto inpPort = pEdge->getOutputNum(); auto inPlaceInpPort = inPlaceInputPort(inpPort); - if (inPlaceInpPort < 0 || inPlaceDirection(this, PortType::INPUT, inpPort) != InplaceDirectionType::CYCLIC) { + if (inPlaceInpPort < 0 || + inPlaceDirection(this, PortType::INPUT, inpPort) != InplaceDirectionType::CYCLIC) { continue; } // inPlace memory cyclic dependency detected, need to resolve @@ -2002,12 +2078,14 @@ void Node::resolveInPlaceDirection() { config.inConfs[inpPort].inPlace(-1); initDescriptor(config); } else if (parentInPlaceDirection == InplaceDirectionType::DOWN) { - //search if siblings already have downstream direction + // search if siblings already have downstream direction auto downstreamPeers = [&] { for (auto& peerEdge : pParent->getChildEdgesAtPort(pEdge->getInputNum())) { auto peerNode = peerEdge->getChild().get(); - if (peerNode == this) continue; - if (inPlaceDirection(peerNode, PortType::INPUT, peerEdge->getOutputNum()) == InplaceDirectionType::DOWN) { + if (peerNode == this) + continue; + if (inPlaceDirection(peerNode, PortType::INPUT, peerEdge->getOutputNum()) == + InplaceDirectionType::DOWN) { return true; } } @@ -2068,7 +2146,8 @@ void Node::resolveInPlaceDirection() { // note: there are only non-inplace or cyclic-inplace descendants at the moment. std::function searchReferencingOutput; searchReferencingOutput = [&](const Node* node, int portIdx) -> void { - if (numConflicts > 1) return; // early stop + if (numConflicts > 1) + return; // early stop auto childEdges = node->getChildEdgesAtPort(portIdx); for (auto& edge : childEdges) { auto pChild = edge->getChild().get(); @@ -2077,7 +2156,8 @@ void Node::resolveInPlaceDirection() { } else { auto result = inPlaceDirection(pChild, PortType::INPUT, edge->getOutputNum()); if (InplaceDirectionType::CYCLIC == result) { - return searchReferencingOutput(pChild, pChild->inPlaceInputPort(edge->getOutputNum())); + return searchReferencingOutput(pChild, + pChild->inPlaceInputPort(edge->getOutputNum())); } } } @@ -2090,7 +2170,8 @@ void Node::resolveInPlaceDirection() { // note: the parent node does not use inPlace memory at the moment, let's check the siblings for (auto& peerEdge : pParent->getChildEdgesAtPort(pEdge->getInputNum())) { auto peerNode = peerEdge->getChild().get(); - if (peerNode == this) continue; + if (peerNode == this) + continue; if (Type::Output == peerNode->getType()) { numConflicts++; } else { @@ -2102,11 +2183,11 @@ void Node::resolveInPlaceDirection() { } } - if (numConflicts == 1) { // downstream to make the only output edge be referenced. + if (numConflicts == 1) { // downstream to make the only output edge be referenced. auto config = getSelectedPrimitiveDescriptor()->getConfig(); config.outConfs[inPlaceInpPort].inPlace(-1); initDescriptor(config); - } else { // the default direction of upstream + } else { // the default direction of upstream auto config = getSelectedPrimitiveDescriptor()->getConfig(); config.inConfs[inpPort].inPlace(-1); initDescriptor(config); @@ -2121,8 +2202,7 @@ void Node::resolveInPlaceDirection() { #ifndef CPU_DEBUG_CAPS std::ostream& operator<<(std::ostream& out, const Node& node) { - return out << "Node " << node.getName() << - " of type " << node.getTypeStr() << "\n"; + return out << "Node " << node.getName() << " of type " << node.getTypeStr() << "\n"; } std::ostream& operator<<(std::ostream& out, const Node* node) { @@ -2130,5 +2210,5 @@ std::ostream& operator<<(std::ostream& out, const Node* node) { } #endif -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/node.h b/src/plugins/intel_cpu/src/node.h index 453b8323fe9e66..9166e87dbf50e1 100644 --- a/src/plugins/intel_cpu/src/node.h +++ b/src/plugins/intel_cpu/src/node.h @@ -4,37 +4,38 @@ #pragma once +#include + #include +#include #include +#include +#include +#include +#include + #include "cpu_memory.h" #include "cpu_shape.h" #include "cpu_types.h" #include "edge.h" +#include "graph_context.h" #include "memory_desc/cpu_memory_desc.h" -#include "selective_build.h" #include "memory_desc/dnnl_memory_desc.h" +#include "nodes/executors/executor.hpp" +#include "nodes/node_config.h" #include "onednn/dnnl.h" #include "onednn/iml_type_mapper.h" -#include #include "openvino/cc/factory.h" #include "openvino/core/node.hpp" -#include -#include "nodes/node_config.h" -#include #include "perf_count.h" -#include "utils/debug_capabilities.h" +#include "selective_build.h" #include "utils/bit_util.hpp" #include "utils/debug_capabilities.h" -#include "graph_context.h" -#include "nodes/executors/executor.hpp" - -#include -#include -#include - -#define THROW_CPU_NODE_ERR(...) OPENVINO_THROW("[CPU] ", getTypeStr(), " node with name '", getName(), "' ", __VA_ARGS__) -#define CPU_NODE_ASSERT(condition, ...) OPENVINO_ASSERT(condition, getTypeStr(), " node with name '", getName(), "' ", __VA_ARGS__) +#define THROW_CPU_NODE_ERR(...) \ + OPENVINO_THROW("[CPU] ", getTypeStr(), " node with name '", getName(), "' ", __VA_ARGS__) +#define CPU_NODE_ASSERT(condition, ...) \ + OPENVINO_ASSERT(condition, getTypeStr(), " node with name '", getName(), "' ", __VA_ARGS__) namespace ov { namespace intel_cpu { @@ -45,13 +46,25 @@ using NodeWeakPtr = std::weak_ptr; class PortConfigurator { public: - PortConfigurator(ov::intel_cpu::LayoutType blockedDescType, ov::element::Type prc, const Shape& shape, - bool constant = false, int inPlace = -1) : - blockedDescCreator(getBlockedDescCreator(blockedDescType)), prc(prc), shape(shape), constant(constant), inPlace(inPlace) {} - - PortConfigurator(ov::intel_cpu::LayoutType blockedDescType, ov::element::Type prc = ov::element::undefined, - bool constant = false, int inPlace = -1) : - blockedDescCreator(getBlockedDescCreator(blockedDescType)), prc(prc), constant(constant), inPlace(inPlace) {} + PortConfigurator(ov::intel_cpu::LayoutType blockedDescType, + ov::element::Type prc, + const Shape& shape, + bool constant = false, + int inPlace = -1) + : blockedDescCreator(getBlockedDescCreator(blockedDescType)), + prc(prc), + shape(shape), + constant(constant), + inPlace(inPlace) {} + + PortConfigurator(ov::intel_cpu::LayoutType blockedDescType, + ov::element::Type prc = ov::element::undefined, + bool constant = false, + int inPlace = -1) + : blockedDescCreator(getBlockedDescCreator(blockedDescType)), + prc(prc), + constant(constant), + inPlace(inPlace) {} ov::intel_cpu::BlockedDescCreator::CreatorConstPtr blockedDescCreator; const ov::element::Type prc; @@ -60,7 +73,8 @@ class PortConfigurator { int inPlace = -1; private: - static ov::intel_cpu::BlockedDescCreator::CreatorConstPtr getBlockedDescCreator(ov::intel_cpu::LayoutType blockedDescType) { + static ov::intel_cpu::BlockedDescCreator::CreatorConstPtr getBlockedDescCreator( + ov::intel_cpu::LayoutType blockedDescType) { auto& creators = ov::intel_cpu::BlockedDescCreator::getCommonCreators(); if (creators.find(blockedDescType) == creators.end()) { OPENVINO_THROW("Cannot find tensor descriptor creator"); @@ -71,11 +85,15 @@ class PortConfigurator { class NodeDesc { public: - NodeDesc(NodeConfig conf, impl_desc_type type): - config(std::move(conf)), implementationType(type), executorFactory(nullptr) {} + NodeDesc(NodeConfig conf, impl_desc_type type) + : config(std::move(conf)), + implementationType(type), + executorFactory(nullptr) {} - NodeDesc(NodeConfig conf, impl_desc_type type, ExecutorFactoryLegacyPtr factory): - config(std::move(conf)), implementationType(type), executorFactory(factory) {} + NodeDesc(NodeConfig conf, impl_desc_type type, ExecutorFactoryLegacyPtr factory) + : config(std::move(conf)), + implementationType(type), + executorFactory(factory) {} const NodeConfig& getConfig() const { return config; @@ -98,8 +116,8 @@ class NodeDesc { } template ::value && !std::is_reference::value, int>::type = 0, - typename std::enable_if::value, int>::type = 0> + typename std::enable_if::value && !std::is_reference::value, int>::type = 0, + typename std::enable_if::value, int>::type = 0> std::shared_ptr getExecutorFactoryAs() { auto casted = std::dynamic_pointer_cast(executorFactory); if (!casted) @@ -119,34 +137,41 @@ class NodeDesc { class Node { public: - Node(const Node &) = delete; - Node & operator = (const Node &) = delete; + Node(const Node&) = delete; + Node& operator=(const Node&) = delete; using AttrPtr = std::shared_ptr; public: - template + template struct Tag {}; struct PerfCounters { PerfCounters(std::string const& name) - : execute(openvino::itt::handle(name)) - , getSupportedDescriptors(openvino::itt::handle>("Node::getSupportedDescriptors")) - , initSupportedPrimitiveDescriptors(openvino::itt::handle>("Node::initSupportedPrimitiveDescriptors")) - , filterSupportedPrimitiveDescriptors(openvino::itt::handle>("Node::filterSupportedPrimitiveDescriptors")) - , selectOptimalPrimitiveDescriptor(openvino::itt::handle>("Node::selectOptimalPrimitiveDescriptor")) - , createPrimitive(openvino::itt::handle>("Node::createPrimitive")) - , initOptimalPrimitiveDescriptor(openvino::itt::handle>("Node::initOptimalPrimitiveDescriptor")) - {} - - template + : execute(openvino::itt::handle(name)), + getSupportedDescriptors(openvino::itt::handle>("Node::getSupportedDescriptors")), + initSupportedPrimitiveDescriptors( + openvino::itt::handle>("Node::initSupportedPrimitiveDescriptors")), + filterSupportedPrimitiveDescriptors( + openvino::itt::handle>("Node::filterSupportedPrimitiveDescriptors")), + selectOptimalPrimitiveDescriptor( + openvino::itt::handle>("Node::selectOptimalPrimitiveDescriptor")), + createPrimitive(openvino::itt::handle>("Node::createPrimitive")), + initOptimalPrimitiveDescriptor( + openvino::itt::handle>("Node::initOptimalPrimitiveDescriptor")) {} + + template void buildClassCounters(const std::string& type_name) { getSupportedDescriptors = openvino::itt::handle>(type_name + "::getSupportedDescriptors"); - initSupportedPrimitiveDescriptors = openvino::itt::handle>(type_name + "::initSupportedPrimitiveDescriptors"); - filterSupportedPrimitiveDescriptors = openvino::itt::handle>(type_name + "::filterSupportedPrimitiveDescriptors"); - selectOptimalPrimitiveDescriptor = openvino::itt::handle>(type_name + "::selectOptimalPrimitiveDescriptor"); + initSupportedPrimitiveDescriptors = + openvino::itt::handle>(type_name + "::initSupportedPrimitiveDescriptors"); + filterSupportedPrimitiveDescriptors = + openvino::itt::handle>(type_name + "::filterSupportedPrimitiveDescriptors"); + selectOptimalPrimitiveDescriptor = + openvino::itt::handle>(type_name + "::selectOptimalPrimitiveDescriptor"); createPrimitive = openvino::itt::handle>(type_name + "::createPrimitive"); - initOptimalPrimitiveDescriptor = openvino::itt::handle>(type_name + "::initOptimalPrimitiveDescriptor"); + initOptimalPrimitiveDescriptor = + openvino::itt::handle>(type_name + "::initOptimalPrimitiveDescriptor"); } openvino::itt::handle_t execute; @@ -159,7 +184,7 @@ class Node { }; class NodesFactory; - static NodesFactory & factory(); + static NodesFactory& factory(); virtual ~Node() = default; @@ -171,11 +196,12 @@ class Node { void remove(); void addParentEdge(const EdgePtr& edge) { - assert(std::none_of(parentEdges.begin(), parentEdges.end(), - [&edge](const EdgeWeakPtr& _edge){ - return _edge.lock()->getOutputNum() == edge->getOutputNum(); - })); - parentEdges.insert(std::upper_bound(parentEdges.begin(), parentEdges.end(), edge, + assert(std::none_of(parentEdges.begin(), parentEdges.end(), [&edge](const EdgeWeakPtr& _edge) { + return _edge.lock()->getOutputNum() == edge->getOutputNum(); + })); + parentEdges.insert(std::upper_bound(parentEdges.begin(), + parentEdges.end(), + edge, [](const EdgeWeakPtr& lhs, const EdgeWeakPtr& rhs) { return lhs.lock()->getOutputNum() < rhs.lock()->getOutputNum(); }), @@ -196,11 +222,11 @@ class Node { removeEdge(edge, childEdges); } - const std::vector &getParentEdges() const noexcept { + const std::vector& getParentEdges() const noexcept { return parentEdges; } - const std::vector &getChildEdges() const noexcept { + const std::vector& getChildEdges() const noexcept { return childEdges; } @@ -238,7 +264,7 @@ class Node { return getSrcMemoryAtPort(idx)->getData(); } - template + template T* getSrcDataAtPortAs(size_t idx) const { return getSrcMemoryAtPort(idx)->getDataAs(); } @@ -247,7 +273,7 @@ class Node { return getDstMemoryAtPort(idx)->getData(); } - template + template T* getDstDataAtPortAs(size_t idx) const { return getDstMemoryAtPort(idx)->getDataAs(); } @@ -273,7 +299,8 @@ class Node { enum class ConstantType { Const, // Node is placed in a constant subgraph NoConst, // Node is placed in a non-constant subgraph - StrictNoConst, // Node produces non-constant subgraph: this type can't be changed and it does not depend on the parent nodes' ConstantType. + StrictNoConst, // Node produces non-constant subgraph: this type can't be changed and it does not depend on the + // parent nodes' ConstantType. }; ConstantType getConstantType() const; void updateConstantType(); @@ -290,10 +317,11 @@ class Node { bool isFusedWith(Type type) const; - virtual void addFusedNode(const NodePtr &fusingNode); + virtual void addFusedNode(const NodePtr& fusingNode); virtual void fuseInto(NodePtr& parentNode) { - // The graph supports fusing only of consecutive nodes and some graph logic requires to know through which input port a node was fused into parent one. + // The graph supports fusing only of consecutive nodes and some graph logic requires to know through which input + // port a node was fused into parent one. for (size_t i = 0; i < getParentEdges().size(); i++) { if (getParentEdgeAt(i)->getParent().get() == parentNode.get()) { setFusingPort(i); @@ -323,15 +351,15 @@ class Node { fusedWith.clear(); } - void mergeWith(const NodePtr &merge) { + void mergeWith(const NodePtr& merge) { mergedWith.push_back(merge); } - const std::vector &getMergeWith() { + const std::vector& getMergeWith() { return mergedWith; } - const std::vector &getFusedWith() { + const std::vector& getFusedWith() { return fusedWith; } @@ -343,17 +371,17 @@ class Node { this->fusingPort = fusingPort; } - const std::string &getName() const { + const std::string& getName() const { return name; } void addOriginalLayer(const std::string& layerName); - const std::string &getOriginalLayers() const { + const std::string& getOriginalLayers() const { return originalLayers; } - const std::string &getParallelDomain() const { + const std::string& getParallelDomain() const { return parallelDomain; } @@ -437,7 +465,9 @@ class Node { virtual std::string getPrimitiveDescriptorType() const; - PerfCount &PerfCounter() { return perfCounter; } + PerfCount& PerfCounter() { + return perfCounter; + } virtual void resolveInPlaceEdges(Edge::LOOK look = Edge::LOOK_BOTH); @@ -448,7 +478,7 @@ class Node { void updateShapes(); void updateDynamicParams(); void executeDynamic(dnnl::stream strm, int numaId = -1); - virtual void redefineOutputMemory(const std::vector &newShapes); + virtual void redefineOutputMemory(const std::vector& newShapes); void redefineOutputMemory(const size_t port, const VectorDims& new_output_shape); bool outputShapeDataDependency() const; @@ -475,7 +505,8 @@ class Node { /** * @brief Performs Node initialization based on graph context. - * This is an auxiliary method that allows to use information not available in Node constructor (e.g. connection information with other nodes) + * This is an auxiliary method that allows to use information not available in Node constructor (e.g. connection + * information with other nodes) */ virtual void init() {} @@ -483,11 +514,11 @@ class Node { return execIndex; } - const std::string & getTypeStr() const { + const std::string& getTypeStr() const { return typeStr; } - void setTypeStr(const std::string &typeStr) { + void setTypeStr(const std::string& typeStr) { this->typeStr = typeStr; } @@ -499,11 +530,11 @@ class Node { return 1; } - const PerfCounters & perfCounters() const { + const PerfCounters& perfCounters() const { return profiling; } - PerfCounters & perfCounters() { + PerfCounters& perfCounters() { return profiling; } @@ -588,7 +619,7 @@ class Node { return false; } - bool canBePerformedAsScaleShift(const Node *parentNode = nullptr) const; + bool canBePerformedAsScaleShift(const Node* parentNode = nullptr) const; bool isDynamicNode() const { return isDynamic; @@ -613,14 +644,14 @@ class Node { } /** - * @brief Return scales and shift if nodes can be executed as ScaleShift, else raise exception - * If node has only scale or shift value, fill missing value with default values - * i.e. EltwiseAdd: fill shifts from constant, fill scales with default values = 1.0f - * @param parentNode - * node from which data comes - * @return pair of scales and shifts - */ - std::pair, std::vector> getScalesAndShifts(const Node *parentNode) const; + * @brief Return scales and shift if nodes can be executed as ScaleShift, else raise exception + * If node has only scale or shift value, fill missing value with default values + * i.e. EltwiseAdd: fill shifts from constant, fill scales with default values = 1.0f + * @param parentNode + * node from which data comes + * @return pair of scales and shifts + */ + std::pair, std::vector> getScalesAndShifts(const Node* parentNode) const; void fuseDQScales(const float* scaleData, const size_t scaleSize); const std::vector& getDQScales() const { @@ -631,8 +662,14 @@ class Node { * Seed node should call this routine and pass its post operations list as parameter. * @param ops List of fused post operations */ - virtual void appendPostOps(dnnl::post_ops& ops, const VectorDims& postOpDims, std::unordered_map& postOpsMem, const int channelAxis = 1); - virtual void appendPostOps(dnnl::post_ops& ops, const VectorDims& postOpDims, std::vector& postOpsMem, const int channelAxis = 1); + virtual void appendPostOps(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::unordered_map& postOpsMem, + const int channelAxis = 1); + virtual void appendPostOps(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::vector& postOpsMem, + const int channelAxis = 1); virtual bool canBeExecutedInInt8() const { OPENVINO_THROW_NOT_IMPLEMENTED("canBeExecutedInInt8 not implemented for node with type ", NameFromType(getType())); @@ -649,22 +686,24 @@ class Node { this->type = type; } - virtual PortDescBasePtr getConsistentInputDesc(const NodeConfig &config, size_t idx) const; - virtual PortDescBasePtr getConsistentOutputDesc(const NodeConfig &config, size_t idx) const; - virtual MemoryDescPtr getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const; - virtual MemoryDescPtr getDstMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const; + virtual PortDescBasePtr getConsistentInputDesc(const NodeConfig& config, size_t idx) const; + virtual PortDescBasePtr getConsistentOutputDesc(const NodeConfig& config, size_t idx) const; + virtual MemoryDescPtr getSrcMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const; + virtual MemoryDescPtr getDstMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const; - virtual AttrPtr initPrimitiveAttr() { return nullptr; } + virtual AttrPtr initPrimitiveAttr() { + return nullptr; + } - typedef std::function - GetPrimitiveMemoryFormatFunc; + typedef std::function + GetPrimitiveMemoryFormatFunc; std::vector internalBlobDesc; std::vector inputShapes; std::vector outputShapes; - std::vector fusedWith; - std::vector mergedWith; + std::vector fusedWith; + std::vector mergedWith; int curNumaNode = -1; @@ -672,11 +711,11 @@ class Node { virtual void toNumaNodeImpl(int numaID); std::string primitivesPriority; - std::vector customImplPriorities; - std::vector inputMemoryFormatsFilter; - std::vector outputMemoryFormatsFilter; + std::vector customImplPriorities; + std::vector inputMemoryFormatsFilter; + std::vector outputMemoryFormatsFilter; bool enforceBF16evenForGraphTail = false; - bool keepOriginalPrecision = false; + bool keepOriginalPrecision = false; std::string originalLayers; // contains names of the original layers separated by comma std::string parallelDomain; @@ -692,11 +731,7 @@ class Node { int selectedPrimitiveDescriptorIndex = -1; - enum class InPlaceType { - Unknown, - InPlace, - NoInPlace - }; + enum class InPlaceType { Unknown, InPlace, NoInPlace }; mutable InPlaceType inplace = InPlaceType::Unknown; ConstantType constant = ConstantType::NoConst; std::vector internalBlobs; @@ -718,7 +753,7 @@ class Node { void selectPreferPrimitiveDescriptorWithShape(const std::vector& priority, bool ignoreConstInputs); bool isOneDimShape(const ov::PartialShape& pshape); bool isReorderRequired(ov::intel_cpu::MemoryDescPtr desc1, ov::intel_cpu::MemoryDescPtr desc2); - bool isConfigDefined(const NodeConfig &config) const; + bool isConfigDefined(const NodeConfig& config) const; virtual bool canBeInPlace() const; /* returns default implementaion prioirity */ @@ -733,13 +768,15 @@ class Node { /** * @brief Auxiliary function to get node input precisions - * @return Vector of precisions based on information from node input edges. Return empty vector in case edges are not initialized yet. + * @return Vector of precisions based on information from node input edges. Return empty vector in case edges are + * not initialized yet. */ virtual std::vector getInputPrecisions() const; /** * @brief Auxiliary function to get node output precisions - * @return Vector of precisions based on information from node output edges. Return empty vector in case edges are not initialized yet. + * @return Vector of precisions based on information from node output edges. Return empty vector in case edges are + * not initialized yet. */ virtual std::vector getOutputPrecisions() const; @@ -803,13 +840,14 @@ class Node { // is still under control of strong references outside of cache. // privateWeightCache is for holding strong references to constant weight // copies of same content with different layouts. - std::shared_ptr> privateWeightCache - = std::make_shared>(); + std::shared_ptr> privateWeightCache = + std::make_shared>(); private: - static void removeEdge(const EdgePtr edge, std::vector &edges) { - edges.erase(std::remove_if(edges.begin(), edges.end(), - [&edge] (EdgeWeakPtr _edge) { + static void removeEdge(const EdgePtr edge, std::vector& edges) { + edges.erase(std::remove_if(edges.begin(), + edges.end(), + [&edge](EdgeWeakPtr _edge) { return _edge.lock() == edge; }), edges.end()); @@ -856,22 +894,20 @@ constexpr uint64_t PortMask(T... rest) { return util::bit::mask(rest...); } -class Node::NodesFactory : public openvino::cc::Factory& op, - const GraphContext::CPtr)> { +class Node::NodesFactory + : public openvino::cc::Factory& op, const GraphContext::CPtr)> { public: NodesFactory(); Node* create(const std::shared_ptr& op, const GraphContext::CPtr context); }; -template +template struct NodeImpl : public NodeType { - NodeImpl(const std::shared_ptr& op, const GraphContext::CPtr context) - : NodeType(op, context) { + NodeImpl(const std::shared_ptr& op, const GraphContext::CPtr context) : NodeType(op, context) { NodeType::perfCounters().template buildClassCounters(NameFromType(NodeType::getType())); } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/adaptive_pooling.cpp b/src/plugins/intel_cpu/src/nodes/adaptive_pooling.cpp index f4e7f6217a8dec..f4af11b0f2362a 100644 --- a/src/plugins/intel_cpu/src/nodes/adaptive_pooling.cpp +++ b/src/plugins/intel_cpu/src/nodes/adaptive_pooling.cpp @@ -3,18 +3,21 @@ // #include "adaptive_pooling.h" -#include "openvino/core/parallel.hpp" -#include "cpu/x64/cpu_isa_traits.hpp" + #include -#include "onednn/dnnl.h" -#include "dnnl_extension_utils.h" -#include "selective_build.h" + #include #include #include -#include "utils/general_utils.h" #include + +#include "cpu/x64/cpu_isa_traits.hpp" +#include "dnnl_extension_utils.h" +#include "onednn/dnnl.h" +#include "openvino/core/parallel.hpp" +#include "selective_build.h" #include "shape_inference/custom/adaptive_pooling.hpp" +#include "utils/general_utils.h" using namespace dnnl; using namespace dnnl::impl::cpu::x64; @@ -23,7 +26,8 @@ namespace ov { namespace intel_cpu { namespace node { -bool AdaptivePooling::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool AdaptivePooling::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { if (one_of(op->get_type_info(), ov::op::v8::AdaptiveAvgPool::get_type_info_static())) { auto adaPool = std::dynamic_pointer_cast(op); @@ -51,9 +55,9 @@ AdaptivePooling::AdaptivePooling(const std::shared_ptr& op, const Grap : Node(op, context, AdaptivePoolingShapeInferFactory(op)) { std::string errorMessage; if (isSupportedOperation(op, errorMessage)) { - errorPrefix = "Adaptive Pooling layer with name '" + getName() + "' "; + errorPrefix = "Adaptive Pooling layer with name '" + getName() + "' "; } else { - OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage); + OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage); } if (one_of(op->get_type_info(), ov::op::v8::AdaptiveAvgPool::get_type_info_static())) { algorithm = Algorithm::AdaptivePoolingAvg; @@ -104,14 +108,14 @@ void AdaptivePooling::initSupportedPrimitiveDescriptors() { // we supports only fp32 currently precision = ov::element::f32; - std::vector dataFormats{ LayoutType::ncsp }; - const auto &inDims = getInputShapeAtPort(0).getDims(); + std::vector dataFormats{LayoutType::ncsp}; + const auto& inDims = getInputShapeAtPort(0).getDims(); if (inDims[1] != Shape::UNDEFINED_DIM && inDims[1] != 1) { dataFormats.push_back(LayoutType::nspc); dataFormats.push_back(LayoutType::nCsp16c); dataFormats.push_back(LayoutType::nCsp8c); } - for (const auto &df : dataFormats) { + for (const auto& df : dataFormats) { if (algorithm == Algorithm::AdaptivePoolingAvg) { addSupportedPrimDesc({{df, precision}, {LayoutType::ncsp, ov::element::i32}}, {{df, precision}}, @@ -134,9 +138,9 @@ void AdaptivePooling::execute(dnnl::stream strm) { if (!(inputPrec == dnnl_f32 && outputPrec == dnnl_f32)) OPENVINO_THROW(errorPrefix, "doesn't support demanded precisions"); - auto &srcMemory0 = getParentEdgeAt(0)->getMemory(); - auto &srcMemory1 = getParentEdgeAt(1)->getMemory(); - int *indexDst = nullptr; + auto& srcMemory0 = getParentEdgeAt(0)->getMemory(); + auto& srcMemory1 = getParentEdgeAt(1)->getMemory(); + int* indexDst = nullptr; if (algorithm == Algorithm::AdaptivePoolingMax) { indexDst = getDstDataAtPortAs(1); @@ -144,14 +148,15 @@ void AdaptivePooling::execute(dnnl::stream strm) { auto isPlainFmt = srcMemory0.getDesc().hasLayoutType(LayoutType::ncsp); auto isTailCFmt = srcMemory0.getDesc().hasLayoutType(LayoutType::nspc); - auto isBlkFmt = srcMemory0.getDesc().hasLayoutType(LayoutType::nCsp16c) || srcMemory0.getDesc().hasLayoutType(LayoutType::nCsp8c); + auto isBlkFmt = srcMemory0.getDesc().hasLayoutType(LayoutType::nCsp16c) || + srcMemory0.getDesc().hasLayoutType(LayoutType::nCsp8c); auto srcBlockDesc = srcMemory0.getDescWithType(); int blockSize = isBlkFmt ? srcBlockDesc->getBlockDims().back() : 1; - const auto *src = getSrcDataAtPortAs(0); - const auto *srcPooledSpatialShapes = getSrcDataAtPortAs(1); - auto *dst = getDstDataAtPortAs(0); + const auto* src = getSrcDataAtPortAs(0); + const auto* srcPooledSpatialShapes = getSrcDataAtPortAs(1); + auto* dst = getDstDataAtPortAs(0); if (static_cast(srcMemory1.getShape().getElementsCount()) != spatialDimsCount) OPENVINO_THROW(errorPrefix, @@ -175,8 +180,9 @@ void AdaptivePooling::execute(dnnl::stream strm) { const int iHW = IH * IW; const int oDHW = OD * OH * OW, oHW = OH * OW; - const int chPadding = blockSize * (isBlkFmt ? srcBlockDesc->getBlockDims()[1] : srcMemory0.getShape().getStaticDims()[1]); - const int blockCount = (isTailCFmt ? 1 : chPadding / blockSize); + const int chPadding = + blockSize * (isBlkFmt ? srcBlockDesc->getBlockDims()[1] : srcMemory0.getShape().getStaticDims()[1]); + const int blockCount = (isTailCFmt ? 1 : chPadding / blockSize); auto selectedPrimitiveDescriptor = getSelectedPrimitiveDescriptor(); if (!selectedPrimitiveDescriptor) OPENVINO_THROW(errorPrefix, "doesn't have primitive descriptors."); @@ -186,27 +192,26 @@ void AdaptivePooling::execute(dnnl::stream strm) { // unified strides array const size_t tailDimsOffset = (isTailCFmt ? -1 : 0); - const size_t inStrides[5] = { - srcStrides[0], - (isTailCFmt ? 1 : srcStrides[1]), - (spatialDimsCount == 3 ? srcStrides[2 + tailDimsOffset] : 0), - (spatialDimsCount >= 2 ? srcStrides[spatialDimsCount + tailDimsOffset] : 0), - srcStrides[spatialDimsCount + 1 + tailDimsOffset] }; - const size_t outStrides[5] = { - dstStrides[0], - (isTailCFmt ? 1 : dstStrides[1]), - (spatialDimsCount == 3 ? dstStrides[2 + tailDimsOffset] : 0), - (spatialDimsCount >= 2 ? dstStrides[spatialDimsCount + tailDimsOffset] : 0), - dstStrides[spatialDimsCount + 1 + tailDimsOffset] }; - - std::function pool; - auto poolMax = [&] (const float *srcData, float *dstData, int od, int oh, int ow, size_t spatIndOff) { + const size_t inStrides[5] = {srcStrides[0], + (isTailCFmt ? 1 : srcStrides[1]), + (spatialDimsCount == 3 ? srcStrides[2 + tailDimsOffset] : 0), + (spatialDimsCount >= 2 ? srcStrides[spatialDimsCount + tailDimsOffset] : 0), + srcStrides[spatialDimsCount + 1 + tailDimsOffset]}; + const size_t outStrides[5] = {dstStrides[0], + (isTailCFmt ? 1 : dstStrides[1]), + (spatialDimsCount == 3 ? dstStrides[2 + tailDimsOffset] : 0), + (spatialDimsCount >= 2 ? dstStrides[spatialDimsCount + tailDimsOffset] : 0), + dstStrides[spatialDimsCount + 1 + tailDimsOffset]}; + + std::function pool; + auto poolMax = [&](const float* srcData, float* dstData, int od, int oh, int ow, size_t spatIndOff) { size_t dStart, dEnd, hStart, hEnd, wStart, wEnd; setBinBorders(&dStart, &dEnd, od, ID, OD); setBinBorders(&hStart, &hEnd, oh, IH, OH); setBinBorders(&wStart, &wEnd, ow, IW, OW); - float res = srcData[dStart * inStrides[2] + hStart * inStrides[3] + wStart * inStrides[4]]; // initial max value - int resIndex = dStart * iHW + hStart * IW + wStart; // initial max index + float res = + srcData[dStart * inStrides[2] + hStart * inStrides[3] + wStart * inStrides[4]]; // initial max value + int resIndex = dStart * iHW + hStart * IW + wStart; // initial max index for (size_t pixD = dStart; pixD < dEnd; pixD++) { for (size_t pixH = hStart; pixH < hEnd; pixH++) { for (size_t pixW = wStart; pixW < wEnd; pixW++) { @@ -219,7 +224,7 @@ void AdaptivePooling::execute(dnnl::stream strm) { *dstData = res; indexDst[spatIndOff * oDHW + od * oHW + oh * OW + ow] = resIndex; }; - auto poolAvg = [&] (const float *srcData, float *dstData, int od, int oh, int ow, size_t spatIndOff) { + auto poolAvg = [&](const float* srcData, float* dstData, int od, int oh, int ow, size_t spatIndOff) { size_t dStart, dEnd, hStart, hEnd, wStart, wEnd; setBinBorders(&dStart, &dEnd, od, ID, OD); setBinBorders(&hStart, &hEnd, oh, IH, OH); @@ -245,11 +250,10 @@ void AdaptivePooling::execute(dnnl::stream strm) { pool = poolAvg; } - parallel_for5d(N, blockCount, OD, OH, OW, - [&](int n, int blkIdx, int od, int oh, int ow) { + parallel_for5d(N, blockCount, OD, OH, OW, [&](int n, int blkIdx, int od, int oh, int ow) { auto srcData = src + n * inStrides[0] + blkIdx * inStrides[1]; - auto dstData = dst + n * outStrides[0] + blkIdx * outStrides[1] + - od * outStrides[2] + oh * outStrides[3] + ow * outStrides[4]; + auto dstData = dst + n * outStrides[0] + blkIdx * outStrides[1] + od * outStrides[2] + oh * outStrides[3] + + ow * outStrides[4]; int cStart = 0, cEnd = C, inResidual = 0, outResidual = 0; if (!isTailCFmt) { cStart = blkIdx * blockSize; @@ -263,18 +267,23 @@ void AdaptivePooling::execute(dnnl::stream strm) { inResidual = outResidual = c % blockSize; } pool(srcData + inResidual, dstData + outResidual, od, oh, ow, n * C + c); - }}); + } + }); } bool AdaptivePooling::created() const { return getType() == Type::AdaptivePooling; } -inline void AdaptivePooling::setBinBorders(size_t *startPtr, size_t *endPtr, size_t idx, size_t inputLength, size_t outputLength) { +inline void AdaptivePooling::setBinBorders(size_t* startPtr, + size_t* endPtr, + size_t idx, + size_t inputLength, + size_t outputLength) { *(startPtr) = idx * inputLength / outputLength; *(endPtr) = ceil(static_cast((idx + 1) * inputLength) / outputLength); } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/adaptive_pooling.h b/src/plugins/intel_cpu/src/nodes/adaptive_pooling.h index c88c9b5989aef9..04b628a5da5cee 100644 --- a/src/plugins/intel_cpu/src/nodes/adaptive_pooling.h +++ b/src/plugins/intel_cpu/src/nodes/adaptive_pooling.h @@ -5,9 +5,11 @@ #pragma once #include -#include + #include +#include #include + #include "dnnl_extension_utils.h" namespace ov { @@ -29,16 +31,18 @@ class AdaptivePooling : public Node { int spatialDimsCount; mutable std::vector spatialDimsValue = {}; ov::element::Type precision = ov::element::f32; - inline void setBinBorders(size_t *startPtr, size_t *endPtr, size_t idx, size_t inputLength, size_t outputLength); + inline void setBinBorders(size_t* startPtr, size_t* endPtr, size_t idx, size_t inputLength, size_t outputLength); std::string errorPrefix; protected: bool needShapeInfer() const override; - bool needPrepareParams() const override { return false; }; + bool needPrepareParams() const override { + return false; + }; void executeDynamicImpl(dnnl::stream strm) override; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/batch_to_space.cpp b/src/plugins/intel_cpu/src/nodes/batch_to_space.cpp index 80713e90750e2d..50665c083ec930 100644 --- a/src/plugins/intel_cpu/src/nodes/batch_to_space.cpp +++ b/src/plugins/intel_cpu/src/nodes/batch_to_space.cpp @@ -2,14 +2,16 @@ // SPDX-License-Identifier: Apache-2.0 // -#include +#include "batch_to_space.h" + +#include #include +#include + #include "dnnl_types.h" +#include "nodes/common/blocked_desc_creator.h" #include "openvino/core/parallel.hpp" #include "selective_build.h" -#include "batch_to_space.h" -#include "nodes/common/blocked_desc_creator.h" -#include namespace ov { namespace intel_cpu { @@ -40,8 +42,8 @@ BatchToSpace::BatchToSpace(const std::shared_ptr& op, const GraphConte if (inputShapes.size() != 4 || outputShapes.size() != 1) OPENVINO_THROW(errorPrefix, " has incorrect number of input or output edges!"); - const auto &inDims = getInputShapeAtPort(0).getDims(); - const auto &outDims = getOutputShapeAtPort(0).getDims(); + const auto& inDims = getInputShapeAtPort(0).getDims(); + const auto& outDims = getOutputShapeAtPort(0).getDims(); if (inDims.size() < 4 || inDims.size() > 5) OPENVINO_THROW(errorPrefix, " has unsupported 'data' input rank: ", inDims.size()); if (inDims.size() != outDims.size()) @@ -52,7 +54,7 @@ void BatchToSpace::initSupportedPrimitiveDescriptors() { if (!supportedPrimitiveDescriptors.empty()) return; - const auto &inDims = getInputShapeAtPort(0).getDims(); + const auto& inDims = getInputShapeAtPort(0).getDims(); const auto precision = getOriginalInputPrecisionAtPort(0); const std::set supported_precision_sizes = {1, 2, 4, 8}; if (supported_precision_sizes.find(precision.size()) == supported_precision_sizes.end()) @@ -88,7 +90,7 @@ void BatchToSpace::initSupportedPrimitiveDescriptors() { } } -static std::vector getShape5D(const VectorDims &shape) { +static std::vector getShape5D(const VectorDims& shape) { std::vector shape5D(5, 1); for (int i = 0; i < 2; i++) { shape5D[i] = shape[i]; @@ -98,26 +100,26 @@ static std::vector getShape5D(const VectorDims &shape) { return shape5D; } -template +template void BatchToSpace::batchToSpaceKernel() { - const auto *srcData = getSrcDataAtPortAs(0); - const auto *blockShapesPtr = getSrcDataAtPortAs(1); + const auto* srcData = getSrcDataAtPortAs(0); + const auto* blockShapesPtr = getSrcDataAtPortAs(1); size_t dataRank = getSrcMemoryAtPort(0)->getShape().getRank(); blockShapeIn.clear(); for (size_t i = 0; i < dataRank; i++) { blockShapeIn.push_back(*(blockShapesPtr + i)); } - const auto *padsBeginPtr = getSrcDataAtPortAs(2); + const auto* padsBeginPtr = getSrcDataAtPortAs(2); cropsBeginIn.clear(); for (size_t i = 0; i < dataRank; i++) { cropsBeginIn.push_back(*(padsBeginPtr + i)); } - auto *dstData = getDstDataAtPortAs(0); + auto* dstData = getDstDataAtPortAs(0); - const auto &inDims = getParentEdgeAt(0)->getMemory().getStaticDims(); - const auto &outDims = getChildEdgeAt(0)->getMemory().getStaticDims(); + const auto& inDims = getParentEdgeAt(0)->getMemory().getStaticDims(); + const auto& outDims = getChildEdgeAt(0)->getMemory().getStaticDims(); auto srcDesc = getParentEdgeAt(0)->getMemory().getDescWithType(); @@ -193,8 +195,8 @@ void BatchToSpace::batchToSpaceKernel() { const int64_t addTmpOC = blocked ? 0lu : oAdd[1]; const int64_t addTmpOc = blocked ? oAdd[1] : 0lu; - const size_t firstI1 = i0 == 0 ? std::max(begin[1], indxStart[1]) : begin[1]; - const size_t lastI1 = i0 == indxEnd[0] ? std::min(indxEnd[1] + 1, finish[1]) : finish[1]; + const size_t firstI1 = i0 == 0 ? std::max(begin[1], indxStart[1]) : begin[1]; + const size_t lastI1 = i0 == indxEnd[0] ? std::min(indxEnd[1] + 1, finish[1]) : finish[1]; for (size_t i1 = firstI1; i1 < lastI1; ++i1) { const size_t block = i1 == finish[1] ? lastBlock : blockSize; @@ -216,12 +218,13 @@ void BatchToSpace::batchToSpaceKernel() { const size_t dstIdx4 = dstIdx3 + tmpOw * blockSize; for (size_t it = 0; it < itEnd + 1; ++it) { const size_t i5Begin = it == 0 ? 0 : (it * blockSize - 1 - oAdd[1]) / blockShape[1] + 1; - const size_t i5End = it == itEnd ? (block - 1) : ((it + 1) * blockSize - 1 - oAdd[1]) / blockShape[1]; + const size_t i5End = + it == itEnd ? (block - 1) : ((it + 1) * blockSize - 1 - oAdd[1]) / blockShape[1]; for (size_t i5 = i5Begin; i5 < i5End + 1; ++i5) { const int64_t tmpOc = i5 * blockShape[1] + addTmpOc; const size_t srcIdx5 = srcIdx4 + i5; const size_t dstIdx5 = - dstIdx4 + it * outSpatialStep * blockSize + (tmpOc - it * blockSize); + dstIdx4 + it * outSpatialStep * blockSize + (tmpOc - it * blockSize); dstData[dstIdx5] = srcData[srcIdx5]; } } @@ -239,13 +242,19 @@ void BatchToSpace::executeDynamicImpl(dnnl::stream strm) { void BatchToSpace::execute(dnnl::stream strm) { switch (getParentEdgeAt(0)->getMemory().getDesc().getPrecision().size()) { - case 1: batchToSpaceKernel::value_type>(); break; - case 2: batchToSpaceKernel::value_type>(); break; - case 4: batchToSpaceKernel::value_type>(); break; - default: - OPENVINO_THROW("BatchToSpace layer does not support precision '", - std::string(getParentEdgeAt(0)->getMemory().getDesc().getPrecision().get_type_name()), - "'"); + case 1: + batchToSpaceKernel::value_type>(); + break; + case 2: + batchToSpaceKernel::value_type>(); + break; + case 4: + batchToSpaceKernel::value_type>(); + break; + default: + OPENVINO_THROW("BatchToSpace layer does not support precision '", + std::string(getParentEdgeAt(0)->getMemory().getDesc().getPrecision().get_type_name()), + "'"); } } @@ -253,6 +262,6 @@ bool BatchToSpace::created() const { return getType() == Type::BatchToSpace; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/batch_to_space.h b/src/plugins/intel_cpu/src/nodes/batch_to_space.h index 1b583f74bd7905..5211e0c0b5dd10 100644 --- a/src/plugins/intel_cpu/src/nodes/batch_to_space.h +++ b/src/plugins/intel_cpu/src/nodes/batch_to_space.h @@ -14,7 +14,7 @@ class BatchToSpace : public Node { public: BatchToSpace(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; // output shape can potentially be empty @@ -25,14 +25,18 @@ class BatchToSpace : public Node { void execute(dnnl::stream strm) override; bool created() const override; - bool needPrepareParams() const override { return false; }; - bool needShapeInfer() const override {return true;}; + bool needPrepareParams() const override { + return false; + }; + bool needShapeInfer() const override { + return true; + }; void executeDynamicImpl(dnnl::stream strm) override; static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; private: - template + template void batchToSpaceKernel(); private: @@ -42,6 +46,6 @@ class BatchToSpace : public Node { std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/bin_conv.cpp b/src/plugins/intel_cpu/src/nodes/bin_conv.cpp index d1e82235ba9bb1..336a370374a9f9 100644 --- a/src/plugins/intel_cpu/src/nodes/bin_conv.cpp +++ b/src/plugins/intel_cpu/src/nodes/bin_conv.cpp @@ -3,34 +3,35 @@ // #include "bin_conv.h" -#include "eltwise.h" -#include "fake_quantize.h" -#include "conv.h" + #include #include #include -#include "dnnl_types.h" + +#include "conv.h" +#include "cpu/x64/cpu_isa_traits.hpp" +#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" +#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" +#include "cpu/x64/jit_generator.hpp" #include "dnnl_extension_utils.h" +#include "dnnl_types.h" +#include "eltwise.h" +#include "fake_quantize.h" #include "openvino/core/parallel.hpp" -#include "cpu/x64/jit_generator.hpp" -#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" -#include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" -#include "cpu/x64/cpu_isa_traits.hpp" -#include "utils/general_utils.h" #include "openvino/opsets/opset1.hpp" +#include "utils/general_utils.h" #include "utils/ngraph_utils.hpp" // WA for xbyak.h #ifdef _WIN32 -# ifndef _WINSOCKAPI_ -# define _WINSOCKAPI_ -# endif -# ifndef _WINSOCK2API_ -# define _WINSOCK2API_ -# endif +# ifndef _WINSOCKAPI_ +# define _WINSOCKAPI_ +# endif +# ifndef _WINSOCK2API_ +# define _WINSOCK2API_ +# endif #endif - using namespace dnnl; using namespace dnnl::impl; using namespace dnnl::impl::cpu; @@ -42,14 +43,17 @@ namespace ov { namespace intel_cpu { namespace node { #if defined(OPENVINO_ARCH_X86_64) -#define GET_OFF(field) offsetof(jit_bin_conv_call_args, field) +# define GET_OFF(field) offsetof(jit_bin_conv_call_args, field) template struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_bin_conv_kernel_f32) - explicit jit_uni_bin_conv_kernel_f32(jit_bin_conv_params jcp, jit_dw_conv_params jcp_dw_conv, const dnnl_primitive_attr &attr) : - jit_uni_bin_conv_kernel(jcp, jcp_dw_conv, attr), jit_generator(jit_name()) {} + explicit jit_uni_bin_conv_kernel_f32(jit_bin_conv_params jcp, + jit_dw_conv_params jcp_dw_conv, + const dnnl_primitive_attr& attr) + : jit_uni_bin_conv_kernel(jcp, jcp_dw_conv, attr), + jit_generator(jit_name()) {} void create_ker() override { jit_generator::create_kernel(); @@ -57,16 +61,19 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ } void generate() override { - const auto &p = attr_.post_ops_; + const auto& p = attr_.post_ops_; int end_idx = jcp_.with_dw_conv ? p.find(primitive_kind::convolution) : p.len(); for (int i = 0; i < end_idx; i++) { - auto &post_op = p.entry_[i]; + auto& post_op = p.entry_[i]; if (post_op.is_eltwise()) { - eltwise_injectors.push_back(std::make_shared>( - this, post_op.eltwise, true, eltwise_reserved, mask_post_op_reserved)); + eltwise_injectors.push_back(std::make_shared>(this, + post_op.eltwise, + true, + eltwise_reserved, + mask_post_op_reserved)); } else if (post_op.is_depthwise()) { - depthwise_injectors.push_back(std::make_shared>( - this, post_op, mask_post_op_reserved)); + depthwise_injectors.push_back( + std::make_shared>(this, post_op, mask_post_op_reserved)); } } @@ -80,7 +87,7 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ mov(reg_oc_work, ptr[this->param1 + GET_OFF(oc_work)]); mov(reg_post_ops_data, ptr[this->param1 + GET_OFF(post_op_data)]); - mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); + mov(reg_oc_off, ptr[param1 + GET_OFF(oc_off)]); mov(reg_table, l_table); Label main_loop_label; @@ -98,14 +105,16 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ int nbits = 8; - L(main_loop_label); { + L(main_loop_label); + { cmp(reg_oc_work, jcp_.oc_block); jl(tail_label, T_NEAR); solve_common(1, jcp_.oc_block); sub(reg_oc_work, jcp_.oc_block); - add(reg_kernel_base, jcp_.oc_block * jcp_.nb_ic * jcp_.kh * jcp_.kw * div_up(jcp_.ic_block, nbits) * jcp_.typesize_in); + add(reg_kernel_base, + jcp_.oc_block * jcp_.nb_ic * jcp_.kh * jcp_.kw * div_up(jcp_.ic_block, nbits) * jcp_.typesize_in); if (jcp_.with_dw_conv) { add(reg_output_base, jcp_.oc_block * jcp_dw_conv_.kh * jcp_.ow * jcp_.typesize_out); @@ -137,8 +146,7 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ } private: - using Vmm = typename conditional3::type; + using Vmm = typename conditional3::type; using Ymm = const Xbyak::Ymm; using reg8_t = const Xbyak::Reg8; @@ -212,100 +220,108 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ nstl::vector>> eltwise_injectors; nstl::vector>> depthwise_injectors; - void cvt2ps(dnnl::memory::data_type type_in, Vmm vmm_in, const Xbyak::Operand &op, bool scalar_load) { + void cvt2ps(dnnl::memory::data_type type_in, Vmm vmm_in, const Xbyak::Operand& op, bool scalar_load) { Xmm xmm_in = Xmm(vmm_in.getIdx()); switch (type_in) { - case memory::data_type::f32: - case memory::data_type::s32: - if (scalar_load) { - mov(reg_tmp_32, op); - uni_vmovq(xmm_in, reg_tmp_64); - } else { - uni_vmovups(vmm_in, op); - } - break; - case memory::data_type::s8: - if (scalar_load) { - movsx(reg_tmp_32, op); - uni_vmovq(xmm_in, reg_tmp_64); - } else { - uni_vpmovsxbd(vmm_in, op); - } - break; - case memory::data_type::u8: - if (scalar_load) { - movzx(reg_tmp_32, op); - uni_vmovq(xmm_in, reg_tmp_64); - } else { - uni_vpmovzxbd(vmm_in, op); - } - break; - default: assert(!"unsupported data type"); + case memory::data_type::f32: + case memory::data_type::s32: + if (scalar_load) { + mov(reg_tmp_32, op); + uni_vmovq(xmm_in, reg_tmp_64); + } else { + uni_vmovups(vmm_in, op); + } + break; + case memory::data_type::s8: + if (scalar_load) { + movsx(reg_tmp_32, op); + uni_vmovq(xmm_in, reg_tmp_64); + } else { + uni_vpmovsxbd(vmm_in, op); + } + break; + case memory::data_type::u8: + if (scalar_load) { + movzx(reg_tmp_32, op); + uni_vmovq(xmm_in, reg_tmp_64); + } else { + uni_vpmovzxbd(vmm_in, op); + } + break; + default: + assert(!"unsupported data type"); } if (type_in != data_type::f32) uni_vcvtdq2ps(vmm_in, vmm_in); } - void store_dst(const Xbyak::Address &op, Vmm vmm_dst, bool scalar_store) { + void store_dst(const Xbyak::Address& op, Vmm vmm_dst, bool scalar_store) { Ymm ymm_dst = Ymm(vmm_dst.getIdx()); Xmm xmm_dst = Xmm(vmm_dst.getIdx()); switch (jcp_.dst_dt) { - case memory::data_type::f32: - case memory::data_type::s32: - if (scalar_store) { - movq(reg_tmp_64, xmm_dst); - mov(op, reg_tmp_32); - } else { - uni_vmovups(op, vmm_dst); - } - break; - case memory::data_type::s8: - uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst); + case memory::data_type::f32: + case memory::data_type::s32: + if (scalar_store) { + movq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_32); + } else { + uni_vmovups(op, vmm_dst); + } + break; + case memory::data_type::s8: + uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst); - if (isa != x64::sse41 && !scalar_store) - vpermq(ymm_dst, ymm_dst, 0x08); + if (isa != x64::sse41 && !scalar_store) + vpermq(ymm_dst, ymm_dst, 0x08); - uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst); + uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst); - if (scalar_store) { - movq(reg_tmp_64, xmm_dst); - mov(op, reg_tmp_8); - } else { - if (isa != x64::sse41) - vmovq(op, xmm_dst); - else - movd(op, xmm_dst); - } - break; - case memory::data_type::u8: - case memory::data_type::bin: - uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst); + if (scalar_store) { + movq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_8); + } else { + if (isa != x64::sse41) + vmovq(op, xmm_dst); + else + movd(op, xmm_dst); + } + break; + case memory::data_type::u8: + case memory::data_type::bin: + uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst); - if (isa != x64::sse41 && !scalar_store) - vpermq(ymm_dst, ymm_dst, 0x08); + if (isa != x64::sse41 && !scalar_store) + vpermq(ymm_dst, ymm_dst, 0x08); - uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst); + uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst); - if (scalar_store) { - movq(reg_tmp_64, xmm_dst); - mov(op, reg_tmp_8); - } else { - if (isa != x64::sse41) - vmovq(op, xmm_dst); - else - movd(op, xmm_dst); - } + if (scalar_store) { + movq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_8); + } else { + if (isa != x64::sse41) + vmovq(op, xmm_dst); + else + movd(op, xmm_dst); + } - break; - default: - assert(!"unknown dst_dt"); + break; + default: + assert(!"unknown dst_dt"); } } - void apply_filter(int ur_w, int pad_l, int pad_r, int oc_blocks, int oc_step, int ic_blocks, bool last_icb, bool h_padded) { + void apply_filter(int ur_w, + int pad_l, + int pad_r, + int oc_blocks, + int oc_step, + int ic_blocks, + bool last_icb, + bool h_padded) { int kw = jcp_.kw; int kh = jcp_.kh; int stride_w = jcp_.stride_w; @@ -318,15 +334,16 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ for (int ki = 0; ki < kw; ki++) { int jj_start = nstl::max(0, div_up(pad_l - ki * dilate_w, stride_w)); - int jj_end = ur_w - nstl::max(0, div_up(ki*dilate_w+pad_r-(kw-1)*dilate_w, stride_w)); + int jj_end = ur_w - nstl::max(0, div_up(ki * dilate_w + pad_r - (kw - 1) * dilate_w, stride_w)); int _start = (!jcp_.exclude_pad) ? 0 : jj_start; int _end = (!jcp_.exclude_pad) ? ur_w : jj_end; for (int ifm2 = 0; ifm2 < ic_blocks; ifm2++) { for (int jj = _start; jj < _end; jj++) { - int inp_off = ((ki*dilate_w + jj*stride_w - pad_l)*div_up(jcp_.ic, nbits) + - ifm2 * div_up(ic_blk, nbits)) * jcp_.typesize_in; + int inp_off = ((ki * dilate_w + jj * stride_w - pad_l) * div_up(jcp_.ic, nbits) + + ifm2 * div_up(ic_blk, nbits)) * + jcp_.typesize_in; if (h_padded || jj < jj_start || jj >= jj_end) { uni_vmovups(vmm_src, ptr[reg_table + 8 * vlen]); @@ -336,10 +353,11 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ for (int r = 0; r < repeats; r++) { for (int ii = 0; ii < oc_blocks; ii++) { - int ker_off = (ifm2 * kh * kw * div_up(ic_blk, nbits) * oc_blk - + ii * jcp_.nb_ic * div_up(ic_blk, nbits) * kh * kw * oc_blk - + ki * div_up(ic_blk, nbits) * oc_blk - + r * div_up(ic_blk, nbits) * (oc_blk / 2)) * jcp_.typesize_in; + int ker_off = + (ifm2 * kh * kw * div_up(ic_blk, nbits) * oc_blk + + ii * jcp_.nb_ic * div_up(ic_blk, nbits) * kh * kw * oc_blk + + ki * div_up(ic_blk, nbits) * oc_blk + r * div_up(ic_blk, nbits) * (oc_blk / 2)) * + jcp_.typesize_in; uni_vmovups(vmm_tmp, ptr[aux1_reg_kernel + ker_off]); @@ -350,7 +368,8 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ if (mayiuse(x64::avx512_vpopcnt)) { vpopcntd(vmm_tmp, vmm_tmp); uni_vpaddd(Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj), - Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj), vmm_tmp); + Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj), + vmm_tmp); } else { if (isa == x64::sse41) { movups(vmm_tmp1, vmm_tmp); @@ -375,12 +394,15 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ } if (mayiuse(avx512_core_vnni)) { - vpdpbusd(Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj), vmm_tmp, vmm_one_u8); + vpdpbusd(Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj), + vmm_tmp, + vmm_one_u8); } else { uni_vpmaddubsw(vmm_tmp, vmm_tmp, vmm_one_u8); uni_vpmaddwd(vmm_tmp, vmm_tmp, vmm_one_s16); uni_vpaddd(Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj), - Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj), vmm_tmp); + Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj), + vmm_tmp); } } } @@ -431,22 +453,22 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ int nbits = 8; const int inp_mult = dilate_h * div_up(jcp_.ic, nbits); - Label t_overflow_label, no_t_overflow_label, - b_overflow_label, no_b_overflow_label; + Label t_overflow_label, no_t_overflow_label, b_overflow_label, no_b_overflow_label; mov(aux_reg_input, reg_input); mov(aux_reg_kernel, reg_kernel_base); - uni_vmovups(vmm_lookup, ptr[reg_table + 0 * vlen]); - uni_vmovups(vmm_mask, ptr[reg_table + 1 * vlen]); - uni_vmovups(vmm_one_u8, ptr[reg_table + 5 * vlen]); + uni_vmovups(vmm_lookup, ptr[reg_table + 0 * vlen]); + uni_vmovups(vmm_mask, ptr[reg_table + 1 * vlen]); + uni_vmovups(vmm_one_u8, ptr[reg_table + 5 * vlen]); uni_vmovups(vmm_one_s16, ptr[reg_table + 6 * vlen]); if (!jcp_.exclude_pad) { - mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); + mov(reg_overflow, ptr[param1 + GET_OFF(t_overflow)]); cmp(reg_overflow, 0); je(no_t_overflow_label, T_NEAR); - L(t_overflow_label); { + L(t_overflow_label); + { oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true); add(aux_reg_kernel, jcp_.typesize_in * kw * jcp_.oc_block * div_up(jcp_.ic_block, nbits)); @@ -459,8 +481,8 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ Label skip_kh_loop; mov(reg_kh, ptr[this->param1 + GET_OFF(kh_padding)]); - if (!jcp_.exclude_pad || (jcp_.exclude_pad && - (jcp_.kh - 1) * (jcp_.dilate_h + 1) < nstl::max(jcp_.t_pad, jcp_.b_pad))) { + if (!jcp_.exclude_pad || + (jcp_.exclude_pad && (jcp_.kh - 1) * (jcp_.dilate_h + 1) < nstl::max(jcp_.t_pad, jcp_.b_pad))) { cmp(reg_kh, 0); je(skip_kh_loop, T_NEAR); } @@ -481,10 +503,11 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ L(skip_kh_loop); if (!jcp_.exclude_pad) { - mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); + mov(reg_overflow, ptr[param1 + GET_OFF(b_overflow)]); cmp(reg_overflow, 0); je(no_b_overflow_label, T_NEAR); - L(b_overflow_label); { + L(b_overflow_label); + { oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks, oc_step, true); add(aux_reg_kernel, jcp_.typesize_in * kw * jcp_.oc_block * div_up(jcp_.ic_block, nbits)); @@ -515,7 +538,7 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ kmovw(ktail_mask, reg_tmp_32); } - const auto &p = attr_.post_ops_; + const auto& p = attr_.post_ops_; for (int r = 0; r < repeats; r++) { int tail_size = isa == x64::sse41 ? nstl::min(jcp_.oc_block / 2, oc_step - r * jcp_.oc_block / 2) : oc_step; bool is_scalar_store = isa == x64::sse41 ? tail_size < jcp_.oc_block / 2 : tail_size < jcp_.oc_block; @@ -524,15 +547,17 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ if (jcp_.exclude_pad) { mov(reg_tmp_32, jcp_.ic); - imul(reg_tmp_32, ptr[param1 + GET_OFF(kh_padding)]); + imul(reg_tmp_32, ptr[param1 + GET_OFF(kh_padding)]); for (int jj = 0; jj < ur_w; jj++) kw_padding[jj] = 0; for (int ki = 0; ki < jcp_.kw; ki++) { int jj_start = nstl::max(0, div_up(pad_l - ki * (jcp_.dilate_w + 1), jcp_.stride_w)); - int jj_end = ur_w - nstl::max(0, div_up(ki * (jcp_.dilate_w + 1) + pad_r - - (jcp_.kw - 1) * (jcp_.dilate_w + 1), jcp_.stride_w)); + int jj_end = + ur_w - nstl::max(0, + div_up(ki * (jcp_.dilate_w + 1) + pad_r - (jcp_.kw - 1) * (jcp_.dilate_w + 1), + jcp_.stride_w)); for (int jj = jj_start; jj < jj_end; jj++) { kw_padding[jj]++; } @@ -552,8 +577,11 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ } for (int ii = 0; ii < oc_blocks; ii++) { - uni_vcvtdq2ps(Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj), Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj)); - uni_vfmadd213ps(Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj), vmm_scale, vmm_shift); + uni_vcvtdq2ps(Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj), + Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj)); + uni_vfmadd213ps(Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj), + vmm_scale, + vmm_shift); } } @@ -580,7 +608,9 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ for (int ii = 0; ii < oc_blocks; ii++) { depthwise_injectors[depthwise_inj_idx]->compute_vector_range(start_idx + ur_w * ii, - start_idx + ur_w * ii + ur_w, reg_d_weights, reg_d_weights); + start_idx + ur_w * ii + ur_w, + reg_d_weights, + reg_d_weights); add(reg_d_weights, jcp_.oc_block * sizeof(float)); } @@ -596,7 +626,7 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ if (is_scalar_store) { if (isa == x64::avx512_core) { - int o_off = jj * jcp_.oc * jcp_.ngroups; + int o_off = jj * jcp_.oc * jcp_.ngroups; Vmm vmm_in = vmm_sum | ktail_mask | T_z; @@ -604,7 +634,7 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ uni_vaddps(vmm_dst, vmm_dst, vmm_sum); } else { for (int oc = 0; oc < tail_size; oc++) { - int o_off = jj * jcp_.oc * jcp_.ngroups + r * (jcp_.oc_block / 2) + oc; + int o_off = jj * jcp_.oc * jcp_.ngroups + r * (jcp_.oc_block / 2) + oc; uni_vpxor(vmm_sum, vmm_sum, vmm_sum); cvt2ps(jcp_.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp_.typesize_out], true); @@ -621,7 +651,8 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ } } } else { - size_t o_off = ii * jcp_.oc_block + jj * jcp_.oc * jcp_.ngroups + r * (jcp_.oc_block / 2); + size_t o_off = + ii * jcp_.oc_block + jj * jcp_.oc * jcp_.ngroups + r * (jcp_.oc_block / 2); cvt2ps(jcp_.dst_dt, vmm_sum, ptr[reg_output + o_off * jcp_.typesize_out], false); uni_vaddps(vmm_dst, vmm_dst, vmm_sum); @@ -649,10 +680,15 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ for (int ii = 0; ii < oc_blocks; ii++) { for (int jj = 0; jj < ur_w; jj++) { for (int r = 0; r < repeats; r++) { - int tail_size = isa == x64::sse41 ? nstl::min(jcp_.oc_block / 2, oc_step - r * jcp_.oc_block / 2) : oc_step; + int tail_size = + isa == x64::sse41 ? nstl::min(jcp_.oc_block / 2, oc_step - r * jcp_.oc_block / 2) : oc_step; mov(reg_b_mask, (1 << tail_size) - 1); - uni_vmovups(vmm_thr, ptr[reg_b_weights + (ii * jcp_.oc_block + r * (jcp_.oc_block / 2)) * sizeof(float)]); - uni_vmovups(vmm_out_mask, ptr[reg_b_out_mask + (ii * jcp_.oc_block + r * (jcp_.oc_block / 2)) * sizeof(float)]); + uni_vmovups( + vmm_thr, + ptr[reg_b_weights + (ii * jcp_.oc_block + r * (jcp_.oc_block / 2)) * sizeof(float)]); + uni_vmovups( + vmm_out_mask, + ptr[reg_b_out_mask + (ii * jcp_.oc_block + r * (jcp_.oc_block / 2)) * sizeof(float)]); Vmm vmm_dst = Vmm(1 + r * jcp_.ur_w * jcp_.nb_oc_blocking + ur_w * ii + jj); @@ -693,7 +729,8 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ } } else { for (int r = 0; r < repeats; r++) { - int tail_size = isa == x64::sse41 ? nstl::min(jcp_.oc_block / 2, oc_step - r * jcp_.oc_block / 2) : oc_step; + int tail_size = + isa == x64::sse41 ? nstl::min(jcp_.oc_block / 2, oc_step - r * jcp_.oc_block / 2) : oc_step; bool is_scalar_store = isa == x64::sse41 ? tail_size < jcp_.oc_block / 2 : tail_size < jcp_.oc_block; if (is_scalar_store) { for (int jj = 0; jj < ur_w; jj++) { @@ -735,7 +772,7 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ size_t o_off; if (jcp_.with_dw_conv) - o_off = ((size_t) ii * jcp_dw_conv_.kh * jcp_.ow + jj) * jcp_.oc_block + + o_off = ((size_t)ii * jcp_dw_conv_.kh * jcp_.ow + jj) * jcp_.oc_block + r * (jcp_.oc_block / 2); else o_off = ii * jcp_.oc_block + jj * jcp_.oc * jcp_.ngroups + r * (jcp_.oc_block / 2); @@ -759,14 +796,15 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ int nbits = 8; const int inp_mult = div_up(jcp_.ic, nbits); - const int out_mult = jcp_.with_dw_conv ? jcp_.oc_block : jcp_.with_binarization ? div_up(jcp_.oc, nbits) : jcp_.oc; + const int out_mult = jcp_.with_dw_conv ? jcp_.oc_block + : jcp_.with_binarization ? div_up(jcp_.oc, nbits) + : jcp_.oc; int l_pad = jcp_.l_pad; - int r_pad = nstl::max(0, (jcp_.ow - 1) * str_w + (kw - 1) * dilate_w - - (iw + l_pad - 1)); - int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w - - (iw + l_pad - 1); - if (r_pad1 > 0) n_oi--; + int r_pad = nstl::max(0, (jcp_.ow - 1) * str_w + (kw - 1) * dilate_w - (iw + l_pad - 1)); + int r_pad1 = (ur_w * n_oi - 1) * str_w + (kw - 1) * dilate_w - (iw + l_pad - 1); + if (r_pad1 > 0) + n_oi--; mov(reg_input, reg_input_base); mov(reg_output, reg_output_base); @@ -779,9 +817,9 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ if (l_pad > 0) { n_oi--; if (n_oi < 0 && r_pad1 > 0) - width_blk_step(ur_w, l_pad, r_pad1, oc_blocks, oc_step); // "lrpad" + width_blk_step(ur_w, l_pad, r_pad1, oc_blocks, oc_step); // "lrpad" else - width_blk_step(ur_w, l_pad, 0, oc_blocks, oc_step); // "lpad" + width_blk_step(ur_w, l_pad, 0, oc_blocks, oc_step); // "lpad" add(reg_input, jcp_.typesize_in * (ur_w * str_w - l_pad) * inp_mult); add(reg_output, jcp_.typesize_out * ur_w * out_mult); } @@ -792,7 +830,7 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ if (n_oi > 0) { L(ow_loop_label); - width_blk_step(ur_w, 0, 0, oc_blocks, oc_step); // "middle" + width_blk_step(ur_w, 0, 0, oc_blocks, oc_step); // "middle" add(reg_input, jcp_.typesize_in * ur_w * str_w * inp_mult); add(reg_output, jcp_.typesize_out * ur_w * out_mult); @@ -801,14 +839,14 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ jl(ow_loop_label, T_NEAR); } - if (r_pad1 > 0 && n_oi >=0) { - width_blk_step(ur_w, 0, r_pad1, oc_blocks, oc_step); // "rpad" + if (r_pad1 > 0 && n_oi >= 0) { + width_blk_step(ur_w, 0, r_pad1, oc_blocks, oc_step); // "rpad" add(reg_input, jcp_.typesize_in * ur_w * str_w * inp_mult); add(reg_output, jcp_.typesize_out * ur_w * out_mult); } if (ur_w_tail != 0) - width_blk_step(ur_w_tail, 0, r_pad, oc_blocks, oc_step); // "tail" + width_blk_step(ur_w_tail, 0, r_pad, oc_blocks, oc_step); // "tail" pop(reg_oc_off); pop(reg_oc_work); @@ -817,17 +855,15 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ } void prepare_table() { - const unsigned int cvals[] = { - 0x02010100, // 0 1 1 2 - 0x03020201, // 1 2 2 3 - 0x03020201, // 1 2 2 3 - 0x04030302, // 2 3 3 4 - 0x0f0f0f0f, - 0x000000ff, - 0xc0000000, // -2.0f - 0x01010101, - 0x00010001 - }; + const unsigned int cvals[] = {0x02010100, // 0 1 1 2 + 0x03020201, // 1 2 2 3 + 0x03020201, // 1 2 2 3 + 0x04030302, // 2 3 3 4 + 0x0f0f0f0f, + 0x000000ff, + 0xc0000000, // -2.0f + 0x01010101, + 0x00010001}; size_t simd_w = vlen / sizeof(int32_t); @@ -876,7 +912,8 @@ struct jit_uni_bin_conv_kernel_f32 : public jit_uni_bin_conv_kernel, public jit_ } }; #endif -bool BinaryConvolution::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool BinaryConvolution::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { if (isDynamicNgraphNode(op)) { errorMessage = "Doesn't support op with dynamic shapes"; @@ -934,7 +971,7 @@ void BinaryConvolution::getSupportedDescriptors() { withSum = false; size_t expectedInputEdgesNum = 2; for (size_t i = 0; i < fusedWith.size(); i++) { - auto *eltwiseNode = dynamic_cast(fusedWith[i].get()); + auto* eltwiseNode = dynamic_cast(fusedWith[i].get()); if (eltwiseNode && eltwiseNode->isSpecialConvolutionAddFusing()) { withSum = true; expectedInputEdgesNum++; @@ -979,22 +1016,30 @@ void BinaryConvolution::initSupportedPrimitiveDescriptors() { if (implType != impl_desc_type::ref) { // optimzed implementation -// auto weiFormat = implType == impl_desc_type::jit_avx512 ? memory::format_tag::OhIw16o32i : memory::format_tag::OhIw8o32i; + // auto weiFormat = implType == impl_desc_type::jit_avx512 ? memory::format_tag::OhIw16o32i : + // memory::format_tag::OhIw8o32i; - //activation + // activation auto nspcCreator = BlockedDescCreator::getCommonCreators().at(LayoutType::nspc); config.inConfs[0].setMemDesc(nspcCreator->createSharedDesc(ov::element::u1, getInputShapeAtPort(0))); - //weights - size_t weiFirstDimBlockSize = implType == impl_desc_type::jit_avx512 ? 16 : 8; //memory::format_tag::OIhw16o32i : memory::format_tag::OIhw8o32i; + // weights + size_t weiFirstDimBlockSize = implType == impl_desc_type::jit_avx512 + ? 16 + : 8; // memory::format_tag::OIhw16o32i : memory::format_tag::OIhw8o32i; auto weiDims = getInputShapeAtPort(1).getStaticDims(); - std::vector weiBlockDims = {div_up(weiDims[0], weiFirstDimBlockSize), div_up(weiDims[1], 32), - weiDims[2], weiDims[3], weiFirstDimBlockSize, 32}; + std::vector weiBlockDims = {div_up(weiDims[0], weiFirstDimBlockSize), + div_up(weiDims[1], 32), + weiDims[2], + weiDims[3], + weiFirstDimBlockSize, + 32}; std::vector weiOrder = {0, 1, 2, 3, 0, 1}; - config.inConfs[1].setMemDesc(std::make_shared(ov::element::u1, Shape(weiDims), weiBlockDims, weiOrder)); + config.inConfs[1].setMemDesc( + std::make_shared(ov::element::u1, Shape(weiDims), weiBlockDims, weiOrder)); - //result + // result auto outputPrecision = withBinarization ? ov::element::u1 : ov::element::f32; config.outConfs[0].setMemDesc(nspcCreator->createSharedDesc(outputPrecision, getOutputShapeAtPort(0))); if (withSum) { @@ -1056,14 +1101,15 @@ void BinaryConvolution::createPrimitive() { jcp.with_dw_conv = false; jcp.with_binarization = withBinarization; - const auto &p = (*attr.get()).post_ops_; + const auto& p = (*attr.get()).post_ops_; jcp.with_sum = p.find(primitive_kind::sum) != -1; jcp.with_binarization = p.find(primitive_kind::binarization) != -1; int simd_w = implType == impl_desc_type::jit_avx512 ? 16 : 8; jcp.ur_w = implType == impl_desc_type::jit_avx512 ? 4 : 2; - if (jcp.ow < jcp.ur_w) jcp.ur_w = jcp.ow; + if (jcp.ow < jcp.ur_w) + jcp.ur_w = jcp.ow; jcp.ur_w_tail = jcp.ow % jcp.ur_w; jcp.ic_block = 32; @@ -1073,7 +1119,10 @@ void BinaryConvolution::createPrimitive() { jcp.oc_block = simd_w; jcp.nb_oc = div_up(jcp.oc, jcp.oc_block); - jcp.nb_oc_blocking = nstl::min(implType == impl_desc_type::jit_sse42 ? 2 : implType == impl_desc_type::jit_avx2 ? 4 : 6, jcp.nb_oc); + jcp.nb_oc_blocking = nstl::min(implType == impl_desc_type::jit_sse42 ? 2 + : implType == impl_desc_type::jit_avx2 ? 4 + : 6, + jcp.nb_oc); auto srcPrecision = getParentEdgeAt(0)->getMemory().getDesc().getPrecision(); auto dstPrecision = getChildEdgeAt(0)->getMemory().getDesc().getPrecision(); @@ -1082,11 +1131,13 @@ void BinaryConvolution::createPrimitive() { jcp.typesize_in = srcPrecision == ov::element::u1 ? 1 : srcPrecision.size(); jcp.typesize_out = dstPrecision == ov::element::u1 ? 1 : dstPrecision.size(); - int r_pad_no_tail = nstl::max(0, (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w - + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); + int r_pad_no_tail = nstl::max( + 0, + (jcp.ow - jcp.ur_w_tail - 1) * jcp.stride_w + (jcp.kw - 1) * (jcp.dilate_w + 1) - (jcp.iw + jcp.l_pad - 1)); - bool args_ok = (jcp.l_pad <= jcp.ur_w) && (r_pad_no_tail <= jcp.ur_w) && - IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0) || (jcp.stride_w == 1 && jcp.stride_h == 1)); + bool args_ok = + (jcp.l_pad <= jcp.ur_w) && (r_pad_no_tail <= jcp.ur_w) && + IMPLICATION(jcp.kw > 7, (jcp.t_pad == 0 && jcp.l_pad == 0) || (jcp.stride_w == 1 && jcp.stride_h == 1)); if (!args_ok) OPENVINO_THROW("BinaryConvolution with name '", getName(), "' has unsupported parameters"); #if defined(OPENVINO_ARCH_X86_64) @@ -1122,12 +1173,12 @@ bool BinaryConvolution::canFuse(const NodePtr& node) const { } } -void BinaryConvolution::setPostOps(dnnl::primitive_attr &attr) { +void BinaryConvolution::setPostOps(dnnl::primitive_attr& attr) { dnnl::post_ops ops; postOpsDataPtrs.clear(); - for (auto &node : fusedWith) { - auto* eltwiseNode = dynamic_cast(node.get()); + for (auto& node : fusedWith) { + auto* eltwiseNode = dynamic_cast(node.get()); if (eltwiseNode) { if (eltwiseNode->isSpecialConvolutionAddFusing()) { ops.append_sum(1.0); @@ -1138,7 +1189,7 @@ void BinaryConvolution::setPostOps(dnnl::primitive_attr &attr) { continue; } - auto* fakeQuantizeNode = dynamic_cast(node.get()); + auto* fakeQuantizeNode = dynamic_cast(node.get()); if (fakeQuantizeNode) { fakeQuantizeNode->appendPostOps(ops, getOutputShapeAtPort(0).getStaticDims(), postOpsDataPtrs); continue; @@ -1154,9 +1205,13 @@ void BinaryConvolution::setPostOps(dnnl::primitive_attr &attr) { attr.set_post_ops(ops); } -void BinaryConvolution::executeOptimized(const uint8_t* src, const uint8_t* weights, uint8_t* dst, - const std::vector& s_str, const std::vector& w_str, const std::vector& d_str) { - auto dst_f32 = reinterpret_cast(dst); +void BinaryConvolution::executeOptimized(const uint8_t* src, + const uint8_t* weights, + uint8_t* dst, + const std::vector& s_str, + const std::vector& w_str, + const std::vector& d_str) { + auto dst_f32 = reinterpret_cast(dst); const int MB = jcp.mb; @@ -1170,26 +1225,28 @@ void BinaryConvolution::executeOptimized(const uint8_t* src, const uint8_t* weig auto par_conv = jit_bin_conv_call_args(); const int ij = oh * jcp.stride_h; - const int i_t_overflow = nstl::min(jcp.kh, div_up(nstl::max(0, jcp.t_pad - ij), (jcp.dilate_h+1))); - const int i_b_overflow = nstl::min(jcp.kh, div_up(nstl::max(jcp.ih, ij + (jcp.kh-1) * (jcp.dilate_h+1) - - jcp.t_pad+1) - jcp.ih, (jcp.dilate_h + 1))); + const int i_t_overflow = nstl::min(jcp.kh, div_up(nstl::max(0, jcp.t_pad - ij), (jcp.dilate_h + 1))); + const int i_b_overflow = + nstl::min(jcp.kh, + div_up(nstl::max(jcp.ih, ij + (jcp.kh - 1) * (jcp.dilate_h + 1) - jcp.t_pad + 1) - jcp.ih, + (jcp.dilate_h + 1))); const size_t _oc = g * jcp.nb_oc + ocb; const size_t _ic = g * jcp.nb_ic; const int ih = nstl::max(ij - jcp.t_pad + i_t_overflow * (jcp.dilate_h + 1), 0); - par_conv.src = &src[(n * s_str[0] + _ic*jcp.ic_block * s_str[1] + ih * s_str[2]) / nbits]; + par_conv.src = &src[(n * s_str[0] + _ic * jcp.ic_block * s_str[1] + ih * s_str[2]) / nbits]; if (jcp.with_binarization) { - par_conv.dst = &dst[(n * d_str[0] + _oc*jcp.oc_block * d_str[1] + oh * d_str[2]) / nbits]; + par_conv.dst = &dst[(n * d_str[0] + _oc * jcp.oc_block * d_str[1] + oh * d_str[2]) / nbits]; } else { - par_conv.dst = &dst_f32[n * d_str[0] + _oc*jcp.oc_block * d_str[1] + oh * d_str[2]]; + par_conv.dst = &dst_f32[n * d_str[0] + _oc * jcp.oc_block * d_str[1] + oh * d_str[2]]; } const int wh = jcp.exclude_pad ? i_t_overflow : 0; par_conv.filt = &weights[(ocb * w_str[0] + wh * w_str[2]) / nbits]; - par_conv.oc_work = nstl::min((ocb + ocb_num) * jcp.oc_block, jcp.oc) - ocb*jcp.oc_block; + par_conv.oc_work = nstl::min((ocb + ocb_num) * jcp.oc_block, jcp.oc) - ocb * jcp.oc_block; par_conv.kw_padding = 0; const int kh_padding = jcp.kh - i_t_overflow - i_b_overflow; @@ -1204,9 +1261,13 @@ void BinaryConvolution::executeOptimized(const uint8_t* src, const uint8_t* weig }); } -void BinaryConvolution::executeReference(const uint8_t* src, const uint8_t* weights, uint8_t* dst, - const std::vector& s_str, const std::vector& w_str, const std::vector& d_str) { - auto dst_fp = reinterpret_cast(dst); +void BinaryConvolution::executeReference(const uint8_t* src, + const uint8_t* weights, + uint8_t* dst, + const std::vector& s_str, + const std::vector& w_str, + const std::vector& d_str) { + auto dst_fp = reinterpret_cast(dst); const bool with_groups = jcp.ngroups > 1; @@ -1240,7 +1301,7 @@ void BinaryConvolution::executeReference(const uint8_t* src, const uint8_t* weig return (uint8_t)((val >> bit) & 0x0001); }; - auto ker = [=](int32_t &d, int g, int mb, int oc, int oh, int ow) { + auto ker = [=](int32_t& d, int g, int mb, int oc, int oh, int ow) { for (int ic = 0; ic < IC; ++ic) { for (int kh = 0; kh < KH; ++kh) { for (int kw = 0; kw < KW; ++kw) { @@ -1259,14 +1320,14 @@ void BinaryConvolution::executeReference(const uint8_t* src, const uint8_t* weig if (pad_value == 0) continue; else - s = pad_value == 1.0f ? (uint8_t) 1 : (uint8_t) 0; + s = pad_value == 1.0f ? (uint8_t)1 : (uint8_t)0; } else { - s = extract_bit(src[iidx / nbits], (uint8_t) (iidx % nbits)); + s = extract_bit(src[iidx / nbits], (uint8_t)(iidx % nbits)); } - uint8_t w = extract_bit(weights[widx / nbits], (uint8_t) (widx % nbits)); + uint8_t w = extract_bit(weights[widx / nbits], (uint8_t)(widx % nbits)); - d += (int32_t) (s ^ w); + d += (int32_t)(s ^ w); } } } @@ -1280,13 +1341,11 @@ void BinaryConvolution::executeReference(const uint8_t* src, const uint8_t* weig if (pad_value == 0.0f) { const int i_left_overflow = nstl::max(0, (padL - ow * KSW)); const int i_right_overflow = nstl::max(IW, (ow * KSW + (KW - 1) * (KDW + 1) - padL + 1)) - IW; - const int kw_padding = - KW - div_up(i_left_overflow, (KDW + 1)) - div_up(i_right_overflow, (KDW + 1)); + const int kw_padding = KW - div_up(i_left_overflow, (KDW + 1)) - div_up(i_right_overflow, (KDW + 1)); const int i_top_overflow = nstl::max(0, (padT - oh * KSH)); const int i_bottom_overflow = nstl::max(IH, (oh * KSH + (KH - 1) * (KDH + 1) - padT + 1)) - IH; - const int kh_padding = - KH - div_up(i_top_overflow, (KDH + 1)) - div_up(i_bottom_overflow, (KDH + 1)); + const int kh_padding = KH - div_up(i_top_overflow, (KDH + 1)) - div_up(i_bottom_overflow, (KDH + 1)); base_value = IC * kh_padding * kw_padding; } else { @@ -1295,7 +1354,7 @@ void BinaryConvolution::executeReference(const uint8_t* src, const uint8_t* weig float a_fp = base_value - static_cast(2 * a); - dst_fp[mb * d_str[0] + (g*OC + oc) * d_str[1] + oh * d_str[2] + ow * d_str[3]] = a_fp; + dst_fp[mb * d_str[0] + (g * OC + oc) * d_str[1] + oh * d_str[2] + ow * d_str[3]] = a_fp; }); } @@ -1342,6 +1401,6 @@ bool BinaryConvolution::created() const { return getType() == Type::BinaryConvolution; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/bin_conv.h b/src/plugins/intel_cpu/src/nodes/bin_conv.h index 86b5cb41b2bf6d..661e075b680ec7 100644 --- a/src/plugins/intel_cpu/src/nodes/bin_conv.h +++ b/src/plugins/intel_cpu/src/nodes/bin_conv.h @@ -39,9 +39,9 @@ struct jit_dw_conv_params { }; struct jit_bin_conv_call_args { - const void *src; - const void *dst; - const void *filt; + const void* src; + const void* dst; + const void* filt; size_t kh_padding; size_t kw_padding; size_t oc_work; @@ -52,15 +52,20 @@ struct jit_bin_conv_call_args { }; struct jit_uni_bin_conv_kernel { - void (*ker_)(const jit_bin_conv_call_args *); + void (*ker_)(const jit_bin_conv_call_args*); - void operator()(const jit_bin_conv_call_args *args) { + void operator()(const jit_bin_conv_call_args* args) { assert(ker_); ker_(args); } - explicit jit_uni_bin_conv_kernel(jit_bin_conv_params jcp, jit_dw_conv_params jcp_dw_conv, const dnnl_primitive_attr &attr) : - ker_(nullptr), jcp_(jcp), jcp_dw_conv_(jcp_dw_conv), attr_(attr) {} + explicit jit_uni_bin_conv_kernel(jit_bin_conv_params jcp, + jit_dw_conv_params jcp_dw_conv, + const dnnl_primitive_attr& attr) + : ker_(nullptr), + jcp_(jcp), + jcp_dw_conv_(jcp_dw_conv), + attr_(attr) {} virtual ~jit_uni_bin_conv_kernel() {} virtual void create_ker() = 0; @@ -68,7 +73,7 @@ struct jit_uni_bin_conv_kernel { jit_bin_conv_params jcp_; jit_dw_conv_params jcp_dw_conv_; - const dnnl_primitive_attr &attr_; + const dnnl_primitive_attr& attr_; }; class BinaryConvolution : public Node { @@ -83,12 +88,14 @@ class BinaryConvolution : public Node { bool canBeInPlace() const override { return false; } - void setPostOps(dnnl::primitive_attr &attr); + void setPostOps(dnnl::primitive_attr& attr); bool canFuse(const NodePtr& node) const override; static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; - impl_desc_type getImplType() { return implType; } + impl_desc_type getImplType() { + return implType; + } private: bool withSum = false; @@ -110,14 +117,22 @@ class BinaryConvolution : public Node { impl_desc_type implType = impl_desc_type::ref; - void executeOptimized(const uint8_t* src, const uint8_t* weights, uint8_t* dst, - const std::vector& s_str, const std::vector& w_str, const std::vector& d_str); - void executeReference(const uint8_t* src, const uint8_t* weights, uint8_t* dst, - const std::vector& s_str, const std::vector& w_str, const std::vector& d_str); + void executeOptimized(const uint8_t* src, + const uint8_t* weights, + uint8_t* dst, + const std::vector& s_str, + const std::vector& w_str, + const std::vector& d_str); + void executeReference(const uint8_t* src, + const uint8_t* weights, + uint8_t* dst, + const std::vector& s_str, + const std::vector& w_str, + const std::vector& d_str); std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/broadcast.cpp b/src/plugins/intel_cpu/src/nodes/broadcast.cpp index c88803e07de601..646e186922b397 100644 --- a/src/plugins/intel_cpu/src/nodes/broadcast.cpp +++ b/src/plugins/intel_cpu/src/nodes/broadcast.cpp @@ -2,15 +2,18 @@ // SPDX-License-Identifier: Apache-2.0 // -#include +#include "broadcast.h" + +#include + #include +#include + +#include "common/cpu_memcpy.h" #include "dnnl_types.h" -#include "openvino/core/parallel.hpp" -#include -#include "broadcast.h" #include "nodes/common/blocked_desc_creator.h" +#include "openvino/core/parallel.hpp" #include "openvino/opsets/opset1.hpp" -#include "common/cpu_memcpy.h" #include "utils/ngraph_utils.hpp" namespace ov { @@ -24,19 +27,20 @@ bool Broadcast::isSupportedOperation(const std::shared_ptr& op, return false; } if (!one_of(ov::as_type_ptr(op)->get_broadcast_spec().m_type, - ov::op::AutoBroadcastType::NUMPY, ov::op::AutoBroadcastType::EXPLICIT)) { + ov::op::AutoBroadcastType::NUMPY, + ov::op::AutoBroadcastType::EXPLICIT)) { errorMessage = "Only NUMPY and EXPLICIT broadcast types are supported."; return false; } if (op->get_input_partial_shape(TARGET_SHAPE_IDX).is_dynamic() || - (op->get_input_size() > AXES_MAPPING_IDX && op->get_input_partial_shape(AXES_MAPPING_IDX).is_dynamic())) { + (op->get_input_size() > AXES_MAPPING_IDX && op->get_input_partial_shape(AXES_MAPPING_IDX).is_dynamic())) { errorMessage = "Only static shapes are supported for target shape and axes mapping inputs."; return false; } if (!isDynamicNgraphNode(op) && - (!ov::is_type(op->get_input_node_ptr(TARGET_SHAPE_IDX)) || - (op->get_input_size() > AXES_MAPPING_IDX && - !ov::is_type(op->get_input_node_ptr(AXES_MAPPING_IDX))))) { + (!ov::is_type(op->get_input_node_ptr(TARGET_SHAPE_IDX)) || + (op->get_input_size() > AXES_MAPPING_IDX && + !ov::is_type(op->get_input_node_ptr(AXES_MAPPING_IDX))))) { errorMessage = "Only constant target shapes and axis mapping inputs are supported for static shapes."; return false; } @@ -72,12 +76,13 @@ Broadcast::Broadcast(const std::shared_ptr& op, const GraphContext::CP if (ov::is_type(op->get_input_node_ptr(TARGET_SHAPE_IDX))) { constMap[TARGET_SHAPE_IDX] = true; - targetShape = (ov::as_type(op->get_input_node_ptr(TARGET_SHAPE_IDX)))->get_vector(); + targetShape = + (ov::as_type(op->get_input_node_ptr(TARGET_SHAPE_IDX)))->get_vector(); } - if (broadcastType == EXPLICIT && - ov::is_type(op->get_input_node_ptr(AXES_MAPPING_IDX))) { + if (broadcastType == EXPLICIT && ov::is_type(op->get_input_node_ptr(AXES_MAPPING_IDX))) { constMap[AXES_MAPPING_IDX] = true; - axesMapping = ov::as_type(op->get_input_node_ptr(AXES_MAPPING_IDX))->get_vector(); + axesMapping = + ov::as_type(op->get_input_node_ptr(AXES_MAPPING_IDX))->get_vector(); } } @@ -126,7 +131,8 @@ void Broadcast::prepareParams() { repeats.assign(targetShape.begin(), targetShape.end()); const auto ndims = repeats.size(); - auto srcBlockedDims = getParentEdgeAt(INPUT_DATA_IDX)->getMemory().getDescWithType()->getBlockDims(); + auto srcBlockedDims = + getParentEdgeAt(INPUT_DATA_IDX)->getMemory().getDescWithType()->getBlockDims(); auto dstBlockedDims = getChildEdgeAt(0)->getMemory().getDescWithType()->getBlockDims(); if (broadcastType == NUMPY) { @@ -227,8 +233,8 @@ void Broadcast::plainExecute(dnnl::stream strm) { } const size_t workAmountDst = dstStrides[0] * dstDims[0]; - const auto *srcData = getSrcDataAtPortAs(INPUT_DATA_IDX); - auto *dstData = getDstDataAtPortAs(0); + const auto* srcData = getSrcDataAtPortAs(INPUT_DATA_IDX); + auto* dstData = getDstDataAtPortAs(0); parallel_nt(0, [&](const int ithr, const int nthr) { size_t i = 0lu, srcIdx = 0lu, start = 0lu, end = 0lu; @@ -246,7 +252,8 @@ void Broadcast::plainExecute(dnnl::stream strm) { for (int j = dataDstRank - 1; j >= 0; j--) { counters[j] = (counters[j] + 1) % dstDims[j]; - if (counters[j] != 0) break; + if (counters[j] != 0) + break; } } }); @@ -256,6 +263,6 @@ bool Broadcast::created() const { return getType() == Type::Broadcast; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/broadcast.h b/src/plugins/intel_cpu/src/nodes/broadcast.h index 1435314ee08776..df9ad4614e311d 100644 --- a/src/plugins/intel_cpu/src/nodes/broadcast.h +++ b/src/plugins/intel_cpu/src/nodes/broadcast.h @@ -4,12 +4,12 @@ #pragma once -#include "common/tile_broadcast_utils.h" - #include #include #include +#include "common/tile_broadcast_utils.h" + namespace ov { namespace intel_cpu { namespace node { @@ -35,10 +35,7 @@ class Broadcast : public Node, public TileBroadcastCommon { private: void plainExecute(dnnl::stream strm); - enum AutoBroadcastType { - NUMPY, - EXPLICIT - }; + enum AutoBroadcastType { NUMPY, EXPLICIT }; AutoBroadcastType broadcastType = NUMPY; static constexpr size_t INPUT_DATA_IDX = 0; @@ -51,6 +48,6 @@ class Broadcast : public Node, public TileBroadcastCommon { std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/bucketize.cpp b/src/plugins/intel_cpu/src/nodes/bucketize.cpp index a71255c0d531e4..cfa4bb031501ef 100644 --- a/src/plugins/intel_cpu/src/nodes/bucketize.cpp +++ b/src/plugins/intel_cpu/src/nodes/bucketize.cpp @@ -2,14 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "bucketize.h" + +#include +#include #include #include -#include -#include "openvino/opsets/opset3.hpp" -#include #include "openvino/core/parallel.hpp" -#include "bucketize.h" +#include "openvino/opsets/opset3.hpp" namespace ov { namespace intel_cpu { @@ -70,16 +71,15 @@ void Bucketize::initSupportedPrimitiveDescriptors() { output_precision = ov::element::i32; } - addSupportedPrimDesc({{LayoutType::ncsp, input_precision}, - {LayoutType::ncsp, boundaries_precision}}, + addSupportedPrimDesc({{LayoutType::ncsp, input_precision}, {LayoutType::ncsp, boundaries_precision}}, {{LayoutType::ncsp, output_precision}}, impl_desc_type::ref_any); } inline constexpr uint32_t getElementsMask(ov::element::Type precision1, - ov::element::Type precision2, - ov::element::Type precision3 = ov::element::undefined, - ov::element::Type precision4 = ov::element::undefined) { + ov::element::Type precision2, + ov::element::Type precision3 = ov::element::undefined, + ov::element::Type precision4 = ov::element::undefined) { return static_cast(ov::element::Type_t(precision1)) | (static_cast(ov::element::Type_t(precision2)) << 8) | (static_cast(ov::element::Type_t(precision3)) << 16) | @@ -90,98 +90,98 @@ void Bucketize::execute(dnnl::stream strm) { auto precision_mask = getElementsMask(input_precision, boundaries_precision, output_precision); switch (precision_mask) { - case getElementsMask(ov::element::f32, ov::element::f32, ov::element::i32): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::f32, ov::element::f32, ov::element::i64): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::f32, ov::element::i32, ov::element::i32): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::f32, ov::element::i32, ov::element::i64): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::f32, ov::element::i64, ov::element::i32): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::f32, ov::element::i64, ov::element::i64): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::i32, ov::element::f32, ov::element::i32): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::i32, ov::element::f32, ov::element::i64): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::i32, ov::element::i32, ov::element::i32): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::i32, ov::element::i32, ov::element::i64): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::i32, ov::element::i64, ov::element::i32): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::i32, ov::element::i64, ov::element::i64): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::i64, ov::element::f32, ov::element::i32): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::i64, ov::element::f32, ov::element::i64): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::i64, ov::element::i32, ov::element::i32): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::i64, ov::element::i32, ov::element::i64): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::i64, ov::element::i64, ov::element::i32): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - case getElementsMask(ov::element::i64, ov::element::i64, ov::element::i64): - bucketize::value_type, - element_type_traits::value_type, - element_type_traits::value_type>(); - break; - default: - OPENVINO_THROW(errorPrefix, " has unsupported precision: ", precision_mask); + case getElementsMask(ov::element::f32, ov::element::f32, ov::element::i32): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::f32, ov::element::f32, ov::element::i64): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::f32, ov::element::i32, ov::element::i32): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::f32, ov::element::i32, ov::element::i64): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::f32, ov::element::i64, ov::element::i32): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::f32, ov::element::i64, ov::element::i64): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::i32, ov::element::f32, ov::element::i32): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::i32, ov::element::f32, ov::element::i64): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::i32, ov::element::i32, ov::element::i32): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::i32, ov::element::i32, ov::element::i64): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::i32, ov::element::i64, ov::element::i32): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::i32, ov::element::i64, ov::element::i64): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::i64, ov::element::f32, ov::element::i32): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::i64, ov::element::f32, ov::element::i64): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::i64, ov::element::i32, ov::element::i32): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::i64, ov::element::i32, ov::element::i64): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::i64, ov::element::i64, ov::element::i32): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + case getElementsMask(ov::element::i64, ov::element::i64, ov::element::i64): + bucketize::value_type, + element_type_traits::value_type, + element_type_traits::value_type>(); + break; + default: + OPENVINO_THROW(errorPrefix, " has unsupported precision: ", precision_mask); } } @@ -222,9 +222,9 @@ bool Bucketize::isExecutable() const { template void Bucketize::bucketize() { - const auto *input_data = getSrcDataAtPortAs(0); - const auto *boundaries_data = getSrcDataAtPortAs(1); - auto *output_data = getDstDataAtPortAs(0); + const auto* input_data = getSrcDataAtPortAs(0); + const auto* boundaries_data = getSrcDataAtPortAs(1); + auto* output_data = getDstDataAtPortAs(0); if (!with_bins) { memset(output_data, 0, num_values * sizeof(T_IND)); @@ -248,6 +248,6 @@ bool Bucketize::created() const { return getType() == Type::Bucketize; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/bucketize.h b/src/plugins/intel_cpu/src/nodes/bucketize.h index c834921a38ce54..0ecdd633838950 100644 --- a/src/plugins/intel_cpu/src/nodes/bucketize.h +++ b/src/plugins/intel_cpu/src/nodes/bucketize.h @@ -14,7 +14,7 @@ class Bucketize : public Node { public: Bucketize(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -46,6 +46,6 @@ class Bucketize : public Node { std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/causal_mask_preprocess.cpp b/src/plugins/intel_cpu/src/nodes/causal_mask_preprocess.cpp index 674d77265c9219..fd015a372ed1db 100644 --- a/src/plugins/intel_cpu/src/nodes/causal_mask_preprocess.cpp +++ b/src/plugins/intel_cpu/src/nodes/causal_mask_preprocess.cpp @@ -4,16 +4,16 @@ #include "causal_mask_preprocess.h" +#include +#include +#include + #include "common/bfloat16.hpp" #include "common/cpu_memcpy.h" #include "cpu/x64/cpu_isa_traits.hpp" #include "shape_inference/shape_inference_internal_dyn.hpp" #include "utils/plain_tensor.hpp" -#include -#include -#include - namespace ov { namespace intel_cpu { namespace node { @@ -48,7 +48,7 @@ The functionality is equivalent to following python code: template struct CausalMaskPreprocess::ExecutorCausalMaskPreprocess : public CausalMaskPreprocess::Executor { void execute(dnnl::stream strm, - intel_cpu::Node * pnode, + intel_cpu::Node* pnode, const intel_cpu::CausalMaskPreprocessNode::Config& config) override { ov::intel_cpu::PlainTensor t_attention_mask(pnode->getSrcMemoryAtPort(0)); ov::intel_cpu::PlainTensor t_batch_size(pnode->getSrcMemoryAtPort(1)); @@ -64,7 +64,14 @@ struct CausalMaskPreprocess::ExecutorCausalMaskPreprocess : public CausalMaskPre pnode->redefineOutputMemory({newDims}); ov::intel_cpu::PlainTensor t_dst(pnode->getDstMemoryAtPort(0)); - DEBUG_LOG("CausalMaskPreprocess::execute", config.type, " batch_size=", batch_size, " qLen=", qLen, " kvLen=", kvLen); + DEBUG_LOG("CausalMaskPreprocess::execute", + config.type, + " batch_size=", + batch_size, + " qLen=", + qLen, + " kvLen=", + kvLen); DEBUG_LOG("CausalMaskPreprocess::execute attention_mask=", t_attention_mask); DEBUG_LOG("CausalMaskPreprocess::execute cache_positions=", t_cache_positions); @@ -81,7 +88,7 @@ struct CausalMaskPreprocess::ExecutorCausalMaskPreprocess : public CausalMaskPre bool cmask_eq0 = (j <= row); bool amask_eq0 = (pamask[j] == 0); bool padding_mask = (cmask_eq0 && amask_eq0); - pdst[j] = (padding_mask | (!cmask_eq0))? min_dtype : T(0); + pdst[j] = (padding_mask | (!cmask_eq0)) ? min_dtype : T(0); } for (; j < kvLen; j++) { bool cmask_eq0 = (j <= row); @@ -103,7 +110,8 @@ CausalMaskPreprocess::CausalMaskPreprocess(const std::shared_ptr& op, m_config = node->get_config(); } -bool CausalMaskPreprocess::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool CausalMaskPreprocess::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { const auto node = std::dynamic_pointer_cast(op); if (!node) { @@ -133,7 +141,8 @@ void CausalMaskPreprocess::initSupportedPrimitiveDescriptors() { oprecs[0] = ov::element::f32; } // all input precisions must be int32 - for (auto& prec : iprecs) prec = ov::element::i32; + for (auto& prec : iprecs) + prec = ov::element::i32; } else { OPENVINO_THROW("CPU: CausalMaskPreprocess type not supported : " + m_config.type); } diff --git a/src/plugins/intel_cpu/src/nodes/causal_mask_preprocess.h b/src/plugins/intel_cpu/src/nodes/causal_mask_preprocess.h index eeb997c4cefb9f..444f242b0597a7 100644 --- a/src/plugins/intel_cpu/src/nodes/causal_mask_preprocess.h +++ b/src/plugins/intel_cpu/src/nodes/causal_mask_preprocess.h @@ -32,7 +32,7 @@ class CausalMaskPreprocess : public Node { private: struct Executor { virtual void execute(dnnl::stream strm, - intel_cpu::Node * pnode, + intel_cpu::Node* pnode, const intel_cpu::CausalMaskPreprocessNode::Config& config) = 0; virtual ~Executor() = default; }; diff --git a/src/plugins/intel_cpu/src/nodes/col2im.cpp b/src/plugins/intel_cpu/src/nodes/col2im.cpp index 4b83e78fd82505..409607ea6bb89c 100644 --- a/src/plugins/intel_cpu/src/nodes/col2im.cpp +++ b/src/plugins/intel_cpu/src/nodes/col2im.cpp @@ -3,8 +3,9 @@ // #include "col2im.h" -#include "openvino/reference/col2im.hpp" + #include "openvino/op/col2im.hpp" +#include "openvino/reference/col2im.hpp" namespace ov { namespace intel_cpu { @@ -62,42 +63,42 @@ void Col2Im::executeDynamicImpl(dnnl::stream strm) { template void Col2Im::executeImpl() { - ov::reference::col2im( - getSrcDataAtPortAs(0), - ov::Shape{getSrcMemoryAtPort(0)->getStaticDims()}, - getSrcDataAtPortAs(1), - getSrcDataAtPortAs(2), - getDstDataAtPortAs(0), - strides, - dilations, - padsBegin, - padsEnd); + ov::reference::col2im(getSrcDataAtPortAs(0), + ov::Shape{getSrcMemoryAtPort(0)->getStaticDims()}, + getSrcDataAtPortAs(1), + getSrcDataAtPortAs(2), + getDstDataAtPortAs(0), + strides, + dilations, + padsBegin, + padsEnd); } namespace { struct Col2ImContext { - Col2Im &node; + Col2Im& node; }; -} +} // namespace -template +template struct Col2Im::Col2ImExecute { using TData = typename std::tuple_element<0, T>::type; using TIndex = typename std::tuple_element<1, T>::type; - void operator()(Col2ImContext & ctx) { - ctx.node.executeImpl(); - } + void operator()(Col2ImContext& ctx) { + ctx.node.executeImpl(); + } }; void Col2Im::execute(dnnl::stream strm) { auto dataPrecision = getParentEdgeAt(0)->getMemory().getDesc().getPrecision(); auto indexPrecision = getParentEdgeAt(1)->getMemory().getDesc().getPrecision(); - Col2ImContext ctx = { - *this - }; + Col2ImContext ctx = {*this}; - OV_SWITCH(intel_cpu, Col2ImExecute, ctx, std::tie(dataPrecision, indexPrecision), + OV_SWITCH(intel_cpu, + Col2ImExecute, + ctx, + std::tie(dataPrecision, indexPrecision), OV_CASE2(ov::element::f32, ov::element::i32, float, int32_t), OV_CASE2(ov::element::f16, ov::element::i32, ov::float16, int32_t), OV_CASE2(ov::element::bf16, ov::element::i32, ov::bfloat16, int32_t), diff --git a/src/plugins/intel_cpu/src/nodes/col2im.h b/src/plugins/intel_cpu/src/nodes/col2im.h index 9904689e53be0f..b56b4bb78469aa 100644 --- a/src/plugins/intel_cpu/src/nodes/col2im.h +++ b/src/plugins/intel_cpu/src/nodes/col2im.h @@ -26,7 +26,7 @@ class Col2Im : public Node { template void executeImpl(); - template + template struct Col2ImExecute; ov::Strides strides; diff --git a/src/plugins/intel_cpu/src/nodes/color_convert.cpp b/src/plugins/intel_cpu/src/nodes/color_convert.cpp index ea3c8e2c774944..a06214b768d6b4 100644 --- a/src/plugins/intel_cpu/src/nodes/color_convert.cpp +++ b/src/plugins/intel_cpu/src/nodes/color_convert.cpp @@ -3,14 +3,17 @@ // #include "color_convert.h" + #include -#include -#include -#include -#include + #include -#include "openvino/core/parallel.hpp" +#include +#include +#include +#include + #include "kernels/x64/jit_kernel.hpp" +#include "openvino/core/parallel.hpp" #include "shape_inference/custom/color_convert.hpp" using namespace dnnl::impl; @@ -39,7 +42,7 @@ class Converter : public ColorConvert::Converter { using Base = ColorConvert::Converter; public: - Converter(Node *node); + Converter(Node* node); bool singlePlane() const; @@ -47,12 +50,12 @@ class Converter : public ColorConvert::Converter { std::tuple yuv_to_rgb(float y, float u, float v); }; -Converter::Converter(Node *node) - : Base(node, node->getAlgorithm() == Algorithm::ColorConvertNV12toRGB - || node->getAlgorithm() == Algorithm::ColorConvertI420toRGB - ? ColorFormat { { 0, 1, 2 } } - : ColorFormat { { 2, 1, 0 } }) { -} +Converter::Converter(Node* node) + : Base(node, + node->getAlgorithm() == Algorithm::ColorConvertNV12toRGB || + node->getAlgorithm() == Algorithm::ColorConvertI420toRGB + ? ColorFormat{{0, 1, 2}} + : ColorFormat{{2, 1, 0}}) {} bool Converter::singlePlane() const { return _node->getOriginalInputsNumber() == 1; @@ -81,46 +84,43 @@ struct jit_uni_converter : public jit_kernel { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_converter) struct Params { - const void * y; - const void * u; - const void * v; - void * dst; + const void* y; + const void* u; + const void* v; + void* dst; size_t width; - uint8_t colorFormat; // RGB: 0, BGR: !=0 + uint8_t colorFormat; // RGB: 0, BGR: !=0 }; - typedef void (*function_t)(const Params *); + typedef void (*function_t)(const Params*); void init(); - void operator()(const Params & args) const { + void operator()(const Params& args) const { _fn(&args); } protected: jit_uni_converter(); - template - void yuv_to_rgb(const variable & y, - const variable & u, - const variable & v, - const variable & color_format, + template + void yuv_to_rgb(const variable& y, + const variable& u, + const variable& v, + const variable& color_format, bool round); - template - void store_tail(const variable & dst, - const variable & a, - const variable & b, - const variable & c, - const variable & size); + template + void store_tail(const variable& dst, + const variable& a, + const variable& b, + const variable& c, + const variable& size); function_t _fn; variable _consts; }; -jit_uni_converter::jit_uni_converter() - : jit_kernel(jit_name()), - _consts(*this) { -} +jit_uni_converter::jit_uni_converter() : jit_kernel(jit_name()), _consts(*this) {} void jit_uni_converter::init() { if (create_kernel() != status::success) @@ -128,15 +128,13 @@ void jit_uni_converter::init() { _fn = (function_t)jit_ker(); } -template -void jit_uni_converter::yuv_to_rgb(const variable & y, - const variable & u, - const variable & v, - const variable & color_format, +template +void jit_uni_converter::yuv_to_rgb(const variable& y, + const variable& u, + const variable& v, + const variable& color_format, bool round) { - auto clip = [&](const variable & op, - const variable & a, - const variable & b) { + auto clip = [&](const variable& op, const variable& a, const variable& b) { if (round) uni_vroundps(op, op, 0); uni_vmaxps(op, op, a); @@ -144,8 +142,12 @@ void jit_uni_converter::yuv_to_rgb(const variable & y, }; // blend r,g,b and put to r0,r1,r2 - auto blend = [&](const variable & r, const variable & g, const variable & b, - const variable & r0, const variable & r1, const variable & r2) { + auto blend = [&](const variable& r, + const variable& g, + const variable& b, + const variable& r0, + const variable& r1, + const variable& r2) { /* Input: r0,r1,r2,r3,r4,r5,r6,r7 @@ -174,7 +176,7 @@ void jit_uni_converter::yuv_to_rgb(const variable & y, */ auto genPermutationMask = [&](int offset) { - std::array mask {}; + std::array mask{}; for (uint8_t i = 0; i < mask.size(); ++i) mask[(i * 3 + offset) % mask.size()] = i; return mask; @@ -184,11 +186,8 @@ void jit_uni_converter::yuv_to_rgb(const variable & y, g.permute(genPermutationMask(1)); b.permute(genPermutationMask(2)); - auto blendWithMask = [&](int offset, const variable & result) { - static const uint32_t blendMasks[2] = { - 0x92492492, - 0x24924924 - }; + auto blendWithMask = [&](int offset, const variable& result) { + static const uint32_t blendMasks[2] = {0x92492492, 0x24924924}; const uint16_t mask0 = static_cast(blendMasks[0] >> ((offset * N) % 3)); const uint16_t mask1 = static_cast(blendMasks[1] >> ((offset * N) % 3)); @@ -208,29 +207,29 @@ void jit_uni_converter::yuv_to_rgb(const variable & y, auto b = var(); auto tmp = var(); - uni_vbroadcastss(tmp, ptr[_consts + 0 * sizeof(float)]); // tmp = [16.0f,16.0f,...] - uni_vsubps(y, y, tmp); // y = y - tmp - uni_vbroadcastss(tmp, ptr[_consts + 1 * sizeof(float)]); // tmp = [128.f,128.f,...] - uni_vsubps(u, u, tmp); // u = u - tmp - uni_vsubps(v, v, tmp); // v = v - tmp + uni_vbroadcastss(tmp, ptr[_consts + 0 * sizeof(float)]); // tmp = [16.0f,16.0f,...] + uni_vsubps(y, y, tmp); // y = y - tmp + uni_vbroadcastss(tmp, ptr[_consts + 1 * sizeof(float)]); // tmp = [128.f,128.f,...] + uni_vsubps(u, u, tmp); // u = u - tmp + uni_vsubps(v, v, tmp); // v = v - tmp - uni_vbroadcastss(tmp, ptr[_consts + 2 * sizeof(float)]); // tmp = [1.164f,1.164f,...] - uni_vmulps(y, y, tmp); // y = y * tmp + uni_vbroadcastss(tmp, ptr[_consts + 2 * sizeof(float)]); // tmp = [1.164f,1.164f,...] + uni_vmulps(y, y, tmp); // y = y * tmp - uni_vbroadcastss(r, ptr[_consts + 3 * sizeof(float)]); // r = [1.596f,1.596f,...] - uni_vmulps(r, r, v); // r = r * v - uni_vaddps(r, r, y); // r = r + y + uni_vbroadcastss(r, ptr[_consts + 3 * sizeof(float)]); // r = [1.596f,1.596f,...] + uni_vmulps(r, r, v); // r = r * v + uni_vaddps(r, r, y); // r = r + y - uni_vbroadcastss(g, ptr[_consts + 4 * sizeof(float)]); // g = [0.391f,0.391f,...] - uni_vmulps(g, g, u); // g = g * u - uni_vsubps(g, y, g); // g = y - g - uni_vbroadcastss(tmp, ptr[_consts + 6 * sizeof(float)]); // tmp = [0.813f,0.813f,...] - uni_vmulps(tmp, tmp, v); // tmp = tmp * v - uni_vsubps(g, g, tmp); // g = g - tmp + uni_vbroadcastss(g, ptr[_consts + 4 * sizeof(float)]); // g = [0.391f,0.391f,...] + uni_vmulps(g, g, u); // g = g * u + uni_vsubps(g, y, g); // g = y - g + uni_vbroadcastss(tmp, ptr[_consts + 6 * sizeof(float)]); // tmp = [0.813f,0.813f,...] + uni_vmulps(tmp, tmp, v); // tmp = tmp * v + uni_vsubps(g, g, tmp); // g = g - tmp - uni_vbroadcastss(b, ptr[_consts + 5 * sizeof(float)]); // b = [2.018f,2.018f,...] - uni_vmulps(b, b, u); // b = b * u - uni_vaddps(b, b, y); // b = b + y + uni_vbroadcastss(b, ptr[_consts + 5 * sizeof(float)]); // b = [2.018f,2.018f,...] + uni_vmulps(b, b, u); // b = b * u + uni_vaddps(b, b, y); // b = b + y // clip uni_vxorps(y, y, y); @@ -241,24 +240,30 @@ void jit_uni_converter::yuv_to_rgb(const variable & y, clip(b, y, u); _if(color_format == 0) - ._then([&]{ blend(r, g, b, y, u, v); }) - ._else([&]{ blend(b, g, r, y, u, v); }); + ._then([&] { + blend(r, g, b, y, u, v); + }) + ._else([&] { + blend(b, g, r, y, u, v); + }); } -template -void jit_uni_converter::store_tail(const variable & dst, - const variable & a, - const variable & b, - const variable & c, - const variable & size) { +template +void jit_uni_converter::store_tail(const variable& dst, + const variable& a, + const variable& b, + const variable& c, + const variable& size) { const size_t step = N * sizeof(T); auto s = stack(3 * step); auto sptr = var(); sptr = s.pointer(); - store(sptr, a); sptr += step; - store(sptr, b); sptr += step; + store(sptr, a); + sptr += step; + store(sptr, b); + sptr += step; store(sptr, c); auto copy_size = size * size_t(3u); @@ -269,36 +274,33 @@ void jit_uni_converter::store_tail(const variable & dst, namespace nv12 { -ColorConvert::Converter::PrimitiveDescs supportedPrimitiveDescs(Node *node) { - const LayoutType layout = LayoutType::ncsp; // 0,1,2,3 +ColorConvert::Converter::PrimitiveDescs supportedPrimitiveDescs(Node* node) { + const LayoutType layout = LayoutType::ncsp; // 0,1,2,3 - const ov::element::Type precision = node->getOriginalInputPrecisionAtPort(0) == ov::element::u8 - ? ov::element::u8 - : ov::element::f32; + const ov::element::Type precision = + node->getOriginalInputPrecisionAtPort(0) == ov::element::u8 ? ov::element::u8 : ov::element::f32; ColorConvert::Converter::PrimitiveDescs descs; - descs.emplace_back(std::vector { node->getOriginalInputsNumber(), { layout, precision } }, - std::vector { { layout, precision } }, - mayiuse(cpu_isa_t::sse41) - ? impl_desc_type::jit_uni - : impl_desc_type::ref, - true); + descs.emplace_back(std::vector{node->getOriginalInputsNumber(), {layout, precision}}, + std::vector{{layout, precision}}, + mayiuse(cpu_isa_t::sse41) ? impl_desc_type::jit_uni : impl_desc_type::ref, + true); return descs; } -template +template class SinglePlaneConvert; -template +template class TwoPlaneConvert; class RefConverter : public Converter { public: - RefConverter(Node *node); + RefConverter(Node* node); protected: - template + template void convert(const T* y, const T* uv, T* dst, @@ -309,15 +311,14 @@ class RefConverter : public Converter { size_t stride_uv); }; -RefConverter::RefConverter(Node *node) - : Converter(node) { +RefConverter::RefConverter(Node* node) : Converter(node) { if (node->getOriginalInputsNumber() != (singlePlane() ? 1 : 2)) OPENVINO_THROW("NV12Converter node has incorrect number of inputs"); if (!node->getOriginalOutputsNumber()) OPENVINO_THROW("NV12Converter node has incorrect number of outputs"); } -template +template void RefConverter::convert(const T* y, const T* uv, T* dst, @@ -346,13 +347,13 @@ void RefConverter::convert(const T* y, }); } -template +template class SinglePlaneConvert : public RefConverter { public: using RefConverter::RefConverter; void execute(dnnl::stream strm) override { - const auto & dims = inputDims(0); + const auto& dims = inputDims(0); const size_t batch_size = dims[N_DIM]; const size_t height = dims[H_DIM] * 2 / 3; @@ -362,22 +363,17 @@ class SinglePlaneConvert : public RefConverter { const T* uv = y + width * height; T* dst = static_cast(output(0)); - convert(y, uv, dst, - batch_size, - height, - width, - height * width * 3 / 2, - height * width * 3 / 2); + convert(y, uv, dst, batch_size, height, width, height * width * 3 / 2, height * width * 3 / 2); } }; -template +template class TwoPlaneConvert : public RefConverter { public: using RefConverter::RefConverter; void execute(dnnl::stream strm) override { - const auto & dims = inputDims(0); + const auto& dims = inputDims(0); const T* y = static_cast(input(0)); const T* uv = static_cast(input(1)); @@ -387,34 +383,24 @@ class TwoPlaneConvert : public RefConverter { const size_t height = dims[H_DIM]; const size_t width = dims[W_DIM]; - convert(y, uv, dst, - batch_size, - height, - width, - height * width, - height * width / 2); + convert(y, uv, dst, batch_size, height, width, height * width, height * width / 2); } }; #if defined(OPENVINO_ARCH_X86_64) -template +template class JitConverter; -template +template class JitConverter : public jit_uni_converter { private: void generate() override; - std::tuple, - variable, - variable> - load_yuv(const variable & src_y, - const variable & src_uv); - std::tuple, - variable> - unpack_uv(const variable & uv); + std::tuple, variable, variable> load_yuv(const variable& src_y, + const variable& src_uv); + std::tuple, variable> unpack_uv(const variable& uv); }; -template +template void JitConverter::generate() { preamble(); @@ -425,7 +411,7 @@ void JitConverter::generate() { auto width = arg(&Params::width); auto colorFormat = arg(&Params::colorFormat); - static const float data[8] = { 16.f, 128.f, 1.164f, 1.596f, 0.391f, 2.018f, 0.813f, 255.f }; + static const float data[8] = {16.f, 128.f, 1.164f, 1.596f, 0.391f, 2.018f, 0.813f, 255.f}; _consts = data; const size_t reg_capacity_log = static_cast(std::logb(N)); @@ -433,26 +419,29 @@ void JitConverter::generate() { width >>= reg_capacity_log; - foreach(0, width, [&](const Reg64 & idx) { + foreach (0, width, [&](const Reg64& idx) { auto yuv = load_yuv(src_y, src_uv); // Aliases - const auto & y = std::get<0>(yuv); - const auto & u = std::get<1>(yuv); - const auto & v = std::get<2>(yuv); + const auto& y = std::get<0>(yuv); + const auto& u = std::get<1>(yuv); + const auto& v = std::get<2>(yuv); yuv_to_rgb(y, u, v, colorFormat, std::is_integral::value); - store(dst, y); dst += step; - store(dst, u); dst += step; - store(dst, v); dst += step; - }); + store(dst, y); + dst += step; + store(dst, u); + dst += step; + store(dst, v); + dst += step; + }) + ; mov(width, argPtr(&Params::width)); width &= N - 1; - _if(width != 0) - ._then([&] { + _if(width != 0)._then([&] { auto y = var(); auto uv = var(); @@ -462,8 +451,8 @@ void JitConverter::generate() { auto uv_pair = unpack_uv(uv); // Aliases - const auto & u = std::get<0>(uv_pair); - const auto & v = std::get<1>(uv_pair); + const auto& u = std::get<0>(uv_pair); + const auto& v = std::get<1>(uv_pair); yuv_to_rgb(y, u, v, colorFormat, std::is_integral::value); @@ -473,12 +462,9 @@ void JitConverter::generate() { postamble(); } -template -std::tuple, - jit_kernel::variable, - jit_kernel::variable> -JitConverter::load_yuv(const variable & src_y, - const variable & src_uv) { +template +std::tuple, jit_kernel::variable, jit_kernel::variable> +JitConverter::load_yuv(const variable& src_y, const variable& src_uv) { auto y = var(); auto uv = var(); @@ -490,29 +476,26 @@ JitConverter::load_yuv(const variable & src_y, src_y += N * sizeof(T); src_uv += N * sizeof(T); - return std::make_tuple(std::move(y), - std::move(std::get<0>(uv_pair)), - std::move(std::get<1>(uv_pair))); + return std::make_tuple(std::move(y), std::move(std::get<0>(uv_pair)), std::move(std::get<1>(uv_pair))); } -template -std::tuple, - jit_kernel::variable> -JitConverter::unpack_uv(const variable & uv) { +template +std::tuple, jit_kernel::variable> JitConverter::unpack_uv( + const variable& uv) { auto u = var(); auto v = var(); - const uint8_t even_mask = 0xA0; // 0b10100000 - const uint8_t odd_mask = 0xF5; // 0b11110101 + const uint8_t even_mask = 0xA0; // 0b10100000 + const uint8_t odd_mask = 0xF5; // 0b11110101 - uni_vshufps(u, uv, uv, even_mask); // u = uv[0,0,2,2,4,4,6,6] - uni_vshufps(v, uv, uv, odd_mask); // v = uv[1,1,3,3,5,5,7,7] + uni_vshufps(u, uv, uv, even_mask); // u = uv[0,0,2,2,4,4,6,6] + uni_vshufps(v, uv, uv, odd_mask); // v = uv[1,1,3,3,5,5,7,7] return std::make_tuple(std::move(u), std::move(v)); } -template -const jit_uni_converter & jit_converter_create() { +template +const jit_uni_converter& jit_converter_create() { auto createKernel = []() { std::unique_ptr kernel; @@ -540,22 +523,21 @@ const jit_uni_converter & jit_converter_create() { return *kernel; } -template -const jit_uni_converter & jit_converter_get() { +template +const jit_uni_converter& jit_converter_get() { return jit_converter_create(); } -template +template class SinglePlaneConvert : public Converter { public: - SinglePlaneConvert(Node *node) - : Converter(node) { + SinglePlaneConvert(Node* node) : Converter(node) { jit_converter_create(); } void execute(dnnl::stream strm) override { - const auto & kernel = jit_converter_get(); - const auto & dims = inputDims(0); + const auto& kernel = jit_converter_get(); + const auto& dims = inputDims(0); const size_t batch_size = dims[N_DIM]; const size_t height = dims[H_DIM] * 2 / 3; @@ -574,23 +556,22 @@ class SinglePlaneConvert : public Converter { args.u = args.v = uv + batch * stride_uv + (h / 2) * width; args.dst = dst + (batch * width * height + h * width) * 3; args.width = width; - args.colorFormat = _colorFormat[0]; // The first byte is enough to determine the RGB or BGR format. + args.colorFormat = _colorFormat[0]; // The first byte is enough to determine the RGB or BGR format. kernel(args); }); } }; -template +template class TwoPlaneConvert : public Converter { public: - TwoPlaneConvert(Node *node) - : Converter(node) { + TwoPlaneConvert(Node* node) : Converter(node) { jit_converter_create(); } void execute(dnnl::stream strm) override { - const auto & kernel = jit_converter_get(); - const auto & dims = inputDims(0); + const auto& kernel = jit_converter_get(); + const auto& dims = inputDims(0); const size_t batch_size = dims[N_DIM]; const size_t height = dims[H_DIM]; @@ -609,46 +590,43 @@ class TwoPlaneConvert : public Converter { args.u = args.v = uv + batch * stride_uv + (h / 2) * width; args.dst = dst + (batch * width * height + h * width) * 3; args.width = width; - args.colorFormat = _colorFormat[0]; // The first byte is enough to determine the RGB or BGR format. + args.colorFormat = _colorFormat[0]; // The first byte is enough to determine the RGB or BGR format. kernel(args); }); } }; #endif -} // namespace nv12 +} // namespace nv12 namespace i420 { -ColorConvert::Converter::PrimitiveDescs supportedPrimitiveDescs(Node *node) { - const LayoutType layout = LayoutType::ncsp; // 0,1,2,3 +ColorConvert::Converter::PrimitiveDescs supportedPrimitiveDescs(Node* node) { + const LayoutType layout = LayoutType::ncsp; // 0,1,2,3 - const ov::element::Type precision = node->getOriginalInputPrecisionAtPort(0) == ov::element::u8 - ? ov::element::u8 - : ov::element::f32; + const ov::element::Type precision = + node->getOriginalInputPrecisionAtPort(0) == ov::element::u8 ? ov::element::u8 : ov::element::f32; ColorConvert::Converter::PrimitiveDescs descs; - descs.emplace_back(std::vector { node->getOriginalInputsNumber(), { layout, precision } }, - std::vector { { layout, precision } }, - mayiuse(cpu_isa_t::sse41) - ? impl_desc_type::jit_uni - : impl_desc_type::ref, - true); + descs.emplace_back(std::vector{node->getOriginalInputsNumber(), {layout, precision}}, + std::vector{{layout, precision}}, + mayiuse(cpu_isa_t::sse41) ? impl_desc_type::jit_uni : impl_desc_type::ref, + true); return descs; } -template +template class SinglePlaneConvert; -template +template class ThreePlaneConvert; class RefConverter : public Converter { public: - RefConverter(Node *node); + RefConverter(Node* node); protected: - template + template void convert(const T* y, const T* u, const T* v, @@ -660,15 +638,14 @@ class RefConverter : public Converter { size_t stride_uv); }; -RefConverter::RefConverter(Node *node) - : Converter(node) { - if (node->getOriginalInputsNumber() != (singlePlane() ? 1: 3)) +RefConverter::RefConverter(Node* node) : Converter(node) { + if (node->getOriginalInputsNumber() != (singlePlane() ? 1 : 3)) OPENVINO_THROW("I420Converter node has incorrect number of inputs"); if (!node->getOriginalOutputsNumber()) OPENVINO_THROW("I420Converter node has incorrect number of outputs"); } -template +template void RefConverter::convert(const T* y, const T* u, const T* v, @@ -699,13 +676,13 @@ void RefConverter::convert(const T* y, }); } -template +template class SinglePlaneConvert : public RefConverter { public: using RefConverter::RefConverter; void execute(dnnl::stream strm) override { - const auto & dims = inputDims(0); + const auto& dims = inputDims(0); const size_t batch_size = dims[N_DIM]; const size_t height = dims[H_DIM] * 2 / 3; @@ -716,22 +693,17 @@ class SinglePlaneConvert : public RefConverter { const T* v = y + 5 * width * height / 4; T* dst = static_cast(output(0)); - convert(y, u, v, dst, - batch_size, - height, - width, - height * width * 3 / 2, - height * width * 3 / 2); + convert(y, u, v, dst, batch_size, height, width, height * width * 3 / 2, height * width * 3 / 2); } }; -template +template class ThreePlaneConvert : public RefConverter { public: using RefConverter::RefConverter; void execute(dnnl::stream strm) override { - const auto & dims = inputDims(0); + const auto& dims = inputDims(0); const T* y = static_cast(input(0)); const T* u = static_cast(input(1)); @@ -742,34 +714,25 @@ class ThreePlaneConvert : public RefConverter { const size_t height = dims[H_DIM]; const size_t width = dims[W_DIM]; - convert(y, u, v, dst, - batch_size, - height, - width, - height * width, - height * width / 4); + convert(y, u, v, dst, batch_size, height, width, height * width, height * width / 4); } }; #if defined(OPENVINO_ARCH_X86_64) -template +template class JitConverter; -template +template class JitConverter : public jit_uni_converter { private: void generate() override; - std::tuple, - variable, - variable> - load_yuv(const variable & src_y, - const variable & src_u, - const variable & src_v); - void unpack_uv(const variable & u, - const variable & v); + std::tuple, variable, variable> load_yuv(const variable& src_y, + const variable& src_u, + const variable& src_v); + void unpack_uv(const variable& u, const variable& v); }; -template +template void JitConverter::generate() { preamble(); @@ -781,7 +744,7 @@ void JitConverter::generate() { auto width = arg(&Params::width); auto colorFormat = arg(&Params::colorFormat); - static const float data[8] = { 16.f, 128.f, 1.164f, 1.596f, 0.391f, 2.018f, 0.813f, 255.f }; + static const float data[8] = {16.f, 128.f, 1.164f, 1.596f, 0.391f, 2.018f, 0.813f, 255.f}; _consts = data; const size_t reg_capacity_log = static_cast(std::logb(N)); @@ -789,26 +752,29 @@ void JitConverter::generate() { width >>= reg_capacity_log; - foreach(0, width, [&](const Reg64 & idx) { + foreach (0, width, [&](const Reg64& idx) { auto yuv = load_yuv(src_y, src_u, src_v); // Aliases - const auto & y = std::get<0>(yuv); - const auto & u = std::get<1>(yuv); - const auto & v = std::get<2>(yuv); + const auto& y = std::get<0>(yuv); + const auto& u = std::get<1>(yuv); + const auto& v = std::get<2>(yuv); yuv_to_rgb(y, u, v, colorFormat, std::is_integral::value); - store(dst, y); dst += step; - store(dst, u); dst += step; - store(dst, v); dst += step; - }); + store(dst, y); + dst += step; + store(dst, u); + dst += step; + store(dst, v); + dst += step; + }) + ; mov(width, argPtr(&Params::width)); width &= N - 1; - _if(width != 0) - ._then([&] { + _if(width != 0)._then([&] { auto y = var(); auto u = var(); auto v = var(); @@ -829,13 +795,11 @@ void JitConverter::generate() { postamble(); } -template -std::tuple, - jit_kernel::variable, - jit_kernel::variable> -JitConverter::load_yuv(const variable & src_y, - const variable & src_u, - const variable & src_v) { +template +std::tuple, jit_kernel::variable, jit_kernel::variable> +JitConverter::load_yuv(const variable& src_y, + const variable& src_u, + const variable& src_v) { auto y = var(); auto u = var(); auto v = var(); @@ -853,16 +817,15 @@ JitConverter::load_yuv(const variable & src_y, return std::make_tuple(std::move(y), std::move(u), std::move(v)); } -template -void JitConverter::unpack_uv(const variable & u, - const variable & v) { - static const uint8_t order[] = { 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7 }; +template +void JitConverter::unpack_uv(const variable& u, const variable& v) { + static const uint8_t order[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7}; u.permute(order); v.permute(order); } -template -const jit_uni_converter & jit_converter_create() { +template +const jit_uni_converter& jit_converter_create() { auto createKernel = []() { std::unique_ptr kernel; @@ -890,22 +853,21 @@ const jit_uni_converter & jit_converter_create() { return *kernel; } -template -const jit_uni_converter & jit_converter_get() { +template +const jit_uni_converter& jit_converter_get() { return jit_converter_create(); } -template +template class SinglePlaneConvert : public Converter { public: - SinglePlaneConvert(Node *node) - : Converter(node) { + SinglePlaneConvert(Node* node) : Converter(node) { jit_converter_create(); } void execute(dnnl::stream strm) override { - const auto & kernel = jit_converter_get(); - const auto & dims = inputDims(0); + const auto& kernel = jit_converter_get(); + const auto& dims = inputDims(0); const size_t batch_size = dims[N_DIM]; const size_t height = dims[H_DIM] * 2 / 3; @@ -926,23 +888,22 @@ class SinglePlaneConvert : public Converter { args.v = v + batch * stride_uv + (h / 2) * (width / 2); args.dst = dst + (batch * width * height + h * width) * 3; args.width = width; - args.colorFormat = _colorFormat[0]; // The first byte is enough to determine the RGB or BGR format. + args.colorFormat = _colorFormat[0]; // The first byte is enough to determine the RGB or BGR format. kernel(args); }); } }; -template +template class ThreePlaneConvert : public Converter { public: - ThreePlaneConvert(Node *node) - : Converter(node) { + ThreePlaneConvert(Node* node) : Converter(node) { jit_converter_create(); } void execute(dnnl::stream strm) override { - const auto & kernel = jit_converter_get(); - const auto & dims = inputDims(0); + const auto& kernel = jit_converter_get(); + const auto& dims = inputDims(0); const T* y = static_cast(input(0)); const T* u = static_cast(input(1)); @@ -963,20 +924,19 @@ class ThreePlaneConvert : public Converter { args.v = v + batch * stride_uv + (h / 2) * (width / 2); args.dst = dst + (batch * width * height + h * width) * 3; args.width = width; - args.colorFormat = _colorFormat[0]; // The first byte is enough to determine the RGB or BGR format. + args.colorFormat = _colorFormat[0]; // The first byte is enough to determine the RGB or BGR format. kernel(args); }); } }; #endif -} // namespace i420 +} // namespace i420 -} // namespace +} // namespace -ColorConvert::Converter::Converter(Node *node, const ColorFormat & colorFormat) - : _node(node) - , _colorFormat(colorFormat) { -} +ColorConvert::Converter::Converter(Node* node, const ColorFormat& colorFormat) + : _node(node), + _colorFormat(colorFormat) {} ov::element::Type ColorConvert::Converter::inputPrecision(size_t idx) const { return _node->getParentEdgeAt(idx)->getMemory().getDesc().getPrecision(); @@ -986,15 +946,15 @@ ov::element::Type ColorConvert::Converter::outputPrecision(size_t idx) const { return _node->getChildEdgeAt(idx)->getMemory().getDesc().getPrecision(); } -const void * ColorConvert::Converter::input(size_t idx) const { +const void* ColorConvert::Converter::input(size_t idx) const { return _node->getSrcDataAtPort(idx); } -void * ColorConvert::Converter::output(size_t idx) const { +void* ColorConvert::Converter::output(size_t idx) const { return _node->getDstDataAtPort(idx); } -const VectorDims & ColorConvert::Converter::inputDims(size_t idx) const { +const VectorDims& ColorConvert::Converter::inputDims(size_t idx) const { return _node->getParentEdgeAt(idx)->getMemory().getStaticDims(); } @@ -1019,42 +979,42 @@ void ColorConvert::initSupportedPrimitiveDescriptors() { return; switch (algorithm) { - case Algorithm::ColorConvertNV12toRGB: - case Algorithm::ColorConvertNV12toBGR: { - for (const auto &desc : nv12::supportedPrimitiveDescs(this)) { - const auto & inPortConfigs = std::get<0>(desc); - const auto & outPortConfigs = std::get<1>(desc); - const auto implType = std::get<2>(desc); - addSupportedPrimDesc(inPortConfigs, outPortConfigs, implType); - } - initSupportedNV12Impls(); - break; + case Algorithm::ColorConvertNV12toRGB: + case Algorithm::ColorConvertNV12toBGR: { + for (const auto& desc : nv12::supportedPrimitiveDescs(this)) { + const auto& inPortConfigs = std::get<0>(desc); + const auto& outPortConfigs = std::get<1>(desc); + const auto implType = std::get<2>(desc); + addSupportedPrimDesc(inPortConfigs, outPortConfigs, implType); } - case Algorithm::ColorConvertI420toRGB: - case Algorithm::ColorConvertI420toBGR: { - for (const auto &desc : i420::supportedPrimitiveDescs(this)) { - const auto & inPortConfigs = std::get<0>(desc); - const auto & outPortConfigs = std::get<1>(desc); - const auto implType = std::get<2>(desc); - addSupportedPrimDesc(inPortConfigs, outPortConfigs, implType); - } - initSupportedI420Impls(); - break; + initSupportedNV12Impls(); + break; + } + case Algorithm::ColorConvertI420toRGB: + case Algorithm::ColorConvertI420toBGR: { + for (const auto& desc : i420::supportedPrimitiveDescs(this)) { + const auto& inPortConfigs = std::get<0>(desc); + const auto& outPortConfigs = std::get<1>(desc); + const auto implType = std::get<2>(desc); + addSupportedPrimDesc(inPortConfigs, outPortConfigs, implType); } - default: - break; + initSupportedI420Impls(); + break; + } + default: + break; } } void ColorConvert::initSupportedNV12Impls() { - #define SUPPORTED_IMPL(Impl, type, desc_type) \ - [](Node *node) { \ - return new nv12::Impl(node); \ - }; +#define SUPPORTED_IMPL(Impl, type, desc_type) \ + [](Node* node) { \ + return new nv12::Impl(node); \ + }; // ref { - auto &impls = _supportedImpls[impl_desc_type::ref][algorithm]; + auto& impls = _supportedImpls[impl_desc_type::ref][algorithm]; impls[ov::element::Type_t::u8][true] = SUPPORTED_IMPL(SinglePlaneConvert, uint8_t, ref); impls[ov::element::Type_t::u8][false] = SUPPORTED_IMPL(TwoPlaneConvert, uint8_t, ref); impls[ov::element::Type_t::f32][true] = SUPPORTED_IMPL(SinglePlaneConvert, float, ref); @@ -1064,25 +1024,25 @@ void ColorConvert::initSupportedNV12Impls() { #if defined(OPENVINO_ARCH_X86_64) // jit_uni { - auto &impls = _supportedImpls[impl_desc_type::jit_uni][algorithm]; + auto& impls = _supportedImpls[impl_desc_type::jit_uni][algorithm]; impls[ov::element::Type_t::u8][true] = SUPPORTED_IMPL(SinglePlaneConvert, uint8_t, jit_uni); impls[ov::element::Type_t::u8][false] = SUPPORTED_IMPL(TwoPlaneConvert, uint8_t, jit_uni); impls[ov::element::Type_t::f32][true] = SUPPORTED_IMPL(SinglePlaneConvert, float, jit_uni); impls[ov::element::Type_t::f32][false] = SUPPORTED_IMPL(TwoPlaneConvert, float, jit_uni); } #endif - #undef SUPPORTED_IMPL +#undef SUPPORTED_IMPL } void ColorConvert::initSupportedI420Impls() { - #define SUPPORTED_IMPL(Impl, type, desc_type) \ - [](Node *node) { \ - return new i420::Impl(node); \ - }; +#define SUPPORTED_IMPL(Impl, type, desc_type) \ + [](Node* node) { \ + return new i420::Impl(node); \ + }; // ref { - auto &impls = _supportedImpls[impl_desc_type::ref][algorithm]; + auto& impls = _supportedImpls[impl_desc_type::ref][algorithm]; impls[ov::element::Type_t::u8][true] = SUPPORTED_IMPL(SinglePlaneConvert, uint8_t, ref); impls[ov::element::Type_t::u8][false] = SUPPORTED_IMPL(ThreePlaneConvert, uint8_t, ref); impls[ov::element::Type_t::f32][true] = SUPPORTED_IMPL(SinglePlaneConvert, float, ref); @@ -1092,32 +1052,29 @@ void ColorConvert::initSupportedI420Impls() { #if defined(OPENVINO_ARCH_X86_64) // jit_uni { - auto &impls = _supportedImpls[impl_desc_type::jit_uni][algorithm]; + auto& impls = _supportedImpls[impl_desc_type::jit_uni][algorithm]; impls[ov::element::Type_t::u8][true] = SUPPORTED_IMPL(SinglePlaneConvert, uint8_t, jit_uni); impls[ov::element::Type_t::u8][false] = SUPPORTED_IMPL(ThreePlaneConvert, uint8_t, jit_uni); impls[ov::element::Type_t::f32][true] = SUPPORTED_IMPL(SinglePlaneConvert, float, jit_uni); impls[ov::element::Type_t::f32][false] = SUPPORTED_IMPL(ThreePlaneConvert, float, jit_uni); } #endif - #undef SUPPORTED_IMPL +#undef SUPPORTED_IMPL } void ColorConvert::createPrimitive() { - const NodeDesc *desc = getSelectedPrimitiveDescriptor(); + const NodeDesc* desc = getSelectedPrimitiveDescriptor(); if (!desc) OPENVINO_THROW(getTypeStr() + " node with name '" + getName() + "' ", "no optimal primitive descriptor selected"); if (!_impl) { - const auto & cfg = desc->getConfig(); + const auto& cfg = desc->getConfig(); const auto precision = cfg.inConfs[0].getMemDesc()->getPrecision(); const bool isSinglePlane = cfg.inConfs.size() == 1; - _impl = std::unique_ptr(_supportedImpls - .at(desc->getImplementationType()) - .at(algorithm) - .at(precision) - .at(isSinglePlane)(this)); + _impl = std::unique_ptr( + _supportedImpls.at(desc->getImplementationType()).at(algorithm).at(precision).at(isSinglePlane)(this)); } } @@ -1139,6 +1096,6 @@ void ColorConvert::executeDynamicImpl(dnnl::stream strm) { execute(strm); } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/color_convert.h b/src/plugins/intel_cpu/src/nodes/color_convert.h index 19df1209dd4bab..9bd27c7cf9dffa 100644 --- a/src/plugins/intel_cpu/src/nodes/color_convert.h +++ b/src/plugins/intel_cpu/src/nodes/color_convert.h @@ -5,10 +5,11 @@ #pragma once #include -#include + +#include #include #include -#include +#include namespace ov { namespace intel_cpu { @@ -35,11 +36,11 @@ class ColorConvert : public Node { void initSupportedI420Impls(); private: - using ConverterBuilder = std::function; - using SupportedImpls = multidim_map; + using SupportedImpls = multidim_map; std::unique_ptr _impl; @@ -48,10 +49,11 @@ class ColorConvert : public Node { class ColorConvert::Converter { public: - using PrimitiveDescs = std::vector, // Input port configurator - std::vector, // Output port configurator - impl_desc_type, // Implementation type - bool>>; // // true - SinglePlaneConvert, false - TwoPlaneConvert/ThreePlaneConvert + using PrimitiveDescs = + std::vector, // Input port configurator + std::vector, // Output port configurator + impl_desc_type, // Implementation type + bool>>; // // true - SinglePlaneConvert, false - TwoPlaneConvert/ThreePlaneConvert using Shapes = std::vector; static constexpr size_t N_DIM = 0; @@ -61,20 +63,20 @@ class ColorConvert::Converter { using ColorFormat = std::array; - Converter(Node *node, const ColorFormat & colorFormat); + Converter(Node* node, const ColorFormat& colorFormat); virtual ~Converter() = default; ov::element::Type inputPrecision(size_t idx) const; ov::element::Type outputPrecision(size_t idx) const; - const void * input(size_t idx) const; - void * output(size_t idx) const; - const VectorDims & inputDims(size_t idx) const; + const void* input(size_t idx) const; + void* output(size_t idx) const; + const VectorDims& inputDims(size_t idx) const; virtual void execute(dnnl::stream strm) = 0; protected: - Node *_node; - ColorFormat _colorFormat; // RGB: {0,1,2}, BGR: {2,1,0} + Node* _node; + ColorFormat _colorFormat; // RGB: {0,1,2}, BGR: {2,1,0} }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.cpp b/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.cpp index a7d3adc50d62e3..5887900ce8fa9e 100644 --- a/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.cpp +++ b/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.cpp @@ -3,26 +3,26 @@ // #include "arbitrary_order_desc_creator.h" + #include "utils/general_utils.h" namespace ov { namespace intel_cpu { -ArbitraryOrderDescCreator::ArbitraryOrderDescCreator(VectorDims order) : - m_order(std::move(order)) { +ArbitraryOrderDescCreator::ArbitraryOrderDescCreator(VectorDims order) : m_order(std::move(order)) { OPENVINO_ASSERT(std::adjacent_find(m_order.begin(), m_order.end()) == m_order.end(), - "Can't construct ArbitraryOrderDescCreator, order vector contains repetitive elements", - vec2str(m_order)); + "Can't construct ArbitraryOrderDescCreator, order vector contains repetitive elements", + vec2str(m_order)); } -CpuBlockedMemoryDesc -ArbitraryOrderDescCreator::createDesc(const ov::element::Type& precision, const Shape& srcShape) const { +CpuBlockedMemoryDesc ArbitraryOrderDescCreator::createDesc(const ov::element::Type& precision, + const Shape& srcShape) const { auto&& dims = srcShape.getDims(); OPENVINO_ASSERT(dims.size() == m_order.size(), - "Couldn't create a tensor descriptor, shape and order size mismatch. Shape: ", - vec2str(dims), - " order: ", - vec2str(m_order)); + "Couldn't create a tensor descriptor, shape and order size mismatch. Shape: ", + vec2str(dims), + " order: ", + vec2str(m_order)); VectorDims blkDims(dims.size()); for (size_t i = 0; i < dims.size(); ++i) { @@ -36,5 +36,5 @@ size_t ArbitraryOrderDescCreator::getMinimalRank() const { return m_order.size(); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.h b/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.h index aaf5a7d5560799..c7341169fd9187 100644 --- a/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.h +++ b/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.h @@ -20,5 +20,5 @@ class ArbitraryOrderDescCreator : public BlockedDescCreator { VectorDims m_order; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/blocked_desc_creator.cpp b/src/plugins/intel_cpu/src/nodes/common/blocked_desc_creator.cpp index 88c351ecafbdc1..a7398cac1e9940 100644 --- a/src/plugins/intel_cpu/src/nodes/common/blocked_desc_creator.cpp +++ b/src/plugins/intel_cpu/src/nodes/common/blocked_desc_creator.cpp @@ -3,9 +3,8 @@ // #include "blocked_desc_creator.h" -#include - +#include namespace ov { namespace intel_cpu { @@ -15,17 +14,19 @@ constexpr size_t channelsPos = 1lu; class PlainFormatCreator : public BlockedDescCreator { public: - CpuBlockedMemoryDesc createDesc(const ov::element::Type &precision, const Shape& srcShape) const override { + CpuBlockedMemoryDesc createDesc(const ov::element::Type& precision, const Shape& srcShape) const override { VectorDims order(srcShape.getRank()); std::iota(order.begin(), order.end(), 0); return CpuBlockedMemoryDesc(precision, srcShape, srcShape.getDims(), order); } - size_t getMinimalRank() const override { return 0lu; } + size_t getMinimalRank() const override { + return 0lu; + } }; class PerChannelCreator : public BlockedDescCreator { public: - CpuBlockedMemoryDesc createDesc(const ov::element::Type &precision, const Shape& srcShape) const override { + CpuBlockedMemoryDesc createDesc(const ov::element::Type& precision, const Shape& srcShape) const override { VectorDims order(srcShape.getRank()); std::iota(order.begin(), order.end(), 0); VectorDims blkDims = srcShape.getDims(); @@ -41,7 +42,9 @@ class PerChannelCreator : public BlockedDescCreator { return CpuBlockedMemoryDesc(precision, srcShape, blkDims, order); } - size_t getMinimalRank() const override { return 3lu; } + size_t getMinimalRank() const override { + return 3lu; + } }; class ChannelBlockedCreator : public BlockedDescCreator { @@ -64,24 +67,27 @@ class ChannelBlockedCreator : public BlockedDescCreator { return CpuBlockedMemoryDesc(precision, srcShape, blkDims, order); } - size_t getMinimalRank() const override { return 3lu; } + size_t getMinimalRank() const override { + return 3lu; + } private: size_t _blockSize; }; -} // namespace +} // namespace const BlockedDescCreator::CreatorsMap& BlockedDescCreator::getCommonCreators() { - static const CreatorsMap map{ { LayoutType::nspc, CreatorConstPtr(new PerChannelCreator) }, - { LayoutType::nCsp8c, CreatorConstPtr(new ChannelBlockedCreator(8)) }, - { LayoutType::nCsp16c, CreatorConstPtr(new ChannelBlockedCreator(16)) }, - { LayoutType::ncsp, CreatorConstPtr(new PlainFormatCreator) } }; + static const CreatorsMap map{{LayoutType::nspc, CreatorConstPtr(new PerChannelCreator)}, + {LayoutType::nCsp8c, CreatorConstPtr(new ChannelBlockedCreator(8))}, + {LayoutType::nCsp16c, CreatorConstPtr(new ChannelBlockedCreator(16))}, + {LayoutType::ncsp, CreatorConstPtr(new PlainFormatCreator)}}; return map; } -std::pair -BlockedDescCreator::makeFilteredRange(const CreatorsMap &map, unsigned int rank) { +std::pair BlockedDescCreator::makeFilteredRange( + const CreatorsMap& map, + unsigned int rank) { auto rankFilter = [rank](const CreatorsMap::value_type& item) { if (item.second->getMinimalRank() > rank) { return false; @@ -94,8 +100,10 @@ BlockedDescCreator::makeFilteredRange(const CreatorsMap &map, unsigned int rank) return std::make_pair(first, last); } -std::pair -BlockedDescCreator::makeFilteredRange(const CreatorsMap& map, unsigned rank, const std::vector& supportedTypes) { +std::pair BlockedDescCreator::makeFilteredRange( + const CreatorsMap& map, + unsigned rank, + const std::vector& supportedTypes) { unsigned bitMask = 0ul; for (auto& item : supportedTypes) { bitMask |= 1 << static_cast(item); @@ -116,12 +124,13 @@ BlockedDescCreator::makeFilteredRange(const CreatorsMap& map, unsigned rank, con return std::make_pair(first, last); } -std::pair -BlockedDescCreator::makeFilteredRange(const CreatorsMap &map, BlockedDescCreator::Predicate predicate) { +std::pair BlockedDescCreator::makeFilteredRange( + const CreatorsMap& map, + BlockedDescCreator::Predicate predicate) { auto first = CreatorsMapFilterConstIterator(std::move(predicate), map.begin(), map.end()); auto last = first.end(); return std::make_pair(first, last); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/blocked_desc_creator.h b/src/plugins/intel_cpu/src/nodes/common/blocked_desc_creator.h index 1fd7a02dff984b..9f8b15b430c727 100644 --- a/src/plugins/intel_cpu/src/nodes/common/blocked_desc_creator.h +++ b/src/plugins/intel_cpu/src/nodes/common/blocked_desc_creator.h @@ -5,6 +5,7 @@ #pragma once #include + #include "cpu_shape.h" #include "memory_desc/cpu_blocked_memory_desc.h" @@ -22,15 +23,18 @@ class BlockedDescCreator { public: static const CreatorsMap& getCommonCreators(); - static std::pair - makeFilteredRange(const CreatorsMap &map, unsigned rank); + static std::pair makeFilteredRange( + const CreatorsMap& map, + unsigned rank); static std::pair makeFilteredRange(const CreatorsMap& map, unsigned rank, const std::vector& supportedTypes); - static std::pair - makeFilteredRange(const CreatorsMap& map, Predicate predicate); + static std::pair makeFilteredRange( + const CreatorsMap& map, + Predicate predicate); virtual CpuBlockedMemoryDesc createDesc(const ov::element::Type& precision, const Shape& srcShape) const = 0; - std::shared_ptr createSharedDesc(const ov::element::Type& precision, const Shape& srcShape) const { + std::shared_ptr createSharedDesc(const ov::element::Type& precision, + const Shape& srcShape) const { return std::make_shared(createDesc(precision, srcShape)); } @@ -49,7 +53,10 @@ class CreatorsMapFilterConstIterator { typedef std::function predicate_type; public: - CreatorsMapFilterConstIterator(predicate_type filter, Iterator begin, Iterator end) : _iter(begin), _end(end), _filter(std::move(filter)) { + CreatorsMapFilterConstIterator(predicate_type filter, Iterator begin, Iterator end) + : _iter(begin), + _end(end), + _filter(std::move(filter)) { while (_iter != _end && !_filter(*_iter)) { ++_iter; } @@ -93,5 +100,5 @@ class CreatorsMapFilterConstIterator { predicate_type _filter; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/cpu_convert.cpp b/src/plugins/intel_cpu/src/nodes/common/cpu_convert.cpp index ad0738e9d57558..a0590827006eb4 100644 --- a/src/plugins/intel_cpu/src/nodes/common/cpu_convert.cpp +++ b/src/plugins/intel_cpu/src/nodes/common/cpu_convert.cpp @@ -5,16 +5,16 @@ #include "cpu_convert.h" #include "cpu_memcpy.h" -#include "utils/bfloat16.hpp" #include "openvino/core/type/nf4.hpp" +#include "utils/bfloat16.hpp" #if defined(OPENVINO_ARCH_X86_64) -#include "nodes/kernels/x64/jit_kernel.hpp" +# include "nodes/kernels/x64/jit_kernel.hpp" #else -#include "cpu_memory.h" -#include "openvino/core/type/element_type_traits.hpp" -#include "selective_build.h" -#include "utils/general_utils.h" +# include "cpu_memory.h" +# include "openvino/core/type/element_type_traits.hpp" +# include "selective_build.h" +# include "utils/general_utils.h" #endif namespace ov { @@ -28,16 +28,12 @@ using namespace dnnl::impl::cpu::x64; using namespace Xbyak; template -void convert_vec(jit_generator & gen, - const RegExp & src, - const RegExp & dst); +void convert_vec(jit_generator& gen, const RegExp& src, const RegExp& dst); template <> -void convert_vec(jit_generator & gen, - const RegExp & src, - const RegExp & dst) { - auto const & f16vec = gen.xmm3; - auto const & f32vec = gen.ymm4; +void convert_vec(jit_generator& gen, const RegExp& src, const RegExp& dst) { + auto const& f16vec = gen.xmm3; + auto const& f32vec = gen.ymm4; gen.movdqu(f16vec, gen.xword[src]); gen.vcvtph2ps(f32vec, f16vec); @@ -45,11 +41,9 @@ void convert_vec(jit_generator & gen, } template <> -void convert_vec(jit_generator & gen, - const RegExp & src, - const RegExp & dst) { - auto const & f16vec = gen.xmm3; - auto const & f32vec = gen.ymm4; +void convert_vec(jit_generator& gen, const RegExp& src, const RegExp& dst) { + auto const& f16vec = gen.xmm3; + auto const& f32vec = gen.ymm4; gen.vmovups(f32vec, gen.yword[src]); gen.vcvtps2ph(f16vec, f32vec, 0); @@ -72,18 +66,18 @@ class jit_convert_array : public jit_kernel { size >>= vlen_log2; - foreach(0, size, [&, this](const Xbyak::Reg64& idx) { + foreach (0, size, [&, this](const Xbyak::Reg64& idx) { _convert_vec(*this, src, dst); src += _src_size * vlen; dst += _dst_size * vlen; - }); + }) + ; mov(size, argPtr(&args_t::count)); size &= vlen - 1; // Tail conversion - _if(size != 0) - ._then([&] { + _if(size != 0)._then([&] { auto tmp = stack(vlen * sizeof(float)); tmp.clear(); @@ -112,24 +106,19 @@ class jit_convert_array : public jit_kernel { typedef void (*fn_t)(const args_t*); - typedef void (*convert_vec_t)(jit_generator &, - const RegExp &, - const RegExp &); + typedef void (*convert_vec_t)(jit_generator&, const RegExp&, const RegExp&); - jit_convert_array(convert_vec_t convert_vec, - size_t src_size, - size_t dst_size) - : jit_kernel(jit_name()) - , _convert_vec(convert_vec) - , _src_size(src_size) - , _dst_size(dst_size) {} + jit_convert_array(convert_vec_t convert_vec, size_t src_size, size_t dst_size) + : jit_kernel(jit_name()), + _convert_vec(convert_vec), + _src_size(src_size), + _dst_size(dst_size) {} - template + template static fn_t get() { - if (mayiuse(cpu_isa_t::avx2) - && dnnl::impl::cpu::x64::cpu().has(Xbyak::util::Cpu::tF16C)) { + if (mayiuse(cpu_isa_t::avx2) && dnnl::impl::cpu::x64::cpu().has(Xbyak::util::Cpu::tF16C)) { static jit_convert_array converter(convert_vec, sizeof(src_t), sizeof(dst_t)); - auto & generator = static_cast(converter); + auto& generator = static_cast(converter); generator.create_kernel(); return (fn_t)generator.jit_ker(); } @@ -148,7 +137,7 @@ void jit_convert(const TI* arg, TO* out, size_t count) { static auto converter = jit_impl::get(); if (converter) { - typename jit_impl::args_t args = { arg, out, count }; + typename jit_impl::args_t args = {arg, out, count}; converter(&args); } else { for (size_t i = 0; i < count; ++i) { @@ -179,44 +168,41 @@ struct PrecisionInfo { using value_type = uint8_t; }; -template::value - || std::is_same::value, - float, T>::type> +template ::value || + std::is_same::value, + float, + T>::type> struct Range { - const std::tuple & fit(const ov::element::Type & prec); + const std::tuple& fit(const ov::element::Type& prec); private: - std::tuple _range { - std::numeric_limits::lowest(), - std::numeric_limits::max() - }; + std::tuple _range{std::numeric_limits::lowest(), std::numeric_limits::max()}; }; -template -const std::tuple & Range::fit(const ov::element::Type & prec) { +template +const std::tuple& Range::fit(const ov::element::Type& prec) { if (prec.is_real()) { double lbound, ubound; switch (prec) { - case ov::element::bf16: - lbound = static_cast(std::numeric_limits::lowest()); - ubound = static_cast(std::numeric_limits::max()); - break; - case ov::element::f16: - lbound = static_cast(std::numeric_limits::lowest()); - ubound = static_cast(std::numeric_limits::max()); - break; - case ov::element::f32: - lbound = static_cast(std::numeric_limits::lowest()); - ubound = static_cast(std::numeric_limits::max()); - break; - case ov::element::f64: - lbound = std::numeric_limits::lowest(); - ubound = std::numeric_limits::max(); - break; - default: - OPENVINO_THROW("Unsupported precision"); + case ov::element::bf16: + lbound = static_cast(std::numeric_limits::lowest()); + ubound = static_cast(std::numeric_limits::max()); + break; + case ov::element::f16: + lbound = static_cast(std::numeric_limits::lowest()); + ubound = static_cast(std::numeric_limits::max()); + break; + case ov::element::f32: + lbound = static_cast(std::numeric_limits::lowest()); + ubound = static_cast(std::numeric_limits::max()); + break; + case ov::element::f64: + lbound = std::numeric_limits::lowest(); + ubound = std::numeric_limits::max(); + break; + default: + OPENVINO_THROW("Unsupported precision"); } // If U is integral, its range always less than float, so not need update _range // Else it will be overflow, for example static_cast double to int64_t: @@ -224,73 +210,71 @@ const std::tuple & Range::fit(const ov::element::Type & prec) { // double dd_ubound = static_cast(ubbound) // static_cast(dd_ubound) will return -9223372036854775808 if (!std::is_integral::value) { - std::get<0>(_range) = static_cast(std::max(static_cast(std::get<0>(_range)), lbound)); - std::get<1>(_range) = static_cast(std::min(static_cast(std::get<1>(_range)), ubound)); + std::get<0>(_range) = static_cast(std::max(static_cast(std::get<0>(_range)), lbound)); + std::get<1>(_range) = static_cast(std::min(static_cast(std::get<1>(_range)), ubound)); } } else { int64_t lbound; uint64_t ubound; switch (prec) { - case ov::element::boolean: - lbound = static_cast(std::numeric_limits::lowest()); - ubound = static_cast(std::numeric_limits::max()); - break; - case ov::element::u8: - lbound = static_cast(std::numeric_limits::lowest()); - ubound = static_cast(std::numeric_limits::max()); - break; - case ov::element::i8: - lbound = static_cast(std::numeric_limits::lowest()); - ubound = static_cast(std::numeric_limits::max()); - break; - case ov::element::u16: - lbound = static_cast(std::numeric_limits::lowest()); - ubound = static_cast(std::numeric_limits::max()); - break; - case ov::element::i16: - lbound = static_cast(std::numeric_limits::lowest()); - ubound = static_cast(std::numeric_limits::max()); - break; - case ov::element::u32: - lbound = static_cast(std::numeric_limits::lowest()); - ubound = static_cast(std::numeric_limits::max()); - break; - case ov::element::i32: - lbound = static_cast(std::numeric_limits::lowest()); - ubound = static_cast(std::numeric_limits::max()); - break; - case ov::element::u64: - lbound = static_cast(std::numeric_limits::lowest()); - ubound = static_cast(std::numeric_limits::max()); - break; - case ov::element::i64: - lbound = static_cast(std::numeric_limits::lowest()); - ubound = static_cast(std::numeric_limits::max()); - break; - default: - OPENVINO_THROW("Unsupported precision"); + case ov::element::boolean: + lbound = static_cast(std::numeric_limits::lowest()); + ubound = static_cast(std::numeric_limits::max()); + break; + case ov::element::u8: + lbound = static_cast(std::numeric_limits::lowest()); + ubound = static_cast(std::numeric_limits::max()); + break; + case ov::element::i8: + lbound = static_cast(std::numeric_limits::lowest()); + ubound = static_cast(std::numeric_limits::max()); + break; + case ov::element::u16: + lbound = static_cast(std::numeric_limits::lowest()); + ubound = static_cast(std::numeric_limits::max()); + break; + case ov::element::i16: + lbound = static_cast(std::numeric_limits::lowest()); + ubound = static_cast(std::numeric_limits::max()); + break; + case ov::element::u32: + lbound = static_cast(std::numeric_limits::lowest()); + ubound = static_cast(std::numeric_limits::max()); + break; + case ov::element::i32: + lbound = static_cast(std::numeric_limits::lowest()); + ubound = static_cast(std::numeric_limits::max()); + break; + case ov::element::u64: + lbound = static_cast(std::numeric_limits::lowest()); + ubound = static_cast(std::numeric_limits::max()); + break; + case ov::element::i64: + lbound = static_cast(std::numeric_limits::lowest()); + ubound = static_cast(std::numeric_limits::max()); + break; + default: + OPENVINO_THROW("Unsupported precision"); } - using ltype = typename std::conditional< - std::is_floating_point::value, - double, int64_t>::type; - using utype = typename std::conditional< - std::is_floating_point::value, - double, uint64_t>::type; - std::get<0>(_range) = static_cast(std::max(static_cast(std::get<0>(_range)), static_cast(lbound))); - std::get<1>(_range) = static_cast(std::min(static_cast(std::get<1>(_range)), static_cast(ubound))); + using ltype = typename std::conditional::value, double, int64_t>::type; + using utype = typename std::conditional::value, double, uint64_t>::type; + std::get<0>(_range) = + static_cast(std::max(static_cast(std::get<0>(_range)), static_cast(lbound))); + std::get<1>(_range) = + static_cast(std::min(static_cast(std::get<1>(_range)), static_cast(ubound))); } return _range; } struct ConvertContext { - const void *srcPtr; - void *dstPtr; + const void* srcPtr; + void* dstPtr; size_t size; ov::element::Type interimPrc; ov::element::Type dstPrc; bool converted; - template + template std::tuple range() const { Range r; r.fit(interimPrc); @@ -298,20 +282,18 @@ struct ConvertContext { } }; -template +template struct ConvertPrecision; -template +template struct ConvertPrecision> { - void operator()(ConvertContext & ctx) { - auto src = static_cast(ctx.srcPtr); - auto dst = static_cast(ctx.dstPtr); + void operator()(ConvertContext& ctx) { + auto src = static_cast(ctx.srcPtr); + auto dst = static_cast(ctx.dstPtr); src_t lbound, ubound; std::tie(lbound, ubound) = ctx.range(); - if (std::is_integral::value - || ctx.interimPrc.is_real() - || std::is_integral::value) { + if (std::is_integral::value || ctx.interimPrc.is_real() || std::is_integral::value) { parallel_for(ctx.size, [&](size_t i) { dst[i] = static_cast(std::max(std::min(src[i], ubound), lbound)); }); @@ -325,11 +307,11 @@ struct ConvertPrecision> { } }; -template<> +template <> struct ConvertPrecision> { - void operator()(ConvertContext & ctx) { - auto src = static_cast(ctx.srcPtr); - auto dst = static_cast(ctx.dstPtr); + void operator()(ConvertContext& ctx) { + auto src = static_cast(ctx.srcPtr); + auto dst = static_cast(ctx.dstPtr); if (ctx.interimPrc.is_real()) { parallel_for(ctx.size, [&](size_t i) { @@ -347,11 +329,11 @@ struct ConvertPrecision> { } }; -template<> +template <> struct ConvertPrecision> { - void operator()(ConvertContext & ctx) { - auto src = static_cast(ctx.srcPtr); - auto dst = static_cast(ctx.dstPtr); + void operator()(ConvertContext& ctx) { + auto src = static_cast(ctx.srcPtr); + auto dst = static_cast(ctx.dstPtr); if (ctx.interimPrc.is_real()) { parallel_for(ctx.size, [&](size_t i) { @@ -370,11 +352,11 @@ struct ConvertPrecision> { }; #if defined(OPENVINO_ARCH_X86_64) -template +template struct ConvertPrecision> { - void operator()(ConvertContext & ctx) { - auto src = static_cast(ctx.srcPtr); - auto dst = static_cast(ctx.dstPtr); + void operator()(ConvertContext& ctx) { + auto src = static_cast(ctx.srcPtr); + auto dst = static_cast(ctx.dstPtr); constexpr size_t batch = 64; const size_t iterations = ov::intel_cpu::div_up(ctx.size, batch); @@ -388,16 +370,16 @@ struct ConvertPrecision> { batch_type tmp; const size_t offset = i * batch; const size_t current_batch_size = std::min(ctx.size - offset, batch); - for (size_t j = 0; j < current_batch_size; ++j) // src_t -> fp32 + for (size_t j = 0; j < current_batch_size; ++j) // src_t -> fp32 tmp[j] = static_cast(std::max(std::min(src[offset + j], ubound), lbound)); - jit_convert(tmp, dst + offset, current_batch_size); // fp32 -> fp16 + jit_convert(tmp, dst + offset, current_batch_size); // fp32 -> fp16 }); } else if (ctx.interimPrc.is_real()) { parallel_for(iterations, [&](size_t i) { const size_t offset = i * batch; const size_t current_batch_size = std::min(ctx.size - offset, batch); if (std::is_same::type, float>::value) { // fp32 -> fp16 - jit_convert(reinterpret_cast(src) + offset, dst + offset, current_batch_size); + jit_convert(reinterpret_cast(src) + offset, dst + offset, current_batch_size); } else { batch_type tmp; for (size_t j = 0; j < current_batch_size; ++j) // src_t -> fp32 @@ -410,9 +392,9 @@ struct ConvertPrecision> { batch_type tmp; const size_t offset = i * batch; const size_t current_batch_size = std::min(ctx.size - offset, batch); - for (size_t j = 0; j < current_batch_size; ++j) // src_t -> fp32 + for (size_t j = 0; j < current_batch_size; ++j) // src_t -> fp32 tmp[j] = static_cast(std::trunc(std::max(std::min(src[offset + j], ubound), lbound))); - jit_convert(tmp, dst + offset, current_batch_size); // fp32 -> fp16 + jit_convert(tmp, dst + offset, current_batch_size); // fp32 -> fp16 }); } @@ -420,11 +402,11 @@ struct ConvertPrecision> { } }; -template +template struct ConvertPrecision> { - void operator()(ConvertContext & ctx) { - auto src = static_cast(ctx.srcPtr); - auto dst = static_cast(ctx.dstPtr); + void operator()(ConvertContext& ctx) { + auto src = static_cast(ctx.srcPtr); + auto dst = static_cast(ctx.dstPtr); constexpr size_t batch = 64; const size_t iterations = ov::intel_cpu::div_up(ctx.size, batch); @@ -438,8 +420,8 @@ struct ConvertPrecision> { batch_type tmp; const size_t offset = i * batch; const size_t current_batch_size = std::min(ctx.size - offset, batch); - jit_convert(src + offset, tmp, current_batch_size); // fp16 -> fp32 - for (size_t j = 0; j < current_batch_size; ++j) // fp32 -> dst_t + jit_convert(src + offset, tmp, current_batch_size); // fp16 -> fp32 + for (size_t j = 0; j < current_batch_size; ++j) // fp32 -> dst_t dst[offset + j] = static_cast(std::max(std::min(tmp[j], ubound), lbound)); }); } else if (ctx.interimPrc.is_real()) { @@ -447,7 +429,7 @@ struct ConvertPrecision> { const size_t offset = i * batch; const size_t current_batch_size = std::min(ctx.size - offset, batch); if (std::is_same::type, float>::value) { // fp16 -> fp32 - jit_convert(src + offset, reinterpret_cast(dst) + offset, current_batch_size); + jit_convert(src + offset, reinterpret_cast(dst) + offset, current_batch_size); } else { batch_type tmp; jit_convert(src + offset, tmp, current_batch_size); // fp16 -> fp32 @@ -460,8 +442,8 @@ struct ConvertPrecision> { batch_type tmp; const size_t offset = i * batch; const size_t current_batch_size = std::min(ctx.size - offset, batch); - jit_convert(src + offset, tmp, current_batch_size); // fp16 -> fp32 - for (size_t j = 0; j < current_batch_size; ++j) // fp32 -> dst_t + jit_convert(src + offset, tmp, current_batch_size); // fp16 -> fp32 + for (size_t j = 0; j < current_batch_size; ++j) // fp32 -> dst_t dst[offset + j] = static_cast(std::trunc(std::max(std::min(tmp[j], ubound), lbound))); }); } @@ -470,11 +452,11 @@ struct ConvertPrecision> { } }; -template<> +template <> struct ConvertPrecision> { - void operator()(ConvertContext & ctx) { - auto src = static_cast(ctx.srcPtr); - auto dst = static_cast(ctx.dstPtr); + void operator()(ConvertContext& ctx) { + auto src = static_cast(ctx.srcPtr); + auto dst = static_cast(ctx.dstPtr); constexpr size_t batch = 64; const size_t iterations = ov::intel_cpu::div_up(ctx.size, batch); @@ -490,10 +472,10 @@ struct ConvertPrecision> { batch_type tmp; const size_t offset = i * batch; const size_t current_batch_size = std::min(ctx.size - offset, batch); - jit_convert(src + offset, tmp, current_batch_size); // fp16 -> fp32 - for (size_t j = 0; j < current_batch_size; ++j) // truncate fp32 + jit_convert(src + offset, tmp, current_batch_size); // fp16 -> fp32 + for (size_t j = 0; j < current_batch_size; ++j) // truncate fp32 tmp[j] = std::trunc(std::max(std::min(tmp[j], ubound), lbound)); - jit_convert(tmp, dst + offset, current_batch_size); // fp32 -> fp16 + jit_convert(tmp, dst + offset, current_batch_size); // fp32 -> fp16 }); } @@ -502,7 +484,7 @@ struct ConvertPrecision> { }; #endif -} // namespace +} // namespace #define INTEL_CPU_CVT(ST, DT) \ OV_CASE2(ov::element::ST, \ @@ -510,74 +492,72 @@ struct ConvertPrecision> { PrecisionInfo::value_type, \ PrecisionInfo::value_type) -#define INTEL_CPU_CVT_LIST \ - INTEL_CPU_CVT(u8, i8), INTEL_CPU_CVT(u8, u16), INTEL_CPU_CVT(u8, i16), INTEL_CPU_CVT(u8, u32), \ - INTEL_CPU_CVT(u8, i32), INTEL_CPU_CVT(u8, u64), INTEL_CPU_CVT(u8, i64), INTEL_CPU_CVT(u8, f32), \ - INTEL_CPU_CVT(u8, f16), INTEL_CPU_CVT(u8, bf16), INTEL_CPU_CVT(u8, f64), INTEL_CPU_CVT(u8, boolean), \ - INTEL_CPU_CVT(i8, u8), INTEL_CPU_CVT(i8, u16), INTEL_CPU_CVT(i8, i16), INTEL_CPU_CVT(i8, u32), \ - INTEL_CPU_CVT(i8, i32), INTEL_CPU_CVT(i8, u64), INTEL_CPU_CVT(i8, i64), INTEL_CPU_CVT(i8, f32), \ - INTEL_CPU_CVT(i8, f16), INTEL_CPU_CVT(i8, bf16), INTEL_CPU_CVT(i8, f64), INTEL_CPU_CVT(i8, boolean), \ - INTEL_CPU_CVT(u16, u8), INTEL_CPU_CVT(u16, i8), INTEL_CPU_CVT(u16, i16), INTEL_CPU_CVT(u16, u32), \ - INTEL_CPU_CVT(u16, i32), INTEL_CPU_CVT(u16, u64), INTEL_CPU_CVT(u16, i64), INTEL_CPU_CVT(u16, f32), \ - INTEL_CPU_CVT(u16, f16), INTEL_CPU_CVT(u16, bf16), INTEL_CPU_CVT(u16, f64), INTEL_CPU_CVT(u16, boolean), \ - INTEL_CPU_CVT(i16, u8), INTEL_CPU_CVT(i16, i8), INTEL_CPU_CVT(i16, u16), INTEL_CPU_CVT(i16, u32), \ - INTEL_CPU_CVT(i16, i32), INTEL_CPU_CVT(i16, u64), INTEL_CPU_CVT(i16, i64), INTEL_CPU_CVT(i16, f32), \ - INTEL_CPU_CVT(i16, f16), INTEL_CPU_CVT(i16, bf16), INTEL_CPU_CVT(i16, f64), INTEL_CPU_CVT(i16, boolean), \ - INTEL_CPU_CVT(u32, u8), INTEL_CPU_CVT(u32, i8), INTEL_CPU_CVT(u32, u16), INTEL_CPU_CVT(u32, i16), \ - INTEL_CPU_CVT(u32, i32), INTEL_CPU_CVT(u32, u64), INTEL_CPU_CVT(u32, i64), INTEL_CPU_CVT(u32, f32), \ - INTEL_CPU_CVT(u32, f16), INTEL_CPU_CVT(u32, bf16), INTEL_CPU_CVT(u32, f64), INTEL_CPU_CVT(u32, boolean), \ - INTEL_CPU_CVT(i32, u8), INTEL_CPU_CVT(i32, i8), INTEL_CPU_CVT(i32, u16), INTEL_CPU_CVT(i32, i16), \ - INTEL_CPU_CVT(i32, u32), INTEL_CPU_CVT(i32, u64), INTEL_CPU_CVT(i32, i64), INTEL_CPU_CVT(i32, f32), \ - INTEL_CPU_CVT(i32, f16), INTEL_CPU_CVT(i32, bf16), INTEL_CPU_CVT(i32, f64), INTEL_CPU_CVT(i32, boolean), \ - INTEL_CPU_CVT(u64, u8), INTEL_CPU_CVT(u64, i8), INTEL_CPU_CVT(u64, u16), INTEL_CPU_CVT(u64, i16), \ - INTEL_CPU_CVT(u64, u32), INTEL_CPU_CVT(u64, i32), INTEL_CPU_CVT(u64, i64), INTEL_CPU_CVT(u64, f32), \ - INTEL_CPU_CVT(u64, f16), INTEL_CPU_CVT(u64, bf16), INTEL_CPU_CVT(u64, f64), INTEL_CPU_CVT(u64, boolean), \ - INTEL_CPU_CVT(i64, u8), INTEL_CPU_CVT(i64, i8), INTEL_CPU_CVT(i64, u16), INTEL_CPU_CVT(i64, i16), \ - INTEL_CPU_CVT(i64, u32), INTEL_CPU_CVT(i64, i32), INTEL_CPU_CVT(i64, u64), INTEL_CPU_CVT(i64, f32), \ - INTEL_CPU_CVT(i64, f16), INTEL_CPU_CVT(i64, bf16), INTEL_CPU_CVT(i64, f64), INTEL_CPU_CVT(i64, boolean), \ - INTEL_CPU_CVT(f32, u8), INTEL_CPU_CVT(f32, i8), INTEL_CPU_CVT(f32, u16), INTEL_CPU_CVT(f32, i16), \ - INTEL_CPU_CVT(f32, u32), INTEL_CPU_CVT(f32, i32), INTEL_CPU_CVT(f32, u64), INTEL_CPU_CVT(f32, i64), \ - INTEL_CPU_CVT(f32, f16), INTEL_CPU_CVT(f32, bf16), INTEL_CPU_CVT(f32, f64), INTEL_CPU_CVT(f32, boolean), \ - INTEL_CPU_CVT(f16, u8), INTEL_CPU_CVT(f16, i8), INTEL_CPU_CVT(f16, u16), INTEL_CPU_CVT(f16, i16), \ - INTEL_CPU_CVT(f16, u32), INTEL_CPU_CVT(f16, i32), INTEL_CPU_CVT(f16, u64), INTEL_CPU_CVT(f16, i64), \ - INTEL_CPU_CVT(f16, f32), INTEL_CPU_CVT(f16, bf16), INTEL_CPU_CVT(f16, f64), INTEL_CPU_CVT(f16, boolean), \ - INTEL_CPU_CVT(bf16, u8), INTEL_CPU_CVT(bf16, i8), INTEL_CPU_CVT(bf16, u16), INTEL_CPU_CVT(bf16, i16), \ - INTEL_CPU_CVT(bf16, u32), INTEL_CPU_CVT(bf16, i32), INTEL_CPU_CVT(bf16, u64), INTEL_CPU_CVT(bf16, i64), \ - INTEL_CPU_CVT(bf16, f32), INTEL_CPU_CVT(bf16, f16), INTEL_CPU_CVT(bf16, f64), INTEL_CPU_CVT(bf16, boolean), \ - INTEL_CPU_CVT(f64, u8), INTEL_CPU_CVT(f64, i8), INTEL_CPU_CVT(f64, u16), INTEL_CPU_CVT(f64, i16), \ - INTEL_CPU_CVT(f64, u32), INTEL_CPU_CVT(f64, i32), INTEL_CPU_CVT(f64, u64), INTEL_CPU_CVT(f64, i64), \ - INTEL_CPU_CVT(f64, f32), INTEL_CPU_CVT(f64, f16), INTEL_CPU_CVT(f64, bf16), INTEL_CPU_CVT(f64, boolean), \ - INTEL_CPU_CVT(boolean, u8), INTEL_CPU_CVT(boolean, i8), INTEL_CPU_CVT(boolean, u16), \ - INTEL_CPU_CVT(boolean, i16), INTEL_CPU_CVT(boolean, u32), INTEL_CPU_CVT(boolean, i32), \ - INTEL_CPU_CVT(boolean, u64), INTEL_CPU_CVT(boolean, i64), INTEL_CPU_CVT(boolean, f32), \ - INTEL_CPU_CVT(boolean, f16), INTEL_CPU_CVT(boolean, bf16), INTEL_CPU_CVT(boolean, f64), INTEL_CPU_CVT(u8, u8), \ - INTEL_CPU_CVT(i8, i8), INTEL_CPU_CVT(u16, u16), INTEL_CPU_CVT(i16, i16), INTEL_CPU_CVT(u32, u32), \ - INTEL_CPU_CVT(i32, i32), INTEL_CPU_CVT(u64, u64), INTEL_CPU_CVT(i64, i64), INTEL_CPU_CVT(f32, f32), \ - INTEL_CPU_CVT(f16, f16), INTEL_CPU_CVT(bf16, bf16), INTEL_CPU_CVT(f64, f64), INTEL_CPU_CVT(boolean, boolean) - - -#define INTEL_CPU_CVT_FROM_BIN_LIST \ - INTEL_CPU_CVT(u1, f32), INTEL_CPU_CVT(u1, f16), INTEL_CPU_CVT(u1, bf16), \ - INTEL_CPU_CVT(u1, f64), INTEL_CPU_CVT(u1, i16), INTEL_CPU_CVT(u1, u8), \ - INTEL_CPU_CVT(u1, i8), INTEL_CPU_CVT(u1, u16), INTEL_CPU_CVT(u1, i32), \ - INTEL_CPU_CVT(u1, u32), INTEL_CPU_CVT(u1, i64), INTEL_CPU_CVT(u1, u64), \ - INTEL_CPU_CVT(u1, boolean) +#define INTEL_CPU_CVT_LIST \ + INTEL_CPU_CVT(u8, i8), INTEL_CPU_CVT(u8, u16), INTEL_CPU_CVT(u8, i16), INTEL_CPU_CVT(u8, u32), \ + INTEL_CPU_CVT(u8, i32), INTEL_CPU_CVT(u8, u64), INTEL_CPU_CVT(u8, i64), INTEL_CPU_CVT(u8, f32), \ + INTEL_CPU_CVT(u8, f16), INTEL_CPU_CVT(u8, bf16), INTEL_CPU_CVT(u8, f64), INTEL_CPU_CVT(u8, boolean), \ + INTEL_CPU_CVT(i8, u8), INTEL_CPU_CVT(i8, u16), INTEL_CPU_CVT(i8, i16), INTEL_CPU_CVT(i8, u32), \ + INTEL_CPU_CVT(i8, i32), INTEL_CPU_CVT(i8, u64), INTEL_CPU_CVT(i8, i64), INTEL_CPU_CVT(i8, f32), \ + INTEL_CPU_CVT(i8, f16), INTEL_CPU_CVT(i8, bf16), INTEL_CPU_CVT(i8, f64), INTEL_CPU_CVT(i8, boolean), \ + INTEL_CPU_CVT(u16, u8), INTEL_CPU_CVT(u16, i8), INTEL_CPU_CVT(u16, i16), INTEL_CPU_CVT(u16, u32), \ + INTEL_CPU_CVT(u16, i32), INTEL_CPU_CVT(u16, u64), INTEL_CPU_CVT(u16, i64), INTEL_CPU_CVT(u16, f32), \ + INTEL_CPU_CVT(u16, f16), INTEL_CPU_CVT(u16, bf16), INTEL_CPU_CVT(u16, f64), INTEL_CPU_CVT(u16, boolean), \ + INTEL_CPU_CVT(i16, u8), INTEL_CPU_CVT(i16, i8), INTEL_CPU_CVT(i16, u16), INTEL_CPU_CVT(i16, u32), \ + INTEL_CPU_CVT(i16, i32), INTEL_CPU_CVT(i16, u64), INTEL_CPU_CVT(i16, i64), INTEL_CPU_CVT(i16, f32), \ + INTEL_CPU_CVT(i16, f16), INTEL_CPU_CVT(i16, bf16), INTEL_CPU_CVT(i16, f64), INTEL_CPU_CVT(i16, boolean), \ + INTEL_CPU_CVT(u32, u8), INTEL_CPU_CVT(u32, i8), INTEL_CPU_CVT(u32, u16), INTEL_CPU_CVT(u32, i16), \ + INTEL_CPU_CVT(u32, i32), INTEL_CPU_CVT(u32, u64), INTEL_CPU_CVT(u32, i64), INTEL_CPU_CVT(u32, f32), \ + INTEL_CPU_CVT(u32, f16), INTEL_CPU_CVT(u32, bf16), INTEL_CPU_CVT(u32, f64), INTEL_CPU_CVT(u32, boolean), \ + INTEL_CPU_CVT(i32, u8), INTEL_CPU_CVT(i32, i8), INTEL_CPU_CVT(i32, u16), INTEL_CPU_CVT(i32, i16), \ + INTEL_CPU_CVT(i32, u32), INTEL_CPU_CVT(i32, u64), INTEL_CPU_CVT(i32, i64), INTEL_CPU_CVT(i32, f32), \ + INTEL_CPU_CVT(i32, f16), INTEL_CPU_CVT(i32, bf16), INTEL_CPU_CVT(i32, f64), INTEL_CPU_CVT(i32, boolean), \ + INTEL_CPU_CVT(u64, u8), INTEL_CPU_CVT(u64, i8), INTEL_CPU_CVT(u64, u16), INTEL_CPU_CVT(u64, i16), \ + INTEL_CPU_CVT(u64, u32), INTEL_CPU_CVT(u64, i32), INTEL_CPU_CVT(u64, i64), INTEL_CPU_CVT(u64, f32), \ + INTEL_CPU_CVT(u64, f16), INTEL_CPU_CVT(u64, bf16), INTEL_CPU_CVT(u64, f64), INTEL_CPU_CVT(u64, boolean), \ + INTEL_CPU_CVT(i64, u8), INTEL_CPU_CVT(i64, i8), INTEL_CPU_CVT(i64, u16), INTEL_CPU_CVT(i64, i16), \ + INTEL_CPU_CVT(i64, u32), INTEL_CPU_CVT(i64, i32), INTEL_CPU_CVT(i64, u64), INTEL_CPU_CVT(i64, f32), \ + INTEL_CPU_CVT(i64, f16), INTEL_CPU_CVT(i64, bf16), INTEL_CPU_CVT(i64, f64), INTEL_CPU_CVT(i64, boolean), \ + INTEL_CPU_CVT(f32, u8), INTEL_CPU_CVT(f32, i8), INTEL_CPU_CVT(f32, u16), INTEL_CPU_CVT(f32, i16), \ + INTEL_CPU_CVT(f32, u32), INTEL_CPU_CVT(f32, i32), INTEL_CPU_CVT(f32, u64), INTEL_CPU_CVT(f32, i64), \ + INTEL_CPU_CVT(f32, f16), INTEL_CPU_CVT(f32, bf16), INTEL_CPU_CVT(f32, f64), INTEL_CPU_CVT(f32, boolean), \ + INTEL_CPU_CVT(f16, u8), INTEL_CPU_CVT(f16, i8), INTEL_CPU_CVT(f16, u16), INTEL_CPU_CVT(f16, i16), \ + INTEL_CPU_CVT(f16, u32), INTEL_CPU_CVT(f16, i32), INTEL_CPU_CVT(f16, u64), INTEL_CPU_CVT(f16, i64), \ + INTEL_CPU_CVT(f16, f32), INTEL_CPU_CVT(f16, bf16), INTEL_CPU_CVT(f16, f64), INTEL_CPU_CVT(f16, boolean), \ + INTEL_CPU_CVT(bf16, u8), INTEL_CPU_CVT(bf16, i8), INTEL_CPU_CVT(bf16, u16), INTEL_CPU_CVT(bf16, i16), \ + INTEL_CPU_CVT(bf16, u32), INTEL_CPU_CVT(bf16, i32), INTEL_CPU_CVT(bf16, u64), INTEL_CPU_CVT(bf16, i64), \ + INTEL_CPU_CVT(bf16, f32), INTEL_CPU_CVT(bf16, f16), INTEL_CPU_CVT(bf16, f64), INTEL_CPU_CVT(bf16, boolean), \ + INTEL_CPU_CVT(f64, u8), INTEL_CPU_CVT(f64, i8), INTEL_CPU_CVT(f64, u16), INTEL_CPU_CVT(f64, i16), \ + INTEL_CPU_CVT(f64, u32), INTEL_CPU_CVT(f64, i32), INTEL_CPU_CVT(f64, u64), INTEL_CPU_CVT(f64, i64), \ + INTEL_CPU_CVT(f64, f32), INTEL_CPU_CVT(f64, f16), INTEL_CPU_CVT(f64, bf16), INTEL_CPU_CVT(f64, boolean), \ + INTEL_CPU_CVT(boolean, u8), INTEL_CPU_CVT(boolean, i8), INTEL_CPU_CVT(boolean, u16), \ + INTEL_CPU_CVT(boolean, i16), INTEL_CPU_CVT(boolean, u32), INTEL_CPU_CVT(boolean, i32), \ + INTEL_CPU_CVT(boolean, u64), INTEL_CPU_CVT(boolean, i64), INTEL_CPU_CVT(boolean, f32), \ + INTEL_CPU_CVT(boolean, f16), INTEL_CPU_CVT(boolean, bf16), INTEL_CPU_CVT(boolean, f64), INTEL_CPU_CVT(u8, u8), \ + INTEL_CPU_CVT(i8, i8), INTEL_CPU_CVT(u16, u16), INTEL_CPU_CVT(i16, i16), INTEL_CPU_CVT(u32, u32), \ + INTEL_CPU_CVT(i32, i32), INTEL_CPU_CVT(u64, u64), INTEL_CPU_CVT(i64, i64), INTEL_CPU_CVT(f32, f32), \ + INTEL_CPU_CVT(f16, f16), INTEL_CPU_CVT(bf16, bf16), INTEL_CPU_CVT(f64, f64), INTEL_CPU_CVT(boolean, boolean) + +#define INTEL_CPU_CVT_FROM_BIN_LIST \ + INTEL_CPU_CVT(u1, f32), INTEL_CPU_CVT(u1, f16), INTEL_CPU_CVT(u1, bf16), INTEL_CPU_CVT(u1, f64), \ + INTEL_CPU_CVT(u1, i16), INTEL_CPU_CVT(u1, u8), INTEL_CPU_CVT(u1, i8), INTEL_CPU_CVT(u1, u16), \ + INTEL_CPU_CVT(u1, i32), INTEL_CPU_CVT(u1, u32), INTEL_CPU_CVT(u1, i64), INTEL_CPU_CVT(u1, u64), \ + INTEL_CPU_CVT(u1, boolean) struct ConvertFromBinContext { - const void *srcPtr; - void *dstPtr; + const void* srcPtr; + void* dstPtr; size_t size; bool converted; }; -template +template struct ConvertFromBinPrecision; -template +template struct ConvertFromBinPrecision> { - void operator()(ConvertFromBinContext &ctx) { - auto src = static_cast(ctx.srcPtr); - auto dst = static_cast(ctx.dstPtr); + void operator()(ConvertFromBinContext& ctx) { + auto src = static_cast(ctx.srcPtr); + auto dst = static_cast(ctx.dstPtr); const size_t nBits = 8; const size_t nBytes = rnd_up(ctx.size, nBits); parallel_for(nBytes, [&](size_t byteIndex) { @@ -590,16 +570,17 @@ struct ConvertFromBinPrecision> { } }; -#define INTEL_CPU_CVT_FROM_4BIT_LIST \ - INTEL_CPU_CVT(u4, f32), INTEL_CPU_CVT(u4, bf16), INTEL_CPU_CVT(u4, f16), INTEL_CPU_CVT(u4, i8), INTEL_CPU_CVT(u4, u8), \ - INTEL_CPU_CVT(i4, f32), INTEL_CPU_CVT(i4, bf16), INTEL_CPU_CVT(i4, f16), INTEL_CPU_CVT(i4, i8), INTEL_CPU_CVT(i4, u8), \ - INTEL_CPU_CVT(nf4, f32), INTEL_CPU_CVT(nf4, bf16), INTEL_CPU_CVT(nf4, f16), INTEL_CPU_CVT(nf4, i8), INTEL_CPU_CVT(nf4, u8), \ - INTEL_CPU_CVT(f4e2m1, f32), INTEL_CPU_CVT(f4e2m1, bf16), INTEL_CPU_CVT(f4e2m1, f16), INTEL_CPU_CVT(f4e2m1, i8), INTEL_CPU_CVT(f4e2m1, u8) +#define INTEL_CPU_CVT_FROM_4BIT_LIST \ + INTEL_CPU_CVT(u4, f32), INTEL_CPU_CVT(u4, bf16), INTEL_CPU_CVT(u4, f16), INTEL_CPU_CVT(u4, i8), \ + INTEL_CPU_CVT(u4, u8), INTEL_CPU_CVT(i4, f32), INTEL_CPU_CVT(i4, bf16), INTEL_CPU_CVT(i4, f16), \ + INTEL_CPU_CVT(i4, i8), INTEL_CPU_CVT(i4, u8), INTEL_CPU_CVT(nf4, f32), INTEL_CPU_CVT(nf4, bf16), \ + INTEL_CPU_CVT(nf4, f16), INTEL_CPU_CVT(nf4, i8), INTEL_CPU_CVT(nf4, u8), INTEL_CPU_CVT(f4e2m1, f32), \ + INTEL_CPU_CVT(f4e2m1, bf16), INTEL_CPU_CVT(f4e2m1, f16), INTEL_CPU_CVT(f4e2m1, i8), INTEL_CPU_CVT(f4e2m1, u8) struct ConvertFrom4BitContext { ov::element::Type_t inType; - const void *srcPtr; - void *dstPtr; + const void* srcPtr; + void* dstPtr; size_t size; bool converted; }; @@ -624,12 +605,12 @@ static int8_t get_u4(const uint8_t& val, bool high) { return high ? (val >> 4) : (val & 0xF); } -template +template struct ConvertFrom4BitPrecision; -template +template struct ConvertFrom4BitPrecision> { - void operator()(ConvertFrom4BitContext &ctx) { + void operator()(ConvertFrom4BitContext& ctx) { auto src = static_cast(ctx.srcPtr); auto dst = static_cast(ctx.dstPtr); if (ctx.inType == ov::element::nf4) { @@ -655,23 +636,23 @@ struct ConvertFrom4BitPrecision> { } }; -#define INTEL_CPU_CVT_FROM_BYTE_FP_LIST \ +#define INTEL_CPU_CVT_FROM_BYTE_FP_LIST \ INTEL_CPU_CVT(f8e8m0, f32), INTEL_CPU_CVT(f8e8m0, bf16), INTEL_CPU_CVT(f8e8m0, f16) struct ConvertFromByteFPContext { ov::element::Type_t inType; - const void *srcPtr; - void *dstPtr; + const void* srcPtr; + void* dstPtr; size_t size; bool converted; }; -template +template struct ConvertFromByteFPPrecision; -template +template struct ConvertFromByteFPPrecision> { - void operator()(ConvertFromByteFPContext &ctx) { + void operator()(ConvertFromByteFPContext& ctx) { auto src = static_cast(ctx.srcPtr); auto dst = static_cast(ctx.dstPtr); if (ctx.inType == ov::element::f8e8m0) { @@ -685,12 +666,16 @@ struct ConvertFromByteFPPrecision> { } }; -void cpu_convert(const void *srcPtr, void *dstPtr, ov::element::Type srcPrc, ov::element::Type dstPrc, const size_t size) { +void cpu_convert(const void* srcPtr, + void* dstPtr, + ov::element::Type srcPrc, + ov::element::Type dstPrc, + const size_t size) { cpu_convert(srcPtr, dstPtr, srcPrc, dstPrc, dstPrc, size); } -void cpu_convert(const void *srcPtr, - void *dstPtr, +void cpu_convert(const void* srcPtr, + void* dstPtr, ov::element::Type srcPrc, ov::element::Type interimPrc, ov::element::Type dstPrc, @@ -705,12 +690,12 @@ void cpu_convert(const void *srcPtr, const size_t L2_cache_size = dnnl::utils::get_cache_size(2, true); const size_t totalSize = size * dstPrc.size(); if (srcPrc == element::string) { - auto str_src = reinterpret_cast(srcPtr); - auto str_dst = reinterpret_cast(dstPtr); + auto str_src = reinterpret_cast(srcPtr); + auto str_dst = reinterpret_cast(dstPtr); std::copy(str_src, str_src + size, str_dst); } else if (totalSize >= L2_cache_size) { - auto src = static_cast(srcPtr); - auto dst = static_cast(dstPtr); + auto src = static_cast(srcPtr); + auto dst = static_cast(dstPtr); parallel_nt(0, [&](const size_t ithr, const size_t nthr) { size_t start = 0, end = 0; splitter(totalSize, nthr, ithr, start, end); @@ -728,12 +713,7 @@ void cpu_convert(const void *srcPtr, "> precision to: ", dstPrc, ". Not implemented."); - ConvertFromBinContext ctx { - srcPtr, - dstPtr, - size, - false - }; + ConvertFromBinContext ctx{srcPtr, dstPtr, size, false}; OV_SWITCH(intel_cpu, ConvertFromBinPrecision, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_BIN_LIST); if (!ctx.converted) OPENVINO_THROW("cpu_convert can't convert from: ", @@ -749,18 +729,15 @@ void cpu_convert(const void *srcPtr, OPENVINO_THROW("cpu_convert can't convert from: ", srcPrc, " precision to: ", dstPrc); } else if (srcPrc.bitwidth() == 8u && srcPrc.is_real()) { ConvertFromByteFPContext ctx{srcPrc, srcPtr, dstPtr, size, false}; - OV_SWITCH(intel_cpu, ConvertFromByteFPPrecision, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_FROM_BYTE_FP_LIST); + OV_SWITCH(intel_cpu, + ConvertFromByteFPPrecision, + ctx, + std::tie(srcPrc, dstPrc), + INTEL_CPU_CVT_FROM_BYTE_FP_LIST); if (!ctx.converted) OPENVINO_THROW("cpu_convert can't convert from: ", srcPrc, " precision to: ", dstPrc); } else { - ConvertContext ctx { - srcPtr, - dstPtr, - size, - interimPrc, - dstPrc, - false - }; + ConvertContext ctx{srcPtr, dstPtr, size, interimPrc, dstPrc, false}; OV_SWITCH(intel_cpu, ConvertPrecision, ctx, std::tie(srcPrc, dstPrc), INTEL_CPU_CVT_LIST); if (!ctx.converted) OPENVINO_THROW("cpu_convert can't convert from: ", srcPrc, " precision to: ", dstPrc); @@ -773,7 +750,7 @@ struct isSupportedContext { template struct isSupported { - void operator()(isSupportedContext &ctx) { + void operator()(isSupportedContext& ctx) { ctx.isSupported = true; } }; @@ -790,5 +767,5 @@ bool is_supported_convert(ov::element::Type srcPrc, ov::element::Type dstPrc) { #undef INTEL_CPU_CVT #undef INTEL_CPU_CVT_LIST -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/cpu_convert.h b/src/plugins/intel_cpu/src/nodes/common/cpu_convert.h index 8390849ff8adc7..11228dbd1dcfdb 100644 --- a/src/plugins/intel_cpu/src/nodes/common/cpu_convert.h +++ b/src/plugins/intel_cpu/src/nodes/common/cpu_convert.h @@ -22,8 +22,8 @@ namespace intel_cpu { * number of elements in buffers to be converted * @return none. */ -void cpu_convert(const void *srcPtr, - void *dstPtr, +void cpu_convert(const void* srcPtr, + void* dstPtr, ov::element::Type srcPrc, ov::element::Type dstPrc, const size_t size); @@ -45,14 +45,14 @@ void cpu_convert(const void *srcPtr, * number of elements in buffers to be converted * @return none. */ -void cpu_convert(const void *srcPtr, - void *dstPtr, +void cpu_convert(const void* srcPtr, + void* dstPtr, ov::element::Type srcPrc, ov::element::Type interimPrc, ov::element::Type dstPrc, const size_t size); - bool is_supported_convert(ov::element::Type srcPrc, ov::element::Type dstPrc); +bool is_supported_convert(ov::element::Type srcPrc, ov::element::Type dstPrc); -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/cpu_memcpy.h b/src/plugins/intel_cpu/src/nodes/common/cpu_memcpy.h old mode 100755 new mode 100644 index 95b0267bd4757c..e827d35a11c2ad --- a/src/plugins/intel_cpu/src/nodes/common/cpu_memcpy.h +++ b/src/plugins/intel_cpu/src/nodes/common/cpu_memcpy.h @@ -5,8 +5,9 @@ #pragma once #include -#include "openvino/core/parallel.hpp" + #include "onednn/dnnl.h" +#include "openvino/core/parallel.hpp" namespace ov { namespace intel_cpu { @@ -36,8 +37,7 @@ inline void cpu_memcpy(void* dst, const void* src, size_t count) { } inline int cpu_memcpy_s(void* dst, size_t dst_size, const void* src, size_t count) { - if (!src || - count > dst_size || + if (!src || count > dst_size || count > (dst > src ? ((uintptr_t)dst - (uintptr_t)src) : ((uintptr_t)src - (uintptr_t)dst))) { // zero out dest if error detected std::memset(dst, 0, dst_size); @@ -55,8 +55,8 @@ inline int cpu_memcpy_s(void* dst, size_t dst_size, const void* src, size_t coun inline void cpu_parallel_memcpy(void* dst, const void* src, size_t count) { const size_t l2_cache_size = dnnl::utils::get_cache_size(2, true); if (count >= l2_cache_size) { - auto src_int8 = static_cast(src); - auto dst_int8 = static_cast(dst); + auto src_int8 = static_cast(src); + auto dst_int8 = static_cast(dst); parallel_nt(0, [&](const size_t ithr, const size_t nthr) { size_t start = 0, end = 0; splitter(count, nthr, ithr, start, end); @@ -67,5 +67,5 @@ inline void cpu_parallel_memcpy(void* dst, const void* src, size_t count) { } } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/defs.h b/src/plugins/intel_cpu/src/nodes/common/defs.h index 6d8574de0939a4..a8a07a2cc8942a 100644 --- a/src/plugins/intel_cpu/src/nodes/common/defs.h +++ b/src/plugins/intel_cpu/src/nodes/common/defs.h @@ -4,10 +4,10 @@ #pragma once -#if defined (HAVE_SSE) || defined (HAVE_AVX2) -# if defined (_WIN32) -# include -# else -# include -# endif +#if defined(HAVE_SSE) || defined(HAVE_AVX2) +# if defined(_WIN32) +# include +# else +# include +# endif #endif diff --git a/src/plugins/intel_cpu/src/nodes/common/dnnl_executor.cpp b/src/plugins/intel_cpu/src/nodes/common/dnnl_executor.cpp index 51aa54c2f50463..695fdbe823ea15 100644 --- a/src/plugins/intel_cpu/src/nodes/common/dnnl_executor.cpp +++ b/src/plugins/intel_cpu/src/nodes/common/dnnl_executor.cpp @@ -18,7 +18,9 @@ DnnlExecutor::DnnlExecutor(const dnnl::primitive_desc& pd) { DnnlExecutor::IntermReorder::IntermReorder(const dnnl::memory::desc& descSrc, const dnnl::memory::desc& descDst, - const dnnl::engine& engine) : m_descSrc(descSrc), m_descDst(descDst) { + const dnnl::engine& engine) + : m_descSrc(descSrc), + m_descDst(descDst) { auto reorderPd = dnnl::reorder::primitive_desc(engine, descSrc, engine, descDst); m_reorder = dnnl::reorder(reorderPd); } @@ -36,7 +38,7 @@ void DnnlExecutor::exec(const std::unordered_map& primArgs, d } void DnnlExecutor::reorder_exec(std::unordered_map primArgs, dnnl::stream strm) { - for (auto &inReorder : inputReorders) { + for (auto& inReorder : inputReorders) { if (primArgs.count(inReorder.first)) { dnnl::memory memDst(inReorder.second.getDstDesc(), strm.get_engine()); inReorder.second.exec(primArgs[inReorder.first], memDst, strm); @@ -46,17 +48,19 @@ void DnnlExecutor::reorder_exec(std::unordered_map primArgs, } } std::unordered_map outputMem; - for (auto &outReorder : outputReorders) { + for (auto& outReorder : outputReorders) { if (primArgs.count(outReorder.first)) { dnnl::memory memSrc(outReorder.second.getSrcDesc(), strm.get_engine()); outputMem[outReorder.first] = primArgs[outReorder.first]; primArgs[outReorder.first] = memSrc; } else { - OPENVINO_THROW("DnnlExecutor has reorder for output ", outReorder.first, ", but doesn't have destination memory"); + OPENVINO_THROW("DnnlExecutor has reorder for output ", + outReorder.first, + ", but doesn't have destination memory"); } } execPrim.execute(strm, primArgs); - for (auto &outReorder : outputReorders) { + for (auto& outReorder : outputReorders) { outReorder.second.exec(primArgs[outReorder.first], outputMem[outReorder.first], strm); } } @@ -79,4 +83,4 @@ impl_desc_type DnnlExecutor::getImplementationType() const { } } // namespace intel_cpu -} // namespace ov +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/dnnl_executor.h b/src/plugins/intel_cpu/src/nodes/common/dnnl_executor.h index 3cc6749857816c..32739a38d37028 100644 --- a/src/plugins/intel_cpu/src/nodes/common/dnnl_executor.h +++ b/src/plugins/intel_cpu/src/nodes/common/dnnl_executor.h @@ -6,74 +6,79 @@ #include #include + #include "memory_desc/dnnl_memory_desc.h" namespace ov { namespace intel_cpu { class DnnlExecutor { - protected: - class IntermReorder { - public: - IntermReorder(const dnnl::memory::desc& descSrc, const dnnl::memory::desc& descDst, const dnnl::engine& engine); - void exec(dnnl::memory& memSrc, dnnl::memory& memDst, dnnl::stream strm); - const dnnl::memory::desc& getSrcDesc() const { return m_descSrc; } - const dnnl::memory::desc& getDstDesc() const { return m_descDst; } - - private: - dnnl::reorder m_reorder; - dnnl::memory::desc m_descSrc; - dnnl::memory::desc m_descDst; - }; - +protected: + class IntermReorder { public: - explicit DnnlExecutor(const dnnl::primitive_desc& pd); - void exec(const std::unordered_map& primArgs, dnnl::stream strm); - bool needReordering() const; - virtual ~DnnlExecutor() = default; - dnnl::primitive getExecPrim() const; - const_dnnl_primitive_desc_t getPrimitiveDesc() const; - impl_desc_type getImplementationType() const; - - DnnlMemoryDescPtr getSrcDesc() const { - return src_md; + IntermReorder(const dnnl::memory::desc& descSrc, const dnnl::memory::desc& descDst, const dnnl::engine& engine); + void exec(dnnl::memory& memSrc, dnnl::memory& memDst, dnnl::stream strm); + const dnnl::memory::desc& getSrcDesc() const { + return m_descSrc; } - DnnlMemoryDescPtr getWeightDesc() const { - return wghts_md; - } - DnnlMemoryDescPtr getDstDesc() const { - return dst_md; - } - DnnlMemoryDescPtr getScratchPadDesc() const { - return scrch_md; + const dnnl::memory::desc& getDstDesc() const { + return m_descDst; } - const dnnl::memory::desc& getDnnlSrcDesc() const { - return src_md->getDnnlDesc(); - } - const dnnl::memory::desc& getDnnlWeightDesc() const { - return wghts_md->getDnnlDesc(); - } - const dnnl::memory::desc& getDnnlDstDesc() const { - return dst_md->getDnnlDesc(); - } - const dnnl::memory::desc& getDnnlScratchPadDesc() const { - return scrch_md->getDnnlDesc(); - } + private: + dnnl::reorder m_reorder; + dnnl::memory::desc m_descSrc; + dnnl::memory::desc m_descDst; + }; + +public: + explicit DnnlExecutor(const dnnl::primitive_desc& pd); + void exec(const std::unordered_map& primArgs, dnnl::stream strm); + bool needReordering() const; + virtual ~DnnlExecutor() = default; + dnnl::primitive getExecPrim() const; + const_dnnl_primitive_desc_t getPrimitiveDesc() const; + impl_desc_type getImplementationType() const; + + DnnlMemoryDescPtr getSrcDesc() const { + return src_md; + } + DnnlMemoryDescPtr getWeightDesc() const { + return wghts_md; + } + DnnlMemoryDescPtr getDstDesc() const { + return dst_md; + } + DnnlMemoryDescPtr getScratchPadDesc() const { + return scrch_md; + } + + const dnnl::memory::desc& getDnnlSrcDesc() const { + return src_md->getDnnlDesc(); + } + const dnnl::memory::desc& getDnnlWeightDesc() const { + return wghts_md->getDnnlDesc(); + } + const dnnl::memory::desc& getDnnlDstDesc() const { + return dst_md->getDnnlDesc(); + } + const dnnl::memory::desc& getDnnlScratchPadDesc() const { + return scrch_md->getDnnlDesc(); + } - protected: - virtual void reorder_exec(std::unordered_map primArgs, dnnl::stream strm); +protected: + virtual void reorder_exec(std::unordered_map primArgs, dnnl::stream strm); - protected: - dnnl::primitive execPrim; - // key is the port number for the primitive that needs memory reordering - std::unordered_map inputReorders; - std::unordered_map outputReorders; - DnnlMemoryDescPtr src_md; - DnnlMemoryDescPtr wghts_md; - DnnlMemoryDescPtr dst_md; - DnnlMemoryDescPtr scrch_md; +protected: + dnnl::primitive execPrim; + // key is the port number for the primitive that needs memory reordering + std::unordered_map inputReorders; + std::unordered_map outputReorders; + DnnlMemoryDescPtr src_md; + DnnlMemoryDescPtr wghts_md; + DnnlMemoryDescPtr dst_md; + DnnlMemoryDescPtr scrch_md; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/fp16_utils.h b/src/plugins/intel_cpu/src/nodes/common/fp16_utils.h index daedcc4bf23ca4..b6622f7ae54d0b 100644 --- a/src/plugins/intel_cpu/src/nodes/common/fp16_utils.h +++ b/src/plugins/intel_cpu/src/nodes/common/fp16_utils.h @@ -13,7 +13,7 @@ typedef short ie_fp16; // F32: exp_bias:127 SEEEEEEE EMMMMMMM MMMMMMMM MMMMMMMM. // F16: exp_bias:15 SEEEEEMM MMMMMMMM #define EXP_MASK_F32 0x7F800000U -#define EXP_MASK_F16 0x7C00U +#define EXP_MASK_F16 0x7C00U // small helper function to represent uint32_t value as float32 inline float asfloat(uint32_t v) { @@ -83,5 +83,5 @@ inline float f16tof32(ie_fp16 x) { return asfloat(u); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/permute_kernel.cpp b/src/plugins/intel_cpu/src/nodes/common/permute_kernel.cpp index 396cebc1ba82e1..60bd675d726e4a 100644 --- a/src/plugins/intel_cpu/src/nodes/common/permute_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/common/permute_kernel.cpp @@ -6,15 +6,14 @@ #include -#include "dnnl_types.h" -#include "dnnl_extension_utils.h" -#include "cpu_memcpy.h" -#include "utils/bfloat16.hpp" - -#include "cpu/x64/jit_generator.hpp" #include "common/primitive_hashing_utils.hpp" -#include "nodes/executors/transpose.hpp" +#include "cpu/x64/jit_generator.hpp" +#include "cpu_memcpy.h" +#include "dnnl_extension_utils.h" +#include "dnnl_types.h" #include "nodes/executors/common/ref_transpose.hpp" +#include "nodes/executors/transpose.hpp" +#include "utils/bfloat16.hpp" using namespace dnnl; using namespace dnnl::impl; @@ -33,7 +32,9 @@ template struct jit_uni_permute_kernel_f32 : public jit_uni_permute_kernel, public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_permute_kernel_f32) - explicit jit_uni_permute_kernel_f32(jit_permute_config_params jcp_) : jit_uni_permute_kernel(jcp_), jit_generator(jit_name()) {} + explicit jit_uni_permute_kernel_f32(jit_permute_config_params jcp_) + : jit_uni_permute_kernel(jcp_), + jit_generator(jit_name()) {} void create_ker() override { jit_generator::create_kernel(); @@ -51,23 +52,43 @@ struct jit_uni_permute_kernel_f32 : public jit_uni_permute_kernel, public jit_ge this->postamble(); } - void load(const Xbyak::Xmm &xmm, const Xbyak::Address &addr) { + void load(const Xbyak::Xmm& xmm, const Xbyak::Address& addr) { switch (jcp.data_size) { - case 16: uni_vmovups(xmm, addr); break; - case 8: uni_vmovsd(xmm, addr); break; - case 4: uni_vmovss(xmm, addr); break; - case 2: uni_vpinsrw(xmm, xmm, addr, 0x0); break; - case 1: uni_vpinsrb(xmm, xmm, addr, 0x0); break; + case 16: + uni_vmovups(xmm, addr); + break; + case 8: + uni_vmovsd(xmm, addr); + break; + case 4: + uni_vmovss(xmm, addr); + break; + case 2: + uni_vpinsrw(xmm, xmm, addr, 0x0); + break; + case 1: + uni_vpinsrb(xmm, xmm, addr, 0x0); + break; } } - void store(const Xbyak::Address &addr, const Xbyak::Xmm &xmm) { + void store(const Xbyak::Address& addr, const Xbyak::Xmm& xmm) { switch (jcp.data_size) { - case 16: uni_vmovups(addr, xmm); break; - case 8: uni_vmovsd(addr, xmm); break; - case 4: uni_vmovss(addr, xmm); break; - case 2: uni_vpextrw(addr, xmm, 0x0); break; - case 1: uni_vpextrb(addr, xmm, 0x0); break; + case 16: + uni_vmovups(addr, xmm); + break; + case 8: + uni_vmovsd(addr, xmm); + break; + case 4: + uni_vmovss(addr, xmm); + break; + case 2: + uni_vpextrw(addr, xmm, 0x0); + break; + case 1: + uni_vpextrb(addr, xmm, 0x0); + break; } } @@ -99,7 +120,8 @@ struct jit_uni_permute_kernel_f32 : public jit_uni_permute_kernel, public jit_ge } } - L(tail_loop_label); { + L(tail_loop_label); + { cmp(reg_work_amount, 0); je(exit_label, T_NEAR); @@ -129,7 +151,8 @@ struct jit_uni_permute_kernel_f32 : public jit_uni_permute_kernel, public jit_ge } private: - using Vmm = typename conditional3::type; + using Vmm = + typename conditional3::type; uint32_t vlen = cpu_isa_traits::vlen; Xbyak::Reg64 reg_src = r8; @@ -144,7 +167,7 @@ struct jit_uni_permute_kernel_f32 : public jit_uni_permute_kernel, public jit_ge Xbyak::Xmm xmm = Xbyak::Xmm(1); }; -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 PermuteKernel::PermuteKernel(const PermuteParams& params) : params(params) { jcp = TransposeExecutor::prepareParams(params); @@ -156,7 +179,7 @@ PermuteKernel::PermuteKernel(const PermuteParams& params) : params(params) { } else if (mayiuse(cpu::x64::sse41)) { permute_kernel.reset(new jit_uni_permute_kernel_f32(jcp)); } -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 if (permute_kernel) permute_kernel->create_ker(); @@ -178,7 +201,7 @@ void PermuteKernel::execute(const uint8_t* src_data, uint8_t* dst_data) { return; } - RefTransposeExecutor::referenceExecute(src_data, dst_data, jcp, dst_dims[0]); + RefTransposeExecutor::referenceExecute(src_data, dst_data, jcp, dst_dims[0]); } void PermuteKernel::optimizedExecute(const uint8_t* src_data, uint8_t* dst_data, const int mb) { @@ -190,42 +213,42 @@ void PermuteKernel::optimizedExecute(const uint8_t* src_data, uint8_t* dst_data, dst_dims[0] = mb; switch (jcp.n) { - case 1: - parallel_for(dst_dims[0], [&](int i0) { - auto arg = jit_args_permute(); - - size_t dst_off = i0 * dst_strides[0]; - size_t src_off = i0 * src_strides[0]; - arg.src = &src_data[src_off * jcp.data_size]; - arg.dst = &dst_data[dst_off * jcp.data_size]; - - (*permute_kernel)(&arg); - }); - break; - case 2: - parallel_for2d(dst_dims[0], dst_dims[1], [&](int i0, int i1) { - auto arg = jit_args_permute(); - - size_t dst_off = i0 * dst_strides[0] + i1 * dst_strides[1]; - size_t src_off = i0 * src_strides[0] + i1 * src_strides[1]; - arg.src = &src_data[src_off * jcp.data_size]; - arg.dst = &dst_data[dst_off * jcp.data_size]; - - (*permute_kernel)(&arg); - }); - break; - case 3: - parallel_for3d(dst_dims[0], dst_dims[1], dst_dims[2], [&](int i0, int i1, int i2) { - auto arg = jit_args_permute(); - - size_t dst_off = i0 * dst_strides[0] + i1 * dst_strides[1] + i2 * dst_strides[2]; - size_t src_off = i0 * src_strides[0] + i1 * src_strides[1] + i2 * src_strides[2]; - arg.src = &src_data[src_off * jcp.data_size]; - arg.dst = &dst_data[dst_off * jcp.data_size]; - - (*permute_kernel)(&arg); - }); - break; + case 1: + parallel_for(dst_dims[0], [&](int i0) { + auto arg = jit_args_permute(); + + size_t dst_off = i0 * dst_strides[0]; + size_t src_off = i0 * src_strides[0]; + arg.src = &src_data[src_off * jcp.data_size]; + arg.dst = &dst_data[dst_off * jcp.data_size]; + + (*permute_kernel)(&arg); + }); + break; + case 2: + parallel_for2d(dst_dims[0], dst_dims[1], [&](int i0, int i1) { + auto arg = jit_args_permute(); + + size_t dst_off = i0 * dst_strides[0] + i1 * dst_strides[1]; + size_t src_off = i0 * src_strides[0] + i1 * src_strides[1]; + arg.src = &src_data[src_off * jcp.data_size]; + arg.dst = &dst_data[dst_off * jcp.data_size]; + + (*permute_kernel)(&arg); + }); + break; + case 3: + parallel_for3d(dst_dims[0], dst_dims[1], dst_dims[2], [&](int i0, int i1, int i2) { + auto arg = jit_args_permute(); + + size_t dst_off = i0 * dst_strides[0] + i1 * dst_strides[1] + i2 * dst_strides[2]; + size_t src_off = i0 * src_strides[0] + i1 * src_strides[1] + i2 * src_strides[2]; + arg.src = &src_data[src_off * jcp.data_size]; + arg.dst = &dst_data[dst_off * jcp.data_size]; + + (*permute_kernel)(&arg); + }); + break; } return; } @@ -245,12 +268,10 @@ size_t PermuteParams::hash() const { } bool PermuteParams::operator==(const PermuteParams& rhs) const { - return (src_block_dims == rhs.src_block_dims) && - (dst_block_dims == rhs.dst_block_dims) && - (src_block_order == rhs.src_block_order) && - (dst_block_order == rhs.dst_block_order) && (order == rhs.order) && - (data_size == rhs.data_size); + return (src_block_dims == rhs.src_block_dims) && (dst_block_dims == rhs.dst_block_dims) && + (src_block_order == rhs.src_block_order) && (dst_block_order == rhs.dst_block_order) && + (order == rhs.order) && (data_size == rhs.data_size); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/permute_kernel.h b/src/plugins/intel_cpu/src/nodes/common/permute_kernel.h index ac665efb4f0bb6..ba7a89d746d945 100644 --- a/src/plugins/intel_cpu/src/nodes/common/permute_kernel.h +++ b/src/plugins/intel_cpu/src/nodes/common/permute_kernel.h @@ -38,9 +38,9 @@ struct jit_args_permute { }; struct jit_uni_permute_kernel { - void (*ker_)(const jit_args_permute *); + void (*ker_)(const jit_args_permute*); - void operator()(const jit_args_permute *args) { + void operator()(const jit_args_permute* args) { assert(ker_); ker_(args); } @@ -71,5 +71,5 @@ class PermuteKernel { PermuteParams params; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/reorder_prim.cpp b/src/plugins/intel_cpu/src/nodes/common/reorder_prim.cpp index 93e145b25b9e95..dd07a721260aac 100644 --- a/src/plugins/intel_cpu/src/nodes/common/reorder_prim.cpp +++ b/src/plugins/intel_cpu/src/nodes/common/reorder_prim.cpp @@ -4,15 +4,14 @@ #include "reorder_prim.h" -#include "dnnl_extension_utils.h" -#include "dnnl_types.h" - #include -#include "common/primitive_hashing_utils.hpp" -#include "cpu/x64/cpu_isa_traits.hpp" #include #include +#include "common/primitive_hashing_utils.hpp" +#include "cpu/x64/cpu_isa_traits.hpp" +#include "dnnl_extension_utils.h" +#include "dnnl_types.h" #include "utils/general_utils.h" namespace ov { diff --git a/src/plugins/intel_cpu/src/nodes/common/softmax.cpp b/src/plugins/intel_cpu/src/nodes/common/softmax.cpp index 66a6ca9c1b6f53..0fcc87f8978752 100644 --- a/src/plugins/intel_cpu/src/nodes/common/softmax.cpp +++ b/src/plugins/intel_cpu/src/nodes/common/softmax.cpp @@ -4,17 +4,17 @@ #include "softmax.h" -#include "openvino/core/parallel.hpp" -#include "cpu/x64/jit_generator.hpp" -#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" -#include "onednn/dnnl.h" -#include "utils/bfloat16.hpp" -#include "emitters/plugin/x64/jit_bf16_emitters.hpp" - #include #include #include +#include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" +#include "cpu/x64/jit_generator.hpp" +#include "emitters/plugin/x64/jit_bf16_emitters.hpp" +#include "onednn/dnnl.h" +#include "openvino/core/parallel.hpp" +#include "utils/bfloat16.hpp" + using namespace dnnl; using namespace dnnl::impl::cpu; using namespace dnnl::impl::cpu::x64; @@ -38,11 +38,13 @@ struct jit_softmax_config_params { ov::element::Type dst_dt; }; - struct jit_uni_softmax_kernel { - void (*ker_)(const jit_args_softmax *); + void (*ker_)(const jit_args_softmax*); - void operator()(const jit_args_softmax *args) { assert(ker_); ker_(args); } + void operator()(const jit_args_softmax* args) { + assert(ker_); + ker_(args); + } jit_uni_softmax_kernel() : ker_(nullptr) {} virtual ~jit_uni_softmax_kernel() {} @@ -54,7 +56,10 @@ template struct jit_uni_softmax_kernel_f32 : public jit_uni_softmax_kernel, public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_softmax_kernel_f32) - jit_uni_softmax_kernel_f32(jit_softmax_config_params jcp) : jit_uni_softmax_kernel(), jit_generator(jit_name()), jcp_(jcp) {} + jit_uni_softmax_kernel_f32(jit_softmax_config_params jcp) + : jit_uni_softmax_kernel(), + jit_generator(jit_name()), + jcp_(jcp) {} void create_ker() override { jit_generator::create_kernel(); @@ -62,14 +67,14 @@ struct jit_uni_softmax_kernel_f32 : public jit_uni_softmax_kernel, public jit_ge } void generate() override { - exp_injector.reset(new jit_uni_eltwise_injector_f32(this, dnnl::impl::alg_kind::eltwise_exp, 0.f, 0.f, 1.0f)); + exp_injector.reset( + new jit_uni_eltwise_injector_f32(this, dnnl::impl::alg_kind::eltwise_exp, 0.f, 0.f, 1.0f)); if (mayiuse(avx512_core)) uni_vcvtneps2bf16.reset(new jit_uni_vcvtneps2bf16(this, isa)); this->preamble(); - mov(reg_src, ptr[reg_params + GET_OFF(src)]); mov(reg_dst, ptr[reg_params + GET_OFF(dst)]); mov(reg_src_stride, ptr[reg_params + GET_OFF(src_stride)]); @@ -86,7 +91,8 @@ struct jit_uni_softmax_kernel_f32 : public jit_uni_softmax_kernel, public jit_ge mov(aux_reg_work_amount, reg_work_amount); mov(aux_reg_src, reg_src); load_vector(vmm_max, ptr[aux_reg_src], jcp_.src_dt); - L(max_loop_label); { + L(max_loop_label); + { cmp(aux_reg_work_amount, 0); jle(max_loop_end_label, T_NEAR); @@ -120,7 +126,8 @@ struct jit_uni_softmax_kernel_f32 : public jit_uni_softmax_kernel, public jit_ge mov(aux_reg_src, reg_src); mov(aux_reg_dst, reg_dst); uni_vpxor(vmm_exp_sum, vmm_exp_sum, vmm_exp_sum); - L(exp_loop_label); { + L(exp_loop_label); + { cmp(aux_reg_work_amount, 0); jle(exp_loop_end_label, T_NEAR); @@ -143,7 +150,8 @@ struct jit_uni_softmax_kernel_f32 : public jit_uni_softmax_kernel, public jit_ge mov(aux_reg_work_amount, reg_work_amount); mov(aux_reg_dst, reg_dst); - L(div_loop_label); { + L(div_loop_label); + { cmp(aux_reg_work_amount, 0); jle(div_loop_end_label, T_NEAR); @@ -196,38 +204,40 @@ struct jit_uni_softmax_kernel_f32 : public jit_uni_softmax_kernel, public jit_ge jit_softmax_config_params jcp_; - inline void load_vector(Vmm vmm_src, const Xbyak::Address &op, ov::element::Type src_dt) { + inline void load_vector(Vmm vmm_src, const Xbyak::Address& op, ov::element::Type src_dt) { switch (src_dt) { - case ov::element::f32: - uni_vmovups(vmm_src, op); - break; - case ov::element::bf16: - vpmovzxwd(vmm_src, op); - uni_vpslld(vmm_src, vmm_src, 16); - break; - default: - assert(!"unknown src_dt"); + case ov::element::f32: + uni_vmovups(vmm_src, op); + break; + case ov::element::bf16: + vpmovzxwd(vmm_src, op); + uni_vpslld(vmm_src, vmm_src, 16); + break; + default: + assert(!"unknown src_dt"); } } - inline void store_vector(const Xbyak::Address &op, Vmm vmm_dst, ov::element::Type dst_dt) { + inline void store_vector(const Xbyak::Address& op, Vmm vmm_dst, ov::element::Type dst_dt) { Xbyak::Ymm ymm_dst = Xbyak::Ymm(vmm_dst.getIdx()); switch (dst_dt) { - case ov::element::f32: - uni_vmovups(op, vmm_dst); - break; - case ov::element::bf16: - uni_vcvtneps2bf16->emit_code({static_cast(vmm_dst.getIdx())}, {static_cast(ymm_dst.getIdx())}); - vmovdqu16(op, ymm_dst); - break; - default: - assert(!"unknown dst_dt"); + case ov::element::f32: + uni_vmovups(op, vmm_dst); + break; + case ov::element::bf16: + uni_vcvtneps2bf16->emit_code({static_cast(vmm_dst.getIdx())}, + {static_cast(ymm_dst.getIdx())}); + vmovdqu16(op, ymm_dst); + break; + default: + assert(!"unknown dst_dt"); } } }; #endif SoftmaxGeneric::SoftmaxGeneric(ov::element::Type inpPrc, ov::element::Type outPrc) - : input_prec(inpPrc), output_prec(outPrc) { + : input_prec(inpPrc), + output_prec(outPrc) { if (ov::element::bf16 == output_prec) { if (!mayiuse(avx512_core)) { OPENVINO_THROW("SoftmaxGeneric doesn't support BF16 precision on this target."); @@ -255,27 +265,27 @@ SoftmaxGeneric::SoftmaxGeneric(ov::element::Type inpPrc, ov::element::Type outPr #endif } -template -void SoftmaxGeneric::calculate(const in_data_t *src_data, out_data_t *dst_data, int B, int C, int H, int W) { +template +void SoftmaxGeneric::calculate(const in_data_t* src_data, out_data_t* dst_data, int B, int C, int H, int W) { for (int b = 0; b < B; b++) { int tail_start = 0; if (softmax_kernel) { - int blocks_num = H*W / block_size; + int blocks_num = H * W / block_size; parallel_for(blocks_num, [&](int ib) { auto arg = jit_args_softmax(); arg.src = src_data + b * C * H * W + ib * block_size; arg.dst = dst_data + b * C * H * W + ib * block_size; - arg.src_stride = static_cast((size_t)(H) * W * sizeof(in_data_t)); - arg.dst_stride = static_cast((size_t)(H) * W * sizeof(out_data_t)); + arg.src_stride = static_cast((size_t)(H)*W * sizeof(in_data_t)); + arg.dst_stride = static_cast((size_t)(H)*W * sizeof(out_data_t)); arg.work_amount = static_cast(C); (*softmax_kernel)(&arg); }); - tail_start = (H*W / block_size) * block_size; + tail_start = (H * W / block_size) * block_size; } parallel_for(H * W - tail_start, [&](int i) { @@ -283,7 +293,8 @@ void SoftmaxGeneric::calculate(const in_data_t *src_data, out_data_t *dst_data, float max = src_data[b * C * H * W + offset]; for (int c = 0; c < C; c++) { float val = src_data[b * C * H * W + c * H * W + offset]; - if (val > max) max = val; + if (val > max) + max = val; } float expSum = 0; @@ -299,7 +310,7 @@ void SoftmaxGeneric::calculate(const in_data_t *src_data, out_data_t *dst_data, } } -void SoftmaxGeneric::execute(const uint8_t *src_data, uint8_t *dst_data, int B, int C, int H, int W) { +void SoftmaxGeneric::execute(const uint8_t* src_data, uint8_t* dst_data, int B, int C, int H, int W) { if (ov::element::f32 == input_prec) { auto float_src_data = reinterpret_cast(src_data); if (ov::element::f32 == output_prec) { @@ -327,5 +338,5 @@ void SoftmaxGeneric::execute(const uint8_t *src_data, uint8_t *dst_data, int B, } } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/softmax.h b/src/plugins/intel_cpu/src/nodes/common/softmax.h index 2e3d5caa4becee..bb450c2ac5a303 100644 --- a/src/plugins/intel_cpu/src/nodes/common/softmax.h +++ b/src/plugins/intel_cpu/src/nodes/common/softmax.h @@ -4,27 +4,28 @@ #pragma once -#include #include -#include "openvino/core/type/element_type.hpp" +#include + #include "defs.h" #include "openvino/core/parallel.hpp" +#include "openvino/core/type/element_type.hpp" namespace ov { namespace intel_cpu { struct jit_uni_softmax_kernel; -static inline -void softmax_many_batches(const float *src_data, float *dst_data, int B, int C, int H, int W) { +static inline void softmax_many_batches(const float* src_data, float* dst_data, int B, int C, int H, int W) { ov::parallel_for(B * H * W, [&](size_t i) { - const float *psrc = src_data + (i / (H * W)) * C * H * W - (i / (H * W)) * H * W; - float *pdst = dst_data + (i / (H * W)) * C * H * W - (i / (H * W)) * H * W; + const float* psrc = src_data + (i / (H * W)) * C * H * W - (i / (H * W)) * H * W; + float* pdst = dst_data + (i / (H * W)) * C * H * W - (i / (H * W)) * H * W; float max = psrc[i]; for (int c = 0; c < C; c++) { float val = psrc[c * H * W + i]; - if (val > max) max = val; + if (val > max) + max = val; } float expSum = 0; @@ -43,9 +44,10 @@ class SoftmaxGeneric { public: SoftmaxGeneric(ov::element::Type inpPrc, ov::element::Type outPrc); - void execute(const uint8_t *src_data, uint8_t *dst_data, int B, int C, int H, int W); + void execute(const uint8_t* src_data, uint8_t* dst_data, int B, int C, int H, int W); + private: - template + template void calculate(const in_data_t* src_data, out_data_t* dst_data, int B, int C, int H, int W); private: @@ -54,5 +56,5 @@ class SoftmaxGeneric { std::shared_ptr softmax_kernel; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/tile_broadcast_utils.cpp b/src/plugins/intel_cpu/src/nodes/common/tile_broadcast_utils.cpp index 6c62304ab22da7..f482b0876b3f4c 100644 --- a/src/plugins/intel_cpu/src/nodes/common/tile_broadcast_utils.cpp +++ b/src/plugins/intel_cpu/src/nodes/common/tile_broadcast_utils.cpp @@ -4,18 +4,17 @@ #include "tile_broadcast_utils.h" +#include + #include "cpu_convert.h" #include "cpu_memcpy.h" -#include "openvino/core/parallel.hpp" -#include #include "memory_desc/dnnl_blocked_memory_desc.h" - - +#include "openvino/core/parallel.hpp" namespace ov { namespace intel_cpu { -VectorDims TileBroadcastCommon::calculateDenseStrides(const VectorDims &dims) { +VectorDims TileBroadcastCommon::calculateDenseStrides(const VectorDims& dims) { VectorDims strides(dims.size(), 1); for (int i = strides.size() - 2; i >= 0; i--) { @@ -25,8 +24,10 @@ VectorDims TileBroadcastCommon::calculateDenseStrides(const VectorDims &dims) { return strides; } -void TileBroadcastCommon::fillOptimizedDimsAndSrcStrides(const VectorDims& srcBlockedDims, const VectorDims& blockedRepeats, - VectorDims& optimizedDims, VectorDims& optimizedSrcStrides) { +void TileBroadcastCommon::fillOptimizedDimsAndSrcStrides(const VectorDims& srcBlockedDims, + const VectorDims& blockedRepeats, + VectorDims& optimizedDims, + VectorDims& optimizedSrcStrides) { optimizedDims.clear(); optimizedSrcStrides.clear(); VectorDims srcBlockedStrides = calculateDenseStrides(srcBlockedDims); @@ -60,10 +61,11 @@ void TileBroadcastCommon::fillOptimizedDimsAndSrcStrides(const VectorDims& srcBl } } -bool TileBroadcastCommon::canBeExecutedInBlockedLayout(VectorDims srcBlockedDims, VectorDims blockedRepeats, - const size_t elemsInBlock) { - if (srcBlockedDims.empty() || blockedRepeats.empty() || elemsInBlock == 0lu || srcBlockedDims[1] == Shape::UNDEFINED_DIM || - (blockedRepeats[1] != 1 && srcBlockedDims[1] % elemsInBlock != 0)) +bool TileBroadcastCommon::canBeExecutedInBlockedLayout(VectorDims srcBlockedDims, + VectorDims blockedRepeats, + const size_t elemsInBlock) { + if (srcBlockedDims.empty() || blockedRepeats.empty() || elemsInBlock == 0lu || + srcBlockedDims[1] == Shape::UNDEFINED_DIM || (blockedRepeats[1] != 1 && srcBlockedDims[1] % elemsInBlock != 0)) return false; srcBlockedDims[1] = div_up(srcBlockedDims[1], elemsInBlock); @@ -90,7 +92,7 @@ bool TileBroadcastCommon::canBeExecutedInNSPCLayout(VectorDims srcBlockedDims, V return optimizedDims.size() <= maxNDims; } -std::vector TileBroadcastCommon::getSupportedConfigs(const Node *node, size_t outSize) { +std::vector TileBroadcastCommon::getSupportedConfigs(const Node* node, size_t outSize) { std::vector supportedPrimitiveDescriptors; auto precision = node->getOriginalInputPrecisionAtPort(0); auto dataType = DnnlExtensionUtils::ElementTypeToDataType(precision); @@ -115,26 +117,31 @@ std::vector TileBroadcastCommon::getSupportedConfigs(const Node *node, config.inConfs[0].constant(constMap[0]); config.inConfs[1].inPlace(-1); config.inConfs[1].constant(constMap[1]); - config.inConfs[1].setMemDesc(std::make_shared(ov::element::i32, node->getInputShapeAtPort(1))); + config.inConfs[1].setMemDesc( + std::make_shared(ov::element::i32, node->getInputShapeAtPort(1))); if (config.inConfs.size() == 3) { config.inConfs[2].inPlace(-1); config.inConfs[2].constant(constMap[2]); - config.inConfs[2].setMemDesc(std::make_shared(ov::element::i32, node->getInputShapeAtPort(2))); + config.inConfs[2].setMemDesc( + std::make_shared(ov::element::i32, node->getInputShapeAtPort(2))); } config.outConfs.resize(outSize); auto pushDesc = [&](dnnl::memory::format_tag inFormat, dnnl::memory::format_tag outFormat) { - config.inConfs[0].setMemDesc(std::make_shared(node->getInputShapeAtPort(0), dataType, inFormat)); + config.inConfs[0].setMemDesc( + std::make_shared(node->getInputShapeAtPort(0), dataType, inFormat)); for (size_t i = 0; i < config.outConfs.size(); i++) { config.outConfs[i].inPlace(-1); config.outConfs[i].constant(false); - config.outConfs[i].setMemDesc(std::make_shared(node->getOutputShapeAtPort(0), dataType, outFormat)); + config.outConfs[i].setMemDesc( + std::make_shared(node->getOutputShapeAtPort(0), dataType, outFormat)); } supportedPrimitiveDescriptors.push_back({config, impl_desc_type::ref}); }; - if (!repeats.empty() && inDataShape.getRank() == outDataShapeRank && (outDataShapeRank == 4 || outDataShapeRank == 5)) { + if (!repeats.empty() && inDataShape.getRank() == outDataShapeRank && + (outDataShapeRank == 4 || outDataShapeRank == 5)) { if (canBeExecutedInBlockedLayout(srcDims, repeats, 16)) { if (outDataShapeRank == 4) { pushDesc(dnnl::memory::format_tag::nChw16c, dnnl::memory::format_tag::nChw16c); @@ -165,7 +172,8 @@ std::vector TileBroadcastCommon::getSupportedConfigs(const Node *node, for (size_t i = 0; i < config.outConfs.size(); i++) { config.outConfs[i].inPlace(-1); config.outConfs[i].constant(false); - config.outConfs[i].setMemDesc(std::make_shared(precision, node->getOutputShapeAtPort(i))); + config.outConfs[i].setMemDesc( + std::make_shared(precision, node->getOutputShapeAtPort(i))); } supportedPrimitiveDescriptors.push_back({config, impl_desc_type::ref}); } else { @@ -175,7 +183,9 @@ std::vector TileBroadcastCommon::getSupportedConfigs(const Node *node, return supportedPrimitiveDescriptors; } -bool TileBroadcastCommon::prepareOptimizedParams(const Node *node, VectorDims& srcBlockedDims, VectorDims& dstBlockedDims) { +bool TileBroadcastCommon::prepareOptimizedParams(const Node* node, + VectorDims& srcBlockedDims, + VectorDims& dstBlockedDims) { while (srcBlockedDims.size() < dstBlockedDims.size()) { srcBlockedDims.insert(srcBlockedDims.begin(), 1); } @@ -186,7 +196,8 @@ bool TileBroadcastCommon::prepareOptimizedParams(const Node *node, VectorDims& s blockedRepeats.push_back(1); } // for NSPC layouts - if (node->getBaseMemDescAtInputPort(0)->hasLayoutType(LayoutType::nspc) && one_of(node->getBaseMemDescAtInputPort(0)->getShape().getRank(), 4u, 5u)) { + if (node->getBaseMemDescAtInputPort(0)->hasLayoutType(LayoutType::nspc) && + one_of(node->getBaseMemDescAtInputPort(0)->getShape().getRank(), 4u, 5u)) { blockedRepeats.push_back(blockedRepeats[1]); blockedRepeats.erase(blockedRepeats.begin() + 1); } @@ -205,7 +216,8 @@ bool TileBroadcastCommon::prepareOptimizedParams(const Node *node, VectorDims& s VectorDims optimizedDstStrides = calculateDenseStrides(optimizedDims); - size_t dataSize = node->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].getMemDesc()->getPrecision().size(); + size_t dataSize = + node->getSelectedPrimitiveDescriptor()->getConfig().inConfs[0].getMemDesc()->getPrecision().size(); for (size_t i = 0; i < optimizedDims.size(); i++) { optimizedSrcStrides[i] *= dataSize; optimizedDstStrides[i] *= dataSize; @@ -221,9 +233,9 @@ bool TileBroadcastCommon::prepareOptimizedParams(const Node *node, VectorDims& s // Broadcast 1 element to N continuous elements based on cpu_memcpy // Step 1: Get the binary format of the number N -// Step 2: Use cpu_memcpy to form fragments containing pow(2, k) (ie. 2, 4, 8, ...) elements, based on the given 1 element -// Step 3: Form N continuous elements, who's a combination of those fragments, demonstrated by its binary format -void TileBroadcastCommon::broadcastScalar(const char *srcData, char *dstData, size_t elt_cnt, size_t data_size) { +// Step 2: Use cpu_memcpy to form fragments containing pow(2, k) (ie. 2, 4, 8, ...) elements, based on the given 1 +// element Step 3: Form N continuous elements, who's a combination of those fragments, demonstrated by its binary format +void TileBroadcastCommon::broadcastScalar(const char* srcData, char* dstData, size_t elt_cnt, size_t data_size) { std::vector binary_digits; binary_digits.clear(); @@ -275,32 +287,44 @@ void TileBroadcastCommon::optimizedExecute(const MemoryPtr& srcMemory, const Mem broadcastScalar(srcData, dstData, elt_cnt, data_size); } } else { - parallel_for5d(optimizedParams.dims[0], optimizedParams.dims[1], optimizedParams.dims[2], optimizedParams.dims[3], optimizedParams.dims[4], - [&](int i0, int i1, int i2, int i3, int i4) { - auto srcData2 = srcData + (i0 * optimizedParams.srcStrides[0] + i1 * optimizedParams.srcStrides[1] + - i2 * optimizedParams.srcStrides[2] + i3 * optimizedParams.srcStrides[3] + - i4 * optimizedParams.srcStrides[4]); - auto dstData2 = dstData + (i0 * optimizedParams.dstStrides[0] + i1 * optimizedParams.dstStrides[1] + - i2 * optimizedParams.dstStrides[2] + i3 * optimizedParams.dstStrides[3] + - i4 * optimizedParams.dstStrides[4]); - for (size_t i = 0; i < optimizedParams.dims[5]; i++) { - cpu_memcpy(dstData2 + i * optimizedParams.dstStrides[5], srcData2, optimizedParams.dstStrides[5]); - } - }); + parallel_for5d( + optimizedParams.dims[0], + optimizedParams.dims[1], + optimizedParams.dims[2], + optimizedParams.dims[3], + optimizedParams.dims[4], + [&](int i0, int i1, int i2, int i3, int i4) { + auto srcData2 = srcData + (i0 * optimizedParams.srcStrides[0] + i1 * optimizedParams.srcStrides[1] + + i2 * optimizedParams.srcStrides[2] + i3 * optimizedParams.srcStrides[3] + + i4 * optimizedParams.srcStrides[4]); + auto dstData2 = dstData + (i0 * optimizedParams.dstStrides[0] + i1 * optimizedParams.dstStrides[1] + + i2 * optimizedParams.dstStrides[2] + i3 * optimizedParams.dstStrides[3] + + i4 * optimizedParams.dstStrides[4]); + for (size_t i = 0; i < optimizedParams.dims[5]; i++) { + cpu_memcpy(dstData2 + i * optimizedParams.dstStrides[5], + srcData2, + optimizedParams.dstStrides[5]); + } + }); } } else { - parallel_for5d(optimizedParams.dims[0], optimizedParams.dims[1], optimizedParams.dims[2], optimizedParams.dims[3], optimizedParams.dims[4], - [&](int i0, int i1, int i2, int i3, int i4) { - auto srcData2 = srcData + (i0 * optimizedParams.srcStrides[0] + i1 * optimizedParams.srcStrides[1] + - i2 * optimizedParams.srcStrides[2] + i3 * optimizedParams.srcStrides[3] + - i4 * optimizedParams.srcStrides[4]); - auto dstData2 = dstData + (i0 * optimizedParams.dstStrides[0] + i1 * optimizedParams.dstStrides[1] + + parallel_for5d( + optimizedParams.dims[0], + optimizedParams.dims[1], + optimizedParams.dims[2], + optimizedParams.dims[3], + optimizedParams.dims[4], + [&](int i0, int i1, int i2, int i3, int i4) { + auto srcData2 = srcData + (i0 * optimizedParams.srcStrides[0] + i1 * optimizedParams.srcStrides[1] + + i2 * optimizedParams.srcStrides[2] + i3 * optimizedParams.srcStrides[3] + + i4 * optimizedParams.srcStrides[4]); + auto dstData2 = dstData + (i0 * optimizedParams.dstStrides[0] + i1 * optimizedParams.dstStrides[1] + i2 * optimizedParams.dstStrides[2] + i3 * optimizedParams.dstStrides[3] + i4 * optimizedParams.dstStrides[4]); - cpu_memcpy(dstData2, srcData2, optimizedParams.copySize); - }); + cpu_memcpy(dstData2, srcData2, optimizedParams.copySize); + }); } } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/tile_broadcast_utils.h b/src/plugins/intel_cpu/src/nodes/common/tile_broadcast_utils.h index 7ae0eacbccd373..6638eba7f88a39 100644 --- a/src/plugins/intel_cpu/src/nodes/common/tile_broadcast_utils.h +++ b/src/plugins/intel_cpu/src/nodes/common/tile_broadcast_utils.h @@ -9,27 +9,28 @@ #include #include - namespace ov { namespace intel_cpu { class TileBroadcastCommon { protected: - static VectorDims calculateDenseStrides(const VectorDims &dims); - std::vector getSupportedConfigs(const Node *node, size_t outSize); - bool prepareOptimizedParams(const Node *node, VectorDims& srcBlockedDims, VectorDims& dstBlockedDims); + static VectorDims calculateDenseStrides(const VectorDims& dims); + std::vector getSupportedConfigs(const Node* node, size_t outSize); + bool prepareOptimizedParams(const Node* node, VectorDims& srcBlockedDims, VectorDims& dstBlockedDims); void optimizedExecute(const MemoryPtr& srcMemory, const MemoryPtr& dstMemory); VectorDims repeats; bool optimizedCase = false; - bool constMap[3] = { false }; + bool constMap[3] = {false}; mutable bool needPrepareParamsVar = false; private: - static void fillOptimizedDimsAndSrcStrides(const VectorDims &srcBlockedDims, const VectorDims &blockedRepeats, - VectorDims &optimizedDims, VectorDims &optimizedSrcStrides); - static void broadcastScalar(const char *srcData, char *dstData, size_t elt_cnt, size_t data_size); + static void fillOptimizedDimsAndSrcStrides(const VectorDims& srcBlockedDims, + const VectorDims& blockedRepeats, + VectorDims& optimizedDims, + VectorDims& optimizedSrcStrides); + static void broadcastScalar(const char* srcData, char* dstData, size_t elt_cnt, size_t data_size); static bool canBeExecutedInBlockedLayout(VectorDims srcDims, VectorDims repeats, const size_t elemsInBlock); static bool canBeExecutedInNSPCLayout(VectorDims srcDims, VectorDims repeats); @@ -42,5 +43,5 @@ class TileBroadcastCommon { } optimizedParams; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/common/uni_simd.h b/src/plugins/intel_cpu/src/nodes/common/uni_simd.h index 7f2cdc7bed4821..dbcec60baa7d4c 100644 --- a/src/plugins/intel_cpu/src/nodes/common/uni_simd.h +++ b/src/plugins/intel_cpu/src/nodes/common/uni_simd.h @@ -5,7 +5,7 @@ #pragma once #if defined(HAVE_SSE) || defined(HAVE_AVX2) || defined(HAVE_AVX512F) -#include +# include #endif namespace ov { @@ -14,348 +14,350 @@ namespace Cpu { #if defined(HAVE_AVX512F) namespace AVX512F { - static inline __m512 _mm_uni_any_ps() { - return __m512{}; - } - - static inline __m512i _mm_uni_any_epi32() { - return __m512i{}; - } - - static inline __m512 _mm_uni_loadu_ps(const float* psrc) { - return _mm512_mask_loadu_ps(_mm_uni_any_ps(), (__mmask16)-1, psrc); - } - - static inline void _mm_uni_storeu_ps(float* pdst, const __m512& vec) { - _mm512_storeu_ps(pdst, vec); - } - - static inline void _mm_uni_storeu_si(void* pdst, const __m512i vec) { - _mm512_storeu_si512(pdst, vec); - } - - static inline __m512 _mm_uni_setzero_ps() { - return _mm512_setzero_ps(); - } - - static inline __m512 _mm_uni_set1_ps(float value) { - return _mm512_set1_ps(value); - } - - static inline __m512 _mm_uni_add_ps(__m512 vec0, __m512 vec1) { - return _mm512_add_ps(vec0, vec1); - } - - static inline __m512 _mm_uni_sub_ps(__m512 vec0, __m512 vec1) { - return _mm512_sub_ps(vec0, vec1); - } - - static inline __m512 _mm_uni_mul_ps(__m512 vec0, __m512 vec1) { - return _mm512_mul_ps(vec0, vec1); - } - - static inline __m512 _mm_uni_div_ps(__m512 vec0, __m512 vec1) { - return _mm512_div_ps(vec0, vec1); - } - - static inline __m512 _mm_uni_sqrt_ps(__m512 vec) { - return _mm512_sqrt_ps(vec); - } - - static inline __m512 _mm_uni_and_ps(__m512 vec0, __m512 vec1) { - return _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(vec0), _mm512_castps_si512(vec1))); - } - - static inline __m512 _mm_uni_or_ps(__m512 vec0, __m512 vec1) { - return _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vec0), _mm512_castps_si512(vec1))); - } - - static inline __m512i _mm_uni_set1_epi32(int value) { - return _mm512_mask_set1_epi32(_mm_uni_any_epi32(), (__mmask16)-1, value); - } - - static inline __m512 _mm_uni_blendv_ps(__m512 vec0, __m512 vec1, __m512 vmask) { - return _mm512_mask_blend_ps(_mm512_cmpneq_epi32_mask(_mm512_castps_si512(vmask), _mm_uni_set1_epi32(0)), vec0, vec1); - } - - static inline __m512 _mm_uni_blendv_ps(__m512 vec0, __m512 vec1, __mmask16 vmask) { - return _mm512_mask_blend_ps(vmask, vec0, vec1); - } - - static inline __m512 _mm_uni_min_ps(__m512 vec0, __m512 vec1) { - return _mm512_min_ps(vec0, vec1); - } - - static inline __m512 _mm_uni_max_ps(__m512 vec0, __m512 vec1) { - return _mm512_max_ps(vec0, vec1); - } - - static inline __m512 _mm_uni_floor_ps(__m512 vec) { - return _mm512_floor_ps(vec); - } - - static inline __m512i _mm_uni_cvtps_epi32(__m512 vec) { - return _mm512_cvtps_epi32(vec); - } - - static inline __m512i _mm_uni_add_epi32(__m512i vec0, __m512i vec1) { - return _mm512_add_epi32(vec0, vec1); - } - - static inline __m512i _mm_uni_slli_epi32(__m512i vec, int value) { - return _mm512_sll_epi32(vec, _mm_set1_epi64x(value)); - } - - static inline __m512 _mm_uni_castsi_ps(__m512i vec) { - return _mm512_castsi512_ps(vec); - } - - static inline __m512i _mm_uni_setzero_si() { - return _mm512_setzero_si512(); - } - - static inline __mmask16 _mm_uni_cmpgt_ps(__m512 vec0, __m512 vec1) { - return _mm512_cmp_ps_mask(vec0, vec1, 14); - } - - static inline __mmask16 _mm_uni_cmpgt_i32(__m512i vec0, __m512i vec1) { - return _mm512_cmp_epi32_mask(vec1, vec0, 1); - } - - static inline __m512i _mm_uni_castps_si(__m512 vec) { - return _mm512_castps_si512(vec); - } - - static inline __m512 _mm_uni_cvtepi32_ps(__m512i vec) { - return _mm512_mask_cvtepi32_ps(_mm_uni_any_ps(), (__mmask16)-1, vec); - } +static inline __m512 _mm_uni_any_ps() { + return __m512{}; +} + +static inline __m512i _mm_uni_any_epi32() { + return __m512i{}; +} + +static inline __m512 _mm_uni_loadu_ps(const float* psrc) { + return _mm512_mask_loadu_ps(_mm_uni_any_ps(), (__mmask16)-1, psrc); +} + +static inline void _mm_uni_storeu_ps(float* pdst, const __m512& vec) { + _mm512_storeu_ps(pdst, vec); +} + +static inline void _mm_uni_storeu_si(void* pdst, const __m512i vec) { + _mm512_storeu_si512(pdst, vec); +} + +static inline __m512 _mm_uni_setzero_ps() { + return _mm512_setzero_ps(); +} + +static inline __m512 _mm_uni_set1_ps(float value) { + return _mm512_set1_ps(value); +} + +static inline __m512 _mm_uni_add_ps(__m512 vec0, __m512 vec1) { + return _mm512_add_ps(vec0, vec1); +} + +static inline __m512 _mm_uni_sub_ps(__m512 vec0, __m512 vec1) { + return _mm512_sub_ps(vec0, vec1); +} + +static inline __m512 _mm_uni_mul_ps(__m512 vec0, __m512 vec1) { + return _mm512_mul_ps(vec0, vec1); +} + +static inline __m512 _mm_uni_div_ps(__m512 vec0, __m512 vec1) { + return _mm512_div_ps(vec0, vec1); +} + +static inline __m512 _mm_uni_sqrt_ps(__m512 vec) { + return _mm512_sqrt_ps(vec); +} + +static inline __m512 _mm_uni_and_ps(__m512 vec0, __m512 vec1) { + return _mm512_castsi512_ps(_mm512_and_epi32(_mm512_castps_si512(vec0), _mm512_castps_si512(vec1))); +} + +static inline __m512 _mm_uni_or_ps(__m512 vec0, __m512 vec1) { + return _mm512_castsi512_ps(_mm512_or_epi32(_mm512_castps_si512(vec0), _mm512_castps_si512(vec1))); +} + +static inline __m512i _mm_uni_set1_epi32(int value) { + return _mm512_mask_set1_epi32(_mm_uni_any_epi32(), (__mmask16)-1, value); +} + +static inline __m512 _mm_uni_blendv_ps(__m512 vec0, __m512 vec1, __m512 vmask) { + return _mm512_mask_blend_ps(_mm512_cmpneq_epi32_mask(_mm512_castps_si512(vmask), _mm_uni_set1_epi32(0)), + vec0, + vec1); +} + +static inline __m512 _mm_uni_blendv_ps(__m512 vec0, __m512 vec1, __mmask16 vmask) { + return _mm512_mask_blend_ps(vmask, vec0, vec1); +} + +static inline __m512 _mm_uni_min_ps(__m512 vec0, __m512 vec1) { + return _mm512_min_ps(vec0, vec1); +} + +static inline __m512 _mm_uni_max_ps(__m512 vec0, __m512 vec1) { + return _mm512_max_ps(vec0, vec1); +} + +static inline __m512 _mm_uni_floor_ps(__m512 vec) { + return _mm512_floor_ps(vec); +} + +static inline __m512i _mm_uni_cvtps_epi32(__m512 vec) { + return _mm512_cvtps_epi32(vec); +} + +static inline __m512i _mm_uni_add_epi32(__m512i vec0, __m512i vec1) { + return _mm512_add_epi32(vec0, vec1); +} + +static inline __m512i _mm_uni_slli_epi32(__m512i vec, int value) { + return _mm512_sll_epi32(vec, _mm_set1_epi64x(value)); +} + +static inline __m512 _mm_uni_castsi_ps(__m512i vec) { + return _mm512_castsi512_ps(vec); +} + +static inline __m512i _mm_uni_setzero_si() { + return _mm512_setzero_si512(); +} + +static inline __mmask16 _mm_uni_cmpgt_ps(__m512 vec0, __m512 vec1) { + return _mm512_cmp_ps_mask(vec0, vec1, 14); +} + +static inline __mmask16 _mm_uni_cmpgt_i32(__m512i vec0, __m512i vec1) { + return _mm512_cmp_epi32_mask(vec1, vec0, 1); +} + +static inline __m512i _mm_uni_castps_si(__m512 vec) { + return _mm512_castps_si512(vec); +} + +static inline __m512 _mm_uni_cvtepi32_ps(__m512i vec) { + return _mm512_mask_cvtepi32_ps(_mm_uni_any_ps(), (__mmask16)-1, vec); +} } // namespace AVX512F #elif defined(HAVE_AVX2) namespace AVX2 { - static inline __m256 _mm_uni_loadu_ps(const float* psrc) { - return _mm256_loadu_ps(psrc); - } +static inline __m256 _mm_uni_loadu_ps(const float* psrc) { + return _mm256_loadu_ps(psrc); +} - static inline void _mm_uni_storeu_ps(float* pdst, const __m256 vec) { - _mm256_storeu_ps(pdst, vec); - } +static inline void _mm_uni_storeu_ps(float* pdst, const __m256 vec) { + _mm256_storeu_ps(pdst, vec); +} - static inline void _mm_uni_storeu_si(__m256i* pdst, const __m256i vec) { - _mm256_storeu_si256(pdst, vec); - } +static inline void _mm_uni_storeu_si(__m256i* pdst, const __m256i vec) { + _mm256_storeu_si256(pdst, vec); +} - static inline __m256 _mm_uni_setzero_ps() { - return _mm256_setzero_ps(); - } +static inline __m256 _mm_uni_setzero_ps() { + return _mm256_setzero_ps(); +} - static inline __m256 _mm_uni_set1_ps(float value) { - return _mm256_set1_ps(value); - } +static inline __m256 _mm_uni_set1_ps(float value) { + return _mm256_set1_ps(value); +} - static inline __m256 _mm_uni_add_ps(__m256 vec0, __m256 vec1) { - return _mm256_add_ps(vec0, vec1); - } +static inline __m256 _mm_uni_add_ps(__m256 vec0, __m256 vec1) { + return _mm256_add_ps(vec0, vec1); +} - static inline __m256 _mm_uni_sub_ps(__m256 vec0, __m256 vec1) { - return _mm256_sub_ps(vec0, vec1); - } +static inline __m256 _mm_uni_sub_ps(__m256 vec0, __m256 vec1) { + return _mm256_sub_ps(vec0, vec1); +} - static inline __m256 _mm_uni_mul_ps(__m256 vec0, __m256 vec1) { - return _mm256_mul_ps(vec0, vec1); - } +static inline __m256 _mm_uni_mul_ps(__m256 vec0, __m256 vec1) { + return _mm256_mul_ps(vec0, vec1); +} - static inline __m256 _mm_uni_div_ps(__m256 vec0, __m256 vec1) { - return _mm256_div_ps(vec0, vec1); - } +static inline __m256 _mm_uni_div_ps(__m256 vec0, __m256 vec1) { + return _mm256_div_ps(vec0, vec1); +} - static inline __m256 _mm_uni_sqrt_ps(__m256 vec) { - return _mm256_sqrt_ps(vec); - } +static inline __m256 _mm_uni_sqrt_ps(__m256 vec) { + return _mm256_sqrt_ps(vec); +} - static inline __m256 _mm_uni_and_ps(__m256 vec0, __m256 vec1) { - return _mm256_and_ps(vec0, vec1); - } +static inline __m256 _mm_uni_and_ps(__m256 vec0, __m256 vec1) { + return _mm256_and_ps(vec0, vec1); +} - static inline __m256 _mm_uni_or_ps(__m256 vec0, __m256 vec1) { - return _mm256_or_ps(vec0, vec1); - } +static inline __m256 _mm_uni_or_ps(__m256 vec0, __m256 vec1) { + return _mm256_or_ps(vec0, vec1); +} - static inline __m256 _mm_uni_blendv_ps(__m256 vec0, __m256 vec1, __m256 vmask) { - return _mm256_blendv_ps(vec0, vec1, vmask); - } +static inline __m256 _mm_uni_blendv_ps(__m256 vec0, __m256 vec1, __m256 vmask) { + return _mm256_blendv_ps(vec0, vec1, vmask); +} - static inline __m256 _mm_uni_min_ps(__m256 vec0, __m256 vec1) { - return _mm256_min_ps(vec0, vec1); - } +static inline __m256 _mm_uni_min_ps(__m256 vec0, __m256 vec1) { + return _mm256_min_ps(vec0, vec1); +} - static inline __m256 _mm_uni_max_ps(__m256 vec0, __m256 vec1) { - return _mm256_max_ps(vec0, vec1); - } +static inline __m256 _mm_uni_max_ps(__m256 vec0, __m256 vec1) { + return _mm256_max_ps(vec0, vec1); +} - static inline __m256 _mm_uni_floor_ps(__m256 vec) { - return _mm256_floor_ps(vec); - } +static inline __m256 _mm_uni_floor_ps(__m256 vec) { + return _mm256_floor_ps(vec); +} - static inline __m256i _mm_uni_cvtps_epi32(__m256 vec) { - return _mm256_cvtps_epi32(vec); - } +static inline __m256i _mm_uni_cvtps_epi32(__m256 vec) { + return _mm256_cvtps_epi32(vec); +} - static inline __m256i _mm_uni_add_epi32(__m256i vec0, __m256i vec1) { - return _mm256_add_epi32(vec0, vec1); - } +static inline __m256i _mm_uni_add_epi32(__m256i vec0, __m256i vec1) { + return _mm256_add_epi32(vec0, vec1); +} - static inline __m256i _mm_uni_set1_epi32(int value) { - return _mm256_set1_epi32(value); - } +static inline __m256i _mm_uni_set1_epi32(int value) { + return _mm256_set1_epi32(value); +} - static inline __m256i _mm_uni_slli_epi32(__m256i vec, int value) { - return _mm256_slli_epi32(vec, value); - } +static inline __m256i _mm_uni_slli_epi32(__m256i vec, int value) { + return _mm256_slli_epi32(vec, value); +} - static inline __m256 _mm_uni_castsi_ps(__m256i vec) { - return _mm256_castsi256_ps(vec); - } +static inline __m256 _mm_uni_castsi_ps(__m256i vec) { + return _mm256_castsi256_ps(vec); +} - static inline __m256i _mm_uni_setzero_si() { - return _mm256_setzero_si256(); - } +static inline __m256i _mm_uni_setzero_si() { + return _mm256_setzero_si256(); +} - static inline __m256 _mm_uni_cmpgt_ps(__m256 vec0, __m256 vec1) { - return _mm256_cmp_ps(vec0, vec1, 14); - } +static inline __m256 _mm_uni_cmpgt_ps(__m256 vec0, __m256 vec1) { + return _mm256_cmp_ps(vec0, vec1, 14); +} - static inline __m256 _mm_uni_cmpgt_i32(__m256i vec0, __m256i vec1) { - return _mm256_cvtepi32_ps(_mm256_cmpgt_epi32(vec0, vec1)); - } +static inline __m256 _mm_uni_cmpgt_i32(__m256i vec0, __m256i vec1) { + return _mm256_cvtepi32_ps(_mm256_cmpgt_epi32(vec0, vec1)); +} - static inline __m256i _mm_uni_blendv_epi8(__m256i vec0, __m256i vec1, __m256i vmask) { - return _mm256_blendv_epi8(vec0, vec1, vmask); - } +static inline __m256i _mm_uni_blendv_epi8(__m256i vec0, __m256i vec1, __m256i vmask) { + return _mm256_blendv_epi8(vec0, vec1, vmask); +} - static inline __m256i _mm_uni_castps_si(__m256 vec) { - return _mm256_castps_si256(vec); - } +static inline __m256i _mm_uni_castps_si(__m256 vec) { + return _mm256_castps_si256(vec); +} - static inline __m256 _mm_uni_cvtepi32_ps(__m256i vec) { - return _mm256_cvtepi32_ps(vec); - } +static inline __m256 _mm_uni_cvtepi32_ps(__m256i vec) { + return _mm256_cvtepi32_ps(vec); +} - static inline int _mm_uni_movemask_ps(__m256 vec) { - return _mm256_movemask_ps(vec); - } +static inline int _mm_uni_movemask_ps(__m256 vec) { + return _mm256_movemask_ps(vec); +} } // namespace AVX2 #elif defined(HAVE_SSE42) namespace SSE42 { - static inline __m128 _mm_uni_loadu_ps(const float* psrc) { - return _mm_loadu_ps(psrc); - } +static inline __m128 _mm_uni_loadu_ps(const float* psrc) { + return _mm_loadu_ps(psrc); +} - static inline void _mm_uni_storeu_ps(float* pdst, const __m128 vec) { - _mm_storeu_ps(pdst, vec); - } - - static inline void _mm_uni_storeu_si(__m128i* pdst, const __m128i vec) { - _mm_storeu_si128(pdst, vec); - } - - static inline __m128 _mm_uni_setzero_ps() { - return _mm_setzero_ps(); - } - - static inline __m128 _mm_uni_set1_ps(float value) { - return _mm_set1_ps(value); - } - - static inline __m128 _mm_uni_add_ps(__m128 vec0, __m128 vec1) { - return _mm_add_ps(vec0, vec1); - } - - static inline __m128 _mm_uni_sub_ps(__m128 vec0, __m128 vec1) { - return _mm_sub_ps(vec0, vec1); - } - - static inline __m128 _mm_uni_mul_ps(__m128 vec0, __m128 vec1) { - return _mm_mul_ps(vec0, vec1); - } - - static inline __m128 _mm_uni_div_ps(__m128 vec0, __m128 vec1) { - return _mm_div_ps(vec0, vec1); - } - - static inline __m128 _mm_uni_sqrt_ps(__m128 vec) { - return _mm_sqrt_ps(vec); - } - - static inline __m128 _mm_uni_and_ps(__m128 vec0, __m128 vec1) { - return _mm_and_ps(vec0, vec1); - } - - static inline __m128 _mm_uni_or_ps(__m128 vec0, __m128 vec1) { - return _mm_or_ps(vec0, vec1); - } - - static inline __m128 _mm_uni_blendv_ps(__m128 vec0, __m128 vec1, __m128 vmask) { - return _mm_blendv_ps(vec0, vec1, vmask); - } - - static inline __m128 _mm_uni_min_ps(__m128 vec0, __m128 vec1) { - return _mm_min_ps(vec0, vec1); - } - - static inline __m128 _mm_uni_max_ps(__m128 vec0, __m128 vec1) { - return _mm_max_ps(vec0, vec1); - } - - static inline __m128 _mm_uni_floor_ps(__m128 vec) { - return _mm_floor_ps(vec); - } - - static inline __m128i _mm_uni_cvtps_epi32(__m128 vec) { - return _mm_cvtps_epi32(vec); - } - - static inline __m128i _mm_uni_add_epi32(__m128i vec0, __m128i vec1) { - return _mm_add_epi32(vec0, vec1); - } - - static inline __m128i _mm_uni_set1_epi32(int value) { - return _mm_set1_epi32(value); - } - - static inline __m128i _mm_uni_slli_epi32(__m128i vec, int value) { - return _mm_slli_epi32(vec, value); - } - - static inline __m128 _mm_uni_castsi_ps(__m128i vec) { - return _mm_castsi128_ps(vec); - } - - static inline __m128i _mm_uni_setzero_si() { - return _mm_setzero_si128(); - } - - static inline __m128 _mm_uni_cmpgt_ps(__m128 vec0, __m128 vec1) { - return _mm_cmpgt_ps(vec0, vec1); - } - - static inline __m128 _mm_uni_cmpgt_i32(__m128i vec0, __m128i vec1) { - return _mm_cvtepi32_ps(_mm_cmpgt_epi32(vec0, vec1)); - } - - static inline __m128i _mm_uni_blendv_epi8(__m128i vec0, __m128i vec1, __m128i vmask) { - return _mm_blendv_epi8(vec0, vec1, vmask); - } - - static inline __m128i _mm_uni_castps_si(__m128 vec) { - return _mm_castps_si128(vec); - } +static inline void _mm_uni_storeu_ps(float* pdst, const __m128 vec) { + _mm_storeu_ps(pdst, vec); +} + +static inline void _mm_uni_storeu_si(__m128i* pdst, const __m128i vec) { + _mm_storeu_si128(pdst, vec); +} + +static inline __m128 _mm_uni_setzero_ps() { + return _mm_setzero_ps(); +} + +static inline __m128 _mm_uni_set1_ps(float value) { + return _mm_set1_ps(value); +} + +static inline __m128 _mm_uni_add_ps(__m128 vec0, __m128 vec1) { + return _mm_add_ps(vec0, vec1); +} + +static inline __m128 _mm_uni_sub_ps(__m128 vec0, __m128 vec1) { + return _mm_sub_ps(vec0, vec1); +} + +static inline __m128 _mm_uni_mul_ps(__m128 vec0, __m128 vec1) { + return _mm_mul_ps(vec0, vec1); +} + +static inline __m128 _mm_uni_div_ps(__m128 vec0, __m128 vec1) { + return _mm_div_ps(vec0, vec1); +} + +static inline __m128 _mm_uni_sqrt_ps(__m128 vec) { + return _mm_sqrt_ps(vec); +} + +static inline __m128 _mm_uni_and_ps(__m128 vec0, __m128 vec1) { + return _mm_and_ps(vec0, vec1); +} + +static inline __m128 _mm_uni_or_ps(__m128 vec0, __m128 vec1) { + return _mm_or_ps(vec0, vec1); +} + +static inline __m128 _mm_uni_blendv_ps(__m128 vec0, __m128 vec1, __m128 vmask) { + return _mm_blendv_ps(vec0, vec1, vmask); +} + +static inline __m128 _mm_uni_min_ps(__m128 vec0, __m128 vec1) { + return _mm_min_ps(vec0, vec1); +} + +static inline __m128 _mm_uni_max_ps(__m128 vec0, __m128 vec1) { + return _mm_max_ps(vec0, vec1); +} + +static inline __m128 _mm_uni_floor_ps(__m128 vec) { + return _mm_floor_ps(vec); +} + +static inline __m128i _mm_uni_cvtps_epi32(__m128 vec) { + return _mm_cvtps_epi32(vec); +} + +static inline __m128i _mm_uni_add_epi32(__m128i vec0, __m128i vec1) { + return _mm_add_epi32(vec0, vec1); +} + +static inline __m128i _mm_uni_set1_epi32(int value) { + return _mm_set1_epi32(value); +} + +static inline __m128i _mm_uni_slli_epi32(__m128i vec, int value) { + return _mm_slli_epi32(vec, value); +} + +static inline __m128 _mm_uni_castsi_ps(__m128i vec) { + return _mm_castsi128_ps(vec); +} + +static inline __m128i _mm_uni_setzero_si() { + return _mm_setzero_si128(); +} + +static inline __m128 _mm_uni_cmpgt_ps(__m128 vec0, __m128 vec1) { + return _mm_cmpgt_ps(vec0, vec1); +} + +static inline __m128 _mm_uni_cmpgt_i32(__m128i vec0, __m128i vec1) { + return _mm_cvtepi32_ps(_mm_cmpgt_epi32(vec0, vec1)); +} + +static inline __m128i _mm_uni_blendv_epi8(__m128i vec0, __m128i vec1, __m128i vmask) { + return _mm_blendv_epi8(vec0, vec1, vmask); +} + +static inline __m128i _mm_uni_castps_si(__m128 vec) { + return _mm_castps_si128(vec); +} - static inline __m128 _mm_uni_cvtepi32_ps(__m128i vec) { - return _mm_cvtepi32_ps(vec); - } - static inline int _mm_uni_movemask_ps(__m128 vec) { - return _mm_movemask_ps(vec); - } +static inline __m128 _mm_uni_cvtepi32_ps(__m128i vec) { + return _mm_cvtepi32_ps(vec); +} +static inline int _mm_uni_movemask_ps(__m128 vec) { + return _mm_movemask_ps(vec); +} } // namespace SSE42 #endif diff --git a/src/plugins/intel_cpu/src/nodes/composite.cpp b/src/plugins/intel_cpu/src/nodes/composite.cpp index a1ceabd6942db1..616d3df6950e9a 100644 --- a/src/plugins/intel_cpu/src/nodes/composite.cpp +++ b/src/plugins/intel_cpu/src/nodes/composite.cpp @@ -4,11 +4,11 @@ #include "composite.h" -#include "nodes/input.h" #include "cpu_memory.h" +#include "nodes/input.h" +#include "shape_inference/shape_inference_internal_dyn.hpp" #include "transformations/cpu_opset/common/op/submodel.hpp" #include "utils/debug_capabilities.h" -#include "shape_inference/shape_inference_internal_dyn.hpp" namespace ov { namespace intel_cpu { diff --git a/src/plugins/intel_cpu/src/nodes/concat.cpp b/src/plugins/intel_cpu/src/nodes/concat.cpp index 635f37b2d05b3a..ef621947d723a7 100644 --- a/src/plugins/intel_cpu/src/nodes/concat.cpp +++ b/src/plugins/intel_cpu/src/nodes/concat.cpp @@ -4,29 +4,29 @@ #include "concat.h" -#include "openvino/op/concat.hpp" +#include +#include +#include +#include +#include #include #include #include -#include "dnnl_extension_utils.h" +#include "common/blocked_desc_creator.h" +#include "common/cpu_memcpy.h" +#include "dnnl_extension_utils.h" #include "onednn/dnnl.h" -#include -#include -#include #include "openvino/core/parallel.hpp" -#include "common/cpu_memcpy.h" -#include "common/blocked_desc_creator.h" -#include -#include +#include "openvino/op/concat.hpp" using namespace dnnl; namespace ov { namespace intel_cpu { namespace node { namespace { - constexpr size_t channelAxis = 1lu; +constexpr size_t channelAxis = 1lu; } bool Concat::isExecutable() const { @@ -86,11 +86,14 @@ void Concat::getSupportedDescriptors() { } } - // we need the first dims before axis to be 1 to avoid the reorder in the edge between the first parent and this concat + // we need the first dims before axis to be 1 to avoid the reorder in the edge between the first parent and this + // concat const auto& childDims = outputShapes[0].getDims(); if (childDims[axis] != Shape::UNDEFINED_DIM && - std::all_of(childDims.begin(), childDims.begin() + axis, [](size_t dim) { return dim == 1; })) + std::all_of(childDims.begin(), childDims.begin() + axis, [](size_t dim) { + return dim == 1; + })) canBeInPlace = true; } @@ -118,11 +121,11 @@ void Concat::initSupportedPrimitiveDescriptors() { const auto& dstShape = getOutputShapeAtPort(0); std::vector tdCreatorTypes = {LayoutType::ncsp, LayoutType::nspc}; - // check if blocked layouts are available the channels size should be evenly divided by the block size to avoid slow oneDNN ref implementation and allow - // inPlace memory usage if possible + // check if blocked layouts are available the channels size should be evenly divided by the block size to avoid slow + // oneDNN ref implementation and allow inPlace memory usage if possible if (dstShape.getRank() > channelAxis) { - for (auto& item : { std::make_pair(8lu, LayoutType::nCsp8c), std::make_pair(16lu, LayoutType::nCsp16c)}) { - const VectorDims &blkDims = dstShape.getDims(); + for (auto& item : {std::make_pair(8lu, LayoutType::nCsp8c), std::make_pair(16lu, LayoutType::nCsp16c)}) { + const VectorDims& blkDims = dstShape.getDims(); if (blkDims[channelAxis] == Shape::UNDEFINED_DIM || blkDims[channelAxis] % item.first != 0) continue; @@ -144,7 +147,8 @@ void Concat::initSupportedPrimitiveDescriptors() { auto& creatorsMap = BlockedDescCreator::getCommonCreators(); - auto itrRange = BlockedDescCreator::makeFilteredRange(creatorsMap, static_cast(dstShape.getRank()), tdCreatorTypes); + auto itrRange = + BlockedDescCreator::makeFilteredRange(creatorsMap, static_cast(dstShape.getRank()), tdCreatorTypes); for (auto itr = itrRange.first; itr != itrRange.second; ++itr) { NodeConfig config; @@ -183,12 +187,15 @@ void Concat::initSupportedPrimitiveDescriptors() { } } - if (!canBeInPlace || std::any_of(inputShapes.begin(), inputShapes.end(), [](const Shape& shape) { return shape.hasZeroDims(); })) + if (!canBeInPlace || std::any_of(inputShapes.begin(), inputShapes.end(), [](const Shape& shape) { + return shape.hasZeroDims(); + })) return; // Optimized inplace case for (auto refPdIndex : pdIndexesToReuse) { - auto config = supportedPrimitiveDescriptors[refPdIndex].getConfig();; + auto config = supportedPrimitiveDescriptors[refPdIndex].getConfig(); + ; for (size_t i = 0; i < config.inConfs.size(); i++) { config.inConfs[i].inPlace(0); } @@ -204,12 +211,16 @@ void Concat::selectOptimalPrimitiveDescriptor() { // for that case. for (size_t i = 0; i < getParentEdges().size(); i++) { for (size_t j = i + 1; j < getParentEdges().size(); j++) { - if (getParentEdgeAt(i) == getParentEdgeAt(j)) canBeInPlace = false; + if (getParentEdgeAt(i) == getParentEdgeAt(j)) + canBeInPlace = false; } } std::map formatFrequency; - std::vector supportedLayouts = {LayoutType::ncsp, LayoutType::nspc, LayoutType::nCsp8c, LayoutType::nCsp16c}; + std::vector supportedLayouts = {LayoutType::ncsp, + LayoutType::nspc, + LayoutType::nCsp8c, + LayoutType::nCsp16c}; for (size_t i = 0; i < getParentEdges().size(); i++) { auto parentEdge = getParentEdgeAt(i); auto parent = parentEdge->getParent(); @@ -218,11 +229,11 @@ void Concat::selectOptimalPrimitiveDescriptor() { if (parent_pdesc == nullptr) continue; - const auto &parent_config = parent_pdesc->getConfig(); + const auto& parent_config = parent_pdesc->getConfig(); int outputIndex = parentEdge->getInputNum(); if (outputIndex < 0 || outputIndex >= static_cast(parent_config.outConfs.size())) OPENVINO_THROW("Cannot find index of output node"); - const auto &port_desc = parent_config.outConfs[outputIndex].getMemDesc(); + const auto& port_desc = parent_config.outConfs[outputIndex].getMemDesc(); for (auto& item : supportedLayouts) { if (port_desc->hasLayoutType(item)) { formatFrequency[item] += 1; @@ -232,15 +243,15 @@ void Concat::selectOptimalPrimitiveDescriptor() { for (size_t i = 0; i < getChildEdges().size(); i++) { auto childEdge = getChildEdgeAt(i); auto child = childEdge->getChild(); - const auto *prim_desc = child->getSelectedPrimitiveDescriptor(); + const auto* prim_desc = child->getSelectedPrimitiveDescriptor(); if (prim_desc == nullptr) continue; - const auto &config = prim_desc->getConfig(); + const auto& config = prim_desc->getConfig(); int inputIndex = childEdge->getOutputNum(); if (inputIndex < 0 || inputIndex >= static_cast(config.inConfs.size())) OPENVINO_THROW("Cannot find index of output node"); - const auto &port_desc = config.inConfs[inputIndex].getMemDesc(); + const auto& port_desc = config.inConfs[inputIndex].getMemDesc(); for (auto& item : supportedLayouts) { if (port_desc->hasLayoutType(item)) { formatFrequency[item] += 1; @@ -249,9 +260,9 @@ void Concat::selectOptimalPrimitiveDescriptor() { } size_t maxCount = 0; - const auto &outDims = getOutputShapeAtPort(0).getDims(); + const auto& outDims = getOutputShapeAtPort(0).getDims(); LayoutType convertTo = LayoutType::ncsp; - for (auto &it : formatFrequency) { + for (auto& it : formatFrequency) { if (it.second > maxCount) { maxCount = it.second; convertTo = it.first; @@ -264,7 +275,7 @@ void Concat::selectOptimalPrimitiveDescriptor() { } } - for (auto& item : { std::make_pair(8lu, LayoutType::nCsp8c), std::make_pair(16lu, LayoutType::nCsp16c) }) { + for (auto& item : {std::make_pair(8lu, LayoutType::nCsp8c), std::make_pair(16lu, LayoutType::nCsp16c)}) { if (convertTo == item.second) { if (outDims[channelAxis] == Shape::UNDEFINED_DIM || outDims[1] % item.first != 0) { convertTo = LayoutType::ncsp; @@ -282,7 +293,8 @@ void Concat::selectOptimalPrimitiveDescriptor() { for (size_t i = 0; i < supportedPrimitiveDescriptors.size(); ++i) { if (supportedPrimitiveDescriptors[i].getConfig().outConfs[0].getMemDesc()->hasLayoutType(convertTo)) { - if (IMPLICATION(supportedPrimitiveDescriptors[i].getImplementationType() == impl_desc_type::unknown, canBeInPlace)) { + if (IMPLICATION(supportedPrimitiveDescriptors[i].getImplementationType() == impl_desc_type::unknown, + canBeInPlace)) { canSelectPrimitive.push_back(i); } } @@ -444,24 +456,26 @@ void Concat::initOptimalPrimitiveDescriptor() { if (selected_pd == nullptr) OPENVINO_THROW("Preferable primitive descriptor is not set."); - if (!isInPlace()) { - Node::initOptimalPrimitiveDescriptor(); + if (!isInPlace()) { + Node::initOptimalPrimitiveDescriptor(); auto config = selected_pd->getConfig(); if (!isConfigDefined(config)) { for (size_t i = 0; i < config.inConfs.size(); i++) { // Concat doesn't support different precision on inputs - config.inConfs[i].setMemDesc(getConsistentInputDesc(config, i)->getMemDesc()->cloneWithNewPrecision(inputPrecision)); + config.inConfs[i].setMemDesc( + getConsistentInputDesc(config, i)->getMemDesc()->cloneWithNewPrecision(inputPrecision)); } for (size_t i = 0; i < config.outConfs.size(); i++) { - config.outConfs[i].setMemDesc(getConsistentOutputDesc(config, i)->getMemDesc()->cloneWithNewPrecision(outputPrecision)); + config.outConfs[i].setMemDesc( + getConsistentOutputDesc(config, i)->getMemDesc()->cloneWithNewPrecision(outputPrecision)); } initDescriptor(config); } } - //block layout may have axis greater than rank, disable ref_concat + // block layout may have axis greater than rank, disable ref_concat auto primDesc = getSelectedPrimitiveDescriptor(); auto memDesc = primDesc->getConfig().outConfs[0].getMemDesc()->as(); auto rank = memDesc->getShape().getRank(); @@ -474,7 +488,9 @@ void Concat::initOptimalPrimitiveDescriptor() { srcPtrs.resize(getParentEdges().size()); } // check if selected Tensor descriptor has nspc layout and concat axis is C - canOptimizeNspc = axis == channelAxis && getSelectedPrimitiveDescriptor()->getConfig().outConfs.front().getMemDesc()->hasLayoutType(LayoutType::nspc); + canOptimizeNspc = + axis == channelAxis && + getSelectedPrimitiveDescriptor()->getConfig().outConfs.front().getMemDesc()->hasLayoutType(LayoutType::nspc); } void Concat::execute(dnnl::stream strm) { @@ -497,7 +513,7 @@ void Concat::execute(dnnl::stream strm) { } else { const auto& dst_memory = getChildEdgeAt(0)->getMemory(); const size_t num_src = getParentEdges().size(); - std::unordered_map mem_ags {{DNNL_ARG_DST, dst_memory.getPrimitive()}}; + std::unordered_map mem_ags{{DNNL_ARG_DST, dst_memory.getPrimitive()}}; size_t nonZeroInShapes = 0; for (size_t i = 0; i < num_src; i++) { const auto& srcMem = getParentEdgeAt(i)->getMemory(); @@ -580,7 +596,7 @@ void Concat::execRef() { } if (!hasOuterLoop) { - if (nelemTotal < 64*1024 || parallel_get_max_threads() == 1) { + if (nelemTotal < 64 * 1024 || parallel_get_max_threads() == 1) { for (size_t a = 0; a < srcPtrs.size(); ++a) { const auto inData = srcPtrs[a]; auto outputData = &dstPtr[dstOffset[a]]; @@ -612,63 +628,65 @@ void Concat::execRef() { physDims[i] = outputShape[i]; } const auto L1Size = dnnl::utils::get_cache_size(1, true); - UNUSED(L1Size); // for Windows - parallel_for6d(physDims[0], physDims[1], physDims[2], physDims[3], physDims[4], numSrc, - [&](size_t n0, size_t n1, size_t n2, size_t n3, size_t n4, size_t a) { - // check if zero memory - if (srcPtrs[a] == nullptr) return; - - size_t inOff = inputStrides[a][0] * n0 + inputStrides[a][1] * n1 + inputStrides[a][2] * n2 - + inputStrides[a][3] * n3 + inputStrides[a][4] * n4; - size_t outOff = outputStrides[0] * n0 + outputStrides[1] * n1 + outputStrides[2] * n2 - + outputStrides[3] * n3 + outputStrides[4] * n4; - const uint8_t *i = &srcPtrs[a][inOff]; - uint8_t *o = &dstPtr[dstOffset[a] + outOff]; + UNUSED(L1Size); // for Windows + parallel_for6d(physDims[0], + physDims[1], + physDims[2], + physDims[3], + physDims[4], + numSrc, + [&](size_t n0, size_t n1, size_t n2, size_t n3, size_t n4, size_t a) { + // check if zero memory + if (srcPtrs[a] == nullptr) + return; + + size_t inOff = inputStrides[a][0] * n0 + inputStrides[a][1] * n1 + inputStrides[a][2] * n2 + + inputStrides[a][3] * n3 + inputStrides[a][4] * n4; + size_t outOff = outputStrides[0] * n0 + outputStrides[1] * n1 + outputStrides[2] * n2 + + outputStrides[3] * n3 + outputStrides[4] * n4; + const uint8_t* i = &srcPtrs[a][inOff]; + uint8_t* o = &dstPtr[dstOffset[a] + outOff]; #if defined(__GNUC__) - // Heuristic: - // memcpy works generally faster for data sizes not - // exceeding L1 cache. - if (nelemToCopy[a] > L1Size) { - // The code below performs data copying: o[e] = i[e] - // and uses a workaround to make GNU compilers optimize it - uint8_t *ptro = o; - const uint8_t *ptri = i; - // head part: bytes before 4 byte-align's address - const size_t headPart = sizeof(uint32_t) - - reinterpret_cast(ptro) - % sizeof(uint32_t); - - // main part: bytes in 4 byte-align - const size_t mainPart - = (nelemToCopy[a] - headPart) / sizeof(uint32_t); - // tail part: bytes after 4 byte-align - const size_t tailPart - = (nelemToCopy[a]) - headPart - - (mainPart * sizeof(uint32_t)); - // copy head part - for (size_t e = 0; e < headPart; ++e) { - *ptro = *ptri; - ++ptro; - ++ptri; - } - // copy main part - std::memcpy(ptro, ptri, mainPart * sizeof(uint32_t)); - ptro += mainPart * sizeof(uint32_t); - ptri += mainPart * sizeof(uint32_t); - // copy tail part - for (size_t e = 0; e < tailPart; ++e) { - *ptro = *ptri; - ++ptro; - ++ptri; - } - } else { - std::memcpy(o, i, nelemToCopy[a]); - } + // Heuristic: + // memcpy works generally faster for data sizes not + // exceeding L1 cache. + if (nelemToCopy[a] > L1Size) { + // The code below performs data copying: o[e] = i[e] + // and uses a workaround to make GNU compilers optimize it + uint8_t* ptro = o; + const uint8_t* ptri = i; + // head part: bytes before 4 byte-align's address + const size_t headPart = + sizeof(uint32_t) - reinterpret_cast(ptro) % sizeof(uint32_t); + + // main part: bytes in 4 byte-align + const size_t mainPart = (nelemToCopy[a] - headPart) / sizeof(uint32_t); + // tail part: bytes after 4 byte-align + const size_t tailPart = (nelemToCopy[a]) - headPart - (mainPart * sizeof(uint32_t)); + // copy head part + for (size_t e = 0; e < headPart; ++e) { + *ptro = *ptri; + ++ptro; + ++ptri; + } + // copy main part + std::memcpy(ptro, ptri, mainPart * sizeof(uint32_t)); + ptro += mainPart * sizeof(uint32_t); + ptri += mainPart * sizeof(uint32_t); + // copy tail part + for (size_t e = 0; e < tailPart; ++e) { + *ptro = *ptri; + ++ptro; + ++ptri; + } + } else { + std::memcpy(o, i, nelemToCopy[a]); + } #else std::memcpy(o, i, nelemToCopy[a]); #endif - }); + }); } } @@ -691,8 +709,10 @@ void Concat::resolveInPlaceEdges(Edge::LOOK look) { " can't use inPlace memory with concatenation on dynamic dimension"); auto edges = getChildEdgesAtPort(inplaceOutIndx); - auto itr = std::find_if(edges.begin(), edges.end(), [](const EdgePtr& edge) { return edge->getStatus() == Edge::Status::Allocated; }); - OPENVINO_ASSERT(itr != edges.end(), " Could not find allocated child edge for concat node: " , getName()); + auto itr = std::find_if(edges.begin(), edges.end(), [](const EdgePtr& edge) { + return edge->getStatus() == Edge::Status::Allocated; + }); + OPENVINO_ASSERT(itr != edges.end(), " Could not find allocated child edge for concat node: ", getName()); auto baseMemBlock = (*itr)->getMemory().getMemoryBlock(); OPENVINO_ASSERT(baseMemBlock != nullptr, " NULL base memory block in concat node: ", getName()); @@ -726,6 +746,6 @@ void Concat::resolveInPlaceEdges(Edge::LOOK look) { } } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/concat.h b/src/plugins/intel_cpu/src/nodes/concat.h index 9ed331bee4f16d..8b75e3839a372d 100644 --- a/src/plugins/intel_cpu/src/nodes/concat.h +++ b/src/plugins/intel_cpu/src/nodes/concat.h @@ -4,8 +4,8 @@ #pragma once -#include "node.h" #include "graph_context.h" +#include "node.h" namespace ov { namespace intel_cpu { @@ -22,7 +22,9 @@ class Concat : public Node { void selectOptimalPrimitiveDescriptor() override; bool created() const override; void execute(dnnl::stream strm) override; - void executeDynamicImpl(dnnl::stream strm) override { execute(strm); } + void executeDynamicImpl(dnnl::stream strm) override { + execute(strm); + } void resolveInPlaceEdges(Edge::LOOK look) override; ov::element::Type getRuntimePrecision() const override; @@ -42,9 +44,9 @@ class Concat : public Node { void execNspcSpecCase(); void exec1DCase(); std::vector inputStrides; - std::vector nelemToCopy; // byte moved in each iter + std::vector nelemToCopy; // byte moved in each iter size_t nelemTotal = 0; - std::vector dstOffset; // dst offset for each input + std::vector dstOffset; // dst offset for each input std::vector srcPtrs; bool hasOuterLoop = false; ov::element::Type inputPrecision = ov::element::f32; @@ -54,6 +56,6 @@ class Concat : public Node { dnnl::primitive prim; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/conv.cpp b/src/plugins/intel_cpu/src/nodes/conv.cpp index 53d53d093cfabf..4cb2dc9058551f 100644 --- a/src/plugins/intel_cpu/src/nodes/conv.cpp +++ b/src/plugins/intel_cpu/src/nodes/conv.cpp @@ -4,8 +4,11 @@ #include "conv.h" -#include "openvino/op/convolution.hpp" -#include "openvino/op/group_conv.hpp" +#include +#include +#include +#include + #include "common/c_types_map.hpp" #include "common/cpu_convert.h" #include "common/primitive_desc.hpp" @@ -27,17 +30,14 @@ #include "oneapi/dnnl/dnnl_common.hpp" #include "oneapi/dnnl/dnnl_types.h" #include "onednn/dnnl.h" +#include "openvino/op/convolution.hpp" +#include "openvino/op/group_conv.hpp" #include "pooling.h" #include "reorder.h" #include "utils/cpu_utils.hpp" #include "utils/debug_capabilities.h" #include "utils/general_utils.h" -#include -#include -#include -#include - using namespace dnnl; namespace ov { @@ -88,7 +88,7 @@ size_t ConvKey::hash() const { return seed; } -bool ConvKey::operator==(const ConvKey &rhs) const { +bool ConvKey::operator==(const ConvKey& rhs) const { bool retVal = true; if (inp0 != rhs.inp0) { retVal = retVal && inp0 && rhs.inp0 && inp0->getDnnlDesc() == rhs.inp0->getDnnlDesc(); @@ -112,11 +112,11 @@ bool ConvKey::operator==(const ConvKey &rhs) const { return retVal; } -} // namespace +} // namespace class Convolution::FusedSubgraph { public: - FusedSubgraph(const std::vector &opList, const Convolution &conv, const GraphContext::CPtr context) { + FusedSubgraph(const std::vector& opList, const Convolution& conv, const GraphContext::CPtr context) { _graph = std::unique_ptr(new Graph()); std::unordered_set nodesSet; @@ -130,16 +130,16 @@ class Convolution::FusedSubgraph { nodesSet.insert(child); }; - //Make inputs - const auto &inpMemDesc1 = conv.getBaseMemDescAtOutputPort(0); + // Make inputs + const auto& inpMemDesc1 = conv.getBaseMemDescAtOutputPort(0); auto inp0 = std::make_shared(inpMemDesc1, "inp0", "Parameter", context); inputs.push_back(inp0); const size_t sumPortNum = conv.getParentEdges().size() - 1; - const auto &inpMemDesc2 = conv.getBaseMemDescAtInputPort(sumPortNum); + const auto& inpMemDesc2 = conv.getBaseMemDescAtInputPort(sumPortNum); auto inp1 = std::make_shared(inpMemDesc2, "inp1", "Parameter", context); inputs.push_back(inp1); - auto itr = std::find_if(opList.begin(), opList.end(), [](const NodePtr &node) { + auto itr = std::find_if(opList.begin(), opList.end(), [](const NodePtr& node) { if (auto eltwise = std::dynamic_pointer_cast(node)) { return eltwise->isSpecialConvolutionAddFusing(); } @@ -153,7 +153,7 @@ class Convolution::FusedSubgraph { addEdge(inp0, sumNode, 0, 0); addEdge(inp1, sumNode, 0, 1); - //Replicate the rest of the subgraph + // Replicate the rest of the subgraph auto parentItr = itr; while (++itr != opList.end()) { auto parentNode = *parentItr; @@ -173,8 +173,8 @@ class Convolution::FusedSubgraph { } } - //Make output - const auto &outMemDesc = conv.getBaseMemDescAtOutputPort(0); + // Make output + const auto& outMemDesc = conv.getBaseMemDescAtOutputPort(0); auto out = std::make_shared(outMemDesc, "out", "Result", context); addEdge(*parentItr, out, 0, 0); outputs.push_back(out); @@ -240,9 +240,20 @@ bool Convolution::isSupportedOperation(const std::shared_ptr& op } Convolution::Convolution(const std::shared_ptr& op, const GraphContext::CPtr context) - : Node(op, context, NgraphShapeInferFactory(op)), withBiases(false), withSum(false), withDWConv(false), - isGrouped(false), dw_conv_oc(0), dw_conv_ih(0), dw_conv_iw(0), dw_conv_in_dt(memory::data_type::undef), - groupNum(1lu), IC(1), groupIC(1), groupOC(1), eltwisePrecision(ov::element::f32) { + : Node(op, context, NgraphShapeInferFactory(op)), + withBiases(false), + withSum(false), + withDWConv(false), + isGrouped(false), + dw_conv_oc(0), + dw_conv_ih(0), + dw_conv_iw(0), + dw_conv_in_dt(memory::data_type::undef), + groupNum(1lu), + IC(1), + groupIC(1), + groupOC(1), + eltwisePrecision(ov::element::f32) { std::string errorMessage; if (!isSupportedOperation(op, errorMessage)) { OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage); @@ -296,12 +307,12 @@ Convolution::Convolution(const std::shared_ptr& op, const GraphContext } paddingL = groupConvolutionOp->get_pads_begin(); paddingR = groupConvolutionOp->get_pads_end(); - autoPadding = one_of(groupConvolutionOp->get_auto_pad(), ov::op::PadType::SAME_UPPER, ov::op::PadType::SAME_LOWER); + autoPadding = + one_of(groupConvolutionOp->get_auto_pad(), ov::op::PadType::SAME_UPPER, ov::op::PadType::SAME_LOWER); } // Only apply this heuristic logic on FP32 IR. IC=1 ,OC=1 would disable brgconv on avx2. const bool isAvx2FP32 = !dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) && - dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) && - !context->isGraphQuantized(); + dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) && !context->isGraphQuantized(); useJitPlanar = ((IC == 1 && groupOC * groupNum == 1) && isAvx2FP32); } @@ -315,7 +326,8 @@ bool Convolution::canBeExecutedInInt8() const { if (!legacyWeightsZeroPoints.empty()) weightsDataType = memory::data_type::s8; - return one_of(inputDataType, memory::data_type::u8, memory::data_type::s8) && weightsDataType == memory::data_type::s8; + return one_of(inputDataType, memory::data_type::u8, memory::data_type::s8) && + weightsDataType == memory::data_type::s8; } ov::element::Type Convolution::fusedEltwisePrecision(const NodePtr& fusingNode) const { @@ -338,62 +350,63 @@ ov::element::Type Convolution::fusedEltwisePrecision(const NodePtr& fusingNode) const std::vector& Convolution::getDefaultImplPriority() { static const std::vector priorities = { - impl_desc_type::unknown, - impl_desc_type::dw_acl, - impl_desc_type::winograd_acl, - impl_desc_type::gemm_acl, - impl_desc_type::acl, - impl_desc_type::brgconv_avx512_dw, - impl_desc_type::brgconv_avx512_amx_1x1, - impl_desc_type::brgconv_avx512_amx, - impl_desc_type::jit_avx512_amx_dw, - impl_desc_type::jit_avx512_amx_1x1, - impl_desc_type::jit_avx512_amx, - impl_desc_type::brgconv_avx512_1x1, - impl_desc_type::brgconv_avx512, - impl_desc_type::jit_avx512_dw, - impl_desc_type::jit_avx512_1x1, - impl_desc_type::jit_avx512, - impl_desc_type::brgconv_avx2_dw, - impl_desc_type::brgconv_avx2_1x1, - impl_desc_type::brgconv_avx2, - impl_desc_type::jit_uni_dw, - impl_desc_type::jit_uni_1x1, - impl_desc_type::jit_uni, - impl_desc_type::jit_avx2_dw, - impl_desc_type::jit_avx2_1x1, - impl_desc_type::jit_avx2, - impl_desc_type::jit_avx_dw, - impl_desc_type::jit_avx_1x1, - impl_desc_type::jit_avx, - impl_desc_type::jit_sse42_dw, - impl_desc_type::jit_sse42_1x1, - impl_desc_type::jit_sse42, - impl_desc_type::gemm_any, - impl_desc_type::gemm_blas, - impl_desc_type::gemm_avx512, - impl_desc_type::gemm_avx2, - impl_desc_type::gemm_avx, - impl_desc_type::gemm_sse42, - impl_desc_type::jit_gemm, - impl_desc_type::ref_any, - impl_desc_type::ref, - }; - if (isBrgConvAvailable()) - return priorities; - - static const std::vector priorities_wo_brgemm = [&] { - std::vectorresult; - std::copy_if(priorities.begin(), priorities.end(), std::back_inserter(result), - [](impl_desc_type type) { return !(type & impl_desc_type::brgconv); }); - return result;}(); - return priorities_wo_brgemm; + impl_desc_type::unknown, + impl_desc_type::dw_acl, + impl_desc_type::winograd_acl, + impl_desc_type::gemm_acl, + impl_desc_type::acl, + impl_desc_type::brgconv_avx512_dw, + impl_desc_type::brgconv_avx512_amx_1x1, + impl_desc_type::brgconv_avx512_amx, + impl_desc_type::jit_avx512_amx_dw, + impl_desc_type::jit_avx512_amx_1x1, + impl_desc_type::jit_avx512_amx, + impl_desc_type::brgconv_avx512_1x1, + impl_desc_type::brgconv_avx512, + impl_desc_type::jit_avx512_dw, + impl_desc_type::jit_avx512_1x1, + impl_desc_type::jit_avx512, + impl_desc_type::brgconv_avx2_dw, + impl_desc_type::brgconv_avx2_1x1, + impl_desc_type::brgconv_avx2, + impl_desc_type::jit_uni_dw, + impl_desc_type::jit_uni_1x1, + impl_desc_type::jit_uni, + impl_desc_type::jit_avx2_dw, + impl_desc_type::jit_avx2_1x1, + impl_desc_type::jit_avx2, + impl_desc_type::jit_avx_dw, + impl_desc_type::jit_avx_1x1, + impl_desc_type::jit_avx, + impl_desc_type::jit_sse42_dw, + impl_desc_type::jit_sse42_1x1, + impl_desc_type::jit_sse42, + impl_desc_type::gemm_any, + impl_desc_type::gemm_blas, + impl_desc_type::gemm_avx512, + impl_desc_type::gemm_avx2, + impl_desc_type::gemm_avx, + impl_desc_type::gemm_sse42, + impl_desc_type::jit_gemm, + impl_desc_type::ref_any, + impl_desc_type::ref, + }; + if (isBrgConvAvailable()) + return priorities; + + static const std::vector priorities_wo_brgemm = [&] { + std::vector result; + std::copy_if(priorities.begin(), priorities.end(), std::back_inserter(result), [](impl_desc_type type) { + return !(type & impl_desc_type::brgconv); + }); + return result; + }(); + return priorities_wo_brgemm; } const bool Convolution::isBrgConvAvailable() { - //When avx2 brgconv heuristic case, disable brgconv to WA the regression. - const bool isBrgConvAvailable = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) && - !useJitPlanar; + // When avx2 brgconv heuristic case, disable brgconv to WA the regression. + const bool isBrgConvAvailable = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) && !useJitPlanar; return isBrgConvAvailable; } @@ -413,7 +426,7 @@ void Convolution::getSupportedDescriptors() { } if (fusedWith[i]->getAlgorithm() == Algorithm::EltwiseAdd) { - auto* eltwiseNode = dynamic_cast(fusedWith[i].get()); + auto* eltwiseNode = dynamic_cast(fusedWith[i].get()); if (eltwiseNode && eltwiseNode->isSpecialConvolutionAddFusing()) { expectedInputEdgesNum++; } @@ -427,17 +440,19 @@ void Convolution::getSupportedDescriptors() { outputDataType = DnnlExtensionUtils::ElementTypeToDataType(getOriginalOutputPrecisionAtPort(0)); eltwisePrecision = DnnlExtensionUtils::DataTypeToElementType(outputDataType); if (!fusedWith.empty()) { - outputDataType = DnnlExtensionUtils::ElementTypeToDataType(fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0)); + outputDataType = DnnlExtensionUtils::ElementTypeToDataType( + fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0)); eltwisePrecision = DnnlExtensionUtils::DataTypeToElementType(outputDataType); } // We need to make sure that convolution output and second input of fused Eltwise operation - // have equal precision sizes since they use the same physical memory. In case precisions are different we upscale to FP32. + // have equal precision sizes since they use the same physical memory. In case precisions are different we upscale + // to FP32. if (outputDataType != memory::data_type::f32 && outputDataType != memory::data_type::bf16 && outputDataType != memory::data_type::f16 && withSum) { for (size_t i = 0; i < fusedWith.size(); i++) { if (fusedWith[i]->getAlgorithm() == Algorithm::EltwiseAdd) { - auto* eltwiseNode = dynamic_cast(fusedWith[i].get()); + auto* eltwiseNode = dynamic_cast(fusedWith[i].get()); if (eltwiseNode && eltwiseNode->isSpecialConvolutionAddFusing()) { eltwisePrecision = fusedEltwisePrecision(fusedWith[i]); if (DnnlExtensionUtils::DataTypeToElementType(outputDataType).size() != eltwisePrecision.size()) { @@ -468,7 +483,7 @@ void Convolution::getSupportedDescriptors() { } for (size_t i = 0; i < fusedWith.size(); i++) { - auto *convolutionNode = dynamic_cast(fusedWith[i].get()); + auto* convolutionNode = dynamic_cast(fusedWith[i].get()); if (convolutionNode) { auto& inActivationDims = convolutionNode->inputShapes[0].getStaticDims(); dw_conv_ih = inActivationDims[convolutionNode->inputShapes[0].getRank() - 2]; @@ -477,7 +492,7 @@ void Convolution::getSupportedDescriptors() { auto& outDims = convolutionNode->outputShapes[0].getStaticDims(); dw_conv_oc = outDims[1]; - const auto &dwWeightsDims = convolutionNode->inputShapes[1].getStaticDims(); + const auto& dwWeightsDims = convolutionNode->inputShapes[1].getStaticDims(); dw_conv_kernel.push_back(dwWeightsDims[dwWeightsDims.size() - 1]); dw_conv_kernel.push_back(dwWeightsDims[dwWeightsDims.size() - 2]); dw_conv_strides = convolutionNode->getStride(); @@ -486,7 +501,8 @@ void Convolution::getSupportedDescriptors() { if (i == 0) { dw_conv_in_dt = DnnlExtensionUtils::ElementTypeToDataType(getOriginalOutputPrecisionAtPort(0)); } else { - dw_conv_in_dt = DnnlExtensionUtils::ElementTypeToDataType(fusedWith[i - 1]->getOriginalOutputPrecisionAtPort(0)); + dw_conv_in_dt = DnnlExtensionUtils::ElementTypeToDataType( + fusedWith[i - 1]->getOriginalOutputPrecisionAtPort(0)); } } else { dw_conv_in_dt = memory::data_type::f32; @@ -498,7 +514,7 @@ void Convolution::getSupportedDescriptors() { int src = getInputShapeAtPort(0).getStaticDims()[2 + j]; int dst = getOutputShapeAtPort(0).getStaticDims()[2 + j]; - krn = (krn - 1)*(dilation[j] + 1) + 1; + krn = (krn - 1) * (dilation[j] + 1) + 1; int calc_dst = (src - krn + paddingL[j]) / stride[j] + 1; paddingR[j] = (dst - calc_dst) * stride[j]; } @@ -506,10 +522,14 @@ void Convolution::getSupportedDescriptors() { } MemoryDescPtr in_candidate, out_candidate; - memory::format_tag nspc = ndims == 3 ? memory::format_tag::nwc : (ndims == 4 ? memory::format_tag::nhwc : memory::format_tag::ndhwc); - memory::format_tag ncsp = ndims == 3 ? memory::format_tag::ncw : (ndims == 4 ? memory::format_tag::nchw : memory::format_tag::ncdhw); - memory::format_tag nCsp8c = ndims == 3 ? memory::format_tag::nCw8c : (ndims == 4 ? memory::format_tag::nChw8c : memory::format_tag::nCdhw8c); - memory::format_tag nCsp16c = ndims == 3 ? memory::format_tag::nCw16c : (ndims == 4 ? memory::format_tag::nChw16c : memory::format_tag::nCdhw16c); + memory::format_tag nspc = + ndims == 3 ? memory::format_tag::nwc : (ndims == 4 ? memory::format_tag::nhwc : memory::format_tag::ndhwc); + memory::format_tag ncsp = + ndims == 3 ? memory::format_tag::ncw : (ndims == 4 ? memory::format_tag::nchw : memory::format_tag::ncdhw); + memory::format_tag nCsp8c = ndims == 3 ? memory::format_tag::nCw8c + : (ndims == 4 ? memory::format_tag::nChw8c : memory::format_tag::nCdhw8c); + memory::format_tag nCsp16c = ndims == 3 ? memory::format_tag::nCw16c + : (ndims == 4 ? memory::format_tag::nChw16c : memory::format_tag::nCdhw16c); if (canBeExecutedInInt8()) { DEBUG_LOG(getName(), "Creating I8 descriptor"); @@ -524,7 +544,7 @@ void Convolution::getSupportedDescriptors() { in_candidate = std::make_shared(getInputShapeAtPort(0), inputDataType, nspc); out_candidate = std::make_shared(getOutputShapeAtPort(0), outputDataType, nspc); - createDescriptor({ in_candidate }, { out_candidate }); + createDescriptor({in_candidate}, {out_candidate}); return; } @@ -549,7 +569,7 @@ void Convolution::getSupportedDescriptors() { eltwisePrecision = ov::element::f32; for (size_t i = 0; i < fusedWith.size(); i++) { if (fusedWith[i]->getAlgorithm() == Algorithm::EltwiseAdd) { - auto* eltwiseNode = dynamic_cast(fusedWith[i].get()); + auto* eltwiseNode = dynamic_cast(fusedWith[i].get()); if (eltwiseNode && eltwiseNode->isSpecialConvolutionAddFusing()) { eltwisePrecision = fusedEltwisePrecision(fusedWith[i]); // TODO(amalyshe): there might be situation when convolution can be executed in BF16, @@ -581,42 +601,44 @@ void Convolution::getSupportedDescriptors() { #if defined(OPENVINO_ARCH_X86_64) // nspc shows better performance only with brgconv implementation - bool nspcFirst = isBrgConvAvailable() && one_of(inputDataType, memory::data_type::f16, memory::data_type::bf16, memory::data_type::f32); + bool nspcFirst = isBrgConvAvailable() && + one_of(inputDataType, memory::data_type::f16, memory::data_type::bf16, memory::data_type::f32); bool nspcAdded = false; if (nspcFirst) { in_candidate = std::make_shared(inputShape, inputDataType, nspc); out_candidate = std::make_shared(outputShape, outputDataType, nspc); - createDescriptor({ in_candidate }, { out_candidate }); + createDescriptor({in_candidate}, {out_candidate}); nspcAdded = true; } if (IC == 1 && groupOC == 1) { in_candidate = std::make_shared(inputShape, inputDataType, ncsp); out_candidate = std::make_shared(outputShape, outputDataType, ncsp); - createDescriptor({ in_candidate }, { out_candidate }); + createDescriptor({in_candidate}, {out_candidate}); } else if (IC < 4) { in_candidate = std::make_shared(inputShape, inputDataType, ncsp); out_candidate = std::make_shared(outputShape, outputDataType, nCsp16c); - createDescriptor({ in_candidate }, { out_candidate }); + createDescriptor({in_candidate}, {out_candidate}); out_candidate = std::make_shared(outputShape, outputDataType, nCsp8c); - createDescriptor({ in_candidate }, { out_candidate }); + createDescriptor({in_candidate}, {out_candidate}); } else { in_candidate = std::make_shared(inputShape, inputDataType, nCsp16c); out_candidate = std::make_shared(outputShape, outputDataType, nCsp16c); - createDescriptor({ in_candidate }, { out_candidate }); + createDescriptor({in_candidate}, {out_candidate}); in_candidate = std::make_shared(inputShape, inputDataType, nCsp8c); out_candidate = std::make_shared(outputShape, outputDataType, nCsp8c); - createDescriptor({ in_candidate }, { out_candidate }); + createDescriptor({in_candidate}, {out_candidate}); } in_candidate = std::make_shared(inputShape, inputDataType, ncsp); out_candidate = std::make_shared(outputShape, outputDataType, ncsp); - createDescriptor({ in_candidate }, { out_candidate }); + createDescriptor({in_candidate}, {out_candidate}); - if (!nspcAdded && (inputDataType != memory::data_type::bf16 && inputDataType != memory::data_type::f16 && isNspcAvailable())) { + if (!nspcAdded && + (inputDataType != memory::data_type::bf16 && inputDataType != memory::data_type::f16 && isNspcAvailable())) { in_candidate = std::make_shared(inputShape, inputDataType, nspc); out_candidate = std::make_shared(outputShape, outputDataType, nspc); - createDescriptor({ in_candidate }, { out_candidate }); + createDescriptor({in_candidate}, {out_candidate}); } #else (void)ncsp; @@ -625,7 +647,7 @@ void Convolution::getSupportedDescriptors() { in_candidate = std::make_shared(inputShape, inputDataType, nspc); out_candidate = std::make_shared(outputShape, outputDataType, nspc); - createDescriptor({ in_candidate }, { out_candidate }); + createDescriptor({in_candidate}, {out_candidate}); #endif } @@ -636,9 +658,11 @@ void Convolution::setPostOps(dnnl::primitive_attr& attr, dnnl::post_ops ops; auto& args = convPostOpsArgs[useLegacyPostOps]; bool isINT8 = canBeExecutedInInt8(); - // Weight dims in NON-Group CONV: [OC, IC, KH, KW], perchannel weight scale applied on OC DIM, weiScaleMaskPerChannel = 1 << 0 - // Weight dims in Group CONV:[Group, OC, IC, KH, KW], perchannel weight scale applied on GROUP and OC DIM, weiScaleMaskPerChannel = ( 1 << 0 | 1<< 1) = 0x03 - DnnlPostOpsComposerLegacy dnnlpoc(getEngine(), attr, ops, args, dims, 1, isINT8, isGrouped ? 3 : 1 << 0, getDQScales(), withBiases); + // Weight dims in NON-Group CONV: [OC, IC, KH, KW], perchannel weight scale applied on OC DIM, + // weiScaleMaskPerChannel = 1 << 0 Weight dims in Group CONV:[Group, OC, IC, KH, KW], perchannel weight scale + // applied on GROUP and OC DIM, weiScaleMaskPerChannel = ( 1 << 0 | 1<< 1) = 0x03 + DnnlPostOpsComposerLegacy + dnnlpoc(getEngine(), attr, ops, args, dims, 1, isINT8, isGrouped ? 3 : 1 << 0, getDQScales(), withBiases); DEBUG_LOG(getName(), " useLegacyPostOps=", useLegacyPostOps, " initWeights=", initWeights); @@ -681,14 +705,14 @@ void Convolution::setPostOps(dnnl::primitive_attr& attr, bool hasSubsequentSum = false; bool hasSubsequentFQ = false; for (size_t j = i + 1; j < fusedWith.size(); j++) { - auto &nextNode = fusedWith[j]; + auto& nextNode = fusedWith[j]; - auto *nextEltwiseNode = dynamic_cast(nextNode.get()); + auto* nextEltwiseNode = dynamic_cast(nextNode.get()); if (nextEltwiseNode && nextEltwiseNode->isSpecialConvolutionAddFusing()) { hasSubsequentSum = true; } - auto *nextQuantizeNode = dynamic_cast(nextNode.get()); + auto* nextQuantizeNode = dynamic_cast(nextNode.get()); if (nextQuantizeNode) { hasSubsequentFQ = true; } @@ -781,12 +805,16 @@ void Convolution::initSupportedPrimitiveDescriptors() { const std::vector dwWeightsDims{dw_conv_oc, 1, 1, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS]}; const std::vector dwBiasesDims{dw_conv_oc}; - const auto dwWeightsPrc = DnnlExtensionUtils::ElementTypeToDataType(dw_conv_in_dt == dnnl_u8 ? ov::element::i8 : ov::element::f32); - const auto dwWeightsDesc = std::make_shared(Shape(dwWeightsDims), dwWeightsPrc, memory::format_tag::Goihw8g); + const auto dwWeightsPrc = DnnlExtensionUtils::ElementTypeToDataType( + dw_conv_in_dt == dnnl_u8 ? ov::element::i8 : ov::element::f32); + const auto dwWeightsDesc = std::make_shared(Shape(dwWeightsDims), + dwWeightsPrc, + memory::format_tag::Goihw8g); inConfs.emplace_back(dwWeightsDesc); const auto dwBiasPrc = memory::data_type::f32; - const auto dwBiasDesc = std::make_shared(Shape(dwBiasesDims), dwBiasPrc, memory::format_tag::x); + const auto dwBiasDesc = + std::make_shared(Shape(dwBiasesDims), dwBiasPrc, memory::format_tag::x); inConfs.emplace_back(dwBiasDesc); } @@ -809,15 +837,21 @@ void Convolution::initSupportedPrimitiveDescriptors() { }; #ifdef CPU_DEBUG_CAPS { - if (!customImplPriorities.empty()) { - DEBUG_LOG("#", getName(), " customImplPriorities [", 0 , "/", customImplPriorities.size(), - "]: ", impl_type_to_string(customImplPriorities[0])); - } + if (!customImplPriorities.empty()) { + DEBUG_LOG("#", + getName(), + " customImplPriorities [", + 0, + "/", + customImplPriorities.size(), + "]: ", + impl_type_to_string(customImplPriorities[0])); + } } #endif for (size_t dIdx = 0; dIdx < descs.size(); dIdx++) { auto& desc = descs[dIdx]; - auto primitive_desc = desc.get(true); //true mean allow empty + auto primitive_desc = desc.get(true); // true mean allow empty if (primitive_desc == nullptr) { continue; } @@ -829,16 +863,25 @@ void Convolution::initSupportedPrimitiveDescriptors() { }; const bool first_match = customImplPriorities.empty(); - DEBUG_LOG("#", getName(), ",descIndex:", dIdx + 1, "/", descs.size(), - ", itpd.impl_info_str(): ", desc.impl_info_str(), - ", parsed imp_type: ", impl_type_to_string(parse_impl_name(desc.impl_info_str())), - ", first_match: ", first_match ? "true" : "false"); - DnnlExtensionUtils::for_each_implementation(desc, - first_match, - [&](impl_desc_type implType) { - return contains(getImplPriority(), implType); - }, - add_supported_desc); + DEBUG_LOG("#", + getName(), + ",descIndex:", + dIdx + 1, + "/", + descs.size(), + ", itpd.impl_info_str(): ", + desc.impl_info_str(), + ", parsed imp_type: ", + impl_type_to_string(parse_impl_name(desc.impl_info_str())), + ", first_match: ", + first_match ? "true" : "false"); + DnnlExtensionUtils::for_each_implementation( + desc, + first_match, + [&](impl_desc_type implType) { + return contains(getImplPriority(), implType); + }, + add_supported_desc); // fallback. if none of the primitive types is present in the priority list just add first implementation // @todo this fallback is not necessary if primitive priority list is filled correctly @@ -852,46 +895,48 @@ bool Convolution::created() const { } namespace { -dnnl::convolution_forward::primitive_desc -createDescriptorInternal(const dnnl::engine& engine, - const dnnl::memory::desc& inputDesc, - const dnnl::memory::desc& weightDesc, - const dnnl::memory::desc& biasDesc, - const dnnl::memory::desc& outputDesc, - bool withBiases, - const std::vector& stride, - const std::vector& dilation, - const std::vector& paddingL, - const std::vector& paddingR, - dnnl::algorithm alg, - const dnnl::primitive_attr& attr) { +dnnl::convolution_forward::primitive_desc createDescriptorInternal(const dnnl::engine& engine, + const dnnl::memory::desc& inputDesc, + const dnnl::memory::desc& weightDesc, + const dnnl::memory::desc& biasDesc, + const dnnl::memory::desc& outputDesc, + bool withBiases, + const std::vector& stride, + const std::vector& dilation, + const std::vector& paddingL, + const std::vector& paddingR, + dnnl::algorithm alg, + const dnnl::primitive_attr& attr) { if (withBiases) { - return dnnl::convolution_forward::primitive_desc( - engine, - prop_kind::forward_inference, - alg, - inputDesc, weightDesc, biasDesc, outputDesc, - dnnl::memory::dims(stride.begin(), stride.end()), - dnnl::memory::dims(dilation.begin(), dilation.end()), - dnnl::memory::dims(paddingL.begin(), paddingL.end()), - dnnl::memory::dims(paddingR.begin(), paddingR.end()), - attr, - true); // allow_empty + return dnnl::convolution_forward::primitive_desc(engine, + prop_kind::forward_inference, + alg, + inputDesc, + weightDesc, + biasDesc, + outputDesc, + dnnl::memory::dims(stride.begin(), stride.end()), + dnnl::memory::dims(dilation.begin(), dilation.end()), + dnnl::memory::dims(paddingL.begin(), paddingL.end()), + dnnl::memory::dims(paddingR.begin(), paddingR.end()), + attr, + true); // allow_empty } else { - return dnnl::convolution_forward::primitive_desc( - engine, - prop_kind::forward_inference, - alg, - inputDesc, weightDesc, outputDesc, - dnnl::memory::dims(stride.begin(), stride.end()), - dnnl::memory::dims(dilation.begin(), dilation.end()), - dnnl::memory::dims(paddingL.begin(), paddingL.end()), - dnnl::memory::dims(paddingR.begin(), paddingR.end()), - attr, - true); // allow_empty + return dnnl::convolution_forward::primitive_desc(engine, + prop_kind::forward_inference, + alg, + inputDesc, + weightDesc, + outputDesc, + dnnl::memory::dims(stride.begin(), stride.end()), + dnnl::memory::dims(dilation.begin(), dilation.end()), + dnnl::memory::dims(paddingL.begin(), paddingL.end()), + dnnl::memory::dims(paddingR.begin(), paddingR.end()), + attr, + true); // allow_empty } } -} // namespace +} // namespace static memory::data_type deriveWeightDataType(memory::data_type src_dt) { memory::data_type wdt = src_dt; @@ -916,7 +961,7 @@ void Convolution::createDescriptor(const std::vector& inputDesc, if (outputDesc[0]->isDefined()) { definedOutMemDesc = MemoryDescUtils::convertToDnnlMemoryDesc(outputDesc[0]); } else { - std::vector shapes = { definedInpMemDesc->getShape(), Shape(weightDims) }; + std::vector shapes = {definedInpMemDesc->getShape(), Shape(weightDims)}; auto outDims = shapeInferGeneric(shapes); definedOutMemDesc = MemoryDescUtils::convertToDnnlMemoryDesc(outputDesc[0]->cloneWithNewDims(outDims.front())); } @@ -930,13 +975,14 @@ void Convolution::createDescriptor(const std::vector& inputDesc, dnnl::memory::desc biasDnnlDesc; if (withBiases) { - //oneDNN ARM Convolution primitive supports only identical in/out data types + // oneDNN ARM Convolution primitive supports only identical in/out data types #if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64) memory::data_type bdt = outDnnlDesc.get_data_type(); #else memory::data_type bdt = memory::data_type::f32; #endif - biasDnnlDesc = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(expectedBiasDims), bdt, memory::format_tag::any); + biasDnnlDesc = + dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(expectedBiasDims), bdt, memory::format_tag::any); } std::vector algorithms; @@ -948,8 +994,17 @@ void Convolution::createDescriptor(const std::vector& inputDesc, for (const auto alg : algorithms) { for (const auto& attr : attrs) { const auto desc = createDescriptorInternal(getEngine(), - inDnnlDesc, weightDnnlDesc, biasDnnlDesc, outDnnlDesc, withBiases, - stride, dilation, paddingL, paddingR, alg, attr); + inDnnlDesc, + weightDnnlDesc, + biasDnnlDesc, + outDnnlDesc, + withBiases, + stride, + dilation, + paddingL, + paddingR, + alg, + attr); descs.emplace_back(desc); } } @@ -983,7 +1038,8 @@ void Convolution::addLegacyZeroPoints(dnnl::primitive_attr& attr) { if (!legacyWeightsZeroPointsMemPtr) { DnnlBlockedMemoryDesc memoryDesc(ov::element::f32, {legacyWeightsZeroPoints.size()}); - legacyWeightsZeroPointsMemPtr = std::make_shared(getEngine(), memoryDesc, legacyWeightsZeroPoints.data()); + legacyWeightsZeroPointsMemPtr = + std::make_shared(getEngine(), memoryDesc, legacyWeightsZeroPoints.data()); } } @@ -993,7 +1049,8 @@ void Convolution::addLegacyZeroPoints(dnnl::primitive_attr& attr) { if (!legacyOutputCompensationMemPtr) { DnnlBlockedMemoryDesc memoryDesc(ov::element::i32, {legacyOutputCompensation.size()}); - legacyOutputCompensationMemPtr = std::make_shared(getEngine(), memoryDesc, legacyOutputCompensation.data()); + legacyOutputCompensationMemPtr = + std::make_shared(getEngine(), memoryDesc, legacyOutputCompensation.data()); } } } @@ -1004,7 +1061,7 @@ static bool attrContainsPostOp(const dnnl::primitive_attr& attr, const dnnl::imp } // See the src/plugins/intel_cpu/src/docs/convPostOps.md for details -void Convolution::SetPostOpsAndZeroPoints(std::vector &attrs) { +void Convolution::SetPostOpsAndZeroPoints(std::vector& attrs) { attrs.resize(1); auto outputShape = outputStaticShape(); // attr[0] - Legacy post ops + Legacy zero points. @@ -1012,14 +1069,13 @@ void Convolution::SetPostOpsAndZeroPoints(std::vector &att setPostOps(attrs[0], outputShape, true); addLegacyZeroPoints(attrs[0]); - //dw-conv would be fused into conv only on AVX2 platform. no need attr[1]. Avoid extra useless attribute. + // dw-conv would be fused into conv only on AVX2 platform. no need attr[1]. Avoid extra useless attribute. if (attrContainsPostOp(attrs[0], dnnl::impl::primitive_kind::convolution)) { return; } // no matter if brgconv is available, 1 attribute is enough. Avoid duplicated attribute - if (inputZeroPointType == zpType::None && - !attrContainsPostOp(attrs[0], dnnl::impl::primitive_kind::depthwise) && + if (inputZeroPointType == zpType::None && !attrContainsPostOp(attrs[0], dnnl::impl::primitive_kind::depthwise) && !attrContainsPostOp(attrs[0], dnnl::impl::primitive_kind::quantization)) { return; } @@ -1034,10 +1090,11 @@ void Convolution::SetPostOpsAndZeroPoints(std::vector &att } // Try 2 attributes. attrs.resize(2); - if (inputZeroPointType == zpType::PerTensor && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) { - //WR to ONEDNN limitation. attr[1] - legacy post ops + stock zero point. + if (inputZeroPointType == zpType::PerTensor && + dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) { + // WR to ONEDNN limitation. attr[1] - legacy post ops + stock zero point. //@todo:Unify to use binary postops+stock zero point when limitation is fixed. - //For now, have to adapt to JIT_AMX kernel for performance. + // For now, have to adapt to JIT_AMX kernel for performance. DEBUG_LOG(getName(), ": set post ops, attr 1, useLegacyPostOps=true"); setPostOps(attrs[1], outputShape, true); } else { @@ -1048,7 +1105,7 @@ void Convolution::SetPostOpsAndZeroPoints(std::vector &att } void Convolution::initDescriptor(const NodeConfig& config) { - auto *selectedPD = getSelectedPrimitiveDescriptor(); + auto* selectedPD = getSelectedPrimitiveDescriptor(); if (!selectedPD) { return; @@ -1057,24 +1114,29 @@ void Convolution::initDescriptor(const NodeConfig& config) { // attr[0] for legacy post ops; // attr[1] is mostly for binaryPostops except when having per-tensor zp on AMX. const int descId = descIdx[selectedPrimitiveDescriptorIndex]; - int attrId = attrs.size() == 1 ? 0 : - descId % 2 == 0 ? 0 : 1; + int attrId = attrs.size() == 1 ? 0 : descId % 2 == 0 ? 0 : 1; preferLegacyPostOps = (attrId == 0 || (attrId == 1 && (inputZeroPointType == zpType::PerTensor) && - dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx))); - //attr[0] for legacy zero point. - //attr[1] for stock per-tensor zero point. + dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx))); + // attr[0] for legacy zero point. + // attr[1] for stock per-tensor zero point. preferLegacyZeroPoint = (attrId == 0); DEBUG_LOG(getName(), - " selectedPrimitiveDescriptorIndex: ", selectedPrimitiveDescriptorIndex, - " DescIdx: ", descId, - " Selected impl type: ", selectedPD->getImplementationType(), - " Desc impl type: ", parse_impl_name(descs[descId].impl_info_str()), - " preferLegacyPostOps: ", preferLegacyPostOps, - " preferLegacyZeroPoint: ", preferLegacyZeroPoint); - - auto updateNodeConfig = [&](const NodeConfig& cfg){ + " selectedPrimitiveDescriptorIndex: ", + selectedPrimitiveDescriptorIndex, + " DescIdx: ", + descId, + " Selected impl type: ", + selectedPD->getImplementationType(), + " Desc impl type: ", + parse_impl_name(descs[descId].impl_info_str()), + " preferLegacyPostOps: ", + preferLegacyPostOps, + " preferLegacyZeroPoint: ", + preferLegacyZeroPoint); + + auto updateNodeConfig = [&](const NodeConfig& cfg) { auto updatedConfig = cfg; for (size_t i = 0; i < descInputNumbers(); i++) { @@ -1097,7 +1159,7 @@ void Convolution::initDescriptor(const NodeConfig& config) { return updatedConfig; }; - if (!canBeExecutedInInt8()) { // strided blobs are suppoted only for FP32 convolutions + if (!canBeExecutedInInt8()) { // strided blobs are suppoted only for FP32 convolutions descs.clear(); createDescriptor({config.inConfs[0].getMemDesc()}, {config.outConfs[0].getMemDesc()}); @@ -1115,7 +1177,7 @@ void Convolution::initDescriptor(const NodeConfig& config) { selectedPD->setConfig(updatedConfig); } -std::shared_ptr Convolution::getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const { +std::shared_ptr Convolution::getSrcMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const { if (idx == 1) { // report original plain layout for weight since it needs to be reordered dynamically at runtime return std::make_shared(getOriginalInputPrecisionAtPort(idx), @@ -1151,7 +1213,8 @@ ov::element::Type Convolution::getRuntimePrecision() const { for (size_t i = 0; i < std::min(getParentEdges().size(), inputsNumLimit); i++) { auto parentEdge = getParentEdgeAt(i); if (parentEdge && parentEdge->getStatus() == Edge::Status::Validated) { - inputPrecisions.emplace_back(DnnlExtensionUtils::DataTypeToElementType((parentEdge->getMemoryPtr()->getDataType()))); + inputPrecisions.emplace_back( + DnnlExtensionUtils::DataTypeToElementType((parentEdge->getMemoryPtr()->getDataType()))); } } @@ -1183,8 +1246,9 @@ bool Convolution::isNspcAvailable() const { return false; } } else { - // it was empirically observed that the nspc convolutions perform much slower than the blocked ones if the channels number more than the specific value - size_t spatialRank = ndims - 2; //two means batch dim plus channels dim + // it was empirically observed that the nspc convolutions perform much slower than the blocked ones if the + // channels number more than the specific value + size_t spatialRank = ndims - 2; // two means batch dim plus channels dim bool is1x1 = false; @@ -1195,24 +1259,24 @@ bool Convolution::isNspcAvailable() const { auto paddingRreversItr = paddingR.crbegin(); for (size_t i = 0; i < spatialRank; ++i) { - is1x1 = true - && *(weightDimsReversItr++) == 1 - && *(strideReversItr++) == 1 - && *(paddingLreversItr++) == 0 - && *(paddingRreversItr++) == 0; + is1x1 = true && *(weightDimsReversItr++) == 1 && *(strideReversItr++) == 1 && + *(paddingLreversItr++) == 0 && *(paddingRreversItr++) == 0; } } - // if the activation field size is 1x1 the avx512 1x1 nspc convolution pollutes caches so that the layer after the convolution performs slow + // if the activation field size is 1x1 the avx512 1x1 nspc convolution pollutes caches so that the layer after + // the convolution performs slow if (mayiuse(impl::cpu::x64::avx512_core) && is1x1) { auto end = inpDims.rbegin(); std::advance(end, spatialRank); - if (std::all_of(inpDims.rbegin(), end, [](size_t x) { return dimsEqualStrong(1, x); })) { + if (std::all_of(inpDims.rbegin(), end, [](size_t x) { + return dimsEqualStrong(1, x); + })) { return false; } } - unsigned thresholdNumChannels = 128u; // for avx and below + unsigned thresholdNumChannels = 128u; // for avx and below if (is1x1) { thresholdNumChannels = 2048u; } else if (mayiuse(impl::cpu::x64::avx512_core)) { @@ -1224,7 +1288,8 @@ bool Convolution::isNspcAvailable() const { return false; } if (!mayiuse(impl::cpu::x64::avx)) { - // SSE41 nspc convolutions do not support ic and oc tails yet and the blocked implementation will be much better than gemm + // SSE41 nspc convolutions do not support ic and oc tails yet and the blocked implementation will be much + // better than gemm if ((IC % 8) || (OC % 8)) { return false; } @@ -1251,7 +1316,7 @@ void Convolution::prepareParams() { OPENVINO_THROW("Input memory is undefined."); } - const NodeDesc *selected_pd = getSelectedPrimitiveDescriptor(); + const NodeDesc* selected_pd = getSelectedPrimitiveDescriptor(); if (selected_pd == nullptr) OPENVINO_THROW("Preferable primitive descriptor is not set for node ", getName(), "."); @@ -1324,44 +1389,41 @@ void Convolution::prepareParams() { dnnlBiasDesc = biasDescPtr->getDnnlDesc(); } - return createDescriptorInternal( - engine, - srcDesc, - wghDesc, - dnnlBiasDesc, - dstDesc, - (biasDescPtr != nullptr), - stride, - dilation, - paddingL, - paddingR, - alg, - attr); + return createDescriptorInternal(engine, + srcDesc, + wghDesc, + dnnlBiasDesc, + dstDesc, + (biasDescPtr != nullptr), + stride, + dilation, + paddingL, + paddingR, + alg, + attr); }; - dnnl::primitive_desc prim_desc = createDnnlConvDesc( - engine, - key.inp0->getDnnlDesc(), - wghDescAny, - key.out->getDnnlDesc(), - key.bias, - key.stride, - key.dilation, - key.paddingL, - key.paddingR, - convAlg, - key.attr); + dnnl::primitive_desc prim_desc = createDnnlConvDesc(engine, + key.inp0->getDnnlDesc(), + wghDescAny, + key.out->getDnnlDesc(), + key.bias, + key.stride, + key.dilation, + key.paddingL, + key.paddingR, + convAlg, + key.attr); const bool found = DnnlExtensionUtils::find_implementation(prim_desc, key.implType); if (found) { - return std::make_shared( - prim_desc, - key.inp0->getDnnlDesc(), - key.inp1->getDnnlDesc(), - key.out->getDnnlDesc(), - engine, - key.constWeight); + return std::make_shared(prim_desc, + key.inp0->getDnnlDesc(), + key.inp1->getDnnlDesc(), + key.out->getDnnlDesc(), + engine, + key.constWeight); } // primitive desc with proper implementation type not found, use the first available @@ -1372,40 +1434,37 @@ void Convolution::prepareParams() { key.out->getDataType(), memory::format_tag::any); - auto reorderConvDesc = createDnnlConvDesc( - engine, - inDesc, - wghDescAny, - outDesc, - key.bias, - key.stride, - key.dilation, - key.paddingL, - key.paddingR, - convAlg, - key.attr); + auto reorderConvDesc = createDnnlConvDesc(engine, + inDesc, + wghDescAny, + outDesc, + key.bias, + key.stride, + key.dilation, + key.paddingL, + key.paddingR, + convAlg, + key.attr); // unable to create a primitive desc if (!reorderConvDesc) return nullptr; if (key.attr.get()->post_ops_.count(dnnl::impl::primitive_kind::sum)) { - return std::make_shared( - reorderConvDesc, - key.inp0->getDnnlDesc(), - key.inp1->getDnnlDesc(), - key.out->getDnnlDesc(), - engine, - key.constWeight); + return std::make_shared(reorderConvDesc, + key.inp0->getDnnlDesc(), + key.inp1->getDnnlDesc(), + key.out->getDnnlDesc(), + engine, + key.constWeight); } - return std::make_shared( - reorderConvDesc, - key.inp0->getDnnlDesc(), - key.inp1->getDnnlDesc(), - key.out->getDnnlDesc(), - engine, - key.constWeight); + return std::make_shared(reorderConvDesc, + key.inp0->getDnnlDesc(), + key.inp1->getDnnlDesc(), + key.out->getDnnlDesc(), + engine, + key.constWeight); }; auto prevExecPtr = execPtr; @@ -1460,7 +1519,8 @@ Convolution::ConvolutionExecutor::ConvolutionExecutor(const dnnl::primitive_desc const dnnl::memory::desc& weightMemDesc, const dnnl::memory::desc& outMemDesc, const dnnl::engine& engine, - bool constWeight) : DnnlExecutor(pd) { + bool constWeight) + : DnnlExecutor(pd) { if (inMemDesc != getDnnlSrcDesc()) { inputReorders.insert({DNNL_ARG_SRC, IntermReorder(inMemDesc, getDnnlSrcDesc(), engine)}); } @@ -1480,7 +1540,8 @@ Convolution::ConvolutionSumExecutor::ConvolutionSumExecutor(const dnnl::primitiv const dnnl::memory::desc& weightMemDesc, const dnnl::memory::desc& outMemDesc, const dnnl::engine& engine, - bool constWeight) : DnnlExecutor(pd) { + bool constWeight) + : DnnlExecutor(pd) { if (inMemDesc != getDnnlSrcDesc()) { inputReorders.insert({DNNL_ARG_SRC, IntermReorder(inMemDesc, getDnnlSrcDesc(), engine)}); } @@ -1498,9 +1559,10 @@ Convolution::ConvolutionSumExecutor::ConvolutionSumExecutor(const dnnl::primitiv } } -void Convolution::ConvolutionSumExecutor::reorder_exec(std::unordered_map primArgs, dnnl::stream strm) { +void Convolution::ConvolutionSumExecutor::reorder_exec(std::unordered_map primArgs, + dnnl::stream strm) { auto outputMem = primArgs.at(DNNL_ARG_DST); - for (auto &inReorder : inputReorders) { + for (auto& inReorder : inputReorders) { if (primArgs.count(inReorder.first)) { dnnl::memory memDst(inReorder.second.getDstDesc(), strm.get_engine()); inReorder.second.exec(primArgs[inReorder.first], memDst, strm); @@ -1549,14 +1611,14 @@ void Convolution::executeDynamicImpl(dnnl::stream strm) { } void Convolution::updatePadding() { - //update padding. + // update padding. if (isDynamicNode() && autoPadding) { paddingL = shapeInference->get_pads_begin(); paddingR = shapeInference->get_pads_end(); } } -void Convolution::redefineOutputMemory(const std::vector &newOutputShapes) { +void Convolution::redefineOutputMemory(const std::vector& newOutputShapes) { if (withSum) { const size_t sumPortNum = getParentEdges().size() - 1; const auto& sumInpMem = getParentEdgeAt(sumPortNum)->getMemory(); @@ -1570,7 +1632,8 @@ void Convolution::redefineOutputMemory(const std::vector &newOutputS auto inp1 = subgraph->getInput(1); inp1->redefineOutputMemory({sumInpMem.getStaticDims()}); - // here we postpone output memory reallocation due to the fact that it is the same memory with the sum second input + // here we postpone output memory reallocation due to the fact that it is the same memory with the sum + // second input return; } else { withSumBroadcast = false; @@ -1579,12 +1642,10 @@ void Convolution::redefineOutputMemory(const std::vector &newOutputS Node::redefineOutputMemory(newOutputShapes); } -MemoryDescPtr Convolution::getSumMemDesc(const primitive_desc &primitive_desc_it) { +MemoryDescPtr Convolution::getSumMemDesc(const primitive_desc& primitive_desc_it) { if (getOutputShapeAtPort(0).isDynamic()) { - // When we set input shape with ranged dims, sum node input shape maybe mismatch with output shape, we just change - // ranged min value to 1 to meet this case. - // For example: - // Output shape = {1, 160, {128, 256}, {128, 256}} + // When we set input shape with ranged dims, sum node input shape maybe mismatch with output shape, we just + // change ranged min value to 1 to meet this case. For example: Output shape = {1, 160, {128, 256}, {128, 256}} // Sum input shape = {1, 160, 1, 1} // Update sum shape to {1, 160, {1, 256}, {1, 256}} auto shape = getOutputShapeAtPort(0); @@ -1622,7 +1683,7 @@ MemoryPtr Convolution::getOutputMemory() const { } } -void Convolution::addFusedNode(const NodePtr &fusingNode) { +void Convolution::addFusedNode(const NodePtr& fusingNode) { if (Type::Eltwise == fusingNode->getType()) { if (fusingNode->getAlgorithm() == Algorithm::EltwiseAdd) { auto eltwiseNode = std::dynamic_pointer_cast(fusingNode); @@ -1655,7 +1716,6 @@ void Convolution::appendLegacyZeroPointsArgs() { } } - void Convolution::appendZeroPointsArgs() { if (stockInputZeroPointsMemPtr != nullptr) { primArgs[DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC] = stockInputZeroPointsMemPtr->getPrimitive(); @@ -1673,10 +1733,9 @@ void Convolution::initializeInputZeroPoints(const uint8_t* inputZpData, const si inputZeroPointType = zpType::PerChannel; } // Only enable per-tensor zero point on avx512-amx and avx512-core-vnni, avx2_vnni_2. - // avx2_vnni is not enabled per-tensor z because of perf regression brgconv with per-tensor zpcompared with jit per-channel zp - // If zero point is pertensor, both legacy zp and stock zp - // would be passed into conv node. The conv node would determine how to create - // post-ops attribute and prioritize to choose final onednn kernel. + // avx2_vnni is not enabled per-tensor z because of perf regression brgconv with per-tensor zpcompared with jit + // per-channel zp If zero point is pertensor, both legacy zp and stock zp would be passed into conv node. The conv + // node would determine how to create post-ops attribute and prioritize to choose final onednn kernel. if (inputZeroPointType == zpType::PerTensor && (impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_amx) || impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core_vnni) || impl::cpu::x64::mayiuse(impl::cpu::x64::avx2_vnni_2))) @@ -1694,15 +1753,14 @@ VectorDims Convolution::makeInputDummyShape(const Shape& inpShape) const { const size_t filterStartIndx = weightDims.size() - spatialRank; VectorDims dummyInputShapeVals(inpShape.getRank(), dummyInputDim); - dummyInputShapeVals[1] = IC; //channels + dummyInputShapeVals[1] = IC; // channels for (size_t i = 0; i < spatialRank; i++) { if (weightDims[filterStartIndx + i] > dummyInputShapeVals[2 + i]) { constexpr Dim dummyOutputDim = 16; - dummyInputShapeVals[2 + i] = (dummyOutputDim - 1) * stride[i] - - (paddingL[i] + paddingR[i]) + - weightDims[filterStartIndx + i] + - (weightDims[filterStartIndx + i]- 1) * (dilation[i]); + dummyInputShapeVals[2 + i] = (dummyOutputDim - 1) * stride[i] - (paddingL[i] + paddingR[i]) + + weightDims[filterStartIndx + i] + + (weightDims[filterStartIndx + i] - 1) * (dilation[i]); } } return MemoryDescUtils::makeDummyShape(inpShape, dummyInputShapeVals).getStaticDims(); @@ -1712,12 +1770,12 @@ VectorDims Convolution::outputStaticShape() const { auto& outputShape = getOutputShapeAtPort(0); if (outputShape.isDynamic()) { auto inpDummyShape = makeInputDummyShape(getInputShapeAtPort(0)); - auto outputDims = shapeInferGeneric({ Shape(inpDummyShape), Shape(weightDims) }); + auto outputDims = shapeInferGeneric({Shape(inpDummyShape), Shape(weightDims)}); return Shape(outputDims.front()).getStaticDims(); } return outputShape.getStaticDims(); } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/conv.h b/src/plugins/intel_cpu/src/nodes/conv.h index a7cac9bced1241..8da3193e5760cf 100644 --- a/src/plugins/intel_cpu/src/nodes/conv.h +++ b/src/plugins/intel_cpu/src/nodes/conv.h @@ -29,7 +29,7 @@ class Convolution : public Node { return false; } ov::element::Type getRuntimePrecision() const override; - std::shared_ptr getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const override; + std::shared_ptr getSrcMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const override; dnnl::memory getWeights() const; dnnl::memory getBias() const; @@ -39,23 +39,35 @@ class Convolution : public Node { } bool canBeExecutedInInt8() const override; - size_t getGroupNum() const { return groupNum; } - //OV Legacy input zero point mechanism can support per-channel zero point. - //Hold legacy input zero point. + size_t getGroupNum() const { + return groupNum; + } + // OV Legacy input zero point mechanism can support per-channel zero point. + // Hold legacy input zero point. std::vector legacyInputZeroPoints; - //Hold legacy weight zero point. + // Hold legacy weight zero point. std::vector legacyWeightsZeroPoints; - //Hold legacy pre-calculated output compensation + // Hold legacy pre-calculated output compensation std::vector legacyOutputCompensation; - //Hold stock per-tensor input zero point. Pass to onednn to calculate output compensation. + // Hold stock per-tensor input zero point. Pass to onednn to calculate output compensation. std::vector inputZeroPoints; void initializeInputZeroPoints(const uint8_t* inputZpData, const size_t inputZpSize); - const VectorDims &getWeightDims() { return weightDims; } - const std::vector &getStride() { return stride; } - const std::vector &getDilation() { return dilation; } - const std::vector &getPaddingL() { return paddingL; } - const std::vector &getPaddingR() { return paddingR; } + const VectorDims& getWeightDims() { + return weightDims; + } + const std::vector& getStride() { + return stride; + } + const std::vector& getDilation() { + return dilation; + } + const std::vector& getPaddingL() { + return paddingL; + } + const std::vector& getPaddingR() { + return paddingR; + } bool canFuse(const NodePtr& node) const override; bool isDepthWise() const { @@ -64,16 +76,12 @@ class Convolution : public Node { protected: ov::element::Type fusedEltwisePrecision(const NodePtr& fusingNode) const; - void redefineOutputMemory(const std::vector &newOutputShapes) override; - void addFusedNode(const NodePtr &fusingNode) override; + void redefineOutputMemory(const std::vector& newOutputShapes) override; + void addFusedNode(const NodePtr& fusingNode) override; const std::vector& getDefaultImplPriority() override; private: - enum class zpType { - None, - PerTensor, - PerChannel - }; + enum class zpType { None, PerTensor, PerChannel }; class FusedSubgraph; using FusedSubgraphPtr = std::shared_ptr; @@ -81,26 +89,26 @@ class Convolution : public Node { executorPtr execPtr = nullptr; class ConvolutionExecutor : public DnnlExecutor { - public: - ConvolutionExecutor(const dnnl::primitive_desc& pd, - const dnnl::memory::desc& inMemDesc, - const dnnl::memory::desc& weightMemDesc, - const dnnl::memory::desc& outMemDesc, - const dnnl::engine& engine, - bool constWeight); + public: + ConvolutionExecutor(const dnnl::primitive_desc& pd, + const dnnl::memory::desc& inMemDesc, + const dnnl::memory::desc& weightMemDesc, + const dnnl::memory::desc& outMemDesc, + const dnnl::engine& engine, + bool constWeight); }; class ConvolutionSumExecutor : public DnnlExecutor { - public: - ConvolutionSumExecutor(const dnnl::primitive_desc& pd, - const dnnl::memory::desc& inMemDesc, - const dnnl::memory::desc& weightMemDesc, - const dnnl::memory::desc& outMemDesc, - const dnnl::engine& engine, - bool constWeight); - - private: - void reorder_exec(std::unordered_map primArgs, dnnl::stream strm) override; + public: + ConvolutionSumExecutor(const dnnl::primitive_desc& pd, + const dnnl::memory::desc& inMemDesc, + const dnnl::memory::desc& weightMemDesc, + const dnnl::memory::desc& outMemDesc, + const dnnl::engine& engine, + bool constWeight); + + private: + void reorder_exec(std::unordered_map primArgs, dnnl::stream strm) override; }; void prepareParams() override; @@ -108,13 +116,16 @@ class Convolution : public Node { void executeDynamicImpl(dnnl::stream strm) override; void addLegacyZeroPoints(dnnl::primitive_attr& attr); void addZeroPoints(dnnl::primitive_attr& attr); - void setPostOps(dnnl::primitive_attr &attr, const VectorDims &dims, bool useLegacyPostOps, bool initWeights = false); - void SetPostOpsAndZeroPoints(std::vector &attrs); + void setPostOps(dnnl::primitive_attr& attr, + const VectorDims& dims, + bool useLegacyPostOps, + bool initWeights = false); + void SetPostOpsAndZeroPoints(std::vector& attrs); void filterSupportedDescriptors(); bool isNspcAvailable() const; void updatePadding(); - MemoryDescPtr getSumMemDesc(const dnnl::primitive_desc &primitive_desc_it); + MemoryDescPtr getSumMemDesc(const dnnl::primitive_desc& primitive_desc_it); MemoryPtr getOutputMemory() const; VectorDims makeInputDummyShape(const Shape& inpShape) const; VectorDims outputStaticShape() const; @@ -131,7 +142,7 @@ class Convolution : public Node { zpType inputZeroPointType = zpType::None; // maps each supportedPrimitiveDescriptor to corresponding desc from descs std::vector descIdx; - VectorDims expectedBiasDims {}; + VectorDims expectedBiasDims{}; std::vector stride; std::vector dilation; @@ -179,6 +190,6 @@ class Convolution : public Node { #endif }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/convert.cpp b/src/plugins/intel_cpu/src/nodes/convert.cpp index 1987c9cc83d5f2..d01a56aac1b86d 100644 --- a/src/plugins/intel_cpu/src/nodes/convert.cpp +++ b/src/plugins/intel_cpu/src/nodes/convert.cpp @@ -26,7 +26,8 @@ bool Convert::isSupportedOperation(const std::shared_ptr& op, st auto srcPrc = op->get_input_element_type(0); auto dstPrc = op->get_output_element_type(0); if (!CommonConvertExecutor::isSupported(srcPrc, dstPrc)) { - errorMessage = "cpu_convert can't convert from: " + srcPrc.to_string() + " precision to: " + dstPrc.to_string(); + errorMessage = + "cpu_convert can't convert from: " + srcPrc.to_string() + " precision to: " + dstPrc.to_string(); return false; } } catch (...) { @@ -36,7 +37,7 @@ bool Convert::isSupportedOperation(const std::shared_ptr& op, st } Convert::Convert(const std::shared_ptr& op, const GraphContext::CPtr context) - : Node(op, context, PassThroughShapeInferFactory()) { + : Node(op, context, PassThroughShapeInferFactory()) { std::string errorMessage; if (isSupportedOperation(op, errorMessage)) { errorPrefix = "Convert node with name '" + getName() + "'"; @@ -48,8 +49,11 @@ Convert::Convert(const std::shared_ptr& op, const GraphContext::CPtr c convertParams.origPrc = convert->get_destination_type(); } -Convert::Convert(const Shape &shape, const ov::element::Type &inPrc, const ov::element::Type &outPrc, - const std::string &nodeName, const GraphContext::CPtr context) +Convert::Convert(const Shape& shape, + const ov::element::Type& inPrc, + const ov::element::Type& outPrc, + const std::string& nodeName, + const GraphContext::CPtr context) : Node("Convert", {shape}, {shape}, {inPrc}, {outPrc}, nodeName, context) { convertParams.origPrc = outPrc; @@ -74,7 +78,7 @@ void Convert::getSupportedDescriptors() { OPENVINO_THROW(errorPrefix, " has incorrect number of output edges"); } -bool Convert::isSupportedDesc(const MemoryDesc &desc) { +bool Convert::isSupportedDesc(const MemoryDesc& desc) { bool isSupported = desc.getType() & MemoryDescType::Blocked; if (desc.getType() == MemoryDescType::DnnlBlocked) isSupported &= desc.as()->hasEmptyExtraData(); @@ -101,13 +105,16 @@ void Convert::initSupportedPrimitiveDescriptors() { MemoryDescPtr dstMemoryDesc = config.outConfs[0].getMemDesc(); convertParams.srcPrc = srcMemoryDesc->getPrecision(); convertParams.dstPrc = dstMemoryDesc->getPrecision(); - auto factory = std::make_shared(convertParams, srcMemoryDesc, dstMemoryDesc, - std::make_shared(context, getImplPriority())); + auto factory = + std::make_shared(convertParams, + srcMemoryDesc, + dstMemoryDesc, + std::make_shared(context, getImplPriority())); supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown, factory); }; - // if input and output pointers are not null and not contain extra data, then the inp/output tensor descriptors were set using setDescs method, so - // they should be used as the actual descriptors. + // if input and output pointers are not null and not contain extra data, then the inp/output tensor descriptors were + // set using setDescs method, so they should be used as the actual descriptors. if (canInitExternalDesc) { dataIn.setMemDesc(input); config.inConfs.push_back(dataIn); @@ -142,8 +149,10 @@ void Convert::initSupportedPrimitiveDescriptors() { : BlockedDescCreator::makeFilteredRange(creators, insShape.getRank()); for (auto itr = range.first; itr != range.second; ++itr) { - config.inConfs[0].setMemDesc(std::make_shared(itr->second->createDesc(insPrecision, insShape))); - config.outConfs[0].setMemDesc(std::make_shared(itr->second->createDesc(outPrecision, outputShape))); + config.inConfs[0].setMemDesc( + std::make_shared(itr->second->createDesc(insPrecision, insShape))); + config.outConfs[0].setMemDesc( + std::make_shared(itr->second->createDesc(outPrecision, outputShape))); supportedPrimitiveDescriptorsBuilder(config); } @@ -159,10 +168,8 @@ void Convert::prepareParams() { auto selectedPD = getSelectedPrimitiveDescriptor(); MemoryDescPtr srcDesc = getSrcMemoryAtPort(0)->getDescPtr(); MemoryDescPtr dstDesc = getDstMemoryAtPort(0)->getDescPtr(); - execPtr = selectedPD->getExecutorFactoryAs()->makeExecutor(convertParams, - srcDesc, - dstDesc, - {}); + execPtr = + selectedPD->getExecutorFactoryAs()->makeExecutor(convertParams, srcDesc, dstDesc, {}); selectedPD->setImplementationType(execPtr->implType()); } @@ -189,6 +196,6 @@ bool Convert::created() const { return getType() == Type::Convert; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/convert.h b/src/plugins/intel_cpu/src/nodes/convert.h index 2a257bd1d31cd8..3bc911d118fd7a 100644 --- a/src/plugins/intel_cpu/src/nodes/convert.h +++ b/src/plugins/intel_cpu/src/nodes/convert.h @@ -14,8 +14,11 @@ namespace node { class Convert : public Node { public: Convert(const std::shared_ptr& op, const GraphContext::CPtr context); - Convert(const Shape &shape, const ov::element::Type &inPrc, const ov::element::Type &outPrc, - const std::string &nodeName, const GraphContext::CPtr context); + Convert(const Shape& shape, + const ov::element::Type& inPrc, + const ov::element::Type& outPrc, + const std::string& nodeName, + const GraphContext::CPtr context); void getSupportedDescriptors() override; void initSupportedPrimitiveDescriptors() override; @@ -28,22 +31,28 @@ class Convert : public Node { } // This is the interface extension designed to provide inp and output tensor descriptors without the CNNLayer. - // In that case the Convert node is instantiated with default CNNLayer and inp/out tensor descriptors are set via this method. - // This is useful if the Convert node is added to the graph as an auxiliary operation at the Graph + // In that case the Convert node is instantiated with default CNNLayer and inp/out tensor descriptors are set via + // this method. This is useful if the Convert node is added to the graph as an auxiliary operation at the Graph // initialization stage. void setDescs(const MemoryDesc& input, const MemoryDesc& output) { this->input = input.clone(); this->output = output.clone(); } - const MemoryDesc& getInput() const { return *input; } - const MemoryDesc& getOutput() const { return *output; } + const MemoryDesc& getInput() const { + return *input; + } + const MemoryDesc& getOutput() const { + return *output; + } - bool needPrepareParams() const override { return inputShapesModified(); } + bool needPrepareParams() const override { + return inputShapesModified(); + } static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; - static bool isSupportedDesc(const MemoryDesc &desc); + static bool isSupportedDesc(const MemoryDesc& desc); private: MemoryDescPtr input; @@ -55,6 +64,6 @@ class Convert : public Node { std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder.cpp b/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder.cpp index 0b467fe452e061..2869d782cdb445 100644 --- a/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder.cpp +++ b/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder.cpp @@ -2,18 +2,20 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/op/ctc_greedy_decoder.hpp" + #include #include -#include "openvino/op/ctc_greedy_decoder.hpp" -#include "openvino/core/parallel.hpp" #include "ctc_greedy_decoder.h" +#include "openvino/core/parallel.hpp" namespace ov { namespace intel_cpu { namespace node { -bool CTCGreedyDecoder::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool CTCGreedyDecoder::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { const auto greedyDecOp = ov::as_type_ptr(op); if (!greedyDecOp) { @@ -61,8 +63,7 @@ void CTCGreedyDecoder::initSupportedPrimitiveDescriptors() { if (!one_of(seqLenPrecision, ov::element::f32, ov::element::bf16, ov::element::f16)) OPENVINO_THROW(errorPrefix, "has unsupported 'sequence_length' input precision: ", seqLenPrecision); - addSupportedPrimDesc({{LayoutType::ncsp, ov::element::f32}, - {LayoutType::ncsp, ov::element::f32}}, + addSupportedPrimDesc({{LayoutType::ncsp, ov::element::f32}, {LayoutType::ncsp, ov::element::f32}}, {{LayoutType::ncsp, ov::element::f32}}, impl_desc_type::ref_any); } @@ -141,7 +142,7 @@ void CTCGreedyDecoder::execute(dnnl::stream strm) { } tStart = 0lu; } - }; // thread body + }; // thread body parallel_nt(0, threadBody); @@ -151,8 +152,7 @@ void CTCGreedyDecoder::execute(dnnl::stream strm) { const size_t sequenceLength = sequenceLengths[b]; float* shiftedOut = outputSequences + b * T; for (size_t t = 0; t < sequenceLength; ++t) { - if (*shiftedOut < blankIndex && - !(mergeRepeated && *shiftedOut == prevClassIdx)) { + if (*shiftedOut < blankIndex && !(mergeRepeated && *shiftedOut == prevClassIdx)) { outputSequences[outputIndex++] = *shiftedOut; } prevClassIdx = *shiftedOut; @@ -174,6 +174,6 @@ bool CTCGreedyDecoder::needPrepareParams() const { return false; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder.h b/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder.h index 1f3179edb904d2..a552ff7db3c566 100644 --- a/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder.h +++ b/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder.h @@ -14,7 +14,7 @@ class CTCGreedyDecoder : public Node { public: CTCGreedyDecoder(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -22,6 +22,7 @@ class CTCGreedyDecoder : public Node { bool needPrepareParams() const override; static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; + private: const size_t DATA_INDEX = 0lu; const size_t SEQUENCE_LENGTH_INDEX = 1lu; @@ -30,6 +31,6 @@ class CTCGreedyDecoder : public Node { std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder_seq_len.cpp b/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder_seq_len.cpp index 63db3968094c3a..3eb02f2583e551 100644 --- a/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder_seq_len.cpp +++ b/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder_seq_len.cpp @@ -2,18 +2,20 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "ctc_greedy_decoder_seq_len.h" + +#include #include #include -#include #include "openvino/core/parallel.hpp" -#include "ctc_greedy_decoder_seq_len.h" namespace ov { namespace intel_cpu { namespace node { -bool CTCGreedyDecoderSeqLen::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool CTCGreedyDecoderSeqLen::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { const auto greedyDecOp = ov::as_type_ptr(op); if (!greedyDecOp) { @@ -67,33 +69,35 @@ void CTCGreedyDecoderSeqLen::initSupportedPrimitiveDescriptors() { inDataConf.emplace_back(LayoutType::ncsp, ov::element::i32); addSupportedPrimDesc(inDataConf, - {{LayoutType::ncsp, ov::element::i32}, - {LayoutType::ncsp, ov::element::i32}}, + {{LayoutType::ncsp, ov::element::i32}, {LayoutType::ncsp, ov::element::i32}}, impl_desc_type::ref_any); } void CTCGreedyDecoderSeqLen::execute(dnnl::stream strm) { const float* probabilities = getSrcDataAtPortAs(DATA_INDEX); const int* sequenceLengths = getSrcDataAtPortAs(SEQUENCE_LENGTH_INDEX); - int* decodedClasses = getDstDataAtPortAs(DECODED_CLASSES_INDEX); + int* decodedClasses = getDstDataAtPortAs(DECODED_CLASSES_INDEX); int* decodedClassesLength = getDstDataAtPortAs(DECODED_CLASSES_LENGTH_INDEX); - const size_t B = getParentEdgeAt(DATA_INDEX)->getMemory().getStaticDims()[0];; - const size_t T = getParentEdgeAt(DATA_INDEX)->getMemory().getStaticDims()[1];; - const int C = getParentEdgeAt(DATA_INDEX)->getMemory().getStaticDims()[2];; + const size_t B = getParentEdgeAt(DATA_INDEX)->getMemory().getStaticDims()[0]; + ; + const size_t T = getParentEdgeAt(DATA_INDEX)->getMemory().getStaticDims()[1]; + ; + const int C = getParentEdgeAt(DATA_INDEX)->getMemory().getStaticDims()[2]; + ; const size_t TC = T * C; int blankIndex = C - 1; if (inputShapes.size() > BLANK_INDEX) - blankIndex = (getSrcDataAtPortAs(BLANK_INDEX))[0]; + blankIndex = (getSrcDataAtPortAs(BLANK_INDEX))[0]; size_t workAmount = 0; for (size_t b = 0; b < B; b++) { if (sequenceLengths[b] > static_cast(T)) { - std::string errorMsg = errorPrefix - + ". Sequence length " + std::to_string(sequenceLengths[b]) - + " cannot be greater than according decoded classes dimension size " - + std::to_string(getChildEdgeAt(DECODED_CLASSES_INDEX)->getMemory().getStaticDims()[1]); + std::string errorMsg = + errorPrefix + ". Sequence length " + std::to_string(sequenceLengths[b]) + + " cannot be greater than according decoded classes dimension size " + + std::to_string(getChildEdgeAt(DECODED_CLASSES_INDEX)->getMemory().getStaticDims()[1]); OPENVINO_THROW(errorMsg); } workAmount += sequenceLengths[b]; @@ -142,7 +146,7 @@ void CTCGreedyDecoderSeqLen::execute(dnnl::stream strm) { } tStart = 0lu; } - }; // thread body + }; // thread body parallel_nt(0, threadBody); @@ -153,8 +157,7 @@ void CTCGreedyDecoderSeqLen::execute(dnnl::stream strm) { int* shiftedOut = decodedClasses + b * T; for (size_t t = 0; t < actualSeqLen; ++t) { - if (*shiftedOut != blankIndex && - !(mergeRepeated && *shiftedOut == prevClassIdx)) { + if (*shiftedOut != blankIndex && !(mergeRepeated && *shiftedOut == prevClassIdx)) { decodedClasses[outputIndex++] = *shiftedOut; } prevClassIdx = *shiftedOut; @@ -177,6 +180,6 @@ bool CTCGreedyDecoderSeqLen::needPrepareParams() const { return false; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder_seq_len.h b/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder_seq_len.h index 4e7d14fd23556a..95ab8ef84b07eb 100644 --- a/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder_seq_len.h +++ b/src/plugins/intel_cpu/src/nodes/ctc_greedy_decoder_seq_len.h @@ -14,7 +14,7 @@ class CTCGreedyDecoderSeqLen : public Node { public: CTCGreedyDecoderSeqLen(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -34,6 +34,6 @@ class CTCGreedyDecoderSeqLen : public Node { std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/ctc_loss.cpp b/src/plugins/intel_cpu/src/nodes/ctc_loss.cpp index 6d09b0aea7e934..0ed3d95503eb62 100644 --- a/src/plugins/intel_cpu/src/nodes/ctc_loss.cpp +++ b/src/plugins/intel_cpu/src/nodes/ctc_loss.cpp @@ -2,11 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/op/ctc_loss.hpp" + #include -#include "openvino/op/ctc_loss.hpp" -#include "openvino/core/parallel.hpp" #include "ctc_loss.h" +#include "openvino/core/parallel.hpp" namespace ov { namespace intel_cpu { @@ -53,9 +54,7 @@ void CTCLoss::initSupportedPrimitiveDescriptors() { for (size_t i = 1; i < inputShapes.size(); ++i) inDataConf.emplace_back(LayoutType::ncsp, ov::element::i32); - addSupportedPrimDesc(inDataConf, - {{LayoutType::ncsp, ov::element::f32}}, - impl_desc_type::ref_any); + addSupportedPrimDesc(inDataConf, {{LayoutType::ncsp, ov::element::f32}}, impl_desc_type::ref_any); } void CTCLoss::executeDynamicImpl(dnnl::stream strm) { @@ -71,7 +70,7 @@ void CTCLoss::execute(dnnl::stream strm) { const int* labelsLength = getSrcDataAtPortAs(3); float* dstData = getDstDataAtPortAs(0); - const auto &inDims = getParentEdgeAt(0)->getMemory().getStaticDims(); + const auto& inDims = getParentEdgeAt(0)->getMemory().getStaticDims(); const size_t batchNum = inDims[0]; const size_t maxTime = inDims[1]; const size_t classesNum = inDims[2]; @@ -96,11 +95,11 @@ void CTCLoss::execute(dnnl::stream strm) { for (size_t b = start; b < end; b++) { if (logitsLength[b] < 0 || labelsLength[b] < 0 || logitsLength[b] > static_cast(maxTime) || labelsLength[b] > logitsLength[b]) { - errorMsgB[ithr] = errorPrefix + ". Logit length cannot be greater than max sequence length. " - + "Label length cannot be greater than a logit length" - + " and both cannot be negative.\nMaxSeqLen: " - + std::to_string(maxTime) + "; Logit len: " + std::to_string(logitsLength[b]) - + "; Label len: " + std::to_string(labelsLength[b]); + errorMsgB[ithr] = errorPrefix + ". Logit length cannot be greater than max sequence length. " + + "Label length cannot be greater than a logit length" + + " and both cannot be negative.\nMaxSeqLen: " + std::to_string(maxTime) + + "; Logit len: " + std::to_string(logitsLength[b]) + + "; Label len: " + std::to_string(labelsLength[b]); returnCode = -1; return; } @@ -151,8 +150,8 @@ void CTCLoss::execute(dnnl::stream strm) { for (size_t ll = 0; ll < actualLogitLen; ll++) { logProbabilities[ll].resize(decodedTargetLen); } - } // for batch - }; // threadBody_1 + } // for batch + }; // threadBody_1 parallel_nt(threads_num, threadBody_1); if (returnCode != 0) { @@ -211,7 +210,7 @@ void CTCLoss::execute(dnnl::stream strm) { } sT = 0lu; } // for batch - }; // threadBody_2 + }; // threadBody_2 parallel_nt(0, threadBody_2); @@ -236,8 +235,8 @@ void CTCLoss::execute(dnnl::stream strm) { if (start >= end) return; - // As per Connectionist Temporal Classification - Labeling Unsegmented Sequence Data with Recurrent Neural Networks: - // Graves et al., 2016, paragraph 4.1 (10) + // As per Connectionist Temporal Classification - Labeling Unsegmented Sequence Data with Recurrent Neural + // Networks: Graves et al., 2016, paragraph 4.1 (10) for (size_t b = start; b < end; b++) { auto& targetD = targetDB[b]; auto& logProbabilities = logProbabilitiesB[b]; @@ -250,21 +249,19 @@ void CTCLoss::execute(dnnl::stream strm) { for (int t = actualLogitLen - 2; t >= 0; t--) { const int t_1 = t + 1; for (int s = std::max(0, decodedTargetLen - (2 * (actualLogitLen - t))); - s < std::min(decodedTargetLen, 2 * (t_1)); s++) { + s < std::min(decodedTargetLen, 2 * (t_1)); + s++) { if (ctcMergeRepeated || targetD[s] == blankIndex) { - logBwd[s][t] = sumLogs(logBwd[s][t], - logBwd[s][t_1] + logProbabilities[t_1][s]); + logBwd[s][t] = sumLogs(logBwd[s][t], logBwd[s][t_1] + logProbabilities[t_1][s]); } if (s + 1 < decodedTargetLen) { - logBwd[s][t] = sumLogs(logBwd[s][t], - logBwd[s + 1][t_1] + logProbabilities[t_1][s + 1]); + logBwd[s][t] = sumLogs(logBwd[s][t], logBwd[s + 1][t_1] + logProbabilities[t_1][s + 1]); } if (s + 2 < decodedTargetLen) { if (targetD[s] != blankIndex && (!ctcMergeRepeated || (targetD[s] != targetD[s + 2]))) { - logBwd[s][t] = sumLogs(logBwd[s][t], - logBwd[s + 2][t_1] + logProbabilities[t_1][s + 2]); + logBwd[s][t] = sumLogs(logBwd[s][t], logBwd[s + 2][t_1] + logProbabilities[t_1][s + 2]); } } } @@ -274,8 +271,8 @@ void CTCLoss::execute(dnnl::stream strm) { logBwd[1][0] += logProbabilities[0][(decodedTargetLen > 1) ? 1 : 0]; dstData[b] = -sumLogs(logBwd[0][0], logBwd[1][0]); - } // for batch - }; // threadBody_3 + } // for batch + }; // threadBody_3 parallel_nt(0, threadBody_3); } @@ -284,6 +281,6 @@ bool CTCLoss::created() const { return getType() == Type::CTCLoss; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/ctc_loss.h b/src/plugins/intel_cpu/src/nodes/ctc_loss.h index a07d8f0fc59479..d1a66df3b92b89 100644 --- a/src/plugins/intel_cpu/src/nodes/ctc_loss.h +++ b/src/plugins/intel_cpu/src/nodes/ctc_loss.h @@ -14,7 +14,7 @@ class CTCLoss : public Node { public: CTCLoss(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -22,7 +22,9 @@ class CTCLoss : public Node { static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; void executeDynamicImpl(dnnl::stream strm) override; - bool needPrepareParams() const override { return false; }; + bool needPrepareParams() const override { + return false; + }; private: bool ctcMergeRepeated; @@ -32,6 +34,6 @@ class CTCLoss : public Node { std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/cum_sum.cpp b/src/plugins/intel_cpu/src/nodes/cum_sum.cpp index e411283e661585..43e69e29916430 100644 --- a/src/plugins/intel_cpu/src/nodes/cum_sum.cpp +++ b/src/plugins/intel_cpu/src/nodes/cum_sum.cpp @@ -3,15 +3,16 @@ // #include "cum_sum.h" + +#include +#include + #include "openvino/core/parallel.hpp" #include "openvino/core/type/float16.hpp" #include "openvino/opsets/opset1.hpp" #include "openvino/opsets/opset3.hpp" #include "utils/bfloat16.hpp" -#include -#include - namespace ov { namespace intel_cpu { namespace node { @@ -38,10 +39,11 @@ CumSum::CumSum(const std::shared_ptr& op, const GraphContext::CPtr con errorPrefix = "CumSum layer with name '" + op->get_friendly_name() + "' "; - if ((getOriginalInputsNumber() != numOfInputs && getOriginalInputsNumber() != (numOfInputs - 1)) || getOriginalOutputsNumber() != 1) + if ((getOriginalInputsNumber() != numOfInputs && getOriginalInputsNumber() != (numOfInputs - 1)) || + getOriginalOutputsNumber() != 1) OPENVINO_THROW(errorPrefix, " has incorrect number of input/output edges!"); - const auto &dataShape = getInputShapeAtPort(CUM_SUM_DATA); + const auto& dataShape = getInputShapeAtPort(CUM_SUM_DATA); numOfDims = dataShape.getRank(); if (numOfDims < 1) { OPENVINO_THROW(errorPrefix, " doesn't support 'data' input tensor with rank: ", numOfDims); @@ -70,13 +72,19 @@ void CumSum::initSupportedPrimitiveDescriptors() { dataPrecision = getOriginalInputPrecisionAtPort(CUM_SUM_DATA); if (!one_of(dataPrecision, - ov::element::i8, ov::element::u8, - ov::element::i16, ov::element::i32, ov::element::i64, ov::element::u64, - ov::element::bf16, ov::element::f16, ov::element::f32)) + ov::element::i8, + ov::element::u8, + ov::element::i16, + ov::element::i32, + ov::element::i64, + ov::element::u64, + ov::element::bf16, + ov::element::f16, + ov::element::f32)) OPENVINO_THROW(errorPrefix, " has unsupported 'data' input precision: ", dataPrecision.get_type_name()); if (inputShapes.size() == numOfInputs) { - const auto &axisTensorPrec = getOriginalInputPrecisionAtPort(AXIS); + const auto& axisTensorPrec = getOriginalInputPrecisionAtPort(AXIS); if (axisTensorPrec != ov::element::i32 && axisTensorPrec != ov::element::i64) OPENVINO_THROW(errorPrefix, " has unsupported 'axis' input precision: ", axisTensorPrec.get_type_name()); } @@ -87,16 +95,17 @@ void CumSum::initSupportedPrimitiveDescriptors() { for (size_t i = 1; i < inputShapes.size(); ++i) inDataConf.emplace_back(LayoutType::ncsp, ov::element::i32); - addSupportedPrimDesc(inDataConf, - {{LayoutType::ncsp, dataPrecision}}, - impl_desc_type::ref_any); + addSupportedPrimDesc(inDataConf, {{LayoutType::ncsp, dataPrecision}}, impl_desc_type::ref_any); } void CumSum::execute(dnnl::stream strm) { if (inputShapes.size() == numOfInputs) axis = getAxis(getParentEdgeAt(AXIS)->getMemory(), getParentEdgeAt(CUM_SUM_DATA)->getMemory()); - OV_SWITCH(intel_cpu, CumSumExecute, this, dataPrecision, + OV_SWITCH(intel_cpu, + CumSumExecute, + this, + dataPrecision, OV_CASE(ov::element::i8, int8_t), OV_CASE(ov::element::u8, uint8_t), OV_CASE(ov::element::i16, int16_t), @@ -110,9 +119,10 @@ void CumSum::execute(dnnl::stream strm) { template void CumSum::exec() { - const auto *input = getSrcDataAtPortAs(CUM_SUM_DATA); - auto *output = getDstDataAtPortAs(0); - const VectorDims strides = getParentEdgeAt(CUM_SUM_DATA)->getMemory().getDescWithType()->getStrides(); + const auto* input = getSrcDataAtPortAs(CUM_SUM_DATA); + auto* output = getDstDataAtPortAs(0); + const VectorDims strides = + getParentEdgeAt(CUM_SUM_DATA)->getMemory().getDescWithType()->getStrides(); if (reverse) { if (exclusive) { @@ -130,16 +140,17 @@ void CumSum::exec() { } template -void CumSum::cumSum(const dataType *input, dataType *output, const VectorDims &strides) { +void CumSum::cumSum(const dataType* input, dataType* output, const VectorDims& strides) { VectorDims iterationRange(numOfDims - 1); size_t j = 0; - const auto &shape = getParentEdgeAt(CUM_SUM_DATA)->getMemory().getStaticDims(); + const auto& shape = getParentEdgeAt(CUM_SUM_DATA)->getMemory().getStaticDims(); for (size_t i = 0; i < shape.size(); i++) { if (i == axis) continue; iterationRange[j++] = shape[i]; } - size_t work_amount_dst = std::accumulate(iterationRange.begin(), iterationRange.end(), size_t(1), std::multiplies()); + size_t work_amount_dst = + std::accumulate(iterationRange.begin(), iterationRange.end(), size_t(1), std::multiplies()); parallel_nt(0, [&](const int ithr, const int nthr) { size_t start = 0, end = 0; VectorDims counters(numOfDims - 1, 0); @@ -159,32 +170,32 @@ void CumSum::cumSum(const dataType *input, dataType *output, const VectorDims &s size_t startOffset = getStartOffset(forStartOffset, strides); - const dataType *inputStart = input + startOffset; - dataType *outputStart = output + startOffset; + const dataType* inputStart = input + startOffset; + dataType* outputStart = output + startOffset; size_t offset = strides[axis]; if (reverse) { if (exclusive) { - outputStart[offset*(shape[axis] - 1)] = 0; + outputStart[offset * (shape[axis] - 1)] = 0; for (int64_t i = shape[axis] - 2; i >= 0; i--) { - outputStart[i*offset] = inputStart[(i+1)*offset] + outputStart[(i+1)*offset]; + outputStart[i * offset] = inputStart[(i + 1) * offset] + outputStart[(i + 1) * offset]; } } else { - outputStart[offset*(shape[axis] - 1)] = inputStart[offset * (shape[axis] - 1)]; + outputStart[offset * (shape[axis] - 1)] = inputStart[offset * (shape[axis] - 1)]; for (int64_t i = shape[axis] - 2; i >= 0; i--) { - outputStart[i*offset] = inputStart[i*offset] + outputStart[(i+1)*offset]; + outputStart[i * offset] = inputStart[i * offset] + outputStart[(i + 1) * offset]; } } } else { if (exclusive) { outputStart[0] = 0; for (size_t i = 1; i < shape[axis]; i++) { - outputStart[i*offset] = inputStart[(i-1)*offset] + outputStart[(i-1)*offset]; + outputStart[i * offset] = inputStart[(i - 1) * offset] + outputStart[(i - 1) * offset]; } } else { outputStart[0] = inputStart[0]; for (size_t i = 1; i < shape[axis]; i++) { - outputStart[i*offset] = inputStart[i*offset] + outputStart[(i-1)*offset]; + outputStart[i * offset] = inputStart[i * offset] + outputStart[(i - 1) * offset]; } } } @@ -219,7 +230,8 @@ inline void CumSum::parallelItStep(std::vector& counters, const std::vec } } -inline size_t CumSum::getStartOffset(const std::vector &forStartOffset, const std::vector& strides) const { +inline size_t CumSum::getStartOffset(const std::vector& forStartOffset, + const std::vector& strides) const { size_t startOffset = 0; for (size_t idx = 0; idx < forStartOffset.size(); ++idx) { startOffset += forStartOffset[idx] * strides[idx]; @@ -232,19 +244,19 @@ size_t CumSum::getAxis(const IMemory& _axis, const IMemory& _data) const { const int64_t dataShapeSize = static_cast(_data.getShape().getRank()); int64_t axisValueFromBlob = 0; switch (axisPrecision) { - case ov::element::i32 : { - const auto *axisPtr = _axis.getDataAs(); - axisValueFromBlob = static_cast(axisPtr[0]); - break; - } - case ov::element::i64 : { - const auto *axisPtr = _axis.getDataAs(); - axisValueFromBlob = axisPtr[0]; - break; - } - default : { - OPENVINO_THROW(errorPrefix, " doesn't support 'axis' input with precision: ", axisPrecision.get_type_name()); - } + case ov::element::i32: { + const auto* axisPtr = _axis.getDataAs(); + axisValueFromBlob = static_cast(axisPtr[0]); + break; + } + case ov::element::i64: { + const auto* axisPtr = _axis.getDataAs(); + axisValueFromBlob = axisPtr[0]; + break; + } + default: { + OPENVINO_THROW(errorPrefix, " doesn't support 'axis' input with precision: ", axisPrecision.get_type_name()); + } } if (axisValueFromBlob < -dataShapeSize || axisValueFromBlob > dataShapeSize - 1) OPENVINO_THROW(errorPrefix, " has axis with a value out of range: ", axisValueFromBlob); @@ -263,6 +275,6 @@ void CumSum::executeDynamicImpl(dnnl::stream strm) { execute(strm); } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/cum_sum.h b/src/plugins/intel_cpu/src/nodes/cum_sum.h index b0aad351d55f93..139c7205e81fcc 100644 --- a/src/plugins/intel_cpu/src/nodes/cum_sum.h +++ b/src/plugins/intel_cpu/src/nodes/cum_sum.h @@ -14,7 +14,7 @@ class CumSum : public Node { public: CumSum(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -29,13 +29,13 @@ class CumSum : public Node { void exec(); template - void cumSum(const dataType *input, dataType *output, const std::vector &strides); + void cumSum(const dataType* input, dataType* output, const std::vector& strides); void parallelItInit(size_t start, std::vector& counters, const std::vector& iterationRange); inline void parallelItStep(std::vector& counters, const std::vector& iterationRange); - inline size_t getStartOffset(const std::vector &forStartOffset, const std::vector& strides) const; + inline size_t getStartOffset(const std::vector& forStartOffset, const std::vector& strides) const; size_t getAxis(const IMemory& _axis, const IMemory& _data) const; @@ -48,7 +48,7 @@ class CumSum : public Node { ov::element::Type dataPrecision; std::string errorPrefix; - template + template struct CumSumExecute { void operator()(CumSum* node) { node->exec(); @@ -56,6 +56,6 @@ class CumSum : public Node { }; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/deconv.cpp b/src/plugins/intel_cpu/src/nodes/deconv.cpp index 537f9111b0ceec..f30e3481afbb3d 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.cpp +++ b/src/plugins/intel_cpu/src/nodes/deconv.cpp @@ -4,16 +4,15 @@ #include "deconv.h" -#include "dnnl_extension_utils.h" #include #include -#include "common/primitive_hashing_utils.hpp" #include #include -#include "cpu/x64/cpu_isa_traits.hpp" -#include "shape_inference/shape_inference.hpp" +#include "common/primitive_hashing_utils.hpp" +#include "cpu/x64/cpu_isa_traits.hpp" +#include "dnnl_extension_utils.h" #include "eltwise.h" #include "fake_quantize.h" #include "input.h" @@ -21,16 +20,16 @@ #include "openvino/core/parallel.hpp" #include "openvino/opsets/opset1.hpp" #include "openvino/runtime/make_tensor.hpp" -#include "utils/general_utils.h" +#include "shape_inference/shape_inference.hpp" #include "utils/cpu_utils.hpp" +#include "utils/general_utils.h" #if defined(OV_CPU_WITH_ACL) -#include "executors/acl/acl_utils.hpp" -#include "utils/debug_capabilities.h" +# include "executors/acl/acl_utils.hpp" +# include "utils/debug_capabilities.h" #endif #include - #include #include @@ -40,8 +39,8 @@ namespace ov { namespace intel_cpu { namespace node { -using DefaultDeconvDescs = std::pair; +using DefaultDeconvDescs = + std::pair; using Int8DeconvDesc = dnnl::deconvolution_forward::primitive_desc; namespace { @@ -92,7 +91,7 @@ size_t DeconvKey::hash() const { return seed; } -bool DeconvKey::operator==(const DeconvKey &rhs) const { +bool DeconvKey::operator==(const DeconvKey& rhs) const { bool retVal = true; if (inp0 != rhs.inp0) { retVal = retVal && inp0 && rhs.inp0 && inp0->getDnnlDesc() == rhs.inp0->getDnnlDesc(); @@ -122,8 +121,8 @@ bool DeconvKey::operator==(const DeconvKey &rhs) const { } /** - * Deconvolution shape inference factory. It defines the input mask depending on the existence of the `output_shape` input. - * Since in case it exists, plugin should pass the input data to the shape inference function. + * Deconvolution shape inference factory. It defines the input mask depending on the existence of the `output_shape` + * input. Since in case it exists, plugin should pass the input data to the shape inference function. * */ class DeconfolutionShapeInferFactory : public ShapeInferFactory { @@ -134,16 +133,19 @@ class DeconfolutionShapeInferFactory : public ShapeInferFactory { const auto port_mask = (m_op->get_input_size() > 2) ? PortMask(2) : EMPTY_PORT_MASK; return make_shape_inference(m_op, port_mask); } + private: std::shared_ptr m_op; }; -} // namespace +} // namespace -bool Deconvolution::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool Deconvolution::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { if (std::dynamic_pointer_cast(op) == nullptr && - std::dynamic_pointer_cast(op) == nullptr) { - errorMessage = "Only opset1 ConvolutionBackpropData and GroupConvolutionBackpropData operations are supported"; + std::dynamic_pointer_cast(op) == nullptr) { + errorMessage = + "Only opset1 ConvolutionBackpropData and GroupConvolutionBackpropData operations are supported"; return false; } size_t ndims = op->get_input_partial_shape(0).rank().get_length(); @@ -151,7 +153,8 @@ bool Deconvolution::isSupportedOperation(const std::shared_ptr& errorMessage = "Only 3D, 4D and 5D blobs are supported as input"; return false; } - if (op->get_input_partial_shape(1).is_dynamic() || (op->get_input_size() > 2 && op->get_input_partial_shape(2).is_dynamic())) { + if (op->get_input_partial_shape(1).is_dynamic() || + (op->get_input_size() > 2 && op->get_input_partial_shape(2).is_dynamic())) { errorMessage = "Doesn't support dynamic shapes for 'weights' and 'output_shape' inputs"; return false; } @@ -161,8 +164,8 @@ bool Deconvolution::isSupportedOperation(const std::shared_ptr& return true; } -Deconvolution::Deconvolution(const std::shared_ptr& op, - const GraphContext::CPtr context) : Node(op, context, DeconfolutionShapeInferFactory(op)) { +Deconvolution::Deconvolution(const std::shared_ptr& op, const GraphContext::CPtr context) + : Node(op, context, DeconfolutionShapeInferFactory(op)) { std::string errorMessage; errorPrefix = "Deconvolution node with name '" + getName() + "' "; if (!isSupportedOperation(op, errorMessage)) @@ -175,7 +178,7 @@ Deconvolution::Deconvolution(const std::shared_ptr& op, IC = weightDims[0]; OC = weightDims[1]; - expectedBiasDims = {OC}; + expectedBiasDims = {OC}; groupNum = 1; withGroups = false; @@ -198,7 +201,7 @@ Deconvolution::Deconvolution(const std::shared_ptr& op, groupNum = weightDims[0]; IC = groupNum * weightDims[1]; OC = groupNum * weightDims[2]; - expectedBiasDims = {OC}; + expectedBiasDims = {OC}; withGroups = groupNum > 1; isDW = withGroups && groupNum == OC && groupNum == IC; @@ -228,8 +231,11 @@ Deconvolution::Deconvolution(const std::shared_ptr& op, lastOutputSpatialDims = ov::as_type(op->get_input_node_ptr(2))->cast_vector(); if (externOutShape && isDynamicNode()) { const auto spDimsNum = getInputShapeAtPort(0).getRank() - 2; - if (getInputShapeAtPort(2).getStaticDims()[0] != spDimsNum || (isConstOutShape && lastOutputSpatialDims.size() != spDimsNum)) { - OPENVINO_THROW(errorPrefix, "'output_shape' input has incorrect number of elements. Expected = ", spDimsNum); + if (getInputShapeAtPort(2).getStaticDims()[0] != spDimsNum || + (isConstOutShape && lastOutputSpatialDims.size() != spDimsNum)) { + OPENVINO_THROW(errorPrefix, + "'output_shape' input has incorrect number of elements. Expected = ", + spDimsNum); } } @@ -239,8 +245,10 @@ Deconvolution::Deconvolution(const std::shared_ptr& op, for (size_t i = 0; i < spatialRank; ++i) is1x1 = is1x1 && *(weightDimsReversItr++) == 1; // 1x1 deconv has some test case failed. The cause is upstream ONEDNN unsupported brgemm implementation cases are - // enabled in forked ONEDNNN https://github.com/openvinotoolkit/oneDNN/blob/117e287000b48a34a7218fcaa274a91571141728/src/common/convolution.cpp#L138. - // Some test cases on 1x1 kernel failed on accuracy check, current WA is disabling brgemm deconv implementation for such cases. + // enabled in forked ONEDNNN + // https://github.com/openvinotoolkit/oneDNN/blob/117e287000b48a34a7218fcaa274a91571141728/src/common/convolution.cpp#L138. + // Some test cases on 1x1 kernel failed on accuracy check, current WA is disabling brgemm deconv implementation for + // such cases. if (is1x1 && deconvAttrs.paddingL != deconvAttrs.paddingR) { // case1: Specify asymmetric padding explicitly asymmetricPaddingAnd1x1 = true; @@ -289,7 +297,9 @@ bool Deconvolution::canBeExecutedInInt8() const { return false; if (!impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core)) { const auto& inMaxDims = getOutputShapeAtPort(0).getMaxDims(); - if (std::any_of(inMaxDims.begin(), inMaxDims.end(), [](Dim dim) { return dim == Shape::UNDEFINED_DIM; })) { + if (std::any_of(inMaxDims.begin(), inMaxDims.end(), [](Dim dim) { + return dim == Shape::UNDEFINED_DIM; + })) { return false; } // heuristicConst = 2^26 @@ -309,7 +319,8 @@ bool Deconvolution::canBeExecutedInInt8() const { // not supported in oneDNN int channelBlock = impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core) ? 16 - : impl::cpu::x64::mayiuse(impl::cpu::x64::avx2) ? 8 : 4; + : impl::cpu::x64::mayiuse(impl::cpu::x64::avx2) ? 8 + : 4; if (withGroups && !isDW && (IC % channelBlock != 0 || OC % channelBlock != 0)) return false; if (!impl::cpu::x64::mayiuse(impl::cpu::x64::avx512_core) && deconvAttrs.stride.back() > 3) @@ -330,16 +341,18 @@ bool Deconvolution::canBeExecutedInInt8() const { bool Deconvolution::canFuse(const NodePtr& node) const { if (canBeExecutedInInt8()) return canFuseSimpleOperation(node); - // Upstream ONEDNN conv_backward_data primitive can't support any post-ops, fork onednn added depthwise support in conv_backward_data JIT implementation. - // ONEDNN deconv primitive can support most of post-ops, but the post-ops implementation details are different. - // So current deconv implementation list in onednn has 2 kinds of implements: + // Upstream ONEDNN conv_backward_data primitive can't support any post-ops, fork onednn added depthwise support in + // conv_backward_data JIT implementation. ONEDNN deconv primitive can support most of post-ops, but the post-ops + // implementation details are different. So current deconv implementation list in onednn has 2 kinds of implements: // 1. deconv implementation with JIT post-ops supported in the kernel (such as brgdeconv) - // 2. forked conv_data_backwards implementation with JIT depthwise post-ops + reference implementation for other post ops. - // Considering that some deconv fallback on the JIT implementation, we limit the post ops fusing to avoid regressions. - // Regression with stylegan2 int8 model pattern: - // none-quantzied deconv(with none-const weight) + FQ pattern fall back on JIT because of onednn limitation. (fall back ticket MFDNN-11577). - // If FQ is fused, it runs with the ref post-ops implementation. - // @todo: if onednn can ensure all the deconv run with the brgemm implementation, we can unify the fuse criteria between int8 and fp32 use cases. + // 2. forked conv_data_backwards implementation with JIT depthwise post-ops + reference implementation for other + // post ops. + // Considering that some deconv fallback on the JIT implementation, we limit the post ops fusing to avoid + // regressions. Regression with stylegan2 int8 model pattern: none-quantzied deconv(with none-const weight) + FQ + // pattern fall back on JIT because of onednn limitation. (fall back ticket MFDNN-11577). If FQ is fused, it runs + // with the ref post-ops implementation. + // @todo: if onednn can ensure all the deconv run with the brgemm implementation, we can unify the fuse criteria + // between int8 and fp32 use cases. return (fusedWith.empty() && node->canBePerformedAsScaleShift(this)); } @@ -360,8 +373,10 @@ std::pair Deconvolution::makeDummyInOutShape() { const auto& maxDims = shape.getMaxDims(); const auto& dims = shape.getDims(); for (size_t i = 0; i < dims.size() - 2; ++i) { - lastOutputSpatialDims[i] = dims[i + 2] == Shape::UNDEFINED_DIM ? std::min(maxDims[i + 2], - std::max(minDims[i + 2], static_cast(64))) : dims[i + 2]; + lastOutputSpatialDims[i] = + dims[i + 2] == Shape::UNDEFINED_DIM + ? std::min(maxDims[i + 2], std::max(minDims[i + 2], static_cast(64))) + : dims[i + 2]; } } @@ -380,14 +395,18 @@ std::pair Deconvolution::makeDummyInOutShape() { for (size_t i = 0; i < origInDims.size() - 2; i++) { if (origInDims[i + 2] == Shape::UNDEFINED_DIM && (origInMinDims[i + 2] != 0 || origInMaxDims[i + 2] != Shape::UNDEFINED_DIM)) { - // if input shape is dynamic and bounded, paddings should be computed basing on the following limitations: + // if input shape is dynamic and bounded, paddings should be computed basing on the following + // limitations: // 1. paddings must not be negative - // 2. the result padding must have such a value to keep the dummy dimensions inside the predefined interval - auto c1 = lastOutputSpatialDims[i] - deconvAttrs.outputPadding[i] - 1 - - (deconvAttrs.dilation[i] + 1) * static_cast(weightDims[wghOffset + 2 + i] - 1); + // 2. the result padding must have such a value to keep the dummy dimensions inside the + // predefined interval + auto c1 = + lastOutputSpatialDims[i] - deconvAttrs.outputPadding[i] - 1 - + (deconvAttrs.dilation[i] + 1) * static_cast(weightDims[wghOffset + 2 + i] - 1); if (origInMaxDims[i + 2] != Shape::UNDEFINED_DIM) { - auto upper_bound = deconvAttrs.stride[i] * static_cast(origInMaxDims[i + 2] - 1) - c1; + auto upper_bound = + deconvAttrs.stride[i] * static_cast(origInMaxDims[i + 2] - 1) - c1; if (upper_bound < 0) { OPENVINO_THROW(errorPrefix, ": paddings for dummy shapes can't be computed"); } @@ -403,9 +422,11 @@ std::pair Deconvolution::makeDummyInOutShape() { for (size_t i = 0; i < inputDims.size() - 2; i++) { if (origInDims[2 + i] == Shape::UNDEFINED_DIM) { - inputDims[2 + i] = (lastOutputSpatialDims[i] - (deconvAttrs.dilation[i] + 1) * - (weightDims[wghOffset + 2 + i] - 1) - 1 + paddings[i] - deconvAttrs.outputPadding[i]) / - deconvAttrs.stride[i] + 1; + inputDims[2 + i] = (lastOutputSpatialDims[i] - + (deconvAttrs.dilation[i] + 1) * (weightDims[wghOffset + 2 + i] - 1) - 1 + + paddings[i] - deconvAttrs.outputPadding[i]) / + deconvAttrs.stride[i] + + 1; } } } @@ -456,12 +477,14 @@ void Deconvolution::getSupportedDescriptors() { if (!descs.empty()) return; isInt8 = canBeExecutedInInt8(); - deconvAttrs.withBiasesParam = withBiases = externOutShape ? getOriginalInputsNumber() == 4 : getOriginalInputsNumber() == 3; + deconvAttrs.withBiasesParam = withBiases = + externOutShape ? getOriginalInputsNumber() == 4 : getOriginalInputsNumber() == 3; ov::element::Type inPrecision = getOriginalInputPrecisionAtPort(0); ov::element::Type outPrecision = getOriginalOutputPrecisionAtPort(0); if (isInt8) { - // TODO: We have to extend jit_avx512_core_x8s8s32x_deconv_fwd_kernel from oneDNN to support BF16 output data type + // TODO: We have to extend jit_avx512_core_x8s8s32x_deconv_fwd_kernel from oneDNN to support BF16 output data + // type if (ov::element::bf16 == inPrecision) inPrecision = ov::element::f32; if (ov::element::bf16 == outPrecision) @@ -475,11 +498,12 @@ void Deconvolution::getSupportedDescriptors() { auto inputDataType = DnnlExtensionUtils::ElementTypeToDataType(inPrecision); outputDataType = DnnlExtensionUtils::ElementTypeToDataType(outPrecision); if (inputDataType == memory::data_type::bf16 || outputDataType == memory::data_type::bf16) - inputDataType = outputDataType = memory::data_type::bf16; + inputDataType = outputDataType = memory::data_type::bf16; if (inputDataType == memory::data_type::f16 || outputDataType == memory::data_type::f16) - inputDataType = outputDataType = memory::data_type::f16; + inputDataType = outputDataType = memory::data_type::f16; if (!fusedWith.empty()) { - outputDataType = DnnlExtensionUtils::ElementTypeToDataType(fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0)); + outputDataType = DnnlExtensionUtils::ElementTypeToDataType( + fusedWith[fusedWith.size() - 1]->getOriginalOutputPrecisionAtPort(0)); } if (getParentEdges().size() != (withBiases ? (biasPort + 1) : biasPort)) { OPENVINO_THROW(errorPrefix, " has incorrect number of input edges"); @@ -489,7 +513,7 @@ void Deconvolution::getSupportedDescriptors() { } VectorDims inDims, outDims; std::tie(inDims, outDims) = makeDummyInOutShape(); - inShape = Shape(inDims); + inShape = Shape(inDims); outShape = Shape(outDims); initPaddingR(inShape, outShape); @@ -505,17 +529,18 @@ void Deconvolution::getSupportedDescriptors() { config.outConfs.resize(getOriginalOutputsNumber()); // ACL use same precision for all inputs config.inConfs[0].setMemDesc( - creatorsMap.at(format)->createSharedDesc(getOriginalInputPrecisionAtPort(0), getInputShapeAtPort(0))); + creatorsMap.at(format)->createSharedDesc(getOriginalInputPrecisionAtPort(0), getInputShapeAtPort(0))); config.inConfs[1].setMemDesc( - creatorsMap.at(weights_format)->createSharedDesc(getOriginalInputPrecisionAtPort(0), getInputShapeAtPort(1))); + creatorsMap.at(weights_format) + ->createSharedDesc(getOriginalInputPrecisionAtPort(0), getInputShapeAtPort(1))); for (size_t i = 2; i < getParentEdges().size(); ++i) { config.inConfs[i].setMemDesc( - creatorsMap.at(format)->createSharedDesc(getOriginalInputPrecisionAtPort(0), getInputShapeAtPort(i))); + creatorsMap.at(format)->createSharedDesc(getOriginalInputPrecisionAtPort(0), getInputShapeAtPort(i))); } for (size_t i = 0; i < config.outConfs.size(); ++i) { config.outConfs[i].setMemDesc( - creatorsMap.at(format)->createSharedDesc(getOriginalOutputPrecisionAtPort(0), getOutputShapeAtPort(i))); + creatorsMap.at(format)->createSharedDesc(getOriginalOutputPrecisionAtPort(0), getOutputShapeAtPort(i))); } std::vector srcMemoryDescs; @@ -532,7 +557,8 @@ void Deconvolution::getSupportedDescriptors() { return AclDeconvExecutorBuilder::customIsSupported(deconvAttrs, srcMemoryDescs, dstMemoryDescs); }; useACL = checkDesc(LayoutType::nspc) || checkDesc(LayoutType::ncsp); - if (useACL) return; + if (useACL) + return; #endif dnnlCompatibleWeiDims = getWeightDims(); // Construct the ONEDNN deconv OP weight shape. @@ -547,26 +573,30 @@ void Deconvolution::getSupportedDescriptors() { auto format = rank == 5 ? dnnl::memory::format_tag::ndhwc : rank == 4 ? dnnl::memory::format_tag::nhwc : dnnl::memory::format_tag::nwc; - MemoryDescPtr in_candidate = std::make_shared(getInputShapeAtPort(0), inputDataType, format); - MemoryDescPtr out_candidate = std::make_shared(getOutputShapeAtPort(0), outputDataType, format); + MemoryDescPtr in_candidate = + std::make_shared(getInputShapeAtPort(0), inputDataType, format); + MemoryDescPtr out_candidate = + std::make_shared(getOutputShapeAtPort(0), outputDataType, format); createDescriptor({in_candidate}, {out_candidate}); } else { for (auto format : getAvailableFormatsForDims(getInputShapeAtPort(0))) { - MemoryDescPtr in_candidate = std::make_shared(getInputShapeAtPort(0), inputDataType, format); - MemoryDescPtr out_candidate = std::make_shared(getOutputShapeAtPort(0), outputDataType, format); + MemoryDescPtr in_candidate = + std::make_shared(getInputShapeAtPort(0), inputDataType, format); + MemoryDescPtr out_candidate = + std::make_shared(getOutputShapeAtPort(0), outputDataType, format); createDescriptor({in_candidate}, {out_candidate}); } } } -void Deconvolution::initPaddingR(const Shape &inShape, const Shape &outShape) { +void Deconvolution::initPaddingR(const Shape& inShape, const Shape& outShape) { for (size_t i = 0; i < deconvAttrs.paddingR.size(); i++) { int with_group = getAlgorithm() == Algorithm::DeconvolutionGrouped ? 1 : 0; const auto& weightDims = getWeightDims(); int krn = weightDims[with_group + 2 + i]; int src = outShape.getStaticDims()[2 + i]; int dst = inShape.getStaticDims()[2 + i]; - krn = (krn - 1)*(deconvAttrs.dilation[i] + 1) + 1; + krn = (krn - 1) * (deconvAttrs.dilation[i] + 1) + 1; deconvAttrs.paddingR[i] = (dst - 1) * deconvAttrs.stride[i] - (src - krn + deconvAttrs.paddingL[i]); } } @@ -584,11 +614,22 @@ void Deconvolution::setPostOps(dnnl::primitive_attr& attr, const VectorDims& dim // For deconv OP, Deconv_OC = IC, Deconv_IC = OC. // Openvino per-channel weight scales are applied on IC/Deconv_OC dimension. // So for deconvolution, - // Weight dims in NON-Group deconv: [Deconv_OC, Deconv_IC, KH, KW], perchannel weight scale is applied on Deconv_OC DIM + // Weight dims in NON-Group deconv: [Deconv_OC, Deconv_IC, KH, KW], perchannel weight scale is applied on Deconv_OC + // DIM // weiScaleMaskPerChannel = 1 << 0 - // Weight dims in Group deconv: [Group, Deconv_OC, Deconv_IC, KH, KW], perchannel weight scale is applied on GROUP and Deconv_OC, + // Weight dims in Group deconv: [Group, Deconv_OC, Deconv_IC, KH, KW], perchannel weight scale is applied on + // GROUP and Deconv_OC, // weiScaleMaskPerChannel = ( 1 << 0 | 1 << 1) = 0x03 - DnnlPostOpsComposerLegacy dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, 1, isInt8, withGroups ? 3 : 1 << 0, getDQScales(), withBiases); + DnnlPostOpsComposerLegacy dnnlpoc(getEngine(), + attr, + ops, + postOpsArgs, + dims, + 1, + isInt8, + withGroups ? 3 : 1 << 0, + getDQScales(), + withBiases); for (size_t i = 0; i < fusedWith.size(); ++i) { auto& node = fusedWith[i]; @@ -633,7 +674,7 @@ bool Deconvolution::needShapeInfer() const { return false; } -VectorDims Deconvolution::shapeInferInternal(const VectorDims &inDims, std::vector outSpDims) const { +VectorDims Deconvolution::shapeInferInternal(const VectorDims& inDims, std::vector outSpDims) const { std::vector> inputShapesRefs{std::ref(inDims), std::ref(getWeightDims())}; std::unordered_map inputValues; VectorDims outSpDimsVecShape; @@ -678,7 +719,7 @@ void Deconvolution::execute(dnnl::stream strm) { for (size_t i = 0; i < getOriginalOutputsNumber(); i++) { dstMemory.push_back(getDstMemoryAtPort(i)); } - //TODO: need to pass post ops data + // TODO: need to pass post ops data execPtrDeconvACL->exec(srcMemory, dstMemory, nullptr); return; } @@ -696,43 +737,50 @@ void Deconvolution::execute(dnnl::stream strm) { namespace { dnnl::primitive_desc createDescriptorInternal(const dnnl::memory::desc& in_candidate, - const dnnl::memory::desc& wgh_candidate, - const dnnl::memory::desc& bias_candidate, - const dnnl::memory::desc& out_candidate, - const bool with_bias, - const std::vector& stride, - const std::vector& dilation, - const ov::CoordinateDiff& paddingL, - const ov::CoordinateDiff& paddingR, - const dnnl::primitive_attr& attr, - const dnnl::engine& engine) { - auto convertDims = [] (const std::vector& orig_dims) { + const dnnl::memory::desc& wgh_candidate, + const dnnl::memory::desc& bias_candidate, + const dnnl::memory::desc& out_candidate, + const bool with_bias, + const std::vector& stride, + const std::vector& dilation, + const ov::CoordinateDiff& paddingL, + const ov::CoordinateDiff& paddingR, + const dnnl::primitive_attr& attr, + const dnnl::engine& engine) { + auto convertDims = [](const std::vector& orig_dims) { return memory::dims(orig_dims.begin(), orig_dims.end()); }; if (with_bias) { - return dnnl::deconvolution_forward::primitive_desc( - engine, - prop_kind::forward_inference, - dnnl::algorithm::deconvolution_direct, - in_candidate, wgh_candidate, bias_candidate, out_candidate, - convertDims(stride), convertDims(dilation), - convertDims(paddingL), convertDims(paddingR), - attr); + return dnnl::deconvolution_forward::primitive_desc(engine, + prop_kind::forward_inference, + dnnl::algorithm::deconvolution_direct, + in_candidate, + wgh_candidate, + bias_candidate, + out_candidate, + convertDims(stride), + convertDims(dilation), + convertDims(paddingL), + convertDims(paddingR), + attr); } else { - return dnnl::deconvolution_forward::primitive_desc( - engine, - prop_kind::forward_inference, - dnnl::algorithm::deconvolution_direct, - in_candidate, wgh_candidate, out_candidate, - convertDims(stride), convertDims(dilation), - convertDims(paddingL), convertDims(paddingR), - attr); + return dnnl::deconvolution_forward::primitive_desc(engine, + prop_kind::forward_inference, + dnnl::algorithm::deconvolution_direct, + in_candidate, + wgh_candidate, + out_candidate, + convertDims(stride), + convertDims(dilation), + convertDims(paddingL), + convertDims(paddingR), + attr); } } -} // namespace +} // namespace -Node::AttrPtr Deconvolution::makePrimitiveAttr(const VectorDims &dims) { +Node::AttrPtr Deconvolution::makePrimitiveAttr(const VectorDims& dims) { auto attr = std::make_shared(dnnl::primitive_attr()); setPostOps(*attr, dims); @@ -747,81 +795,61 @@ Node::AttrPtr Deconvolution::initPrimitiveAttr() { const std::vector& Deconvolution::getDefaultImplPriority() { static const std::vector priorities { impl_desc_type::unknown, - // Undef impl type is used to express use-cases there real type is unkown during compilation - // Undef has higher priority than defined types in order to force primitive selection logic to make decision based on other properties - impl_desc_type::undef, - impl_desc_type::brgconv_avx512_amx_1x1, - impl_desc_type::brgconv_avx512_amx, - impl_desc_type::jit_avx512_amx_dw, - impl_desc_type::jit_avx512_amx_1x1, - impl_desc_type::jit_avx512_amx, - impl_desc_type::brgconv_avx512_1x1, - impl_desc_type::brgconv_avx512, - impl_desc_type::jit_avx512_dw, - impl_desc_type::jit_avx512_1x1, - impl_desc_type::jit_avx512, - impl_desc_type::brgconv_avx2_1x1, - impl_desc_type::brgconv_avx2, - impl_desc_type::jit_uni_dw, - impl_desc_type::jit_uni_1x1, - impl_desc_type::jit_uni, - impl_desc_type::jit_avx2_dw, - impl_desc_type::jit_avx2_1x1, - impl_desc_type::jit_avx2, - impl_desc_type::jit_avx_dw, - impl_desc_type::jit_avx_1x1, - impl_desc_type::jit_avx, - impl_desc_type::jit_sse42_dw, - impl_desc_type::jit_sse42_1x1, - impl_desc_type::jit_sse42, + // Undef impl type is used to express use-cases there real type is unkown during compilation + // Undef has higher priority than defined types in order to force primitive selection logic to make decision + // based on other properties + impl_desc_type::undef, impl_desc_type::brgconv_avx512_amx_1x1, impl_desc_type::brgconv_avx512_amx, + impl_desc_type::jit_avx512_amx_dw, impl_desc_type::jit_avx512_amx_1x1, impl_desc_type::jit_avx512_amx, + impl_desc_type::brgconv_avx512_1x1, impl_desc_type::brgconv_avx512, impl_desc_type::jit_avx512_dw, + impl_desc_type::jit_avx512_1x1, impl_desc_type::jit_avx512, impl_desc_type::brgconv_avx2_1x1, + impl_desc_type::brgconv_avx2, impl_desc_type::jit_uni_dw, impl_desc_type::jit_uni_1x1, + impl_desc_type::jit_uni, impl_desc_type::jit_avx2_dw, impl_desc_type::jit_avx2_1x1, + impl_desc_type::jit_avx2, impl_desc_type::jit_avx_dw, impl_desc_type::jit_avx_1x1, impl_desc_type::jit_avx, + impl_desc_type::jit_sse42_dw, impl_desc_type::jit_sse42_1x1, impl_desc_type::jit_sse42, #if defined(OPENVINO_ARCH_ARM64) - impl_desc_type::jit_asimd, + impl_desc_type::jit_asimd, #endif - impl_desc_type::gemm_any, - impl_desc_type::gemm_blas, - impl_desc_type::gemm_avx512, - impl_desc_type::gemm_avx2, - impl_desc_type::gemm_avx, - impl_desc_type::gemm_sse42, - impl_desc_type::gemm_acl, - impl_desc_type::acl, - impl_desc_type::jit_gemm, - impl_desc_type::ref_any, - impl_desc_type::ref, + impl_desc_type::gemm_any, impl_desc_type::gemm_blas, impl_desc_type::gemm_avx512, impl_desc_type::gemm_avx2, + impl_desc_type::gemm_avx, impl_desc_type::gemm_sse42, impl_desc_type::gemm_acl, impl_desc_type::acl, + impl_desc_type::jit_gemm, impl_desc_type::ref_any, impl_desc_type::ref, }; if (!asymmetricPaddingAnd1x1) return priorities; static const std::vector priorities_wo_brgemm = [&] { - std::vectorresult; - std::copy_if(priorities.begin(), priorities.end(), std::back_inserter(result), - [](impl_desc_type type) { return !(type & impl_desc_type::brgconv); }); - return result;}(); + std::vector result; + std::copy_if(priorities.begin(), priorities.end(), std::back_inserter(result), [](impl_desc_type type) { + return !(type & impl_desc_type::brgconv); + }); + return result; + }(); return priorities_wo_brgemm; } bool Deconvolution::isImplicit1x1PaddingAsymmetric(const VectorDims& inputDims) { - auto isZero = [](std::ptrdiff_t i) { return i == 0; }; + auto isZero = [](std::ptrdiff_t i) { + return i == 0; + }; size_t spatialRank = getInputShapeAtPort(0).getRank() - 2; - if (is1x1 && std::all_of(deconvAttrs.paddingR.begin(), deconvAttrs.paddingR.end(), isZero) - && std::all_of(deconvAttrs.paddingL.begin(), deconvAttrs.paddingL.end(), isZero) - && std::all_of(deconvAttrs.outputPadding.begin(), deconvAttrs.outputPadding.end(), isZero) - ) { - auto calPaddingEnd = [](int64_t i, int64_t o, int64_t s) -> int64_t { - // Accoriding to https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html, - // output[i] = (input[i] -1) * stride[i] - 2 x padding[i] + dilation[i] x (kernel_size[i] - 1) + output_padding[i] + 1. - // When kernel_size[i] = 1, output_padding = 0, output[i] = (input[i] -1) * stride[i] - 2 x padding[i] + 1. - // implicit padding end = 2 x padding[i] = (input[i] -1) * stride[i] + 1 - output[i] - return (i - 1) * s + 1 - o;}; - for (size_t i = 0; i < spatialRank; i++) { - int64_t inputDim = static_cast(inputDims[i + 2]); - int64_t outputDim = static_cast(lastOutputSpatialDims[i]); - int64_t stride = static_cast(deconvAttrs.stride[i]); - if (calPaddingEnd(inputDim, outputDim, stride) > 0) { - return true; - } + if (is1x1 && std::all_of(deconvAttrs.paddingR.begin(), deconvAttrs.paddingR.end(), isZero) && + std::all_of(deconvAttrs.paddingL.begin(), deconvAttrs.paddingL.end(), isZero) && + std::all_of(deconvAttrs.outputPadding.begin(), deconvAttrs.outputPadding.end(), isZero)) { + auto calPaddingEnd = [](int64_t i, int64_t o, int64_t s) -> int64_t { + // Accoriding to https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html, + // output[i] = (input[i] -1) * stride[i] - 2 x padding[i] + dilation[i] x (kernel_size[i] - 1) + + // output_padding[i] + 1. When kernel_size[i] = 1, output_padding = 0, output[i] = (input[i] -1) * stride[i] + // - 2 x padding[i] + 1. implicit padding end = 2 x padding[i] = (input[i] -1) * stride[i] + 1 - output[i] + return (i - 1) * s + 1 - o; + }; + for (size_t i = 0; i < spatialRank; i++) { + int64_t inputDim = static_cast(inputDims[i + 2]); + int64_t outputDim = static_cast(lastOutputSpatialDims[i]); + int64_t stride = static_cast(deconvAttrs.stride[i]); + if (calPaddingEnd(inputDim, outputDim, stride) > 0) { + return true; } + } } return false; } @@ -854,8 +882,10 @@ void Deconvolution::prepareParams() { dstMemoryDescs.push_back(getChildEdgeAt(i)->getMemory().getDescWithType()); } - execPtrDeconvACL = selected_pd->getExecutorFactoryAs()->makeExecutor(deconvAttrs, srcMemoryDescs, - dstMemoryDescs, *attr); + execPtrDeconvACL = selected_pd->getExecutorFactoryAs()->makeExecutor(deconvAttrs, + srcMemoryDescs, + dstMemoryDescs, + *attr); selected_pd->setImplementationType(execPtrDeconvACL->getImplType()); return; } @@ -891,7 +921,7 @@ void Deconvolution::prepareParams() { OPENVINO_THROW("Bias memory memory is undefined."); biasDesc = biasMemPtr->getDescWithType(); } - bool is1x1PaddingAsymmetric = false; + bool is1x1PaddingAsymmetric = false; if (externOutShape && (!isConstOutShape || isDynamicNode())) { // Check implicit asymmetric padding case for dynamic case and runtime output shape. is1x1PaddingAsymmetric = isImplicit1x1PaddingAsymmetric(getSrcMemoryAtPort(0)->getShape().getStaticDims()); @@ -917,34 +947,41 @@ void Deconvolution::prepareParams() { dnnl::memory::desc dnnlBiasDesc; const auto& weiDims = key.inp1->getShape().getStaticDims(); const auto srcDataType = key.inp0->getDataType(); - const auto weiDataType = (one_of(srcDataType, memory::data_type::s8, memory::data_type::u8)) ? - memory::data_type::s8 : srcDataType; + const auto weiDataType = + (one_of(srcDataType, memory::data_type::s8, memory::data_type::u8)) ? memory::data_type::s8 : srcDataType; auto wghDescAny = - dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(weiDims), - weiDataType, - memory::format_tag::any); + dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(weiDims), weiDataType, memory::format_tag::any); if (key.bias) dnnlBiasDesc = key.bias->getDnnlDesc(); - desc = createDescriptorInternal(key.inp0->getDnnlDesc(), wghDescAny, dnnlBiasDesc, key.out->getDnnlDesc(), - key.bias != nullptr, key.stride, key.dilation, key.paddingL, key.paddingR, key.attr, engine); + desc = createDescriptorInternal(key.inp0->getDnnlDesc(), + wghDescAny, + dnnlBiasDesc, + key.out->getDnnlDesc(), + key.bias != nullptr, + key.stride, + key.dilation, + key.paddingL, + key.paddingR, + key.attr, + engine); primitive_desc_iterator itpd = desc; executorPtr execPtr = nullptr; while (static_cast(itpd)) { impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str()); - //Skip the brgemm implemenation for asymmetric padding case because of the accuracy issue. + // Skip the brgemm implemenation for asymmetric padding case because of the accuracy issue. if (key.isImplicit1x1PaddingAsymmetric && (impl_type & impl_desc_type::brgconv)) continue; if (impl_type == key.implType) { auto prim_desc = deconvolution_forward::primitive_desc(itpd.get()); execPtr = std::make_shared(prim_desc, - key.inp0->getDnnlDesc(), - key.inp1->getDnnlDesc(), - key.out->getDnnlDesc(), - engine, - key.constWeight); + key.inp0->getDnnlDesc(), + key.inp1->getDnnlDesc(), + key.out->getDnnlDesc(), + engine, + key.constWeight); break; } @@ -954,16 +991,27 @@ void Deconvolution::prepareParams() { } if (!execPtr) { - auto inDesc = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(key.inp0->getShape().getStaticDims()), - key.inp0->getDataType(), - memory::format_tag::any); - auto outDesc = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(key.out->getShape().getStaticDims()), - key.out->getDataType(), - memory::format_tag::any); + auto inDesc = + dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(key.inp0->getShape().getStaticDims()), + key.inp0->getDataType(), + memory::format_tag::any); + auto outDesc = + dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(key.out->getShape().getStaticDims()), + key.out->getDataType(), + memory::format_tag::any); dnnl::primitive_desc anyDeconvDesc; - anyDeconvDesc = createDescriptorInternal(inDesc, wghDescAny, dnnlBiasDesc, outDesc, key.bias != nullptr, - key.stride, key.dilation, key.paddingL, key.paddingR, key.attr, engine); + anyDeconvDesc = createDescriptorInternal(inDesc, + wghDescAny, + dnnlBiasDesc, + outDesc, + key.bias != nullptr, + key.stride, + key.dilation, + key.paddingL, + key.paddingR, + key.attr, + engine); if (anyDeconvDesc) { auto prim_desc = deconvolution_forward::primitive_desc(anyDeconvDesc.get()); execPtr = std::make_shared(prim_desc, @@ -983,13 +1031,12 @@ void Deconvolution::prepareParams() { auto cache = context->getParamsCache(); auto result = cache->getOrCreate(key, builder); - execPtr = result.first; if (!execPtr) OPENVINO_THROW("Primitive descriptor was not found for node ", getName(), "."); primArgs[DNNL_ARG_SRC] = srcMemPtr->getPrimitive(); - primArgs[DNNL_ARG_DST]= dstMemPtr->getPrimitive(); + primArgs[DNNL_ARG_DST] = dstMemPtr->getPrimitive(); if (weightIsConst) { // const weight preparation/reordering needs to be done once at next execution // when the input weight data is guaranteed to be ready (considering possible const-folding @@ -1017,8 +1064,8 @@ void Deconvolution::prepareParams() { #endif } -void Deconvolution::createDescriptor(const std::vector &inputDesc, - const std::vector &outputDesc) { +void Deconvolution::createDescriptor(const std::vector& inputDesc, + const std::vector& outputDesc) { auto inDesc = inputDesc[0]->isDefined() ? inputDesc[0] : inputDesc[0]->cloneWithNewDims(inShape.getStaticDims()); auto dnnlInDesc = MemoryDescUtils::convertToDnnlBlockedMemoryDesc(*inDesc); const auto& in_candidate = dnnlInDesc.getDnnlDesc(); @@ -1039,26 +1086,38 @@ void Deconvolution::createDescriptor(const std::vector &inputDesc AttrPtr attr = initPrimitiveAttr(); if (withBiases) { memory::data_type bdt = memory::data_type::f32; - bias_candidate = dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(expectedBiasDims), bdt, memory::format_tag::any); + bias_candidate = + dnnl::memory::desc(DnnlExtensionUtils::convertToDnnlDims(expectedBiasDims), bdt, memory::format_tag::any); } - dnnl::memory::desc wgh_candidate(DnnlExtensionUtils::convertToDnnlDims(dnnlCompatibleWeiDims), isInt8 ? memory::data_type::s8 : dnnlInDesc.getDataType(), - memory::format_tag::any); - descs.emplace_back(createDescriptorInternal(in_candidate, wgh_candidate, bias_candidate, - out_candidate, withBiases, deconvAttrs.stride, deconvAttrs.dilation, - deconvAttrs.paddingL, deconvAttrs.paddingR, *attr, getEngine())); + dnnl::memory::desc wgh_candidate(DnnlExtensionUtils::convertToDnnlDims(dnnlCompatibleWeiDims), + isInt8 ? memory::data_type::s8 : dnnlInDesc.getDataType(), + memory::format_tag::any); + descs.emplace_back(createDescriptorInternal(in_candidate, + wgh_candidate, + bias_candidate, + out_candidate, + withBiases, + deconvAttrs.stride, + deconvAttrs.dilation, + deconvAttrs.paddingL, + deconvAttrs.paddingR, + *attr, + getEngine())); } -std::shared_ptr Deconvolution::getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const { +std::shared_ptr Deconvolution::getSrcMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const { if (idx == 2 && !withBiases) { - //Expected dest shape; + // Expected dest shape; return std::make_shared(ov::element::i32, Shape(getInputShapeAtPort(2).getStaticDims())); } else if (idx > 0) { // weight and bias are exposed with the planar layout. // we need to store 'weight' input as edge, - // because at this moment we can't simple replace internal blob with input, since we need to save weight data as is, but with different order - return std::make_shared(getOriginalInputPrecisionAtPort(idx), Shape(getInputShapeAtPort(idx).getStaticDims())); + // because at this moment we can't simple replace internal blob with input, since we need to save weight data as + // is, but with different order + return std::make_shared(getOriginalInputPrecisionAtPort(idx), + Shape(getInputShapeAtPort(idx).getStaticDims())); } - //idx =0 case + // idx =0 case auto desc = prim_desc.src_desc(idx); if (getInputShapeAtPort(idx).isDynamic()) { return DnnlExtensionUtils::makeUndefinedDesc(desc, getInputShapeAtPort(idx)); @@ -1066,8 +1125,8 @@ std::shared_ptr Deconvolution::getSrcMemDesc(const dnnl::primitive_d return DnnlExtensionUtils::makeDescriptor(desc); } -std::shared_ptr Deconvolution::getDstMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const { - auto desc = prim_desc.dst_desc(idx); +std::shared_ptr Deconvolution::getDstMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const { + auto desc = prim_desc.dst_desc(idx); if (getOutputShapeAtPort(idx).isDynamic()) { return DnnlExtensionUtils::makeUndefinedDesc(desc, getOutputShapeAtPort(idx)); } @@ -1081,7 +1140,8 @@ ov::element::Type Deconvolution::getRuntimePrecision() const { for (size_t i = 0; i < std::min(getParentEdges().size(), inputsNumLimit); i++) { auto parentEdge = getParentEdgeAt(i); if (parentEdge && parentEdge->getStatus() == Edge::Status::Validated) { - inputPrecisions.emplace_back(DnnlExtensionUtils::DataTypeToElementType((parentEdge->getMemoryPtr()->getDataType()))); + inputPrecisions.emplace_back( + DnnlExtensionUtils::DataTypeToElementType((parentEdge->getMemoryPtr()->getDataType()))); } } @@ -1089,11 +1149,12 @@ ov::element::Type Deconvolution::getRuntimePrecision() const { } Deconvolution::DeconvDNNLExecutor::DeconvDNNLExecutor(const dnnl::deconvolution_forward::primitive_desc& pd, - const dnnl::memory::desc& inMemDesc, - const dnnl::memory::desc& weightMemDesc, - const dnnl::memory::desc& outMemDesc, - const dnnl::engine& engine, - bool constWeight) : DnnlExecutor(pd) { + const dnnl::memory::desc& inMemDesc, + const dnnl::memory::desc& weightMemDesc, + const dnnl::memory::desc& outMemDesc, + const dnnl::engine& engine, + bool constWeight) + : DnnlExecutor(pd) { if (inMemDesc != getDnnlSrcDesc()) { inputReorders.insert({DNNL_ARG_SRC, IntermReorder(inMemDesc, getDnnlSrcDesc(), engine)}); } @@ -1111,7 +1172,7 @@ std::vector Deconvolution::readOutputSpatialDims() const { if (getParentEdges().size() < 3) { OPENVINO_THROW("Can't get output spatial dims. Inputs number = ", getParentEdges().size()); } - const auto &shapeMemPtr = getSrcMemoryAtPort(2); + const auto& shapeMemPtr = getSrcMemoryAtPort(2); if (!shapeMemPtr || !shapeMemPtr->isDefined()) { OPENVINO_THROW("'output_shape' input memory is undefined."); } @@ -1119,20 +1180,20 @@ std::vector Deconvolution::readOutputSpatialDims() const { if (shapeMemPtr->getStaticDims()[0] != spDimsNum) { OPENVINO_THROW("Can't read output spatial dims, beause 'output_shape' input has incorrect number of elements"); } - const int32_t *outShapePtr = shapeMemPtr->getDataAs(); + const int32_t* outShapePtr = shapeMemPtr->getDataAs(); std::vector outSpDims(outShapePtr, outShapePtr + shapeMemPtr->getStaticDims()[0]); return outSpDims; } bool Deconvolution::canFuseBias() const { - //ONEDNN deconvolution_fwd_t primitive can support bias fusing. but has different implementations. - //For the brgdeconv implementation in the deconv list, bias is implemented via JIT kernel. - //For the fall back ref implementation entry(previous conv_backward_data), bias is implemented via reference post-ops. - //It is difficult to recognize whether the deconv will run with brg or fall back to backwards data implementation on the fusing - //transformation stage. In the end, all the deconv should run with brg implement. - //And in model zoo only limited deconv has bias or other post-ops in IR. - //Based on above, enable the bias fusing for all deconv implementations. - return (externOutShape ? getParentEdges().size() == 3 : getParentEdges().size() == 2); + // ONEDNN deconvolution_fwd_t primitive can support bias fusing. but has different implementations. + // For the brgdeconv implementation in the deconv list, bias is implemented via JIT kernel. + // For the fall back ref implementation entry(previous conv_backward_data), bias is implemented via reference + // post-ops. It is difficult to recognize whether the deconv will run with brg or fall back to backwards data + // implementation on the fusing transformation stage. In the end, all the deconv should run with brg implement. And + // in model zoo only limited deconv has bias or other post-ops in IR. Based on above, enable the bias fusing for all + // deconv implementations. + return (externOutShape ? getParentEdges().size() == 3 : getParentEdges().size() == 2); } void Deconvolution::initSupportedPrimitiveDescriptors() { @@ -1143,7 +1204,7 @@ void Deconvolution::initSupportedPrimitiveDescriptors() { VectorDims inDims, outDims; std::tie(inDims, outDims) = makeDummyInOutShape(); - auto tmpInShape = Shape(inDims); + auto tmpInShape = Shape(inDims); auto tmpOutShape = Shape(outDims); initPaddingR(tmpInShape, tmpOutShape); @@ -1154,18 +1215,19 @@ void Deconvolution::initSupportedPrimitiveDescriptors() { config.outConfs.resize(getOriginalOutputsNumber()); config.inConfs[0].setMemDesc( - creatorsMap.at(format)->createSharedDesc(getOriginalInputPrecisionAtPort(0), getInputShapeAtPort(0))); + creatorsMap.at(format)->createSharedDesc(getOriginalInputPrecisionAtPort(0), getInputShapeAtPort(0))); config.inConfs[1].setMemDesc( - creatorsMap.at(weights_format)->createSharedDesc(getOriginalInputPrecisionAtPort(0), getInputShapeAtPort(1))); + creatorsMap.at(weights_format) + ->createSharedDesc(getOriginalInputPrecisionAtPort(0), getInputShapeAtPort(1))); for (size_t i = 2; i < getParentEdges().size(); ++i) { config.inConfs[i].setMemDesc( - creatorsMap.at(format)->createSharedDesc(getOriginalInputPrecisionAtPort(0), getInputShapeAtPort(i))); + creatorsMap.at(format)->createSharedDesc(getOriginalInputPrecisionAtPort(0), getInputShapeAtPort(i))); } for (size_t i = 0; i < config.outConfs.size(); ++i) { config.outConfs[i].setMemDesc( - creatorsMap.at(format)->createSharedDesc(getOriginalOutputPrecisionAtPort(0), getOutputShapeAtPort(i))); + creatorsMap.at(format)->createSharedDesc(getOriginalOutputPrecisionAtPort(0), getOutputShapeAtPort(i))); } std::vector srcMemoryDescs; @@ -1179,8 +1241,11 @@ void Deconvolution::initSupportedPrimitiveDescriptors() { dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()->clone()); } - auto factory = std::make_shared(deconvAttrs, srcMemoryDescs, dstMemoryDescs, - std::make_shared(context, getImplPriority())); + auto factory = + std::make_shared(deconvAttrs, + srcMemoryDescs, + dstMemoryDescs, + std::make_shared(context, getImplPriority())); supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::gemm_acl, factory); }; @@ -1188,7 +1253,6 @@ void Deconvolution::initSupportedPrimitiveDescriptors() { pushDesc(LayoutType::ncsp); } - -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/deconv.h b/src/plugins/intel_cpu/src/nodes/deconv.h index d94bcd8bcaca13..1c3e1fe8978918 100644 --- a/src/plugins/intel_cpu/src/nodes/deconv.h +++ b/src/plugins/intel_cpu/src/nodes/deconv.h @@ -29,27 +29,32 @@ class Deconvolution : public Node { return static_cast(getParentEdges().size()); } - std::shared_ptr getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const override; - std::shared_ptr getDstMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const override; + std::shared_ptr getSrcMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const override; + std::shared_ptr getDstMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const override; ov::element::Type getRuntimePrecision() const override; static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; bool canFuse(const NodePtr& node) const override; - const VectorDims& getWeightDims() const { return getInputShapeAtPort(1).getStaticDims(); } - const std::vector& getStride() const { return deconvAttrs.stride; } + const VectorDims& getWeightDims() const { + return getInputShapeAtPort(1).getStaticDims(); + } + const std::vector& getStride() const { + return deconvAttrs.stride; + } void prepareParams() override; void execute(dnnl::stream strm) override; - void executeDynamicImpl(dnnl::stream strm) override { execute(strm); } + void executeDynamicImpl(dnnl::stream strm) override { + execute(strm); + } bool needShapeInfer() const override; bool canFuseBias() const; bool canBeExecutedInInt8() const override; const std::vector& getDefaultImplPriority() override; - protected: AttrPtr initPrimitiveAttr() override; AttrPtr makePrimitiveAttr(const VectorDims& dims); @@ -60,13 +65,13 @@ class Deconvolution : public Node { using executorPtr = std::shared_ptr; executorPtr execPtr = nullptr; class DeconvDNNLExecutor : public DnnlExecutor { - public: - DeconvDNNLExecutor(const dnnl::deconvolution_forward::primitive_desc& pd, - const dnnl::memory::desc& inMemDesc, - const dnnl::memory::desc& weightMemDesc, - const dnnl::memory::desc& outMemDesc, - const dnnl::engine& engine, - bool constWeight); + public: + DeconvDNNLExecutor(const dnnl::deconvolution_forward::primitive_desc& pd, + const dnnl::memory::desc& inMemDesc, + const dnnl::memory::desc& weightMemDesc, + const dnnl::memory::desc& outMemDesc, + const dnnl::engine& engine, + bool constWeight); }; bool isImplicit1x1PaddingAsymmetric(const VectorDims& inputDims); @@ -79,8 +84,8 @@ class Deconvolution : public Node { size_t IC = 0; size_t OC = 0; std::vector lastOutputSpatialDims; - VectorDims dnnlCompatibleWeiDims {}; - VectorDims expectedBiasDims {}; + VectorDims dnnlCompatibleWeiDims{}; + VectorDims expectedBiasDims{}; bool useACL = false; DeconvAttrs deconvAttrs; @@ -93,9 +98,9 @@ class Deconvolution : public Node { MemoryPtr dnnlCompatibleWeights = nullptr; std::shared_ptr attr; - void setPostOps(dnnl::primitive_attr &attr, const VectorDims &dims); - VectorDims shapeInferInternal(const VectorDims &inDims, std::vector outSpDims) const; - void initPaddingR(const Shape &inShape, const Shape &outShape); + void setPostOps(dnnl::primitive_attr& attr, const VectorDims& dims); + VectorDims shapeInferInternal(const VectorDims& inDims, std::vector outSpDims) const; + void initPaddingR(const Shape& inShape, const Shape& outShape); std::vector readOutputSpatialDims() const; std::pair makeDummyInOutShape(); bool withBiases = false; @@ -110,6 +115,6 @@ class Deconvolution : public Node { bool isConstOutShape = false; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/def_conv.cpp b/src/plugins/intel_cpu/src/nodes/def_conv.cpp index 0167a18673c444..7c5427d0def045 100644 --- a/src/plugins/intel_cpu/src/nodes/def_conv.cpp +++ b/src/plugins/intel_cpu/src/nodes/def_conv.cpp @@ -4,21 +4,20 @@ #include "def_conv.h" -#include +#include +#include +#include #include #include -#include -#include "openvino/core/parallel.hpp" -#include "memory_desc/dnnl_blocked_memory_desc.h" #include "common/primitive_hashing_utils.hpp" -#include "openvino/util/pp.hpp" - -#include "dnnl_types.h" -#include "dnnl_extension_utils.h" #include "cpu/x64/jit_generator.hpp" -#include +#include "dnnl_extension_utils.h" +#include "dnnl_types.h" +#include "memory_desc/dnnl_blocked_memory_desc.h" +#include "openvino/core/parallel.hpp" +#include "openvino/util/pp.hpp" using namespace dnnl; using namespace dnnl::impl; @@ -30,7 +29,7 @@ namespace ov { namespace intel_cpu { namespace node { #if defined(OPENVINO_ARCH_X86_64) -#define GET_OFF(field) offsetof(jit_def_conv_call_args, field) +# define GET_OFF(field) offsetof(jit_def_conv_call_args, field) template struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_generator { @@ -38,7 +37,9 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ constexpr static int sampledPointsPerPixel = DeformableConvolution::sampledPointsPerPixel; - explicit jit_uni_def_conv_kernel_f32(const jit_def_conv_params& jcp) : jit_uni_def_conv_kernel(jcp), jit_generator(jit_name()) {} + explicit jit_uni_def_conv_kernel_f32(const jit_def_conv_params& jcp) + : jit_uni_def_conv_kernel(jcp), + jit_generator(jit_name()) {} void create_ker() override { jit_generator::create_kernel(); @@ -72,8 +73,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ } private: - using Vmm = typename conditional3::type; + using Vmm = + typename conditional3::type; const int vlen = cpu_isa_traits::vlen; using Ymm = const Xbyak::Ymm; @@ -113,18 +114,29 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ Xbyak::Opmask ktail_mask = Xbyak::Opmask(2); - inline Xbyak::Address table_val(int index) - { return ptr[reg_table + index * vlen]; } + inline Xbyak::Address table_val(int index) { + return ptr[reg_table + index * vlen]; + } - inline Vmm get_vmm_ker(int idx) { return Vmm(idx + 0); } - inline Vmm get_vmm_src(int idx) { return Vmm(idx + 1); } - inline Vmm get_vmm_acc(int idx) { return Vmm(idx + jcp_.ur_w + 1); } - inline Ymm get_ymm_acc(int idx) { return Ymm(idx + jcp_.ur_w + 1); } - inline Xmm get_xmm_acc(int idx) { return Xmm(idx + jcp_.ur_w + 1); } + inline Vmm get_vmm_ker(int idx) { + return Vmm(idx + 0); + } + inline Vmm get_vmm_src(int idx) { + return Vmm(idx + 1); + } + inline Vmm get_vmm_acc(int idx) { + return Vmm(idx + jcp_.ur_w + 1); + } + inline Ymm get_ymm_acc(int idx) { + return Ymm(idx + jcp_.ur_w + 1); + } + inline Xmm get_xmm_acc(int idx) { + return Xmm(idx + jcp_.ur_w + 1); + } Xbyak::Label l_table; - inline void checkZeroWei(const Xbyak::Xmm &x1, Label &nullifyLabel) { + inline void checkZeroWei(const Xbyak::Xmm& x1, Label& nullifyLabel) { ptest(x1, x1); jz(nullifyLabel); } @@ -135,13 +147,16 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ mov(reg_ow_pos, 0); - L(ow_loop_main); { + L(ow_loop_main); + { cmp(reg_ow_pos, jcp_.ow - jcp_.ur_w); jg(ow_tail, T_NEAR); oc_loop(jcp_.ur_w); - add(reg_sampled_wei, jcp_.ur_w * jcp_.kh * jcp_.kw * sampledPointsPerPixel * jcp_.typesize_sampled_wei); // type = float - add(reg_sampled_offs, jcp_.ur_w * jcp_.kh * jcp_.kw * sampledPointsPerPixel * jcp_.typesize_sampled_offsets); // type = int + add(reg_sampled_wei, + jcp_.ur_w * jcp_.kh * jcp_.kw * sampledPointsPerPixel * jcp_.typesize_sampled_wei); // type = float + add(reg_sampled_offs, + jcp_.ur_w * jcp_.kh * jcp_.kw * sampledPointsPerPixel * jcp_.typesize_sampled_offsets); // type = int add(reg_output, jcp_.ur_w * jcp_.oc * jcp_.typesize_out); @@ -149,7 +164,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ jmp(ow_loop_main, T_NEAR); } - L(ow_tail); { + L(ow_tail); + { if (jcp_.ow % jcp_.ur_w != 0) oc_loop(jcp_.ow % jcp_.ur_w); } @@ -191,7 +207,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ for (int ic = 0; ic < ic_step; ic++) { for (int ow = 0; ow < ow_step; ow++) { Vmm vmm_src = get_vmm_src(ow); - size_t inp_off = (size_t) ow * jcp_.kh * jcp_.kw * jcp_.ic + kh * jcp_.kw * jcp_.ic + kw * jcp_.ic + ic; + size_t inp_off = + (size_t)ow * jcp_.kh * jcp_.kw * jcp_.ic + kh * jcp_.kw * jcp_.ic + kw * jcp_.ic + ic; uni_vbroadcastss(vmm_src, ptr[aux2_reg_input_buffer + inp_off * jcp_.typesize_in]); } @@ -199,10 +216,10 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ for (int r = 0; r < repeats; r++) { for (int ocb = 0; ocb < oc_blocks_step; ocb++) { Vmm vmm_ker = get_vmm_ker(0); - size_t ker_off = (size_t) ocb * jcp_.nb_ic * jcp_.kh * jcp_.kw * jcp_.ic_block * jcp_.oc_block + - kh * jcp_.kw * jcp_.ic_block * jcp_.oc_block + - kw * jcp_.ic_block * jcp_.oc_block + - ic * jcp_.oc_block + r * jcp_.oc_block / 2; + size_t ker_off = + (size_t)ocb * jcp_.nb_ic * jcp_.kh * jcp_.kw * jcp_.ic_block * jcp_.oc_block + + kh * jcp_.kw * jcp_.ic_block * jcp_.oc_block + kw * jcp_.ic_block * jcp_.oc_block + + ic * jcp_.oc_block + r * jcp_.oc_block / 2; uni_vmovups(vmm_ker, ptr[aux2_reg_kernel + ker_off * jcp_.typesize_in]); for (int ow = 0; ow < ow_step; ow++) { @@ -248,7 +265,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ init_accums(ow_step, oc_blocks_step, oc_step); - L(ic_main_loop); { + L(ic_main_loop); + { cmp(reg_ic_iter, jcp_.ic_block); jl(ic_tail, T_NEAR); @@ -259,7 +277,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ jmp(ic_main_loop, T_NEAR); } - L(ic_tail); { + L(ic_tail); + { if (jcp_.ic % jcp_.ic_block != 0) { apply_filter(ow_step, oc_blocks_step, oc_step, jcp_.ic % jcp_.ic_block); } @@ -283,7 +302,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ xor_(reg_dg_iter, reg_dg_iter); const int ic_per_def_group = jcp_.ic / jcp_.dg; - L(dg_loop); { + L(dg_loop); + { cmp(reg_dg_iter, jcp_.dg); jge(dg_loop_end, T_NEAR); @@ -326,7 +346,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ Xmm xmm_w4 = Xmm(5); Xmm xmm_v1 = Xmm(2); - Xmm xmm_v2 = Xmm(3);; + Xmm xmm_v2 = Xmm(3); + ; Xmm xmm_v3 = Xmm(6); Xmm xmm_v4 = Xmm(7); @@ -341,7 +362,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ Vmm vmm_v4 = Vmm(xmm_v4.getIdx()); // offsets computation - size_t ind_off_hh = sampledPointsPerPixel * (((size_t) kh * jcp_.kw + kw) + ow * (jcp_.kh * jcp_.kw)); + size_t ind_off_hh = + sampledPointsPerPixel * (((size_t)kh * jcp_.kw + kw) + ow * (jcp_.kh * jcp_.kw)); size_t ind_off_hl = ind_off_hh + 1; size_t ind_off_lh = ind_off_hl + 1; size_t ind_off_ll = ind_off_lh + 1; @@ -366,12 +388,16 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ jl(ic_loop_tail, T_NEAR); // check zero markers - uni_vbroadcastss(xmm_v1, dword[aux_reg_sampled_wei + ind_off_ll * jcp_.typesize_sampled_wei]); - uni_vbroadcastss(xmm_v2, dword[aux_reg_sampled_wei + ind_off_hl * jcp_.typesize_sampled_wei]); - uni_vbroadcastss(xmm_v3, dword[aux_reg_sampled_wei + ind_off_lh * jcp_.typesize_sampled_wei]); - uni_vbroadcastss(xmm_v4, dword[aux_reg_sampled_wei + ind_off_hh * jcp_.typesize_sampled_wei]); + uni_vbroadcastss(xmm_v1, + dword[aux_reg_sampled_wei + ind_off_ll * jcp_.typesize_sampled_wei]); + uni_vbroadcastss(xmm_v2, + dword[aux_reg_sampled_wei + ind_off_hl * jcp_.typesize_sampled_wei]); + uni_vbroadcastss(xmm_v3, + dword[aux_reg_sampled_wei + ind_off_lh * jcp_.typesize_sampled_wei]); + uni_vbroadcastss(xmm_v4, + dword[aux_reg_sampled_wei + ind_off_hh * jcp_.typesize_sampled_wei]); - size_t input_buffer_off = (size_t) kh * jcp_.kw * jcp_.ic + kw * jcp_.ic; + size_t input_buffer_off = (size_t)kh * jcp_.kw * jcp_.ic + kw * jcp_.ic; uni_vpmovsxdq(xmm_v1_off, xmm_v1_off); uni_vmovq(reg_tmp_64, xmm_v1_off); @@ -382,9 +408,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ uni_vmulps(vmm_v1, vmm_v1, vmm_w1); jmp(nullify_v1_end, T_NEAR); L(nullify_v1); - { - uni_vpxor(vmm_v1, vmm_v1, vmm_v1); - } + { uni_vpxor(vmm_v1, vmm_v1, vmm_v1); } L(nullify_v1_end); uni_vpmovsxdq(xmm_v2_off, xmm_v2_off); @@ -396,9 +420,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ uni_vmulps(vmm_v2, vmm_v2, vmm_w2); jmp(nullify_v2_end, T_NEAR); L(nullify_v2); - { - uni_vpxor(vmm_v2, vmm_v2, vmm_v2); - } + { uni_vpxor(vmm_v2, vmm_v2, vmm_v2); } L(nullify_v2_end); uni_vpmovsxdq(xmm_v3_off, xmm_v3_off); @@ -410,9 +432,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ uni_vmulps(vmm_v3, vmm_v3, vmm_w3); jmp(nullify_v3_end, T_NEAR); L(nullify_v3); - { - uni_vpxor(vmm_v3, vmm_v3, vmm_v3); - } + { uni_vpxor(vmm_v3, vmm_v3, vmm_v3); } L(nullify_v3_end); uni_vpmovsxdq(xmm_v4_off, xmm_v4_off); @@ -424,9 +444,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ uni_vmulps(vmm_v4, vmm_v4, vmm_w4); jmp(nullify_v4_end, T_NEAR); L(nullify_v4); - { - uni_vpxor(vmm_v4, vmm_v4, vmm_v4); - } + { uni_vpxor(vmm_v4, vmm_v4, vmm_v4); } L(nullify_v4_end); uni_vaddps(vmm_v1, vmm_v1, vmm_v2); @@ -446,12 +464,16 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ jl(loop_end, T_NEAR); // check zero markers - uni_vbroadcastss(xmm_v1, dword[aux_reg_sampled_wei + ind_off_ll * jcp_.typesize_sampled_wei]); - uni_vbroadcastss(xmm_v2, dword[aux_reg_sampled_wei + ind_off_hl * jcp_.typesize_sampled_wei]); - uni_vbroadcastss(xmm_v3, dword[aux_reg_sampled_wei + ind_off_lh * jcp_.typesize_sampled_wei]); - uni_vbroadcastss(xmm_v4, dword[aux_reg_sampled_wei + ind_off_hh * jcp_.typesize_sampled_wei]); - - size_t input_buffer_off = (size_t) kh * jcp_.kw * jcp_.ic + kw * jcp_.ic; + uni_vbroadcastss(xmm_v1, + dword[aux_reg_sampled_wei + ind_off_ll * jcp_.typesize_sampled_wei]); + uni_vbroadcastss(xmm_v2, + dword[aux_reg_sampled_wei + ind_off_hl * jcp_.typesize_sampled_wei]); + uni_vbroadcastss(xmm_v3, + dword[aux_reg_sampled_wei + ind_off_lh * jcp_.typesize_sampled_wei]); + uni_vbroadcastss(xmm_v4, + dword[aux_reg_sampled_wei + ind_off_hh * jcp_.typesize_sampled_wei]); + + size_t input_buffer_off = (size_t)kh * jcp_.kw * jcp_.ic + kw * jcp_.ic; uni_vpmovsxdq(xmm_v1_off, xmm_v1_off); uni_vmovq(reg_tmp_64, xmm_v1_off); imul(reg_tmp_64, reg_tmp_64, jcp_.ic * jcp_.typesize_in); @@ -461,9 +483,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ uni_vmulss(xmm_v1, xmm_v1, xmm_w1); jmp(nullify_v1_end_tail, T_NEAR); L(nullify_v1_tail); - { - uni_vpxor(xmm_v1, xmm_v1, xmm_v1); - } + { uni_vpxor(xmm_v1, xmm_v1, xmm_v1); } L(nullify_v1_end_tail); uni_vpmovsxdq(xmm_v2_off, xmm_v2_off); @@ -475,9 +495,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ uni_vmulss(xmm_v2, xmm_v2, xmm_w2); jmp(nullify_v2_end_tail, T_NEAR); L(nullify_v2_tail); - { - uni_vpxor(xmm_v2, xmm_v2, xmm_v2); - } + { uni_vpxor(xmm_v2, xmm_v2, xmm_v2); } L(nullify_v2_end_tail); uni_vpmovsxdq(xmm_v3_off, xmm_v3_off); @@ -489,9 +507,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ uni_vmulss(xmm_v3, xmm_v3, xmm_w3); jmp(nullify_v3_end_tail, T_NEAR); L(nullify_v3_tail); - { - uni_vpxor(xmm_v3, xmm_v3, xmm_v3); - } + { uni_vpxor(xmm_v3, xmm_v3, xmm_v3); } L(nullify_v3_end_tail); uni_vpmovsxdq(xmm_v4_off, xmm_v4_off); @@ -503,9 +519,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ uni_vmulss(xmm_v4, xmm_v4, xmm_w4); jmp(nullify_v4_end_tail, T_NEAR); L(nullify_v4_tail); - { - uni_vpxor(xmm_v4, xmm_v4, xmm_v4); - } + { uni_vpxor(xmm_v4, xmm_v4, xmm_v4); } L(nullify_v4_end_tail); uni_vaddss(xmm_v1, xmm_v1, xmm_v2); @@ -524,8 +538,10 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ } } - add(aux_reg_sampled_wei, sampledPointsPerPixel * jcp_.kh * jcp_.kw * jcp_.oh * jcp_.ow * jcp_.typesize_sampled_wei); - add(aux_reg_sampled_offs, sampledPointsPerPixel * jcp_.kh * jcp_.kw * jcp_.oh * jcp_.ow * jcp_.typesize_sampled_offsets); + add(aux_reg_sampled_wei, + sampledPointsPerPixel * jcp_.kh * jcp_.kw * jcp_.oh * jcp_.ow * jcp_.typesize_sampled_wei); + add(aux_reg_sampled_offs, + sampledPointsPerPixel * jcp_.kh * jcp_.kw * jcp_.oh * jcp_.ow * jcp_.typesize_sampled_offsets); add(aux_reg_input, ic_per_def_group * jcp_.typesize_in); add(aux2_reg_input_buffer, ic_per_def_group * jcp_.typesize_in); inc(reg_dg_iter); @@ -542,7 +558,7 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ if (jcp_.with_bias) { for (int r = 0; r < repeats; r++) { for (int ocb = 0; ocb < oc_blocks_step; ocb++) { - size_t bias_off = (size_t) ocb * jcp_.oc_block + r * jcp_.oc_block / 2; + size_t bias_off = (size_t)ocb * jcp_.oc_block + r * jcp_.oc_block / 2; uni_vmovups(Vmm(0), ptr[aux_reg_bias + bias_off * jcp_.typesize_bia]); for (int ow = 0; ow < ow_step; ow++) { @@ -560,7 +576,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ } for (int r = 0; r < repeats; r++) { - int tail_size = isa == cpu::x64::sse41 ? std::min(jcp_.oc_block / 2, oc_step - r * jcp_.oc_block / 2) : oc_step; + int tail_size = + isa == cpu::x64::sse41 ? std::min(jcp_.oc_block / 2, oc_step - r * jcp_.oc_block / 2) : oc_step; bool is_scalar_store = isa == cpu::x64::sse41 ? tail_size < jcp_.oc_block / 2 : tail_size < jcp_.oc_block; if (is_scalar_store) { for (int ow = 0; ow < ow_step; ow++) { @@ -568,11 +585,11 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ Xmm xmm_dst = get_xmm_acc(r * jcp_.ur_w * jcp_.nb_oc_blocking + ow); if (isa == avx512_core) { - size_t out_off = (size_t) ow * jcp_.oc; + size_t out_off = (size_t)ow * jcp_.oc; uni_vmovups(ptr[aux_reg_output + out_off * jcp_.typesize_out], vmm_dst | ktail_mask); } else { for (int oc = 0; oc < tail_size; oc++) { - size_t out_off = (size_t) ow * jcp_.oc + oc + r * (jcp_.oc_block / 2); + size_t out_off = (size_t)ow * jcp_.oc + oc + r * (jcp_.oc_block / 2); uni_vmovq(reg_tmp_64, xmm_dst); mov(ptr[aux_reg_output + out_off * jcp_.typesize_out], reg_tmp_32); @@ -593,7 +610,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ for (int ocb = 0; ocb < oc_blocks_step; ocb++) { for (int ow = 0; ow < ow_step; ow++) { Vmm vmm_acc = get_vmm_acc(r * jcp_.ur_w * jcp_.nb_oc_blocking + ocb * ow_step + ow); - size_t out_off = (size_t) ow * jcp_.oc * jcp_.ngroups + ocb * jcp_.oc_block + r * (jcp_.oc_block / 2); + size_t out_off = + (size_t)ow * jcp_.oc * jcp_.ngroups + ocb * jcp_.oc_block + r * (jcp_.oc_block / 2); uni_vmovups(ptr[aux_reg_output + out_off * jcp_.typesize_out], vmm_acc); } } @@ -629,14 +647,17 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ mov(aux_reg_bias, reg_bias); mov(reg_oc_work, jcp_.oc); - L(oc_unrolled_loop); { + L(oc_unrolled_loop); + { cmp(reg_oc_work, jcp_.nb_oc_blocking * jcp_.oc_block); jl(oc_main_loop, T_NEAR); ic_loop(ow_step, jcp_.nb_oc_blocking, jcp_.oc_block); store_output(ow_step, jcp_.nb_oc_blocking, jcp_.oc_block); - add(aux_reg_kernel, jcp_.nb_oc_blocking * jcp_.nb_ic * jcp_.kh * jcp_.kw * jcp_.ic_block * jcp_.oc_block * jcp_.typesize_in); + add(aux_reg_kernel, + jcp_.nb_oc_blocking * jcp_.nb_ic * jcp_.kh * jcp_.kw * jcp_.ic_block * jcp_.oc_block * + jcp_.typesize_in); add(aux_reg_output, jcp_.nb_oc_blocking * jcp_.oc_block * jcp_.typesize_out); add(aux_reg_bias, jcp_.nb_oc_blocking * jcp_.oc_block * jcp_.typesize_bia); sub(reg_oc_work, jcp_.nb_oc_blocking * jcp_.oc_block); @@ -644,7 +665,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ jmp(oc_unrolled_loop, T_NEAR); } - L(oc_main_loop); { + L(oc_main_loop); + { cmp(reg_oc_work, jcp_.oc_block); jl(oc_tail, T_NEAR); @@ -659,7 +681,8 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ jmp(oc_main_loop, T_NEAR); } - L(oc_tail); { + L(oc_tail); + { if (jcp_.oc % jcp_.oc_block != 0) { ic_loop(ow_step, 1, jcp_.oc % jcp_.oc_block); store_output(ow_step, 1, jcp_.oc % jcp_.oc_block); @@ -672,11 +695,12 @@ struct jit_uni_def_conv_kernel_f32 : public jit_uni_def_conv_kernel, public jit_ } }; #endif -bool DeformableConvolution::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool DeformableConvolution::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { if (!one_of(op->get_type_info(), - ov::op::v1::DeformableConvolution::get_type_info_static(), - ov::op::v8::DeformableConvolution::get_type_info_static())) { + ov::op::v1::DeformableConvolution::get_type_info_static(), + ov::op::v8::DeformableConvolution::get_type_info_static())) { errorMessage = "Node is not an instance of DeformableConvolution form the operation set v1 or v8."; return false; } @@ -721,16 +745,16 @@ size_t DefConvKey::hash() const { return seed; } -bool DefConvKey::operator==(const DefConvKey &rhs) const { +bool DefConvKey::operator==(const DefConvKey& rhs) const { bool retVal = true; for (size_t i = 0; i < descVector.size(); i++) { if (descVector[i] != rhs.descVector[i]) { retVal = retVal && descVector[i] && rhs.descVector[i] && - descVector[i]->getBlockDims() == rhs.descVector[i]->getBlockDims() && - descVector[i]->getStrides() == rhs.descVector[i]->getStrides() && - descVector[i]->getOrder() == rhs.descVector[i]->getOrder() && - descVector[i]->getOffsetPaddingToData() == rhs.descVector[i]->getOffsetPaddingToData() && - descVector[i]->getOffsetPadding() == rhs.descVector[i]->getOffsetPadding(); + descVector[i]->getBlockDims() == rhs.descVector[i]->getBlockDims() && + descVector[i]->getStrides() == rhs.descVector[i]->getStrides() && + descVector[i]->getOrder() == rhs.descVector[i]->getOrder() && + descVector[i]->getOffsetPaddingToData() == rhs.descVector[i]->getOffsetPaddingToData() && + descVector[i]->getOffsetPadding() == rhs.descVector[i]->getOffsetPadding(); } } @@ -742,7 +766,7 @@ bool DefConvKey::operator==(const DefConvKey &rhs) const { return retVal; } -} // namespace +} // namespace DeformableConvolution::DeformableConvolution(const std::shared_ptr& op, const GraphContext::CPtr context) : Node(op, context, NgraphShapeInferFactory(op)) { @@ -825,13 +849,14 @@ void DeformableConvolution::initSupportedPrimitiveDescriptors() { impl_desc_type impl_type; const int simd_w = mayiuse(cpu::x64::avx512_core) ? 16 : 8; - auto &weiDims = getInputShapeAtPort(WEI_ID).getDims(); + auto& weiDims = getInputShapeAtPort(WEI_ID).getDims(); if (weiDims[1] == Shape::UNDEFINED_DIM || weiDims[0] == Shape::UNDEFINED_DIM || // 1. strict fallback, until devising of multigroup handling in common case defConvAttr.group != 1 || // 2. common fallback, except specific n_group / n_channel combinations - (defConvAttr.group != 1 && ((weiDims[1] % simd_w != 0) // in_channels_per_gr !% simd_w - || ((weiDims[0] / defConvAttr.group) % simd_w != 0)))) { // out_channels_per_gr !% simd_w + (defConvAttr.group != 1 && + ((weiDims[1] % simd_w != 0) // in_channels_per_gr !% simd_w + || ((weiDims[0] / defConvAttr.group) % simd_w != 0)))) { // out_channels_per_gr !% simd_w enforceRef = true; } else { enforceRef = false; @@ -854,41 +879,48 @@ void DeformableConvolution::initSupportedPrimitiveDescriptors() { auto dataFormat = memory::format_tag::nhwc; auto offFormat = memory::format_tag::nchw; auto weiFormat = mayiuse(avx512_core) ? memory::format_tag::OIhw16i16o : memory::format_tag::OIhw8i8o; - config.inConfs[DATA_ID].setMemDesc(std::make_shared(getInputShapeAtPort(DATA_ID), - memory::data_type::f32, dataFormat)); - config.inConfs[OFF_ID].setMemDesc(std::make_shared(getInputShapeAtPort(OFF_ID), - memory::data_type::f32, offFormat)); + config.inConfs[DATA_ID].setMemDesc( + std::make_shared(getInputShapeAtPort(DATA_ID), memory::data_type::f32, dataFormat)); + config.inConfs[OFF_ID].setMemDesc( + std::make_shared(getInputShapeAtPort(OFF_ID), memory::data_type::f32, offFormat)); - config.inConfs[WEI_ID].setMemDesc(std::make_shared(getInputShapeAtPort(WEI_ID), - memory::data_type::f32, weiFormat)); + config.inConfs[WEI_ID].setMemDesc( + std::make_shared(getInputShapeAtPort(WEI_ID), memory::data_type::f32, weiFormat)); if (inputsNumber > 3) { config.inConfs[MOD_ID].setMemDesc(std::make_shared(getInputShapeAtPort(MOD_ID), - memory::data_type::f32, memory::format_tag::nchw)); + memory::data_type::f32, + memory::format_tag::nchw)); } - config.outConfs[0].setMemDesc(std::make_shared(getOutputShapeAtPort(DATA_ID), - memory::data_type::f32, dataFormat)); + config.outConfs[0].setMemDesc( + std::make_shared(getOutputShapeAtPort(DATA_ID), memory::data_type::f32, dataFormat)); supportedPrimitiveDescriptors.push_back({config, impl_type}); } else { // reference implementation - config.inConfs[DATA_ID].setMemDesc(std::make_shared(getInputShapeAtPort(DATA_ID), memory::data_type::f32, + config.inConfs[DATA_ID].setMemDesc(std::make_shared(getInputShapeAtPort(DATA_ID), + memory::data_type::f32, memory::format_tag::nchw)); - config.inConfs[OFF_ID].setMemDesc(std::make_shared(getInputShapeAtPort(OFF_ID), memory::data_type::f32, + config.inConfs[OFF_ID].setMemDesc(std::make_shared(getInputShapeAtPort(OFF_ID), + memory::data_type::f32, memory::format_tag::nchw)); - config.inConfs[WEI_ID].setMemDesc(std::make_shared(getInputShapeAtPort(WEI_ID), memory::data_type::f32, + config.inConfs[WEI_ID].setMemDesc(std::make_shared(getInputShapeAtPort(WEI_ID), + memory::data_type::f32, memory::format_tag::oihw)); if (inputsNumber > 3) { - config.inConfs[MOD_ID].setMemDesc(std::make_shared(getInputShapeAtPort(MOD_ID), memory::data_type::f32, + config.inConfs[MOD_ID].setMemDesc(std::make_shared(getInputShapeAtPort(MOD_ID), + memory::data_type::f32, memory::format_tag::nchw)); } - config.outConfs[0].setMemDesc(std::make_shared(getOutputShapeAtPort(DATA_ID), memory::data_type::f32, + config.outConfs[0].setMemDesc(std::make_shared(getOutputShapeAtPort(DATA_ID), + memory::data_type::f32, memory::format_tag::nchw)); supportedPrimitiveDescriptors.push_back({config, impl_type}); } } -void DeformableConvolution::DefConvExecutor::prepareSamplingWeights( - const float* offsets, const float* modulation, bool enforceRef) { +void DeformableConvolution::DefConvExecutor::prepareSamplingWeights(const float* offsets, + const float* modulation, + bool enforceRef) { const int MB = jcp.mb; const int OH = jcp.oh; const int OW = jcp.ow; @@ -918,45 +950,45 @@ void DeformableConvolution::DefConvExecutor::prepareSamplingWeights( const int h_in = oh * KSH - padT; const int w_in = ow * KSW - padL; - const float *data_offset_ptr = offsets + mb * offStrides[0] + (dg * 2 * KH * KW) * offStrides[1]; - const float *modulation_offset_ptr = nullptr; + const float* data_offset_ptr = offsets + mb * offStrides[0] + (dg * 2 * KH * KW) * offStrides[1]; + const float* modulation_offset_ptr = nullptr; if (modulation != nullptr) { modulation_offset_ptr = modulation + mb * modStrides[0] + (dg * ker_size) * modStrides[1]; } for (int kh = 0; kh < KH; kh++) { for (int kw = 0; kw < KW; kw++) { - const size_t data_offset_h_index = 2 * ((size_t) kh * KW + kw) * offStrides[1] + oh * offStrides[2] + ow * offStrides[3]; - const size_t data_offset_w_index = (2 * ((size_t) kh * KW + kw) + 1) * offStrides[1] + oh * offStrides[2] + ow * offStrides[3]; + const size_t data_offset_h_index = + 2 * ((size_t)kh * KW + kw) * offStrides[1] + oh * offStrides[2] + ow * offStrides[3]; + const size_t data_offset_w_index = + (2 * ((size_t)kh * KW + kw) + 1) * offStrides[1] + oh * offStrides[2] + ow * offStrides[3]; const float offset_h = data_offset_ptr[data_offset_h_index]; const float offset_w = data_offset_ptr[data_offset_w_index]; float map_h = h_in + kh * (KDH + 1) + offset_h; float map_w = w_in + kw * (KDW + 1) + offset_w; bool skip_compute; if (with_bi_pad) { - skip_compute = !(static_cast(map_w) > -1 && - static_cast(map_w) < IW && - static_cast(map_h) > -1 && - static_cast(map_h) < IH); + skip_compute = !(static_cast(map_w) > -1 && static_cast(map_w) < IW && + static_cast(map_h) > -1 && static_cast(map_h) < IH); } else { - skip_compute = !(map_w >= 0 && map_w < IW && - map_h >= 0 && map_h < IH); + skip_compute = !(map_w >= 0 && map_w < IW && map_h >= 0 && map_h < IH); } if (!skip_compute) { // modulations precomp. float modulation_scalar = 1.0f; if (modulation_offset_ptr != nullptr) { - size_t modulation_index = (kh * KW + kw) * modStrides[1] + oh * modStrides[2] + ow * modStrides[3]; + size_t modulation_index = + (kh * KW + kw) * modStrides[1] + oh * modStrides[2] + ow * modStrides[3]; modulation_scalar = modulation_offset_ptr[modulation_index]; } // interpolation precomp. const int cur_h_end = IH; const int cur_w_end = IW; - int h_low = with_bi_pad ? static_cast(floorf(map_h)) : - std::max(static_cast(floorf(map_h)), 0); - int w_low = with_bi_pad ? static_cast(floorf(map_w)) : - std::max(static_cast(floorf(map_w)), 0); + int h_low = + with_bi_pad ? static_cast(floorf(map_h)) : std::max(static_cast(floorf(map_h)), 0); + int w_low = + with_bi_pad ? static_cast(floorf(map_w)) : std::max(static_cast(floorf(map_w)), 0); int h_high = with_bi_pad ? h_low + 1 : std::min(static_cast(ceilf(map_h)), cur_h_end - 1); int w_high = with_bi_pad ? w_low + 1 : std::min(static_cast(ceilf(map_w)), cur_w_end - 1); @@ -976,7 +1008,7 @@ void DeformableConvolution::DefConvExecutor::prepareSamplingWeights( const int h_off_low = h_ind_low * (srcStrides[2] / srcStrides[3]); const int h_off_high = h_ind_high * (srcStrides[2] / srcStrides[3]); - const int w_off_low = w_ind_low; + const int w_off_low = w_ind_low; const int w_off_high = w_ind_high; pSampledCoordsVector[sampledCoordIndex] = h_off_high + w_off_high; pSampledCoordsVector[sampledCoordIndex + 1] = h_off_high + w_off_low; @@ -984,7 +1016,7 @@ void DeformableConvolution::DefConvExecutor::prepareSamplingWeights( pSampledCoordsVector[sampledCoordIndex + 3] = h_off_low + w_off_low; float w22 = hh * hw * modulation_scalar, w21 = hh * lw * modulation_scalar, - w12 = lh * hw * modulation_scalar, w11 = lh * lw * modulation_scalar; + w12 = lh * hw * modulation_scalar, w11 = lh * lw * modulation_scalar; pInterpWeightsVector[sampledCoordIndex] = w11; pInterpWeightsVector[sampledCoordIndex + 1] = w12; @@ -1007,15 +1039,16 @@ void DeformableConvolution::DefConvExecutor::prepareSamplingWeights( }); } -DeformableConvolution::DefConvExecutor::DefConvExecutor(const DefConvAttr &defConvAttr, - const std::vector> &descVector) { +DeformableConvolution::DefConvExecutor::DefConvExecutor( + const DefConvAttr& defConvAttr, + const std::vector>& descVector) { if (descVector.size() != 4 && descVector.size() != 5) { OPENVINO_THROW("Deformable Convolution executor got incorrect desc's count (", descVector.size(), ")"); } bool withModulation = descVector.size() == 5; - auto &srcDesc = descVector[DATA_ID]; - auto &dstDesc = descVector[descVector.size() - 1]; + auto& srcDesc = descVector[DATA_ID]; + auto& dstDesc = descVector[descVector.size() - 1]; srcStrides = std::vector(srcDesc->getStrides().size()); offStrides = descVector[OFF_ID]->getStrides(); weiStrides = descVector[WEI_ID]->getStrides(); @@ -1085,9 +1118,10 @@ DeformableConvolution::DefConvExecutor::DefConvExecutor(const DefConvAttr &defCo jcp.nthr = dnnl_get_max_threads(); } -DeformableConvolution::DefConvJitExecutor::DefConvJitExecutor(const DefConvAttr &defConvAttr, - const std::vector> &descVector) : - DefConvExecutor(defConvAttr, descVector) { +DeformableConvolution::DefConvJitExecutor::DefConvJitExecutor( + const DefConvAttr& defConvAttr, + const std::vector>& descVector) + : DefConvExecutor(defConvAttr, descVector) { #if defined(OPENVINO_ARCH_X86_64) if (mayiuse(cpu::x64::avx512_core)) { def_conv_kernel.reset(new jit_uni_def_conv_kernel_f32(jcp)); @@ -1106,9 +1140,13 @@ DeformableConvolution::DefConvJitExecutor::DefConvJitExecutor(const DefConvAttr #endif } -void DeformableConvolution::DefConvRefExecutor::exec(const float* src, const float* offsets, - const float* weights, const float* modulation, float* dst, - int *pSampledCoordsVector, float *pInterpWeightsVector) { +void DeformableConvolution::DefConvRefExecutor::exec(const float* src, + const float* offsets, + const float* weights, + const float* modulation, + float* dst, + int* pSampledCoordsVector, + float* pInterpWeightsVector) { this->pSampledCoordsVector = pSampledCoordsVector; this->pInterpWeightsVector = pInterpWeightsVector; prepareSamplingWeights(offsets, modulation, true); @@ -1133,17 +1171,18 @@ void DeformableConvolution::DefConvRefExecutor::exec(const float* src, const flo auto compKer = [OV_CAPTURE_CPY_AND_THIS](int g, int mb, int oc, int oh, int ow) { float d = 0; for (int ic = 0; ic < IC; ic++) { - const float *data_im_ptr = src + mb * srcStrides[0] + (g * IC + ic) * srcStrides[1]; + const float* data_im_ptr = src + mb * srcStrides[0] + (g * IC + ic) * srcStrides[1]; const int deformable_group_index = (IC * g + ic) / channel_per_deformable_group; - int sampledCoordIndex = (mb * DGHW + deformable_group_index * HW + oh * OW + ow) * ker_size * sampledPointsPerPixel; - size_t weiIndex = (size_t) g * group_wei_stride + oc * weiStrides[0] + ic * weiStrides[1]; + int sampledCoordIndex = + (mb * DGHW + deformable_group_index * HW + oh * OW + ow) * ker_size * sampledPointsPerPixel; + size_t weiIndex = (size_t)g * group_wei_stride + oc * weiStrides[0] + ic * weiStrides[1]; for (size_t kh_off = 0; kh_off < KH * weiStrides[2]; kh_off += weiStrides[2]) { for (size_t kw_off = 0; kw_off < KW * weiStrides[3]; kw_off += weiStrides[3]) { // check if current addendum marked as equal zero if (pSampledCoordsVector[sampledCoordIndex] != -1) { const int v11 = pSampledCoordsVector[sampledCoordIndex]; const int v12 = pSampledCoordsVector[sampledCoordIndex + 1]; - const int v21 = pSampledCoordsVector[sampledCoordIndex + 2]; + const int v21 = pSampledCoordsVector[sampledCoordIndex + 2]; const int v22 = pSampledCoordsVector[sampledCoordIndex + 3]; float val = 0; @@ -1174,8 +1213,9 @@ void DeformableConvolution::DefConvRefExecutor::exec(const float* src, const flo }; parallel_nd(G, MB, OC, OH, OW, [&](dnnl_dim_t g, dnnl_dim_t mb, dnnl_dim_t oc, dnnl_dim_t oh, dnnl_dim_t ow) { - dst[mb * dstStrides[0] + (g * OC + oc) * dstStrides[1] + oh * dstStrides[2] + ow * dstStrides[3]] = compKer(g, mb, oc, oh, ow); - }); + dst[mb * dstStrides[0] + (g * OC + oc) * dstStrides[1] + oh * dstStrides[2] + ow * dstStrides[3]] = + compKer(g, mb, oc, oh, ow); + }); } void DeformableConvolution::prepareParams() { @@ -1208,22 +1248,17 @@ void DeformableConvolution::prepareParams() { updatePadding(); - std::vector> descVector { + std::vector> descVector{ getParentEdgeAt(DATA_ID)->getMemory().getDescWithType(), getParentEdgeAt(OFF_ID)->getMemory().getDescWithType(), - getParentEdgeAt(WEI_ID)->getMemory().getDescWithType() - }; + getParentEdgeAt(WEI_ID)->getMemory().getDescWithType()}; if (withModulation) { descVector.push_back(getParentEdgeAt(MOD_ID)->getMemory().getDescWithType()); } descVector.push_back(getChildEdgeAt(0)->getMemory().getDescWithType()); - DefConvKey key = { - descVector, - defConvAttr, - getSelectedPrimitiveDescriptor()->getImplementationType() - }; + DefConvKey key = {descVector, defConvAttr, getSelectedPrimitiveDescriptor()->getImplementationType()}; const int MB = getParentEdgeAt(DATA_ID)->getMemory().getStaticDims()[0]; const int OH = getChildEdgeAt(0)->getMemory().getStaticDims()[2]; @@ -1241,7 +1276,7 @@ void DeformableConvolution::prepareParams() { execPtr = nullptr; auto cache = context->getParamsCache(); - auto result = cache->getOrCreate(key, [] (const DefConvKey& key) -> std::shared_ptr { + auto result = cache->getOrCreate(key, [](const DefConvKey& key) -> std::shared_ptr { if (key.implType == impl_desc_type::ref) { return std::make_shared(key.defConvAttr, key.descVector); } @@ -1258,9 +1293,13 @@ void DeformableConvolution::executeDynamicImpl(dnnl::stream strm) { execute(strm); } -void DeformableConvolution::DefConvJitExecutor::exec(const float* src, const float* offsets, - const float* weights, const float* modulation, float* dst, - int *pSampledCoordsVector, float *pInterpWeightsVector) { +void DeformableConvolution::DefConvJitExecutor::exec(const float* src, + const float* offsets, + const float* weights, + const float* modulation, + float* dst, + int* pSampledCoordsVector, + float* pInterpWeightsVector) { this->pSampledCoordsVector = pSampledCoordsVector; this->pInterpWeightsVector = pInterpWeightsVector; prepareSamplingWeights(offsets, modulation, false); @@ -1276,9 +1315,11 @@ void DeformableConvolution::DefConvJitExecutor::exec(const float* src, const flo const size_t _oc = g * jcp.nb_oc; const size_t _ic = g * jcp.nb_ic; - par_conv.src = &src[n * srcStrides[0] + _ic*jcp.ic_block * srcStrides[1]]; - par_conv.sampledWei = &(pInterpWeightsVector[(n * jcp.dg * jcp.oh + oh) * jcp.kh * jcp.kw * jcp.ow * sampledPointsPerPixel]); - par_conv.sampledCoords = &(pSampledCoordsVector[(n * jcp.dg * jcp.oh + oh) * jcp.kh * jcp.kw * jcp.ow * sampledPointsPerPixel]); + par_conv.src = &src[n * srcStrides[0] + _ic * jcp.ic_block * srcStrides[1]]; + par_conv.sampledWei = + &(pInterpWeightsVector[(n * jcp.dg * jcp.oh + oh) * jcp.kh * jcp.kw * jcp.ow * sampledPointsPerPixel]); + par_conv.sampledCoords = + &(pSampledCoordsVector[(n * jcp.dg * jcp.oh + oh) * jcp.kh * jcp.kw * jcp.ow * sampledPointsPerPixel]); par_conv.filt = &weights[g * jcp.nb_oc * jcp.nb_ic * jcp.kh * jcp.kw * jcp.ic_block * jcp.oc_block]; par_conv.dst = &dst[n * dstStrides[0] + _oc * jcp.oc_block * dstStrides[1] + oh * dstStrides[2]]; par_conv.buf = input_buffer_ptr + ithr * jcp.ur_w * jcp.kh * jcp.kw * jcp.ic; @@ -1292,20 +1333,20 @@ void DeformableConvolution::DefConvJitExecutor::exec(const float* src, const flo void DeformableConvolution::execute(dnnl::stream strm) { const size_t inputsNumber = getOriginalInputsNumber(); - auto &srcMemory0 = getParentEdgeAt(0)->getMemory(); - auto &srcMemory1 = getParentEdgeAt(1)->getMemory(); - auto &srcMemory2 = getParentEdgeAt(2)->getMemory(); - auto &dstMemory = getChildEdgeAt(0)->getMemory(); + auto& srcMemory0 = getParentEdgeAt(0)->getMemory(); + auto& srcMemory1 = getParentEdgeAt(1)->getMemory(); + auto& srcMemory2 = getParentEdgeAt(2)->getMemory(); + auto& dstMemory = getChildEdgeAt(0)->getMemory(); - const auto *src = srcMemory0.getDataAs(); - const auto *offsets = srcMemory1.getDataAs(); - const auto *weights = srcMemory2.getDataAs(); + const auto* src = srcMemory0.getDataAs(); + const auto* offsets = srcMemory1.getDataAs(); + const auto* weights = srcMemory2.getDataAs(); float* modulation = nullptr; if (inputsNumber > 3) { modulation = getSrcDataAtPortAs(3); } - float *dst = dstMemory.getDataAs(); + float* dst = dstMemory.getDataAs(); auto selectedPrimitiveDescriptor = getSelectedPrimitiveDescriptor(); if (!selectedPrimitiveDescriptor) @@ -1333,6 +1374,6 @@ ov::element::Type DeformableConvolution::getRuntimePrecision() const { return getMaxPrecision(getInputPrecisions()); } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/def_conv.h b/src/plugins/intel_cpu/src/nodes/def_conv.h index 127fd00eb2bf00..ed5800a19a0e84 100644 --- a/src/plugins/intel_cpu/src/nodes/def_conv.h +++ b/src/plugins/intel_cpu/src/nodes/def_conv.h @@ -5,6 +5,7 @@ #pragma once #include + #include #include #include @@ -43,20 +44,20 @@ struct jit_def_conv_params { }; struct jit_def_conv_call_args { - const void *src; - const void *sampledWei; - const void *sampledCoords; - const void *filt; - const void *bias; - const void *dst; - const void *buf; + const void* src; + const void* sampledWei; + const void* sampledCoords; + const void* filt; + const void* bias; + const void* dst; + const void* buf; size_t oh_pos; }; struct jit_uni_def_conv_kernel { - void (*ker_)(const jit_def_conv_call_args *); + void (*ker_)(const jit_def_conv_call_args*); - void operator()(const jit_def_conv_call_args *args) { + void operator()(const jit_def_conv_call_args* args) { assert(ker_); ker_(args); } @@ -109,53 +110,66 @@ class DeformableConvolution : public Node { static constexpr size_t MOD_ID = 3; std::string errorPrefix; class DefConvExecutor { - public: - DefConvExecutor(const DefConvAttr &defConvAttr, - const std::vector> &descVector); - - virtual void exec(const float* src, const float* offsets, - const float* weights, const float* modulation, float* dst, - int *pSampledCoordsVector, float *pInterpWeightsVector) = 0; - virtual ~DefConvExecutor() = default; - - protected: - void prepareSamplingWeights(const float* offsets, const float* modulation = nullptr, bool enforceRef = false); - jit_def_conv_params jcp = {}; - VectorDims srcStrides; - VectorDims offStrides; - VectorDims weiStrides; - VectorDims modStrides; - VectorDims dstStrides; - int *pSampledCoordsVector; - float *pInterpWeightsVector; + public: + DefConvExecutor(const DefConvAttr& defConvAttr, + const std::vector>& descVector); + + virtual void exec(const float* src, + const float* offsets, + const float* weights, + const float* modulation, + float* dst, + int* pSampledCoordsVector, + float* pInterpWeightsVector) = 0; + virtual ~DefConvExecutor() = default; + + protected: + void prepareSamplingWeights(const float* offsets, const float* modulation = nullptr, bool enforceRef = false); + jit_def_conv_params jcp = {}; + VectorDims srcStrides; + VectorDims offStrides; + VectorDims weiStrides; + VectorDims modStrides; + VectorDims dstStrides; + int* pSampledCoordsVector; + float* pInterpWeightsVector; }; class DefConvRefExecutor : public DefConvExecutor { - public: - DefConvRefExecutor(const DefConvAttr &defConvAttr, - const std::vector> &descVector) : - DefConvExecutor(defConvAttr, descVector) {} - - void exec(const float* src, const float* offsets, - const float* weights, const float* modulation, float* dst, - int *pSampledCoordsVector, float *pInterpWeightsVector) override; + public: + DefConvRefExecutor(const DefConvAttr& defConvAttr, + const std::vector>& descVector) + : DefConvExecutor(defConvAttr, descVector) {} + + void exec(const float* src, + const float* offsets, + const float* weights, + const float* modulation, + float* dst, + int* pSampledCoordsVector, + float* pInterpWeightsVector) override; }; class DefConvJitExecutor : public DefConvExecutor { - std::shared_ptr def_conv_kernel = nullptr; - public: - DefConvJitExecutor(const DefConvAttr &defConvAttr, - const std::vector> &descVector); - - void exec(const float* src, const float* offsets, - const float* weights, const float* modulation, float* dst, - int *pSampledCoordsVector, float *pInterpWeightsVector) override; + std::shared_ptr def_conv_kernel = nullptr; + + public: + DefConvJitExecutor(const DefConvAttr& defConvAttr, + const std::vector>& descVector); + + void exec(const float* src, + const float* offsets, + const float* weights, + const float* modulation, + float* dst, + int* pSampledCoordsVector, + float* pInterpWeightsVector) override; }; std::shared_ptr execPtr = nullptr; bool autoPadding = false; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/depth_to_space.cpp b/src/plugins/intel_cpu/src/nodes/depth_to_space.cpp index eb3789068adca1..a8629ce2592d76 100644 --- a/src/plugins/intel_cpu/src/nodes/depth_to_space.cpp +++ b/src/plugins/intel_cpu/src/nodes/depth_to_space.cpp @@ -4,16 +4,15 @@ #include "depth_to_space.h" -#include "dnnl_extension_utils.h" -#include "utils/general_utils.h" - #include -#include "common/primitive_hashing_utils.hpp" -#include "cpu/x64/jit_generator.hpp" -#include "openvino/opsets/opset1.hpp" #include #include "common/blocked_desc_creator.h" +#include "common/primitive_hashing_utils.hpp" +#include "cpu/x64/jit_generator.hpp" +#include "dnnl_extension_utils.h" +#include "openvino/opsets/opset1.hpp" +#include "utils/general_utils.h" #define THROW_ERROR(...) OPENVINO_THROW("DepthToSpace layer with name '", getName(), "' ", __VA_ARGS__) @@ -40,9 +39,8 @@ size_t DepthToSpace::DepthToSpaceAttrs::hash() const { } bool DepthToSpace::DepthToSpaceAttrs::operator==(const DepthToSpaceAttrs& rhs) const { - bool result = layoutType == rhs.layoutType && mode == rhs.mode && - blockSize == rhs.blockSize && blockStep == rhs.blockStep && - dataSize == rhs.dataSize && nSpatialDims == rhs.nSpatialDims && + bool result = layoutType == rhs.layoutType && mode == rhs.mode && blockSize == rhs.blockSize && + blockStep == rhs.blockStep && dataSize == rhs.dataSize && nSpatialDims == rhs.nSpatialDims && srcBlockedDims == rhs.srcBlockedDims; return result; @@ -56,7 +54,9 @@ bool DepthToSpace::isSupportedOperation(const std::shared_ptr& o return false; } const auto mode = depthToSpace->get_mode(); - if (!one_of(mode, ov::op::v0::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, ov::op::v0::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST)) { + if (!one_of(mode, + ov::op::v0::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, + ov::op::v0::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST)) { errorMessage = "Does not support mode: " + ov::as_string(mode); return false; } @@ -138,7 +138,8 @@ void DepthToSpace::initSupportedPrimitiveDescriptors() { if (inputDataShape.getRank() > 2) { const auto& srcDims = inputDataShape.getDims(); auto canUseBlocked = [OV_CAPTURE_CPY_AND_THIS](const size_t block) { - return srcDims[1] != Shape::UNDEFINED_DIM && srcDims[1] % block == 0 && (srcDims[1] / block) % attrs.blockStep == 0 && + return srcDims[1] != Shape::UNDEFINED_DIM && srcDims[1] % block == 0 && + (srcDims[1] / block) % attrs.blockStep == 0 && (attrs.mode == Mode::DEPTH_FIRST ? block % attrs.blockStep == 0 : true); }; @@ -172,9 +173,10 @@ void DepthToSpace::createPrimitive() { const auto& memoryDesc = srcMemPtr->getDesc(); attrs.dataSize = memoryDesc.getPrecision().size(); attrs.nSpatialDims = memoryDesc.getShape().getRank() - 2; - attrs.layoutType = memoryDesc.hasLayoutType(LayoutType::nCsp16c) ? LayoutType::nCsp16c : - memoryDesc.hasLayoutType(LayoutType::nCsp8c) ? LayoutType::nCsp8c : - memoryDesc.hasLayoutType(LayoutType::nspc) ? LayoutType::nspc : LayoutType::ncsp; + attrs.layoutType = memoryDesc.hasLayoutType(LayoutType::nCsp16c) ? LayoutType::nCsp16c + : memoryDesc.hasLayoutType(LayoutType::nCsp8c) ? LayoutType::nCsp8c + : memoryDesc.hasLayoutType(LayoutType::nspc) ? LayoutType::nspc + : LayoutType::ncsp; if (inputShapesDefined()) { if (needPrepareParams()) @@ -205,7 +207,8 @@ DepthToSpace::DepthToSpaceExecutor::DepthToSpaceExecutor(const DepthToSpaceAttrs const bool isBlocked = one_of(attrs.layoutType, LayoutType::nCsp16c, LayoutType::nCsp8c); const bool isChannelsFirst = attrs.layoutType == LayoutType::nspc; const size_t nDims = attrs.srcBlockedDims.size(); - const size_t reshapedRank = nDims + attrs.nSpatialDims + static_cast(isBlocked && attrs.mode == Mode::DEPTH_FIRST); + const size_t reshapedRank = + nDims + attrs.nSpatialDims + static_cast(isBlocked && attrs.mode == Mode::DEPTH_FIRST); const size_t lastIdx = reshapedRank - 1; size_t firstSpatialOrder = 2; @@ -219,21 +222,24 @@ DepthToSpace::DepthToSpaceExecutor::DepthToSpaceExecutor(const DepthToSpaceAttrs params.src_block_dims[0] = attrs.srcBlockedDims[0]; // reshaping of src dimensions and creating the permutation order for each layout: - // new shape: mode = blocks_first [N, block_size, block_size, ..., block_size, C / (block_size ^ K), D1, D2, ..., DK] - // mode = depth_first [N, C / (block_size ^ K), block_size, block_size, ..., block_size, D1, D2, ..., DK] + // new shape: mode = blocks_first [N, block_size, block_size, ..., block_size, C / (block_size ^ K), D1, D2, ..., + // DK] + // mode = depth_first [N, C / (block_size ^ K), block_size, block_size, ..., block_size, D1, D2, ..., + // DK] // order : mode = blocks_first : [0, K + 1, K + 2, 1, K + 3, 2, K + 4, 3, ..., K + (K + 1), K] // mode = depth_first : [0, 1, K + 2, 2, K + 3, 3, K + 4, 4, ..., K + (K + 1), K + 1] // where `k` is number of spatial dimensions - auto reshapeAndSetPermOrder = [&](const size_t idx1, const size_t idx2, const size_t shift, const VectorDims& dims) { - for (size_t i = 0; i < attrs.nSpatialDims; i++) { - params.order[i * 2 + shift] = i + idx1; - params.order[i * 2 + shift + 1] = i + idx2; + auto reshapeAndSetPermOrder = + [&](const size_t idx1, const size_t idx2, const size_t shift, const VectorDims& dims) { + for (size_t i = 0; i < attrs.nSpatialDims; i++) { + params.order[i * 2 + shift] = i + idx1; + params.order[i * 2 + shift + 1] = i + idx2; - params.src_block_dims[params.order[i * 2 + shift]] = dims[i + shift]; - params.src_block_dims[params.order[i * 2 + shift + 1]] = attrs.blockSize; - } - }; + params.src_block_dims[params.order[i * 2 + shift]] = dims[i + shift]; + params.src_block_dims[params.order[i * 2 + shift + 1]] = attrs.blockSize; + } + }; if (isBlocked) { size_t orderShiftForBlocks, orderShiftForDims; @@ -314,6 +320,6 @@ bool DepthToSpace::created() const { return getType() == Type::DepthToSpace; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/depth_to_space.h b/src/plugins/intel_cpu/src/nodes/depth_to_space.h index 2eda39f60394af..21eca73f97318c 100644 --- a/src/plugins/intel_cpu/src/nodes/depth_to_space.h +++ b/src/plugins/intel_cpu/src/nodes/depth_to_space.h @@ -54,6 +54,6 @@ class DepthToSpace : public Node { executorPtr execPtr = nullptr; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/detection_output.cpp b/src/plugins/intel_cpu/src/nodes/detection_output.cpp index 99702780b83034..9cf52e7042c6ba 100644 --- a/src/plugins/intel_cpu/src/nodes/detection_output.cpp +++ b/src/plugins/intel_cpu/src/nodes/detection_output.cpp @@ -16,8 +16,7 @@ namespace node { namespace { template -bool SortScorePairDescend(const std::pair& pair1, - const std::pair& pair2) { +bool SortScorePairDescend(const std::pair& pair1, const std::pair& pair2) { return (pair1.first > pair2.first) || (pair1.first == pair2.first && pair1.second < pair2.second); } @@ -27,9 +26,10 @@ bool SortScorePairDescend>(const std::pair pair2.first) || (pair1.first == pair2.first && pair1.second.second < pair2.second.second); } -} // namespace +} // namespace -bool DetectionOutput::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool DetectionOutput::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { const auto doOp = ov::as_type_ptr(op); if (!doOp) { @@ -58,7 +58,7 @@ DetectionOutput::DetectionOutput(const std::shared_ptr& op, const Grap errorPrefix = "DetectionOutput node with name '" + getName() + "' "; if (getOriginalInputsNumber() != 3 && getOriginalInputsNumber() != 5) - OPENVINO_THROW(errorPrefix, "has incorrect number of input edges."); + OPENVINO_THROW(errorPrefix, "has incorrect number of input edges."); if (getOriginalOutputsNumber() != 1) OPENVINO_THROW(errorPrefix, "has incorrect number of output edges."); @@ -93,7 +93,7 @@ DetectionOutput::DetectionOutput(const std::shared_ptr& op, const Grap void DetectionOutput::prepareParams() { const auto& idPriorDims = getParentEdgeAt(ID_PRIOR)->getMemory().getShape().getStaticDims(); - const auto &idConfDims = getParentEdgeAt(ID_CONF)->getMemory().getShape().getStaticDims(); + const auto& idConfDims = getParentEdgeAt(ID_CONF)->getMemory().getShape().getStaticDims(); priorsNum = static_cast(idPriorDims.back() / priorSize); isPriorsPerImg = idPriorDims.front() != 1; classesNum = static_cast(idConfDims.back() / priorsNum); @@ -130,9 +130,8 @@ void DetectionOutput::prepareParams() { // --> g_topk(vector<>(all detections) --> indices per class)) // MXNet: max conf for prior within img, filter(indices) --> topk_img(buffer) --> nms_cls(indices) // --> g_topk(vector<>(all detections) --> indices per class)) - isSparsityWorthwhile = - (confidenceThreshold > sparsityThreshold) && - ((classesNum * priorsNum * sizeof(float) * 2) > static_cast(cacheSizeL3)); + isSparsityWorthwhile = (confidenceThreshold > sparsityThreshold) && + ((classesNum * priorsNum * sizeof(float) * 2) > static_cast(cacheSizeL3)); confInfoLen = (!decreaseClassId && isSparsityWorthwhile) ? (2 * priorsNum + 1) : priorsNum; reorderedConf.resize(imgNum * classesNum * confInfoLen); @@ -149,17 +148,17 @@ void DetectionOutput::initSupportedPrimitiveDescriptors() { for (size_t i = 0; i < inputShapes.size(); ++i) inDataConf.emplace_back(LayoutType::ncsp, ov::element::f32); - addSupportedPrimDesc(inDataConf, - {{LayoutType::ncsp, ov::element::f32}}, - impl_desc_type::ref_any); + addSupportedPrimDesc(inDataConf, {{LayoutType::ncsp, ov::element::f32}}, impl_desc_type::ref_any); } struct ConfidenceComparatorDO { explicit ConfidenceComparatorDO(const float* confDataIn) : confData(confDataIn) {} bool operator()(int idx1, int idx2) { - if (confData[idx1] > confData[idx2]) return true; - if (confData[idx1] < confData[idx2]) return false; + if (confData[idx1] > confData[idx2]) + return true; + if (confData[idx1] < confData[idx2]) + return false; return idx1 < idx2; } @@ -171,31 +170,29 @@ void DetectionOutput::executeDynamicImpl(dnnl::stream strm) { } void DetectionOutput::execute(dnnl::stream strm) { - float *dstData = getDstDataAtPortAs(0); + float* dstData = getDstDataAtPortAs(0); - const float *locData = getSrcDataAtPortAs(ID_LOC); - const float *confData = getSrcDataAtPortAs(ID_CONF); - const float *priorData = getSrcDataAtPortAs(ID_PRIOR); - const float *ARMConfData = inputShapes.size() > 3 ? - getSrcDataAtPortAs(ID_ARM_CONF) : nullptr; - const float *ARMLocData = inputShapes.size() > 4 ? - getSrcDataAtPortAs(ID_ARM_LOC) : nullptr; + const float* locData = getSrcDataAtPortAs(ID_LOC); + const float* confData = getSrcDataAtPortAs(ID_CONF); + const float* priorData = getSrcDataAtPortAs(ID_PRIOR); + const float* ARMConfData = inputShapes.size() > 3 ? getSrcDataAtPortAs(ID_ARM_CONF) : nullptr; + const float* ARMLocData = inputShapes.size() > 4 ? getSrcDataAtPortAs(ID_ARM_LOC) : nullptr; - float *reorderedConfData = reorderedConf.data(); - int *reorderedConfDataIndices = reinterpret_cast(reorderedConf.data()); + float* reorderedConfData = reorderedConf.data(); + int* reorderedConfDataIndices = reinterpret_cast(reorderedConf.data()); - float *decodedBboxesData = decodedBboxes.data(); - float *bboxSizesData = bboxSizes.data(); - int *indicesData = indices.data(); - int *indicesBufData = indicesBuffer.data(); - int *detectionsData = detectionsCount.data(); + float* decodedBboxesData = decodedBboxes.data(); + float* bboxSizesData = bboxSizes.data(); + int* indicesData = indices.data(); + int* indicesBufData = indicesBuffer.data(); + int* detectionsData = detectionsCount.data(); memset(detectionsData, 0, imgNum * classesNum * sizeof(int)); int priorsBatch = isPriorsPerImg ? imgNum : 1; - int *numPriorsActualdata = numPriorsActual.data(); + int* numPriorsActualdata = numPriorsActual.data(); for (int n = 0; n < priorsBatch; ++n) { - const float *ppriors = priorData; + const float* ppriors = priorData; ppriors += varianceEncodedInTarget ? (n * priorsNum * priorSize) : (2 * n * priorsNum * priorSize); getActualPriorNum(ppriors, numPriorsActualdata, n); } @@ -204,21 +201,32 @@ void DetectionOutput::execute(dnnl::stream strm) { if (!isSparsityWorthwhile) { confReorderDense(confData, ARMConfData, reorderedConfData); - } else { // sparsity + } else { // sparsity if (!decreaseClassId) { - confReorderAndFilterSparsityCF(confData, ARMConfData, reorderedConfData, indicesData, indicesBufData, detectionsData); + confReorderAndFilterSparsityCF(confData, + ARMConfData, + reorderedConfData, + indicesData, + indicesBufData, + detectionsData); } else { - confReorderAndFilterSparsityMX(confData, ARMConfData, reorderedConfData, indicesData, indicesBufData, detectionsData); + confReorderAndFilterSparsityMX(confData, + ARMConfData, + reorderedConfData, + indicesData, + indicesBufData, + detectionsData); } } - int *confInfoV = confInfoForPrior.data(); + int* confInfoV = confInfoForPrior.data(); for (int n = 0; n < imgNum; ++n) { - const float *ppriors = priorData; - const float *priorVariances = priorData + priorsNum * priorSize; + const float* ppriors = priorData; + const float* priorVariances = priorData + priorsNum * priorSize; if (isPriorsPerImg) { - int priorSizePerImg = varianceEncodedInTarget ? (n * priorsNum * priorSize) : (2 * n * priorsNum * priorSize); + int priorSizePerImg = + varianceEncodedInTarget ? (n * priorsNum * priorSize) : (2 * n * priorsNum * priorSize); ppriors += priorSizePerImg; priorVariances += varianceEncodedInTarget ? 0 : priorSizePerImg; } @@ -226,17 +234,50 @@ void DetectionOutput::execute(dnnl::stream strm) { if (isShareLoc) { int locShift = n * priorsNum; int coordShift = locShift * 4; - const float *ploc = locData + coordShift; - float *pboxes = decodedBboxesData + coordShift; - float *psizes = bboxSizesData + locShift; - int *confInfoVB = confInfoV + locShift; + const float* ploc = locData + coordShift; + float* pboxes = decodedBboxesData + coordShift; + float* psizes = bboxSizesData + locShift; + int* confInfoVB = confInfoV + locShift; if (withAddBoxPred) { - const float *pARMLoc = ARMLocData + coordShift; - decodeBBoxes(ppriors, pARMLoc, priorVariances, pboxes, psizes, numPriorsActualdata, n, coordOffset, priorSize, true, nullptr, confInfoVB); - decodeBBoxes(pboxes, ploc, priorVariances, pboxes, psizes, numPriorsActualdata, n, 0, 4, false, nullptr, confInfoVB); + const float* pARMLoc = ARMLocData + coordShift; + decodeBBoxes(ppriors, + pARMLoc, + priorVariances, + pboxes, + psizes, + numPriorsActualdata, + n, + coordOffset, + priorSize, + true, + nullptr, + confInfoVB); + decodeBBoxes(pboxes, + ploc, + priorVariances, + pboxes, + psizes, + numPriorsActualdata, + n, + 0, + 4, + false, + nullptr, + confInfoVB); } else { - decodeBBoxes(ppriors, ploc, priorVariances, pboxes, psizes, numPriorsActualdata, n, coordOffset, priorSize, true, nullptr, confInfoVB); + decodeBBoxes(ppriors, + ploc, + priorVariances, + pboxes, + psizes, + numPriorsActualdata, + n, + coordOffset, + priorSize, + true, + nullptr, + confInfoVB); } } else { for (int c = 0; c < locNumForClasses; ++c) { @@ -245,16 +286,46 @@ void DetectionOutput::execute(dnnl::stream strm) { } int locShift = n * priorsNum * locNumForClasses; int coordShift = locShift * 4; - const float *ploc = locData + coordShift + c * 4; - float *pboxes = decodedBboxesData + coordShift + c * 4 * priorsNum; - float *psizes = bboxSizesData + locShift + c * priorsNum; - int *confInfoHBC = reorderedConfDataIndices + n * confInfoLen * classesNum + c*confInfoLen; + const float* ploc = locData + coordShift + c * 4; + float* pboxes = decodedBboxesData + coordShift + c * 4 * priorsNum; + float* psizes = bboxSizesData + locShift + c * priorsNum; + int* confInfoHBC = reorderedConfDataIndices + n * confInfoLen * classesNum + c * confInfoLen; if (withAddBoxPred) { - const float *pARMLoc = ARMLocData + n * 4 * locNumForClasses * priorsNum + c * 4; - decodeBBoxes(ppriors, pARMLoc, priorVariances, pboxes, psizes, numPriorsActualdata, n, coordOffset, priorSize, true, confInfoHBC); - decodeBBoxes(pboxes, ploc, priorVariances, pboxes, psizes, numPriorsActualdata, n, 0, 4, false, confInfoHBC); + const float* pARMLoc = ARMLocData + n * 4 * locNumForClasses * priorsNum + c * 4; + decodeBBoxes(ppriors, + pARMLoc, + priorVariances, + pboxes, + psizes, + numPriorsActualdata, + n, + coordOffset, + priorSize, + true, + confInfoHBC); + decodeBBoxes(pboxes, + ploc, + priorVariances, + pboxes, + psizes, + numPriorsActualdata, + n, + 0, + 4, + false, + confInfoHBC); } else { - decodeBBoxes(ppriors, ploc, priorVariances, pboxes, psizes, numPriorsActualdata, n, coordOffset, priorSize, true, confInfoHBC); + decodeBBoxes(ppriors, + ploc, + priorVariances, + pboxes, + psizes, + numPriorsActualdata, + n, + coordOffset, + priorSize, + true, + confInfoHBC); } } } @@ -267,16 +338,16 @@ void DetectionOutput::execute(dnnl::stream strm) { parallel_for(classesNum, [&](int c) { if (c != backgroundClassId) { // Ignore background class const int off = n * priorsNum * classesNum + c * priorsNum; - const float *pconfReorder = reorderedConfData + off; - int *pindices = indicesData + off; - int *pbuffer = indicesBufData + off; - int *pdetections = detectionsData + n * classesNum + c; + const float* pconfReorder = reorderedConfData + off; + int* pindices = indicesData + off; + int* pbuffer = indicesBufData + off; + int* pdetections = detectionsData + n * classesNum + c; if (!isSparsityWorthwhile) confFilterCF(pconfReorder, pindices, pbuffer, pdetections, n); - const float *pboxes; - const float *psizes; + const float* pboxes; + const float* psizes; if (isShareLoc) { pboxes = decodedBboxesData + n * 4 * priorsNum; psizes = bboxSizesData + n * priorsNum; @@ -291,23 +362,23 @@ void DetectionOutput::execute(dnnl::stream strm) { } else { // MXNet style const int offImg = n * priorsNum * classesNum; - const float *pconf = confData + offImg; - float *pconfReorder = reorderedConfData + offImg; - int *pbuffer = indicesBufData + offImg; - int *pindices = indicesData + offImg; - int *pdetections = detectionsData + n * classesNum; + const float* pconf = confData + offImg; + float* pconfReorder = reorderedConfData + offImg; + int* pbuffer = indicesBufData + offImg; + int* pindices = indicesData + offImg; + int* pdetections = detectionsData + n * classesNum; if (!isSparsityWorthwhile) confFilterMX(pconf, ARMConfData, pconfReorder, pindices, pbuffer, pdetections, n); - const float *pboxes = decodedBboxesData + n * 4 * locNumForClasses * priorsNum; - const float *psizes = bboxSizesData + n * locNumForClasses * priorsNum; + const float* pboxes = decodedBboxesData + n * 4 * locNumForClasses * priorsNum; + const float* psizes = bboxSizesData + n * locNumForClasses * priorsNum; NMSMX(pbuffer, pdetections, pindices, pboxes, psizes); } int detectionsTotal = 0; - detectionsTotal = parallel_sum(classesNum, detectionsTotal, [&](size_t c)->int { + detectionsTotal = parallel_sum(classesNum, detectionsTotal, [&](size_t c) -> int { return detectionsData[n * classesNum + c]; }); @@ -318,9 +389,9 @@ void DetectionOutput::execute(dnnl::stream strm) { std::mutex mtx; parallel_for(classesNum, [&](int c) { const int detections = detectionsData[n * classesNum + c]; - int *pindices = indicesData + n * classesNum * priorsNum + c * priorsNum; + int* pindices = indicesData + n * classesNum * priorsNum + c * priorsNum; - float *pconf = reorderedConfData + n * classesNum * confInfoLen + c * confInfoLen; + float* pconf = reorderedConfData + n * classesNum * confInfoLen + c * confInfoLen; for (int i = 0; i < detections; ++i) { int pr = pindices[i]; @@ -330,7 +401,8 @@ void DetectionOutput::execute(dnnl::stream strm) { } }); - std::sort(confIndicesClassMap.begin(), confIndicesClassMap.end(), + std::sort(confIndicesClassMap.begin(), + confIndicesClassMap.end(), SortScorePairDescend>); confIndicesClassMap.resize(keepTopK); @@ -340,7 +412,7 @@ void DetectionOutput::execute(dnnl::stream strm) { for (size_t j = 0; j < confIndicesClassMap.size(); ++j) { const int cls = confIndicesClassMap[j].second.first; const int pr = confIndicesClassMap[j].second.second; - int *pindices = indicesData + n * classesNum * priorsNum + cls * priorsNum; + int* pindices = indicesData + n * classesNum * priorsNum + cls * priorsNum; pindices[detectionsData[n * classesNum + cls]] = pr; detectionsData[n * classesNum + cls]++; } @@ -351,7 +423,11 @@ void DetectionOutput::execute(dnnl::stream strm) { generateOutput(reorderedConfData, indicesData, detectionsData, decodedBboxesData, dstData); } -inline void DetectionOutput::confFilterCF(const float* pconf, int* pindices, int* pbuffer, int* detectionsData, const int& n) { +inline void DetectionOutput::confFilterCF(const float* pconf, + int* pindices, + int* pbuffer, + int* detectionsData, + const int& n) { // in: reorderedConf // out: pindices count int count = 0; @@ -371,21 +447,27 @@ inline void DetectionOutput::confFilterCF(const float* pconf, int* pindices, int // MX filter is per image filter, max output is prior num(select max for all class within this prior) // NMS is per class, keep topk is per image, final output is per class -inline void DetectionOutput::confFilterMX(const float* confData, const float* ARMConfData, float* reorderedConfData, - int* indicesData, int* indicesBufData, int* detectionsData, const int& n) { +inline void DetectionOutput::confFilterMX(const float* confData, + const float* ARMConfData, + float* reorderedConfData, + int* indicesData, + int* indicesBufData, + int* detectionsData, + const int& n) { std::mutex mtx; parallel_for(numPriorsActual[n], [&](size_t p) { // in: origin conf // out: pindices, detectionCount // intentionally code branch from higher level if (withAddBoxPred) { - const bool isARMPrior = ARMConfData[n*priorsNum*2 + p * 2 + 1] < objScore; + const bool isARMPrior = ARMConfData[n * priorsNum * 2 + p * 2 + 1] < objScore; float maxConf = -1; int maxCIdx = 0; for (int c = 1; c < classesNum; ++c) { float conf = confData[p * classesNum + c]; if (isARMPrior) - conf = (c == backgroundClassId) ? 1.0f : 0.0f; // still need refresh conf due to read from origin conf + conf = + (c == backgroundClassId) ? 1.0f : 0.0f; // still need refresh conf due to read from origin conf if (conf >= confidenceThreshold && conf > maxConf) { maxConf = conf; maxCIdx = c; @@ -394,7 +476,7 @@ inline void DetectionOutput::confFilterMX(const float* confData, const float* AR if (maxCIdx > 0) { // include this prior mtx.lock(); - indicesData[detectionsData[0]] = maxCIdx*priorsNum + p; // de-refer to get prior and class id. + indicesData[detectionsData[0]] = maxCIdx * priorsNum + p; // de-refer to get prior and class id. detectionsData[0]++; mtx.unlock(); } @@ -411,7 +493,7 @@ inline void DetectionOutput::confFilterMX(const float* confData, const float* AR if (maxCIdx > 0) { // include this prior and class with max conf mtx.lock(); - indicesData[detectionsData[0]] = maxCIdx*priorsNum + p; // de-refer to get prior and class id. + indicesData[detectionsData[0]] = maxCIdx * priorsNum + p; // de-refer to get prior and class id. detectionsData[0]++; mtx.unlock(); } @@ -423,14 +505,14 @@ inline void DetectionOutput::confFilterMX(const float* confData, const float* AR int count = detectionsData[0]; int k = (topK == -1 ? count : (std::min)(topK, count)); - const float *pconf = reorderedConfData; + const float* pconf = reorderedConfData; // int *indices = indicesData; // int *pbuffer = indicesBufData; topk(indicesData, indicesBufData, pconf, count, k); detectionsData[0] = k; } -inline void DetectionOutput::getActualPriorNum(const float *priorData, int* numPriorsActual, int n) { +inline void DetectionOutput::getActualPriorNum(const float* priorData, int* numPriorsActual, int n) { numPriorsActual[n] = priorsNum; if (!normalized) { int num = 0; @@ -444,16 +526,20 @@ inline void DetectionOutput::getActualPriorNum(const float *priorData, int* numP } } -inline void DetectionOutput::confReorderDense(const float *confData, const float *ARMConfData, float *reorderedConfData) { +inline void DetectionOutput::confReorderDense(const float* confData, + const float* ARMConfData, + float* reorderedConfData) { if (withAddBoxPred) { parallel_for2d(imgNum, priorsNum, [&](size_t n, size_t p) { if (ARMConfData[n * priorsNum * 2 + p * 2 + 1] < objScore) { for (int c = 0; c < classesNum; ++c) { - reorderedConfData[n * priorsNum * classesNum + c * priorsNum + p] = c == backgroundClassId ? 1.0f : 0.0f; + reorderedConfData[n * priorsNum * classesNum + c * priorsNum + p] = + c == backgroundClassId ? 1.0f : 0.0f; } } else { for (int c = 0; c < classesNum; ++c) { - reorderedConfData[n * priorsNum * classesNum + c * priorsNum + p] = confData[n * priorsNum * classesNum + p * classesNum + c]; + reorderedConfData[n * priorsNum * classesNum + c * priorsNum + p] = + confData[n * priorsNum * classesNum + p * classesNum + c]; } } }); @@ -463,20 +549,23 @@ inline void DetectionOutput::confReorderDense(const float *confData, const float parallel_for2d(imgNum, classesNum, [&](size_t n, size_t c) { const int offset = n * priorsNum * classesNum; for (int p = 0; p < priorsNum; ++p) { - reorderedConfData[offset + c * priorsNum + p] = - confData[offset + p * classesNum + c]; + reorderedConfData[offset + c * priorsNum + p] = confData[offset + p * classesNum + c]; } }); } -inline void DetectionOutput::confReorderAndFilterSparsityCF(const float* confData, const float* ARMConfData, float* reorderedConfData, - int* indicesData, int* indicesBufData, int* detectionsData) { +inline void DetectionOutput::confReorderAndFilterSparsityCF(const float* confData, + const float* ARMConfData, + float* reorderedConfData, + int* indicesData, + int* indicesBufData, + int* detectionsData) { int* reorderedConfDataIndices = reinterpret_cast(reorderedConfData); for (int n = 0; n < imgNum; ++n) { const int off = n * priorsNum * classesNum; const int offV = n * priorsNum; // vertical info - const int offH = n * confInfoLen * classesNum; // horizontal info + const int offH = n * confInfoLen * classesNum; // horizontal info // reset count parallel_for(classesNum, [&](size_t c) { const int countIdx = offH + c * confInfoLen + priorsNum; @@ -506,7 +595,7 @@ inline void DetectionOutput::confReorderAndFilterSparsityCF(const float* confDat // vertical info for isShareLoc(flag to decode for each prior) if (!priorStatusSet && isShareLoc) { - confInfoForPrior[offV + p] = 1; // 1 for decode + confInfoForPrior[offV + p] = 1; // 1 for decode } } } @@ -542,9 +631,9 @@ inline void DetectionOutput::confReorderAndFilterSparsityCF(const float* confDat const int count = reorderedConfDataIndices[countIdx]; const int k = (topK == -1 ? count : (std::min)(topK, count)); - int *reorderedConfIndices = reorderedConfDataIndices + countIdx + 1; - int *pbuffer = indicesBufData + off + c * priorsNum; - const float *pconf = reorderedConfData + offH + c * confInfoLen; + int* reorderedConfIndices = reorderedConfDataIndices + countIdx + 1; + int* pbuffer = indicesBufData + off + c * priorsNum; + const float* pconf = reorderedConfData + offH + c * confInfoLen; topk(reorderedConfIndices, pbuffer, pconf, count, k); detectionsData[n * classesNum + c] = k; @@ -552,8 +641,12 @@ inline void DetectionOutput::confReorderAndFilterSparsityCF(const float* confDat } } -inline void DetectionOutput::confReorderAndFilterSparsityMX(const float* confData, const float* ARMConfData, float* reorderedConfData, - int* indicesData, int* indicesBufData, int* detectionsData) { +inline void DetectionOutput::confReorderAndFilterSparsityMX(const float* confData, + const float* ARMConfData, + float* reorderedConfData, + int* indicesData, + int* indicesBufData, + int* detectionsData) { for (int n = 0; n < imgNum; ++n) { const int off = n * priorsNum * classesNum; const int offV = n * priorsNum; // vertical info @@ -579,7 +672,7 @@ inline void DetectionOutput::confReorderAndFilterSparsityMX(const float* confDat // vertical info for isShareLoc(flag to decode for each prior) if (!priorStatusSet && isShareLoc) { - confInfoForPrior[offV + p] = 1; // 1 for decode + confInfoForPrior[offV + p] = 1; // 1 for decode } // vertical info for MXNet style(max conf for each prior) if (c != 0) { @@ -593,7 +686,8 @@ inline void DetectionOutput::confReorderAndFilterSparsityMX(const float* confDat // MXNet statistic, indices and detectionCount is for each image if (maxCIdx > 0) { mtx.lock(); - indicesData[off + detectionsData[n * classesNum]] = maxCIdx * priorsNum + p; // de-refer to get prior and class id. + indicesData[off + detectionsData[n * classesNum]] = + maxCIdx * priorsNum + p; // de-refer to get prior and class id. detectionsData[n * classesNum]++; mtx.unlock(); } @@ -604,27 +698,27 @@ inline void DetectionOutput::confReorderAndFilterSparsityMX(const float* confDat const int count = detectionsData[n * classesNum]; const int k = (topK == -1 ? count : (std::min)(topK, count)); - const float *pconf = reorderedConfData + off; - int *indices = indicesData + off; - int *pbuffer = indicesBufData + off; + const float* pconf = reorderedConfData + off; + int* indices = indicesData + off; + int* pbuffer = indicesBufData + off; topk(indices, pbuffer, pconf, count, k); detectionsData[n * classesNum] = k; } } // apply locData(offset) to priordata, generate decodedBox -inline void DetectionOutput::decodeBBoxes(const float *priorData, - const float *locData, - const float *varianceData, - float *decodedBboxes, - float *decodedBboxSizes, - int* numPriorsActual, - int n, - const int& offs, - const int& priorSize, - bool decodeType, - const int *confInfoH, - const int *confInfoV) { +inline void DetectionOutput::decodeBBoxes(const float* priorData, + const float* locData, + const float* varianceData, + float* decodedBboxes, + float* decodedBboxSizes, + int* numPriorsActual, + int n, + const int& offs, + const int& priorSize, + bool decodeType, + const int* confInfoH, + const int* confInfoV) { int prNum = numPriorsActual[n]; if (!decodeType) { prNum = priorsNum; @@ -672,8 +766,8 @@ inline void DetectionOutput::decodeBBoxes(const float *priorData, newYMax = priorYMax + varianceData[p * 4 + 3] * locYMax; } } else if (codeType == CodeType::CENTER_SIZE) { - float priorWidth = priorXMax - priorXMin; - float priorHeight = priorYMax - priorYMin; + float priorWidth = priorXMax - priorXMin; + float priorHeight = priorYMax - priorYMin; float priorCenterX = (priorXMin + priorXMax) / 2.0f; float priorCenterY = (priorYMin + priorYMax) / 2.0f; @@ -682,21 +776,21 @@ inline void DetectionOutput::decodeBBoxes(const float *priorData, if (varianceEncodedInTarget) { // variance is encoded in target, we simply need to restore the offset predictions. - decodeBboxCenterX = locXMin * priorWidth + priorCenterX; + decodeBboxCenterX = locXMin * priorWidth + priorCenterX; decodeBboxCenterY = locYMin * priorHeight + priorCenterY; - decodeBboxWidth = std::exp(locXMax) * priorWidth; + decodeBboxWidth = std::exp(locXMax) * priorWidth; decodeBboxHeight = std::exp(locYMax) * priorHeight; } else { // variance is encoded in bbox, we need to scale the offset accordingly. - decodeBboxCenterX = varianceData[p*4 + 0] * locXMin * priorWidth + priorCenterX; - decodeBboxCenterY = varianceData[p*4 + 1] * locYMin * priorHeight + priorCenterY; - decodeBboxWidth = std::exp(varianceData[p*4 + 2] * locXMax) * priorWidth; - decodeBboxHeight = std::exp(varianceData[p*4 + 3] * locYMax) * priorHeight; + decodeBboxCenterX = varianceData[p * 4 + 0] * locXMin * priorWidth + priorCenterX; + decodeBboxCenterY = varianceData[p * 4 + 1] * locYMin * priorHeight + priorCenterY; + decodeBboxWidth = std::exp(varianceData[p * 4 + 2] * locXMax) * priorWidth; + decodeBboxHeight = std::exp(varianceData[p * 4 + 3] * locYMax) * priorHeight; } - newXMin = decodeBboxCenterX - decodeBboxWidth / 2.0f; + newXMin = decodeBboxCenterX - decodeBboxWidth / 2.0f; newYMin = decodeBboxCenterY - decodeBboxHeight / 2.0f; - newXMax = decodeBboxCenterX + decodeBboxWidth / 2.0f; + newXMax = decodeBboxCenterX + decodeBboxWidth / 2.0f; newYMax = decodeBboxCenterY + decodeBboxHeight / 2.0f; } @@ -707,25 +801,20 @@ inline void DetectionOutput::decodeBBoxes(const float *priorData, newYMax = (std::max)(0.0f, (std::min)(1.0f, newYMax)); } - decodedBboxes[p*4 + 0] = newXMin; - decodedBboxes[p*4 + 1] = newYMin; - decodedBboxes[p*4 + 2] = newXMax; - decodedBboxes[p*4 + 3] = newYMax; + decodedBboxes[p * 4 + 0] = newXMin; + decodedBboxes[p * 4 + 1] = newYMin; + decodedBboxes[p * 4 + 2] = newXMax; + decodedBboxes[p * 4 + 3] = newYMax; decodedBboxSizes[p] = (newXMax - newXMin) * (newYMax - newYMin); }); } -inline void DetectionOutput::topk(const int *indicesIn, int *indicesOut, const float *conf, int n, int k) { - std::partial_sort_copy(indicesIn, indicesIn + n, - indicesOut, indicesOut + k, - ConfidenceComparatorDO(conf)); +inline void DetectionOutput::topk(const int* indicesIn, int* indicesOut, const float* conf, int n, int k) { + std::partial_sort_copy(indicesIn, indicesIn + n, indicesOut, indicesOut + k, ConfidenceComparatorDO(conf)); } -static inline float JaccardOverlap(const float *decodedBbox, - const float *bboxSizes, - const int idx1, - const int idx2) { +static inline float JaccardOverlap(const float* decodedBbox, const float* bboxSizes, const int idx1, const int idx2) { const float xmin1 = decodedBbox[idx1 * 4 + 0]; const float ymin1 = decodedBbox[idx1 * 4 + 1]; const float xmax1 = decodedBbox[idx1 * 4 + 2]; @@ -745,7 +834,7 @@ static inline float JaccardOverlap(const float *decodedBbox, float intersectXMax = (std::min)(xmax1, xmax2); float intersectYMax = (std::min)(ymax1, ymax2); - float intersectWidth = intersectXMax - intersectXMin; + float intersectWidth = intersectXMax - intersectXMin; float intersectHeight = intersectYMax - intersectYMin; if (intersectWidth <= 0 || intersectHeight <= 0) { @@ -760,10 +849,10 @@ static inline float JaccardOverlap(const float *decodedBbox, } inline void DetectionOutput::NMSCF(int* indicesIn, - int& detections, - int* indicesOut, - const float* bboxes, - const float* boxSizes) { + int& detections, + int* indicesOut, + const float* bboxes, + const float* boxSizes) { // nms for this class int countIn = detections; detections = 0; @@ -787,10 +876,10 @@ inline void DetectionOutput::NMSCF(int* indicesIn, } inline void DetectionOutput::NMSMX(int* indicesIn, - int* detections, - int* indicesOut, - const float* bboxes, - const float* sizes) { + int* detections, + int* indicesOut, + const float* bboxes, + const float* sizes) { // Input is candidate for image, output is candidate for each class within image int countIn = detections[0]; detections[0] = 0; @@ -801,8 +890,8 @@ inline void DetectionOutput::NMSMX(int* indicesIn, const int prior = idx % priorsNum; // nms within this class - int &ndetection = detections[cls]; - int *pindices = indicesOut + cls * priorsNum; + int& ndetection = detections[cls]; + int* pindices = indicesOut + cls * priorsNum; bool keep = true; for (int k = 0; k < ndetection; ++k) { @@ -825,8 +914,11 @@ inline void DetectionOutput::NMSMX(int* indicesIn, } } -inline void DetectionOutput::generateOutput(float* reorderedConfData, int* indicesData, int* detectionsData, float* decodedBboxesData, - float* dstData) { +inline void DetectionOutput::generateOutput(float* reorderedConfData, + int* indicesData, + int* detectionsData, + float* decodedBboxesData, + float* dstData) { const auto& outDims = getChildEdgeAt(0)->getMemory().getStaticDims(); const int numResults = outDims[2]; const int DETECTION_SIZE = outDims[3]; @@ -850,26 +942,22 @@ inline void DetectionOutput::generateOutput(float* reorderedConfData, int* indic // set final detection result to output blob int count = 0; for (int n = 0; n < imgNum; ++n) { - const float *pconf = reorderedConfData + n * confInfoLen * classesNum; - const float *pboxes = decodedBboxesData + n * priorsNum * 4 * locNumForClasses; - const int *pindices = indicesData + n * classesNum * priorsNum; + const float* pconf = reorderedConfData + n * confInfoLen * classesNum; + const float* pboxes = decodedBboxesData + n * priorsNum * 4 * locNumForClasses; + const int* pindices = indicesData + n * classesNum * priorsNum; for (int c = 0; c < classesNum; ++c) { for (int i = 0; i < detectionsData[n * classesNum + c]; ++i) { int prIdx = pindices[c * priorsNum + i]; dstData[count * DETECTION_SIZE + 0] = static_cast(n); - dstData[count * DETECTION_SIZE + 1] = static_cast(decreaseClassId ? c-1 : c); + dstData[count * DETECTION_SIZE + 1] = static_cast(decreaseClassId ? c - 1 : c); dstData[count * DETECTION_SIZE + 2] = pconf[c * confInfoLen + prIdx]; - float xmin = isShareLoc ? pboxes[prIdx * 4 + 0] : - pboxes[c * 4 * priorsNum + prIdx * 4 + 0]; - float ymin = isShareLoc ? pboxes[prIdx * 4 + 1] : - pboxes[c * 4 * priorsNum + prIdx * 4 + 1]; - float xmax = isShareLoc ? pboxes[prIdx * 4 + 2] : - pboxes[c * 4 * priorsNum + prIdx * 4 + 2]; - float ymax = isShareLoc ? pboxes[prIdx * 4 + 3] : - pboxes[c * 4 * priorsNum + prIdx * 4 + 3]; + float xmin = isShareLoc ? pboxes[prIdx * 4 + 0] : pboxes[c * 4 * priorsNum + prIdx * 4 + 0]; + float ymin = isShareLoc ? pboxes[prIdx * 4 + 1] : pboxes[c * 4 * priorsNum + prIdx * 4 + 1]; + float xmax = isShareLoc ? pboxes[prIdx * 4 + 2] : pboxes[c * 4 * priorsNum + prIdx * 4 + 2]; + float ymax = isShareLoc ? pboxes[prIdx * 4 + 3] : pboxes[c * 4 * priorsNum + prIdx * 4 + 3]; if (clipAfterNMS) { xmin = (std::max)(0.0f, (std::min)(1.0f, xmin)); @@ -898,6 +986,6 @@ bool DetectionOutput::created() const { return getType() == Type::DetectionOutput; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/detection_output.h b/src/plugins/intel_cpu/src/nodes/detection_output.h index 418898f011f313..1a42bfa9b2980a 100644 --- a/src/plugins/intel_cpu/src/nodes/detection_output.h +++ b/src/plugins/intel_cpu/src/nodes/detection_output.h @@ -15,7 +15,7 @@ class DetectionOutput : public Node { public: DetectionOutput(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -47,8 +47,8 @@ class DetectionOutput : public Node { float sparsityThreshold = 0.03f; int topK = 0; float NMSThreshold = 0.0f; - bool clipBeforeNMS = false; - bool clipAfterNMS = false; + bool clipBeforeNMS = false; + bool clipAfterNMS = false; int backgroundClassId = 0; bool decreaseClassId = false; int keepTopK = 0; @@ -75,28 +75,52 @@ class DetectionOutput : public Node { inline void confFilterCF(const float* pConf, int* pindices, int* pbuffer, int* detectionsData, const int& n); - inline void confFilterMX(const float* confData, const float* ARMConfData, float* reorderedConfData, - int* indicesData, int* indicesBufData, int* detectionsData, const int& n); - - inline void confReorderAndFilterSparsityCF(const float* confData, const float* ARMConfData, float* reorderedConfData, - int* indicesData, int* indicesBufData, int* detectionsData); - - inline void confReorderAndFilterSparsityMX(const float* confData, const float* ARMConfData, float* reorderedConfData, - int* indicesData, int* indicesBufData, int* detectionsData); - - inline void decodeBBoxes(const float* prior_data, const float* loc_data, const float* variance_data, - float* decoded_bboxes, float* decoded_bbox_sizes, int* num_priors_actual, int n, const int& offs, const int& pr_size, - bool decodeType = true, const int* conf_info_h = nullptr, const int* conf_info_v = nullptr); // decodeType is false after ARM - - inline void NMSCF(int* indicesIn, int& detections, int* indicesOut, - const float* bboxes, const float* boxSizes); - - inline void NMSMX(int* indicesIn, int* detections, int* indicesOut, - const float* bboxes, const float* sizes); + inline void confFilterMX(const float* confData, + const float* ARMConfData, + float* reorderedConfData, + int* indicesData, + int* indicesBufData, + int* detectionsData, + const int& n); + + inline void confReorderAndFilterSparsityCF(const float* confData, + const float* ARMConfData, + float* reorderedConfData, + int* indicesData, + int* indicesBufData, + int* detectionsData); + + inline void confReorderAndFilterSparsityMX(const float* confData, + const float* ARMConfData, + float* reorderedConfData, + int* indicesData, + int* indicesBufData, + int* detectionsData); + + inline void decodeBBoxes(const float* prior_data, + const float* loc_data, + const float* variance_data, + float* decoded_bboxes, + float* decoded_bbox_sizes, + int* num_priors_actual, + int n, + const int& offs, + const int& pr_size, + bool decodeType = true, + const int* conf_info_h = nullptr, + const int* conf_info_v = nullptr); // decodeType is false after ARM + + inline void NMSCF(int* indicesIn, int& detections, int* indicesOut, const float* bboxes, const float* boxSizes); + + inline void NMSMX(int* indicesIn, int* detections, int* indicesOut, const float* bboxes, const float* sizes); inline void topk(const int* indicesIn, int* indicesOut, const float* conf, int n, int k); - inline void generateOutput(float* reorderedConfData, int* indicesData, int* detectionsData, float* decodedBboxesData, float* dstData); + inline void generateOutput(float* reorderedConfData, + int* indicesData, + int* detectionsData, + float* decodedBboxesData, + float* dstData); std::vector decodedBboxes; std::vector indicesBuffer; @@ -110,6 +134,6 @@ class DetectionOutput : public Node { std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/dft.cpp b/src/plugins/intel_cpu/src/nodes/dft.cpp index 76ecbbb36617f5..5fa8053d7024d7 100644 --- a/src/plugins/intel_cpu/src/nodes/dft.cpp +++ b/src/plugins/intel_cpu/src/nodes/dft.cpp @@ -4,17 +4,17 @@ #include "dft.h" +#include +#include #include #include -#include -#include "dnnl_extension_utils.h" -#include "openvino/core/parallel.hpp" +#include "common/cpu_memcpy.h" +#include "dnnl_extension_utils.h" #include "onednn/dnnl.h" +#include "openvino/core/parallel.hpp" #include "utils/general_utils.h" -#include "common/cpu_memcpy.h" #include "utils/ngraph_utils.hpp" -#include using namespace dnnl::impl; using namespace dnnl::impl::cpu::x64; @@ -104,10 +104,10 @@ void DFT::initSupportedPrimitiveDescriptors() { } } - std::vector inDataConfigurators({{LayoutType::ncsp, ov::element::f32}, - {LayoutType::ncsp, ov::element::i32}}); + std::vector inDataConfigurators( + {{LayoutType::ncsp, ov::element::f32}, {LayoutType::ncsp, ov::element::i32}}); if (inputShapes.size() > SIGNAL_SIZE_INDEX) - inDataConfigurators.push_back({LayoutType::ncsp, ov::element::i32}); + inDataConfigurators.push_back({LayoutType::ncsp, ov::element::i32}); addSupportedPrimDesc(inDataConfigurators, {{LayoutType::ncsp, ov::element::f32}}, impl_desc_type::ref_any); } @@ -172,8 +172,12 @@ size_t calculateOffsetFromStrides(const std::vector& coords, const std:: return offset; } -void gatherToBufferND(float* buffer, const float* data, size_t axis, const std::vector& dimIndexes, - const std::vector& shape, const std::vector& strides) { +void gatherToBufferND(float* buffer, + const float* data, + size_t axis, + const std::vector& dimIndexes, + const std::vector& shape, + const std::vector& strides) { size_t numberOfComplex = shape[axis]; size_t offset = calculateOffsetFromStrides(dimIndexes, strides); @@ -184,8 +188,12 @@ void gatherToBufferND(float* buffer, const float* data, size_t axis, const std:: } } -void applyBufferND(const float* buffer, float* output, size_t axis, const std::vector& dimIndexes, - const std::vector& shape, const std::vector& strides) { +void applyBufferND(const float* buffer, + float* output, + size_t axis, + const std::vector& dimIndexes, + const std::vector& shape, + const std::vector& strides) { size_t numberOfComplex = shape[axis]; size_t offset = calculateOffsetFromStrides(dimIndexes, strides); @@ -196,8 +204,12 @@ void applyBufferND(const float* buffer, float* output, size_t axis, const std::v } } -void copyDataToOutputWithSignalSize(const float* input, const std::vector& inputShape, const std::vector& inputStrides, - float* output, const std::vector& outputShape, const std::vector& outputStrides) { +void copyDataToOutputWithSignalSize(const float* input, + const std::vector& inputShape, + const std::vector& inputStrides, + float* output, + const std::vector& outputShape, + const std::vector& outputStrides) { auto totalInput = std::accumulate(inputShape.begin(), inputShape.end(), size_t(1), std::multiplies()); auto totalOutput = std::accumulate(outputShape.begin(), outputShape.end(), size_t(1), std::multiplies()); std::fill_n(output, totalOutput, 0.f); @@ -221,7 +233,10 @@ void copyDataToOutputWithSignalSize(const float* input, const std::vector inputStridesRange(inputStrides.begin(), inputStrides.begin() + iterationRange.size()); const std::vector outputStridesRange(outputStrides.begin(), outputStrides.begin() + iterationRange.size()); - const size_t blockSize = std::accumulate(inputShape.begin() + lastChangedDim + 1, inputShape.end(), size_t(1), std::multiplies()); + const size_t blockSize = std::accumulate(inputShape.begin() + lastChangedDim + 1, + inputShape.end(), + size_t(1), + std::multiplies()); const size_t blockSizeBytes = blockSize * sizeof(float); std::vector iterationCounter(iterationRange.size(), 0); do { @@ -231,7 +246,7 @@ void copyDataToOutputWithSignalSize(const float* input, const std::vectorgetMemory().getStaticDims(); @@ -269,7 +284,8 @@ void DFT::execute(dnnl::stream strm) { if (inputShape != outputShape) { copyDataToOutputWithSignalSize(src, inputShape, inputStrides, dst, outputShape, outputStrides); } else { - auto totalElements = std::accumulate(inputShape.begin(), inputShape.end(), size_t(1), std::multiplies()); + auto totalElements = + std::accumulate(inputShape.begin(), inputShape.end(), size_t(1), std::multiplies()); cpu_memcpy(dst, src, totalElements * sizeof(float)); } @@ -315,17 +331,32 @@ void DFT::dftNd(float* output, std::vector gatheredData(outputLen * 2); auto parallelIterationCounter = iterationCounter; parallelIterationCounter[parallelDimIndex] = dim; - gatherToBufferND(gatheredData.data(), output, currentAxis, parallelIterationCounter, outputShape, outputStrides); + gatherToBufferND(gatheredData.data(), + output, + currentAxis, + parallelIterationCounter, + outputShape, + outputStrides); const float* resultBufPtr; fft(gatheredData.data(), gatheredData.data() + outputLen, outputLen, inverse, false, &resultBufPtr); - applyBufferND(resultBufPtr, output, currentAxis, parallelIterationCounter, outputShape, outputStrides); + applyBufferND(resultBufPtr, + output, + currentAxis, + parallelIterationCounter, + outputShape, + outputStrides); }); iterationCounter[parallelDimIndex] = iterationRange[parallelDimIndex] - 1; } while (nextIterationStep(iterationCounter, iterationRange, currentAxis)); } else { std::vector gatheredData(outputLen); do { - gatherToBufferND(gatheredData.data(), output, currentAxis, iterationCounter, outputShape, outputStrides); + gatherToBufferND(gatheredData.data(), + output, + currentAxis, + iterationCounter, + outputShape, + outputStrides); naiveDFT(gatheredData.data(), outputLen, inverse); applyBufferND(gatheredData.data(), output, currentAxis, iterationCounter, outputShape, outputStrides); } while (nextIterationStep(iterationCounter, iterationRange, currentAxis)); @@ -585,6 +616,6 @@ void DFT::createJITKernels(bool hasDFT, bool hasFFT) { } #endif } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/dft.h b/src/plugins/intel_cpu/src/nodes/dft.h index 82b6ea3b33a618..eef5e2ea529066 100644 --- a/src/plugins/intel_cpu/src/nodes/dft.h +++ b/src/plugins/intel_cpu/src/nodes/dft.h @@ -63,6 +63,6 @@ class DFT : public Node { bool lastInverse; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/eltwise.cpp b/src/plugins/intel_cpu/src/nodes/eltwise.cpp index 54cf435009059d..5daefa01eddfab 100644 --- a/src/plugins/intel_cpu/src/nodes/eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/eltwise.cpp @@ -3,6 +3,18 @@ // #include "eltwise.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + #include "common/cpu_convert.h" #include "common/float16.hpp" #include "common/primitive_hashing_utils.hpp" @@ -10,6 +22,10 @@ #include "cpu/ref_eltwise.hpp" #include "cpu_types.h" #include "dnnl_extension_utils.h" +#include "emitters/plugin/x64/jit_bf16_emitters.hpp" +#include "emitters/plugin/x64/jit_dnnl_emitters.hpp" +#include "emitters/plugin/x64/jit_eltwise_emitters.hpp" +#include "emitters/plugin/x64/jit_emitter.hpp" #include "fake_quantize.h" #include "input.h" #include "memory_desc/dnnl_blocked_memory_desc.h" @@ -17,13 +33,13 @@ #include "onednn/dnnl.h" #include "openvino/core/except.hpp" #include "openvino/core/parallel.hpp" -#include "openvino/opsets/opset1.hpp" #include "openvino/op/bitwise_and.hpp" #include "openvino/op/bitwise_left_shift.hpp" #include "openvino/op/bitwise_not.hpp" #include "openvino/op/bitwise_or.hpp" #include "openvino/op/bitwise_right_shift.hpp" #include "openvino/op/bitwise_xor.hpp" +#include "openvino/opsets/opset1.hpp" #include "pooling.h" #include "selective_build.h" #include "shape_inference/custom/eltwise.hpp" @@ -35,27 +51,10 @@ #include "utils/general_utils.h" #include "utils/ngraph_utils.hpp" -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "emitters/plugin/x64/jit_emitter.hpp" -#include "emitters/plugin/x64/jit_eltwise_emitters.hpp" -#include "emitters/plugin/x64/jit_dnnl_emitters.hpp" -#include "emitters/plugin/x64/jit_bf16_emitters.hpp" - #if defined(OPENVINO_ARCH_ARM64) -#include "cpu/aarch64/cpu_isa_traits.hpp" -#include "kernels/aarch64/jit_uni_eltwise_generic.hpp" -#include "executors/aarch64/jit_eltwise.hpp" +# include "cpu/aarch64/cpu_isa_traits.hpp" +# include "executors/aarch64/jit_eltwise.hpp" +# include "kernels/aarch64/jit_uni_eltwise_generic.hpp" #endif using namespace dnnl::impl::utils; @@ -92,60 +91,72 @@ bool jitIsSupported(const Node* node, beta, gamma); } -} // namespace +} // namespace #endif #if defined(OPENVINO_ARCH_X86_64) -template +template struct SupportedPrecisions { - void operator()(std::set> &precisions) { + void operator()(std::set>& precisions) { precisions = T::get_supported_precisions(); } }; struct EltwiseEmitterContext { std::shared_ptr emitter; - jit_generator *host; + jit_generator* host; cpu_isa_t host_isa; const EltwiseData& opData; ov::element::Type exec_prc; }; -template +template struct EltwiseEmitter { - void operator()(EltwiseEmitterContext & ctx) { + void operator()(EltwiseEmitterContext& ctx) { ctx.emitter = std::make_shared(ctx.host, ctx.host_isa, ctx.exec_prc); } }; -template<> +template <> struct EltwiseEmitter { - void operator()(EltwiseEmitterContext & ctx) { + void operator()(EltwiseEmitterContext& ctx) { auto algKind = static_cast(ctx.opData.onednnAlgorithm); - ctx.emitter = std::make_shared(ctx.host, ctx.host_isa, algKind, - ctx.opData.alpha, ctx.opData.beta, ctx.exec_prc); + ctx.emitter = std::make_shared(ctx.host, + ctx.host_isa, + algKind, + ctx.opData.alpha, + ctx.opData.beta, + ctx.exec_prc); } }; -template<> +template <> struct EltwiseEmitter { - void operator()(EltwiseEmitterContext & ctx) { - ctx.emitter = std::make_shared(ctx.host, ctx.host_isa, ctx.opData.alpha, - ctx.opData.beta, ctx.opData.gamma, ctx.exec_prc); + void operator()(EltwiseEmitterContext& ctx) { + ctx.emitter = std::make_shared(ctx.host, + ctx.host_isa, + ctx.opData.alpha, + ctx.opData.beta, + ctx.opData.gamma, + ctx.exec_prc); } }; -template<> +template <> struct EltwiseEmitter { - void operator()(EltwiseEmitterContext & ctx) { - ctx.emitter = std::make_shared(ctx.host, ctx.host_isa, ctx.exec_prc, ctx.opData.alpha, ctx.opData.beta); + void operator()(EltwiseEmitterContext& ctx) { + ctx.emitter = std::make_shared(ctx.host, + ctx.host_isa, + ctx.exec_prc, + ctx.opData.alpha, + ctx.opData.beta); } }; static void set_intersection(const std::set>& precisions1, - const std::set>& precisions2, - std::set>& intersection) { + const std::set>& precisions2, + std::set>& intersection) { std::map intersection_types; for (auto it1 = precisions1.begin(); it1 != precisions1.end(); ++it1) { @@ -195,15 +206,8 @@ ov::element::Type eltwise_precision_helper::get_precision(const size_t inputs_nu supported_precision_intersection = prcs_intersect; } - static const element::Type exec_precisions_priority[] = { - element::u8, - element::i8, - element::u16, - element::i16, - element::bf16, - element::i32, - element::f32 - }; + static const element::Type exec_precisions_priority[] = + {element::u8, element::i8, element::u16, element::i16, element::bf16, element::i32, element::f32}; for (const auto prc : exec_precisions_priority) { if (std::any_of(supported_precision_intersection.begin(), @@ -234,59 +238,62 @@ ov::element::Type eltwise_precision_helper::get_precision(const size_t inputs_nu std::set> eltwise_precision_helper::get_supported_precisions(const Algorithm& algo) { std::set> precisions; - OV_SWITCH(intel_cpu, SupportedPrecisions, precisions, algo, - OV_CASE(Algorithm::EltwiseRelu, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseGeluErf, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseGeluTanh, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseElu, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseTanh, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseSigmoid, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseAbs, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseSqrt, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseSoftRelu, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseClamp, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseSwish, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseHswish, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseMish, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseHsigmoid, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseRoundHalfToEven, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseRoundHalfAwayFromZero, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter), - OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter), - OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter), - OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter), - OV_CASE(Algorithm::EltwiseDivide, jit_divide_emitter), - OV_CASE(Algorithm::EltwiseFloor, jit_floor_emitter), - OV_CASE(Algorithm::EltwiseCeiling, jit_ceiling_emitter), - OV_CASE(Algorithm::EltwiseFloorMod, jit_floor_mod_emitter), - OV_CASE(Algorithm::EltwiseMod, jit_mod_emitter), - OV_CASE(Algorithm::EltwiseMaximum, jit_maximum_emitter), - OV_CASE(Algorithm::EltwiseMinimum, jit_minimum_emitter), - OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter), - OV_CASE(Algorithm::EltwiseSquaredDifference, jit_squared_difference_emitter), - OV_CASE(Algorithm::EltwisePowerDynamic, jit_power_dynamic_emitter), - OV_CASE(Algorithm::EltwiseEqual, jit_equal_emitter), - OV_CASE(Algorithm::EltwiseNotEqual, jit_not_equal_emitter), - OV_CASE(Algorithm::EltwiseGreater, jit_greater_emitter), - OV_CASE(Algorithm::EltwiseGreaterEqual, jit_greater_equal_emitter), - OV_CASE(Algorithm::EltwiseLess, jit_less_emitter), - OV_CASE(Algorithm::EltwiseLessEqual, jit_less_equal_emitter), - OV_CASE(Algorithm::EltwiseLogicalAnd, jit_logical_and_emitter), - OV_CASE(Algorithm::EltwiseLogicalOr, jit_logical_or_emitter), - OV_CASE(Algorithm::EltwiseLogicalXor, jit_logical_xor_emitter), - OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter), - OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), - OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter), - OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter), - OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter), - OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter), - OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter), - OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter), - OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter), - OV_CASE(Algorithm::EltwiseBitwiseAnd, jit_bitwise_and_emitter), - OV_CASE(Algorithm::EltwiseBitwiseNot, jit_bitwise_not_emitter), - OV_CASE(Algorithm::EltwiseBitwiseOr, jit_bitwise_or_emitter), - OV_CASE(Algorithm::EltwiseBitwiseXor, jit_bitwise_xor_emitter)); + OV_SWITCH(intel_cpu, + SupportedPrecisions, + precisions, + algo, + OV_CASE(Algorithm::EltwiseRelu, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseGeluErf, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseGeluTanh, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseElu, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseTanh, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseSigmoid, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseAbs, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseSqrt, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseSoftRelu, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseClamp, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseSwish, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseHswish, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseMish, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseHsigmoid, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseRoundHalfToEven, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseRoundHalfAwayFromZero, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter), + OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter), + OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter), + OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter), + OV_CASE(Algorithm::EltwiseDivide, jit_divide_emitter), + OV_CASE(Algorithm::EltwiseFloor, jit_floor_emitter), + OV_CASE(Algorithm::EltwiseCeiling, jit_ceiling_emitter), + OV_CASE(Algorithm::EltwiseFloorMod, jit_floor_mod_emitter), + OV_CASE(Algorithm::EltwiseMod, jit_mod_emitter), + OV_CASE(Algorithm::EltwiseMaximum, jit_maximum_emitter), + OV_CASE(Algorithm::EltwiseMinimum, jit_minimum_emitter), + OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter), + OV_CASE(Algorithm::EltwiseSquaredDifference, jit_squared_difference_emitter), + OV_CASE(Algorithm::EltwisePowerDynamic, jit_power_dynamic_emitter), + OV_CASE(Algorithm::EltwiseEqual, jit_equal_emitter), + OV_CASE(Algorithm::EltwiseNotEqual, jit_not_equal_emitter), + OV_CASE(Algorithm::EltwiseGreater, jit_greater_emitter), + OV_CASE(Algorithm::EltwiseGreaterEqual, jit_greater_equal_emitter), + OV_CASE(Algorithm::EltwiseLess, jit_less_emitter), + OV_CASE(Algorithm::EltwiseLessEqual, jit_less_equal_emitter), + OV_CASE(Algorithm::EltwiseLogicalAnd, jit_logical_and_emitter), + OV_CASE(Algorithm::EltwiseLogicalOr, jit_logical_or_emitter), + OV_CASE(Algorithm::EltwiseLogicalXor, jit_logical_xor_emitter), + OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter), + OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), + OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter), + OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter), + OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter), + OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter), + OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter), + OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter), + OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter), + OV_CASE(Algorithm::EltwiseBitwiseAnd, jit_bitwise_and_emitter), + OV_CASE(Algorithm::EltwiseBitwiseNot, jit_bitwise_not_emitter), + OV_CASE(Algorithm::EltwiseBitwiseOr, jit_bitwise_or_emitter), + OV_CASE(Algorithm::EltwiseBitwiseXor, jit_bitwise_xor_emitter)); if (precisions.empty()) OPENVINO_THROW("Unsupported operation type for Eltwise emitter"); @@ -302,7 +309,11 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener const std::vector& eltwise_data, const std::vector& ops_list, const dnnl::post_ops& post_ops) - : jit_uni_eltwise_kernel(jep), jit_generator(jit_name()), eltwise_data_(eltwise_data), ops_list_(ops_list), post_ops_(post_ops) {} + : jit_uni_eltwise_kernel(jep), + jit_generator(jit_name()), + eltwise_data_(eltwise_data), + ops_list_(ops_list), + post_ops_(post_ops) {} void create_ker() override { jit_generator::create_kernel(); @@ -322,14 +333,18 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener if (!p->entry_[i].is_quantization()) { OPENVINO_THROW("Eltwise jitter error. Unsupported post op detected"); } - quantization_injectors.push_back(std::make_shared>( - this, p->entry_[i], vmm_d_weights, vmm_d_bias, reg_d_weights, reg_d_bias)); + quantization_injectors.push_back(std::make_shared>(this, + p->entry_[i], + vmm_d_weights, + vmm_d_bias, + reg_d_weights, + reg_d_bias)); } if (mayiuse(avx512_core) || mayiuse(avx2_vnni_2)) uni_vcvtneps2bf16.reset(new jit_uni_vcvtneps2bf16(this, isa)); - const auto &jep = jep_; + const auto& jep = jep_; this->preamble(); @@ -435,7 +450,11 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener for (size_t j = 0; j < min_src_size / vec_step; j++) { for (size_t i = 0; i < jep.inputs_number; i++) { if (jep.src_size[i] != 1) - load_vector(get_vmm_reg(i), ptr[get_src_reg(i) + j * vec_step * jep.src_prc[i].size()], jep.src_prc[i], exec_prc, false); + load_vector(get_vmm_reg(i), + ptr[get_src_reg(i) + j * vec_step * jep.src_prc[i].size()], + jep.src_prc[i], + exec_prc, + false); } compute_eltwise_op(); @@ -449,7 +468,10 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener for (size_t j = tail_start; j < min_src_size; j++) { for (size_t i = 0; i < jep.inputs_number; i++) { if (jep.src_size[i] != 1) - load_scalar(get_xmm_reg(i), ptr[get_src_reg(i) + j * jep.src_prc[i].size()], jep.src_prc[i], exec_prc); + load_scalar(get_xmm_reg(i), + ptr[get_src_reg(i) + j * jep.src_prc[i].size()], + jep.src_prc[i], + exec_prc); } compute_eltwise_op(); @@ -571,7 +593,7 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener } Reg64 reg_post_op_ptrs = rax; - Reg64 start_to_offsets = reg_post_op_ptrs; // rax + Reg64 start_to_offsets = reg_post_op_ptrs; // rax Reg64 reg_dst = rbx; Reg64 reg_work_amount = rdx; @@ -606,67 +628,64 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener const dnnl::post_ops& post_ops_; std::shared_ptr create_eltwise_emitter(const EltwiseData& data, ov::element::Type exec_prec) { - EltwiseEmitterContext ctx = { - nullptr, - this, - isa, - data, - exec_prec - }; - - OV_SWITCH(intel_cpu, EltwiseEmitter, ctx, data.algo, - OV_CASE(Algorithm::EltwiseRelu, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseGeluErf, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseGeluTanh, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseElu, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseTanh, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseSigmoid, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseAbs, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseSqrt, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseSoftRelu, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseClamp, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseSwish, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseHswish, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseMish, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseHsigmoid, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseRoundHalfToEven, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseRoundHalfAwayFromZero, jit_dnnl_aux_emitter), - OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter), - OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter), - OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter), - OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter), - OV_CASE(Algorithm::EltwiseDivide, jit_divide_emitter), - OV_CASE(Algorithm::EltwiseFloor, jit_floor_emitter), - OV_CASE(Algorithm::EltwiseCeiling, jit_ceiling_emitter), - OV_CASE(Algorithm::EltwiseFloorMod, jit_floor_mod_emitter), - OV_CASE(Algorithm::EltwiseMod, jit_mod_emitter), - OV_CASE(Algorithm::EltwiseMaximum, jit_maximum_emitter), - OV_CASE(Algorithm::EltwiseMinimum, jit_minimum_emitter), - OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter), - OV_CASE(Algorithm::EltwiseSquaredDifference, jit_squared_difference_emitter), - OV_CASE(Algorithm::EltwisePowerDynamic, jit_power_dynamic_emitter), - OV_CASE(Algorithm::EltwiseEqual, jit_equal_emitter), - OV_CASE(Algorithm::EltwiseNotEqual, jit_not_equal_emitter), - OV_CASE(Algorithm::EltwiseGreater, jit_greater_emitter), - OV_CASE(Algorithm::EltwiseGreaterEqual, jit_greater_equal_emitter), - OV_CASE(Algorithm::EltwiseLess, jit_less_emitter), - OV_CASE(Algorithm::EltwiseLessEqual, jit_less_equal_emitter), - OV_CASE(Algorithm::EltwiseLogicalAnd, jit_logical_and_emitter), - OV_CASE(Algorithm::EltwiseLogicalOr, jit_logical_or_emitter), - OV_CASE(Algorithm::EltwiseLogicalXor, jit_logical_xor_emitter), - OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter), - OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), - OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter), - OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter), - OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter), - OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter), - OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter), - OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter), - OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter), - OV_CASE(Algorithm::EltwiseBitwiseAnd, jit_bitwise_and_emitter), - OV_CASE(Algorithm::EltwiseBitwiseNot, jit_bitwise_not_emitter), - OV_CASE(Algorithm::EltwiseBitwiseOr, jit_bitwise_or_emitter), - OV_CASE(Algorithm::EltwiseBitwiseXor, jit_bitwise_xor_emitter)); + EltwiseEmitterContext ctx = {nullptr, this, isa, data, exec_prec}; + + OV_SWITCH(intel_cpu, + EltwiseEmitter, + ctx, + data.algo, + OV_CASE(Algorithm::EltwiseRelu, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseGeluErf, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseGeluTanh, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseElu, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseTanh, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseSigmoid, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseAbs, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseSqrt, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseSoftRelu, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseClamp, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseSwish, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseHswish, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseMish, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseHsigmoid, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseRoundHalfToEven, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseRoundHalfAwayFromZero, jit_dnnl_aux_emitter), + OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter), + OV_CASE(Algorithm::EltwiseMulAdd, jit_mul_add_emitter), + OV_CASE(Algorithm::EltwiseSubtract, jit_subtract_emitter), + OV_CASE(Algorithm::EltwiseMultiply, jit_multiply_emitter), + OV_CASE(Algorithm::EltwiseDivide, jit_divide_emitter), + OV_CASE(Algorithm::EltwiseFloor, jit_floor_emitter), + OV_CASE(Algorithm::EltwiseCeiling, jit_ceiling_emitter), + OV_CASE(Algorithm::EltwiseFloorMod, jit_floor_mod_emitter), + OV_CASE(Algorithm::EltwiseMod, jit_mod_emitter), + OV_CASE(Algorithm::EltwiseMaximum, jit_maximum_emitter), + OV_CASE(Algorithm::EltwiseMinimum, jit_minimum_emitter), + OV_CASE(Algorithm::EltwiseExp, jit_exp_emitter), + OV_CASE(Algorithm::EltwiseSquaredDifference, jit_squared_difference_emitter), + OV_CASE(Algorithm::EltwisePowerDynamic, jit_power_dynamic_emitter), + OV_CASE(Algorithm::EltwiseEqual, jit_equal_emitter), + OV_CASE(Algorithm::EltwiseNotEqual, jit_not_equal_emitter), + OV_CASE(Algorithm::EltwiseGreater, jit_greater_emitter), + OV_CASE(Algorithm::EltwiseGreaterEqual, jit_greater_equal_emitter), + OV_CASE(Algorithm::EltwiseLess, jit_less_emitter), + OV_CASE(Algorithm::EltwiseLessEqual, jit_less_equal_emitter), + OV_CASE(Algorithm::EltwiseLogicalAnd, jit_logical_and_emitter), + OV_CASE(Algorithm::EltwiseLogicalOr, jit_logical_or_emitter), + OV_CASE(Algorithm::EltwiseLogicalXor, jit_logical_xor_emitter), + OV_CASE(Algorithm::EltwiseLogicalNot, jit_logical_not_emitter), + OV_CASE(Algorithm::EltwisePowerStatic, jit_power_static_emitter), + OV_CASE(Algorithm::EltwisePrelu, jit_prelu_emitter), + OV_CASE(Algorithm::EltwiseErf, jit_erf_emitter), + OV_CASE(Algorithm::EltwiseSoftSign, jit_soft_sign_emitter), + OV_CASE(Algorithm::EltwiseIsFinite, jit_is_finite_emitter), + OV_CASE(Algorithm::EltwiseIsInf, jit_is_inf_emitter), + OV_CASE(Algorithm::EltwiseIsNaN, jit_is_nan_emitter), + OV_CASE(Algorithm::EltwiseSelect, jit_select_emitter), + OV_CASE(Algorithm::EltwiseBitwiseAnd, jit_bitwise_and_emitter), + OV_CASE(Algorithm::EltwiseBitwiseNot, jit_bitwise_not_emitter), + OV_CASE(Algorithm::EltwiseBitwiseOr, jit_bitwise_or_emitter), + OV_CASE(Algorithm::EltwiseBitwiseXor, jit_bitwise_xor_emitter)); if (!ctx.emitter) OPENVINO_THROW("Unsupported operation type for Eltwise emitter"); @@ -714,17 +733,31 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener bool do_rounding = do_dequantization || jep_.dst_prc == ov::element::f32 || i != ops_list_.size() - 1; int s_idx = vmm_dst.getIdx(); - size_t ptrs_table_off = quantization_post_op_idx * quantization_injectors[quantization_post_op_idx]->memoryStep(); - - quantization_injectors[quantization_post_op_idx]->init_crop_ptrs(reg_post_op_ptrs + ptrs_table_off, reg_oc_off); - quantization_injectors[quantization_post_op_idx]->compute_crop(s_idx, s_idx + 1, offset, is_scalar, jep_.oc_size == 1); - - quantization_injectors[quantization_post_op_idx]->init_input_scale_shift_ptrs(reg_post_op_ptrs + ptrs_table_off, reg_oc_off); - quantization_injectors[quantization_post_op_idx]->compute_input_scale_shift(s_idx, s_idx + 1, offset, do_rounding, - is_scalar, jep_.oc_size == 1); - - quantization_injectors[quantization_post_op_idx]->init_output_scale_shift_ptrs(reg_post_op_ptrs + ptrs_table_off, reg_oc_off); - quantization_injectors[quantization_post_op_idx]->compute_output_scale_shift(s_idx, s_idx + 1, offset, is_scalar, jep_.oc_size == 1); + size_t ptrs_table_off = + quantization_post_op_idx * quantization_injectors[quantization_post_op_idx]->memoryStep(); + + quantization_injectors[quantization_post_op_idx]->init_crop_ptrs(reg_post_op_ptrs + ptrs_table_off, + reg_oc_off); + quantization_injectors[quantization_post_op_idx]->compute_crop(s_idx, + s_idx + 1, + offset, + is_scalar, + jep_.oc_size == 1); + + quantization_injectors[quantization_post_op_idx]->init_input_scale_shift_ptrs( + reg_post_op_ptrs + ptrs_table_off, + reg_oc_off); + quantization_injectors[quantization_post_op_idx] + ->compute_input_scale_shift(s_idx, s_idx + 1, offset, do_rounding, is_scalar, jep_.oc_size == 1); + + quantization_injectors[quantization_post_op_idx]->init_output_scale_shift_ptrs( + reg_post_op_ptrs + ptrs_table_off, + reg_oc_off); + quantization_injectors[quantization_post_op_idx]->compute_output_scale_shift(s_idx, + s_idx + 1, + offset, + is_scalar, + jep_.oc_size == 1); quantization_post_op_idx++; } else { @@ -733,7 +766,11 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener } } - inline void load_vector(Vmm vmm_src, const Xbyak::Address &op, ov::element::Type src_prc, ov::element::Type dst_prc, bool broadcast) { + inline void load_vector(Vmm vmm_src, + const Xbyak::Address& op, + ov::element::Type src_prc, + ov::element::Type dst_prc, + bool broadcast) { Xmm xmm_src = Xmm(vmm_src.getIdx()); if (src_prc == dst_prc) { @@ -751,120 +788,126 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener uni_vbroadcastss(vmm_src, xmm_src); } else { switch (src_prc) { - case ov::element::f32: - case ov::element::i32: - uni_vmovups(vmm_src, op); - break; - case ov::element::bf16: - vpmovzxwd(vmm_src, op); - uni_vpslld(vmm_src, vmm_src, 16); - break; - case ov::element::f16: - vcvtph2ps(vmm_src, op); - break; - case ov::element::u16: - uni_vpmovzxwd(vmm_src, op); - break; - case ov::element::i16: - uni_vpmovsxwd(vmm_src, op); - break; - case ov::element::i8: - uni_vpmovsxbd(vmm_src, op); - break; - case ov::element::u8: - uni_vpmovzxbd(vmm_src, op); - break; - default: - OPENVINO_THROW("unknown src_prc"); - } - - switch (dst_prc) { - case ov::element::f32: - if (!src_prc.is_real()) - uni_vcvtdq2ps(vmm_src, vmm_src); - break; - case ov::element::i32: - if (src_prc.is_real()) - uni_vcvtps2dq(vmm_src, vmm_src); - break; - default: - OPENVINO_THROW("unknown dst_prc"); - } - } - } - - inline void load_scalar(Xmm xmm_src, const Xbyak::Address &op, ov::element::Type src_prc, ov::element::Type dst_prc) { - if (src_prc == dst_prc) { - switch (src_prc.size()) { - case 4: - uni_vmovss(xmm_src, op); - break; - case 1: - mov(reg_tmp_8, op); - movzx(reg_tmp_32, reg_tmp_8); - uni_vmovd(xmm_src, reg_tmp_32); - break; - default: - OPENVINO_THROW("unknown prc"); - } - return; - } - - switch (src_prc) { case ov::element::f32: case ov::element::i32: - uni_vmovss(xmm_src, op); + uni_vmovups(vmm_src, op); break; case ov::element::bf16: - if (isa == x64::avx2_vnni_2) { - vbcstnebf162ps(xmm_src, op); - } else { - uni_vpinsrw(xmm_src, xmm_src, op, 0); - uni_vpslld(xmm_src, xmm_src, 16); - } + vpmovzxwd(vmm_src, op); + uni_vpslld(vmm_src, vmm_src, 16); break; case ov::element::f16: - if (isa == x64::avx2_vnni_2) { - vbcstnesh2ps(xmm_src, op); - } else { - vcvtph2ps(xmm_src, op); - } - break; - case ov::element::i16: - uni_vpinsrw(xmm_src, xmm_src, op, 0); - uni_vpmovsxwd(xmm_src, op); + vcvtph2ps(vmm_src, op); break; case ov::element::u16: - uni_vpinsrw(xmm_src, xmm_src, op, 0); - uni_vpmovzxwd(xmm_src, op); + uni_vpmovzxwd(vmm_src, op); + break; + case ov::element::i16: + uni_vpmovsxwd(vmm_src, op); break; case ov::element::i8: - movsx(reg_tmp_32, op); - uni_vmovq(xmm_src, reg_tmp_64); + uni_vpmovsxbd(vmm_src, op); break; case ov::element::u8: - movzx(reg_tmp_32, op); - uni_vmovq(xmm_src, reg_tmp_64); + uni_vpmovzxbd(vmm_src, op); break; default: OPENVINO_THROW("unknown src_prc"); - } + } - switch (dst_prc) { + switch (dst_prc) { case ov::element::f32: if (!src_prc.is_real()) - uni_vcvtdq2ps(xmm_src, xmm_src); + uni_vcvtdq2ps(vmm_src, vmm_src); break; case ov::element::i32: if (src_prc.is_real()) - uni_vcvtps2dq(xmm_src, xmm_src); + uni_vcvtps2dq(vmm_src, vmm_src); break; default: OPENVINO_THROW("unknown dst_prc"); + } + } + } + + inline void load_scalar(Xmm xmm_src, + const Xbyak::Address& op, + ov::element::Type src_prc, + ov::element::Type dst_prc) { + if (src_prc == dst_prc) { + switch (src_prc.size()) { + case 4: + uni_vmovss(xmm_src, op); + break; + case 1: + mov(reg_tmp_8, op); + movzx(reg_tmp_32, reg_tmp_8); + uni_vmovd(xmm_src, reg_tmp_32); + break; + default: + OPENVINO_THROW("unknown prc"); + } + return; + } + + switch (src_prc) { + case ov::element::f32: + case ov::element::i32: + uni_vmovss(xmm_src, op); + break; + case ov::element::bf16: + if (isa == x64::avx2_vnni_2) { + vbcstnebf162ps(xmm_src, op); + } else { + uni_vpinsrw(xmm_src, xmm_src, op, 0); + uni_vpslld(xmm_src, xmm_src, 16); + } + break; + case ov::element::f16: + if (isa == x64::avx2_vnni_2) { + vbcstnesh2ps(xmm_src, op); + } else { + vcvtph2ps(xmm_src, op); + } + break; + case ov::element::i16: + uni_vpinsrw(xmm_src, xmm_src, op, 0); + uni_vpmovsxwd(xmm_src, op); + break; + case ov::element::u16: + uni_vpinsrw(xmm_src, xmm_src, op, 0); + uni_vpmovzxwd(xmm_src, op); + break; + case ov::element::i8: + movsx(reg_tmp_32, op); + uni_vmovq(xmm_src, reg_tmp_64); + break; + case ov::element::u8: + movzx(reg_tmp_32, op); + uni_vmovq(xmm_src, reg_tmp_64); + break; + default: + OPENVINO_THROW("unknown src_prc"); + } + + switch (dst_prc) { + case ov::element::f32: + if (!src_prc.is_real()) + uni_vcvtdq2ps(xmm_src, xmm_src); + break; + case ov::element::i32: + if (src_prc.is_real()) + uni_vcvtps2dq(xmm_src, xmm_src); + break; + default: + OPENVINO_THROW("unknown dst_prc"); } } - inline void store_vector(const Xbyak::Address &op, Vmm vmm_dst, ov::element::Type src_prc, ov::element::Type dst_prc) { + inline void store_vector(const Xbyak::Address& op, + Vmm vmm_dst, + ov::element::Type src_prc, + ov::element::Type dst_prc) { Xmm xmm_dst = Xmm(vmm_dst.getIdx()); Ymm ymm_dst = Ymm(vmm_dst.getIdx()); @@ -874,170 +917,173 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener } switch (src_prc) { - case ov::element::f32: - if (!dst_prc.is_real()) - uni_vcvtps2dq(vmm_dst, vmm_dst); - break; - case ov::element::i32: - if (dst_prc.is_real()) - uni_vcvtdq2ps(vmm_dst, vmm_dst); - break; - default: - OPENVINO_THROW("unknown src_prc"); + case ov::element::f32: + if (!dst_prc.is_real()) + uni_vcvtps2dq(vmm_dst, vmm_dst); + break; + case ov::element::i32: + if (dst_prc.is_real()) + uni_vcvtdq2ps(vmm_dst, vmm_dst); + break; + default: + OPENVINO_THROW("unknown src_prc"); } switch (dst_prc) { - case ov::element::f32: - case ov::element::i32: - uni_vmovups(op, vmm_dst); - break; - case ov::element::bf16: - if (isa == x64::avx512_core) { - uni_vcvtneps2bf16->emit_code({static_cast(vmm_dst.getIdx())}, - {static_cast(ymm_dst.getIdx())}); - vmovdqu16(op, ymm_dst); - } else { - uni_vcvtneps2bf16->emit_code({static_cast(vmm_dst.getIdx())}, - {static_cast(xmm_dst.getIdx())}); + case ov::element::f32: + case ov::element::i32: + uni_vmovups(op, vmm_dst); + break; + case ov::element::bf16: + if (isa == x64::avx512_core) { + uni_vcvtneps2bf16->emit_code({static_cast(vmm_dst.getIdx())}, + {static_cast(ymm_dst.getIdx())}); + vmovdqu16(op, ymm_dst); + } else { + uni_vcvtneps2bf16->emit_code({static_cast(vmm_dst.getIdx())}, + {static_cast(xmm_dst.getIdx())}); + uni_vmovdqu(op, xmm_dst); + } + break; + case ov::element::f16: + vcvtps2ph(op, vmm_dst, 0x4); + break; + case ov::element::i16: + if (isa == x64::avx512_core) { + vpmovsdw(op, vmm_dst); + } else { + uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst); + if (isa != x64::sse41) { + vpermq(ymm_dst, ymm_dst, 0x08); uni_vmovdqu(op, xmm_dst); - } - break; - case ov::element::f16: - vcvtps2ph(op, vmm_dst, 0x4); - break; - case ov::element::i16: - if (isa == x64::avx512_core) { - vpmovsdw(op, vmm_dst); } else { - uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst); - if (isa != x64::sse41) { - vpermq(ymm_dst, ymm_dst, 0x08); - uni_vmovdqu(op, xmm_dst); - } else { - movq(op, xmm_dst); - } + movq(op, xmm_dst); } - break; - case ov::element::u16: - if (isa == x64::avx512_core) { - vpmaxsd(vmm_dst, vmm_zero, vmm_dst); - vpmovusdw(op, vmm_dst); - } else { - uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst); - if (isa != x64::sse41) { - vpermq(ymm_dst, ymm_dst, 0x08); - uni_vmovdqu(op, xmm_dst); - } else { - movq(op, xmm_dst); - } - } - break; - case ov::element::i8: - if (isa == x64::avx512_core) { - vpmovsdb(op, vmm_dst); - } else { - uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst); - if (isa != x64::sse41) - vpermq(ymm_dst, ymm_dst, 0x08); - uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst); - if (isa != x64::sse41) - vmovq(op, xmm_dst); - else - movd(op, xmm_dst); - } - break; - case ov::element::u8: - if (isa == x64::avx512_core) { - vpmaxsd(vmm_dst, vmm_zero, vmm_dst); - vpmovusdb(op, vmm_dst); + } + break; + case ov::element::u16: + if (isa == x64::avx512_core) { + vpmaxsd(vmm_dst, vmm_zero, vmm_dst); + vpmovusdw(op, vmm_dst); + } else { + uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst); + if (isa != x64::sse41) { + vpermq(ymm_dst, ymm_dst, 0x08); + uni_vmovdqu(op, xmm_dst); } else { - uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst); - if (isa != x64::sse41) - vpermq(ymm_dst, ymm_dst, 0x08); - uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst); - if (isa != x64::sse41) - vmovq(op, xmm_dst); - else - movd(op, xmm_dst); + movq(op, xmm_dst); } - break; - default: - OPENVINO_THROW("unknown dst_prc"); + } + break; + case ov::element::i8: + if (isa == x64::avx512_core) { + vpmovsdb(op, vmm_dst); + } else { + uni_vpackssdw(vmm_dst, vmm_dst, vmm_dst); + if (isa != x64::sse41) + vpermq(ymm_dst, ymm_dst, 0x08); + uni_vpacksswb(vmm_dst, vmm_dst, vmm_dst); + if (isa != x64::sse41) + vmovq(op, xmm_dst); + else + movd(op, xmm_dst); + } + break; + case ov::element::u8: + if (isa == x64::avx512_core) { + vpmaxsd(vmm_dst, vmm_zero, vmm_dst); + vpmovusdb(op, vmm_dst); + } else { + uni_vpackusdw(vmm_dst, vmm_dst, vmm_dst); + if (isa != x64::sse41) + vpermq(ymm_dst, ymm_dst, 0x08); + uni_vpackuswb(vmm_dst, vmm_dst, vmm_dst); + if (isa != x64::sse41) + vmovq(op, xmm_dst); + else + movd(op, xmm_dst); + } + break; + default: + OPENVINO_THROW("unknown dst_prc"); } } - inline void store_scalar(const Xbyak::Address &op, Xmm xmm_dst, ov::element::Type src_prc, ov::element::Type dst_prc) { + inline void store_scalar(const Xbyak::Address& op, + Xmm xmm_dst, + ov::element::Type src_prc, + ov::element::Type dst_prc) { if (src_prc == dst_prc) { switch (src_prc.size()) { - case 4: - uni_vmovss(op, xmm_dst); - break; - case 1: - movq(reg_tmp_64, xmm_dst); - mov(op, reg_tmp_8); - break; - default: - OPENVINO_THROW("unknown prc"); + case 4: + uni_vmovss(op, xmm_dst); + break; + case 1: + movq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_8); + break; + default: + OPENVINO_THROW("unknown prc"); } return; } switch (src_prc) { - case ov::element::f32: - if (!dst_prc.is_real()) - uni_vcvtps2dq(xmm_dst, xmm_dst); - break; - case ov::element::i32: - if (dst_prc.is_real()) - uni_vcvtdq2ps(xmm_dst, xmm_dst); - break; - default: - OPENVINO_THROW("unknown src_prc"); + case ov::element::f32: + if (!dst_prc.is_real()) + uni_vcvtps2dq(xmm_dst, xmm_dst); + break; + case ov::element::i32: + if (dst_prc.is_real()) + uni_vcvtdq2ps(xmm_dst, xmm_dst); + break; + default: + OPENVINO_THROW("unknown src_prc"); } switch (dst_prc) { - case ov::element::f32: - case ov::element::i32: - uni_vmovss(op, xmm_dst); - break; - case ov::element::bf16: - uni_vpsrld(xmm_dst, xmm_dst, 16); - uni_vpextrw(op, xmm_dst, 0x0); - break; - case ov::element::f16: - vcvtps2ph(xmm_dst, xmm_dst, 0x4); - movq(reg_tmp_64, xmm_dst); - mov(op, reg_tmp_16); - break; - case ov::element::i16: - uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst); - movq(reg_tmp_64, xmm_dst); - mov(op, reg_tmp_16); - break; - case ov::element::u16: - uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst); - movq(reg_tmp_64, xmm_dst); - mov(op, reg_tmp_16); - break; - case ov::element::i8: - uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst); - uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst); - movq(reg_tmp_64, xmm_dst); - mov(op, reg_tmp_8); - break; - case ov::element::u8: - uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst); - uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst); - movq(reg_tmp_64, xmm_dst); - mov(op, reg_tmp_8); - break; - default: - OPENVINO_THROW("unknown dst_prc"); + case ov::element::f32: + case ov::element::i32: + uni_vmovss(op, xmm_dst); + break; + case ov::element::bf16: + uni_vpsrld(xmm_dst, xmm_dst, 16); + uni_vpextrw(op, xmm_dst, 0x0); + break; + case ov::element::f16: + vcvtps2ph(xmm_dst, xmm_dst, 0x4); + movq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_16); + break; + case ov::element::i16: + uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst); + movq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_16); + break; + case ov::element::u16: + uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst); + movq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_16); + break; + case ov::element::i8: + uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst); + uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst); + movq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_8); + break; + case ov::element::u8: + uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst); + uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst); + movq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_8); + break; + default: + OPENVINO_THROW("unknown dst_prc"); } } }; -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 Eltwise::BroadcastingPolicy Eltwise::determineBroadcastingPolicy(const std::shared_ptr& op) { const auto const1 = ov::as_type_ptr(op->get_input_node_shared_ptr(0)); @@ -1297,7 +1343,6 @@ const std::map& Eltwise::getIn return initializers; } - namespace { struct EltwiseKey { @@ -1353,12 +1398,8 @@ struct EltwiseKey { return false; } - bool result = eltwise_data == rhs.eltwise_data && - ops_list == rhs.ops_list && - inpPrc == rhs.inpPrc && - outPrc == rhs.outPrc && - *postOps.get() == *rhs.postOps.get() && - implType == rhs.implType; + bool result = eltwise_data == rhs.eltwise_data && ops_list == rhs.ops_list && inpPrc == rhs.inpPrc && + outPrc == rhs.outPrc && *postOps.get() == *rhs.postOps.get() && implType == rhs.implType; if (result) { if (implType == EltwiseImplType::optimizedShapeAgnostic) { @@ -1370,8 +1411,7 @@ struct EltwiseKey { return false; } } else { - result = result && outOrder == rhs.outOrder && - outBlkDims == rhs.outBlkDims; + result = result && outOrder == rhs.outOrder && outBlkDims == rhs.outBlkDims; for (size_t i = 0; i < inpDims.size() && result; ++i) { result = result && (inpDims[i] == rhs.inpDims[i]); } @@ -1426,7 +1466,8 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { auto collapseLastOffsets = [](std::vector& dims, int dimsToCollapse) { for (size_t i = dims.size() - 2; i > dims.size() - dimsToCollapse - 2; i--) { if (dims[dims.size() - 1] > 0 || dims[i] > 0) - dims[dims.size() - 1] = std::max(dims[dims.size() - 1], static_cast(1)) * std::max(dims[i], static_cast(1)); + dims[dims.size() - 1] = std::max(dims[dims.size() - 1], static_cast(1)) * + std::max(dims[i], static_cast(1)); else dims[dims.size() - 1] *= dims[i]; } @@ -1442,8 +1483,10 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { auto isFusedWith = [&](Type type_) { auto start_itr = ops_list.begin(); - std::advance(start_itr, 1); // apply offset since the first op in the list is the op itself - return any_of(start_itr, ops_list.end(), [=](Type type) { return type == type_; }); + std::advance(start_itr, 1); // apply offset since the first op in the list is the op itself + return any_of(start_itr, ops_list.end(), [=](Type type) { + return type == type_; + }); }; if (inpDims.empty()) { @@ -1493,7 +1536,8 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { int oc_dim_idx = i + (jep.input_size - outOrder.size()); jep.oc_offsets[oc_dim_idx] = offset_oc; offset_oc *= jep.dims[oc_dim_idx]; - if (oc_dim_idx + 1 != static_cast(jep.input_size)) { // since in nspc case we can safely collapse the last axis + if (oc_dim_idx + 1 != + static_cast(jep.input_size)) { // since in nspc case we can safely collapse the last axis lastUnchangedAxis = oc_dim_idx; } } @@ -1514,7 +1558,8 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { int collapsedDims = 0; bool hasDifferentDims = false; - while (!useRuntimePtrs && currentJitWorkAmount < minimalJitWorkAmount && currentJitWorkAmount < fullWorkAmount) { + while (!useRuntimePtrs && currentJitWorkAmount < minimalJitWorkAmount && + currentJitWorkAmount < fullWorkAmount) { if (collapsedDims >= maxCollapsedDims) break; @@ -1595,8 +1640,9 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { jep.work_amount = jep.dst_size = jep.dims.back(); jep.oc_size = oc_size; - std::transform(jep.oc_offsets.begin(), jep.oc_offsets.end(), jep.oc_offsets.begin(), - [](size_t& offset) { return offset * sizeof(float);}); + std::transform(jep.oc_offsets.begin(), jep.oc_offsets.end(), jep.oc_offsets.begin(), [](size_t& offset) { + return offset * sizeof(float); + }); #if defined(OPENVINO_ARCH_X86_64) if (mayiuse(x64::avx512_core)) { @@ -1608,7 +1654,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { } else { OPENVINO_THROW("Can't create jit eltwise kernel"); } -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 #if defined(OPENVINO_ARCH_ARM64) if (mayiuse(aarch64::asimd)) { @@ -1616,28 +1662,28 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { } else { OPENVINO_THROW("Can't create jit eltwise kernel"); } -#endif // OPENVINO_ARCH_ARM64 +#endif // OPENVINO_ARCH_ARM64 if (_pKernel) _pKernel->create_ker(); } - void exec(const jit_eltwise_call_args_ptrs &args_ptrs, const VectorDims &dims_out) override { + void exec(const jit_eltwise_call_args_ptrs& args_ptrs, const VectorDims& dims_out) override { if (!_pKernel) OPENVINO_THROW("Can't execute, kernel for eltwise node is not compiled"); if (_pKernel->jep_.input_size == optimalTensorRank) { // execute Optimized 6D auto d6_loop = [&](size_t i0, size_t i1, size_t i2, size_t i3, size_t i4) { - auto args = jit_eltwise_call_args_indexes(); - args.indexes[0] = i0; - args.indexes[1] = i1; - args.indexes[2] = i2; - args.indexes[3] = i3; - args.indexes[4] = i4; + auto args = jit_eltwise_call_args_indexes(); + args.indexes[0] = i0; + args.indexes[1] = i1; + args.indexes[2] = i2; + args.indexes[3] = i3; + args.indexes[4] = i4; - (*_pKernel)(&args_ptrs, &args); - }; + (*_pKernel)(&args_ptrs, &args); + }; parallel_nt_static(m_threads_num, [&](const int ithr, const int nthr) { for_5d(ithr, nthr, dims_out[0], dims_out[1], dims_out[2], dims_out[3], dims_out[4], d6_loop); @@ -1693,13 +1739,14 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { /* enabled only for float at float16_t at the moment * can be extended in the future */ -template +template class EltwiseRefBaseExecutor : public Eltwise::IEltwiseExecutor { public: EltwiseRefBaseExecutor(const EltwiseData& opData, const VectorDims& outBlkDims, const std::vector& inpDims) - : _opData(std::move(opData)), _inpDims(inpDims) { + : _opData(std::move(opData)), + _inpDims(inpDims) { if (inpDims.empty()) { OPENVINO_THROW("Can not make Eltwise executor from empty input dims array"); } else if (inpDims.front().empty()) { @@ -1750,18 +1797,18 @@ class EltwiseRefBaseExecutor : public Eltwise::IEltwiseExecutor { protected: void init_ptr(const jit_eltwise_call_args_ptrs& args_ptrs, - const VectorDims& dims_out, - std::vector& counters, - const size_t iwork, - std::vector& src_f, - T*& dst_ptr_f) { + const VectorDims& dims_out, + std::vector& counters, + const size_t iwork, + std::vector& src_f, + T*& dst_ptr_f) { size_t tmp = iwork; for (ptrdiff_t j = dims_out.size() - 1; j >= 0; j--) { counters[j] = tmp % dims_out[j]; tmp /= dims_out[j]; } - size_t index_in[MAX_ELTWISE_INPUTS] = { 0 }; + size_t index_in[MAX_ELTWISE_INPUTS] = {0}; for (size_t i = 0; i < _inputNum; i++) { index_in[i] = 0; for (size_t j = 0; j < counters.size(); j++) { @@ -1776,7 +1823,7 @@ class EltwiseRefBaseExecutor : public Eltwise::IEltwiseExecutor { } index_out /= sizeof(T); - //std::vector src_f(_inputNum); + // std::vector src_f(_inputNum); for (size_t i = 0; i < _inputNum; i++) { src_f[i] = (reinterpret_cast(args_ptrs.src_ptr[i]) + index_in[i])[0]; } @@ -1795,19 +1842,15 @@ class EltwiseRefBaseExecutor : public Eltwise::IEltwiseExecutor { /* enabled only for float at float16_t at the moment * can be extended in the future */ -template::value || - std::is_same::value> - ::type * = nullptr> +template ::value || + std::is_same::value>::type* = nullptr> class EltwiseRefExecutor : public EltwiseRefBaseExecutor { public: - EltwiseRefExecutor(const EltwiseData& opData, - const VectorDims& outBlkDims, - std::vector inpDims) : EltwiseRefBaseExecutor(opData, outBlkDims, inpDims) { - } + EltwiseRefExecutor(const EltwiseData& opData, const VectorDims& outBlkDims, std::vector inpDims) + : EltwiseRefBaseExecutor(opData, outBlkDims, inpDims) {} - void exec(const jit_eltwise_call_args_ptrs &args_ptrs, const VectorDims &dims_out) override { + void exec(const jit_eltwise_call_args_ptrs& args_ptrs, const VectorDims& dims_out) override { if (this->_opData.algo == Algorithm::EltwiseLog) { const T* src_ptr_f = reinterpret_cast(args_ptrs.src_ptr[0]); T* dst_ptr_f = reinterpret_cast(args_ptrs.dst_ptr); @@ -1857,8 +1900,11 @@ class EltwiseRefExecutor : public EltwiseRefBaseExecutor { std::shared_ptr ref_eltwise_injector = nullptr; if (this->_opData.onednnAlgorithm != dnnl::algorithm::undef) { - ref_eltwise_injector = std::make_shared( - static_cast(this->_opData.onednnAlgorithm), this->_opData.alpha, this->_opData.beta, 1.f); + ref_eltwise_injector = + std::make_shared(static_cast(this->_opData.onednnAlgorithm), + this->_opData.alpha, + this->_opData.beta, + 1.f); } parallel_nt(0, [&](const int ithr, const int nthr) { @@ -1873,86 +1919,144 @@ class EltwiseRefExecutor : public EltwiseRefBaseExecutor { this->init_ptr(args_ptrs, dims_out, counters, iwork, src_f, dst_ptr_f); switch (this->_opData.algo) { - case Algorithm::EltwiseRelu: - case Algorithm::EltwiseGeluErf: - case Algorithm::EltwiseGeluTanh: - case Algorithm::EltwiseElu: - case Algorithm::EltwiseTanh: - case Algorithm::EltwiseSigmoid: - case Algorithm::EltwiseAbs: - case Algorithm::EltwiseSqrt: - case Algorithm::EltwiseSoftRelu: - case Algorithm::EltwiseClamp: - case Algorithm::EltwiseSwish: - case Algorithm::EltwiseHswish: - case Algorithm::EltwiseMish: - case Algorithm::EltwiseHsigmoid: - case Algorithm::EltwiseRoundHalfToEven: - case Algorithm::EltwiseRoundHalfAwayFromZero: - *dst_ptr_f = ref_eltwise_injector->compute_scalar(src_f[0]); - break; - case Algorithm::EltwiseAdd: *dst_ptr_f = src_f[0] + src_f[1]; break; - case Algorithm::EltwiseMulAdd: *dst_ptr_f = src_f[0] * src_f[1] + src_f[2]; break; - case Algorithm::EltwiseSubtract: *dst_ptr_f = src_f[0] - src_f[1]; break; - case Algorithm::EltwiseMultiply: *dst_ptr_f = src_f[0] * src_f[1]; break; - case Algorithm::EltwiseDivide: *dst_ptr_f = src_f[0] / src_f[1]; break; - case Algorithm::EltwiseCeiling: *dst_ptr_f = ceilf(src_f[0]); break; - case Algorithm::EltwiseFloor: *dst_ptr_f = floorf(src_f[0]); break; - case Algorithm::EltwiseFloorMod: *dst_ptr_f = src_f[0] - floorf(src_f[0] / src_f[1]) * src_f[1]; break; - case Algorithm::EltwiseMod: *dst_ptr_f = src_f[0] - truncf(src_f[0] / src_f[1]) * src_f[1]; break; - case Algorithm::EltwiseMaximum: *dst_ptr_f = std::max(src_f[0], src_f[1]); break; - case Algorithm::EltwiseMinimum: *dst_ptr_f = std::min(src_f[0], src_f[1]); break; - case Algorithm::EltwiseExp: *dst_ptr_f = expf(src_f[0]); break; - case Algorithm::EltwiseSquaredDifference: *dst_ptr_f = powf((src_f[0] - src_f[1]), 2.f); break; - case Algorithm::EltwisePowerDynamic: *dst_ptr_f = powf(src_f[0], src_f[1]); break; - case Algorithm::EltwiseEqual: *dst_ptr_f = src_f[0] == src_f[1]; break; - case Algorithm::EltwiseNotEqual: *dst_ptr_f = src_f[0] != src_f[1]; break; - case Algorithm::EltwiseGreater: *dst_ptr_f = src_f[0] > src_f[1]; break; - case Algorithm::EltwiseGreaterEqual: *dst_ptr_f = src_f[0] >= src_f[1]; break; - case Algorithm::EltwiseLess: *dst_ptr_f = src_f[0] < src_f[1]; break; - case Algorithm::EltwiseLessEqual: *dst_ptr_f = src_f[0] <= src_f[1]; break; - case Algorithm::EltwiseLogicalAnd: *dst_ptr_f = src_f[0] && src_f[1]; break; - case Algorithm::EltwiseLogicalOr: *dst_ptr_f = src_f[0] || src_f[1]; break; - case Algorithm::EltwiseLogicalXor: *dst_ptr_f = (src_f[0] || src_f[1]) - (src_f[0] && src_f[1]); break; - case Algorithm::EltwiseLogicalNot: *dst_ptr_f = !src_f[0]; break; - case Algorithm::EltwisePrelu: *dst_ptr_f = src_f[0] > 0 ? src_f[0] : static_cast(src_f[0] * src_f[1]); break; - case Algorithm::EltwiseErf: *dst_ptr_f = std::erf(src_f[0]); break; - case Algorithm::EltwiseSoftSign: *dst_ptr_f = src_f[0] / (1 + std::fabs(src_f[0])); break; - // @todo implement proper isinfinite for non-float precisions - case Algorithm::EltwiseIsFinite: *dst_ptr_f = std::isfinite(static_cast(src_f[0])); break; - case Algorithm::EltwiseIsInf: - *dst_ptr_f = (this->_opData.alpha && (src_f[0] == -std::numeric_limits::infinity())) || - (this->_opData.beta && (src_f[0] == std::numeric_limits::infinity())); - break; - case Algorithm::EltwiseIsNaN: *dst_ptr_f = std::isnan(src_f[0]); break; - case Algorithm::EltwiseSelect: *dst_ptr_f = src_f[0] ? src_f[1] : src_f[2]; break; - default: OPENVINO_THROW("Unsupported operation type for Eltwise executor"); + case Algorithm::EltwiseRelu: + case Algorithm::EltwiseGeluErf: + case Algorithm::EltwiseGeluTanh: + case Algorithm::EltwiseElu: + case Algorithm::EltwiseTanh: + case Algorithm::EltwiseSigmoid: + case Algorithm::EltwiseAbs: + case Algorithm::EltwiseSqrt: + case Algorithm::EltwiseSoftRelu: + case Algorithm::EltwiseClamp: + case Algorithm::EltwiseSwish: + case Algorithm::EltwiseHswish: + case Algorithm::EltwiseMish: + case Algorithm::EltwiseHsigmoid: + case Algorithm::EltwiseRoundHalfToEven: + case Algorithm::EltwiseRoundHalfAwayFromZero: + *dst_ptr_f = ref_eltwise_injector->compute_scalar(src_f[0]); + break; + case Algorithm::EltwiseAdd: + *dst_ptr_f = src_f[0] + src_f[1]; + break; + case Algorithm::EltwiseMulAdd: + *dst_ptr_f = src_f[0] * src_f[1] + src_f[2]; + break; + case Algorithm::EltwiseSubtract: + *dst_ptr_f = src_f[0] - src_f[1]; + break; + case Algorithm::EltwiseMultiply: + *dst_ptr_f = src_f[0] * src_f[1]; + break; + case Algorithm::EltwiseDivide: + *dst_ptr_f = src_f[0] / src_f[1]; + break; + case Algorithm::EltwiseCeiling: + *dst_ptr_f = ceilf(src_f[0]); + break; + case Algorithm::EltwiseFloor: + *dst_ptr_f = floorf(src_f[0]); + break; + case Algorithm::EltwiseFloorMod: + *dst_ptr_f = src_f[0] - floorf(src_f[0] / src_f[1]) * src_f[1]; + break; + case Algorithm::EltwiseMod: + *dst_ptr_f = src_f[0] - truncf(src_f[0] / src_f[1]) * src_f[1]; + break; + case Algorithm::EltwiseMaximum: + *dst_ptr_f = std::max(src_f[0], src_f[1]); + break; + case Algorithm::EltwiseMinimum: + *dst_ptr_f = std::min(src_f[0], src_f[1]); + break; + case Algorithm::EltwiseExp: + *dst_ptr_f = expf(src_f[0]); + break; + case Algorithm::EltwiseSquaredDifference: + *dst_ptr_f = powf((src_f[0] - src_f[1]), 2.f); + break; + case Algorithm::EltwisePowerDynamic: + *dst_ptr_f = powf(src_f[0], src_f[1]); + break; + case Algorithm::EltwiseEqual: + *dst_ptr_f = src_f[0] == src_f[1]; + break; + case Algorithm::EltwiseNotEqual: + *dst_ptr_f = src_f[0] != src_f[1]; + break; + case Algorithm::EltwiseGreater: + *dst_ptr_f = src_f[0] > src_f[1]; + break; + case Algorithm::EltwiseGreaterEqual: + *dst_ptr_f = src_f[0] >= src_f[1]; + break; + case Algorithm::EltwiseLess: + *dst_ptr_f = src_f[0] < src_f[1]; + break; + case Algorithm::EltwiseLessEqual: + *dst_ptr_f = src_f[0] <= src_f[1]; + break; + case Algorithm::EltwiseLogicalAnd: + *dst_ptr_f = src_f[0] && src_f[1]; + break; + case Algorithm::EltwiseLogicalOr: + *dst_ptr_f = src_f[0] || src_f[1]; + break; + case Algorithm::EltwiseLogicalXor: + *dst_ptr_f = (src_f[0] || src_f[1]) - (src_f[0] && src_f[1]); + break; + case Algorithm::EltwiseLogicalNot: + *dst_ptr_f = !src_f[0]; + break; + case Algorithm::EltwisePrelu: + *dst_ptr_f = src_f[0] > 0 ? src_f[0] : static_cast(src_f[0] * src_f[1]); + break; + case Algorithm::EltwiseErf: + *dst_ptr_f = std::erf(src_f[0]); + break; + case Algorithm::EltwiseSoftSign: + *dst_ptr_f = src_f[0] / (1 + std::fabs(src_f[0])); + break; + // @todo implement proper isinfinite for non-float precisions + case Algorithm::EltwiseIsFinite: + *dst_ptr_f = std::isfinite(static_cast(src_f[0])); + break; + case Algorithm::EltwiseIsInf: + *dst_ptr_f = (this->_opData.alpha && (src_f[0] == -std::numeric_limits::infinity())) || + (this->_opData.beta && (src_f[0] == std::numeric_limits::infinity())); + break; + case Algorithm::EltwiseIsNaN: + *dst_ptr_f = std::isnan(src_f[0]); + break; + case Algorithm::EltwiseSelect: + *dst_ptr_f = src_f[0] ? src_f[1] : src_f[2]; + break; + default: + OPENVINO_THROW("Unsupported operation type for Eltwise executor"); } } }); } }; -template::value || - std::is_same::value || - std::is_same::value || - std::is_same::value || - std::is_same::value> - ::type * = nullptr> +template ::value || std::is_same::value || + std::is_same::value || std::is_same::value || + std::is_same::value>::type* = nullptr> class BitwiseRefExecutor : public EltwiseRefBaseExecutor { public: - BitwiseRefExecutor(const EltwiseData& opData, - const VectorDims& outBlkDims, - const std::vector& inpDims) : EltwiseRefBaseExecutor(opData, outBlkDims, inpDims) { - } + BitwiseRefExecutor(const EltwiseData& opData, const VectorDims& outBlkDims, const std::vector& inpDims) + : EltwiseRefBaseExecutor(opData, outBlkDims, inpDims) {} - void exec(const jit_eltwise_call_args_ptrs &args_ptrs, const VectorDims &dims_out) override { + void exec(const jit_eltwise_call_args_ptrs& args_ptrs, const VectorDims& dims_out) override { std::shared_ptr ref_eltwise_injector = nullptr; if (this->_opData.onednnAlgorithm != dnnl::algorithm::undef) { - ref_eltwise_injector = std::make_shared( - static_cast(this->_opData.onednnAlgorithm), this->_opData.alpha, this->_opData.beta, 1.f); + ref_eltwise_injector = + std::make_shared(static_cast(this->_opData.onednnAlgorithm), + this->_opData.alpha, + this->_opData.beta, + 1.f); } parallel_nt(0, [&](const int ithr, const int nthr) { @@ -1967,81 +2071,79 @@ class BitwiseRefExecutor : public EltwiseRefBaseExecutor { this->init_ptr(args_ptrs, dims_out, counters, iwork, src_f, dst_ptr_f); switch (this->_opData.algo) { - case Algorithm::EltwiseBitwiseAnd: { - *dst_ptr_f = src_f[0] & src_f[1]; - break; - } - case Algorithm::EltwiseBitwiseNot: { - *dst_ptr_f = ~src_f[0]; - break; - } - case Algorithm::EltwiseBitwiseOr: { - *dst_ptr_f = src_f[0] | src_f[1]; - break; - } - case Algorithm::EltwiseBitwiseXor: { - *dst_ptr_f = src_f[0] ^ src_f[1]; - break; - } - case Algorithm::EltwiseBitwiseLeftShift: { - *dst_ptr_f = src_f[0] << src_f[1]; - break; - } - case Algorithm::EltwiseBitwiseRightShift: { - *dst_ptr_f = src_f[0] >> src_f[1]; - break; - } - default: - OPENVINO_THROW("Unsupported operation type for Eltwise executor"); + case Algorithm::EltwiseBitwiseAnd: { + *dst_ptr_f = src_f[0] & src_f[1]; + break; + } + case Algorithm::EltwiseBitwiseNot: { + *dst_ptr_f = ~src_f[0]; + break; + } + case Algorithm::EltwiseBitwiseOr: { + *dst_ptr_f = src_f[0] | src_f[1]; + break; + } + case Algorithm::EltwiseBitwiseXor: { + *dst_ptr_f = src_f[0] ^ src_f[1]; + break; + } + case Algorithm::EltwiseBitwiseLeftShift: { + *dst_ptr_f = src_f[0] << src_f[1]; + break; + } + case Algorithm::EltwiseBitwiseRightShift: { + *dst_ptr_f = src_f[0] >> src_f[1]; + break; + } + default: + OPENVINO_THROW("Unsupported operation type for Eltwise executor"); } } }); } }; -} // namespace +} // namespace static Eltwise::executorPtr buildRefExecutor(const EltwiseKey& key) { switch (key.outPrc) { - case ov::element::f16: - return std::make_shared>(key.eltwise_data.front(), - key.outBlkDims, - key.inpDims); - case ov::element::i8: - return std::make_shared::value_type>>( - key.eltwise_data.front(), - key.outBlkDims, - key.inpDims); - - case ov::element::u8: - return std::make_shared::value_type>>( - key.eltwise_data.front(), - key.outBlkDims, - key.inpDims); - - case ov::element::i16: - return std::make_shared::value_type>>( - key.eltwise_data.front(), - key.outBlkDims, - key.inpDims); - - case ov::element::u16: - return std::make_shared::value_type>>( - key.eltwise_data.front(), - key.outBlkDims, - key.inpDims); + case ov::element::f16: + return std::make_shared>(key.eltwise_data.front(), + key.outBlkDims, + key.inpDims); + case ov::element::i8: + return std::make_shared::value_type>>( + key.eltwise_data.front(), + key.outBlkDims, + key.inpDims); + + case ov::element::u8: + return std::make_shared::value_type>>( + key.eltwise_data.front(), + key.outBlkDims, + key.inpDims); + + case ov::element::i16: + return std::make_shared::value_type>>( + key.eltwise_data.front(), + key.outBlkDims, + key.inpDims); + + case ov::element::u16: + return std::make_shared::value_type>>( + key.eltwise_data.front(), + key.outBlkDims, + key.inpDims); # - case ov::element::i32: - return std::make_shared::value_type>>( - key.eltwise_data.front(), - key.outBlkDims, - key.inpDims); + case ov::element::i32: + return std::make_shared::value_type>>( + key.eltwise_data.front(), + key.outBlkDims, + key.inpDims); - default: - // use float reference executor for any other precision for now - return std::make_shared>(key.eltwise_data.front(), - key.outBlkDims, - key.inpDims); + default: + // use float reference executor for any other precision for now + return std::make_shared>(key.eltwise_data.front(), key.outBlkDims, key.inpDims); } } @@ -2064,7 +2166,7 @@ static Eltwise::executorPtr buildExecutor(const EltwiseKey& key) { bool Eltwise::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { if (getInitializers().find(op->get_type_info()) == getInitializers().end()) { - errorMessage = "Doesn't support Eltwise algorithm: " + std::string(op->get_type_name()); + errorMessage = "Doesn't support Eltwise algorithm: " + std::string(op->get_type_name()); return false; } if (const auto binOp = ov::as_type_ptr(op)) { @@ -2087,8 +2189,9 @@ bool Eltwise::isSupportedOperation(const std::shared_ptr& op, st return true; } -Eltwise::Eltwise(const std::shared_ptr& op, const GraphContext::CPtr context) : - Node(op, context, EltwiseShapeInferFactory()), broadcastingPolicy(Undefined) { +Eltwise::Eltwise(const std::shared_ptr& op, const GraphContext::CPtr context) + : Node(op, context, EltwiseShapeInferFactory()), + broadcastingPolicy(Undefined) { std::string errorMessage; if (!isSupportedOperation(op, errorMessage)) { OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage); @@ -2098,67 +2201,68 @@ Eltwise::Eltwise(const std::shared_ptr& op, const GraphContext::CPtr c size_t Eltwise::getOpInputsNum() const { switch (getAlgorithm()) { - case Algorithm::EltwiseIsFinite: - case Algorithm::EltwiseIsInf: - case Algorithm::EltwiseIsNaN: - case Algorithm::EltwiseRelu: - case Algorithm::EltwiseGeluErf: - case Algorithm::EltwiseGeluTanh: - case Algorithm::EltwiseCeiling: - case Algorithm::EltwiseFloor: - case Algorithm::EltwiseElu: - case Algorithm::EltwiseTanh: - case Algorithm::EltwiseSigmoid: - case Algorithm::EltwiseAbs: - case Algorithm::EltwiseSqrt: - case Algorithm::EltwiseSoftRelu: - case Algorithm::EltwiseExp: - case Algorithm::EltwiseClamp: - case Algorithm::EltwiseErf: - case Algorithm::EltwiseLogicalNot: - case Algorithm::EltwisePowerStatic: - case Algorithm::EltwiseSwish: - case Algorithm::EltwiseHswish: - case Algorithm::EltwiseMish: - case Algorithm::EltwiseHsigmoid: - case Algorithm::EltwiseRoundHalfToEven: - case Algorithm::EltwiseRoundHalfAwayFromZero: - case Algorithm::EltwiseSoftSign: - case Algorithm::EltwiseLog: - return 1; - case Algorithm::EltwiseAdd: - case Algorithm::EltwiseSubtract: - case Algorithm::EltwiseMultiply: - case Algorithm::EltwiseDivide: - case Algorithm::EltwiseFloorMod: - case Algorithm::EltwiseMod: - case Algorithm::EltwiseMaximum: - case Algorithm::EltwiseMinimum: - case Algorithm::EltwiseSquaredDifference: - case Algorithm::EltwisePowerDynamic: - case Algorithm::EltwiseEqual: - case Algorithm::EltwiseNotEqual: - case Algorithm::EltwiseGreater: - case Algorithm::EltwiseGreaterEqual: - case Algorithm::EltwiseLess: - case Algorithm::EltwiseLessEqual: - case Algorithm::EltwiseLogicalAnd: - case Algorithm::EltwiseLogicalOr: - case Algorithm::EltwiseLogicalXor: - case Algorithm::EltwiseBitwiseAnd: - case Algorithm::EltwiseBitwiseOr: - case Algorithm::EltwiseBitwiseXor: - case Algorithm::EltwiseBitwiseLeftShift: - case Algorithm::EltwiseBitwiseRightShift: - return 2; - case Algorithm::EltwiseBitwiseNot: - return 1; - case Algorithm::EltwisePrelu: - return 2; - case Algorithm::EltwiseMulAdd: - case Algorithm::EltwiseSelect: - return 3; - default: OPENVINO_THROW("Unsupported operation for Eltwise node with name `", getName(), "`."); + case Algorithm::EltwiseIsFinite: + case Algorithm::EltwiseIsInf: + case Algorithm::EltwiseIsNaN: + case Algorithm::EltwiseRelu: + case Algorithm::EltwiseGeluErf: + case Algorithm::EltwiseGeluTanh: + case Algorithm::EltwiseCeiling: + case Algorithm::EltwiseFloor: + case Algorithm::EltwiseElu: + case Algorithm::EltwiseTanh: + case Algorithm::EltwiseSigmoid: + case Algorithm::EltwiseAbs: + case Algorithm::EltwiseSqrt: + case Algorithm::EltwiseSoftRelu: + case Algorithm::EltwiseExp: + case Algorithm::EltwiseClamp: + case Algorithm::EltwiseErf: + case Algorithm::EltwiseLogicalNot: + case Algorithm::EltwisePowerStatic: + case Algorithm::EltwiseSwish: + case Algorithm::EltwiseHswish: + case Algorithm::EltwiseMish: + case Algorithm::EltwiseHsigmoid: + case Algorithm::EltwiseRoundHalfToEven: + case Algorithm::EltwiseRoundHalfAwayFromZero: + case Algorithm::EltwiseSoftSign: + case Algorithm::EltwiseLog: + return 1; + case Algorithm::EltwiseAdd: + case Algorithm::EltwiseSubtract: + case Algorithm::EltwiseMultiply: + case Algorithm::EltwiseDivide: + case Algorithm::EltwiseFloorMod: + case Algorithm::EltwiseMod: + case Algorithm::EltwiseMaximum: + case Algorithm::EltwiseMinimum: + case Algorithm::EltwiseSquaredDifference: + case Algorithm::EltwisePowerDynamic: + case Algorithm::EltwiseEqual: + case Algorithm::EltwiseNotEqual: + case Algorithm::EltwiseGreater: + case Algorithm::EltwiseGreaterEqual: + case Algorithm::EltwiseLess: + case Algorithm::EltwiseLessEqual: + case Algorithm::EltwiseLogicalAnd: + case Algorithm::EltwiseLogicalOr: + case Algorithm::EltwiseLogicalXor: + case Algorithm::EltwiseBitwiseAnd: + case Algorithm::EltwiseBitwiseOr: + case Algorithm::EltwiseBitwiseXor: + case Algorithm::EltwiseBitwiseLeftShift: + case Algorithm::EltwiseBitwiseRightShift: + return 2; + case Algorithm::EltwiseBitwiseNot: + return 1; + case Algorithm::EltwisePrelu: + return 2; + case Algorithm::EltwiseMulAdd: + case Algorithm::EltwiseSelect: + return 3; + default: + OPENVINO_THROW("Unsupported operation for Eltwise node with name `", getName(), "`."); } } @@ -2183,40 +2287,37 @@ void Eltwise::getSupportedDescriptors() { void Eltwise::initSupportedPrimitiveDescriptors() { const auto isBitwise = [](const Algorithm& algorithm) { - return one_of( - algorithm, - Algorithm::EltwiseBitwiseAnd, - Algorithm::EltwiseBitwiseNot, - Algorithm::EltwiseBitwiseOr, - Algorithm::EltwiseBitwiseXor, - Algorithm::EltwiseBitwiseLeftShift, - Algorithm::EltwiseBitwiseRightShift); + return one_of(algorithm, + Algorithm::EltwiseBitwiseAnd, + Algorithm::EltwiseBitwiseNot, + Algorithm::EltwiseBitwiseOr, + Algorithm::EltwiseBitwiseXor, + Algorithm::EltwiseBitwiseLeftShift, + Algorithm::EltwiseBitwiseRightShift); }; - std::vector supportedPrecisions = isBitwise(algorithm) ? - std::vector { - ov::element::u8, - ov::element::i8, - ov::element::u16, - ov::element::i16, - ov::element::i32 - } : std::vector { - ov::element::f32, - ov::element::u8, - ov::element::i8, - ov::element::u16, - ov::element::i16, - ov::element::bf16, - ov::element::f16, - ov::element::i32 - }; + std::vector supportedPrecisions = isBitwise(algorithm) + ? std::vector{ov::element::u8, + ov::element::i8, + ov::element::u16, + ov::element::i16, + ov::element::i32} + : std::vector{ov::element::f32, + ov::element::u8, + ov::element::i8, + ov::element::u16, + ov::element::i16, + ov::element::bf16, + ov::element::f16, + ov::element::i32}; if (!supportedPrimitiveDescriptors.empty()) return; - // if dim rank is greater than the maximum possible, we should use the reference execution -#if defined (OPENVINO_ARCH_ARM64) - bool canUseOptimizedImpl = mayiuse(dnnl::impl::cpu::aarch64::asimd) && (getInputShapeAtPort(0).getRank() <= MAX_ELTWISE_DIM_RANK); + // if dim rank is greater than the maximum possible, we should use the reference execution +#if defined(OPENVINO_ARCH_ARM64) + bool canUseOptimizedImpl = + mayiuse(dnnl::impl::cpu::aarch64::asimd) && (getInputShapeAtPort(0).getRank() <= MAX_ELTWISE_DIM_RANK); bool canUseOptimizedShapeAgnosticImpl = isDynamicNode() && canUseOptimizedImpl; #else bool canUseOptimizedImpl = mayiuse(x64::sse41) && getInputShapeAtPort(0).getRank() <= MAX_ELTWISE_DIM_RANK; @@ -2261,7 +2362,7 @@ void Eltwise::initSupportedPrimitiveDescriptors() { ")"); std::vector inputPrecisions; - for (const auto &prec : getOriginalInputPrecisions()) { + for (const auto& prec : getOriginalInputPrecisions()) { inputPrecisions.push_back(prec); } @@ -2288,31 +2389,32 @@ void Eltwise::initSupportedPrimitiveDescriptors() { } #ifndef OPENVINO_ARCH_ARM64 - implType = canUseOptimizedShapeAgnosticImpl ? EltwiseImplType::optimizedShapeAgnostic : - canUseOptimizedImpl ? EltwiseImplType::optimized : EltwiseImplType::reference; + implType = canUseOptimizedShapeAgnosticImpl ? EltwiseImplType::optimizedShapeAgnostic + : canUseOptimizedImpl ? EltwiseImplType::optimized + : EltwiseImplType::reference; if (!hasHardwareSupport(ov::element::bf16)) { bool hasBF16 = false; - for (auto &inPrc : inputPrecisions) + for (auto& inPrc : inputPrecisions) if (inPrc == ov::element::bf16) hasBF16 = true; if (outputPrecision == ov::element::bf16 || hasBF16) OPENVINO_THROW("Eltwise node with name `", getName(), "` doesn't support BF16 precision on this target."); } -#if defined(OV_CPU_WITH_ACL) +# if defined(OV_CPU_WITH_ACL) const bool useJit = false; -#endif +# endif #elif defined(OPENVINO_ARCH_ARM64) - const bool useJit = canUseOptimizedImpl && - jitIsSupported(this, getAlpha(), getBeta(), getGamma()); + const bool useJit = canUseOptimizedImpl && jitIsSupported(this, getAlpha(), getBeta(), getGamma()); if (!useJit) { canUseOptimizedImpl = false; } - implType = (useJit && canUseOptimizedImpl) ? - (canUseOptimizedShapeAgnosticImpl ? EltwiseImplType::optimizedShapeAgnostic : EltwiseImplType::optimized) : - EltwiseImplType::reference; + implType = + (useJit && canUseOptimizedImpl) + ? (canUseOptimizedShapeAgnosticImpl ? EltwiseImplType::optimizedShapeAgnostic : EltwiseImplType::optimized) + : EltwiseImplType::reference; #else OPENVINO_THROW("Unknow CPU architecture"); #endif @@ -2330,66 +2432,74 @@ void Eltwise::initSupportedPrimitiveDescriptors() { const bool useAcl = !useJit; if (useAcl) { - // Use original output precision as a reference point since some eltwise algorithms have non-float inputs (i.e. EltwiseSelect) - ov::element::Type forcedPrec = getOriginalOutputPrecisionAtPort(0) == ov::element::f16 ? ov::element::f16 : ov::element::f32; - // ACL implementation supports only identical precisions on inputs/outputs so they are aligned it to highest one - if (AclEltwiseExecutor::isEltwiseAlgorithmSupported(getAlgorithm())) { - for (size_t i = 0; i < getParentEdges().size(); i++) { - if (!getParentEdgeAt(i)->getParent()->isConstant()) { - if (getOriginalInputPrecisionAtPort(i).size() > forcedPrec.size()) { - forcedPrec = getOriginalInputPrecisionAtPort(i); + // Use original output precision as a reference point since some eltwise algorithms have non-float inputs (i.e. + // EltwiseSelect) + ov::element::Type forcedPrec = + getOriginalOutputPrecisionAtPort(0) == ov::element::f16 ? ov::element::f16 : ov::element::f32; + // ACL implementation supports only identical precisions on inputs/outputs so they are aligned it to highest one + if (AclEltwiseExecutor::isEltwiseAlgorithmSupported(getAlgorithm())) { + for (size_t i = 0; i < getParentEdges().size(); i++) { + if (!getParentEdgeAt(i)->getParent()->isConstant()) { + if (getOriginalInputPrecisionAtPort(i).size() > forcedPrec.size()) { + forcedPrec = getOriginalInputPrecisionAtPort(i); + } } } + if (!forcedPrec.is_real()) { + forcedPrec = ov::element::f32; + } } - if (!forcedPrec.is_real()) { - forcedPrec = ov::element::f32; - } - } - for (size_t i = 0; i < inputPrecisions.size(); i++) { - inputPrecisions[i] = filterPrecision(inputPrecisions[i], forcedPrec); - } - outputPrecision = filterPrecision(outputPrecision, forcedPrec); - } else { -#endif -#if defined(OV_CPU_WITH_SHL) - if (ShlEltwiseExecutor::isEltwiseAlgorithmSupported(getAlgorithm())) { - // SHL implementation supports only identical precisions on inputs/outputs and only FP32 for now - const ov::element::Type forcedPrec = ov::element::f32; for (size_t i = 0; i < inputPrecisions.size(); i++) { - inputPrecisions[i] = forcedPrec; + inputPrecisions[i] = filterPrecision(inputPrecisions[i], forcedPrec); } - outputPrecision = forcedPrec; + outputPrecision = filterPrecision(outputPrecision, forcedPrec); } else { #endif - auto filterPrecision = [&](const ov::element::Type& prc) { - if (implType == EltwiseImplType::reference) { - if (isBitwise(algorithm)) { - if (std::find(supportedPrecisions.begin(), supportedPrecisions.end(), prc) == supportedPrecisions.end()) { - OPENVINO_THROW("Eltwise node with name `", getName(), "` doesn't support ", prc, " precision."); - } - return prc; - } - return ov::element::f32; - } else if (std::find(supportedPrecisions.begin(), supportedPrecisions.end(), prc) == supportedPrecisions.end()) { - if (prc == ov::element::u32 || prc == ov::element::i64 || prc == ov::element::u64) { - return ov::element::i32; - } else if (prc == ov::element::f64) { - return ov::element::f32; - } else { - OPENVINO_THROW("Eltwise node with name `", getName(), "` doesn't support ", prc, " precision."); +#if defined(OV_CPU_WITH_SHL) + if (ShlEltwiseExecutor::isEltwiseAlgorithmSupported(getAlgorithm())) { + // SHL implementation supports only identical precisions on inputs/outputs and only FP32 for now + const ov::element::Type forcedPrec = ov::element::f32; + for (size_t i = 0; i < inputPrecisions.size(); i++) { + inputPrecisions[i] = forcedPrec; } + outputPrecision = forcedPrec; } else { - return prc; - } - }; +#endif + auto filterPrecision = [&](const ov::element::Type& prc) { + if (implType == EltwiseImplType::reference) { + if (isBitwise(algorithm)) { + if (std::find(supportedPrecisions.begin(), supportedPrecisions.end(), prc) == + supportedPrecisions.end()) { + OPENVINO_THROW("Eltwise node with name `", + getName(), + "` doesn't support ", + prc, + " precision."); + } + return prc; + } + return ov::element::f32; + } else if (std::find(supportedPrecisions.begin(), supportedPrecisions.end(), prc) == + supportedPrecisions.end()) { + if (prc == ov::element::u32 || prc == ov::element::i64 || prc == ov::element::u64) { + return ov::element::i32; + } else if (prc == ov::element::f64) { + return ov::element::f32; + } else { + OPENVINO_THROW("Eltwise node with name `", getName(), "` doesn't support ", prc, " precision."); + } + } else { + return prc; + } + }; - for (size_t i = 0; i < inputPrecisions.size(); i++) { - inputPrecisions[i] = filterPrecision(inputPrecisions[i]); - } - outputPrecision = filterPrecision(outputPrecision); + for (size_t i = 0; i < inputPrecisions.size(); i++) { + inputPrecisions[i] = filterPrecision(inputPrecisions[i]); + } + outputPrecision = filterPrecision(outputPrecision); #if defined(OV_CPU_WITH_SHL) - } + } #endif #if defined(OV_CPU_WITH_ACL) } @@ -2398,22 +2508,19 @@ void Eltwise::initSupportedPrimitiveDescriptors() { // TODO: delete after new LPT (ngraph based) is merged // WA is needed to handle bug in LPT that produces wrong precision after average pooling (I8/U8 instead of FP32) if ((getAlgorithm() == Algorithm::EltwiseMulAdd || getAlgorithm() == Algorithm::EltwisePowerStatic) && - (inputPrecisions[0] == ov::element::u8 || inputPrecisions[0] == ov::element::i8)) { + (inputPrecisions[0] == ov::element::u8 || inputPrecisions[0] == ov::element::i8)) { auto parentNode = getParentEdgeAt(0)->getParent(); if (getParentEdgeAt(0)->getParent()->getAlgorithm() == Algorithm::PoolingAvg) { inputPrecisions[0] = ov::element::f32; } } - enum LayoutType { - Planar, - ChannelsFirst, - Blocked - }; + enum LayoutType { Planar, ChannelsFirst, Blocked }; - auto initDesc = [&] (LayoutType lt, const bool useEltwiseExecutor = false, const bool useJit = false) -> NodeDesc { - auto createMemoryDesc = [lt](const Shape &shape, ov::element::Type prc, size_t offset) -> std::shared_ptr { - const auto &dims = shape.getDims(); + auto initDesc = [&](LayoutType lt, const bool useEltwiseExecutor = false, const bool useJit = false) -> NodeDesc { + auto createMemoryDesc = + [lt](const Shape& shape, ov::element::Type prc, size_t offset) -> std::shared_ptr { + const auto& dims = shape.getDims(); if (lt == ChannelsFirst && shape.getRank() != 1) { auto ndims = shape.getRank(); VectorDims order(ndims); @@ -2429,10 +2536,11 @@ void Eltwise::initSupportedPrimitiveDescriptors() { } return std::make_shared(prc, shape, blocks, order, offset); - // TODO: need investigate - // bad accuracy for shape {1, 1, 4, 11}, {2, 5, 1, 1} - // same for disabled collapse dims - } else if (lt == Blocked && shape.getRank() != 1 && (shape.getMinDims()[1] != Shape::UNDEFINED_DIM && shape.getMinDims()[1] > 1)) { + // TODO: need investigate + // bad accuracy for shape {1, 1, 4, 11}, {2, 5, 1, 1} + // same for disabled collapse dims + } else if (lt == Blocked && shape.getRank() != 1 && + (shape.getMinDims()[1] != Shape::UNDEFINED_DIM && shape.getMinDims()[1] > 1)) { size_t blockSize = dnnl::impl::cpu::x64::mayiuse(x64::avx512_core) ? 16 : 8; VectorDims blocks = dims; VectorDims order(blocks.size()); @@ -2463,9 +2571,9 @@ void Eltwise::initSupportedPrimitiveDescriptors() { portConfig.inPlace((!i && canBeInPlace() && inputPrecisions[i] == outputPrecision) ? 0 : -1); portConfig.constant(false); - const auto &srcShape = getInputShapeAtPort(i); + const auto& srcShape = getInputShapeAtPort(i); if (!isDynamicNode() && srcShape.getDims()[0] == 1) { - inputMask.reset(0); // accepts any stride on the batch axis + inputMask.reset(0); // accepts any stride on the batch axis } portConfig.setMemDesc(createMemoryDesc(srcShape, inputPrecisions[i], offset), inputMask); @@ -2476,10 +2584,10 @@ void Eltwise::initSupportedPrimitiveDescriptors() { portConfig.inPlace(-1); portConfig.constant(false); - const auto &dstShape = getOutputShapeAtPort(0); + const auto& dstShape = getOutputShapeAtPort(0); BlockedMemoryDesc::CmpMask outputMask = BlockedMemoryDesc::SKIP_OFFSET_MASK; if (!isDynamicNode() && dstShape.getDims()[0] == 1) { - outputMask.reset(0); // accepts any stride on the batch axis + outputMask.reset(0); // accepts any stride on the batch axis } portConfig.setMemDesc(createMemoryDesc(dstShape, outputPrecision, offset), outputMask); @@ -2487,13 +2595,13 @@ void Eltwise::initSupportedPrimitiveDescriptors() { if (useEltwiseExecutor || useJit) { impl_desc_type impl_type; - #if defined (OPENVINO_ARCH_ARM64) +#if defined(OPENVINO_ARCH_ARM64) if (useJit) { impl_type = impl_desc_type::jit_asimd; } - #else +#else impl_type = impl_desc_type::undef; - #endif +#endif std::vector srcMemoryDescs; for (size_t i = 0; i < config.inConfs.size(); i++) { @@ -2504,20 +2612,23 @@ void Eltwise::initSupportedPrimitiveDescriptors() { dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()); } - auto factory = std::make_shared(eltwiseAttrs, srcMemoryDescs, dstMemoryDescs, - std::make_shared(context, getImplPriority())); + auto factory = + std::make_shared(eltwiseAttrs, + srcMemoryDescs, + dstMemoryDescs, + std::make_shared(context, getImplPriority())); return {config, impl_type, !factory->isEmpty() ? factory : nullptr}; } else { impl_desc_type impl_type = impl_desc_type::ref; if (canUseOptimizedImpl) { - #if defined (OPENVINO_ARCH_ARM64) +#if defined(OPENVINO_ARCH_ARM64) if (mayiuse(dnnl::impl::cpu::aarch64::asimd)) { impl_type = impl_desc_type::jit_asimd; } else { OPENVINO_THROW("not supported architecture"); } - #else +#else if (mayiuse(x64::avx512_core)) { impl_type = impl_desc_type::jit_avx512; } else if (mayiuse(x64::avx2)) { @@ -2525,7 +2636,7 @@ void Eltwise::initSupportedPrimitiveDescriptors() { } else if (mayiuse(x64::sse41)) { impl_type = impl_desc_type::jit_sse42; } - #endif +#endif } return {config, impl_type}; @@ -2534,10 +2645,11 @@ void Eltwise::initSupportedPrimitiveDescriptors() { bool isChannelsFirstApplicable = one_of(getOutputShapeAtPort(0).getRank(), 1u, 2u, 3u, 4u, 5u); for (size_t i = 0; i < getParentEdges().size(); i++) { - isChannelsFirstApplicable = isChannelsFirstApplicable && one_of(getInputShapeAtPort(i).getRank(), 1u, 2u, 3u, 4u, 5u); - isChannelsFirstApplicable = isChannelsFirstApplicable && implication(getInputShapeAtPort(i).getRank() != 1, - getOutputShapeAtPort(0).getRank() == - getInputShapeAtPort(i).getRank()); + isChannelsFirstApplicable = + isChannelsFirstApplicable && one_of(getInputShapeAtPort(i).getRank(), 1u, 2u, 3u, 4u, 5u); + isChannelsFirstApplicable = isChannelsFirstApplicable && + implication(getInputShapeAtPort(i).getRank() != 1, + getOutputShapeAtPort(0).getRank() == getInputShapeAtPort(i).getRank()); } #if defined(OPENVINO_ARCH_ARM64) @@ -2547,13 +2659,14 @@ void Eltwise::initSupportedPrimitiveDescriptors() { #endif for (size_t i = 0; i < getParentEdges().size(); i++) { - const auto &inShape = getInputShapeAtPort(i); + const auto& inShape = getInputShapeAtPort(i); isBlockedApplicable = isBlockedApplicable && one_of(inShape.getRank(), 1u, 3u, 4u, 5u); - isBlockedApplicable = isBlockedApplicable && implication(inShape.getRank() != 1, - getOutputShapeAtPort(0).getRank() == - inShape.getRank()); + isBlockedApplicable = + isBlockedApplicable && + implication(inShape.getRank() != 1, getOutputShapeAtPort(0).getRank() == inShape.getRank()); if (isDynamicNode() && inShape.getRank() != 1) - isBlockedApplicable = isBlockedApplicable && inShape.getMinDims()[1] != Shape::UNDEFINED_DIM && inShape.getMinDims()[1] > 1; + isBlockedApplicable = + isBlockedApplicable && inShape.getMinDims()[1] != Shape::UNDEFINED_DIM && inShape.getMinDims()[1] > 1; } inputNum = getParentEdges().size(); @@ -2561,28 +2674,29 @@ void Eltwise::initSupportedPrimitiveDescriptors() { #if defined(OV_CPU_WITH_ACL) if (useAcl || useJit) { - eltwiseAttrs = {algorithm, alpha, beta, gamma}; + eltwiseAttrs = {algorithm, alpha, beta, gamma}; - auto addDesc = [&initDesc, &useJit](std::vector& supportedPrimitiveDescriptors, const LayoutType layoutType) { - auto nodeDesc = initDesc(layoutType, !useJit, useJit); - if (nodeDesc.getExecutorFactory()) - supportedPrimitiveDescriptors.emplace_back(nodeDesc); - }; + auto addDesc = [&initDesc, &useJit](std::vector& supportedPrimitiveDescriptors, + const LayoutType layoutType) { + auto nodeDesc = initDesc(layoutType, !useJit, useJit); + if (nodeDesc.getExecutorFactory()) + supportedPrimitiveDescriptors.emplace_back(nodeDesc); + }; - // @todo should be handled in scope of selectPreferPrimitiveDescriptor - if (context->getConfig().modelType == Config::ModelType::CNN) { - if (isChannelsFirstApplicable) - addDesc(supportedPrimitiveDescriptors, ChannelsFirst); - addDesc(supportedPrimitiveDescriptors, Planar); - } else { - addDesc(supportedPrimitiveDescriptors, Planar); - if (isChannelsFirstApplicable) - addDesc(supportedPrimitiveDescriptors, ChannelsFirst); - } + // @todo should be handled in scope of selectPreferPrimitiveDescriptor + if (context->getConfig().modelType == Config::ModelType::CNN) { + if (isChannelsFirstApplicable) + addDesc(supportedPrimitiveDescriptors, ChannelsFirst); + addDesc(supportedPrimitiveDescriptors, Planar); + } else { + addDesc(supportedPrimitiveDescriptors, Planar); + if (isChannelsFirstApplicable) + addDesc(supportedPrimitiveDescriptors, ChannelsFirst); + } - canUseEltwiseExecPtr = !supportedPrimitiveDescriptors.empty() && !useJit; - if (!supportedPrimitiveDescriptors.empty()) - return; + canUseEltwiseExecPtr = !supportedPrimitiveDescriptors.empty() && !useJit; + if (!supportedPrimitiveDescriptors.empty()) + return; } #endif @@ -2652,15 +2766,18 @@ void Eltwise::prepareParams() { dstMemoryDescs.push_back(getDstMemoryAtPort(0)->getDescPtr()); auto selectedPD = getSelectedPrimitiveDescriptor(); - eltwiseExecPtr = selectedPD->getExecutorFactoryAs()->makeExecutor(eltwiseAttrs, srcMemoryDescs, dstMemoryDescs, {}); + eltwiseExecPtr = selectedPD->getExecutorFactoryAs()->makeExecutor(eltwiseAttrs, + srcMemoryDescs, + dstMemoryDescs, + {}); selectedPD->setImplementationType(eltwiseExecPtr->getImplType()); return; } auto outBlockingDesc = getChildEdgeAt(0)->getMemory().getDescWithType(); - const auto &outOrder = outBlockingDesc->getOrder(); - const auto ¤tOutBlkDims = outBlockingDesc->getBlockDims(); + const auto& outOrder = outBlockingDesc->getOrder(); + const auto& currentOutBlkDims = outBlockingDesc->getBlockDims(); size_t input_size = std::max(static_cast(EltwiseJitExecutor::optimalTensorRank), currentOutBlkDims.size()); @@ -2679,13 +2796,16 @@ void Eltwise::prepareParams() { size_t inRank = currentInBlkDims[i].size(); // WA to normalize blocked and planar layouts - const auto &inOrder = inBlockingDesc->getOrder(); + const auto& inOrder = inBlockingDesc->getOrder(); size_t startOff = outOrder.size() != outBlockingDesc->getShape().getRank() && - outOrder[outOrder.size() - 1] != inOrder[inOrder.size() - 1] ? 1 : 0; + outOrder[outOrder.size() - 1] != inOrder[inOrder.size() - 1] + ? 1 + : 0; // WA to handle nspc layout with 1D tensors if (1 == inRank) { - if (outRank > 2 && 1 == outOrder.back()) startOff = 1; + if (outRank > 2 && 1 == outOrder.back()) + startOff = 1; } for (size_t j = 0; j < inRank; j++) { @@ -2718,14 +2838,18 @@ void Eltwise::prepareParams() { if (!canSkipSearchInCache) { EltwiseData thisOp{getAlgorithm(), getOneDnnAlgorithm(), getAlpha(), getBeta(), getGamma()}; - EltwiseKey key = {{thisOp}, {getType()}, currentOutBlkDims, outOrder, dims_in, inpPrc, outPrc, dnnl::post_ops(), implType}; + EltwiseKey key = + {{thisOp}, {getType()}, currentOutBlkDims, outOrder, dims_in, inpPrc, outPrc, dnnl::post_ops(), implType}; fqDataPtrs.clear(); - for (const auto &node : fusedWith) { + for (const auto& node : fusedWith) { key.ops_list.push_back(node->getType()); if (node->getType() == Type::Eltwise) { if (auto eltwise = std::dynamic_pointer_cast(node)) { - key.eltwise_data.push_back({eltwise->getAlgorithm(), eltwise->getOneDnnAlgorithm(), eltwise->getAlpha(), - eltwise->getBeta(), eltwise->getGamma()}); + key.eltwise_data.push_back({eltwise->getAlgorithm(), + eltwise->getOneDnnAlgorithm(), + eltwise->getAlpha(), + eltwise->getBeta(), + eltwise->getGamma()}); } } else if (node->getType() == Type::FakeQuantize) { node->appendPostOps(key.postOps, {}, fqDataPtrs); @@ -2745,9 +2869,9 @@ void Eltwise::prepareParams() { // update execParams for shape agnostic kernel if (implType == EltwiseImplType::optimizedShapeAgnostic) { - auto &outDims = execParams.outDims; - auto &inOffsets = execParams.inOffsets; - auto &outOffsets = execParams.outOffsets; + auto& outDims = execParams.outDims; + auto& inOffsets = execParams.inOffsets; + auto& outOffsets = execParams.outOffsets; // outDims recalculation outDims.resize(dims_in[0].size(), 1); @@ -2805,7 +2929,8 @@ void Eltwise::selectOptimalPrimitiveDescriptor() { void Eltwise::execute(dnnl::stream strm) { if (execPtr) { jit_eltwise_call_args_ptrs args_ptrs = {}; - VectorDims dims_out = implType == EltwiseImplType::optimizedShapeAgnostic ? execParams.outDims : execPtr->getOutDims(); + VectorDims dims_out = + implType == EltwiseImplType::optimizedShapeAgnostic ? execParams.outDims : execPtr->getOutDims(); for (size_t i = 0; i < memPtrs.size() - 1; i++) args_ptrs.src_ptr[i] = memPtrs[i]->getDataAs() + start_offset_in[i]; args_ptrs.dst_ptr = memPtrs.back()->getDataAs() + start_offset_out; @@ -2873,15 +2998,14 @@ void Eltwise::fuseInto(NodePtr& parentNode) { getAlgorithm() == Algorithm::EltwiseAdd && dimsEqualWeak(getInputShapeAtPort(0).getDims(), getInputShapeAtPort(1).getDims()) && !getParentEdgeAt(0)->getParent()->isConstant() && !getParentEdgeAt(1)->getParent()->isConstant(); - if ((scales.empty() && shifts.empty()) && - !specialConvolutionAddFusing && + if ((scales.empty() && shifts.empty()) && !specialConvolutionAddFusing && canBePerformedAsScaleShift(parentNode.get())) { std::tie(scales, shifts) = getScalesAndShifts(parentNode.get()); } Node::fuseInto(parentNode); } -void Eltwise::appendMemory(const std::vector &data, MemoryPtr &memPtr, std::vector& postOpsMem) { +void Eltwise::appendMemory(const std::vector& data, MemoryPtr& memPtr, std::vector& postOpsMem) { if (!memPtr) { DnnlBlockedMemoryDesc memoryDesc(ov::element::f32, {data.size()}); memPtr = std::make_shared(getEngine(), memoryDesc, data.data()); @@ -2889,12 +3013,15 @@ void Eltwise::appendMemory(const std::vector &data, MemoryPtr &memPtr, st } } -void Eltwise::appendMemory(const std::vector &data, MemoryPtr &memPtr, std::vector& postOpsMem) { +void Eltwise::appendMemory(const std::vector& data, MemoryPtr& memPtr, std::vector& postOpsMem) { postOpsMem.push_back(data.data()); } template -void Eltwise::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector& postOpsMem, const int channelAxis) { +void Eltwise::appendPostOpsImpl(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::vector& postOpsMem, + const int channelAxis) { const std::string errorPrefix = "Appending Eltwise node with name '" + getName() + "' "; if (getOneDnnAlgorithm() != dnnl::algorithm::undef) { @@ -2920,7 +3047,8 @@ void Eltwise::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDim case dnnl::algorithm::eltwise_round_half_away_from_zero: ops.append_eltwise(getOneDnnAlgorithm(), getAlpha(), getBeta()); break; - default: OPENVINO_THROW(errorPrefix, "as post operation is not supported"); + default: + OPENVINO_THROW(errorPrefix, "as post operation is not supported"); } } else { // per-tensor EltwisePowerStatic can be implemented with more well-supported eltwise postOps @@ -2938,7 +3066,8 @@ void Eltwise::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDim const auto chIdx = postOpDims.size() > 1 ? channelAxis : 0; channelSize = postOpDims[chIdx]; } - // since legacy depthwise post ops mechanism requires broadcasted data we need to reinitilize it in case of changed shape + // since legacy depthwise post ops mechanism requires broadcasted data we need to reinitilize it in case of + // changed shape if (depthwiseData.empty() || depthwiseDataSize != 2 * channelSize) { depthwiseData.clear(); depthwiseMemory.reset(); @@ -2995,7 +3124,10 @@ void Eltwise::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDim } } -void Eltwise::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::unordered_map& postOpsMem, const int channelAxis) { +void Eltwise::appendPostOps(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::unordered_map& postOpsMem, + const int channelAxis) { std::vector postOpsMemPtrs; appendPostOpsImpl(ops, postOpDims, postOpsMemPtrs, channelAxis); @@ -3006,11 +3138,17 @@ void Eltwise::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, s } } -void Eltwise::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector& postOpsMem, const int channelAxis) { +void Eltwise::appendPostOps(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::vector& postOpsMem, + const int channelAxis) { appendPostOpsImpl(ops, postOpDims, postOpsMem, channelAxis); } -bool Eltwise::appendAttrPostOps(DnnlPostOpsComposerLegacy& dnnlpoc, bool isLastPostOp, dnnl::memory::data_type outDataType, bool allowBinary) { +bool Eltwise::appendAttrPostOps(DnnlPostOpsComposerLegacy& dnnlpoc, + bool isLastPostOp, + dnnl::memory::data_type outDataType, + bool allowBinary) { const std::string errorPrefix = "Appending Eltwise node with name '" + getName() + "' as binary post op "; if (getOneDnnAlgorithm() != dnnl::algorithm::undef) { @@ -3039,7 +3177,8 @@ bool Eltwise::appendAttrPostOps(DnnlPostOpsComposerLegacy& dnnlpoc, bool isLastP // call dnnlpoc's specialized API to generate optimized postOps sequence dnnlpoc.appendLinear({getAlpha()}, {getBeta()}, isLastPostOp); break; - default: OPENVINO_THROW(errorPrefix, "as post operation is not supported"); + default: + OPENVINO_THROW(errorPrefix, "as post operation is not supported"); } } else { switch (getAlgorithm()) { @@ -3054,9 +3193,9 @@ bool Eltwise::appendAttrPostOps(DnnlPostOpsComposerLegacy& dnnlpoc, bool isLastP case Algorithm::EltwisePowerStatic: if (beta != 1.0f && gamma != 0.0f) { return dnnlpoc.appendLinear(scales, shifts, isLastPostOp, allowBinary); - } else if (beta != 1.0f) {// Multiply if has scales + } else if (beta != 1.0f) { // Multiply if has scales return dnnlpoc.appendScale(scales, isLastPostOp, allowBinary); - } else if (gamma != 0.0f) {// Add only if has shifts + } else if (gamma != 0.0f) { // Add only if has shifts return dnnlpoc.appendShift(shifts, allowBinary); } break; @@ -3103,16 +3242,17 @@ bool Eltwise::canFuseParent(const NodePtr& parentNode) const { bool Eltwise::canFuse(const NodePtr& node) const { auto isIntegerComputeSupported = [](const Node* node) { - if (!one_of(node->getAlgorithm(), Algorithm::EltwiseAdd, - Algorithm::EltwiseMultiply, - Algorithm::EltwiseMulAdd, - Algorithm::EltwiseSubtract, - Algorithm::EltwiseDivide, - Algorithm::EltwiseSquaredDifference)) { + if (!one_of(node->getAlgorithm(), + Algorithm::EltwiseAdd, + Algorithm::EltwiseMultiply, + Algorithm::EltwiseMulAdd, + Algorithm::EltwiseSubtract, + Algorithm::EltwiseDivide, + Algorithm::EltwiseSquaredDifference)) { return false; } - for (const auto &originalInputPrecision : node->getOriginalInputPrecisions()) { + for (const auto& originalInputPrecision : node->getOriginalInputPrecisions()) { if (originalInputPrecision != ov::element::i32) { return false; } @@ -3121,7 +3261,7 @@ bool Eltwise::canFuse(const NodePtr& node) const { return true; }; -#if defined (OPENVINO_ARCH_ARM64) +#if defined(OPENVINO_ARCH_ARM64) if (!mayiuse(dnnl::impl::cpu::aarch64::asimd) || (getInputShapeAtPort(0).getRank() > MAX_ELTWISE_DIM_RANK)) return false; @@ -3129,10 +3269,8 @@ bool Eltwise::canFuse(const NodePtr& node) const { return false; } const auto eltwise = dynamic_cast(node.get()); - if ((eltwise == nullptr) || (!jitIsSupported(eltwise, - eltwise->getAlpha(), - eltwise->getBeta(), - eltwise->getGamma()))) { + if ((eltwise == nullptr) || + (!jitIsSupported(eltwise, eltwise->getAlpha(), eltwise->getBeta(), eltwise->getGamma()))) { return false; } #else @@ -3170,29 +3308,30 @@ bool Eltwise::canFuse(const NodePtr& node) const { return false; if (node->getType() == Type::Eltwise) { - // [WA] Since execution precision change from I32 to FP32 for arithmetic operations may lead to incorrect results - // we disable fusing cases which may lead to invalid precision conversions inside the kernel - // [TODO] We need to rewrite support for different precisions at all to avoid implicit conversions to FP32 - // (all should be handled via explicit convert operations) + // [WA] Since execution precision change from I32 to FP32 for arithmetic operations may lead to incorrect + // results we disable fusing cases which may lead to invalid precision conversions inside the kernel [TODO] We + // need to rewrite support for different precisions at all to avoid implicit conversions to FP32 (all should be + // handled via explicit convert operations) bool isIntegerFusingNode = isIntegerComputeSupported(node.get()); - if ((isIntegerNode && !isIntegerFusingNode) || - (!isIntegerNode && isIntegerFusingNode)) { + if ((isIntegerNode && !isIntegerFusingNode) || (!isIntegerNode && isIntegerFusingNode)) { return false; } if (node->getParentEdgeAt(0)->getParent().get() != this) { - // Eltwise jitter doesn't respect commutative property, so fusing is disabled in case it applied not for 0-th port. - if (one_of(node->getAlgorithm(), Algorithm::EltwiseSubtract, - Algorithm::EltwiseDivide, - Algorithm::EltwiseFloorMod, - Algorithm::EltwiseMod, - Algorithm::EltwisePowerDynamic, - Algorithm::EltwiseGreater, - Algorithm::EltwiseGreaterEqual, - Algorithm::EltwiseLess, - Algorithm::EltwiseLessEqual, - Algorithm::EltwiseMulAdd, - Algorithm::EltwiseSelect)) { + // Eltwise jitter doesn't respect commutative property, so fusing is disabled in case it applied not for + // 0-th port. + if (one_of(node->getAlgorithm(), + Algorithm::EltwiseSubtract, + Algorithm::EltwiseDivide, + Algorithm::EltwiseFloorMod, + Algorithm::EltwiseMod, + Algorithm::EltwisePowerDynamic, + Algorithm::EltwiseGreater, + Algorithm::EltwiseGreaterEqual, + Algorithm::EltwiseLess, + Algorithm::EltwiseLessEqual, + Algorithm::EltwiseMulAdd, + Algorithm::EltwiseSelect)) { return false; } @@ -3205,7 +3344,8 @@ bool Eltwise::canFuse(const NodePtr& node) const { } } - // We can use optimized execution with fusions only in cases when dim rank is less or equal to the maximum possible + // We can use optimized execution with fusions only in cases when dim rank is less or equal to the maximum + // possible if (node->getInputShapeAtPort(0).getRank() > MAX_ELTWISE_DIM_RANK) return false; @@ -3224,13 +3364,15 @@ ov::element::Type Eltwise::getRuntimePrecision() const { // Don't take bias precision into account for (size_t i = 0; i < getParentEdges().size(); i++) { auto parentEdge = getParentEdgeAt(i); - if (parentEdge && parentEdge->getStatus() == Edge::Status::Validated && !parentEdge->getParent()->isConstant()) { - inputPrecisions.emplace_back(DnnlExtensionUtils::DataTypeToElementType((parentEdge->getMemoryPtr()->getDataType()))); + if (parentEdge && parentEdge->getStatus() == Edge::Status::Validated && + !parentEdge->getParent()->isConstant()) { + inputPrecisions.emplace_back( + DnnlExtensionUtils::DataTypeToElementType((parentEdge->getMemoryPtr()->getDataType()))); } } return getMaxPrecision(inputPrecisions); } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/eltwise.h b/src/plugins/intel_cpu/src/nodes/eltwise.h index 6013ce732ee5fc..d0ca94e08824c8 100644 --- a/src/plugins/intel_cpu/src/nodes/eltwise.h +++ b/src/plugins/intel_cpu/src/nodes/eltwise.h @@ -5,17 +5,18 @@ #pragma once #include + +#include #include #include -#include #include "dnnl_postops_composer_legacy.h" -#include "nodes/executors/eltwise.hpp" #include "executors/eltwise_list.hpp" +#include "nodes/executors/eltwise.hpp" #include "nodes/kernels/jit_eltwise_call_args_ptrs.hpp" #if defined(OPENVINO_ARCH_ARM64) -#include "kernels/aarch64/jit_uni_eltwise_generic.hpp" +# include "kernels/aarch64/jit_uni_eltwise_generic.hpp" #endif namespace ov { @@ -68,18 +69,14 @@ struct jit_uni_eltwise_kernel { #endif -enum class EltwiseImplType { - reference = 0, - optimized = 1, - optimizedShapeAgnostic = 2 -}; +enum class EltwiseImplType { reference = 0, optimized = 1, optimizedShapeAgnostic = 2 }; class Eltwise : public Node { public: class IEltwiseExecutor { public: IEltwiseExecutor() = default; - virtual void exec(const jit_eltwise_call_args_ptrs &args_ptrs, const VectorDims &dims_out) = 0; + virtual void exec(const jit_eltwise_call_args_ptrs& args_ptrs, const VectorDims& dims_out) = 0; virtual size_t getBatchDimIdx() const = 0; virtual const VectorDims& getOutDims() const = 0; virtual ~IEltwiseExecutor() = default; @@ -98,22 +95,45 @@ class Eltwise : public Node { bool canBeInPlace() const override; bool canFuseParent(const NodePtr& parentNode) const; bool canFuse(const NodePtr& node) const override; - void appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::unordered_map& postOpsMem, const int channelAxis = 1) override; - void appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector& postOpsMem, const int channelAxis = 1) override; - bool appendAttrPostOps(DnnlPostOpsComposerLegacy& dnnlpoc, bool isLastPostOp, dnnl::memory::data_type outDataType, bool allowBinary = true); + void appendPostOps(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::unordered_map& postOpsMem, + const int channelAxis = 1) override; + void appendPostOps(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::vector& postOpsMem, + const int channelAxis = 1) override; + bool appendAttrPostOps(DnnlPostOpsComposerLegacy& dnnlpoc, + bool isLastPostOp, + dnnl::memory::data_type outDataType, + bool allowBinary = true); void fuseInto(NodePtr& parentNode) override; ov::element::Type getRuntimePrecision() const override; - float getAlpha() const { return alpha; } - float getBeta() const { return beta; } - float getGamma() const { return gamma; } - const std::vector& getScales() const { return scales; } - const std::vector& getShifts() const { return shifts; } + float getAlpha() const { + return alpha; + } + float getBeta() const { + return beta; + } + float getGamma() const { + return gamma; + } + const std::vector& getScales() const { + return scales; + } + const std::vector& getShifts() const { + return shifts; + } - dnnl::algorithm getOneDnnAlgorithm() const { return onednnAlgorithm; } + dnnl::algorithm getOneDnnAlgorithm() const { + return onednnAlgorithm; + } bool isWithBroadcast(); - bool isSpecialConvolutionAddFusing() const { return specialConvolutionAddFusing; } + bool isSpecialConvolutionAddFusing() const { + return specialConvolutionAddFusing; + } bool needPrepareParams() const override; void prepareParams() override; @@ -127,7 +147,9 @@ class Eltwise : public Node { Undefined, }; - BroadcastingPolicy getBroadcastingPolicy() const { return broadcastingPolicy; } + BroadcastingPolicy getBroadcastingPolicy() const { + return broadcastingPolicy; + } static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; @@ -181,10 +203,13 @@ class Eltwise : public Node { size_t getOpInputsNum() const; template - void appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector& postOpsMem, const int channelAxis = 1); + void appendPostOpsImpl(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::vector& postOpsMem, + const int channelAxis = 1); - void appendMemory(const std::vector &data, MemoryPtr &memPtr, std::vector& postOpsMem); - void appendMemory(const std::vector &data, MemoryPtr &memPtr, std::vector& postOpsMem); + void appendMemory(const std::vector& data, MemoryPtr& memPtr, std::vector& postOpsMem); + void appendMemory(const std::vector& data, MemoryPtr& memPtr, std::vector& postOpsMem); bool canUseEltwiseExecPtr = false; EltwiseAttrs eltwiseAttrs; @@ -201,6 +226,6 @@ class eltwise_precision_helper { static std::set> get_supported_precisions(const Algorithm& algo); }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/embedding_bag.cpp b/src/plugins/intel_cpu/src/nodes/embedding_bag.cpp index 8b144e90c865bc..2dcb93f9fc6c1b 100644 --- a/src/plugins/intel_cpu/src/nodes/embedding_bag.cpp +++ b/src/plugins/intel_cpu/src/nodes/embedding_bag.cpp @@ -18,10 +18,10 @@ namespace intel_cpu { namespace node { EmbeddingBag::EmbeddingBag(const std::shared_ptr& op, - size_t requiredInputNum, - size_t indicesIdx, - size_t perSampleWeightsIdx, - size_t defaultIndexIdx) + size_t requiredInputNum, + size_t indicesIdx, + size_t perSampleWeightsIdx, + size_t defaultIndexIdx) : INDICES_IDX(indicesIdx), PER_SAMPLE_WEIGHTS_IDX(perSampleWeightsIdx), DEFAULT_INDEX_IDX(defaultIndexIdx) { @@ -47,9 +47,9 @@ void EmbeddingBag::prepareParams(const VectorDims& indexStaticShape) { template void EmbeddingBag::processData(const T* srcData, - const T* weightsData, - const VectorDims& inDataDims, - const MemoryPtr& outMemory) { + const T* weightsData, + const VectorDims& inDataDims, + const MemoryPtr& outMemory) { std::string msgPrefix = std::string("Node EmbeddingBag with name '") + _layerName + "' "; initFromInputs(); @@ -127,10 +127,10 @@ void EmbeddingBag::processData(const T* srcData, } void EmbeddingBag::execute(const uint8_t* srcData, - const uint8_t* weightsData, - const ov::element::Type& srcPrc, - const VectorDims& inDims, - const MemoryPtr& outMemory) { + const uint8_t* weightsData, + const ov::element::Type& srcPrc, + const VectorDims& inDims, + const MemoryPtr& outMemory) { switch (srcPrc) { case ov::element::f32: { return processData::value_type>( @@ -157,8 +157,7 @@ void EmbeddingBag::execute(const uint8_t* srcData, outMemory); } default: { - OPENVINO_THROW("EmbeddingBag layer does not support precision '" + std::string(srcPrc.get_type_name()) + - "'"); + OPENVINO_THROW("EmbeddingBag layer does not support precision '" + std::string(srcPrc.get_type_name()) + "'"); } } } diff --git a/src/plugins/intel_cpu/src/nodes/embedding_bag.h b/src/plugins/intel_cpu/src/nodes/embedding_bag.h index 28c8666233fa1a..d804ea06c2b317 100644 --- a/src/plugins/intel_cpu/src/nodes/embedding_bag.h +++ b/src/plugins/intel_cpu/src/nodes/embedding_bag.h @@ -13,32 +13,32 @@ namespace node { class EmbeddingBag { public: enum class Reduction { SUM, MEAN }; - EmbeddingBag( - const std::shared_ptr&, - size_t requiredInputsNum, - size_t indicesIdx, - size_t perSampleWeightsIdx, - size_t defaultIndexIdx); - - void execute(const uint8_t* srcData, const uint8_t* weightsData, const ov::element::Type &srcPrc, - const VectorDims& inDims, const MemoryPtr& outMemory); + EmbeddingBag(const std::shared_ptr&, + size_t requiredInputsNum, + size_t indicesIdx, + size_t perSampleWeightsIdx, + size_t defaultIndexIdx); + + void execute(const uint8_t* srcData, + const uint8_t* weightsData, + const ov::element::Type& srcPrc, + const VectorDims& inDims, + const MemoryPtr& outMemory); ~EmbeddingBag() = default; protected: virtual void initFromInputs() = 0; - virtual void getIndices( - size_t embIndex, - const int*& indicesRef, - size_t& size, - int& weightsIdx, - bool& withWeights) = 0; + virtual void getIndices(size_t embIndex, + const int*& indicesRef, + size_t& size, + int& weightsIdx, + bool& withWeights) = 0; void prepareParams(const VectorDims& indexStaticShape); - template - void processData(const T* srcData, const T* weightsData, - const VectorDims& inDataDims, const MemoryPtr& outMemory); + template + void processData(const T* srcData, const T* weightsData, const VectorDims& inDataDims, const MemoryPtr& outMemory); const size_t EMB_TABLE_IDX = 0lu; const size_t INDICES_IDX; @@ -51,6 +51,6 @@ class EmbeddingBag { std::string _layerName; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/embedding_bag_offsets.cpp b/src/plugins/intel_cpu/src/nodes/embedding_bag_offsets.cpp index b5fbaee982808d..8da557a823a948 100644 --- a/src/plugins/intel_cpu/src/nodes/embedding_bag_offsets.cpp +++ b/src/plugins/intel_cpu/src/nodes/embedding_bag_offsets.cpp @@ -2,24 +2,27 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "embedding_bag_offsets.h" + #include -#include #include -#include "embedding_bag_offsets.h" -#include "openvino/op/embeddingbag_offsets_sum.hpp" -#include "openvino/op/embeddingbag_offsets.hpp" +#include +#include "openvino/op/embeddingbag_offsets.hpp" +#include "openvino/op/embeddingbag_offsets_sum.hpp" namespace ov { namespace intel_cpu { namespace node { -bool EmbeddingBagOffset::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool EmbeddingBagOffset::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { const auto embBagOffsetSumOp = ov::as_type_ptr(op); const auto embBagOffsetOp = ov::as_type_ptr(op); if (!embBagOffsetSumOp && !embBagOffsetOp) { - errorMessage = "Node is not an instance of the v3::EmbeddingBagOffsetsSum or v15::EmbeddingBagOffsets operation."; + errorMessage = + "Node is not an instance of the v3::EmbeddingBagOffsetsSum or v15::EmbeddingBagOffsets operation."; return false; } } catch (...) { @@ -46,7 +49,8 @@ EmbeddingBagOffset::EmbeddingBagOffset(const std::shared_ptr& op, cons _reduction = Reduction::MEAN; break; default: - THROW_CPU_NODE_ERR("EmbeddingBagOffsets does not support reduction mode: ", ov::as_string(offsets_op->get_reduction())); + THROW_CPU_NODE_ERR("EmbeddingBagOffsets does not support reduction mode: ", + ov::as_string(offsets_op->get_reduction())); } } if (getInputShapeAtPort(INDICES_IDX).getRank() != 1ul) @@ -61,8 +65,10 @@ void EmbeddingBagOffset::initSupportedPrimitiveDescriptors() { return; std::string logPrefix = std::string("Layer EmbeddingBag with name '") + _layerName + "' "; - static const std::set supportedPrecisions = - {ov::element::f32, ov::element::i8, ov::element::u8, ov::element::i32}; + static const std::set supportedPrecisions = {ov::element::f32, + ov::element::i8, + ov::element::u8, + ov::element::i32}; auto inDataPrecision = getOriginalInputPrecisionAtPort(EMB_TABLE_IDX); if (one_of(inDataPrecision, ov::element::bf16, ov::element::f16)) @@ -71,8 +77,10 @@ void EmbeddingBagOffset::initSupportedPrimitiveDescriptors() { if (supportedPrecisions.find(inDataPrecision) == supportedPrecisions.end()) OPENVINO_THROW(logPrefix, "has unsupported precision: ", inDataPrecision.get_type_name()); } else { - static const std::set defaultSupportedPrecisions = - {ov::element::f32, ov::element::i8, ov::element::u8, ov::element::i32}; + static const std::set defaultSupportedPrecisions = {ov::element::f32, + ov::element::i8, + ov::element::u8, + ov::element::i32}; if (defaultSupportedPrecisions.find(inDataPrecision) == defaultSupportedPrecisions.end()) OPENVINO_THROW(logPrefix, "has unsupported precision: ", inDataPrecision.get_type_name()); } @@ -103,7 +111,11 @@ void EmbeddingBagOffset::initFromInputs() { } } -void EmbeddingBagOffset::getIndices(size_t embIndex, const int*& indices, size_t& size, int& weightsIdx, bool& withWeight) { +void EmbeddingBagOffset::getIndices(size_t embIndex, + const int*& indices, + size_t& size, + int& weightsIdx, + bool& withWeight) { if (static_cast(embIndex) >= _offsetsLen) { OPENVINO_THROW("Invalid embedding bag index."); } @@ -145,20 +157,23 @@ bool EmbeddingBagOffset::isExecutable() const { } void EmbeddingBagOffset::execute(dnnl::stream strm) { - const auto *srcData = getSrcDataAtPortAs(0); + const auto* srcData = getSrcDataAtPortAs(0); const uint8_t* weightsData = nullptr; if (_withWeights) weightsData = getSrcDataAtPortAs(PER_SAMPLE_WEIGHTS_IDX); - const auto &inputMem = getParentEdgeAt(0)->getMemory(); - EmbeddingBag::execute(srcData, weightsData, inputMem.getDesc().getPrecision(), - inputMem.getStaticDims(), getDstMemoryAtPort(0)); + const auto& inputMem = getParentEdgeAt(0)->getMemory(); + EmbeddingBag::execute(srcData, + weightsData, + inputMem.getDesc().getPrecision(), + inputMem.getStaticDims(), + getDstMemoryAtPort(0)); } bool EmbeddingBagOffset::created() const { return getType() == Type::EmbeddingBagOffsets; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/embedding_bag_offsets.h b/src/plugins/intel_cpu/src/nodes/embedding_bag_offsets.h index a31b518e7891a9..f8a28152a26642 100644 --- a/src/plugins/intel_cpu/src/nodes/embedding_bag_offsets.h +++ b/src/plugins/intel_cpu/src/nodes/embedding_bag_offsets.h @@ -15,7 +15,7 @@ class EmbeddingBagOffset : public Node, public EmbeddingBag { public: EmbeddingBagOffset(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -41,6 +41,6 @@ class EmbeddingBagOffset : public Node, public EmbeddingBag { size_t _offsetsLen = 0; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/embedding_bag_packed.cpp b/src/plugins/intel_cpu/src/nodes/embedding_bag_packed.cpp index fd2e0b6141f1fc..c1a06835a67af3 100644 --- a/src/plugins/intel_cpu/src/nodes/embedding_bag_packed.cpp +++ b/src/plugins/intel_cpu/src/nodes/embedding_bag_packed.cpp @@ -2,23 +2,27 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "embedding_bag_packed.h" + #include -#include #include -#include "embedding_bag_packed.h" -#include "openvino/op/embeddingbag_packedsum.hpp" +#include + #include "openvino/op/embeddingbag_packed.hpp" +#include "openvino/op/embeddingbag_packedsum.hpp" namespace ov { namespace intel_cpu { namespace node { -bool EmbeddingBagPacked::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool EmbeddingBagPacked::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { const auto embBagPackedSumOp = ov::as_type_ptr(op); const auto embBagPackedOp = ov::as_type_ptr(op); if (!embBagPackedSumOp && !embBagPackedOp) { - errorMessage = "Node is not an instance of the v3::EmbeddingBagPackedSum or v15::EmbeddingBagPacked operations."; + errorMessage = + "Node is not an instance of the v3::EmbeddingBagPackedSum or v15::EmbeddingBagPacked operations."; return false; } } catch (...) { @@ -45,7 +49,8 @@ EmbeddingBagPacked::EmbeddingBagPacked(const std::shared_ptr& op, cons _reduction = Reduction::MEAN; break; default: - THROW_CPU_NODE_ERR("EmbeddingBagPacked does not support reduction mode: ", ov::as_string(packed_op->get_reduction())); + THROW_CPU_NODE_ERR("EmbeddingBagPacked does not support reduction mode: ", + ov::as_string(packed_op->get_reduction())); } } if (getInputShapeAtPort(INDICES_IDX).getRank() != 2ul) @@ -57,8 +62,10 @@ void EmbeddingBagPacked::initSupportedPrimitiveDescriptors() { return; std::string logPrefix = std::string("Layer EmbeddingBag with name '") + _layerName + "' "; - static const std::set supportedPrecisions = - {ov::element::f32, ov::element::i8, ov::element::u8, ov::element::i32}; + static const std::set supportedPrecisions = {ov::element::f32, + ov::element::i8, + ov::element::u8, + ov::element::i32}; auto inDataPrecision = getOriginalInputPrecisionAtPort(EMB_TABLE_IDX); if (one_of(inDataPrecision, ov::element::bf16, ov::element::f16)) @@ -67,14 +74,16 @@ void EmbeddingBagPacked::initSupportedPrimitiveDescriptors() { if (supportedPrecisions.find(inDataPrecision) == supportedPrecisions.end()) OPENVINO_THROW(logPrefix, "has unsupported precision: ", inDataPrecision.get_type_name()); } else { - static const std::set defaultSupportedPrecisions = - {ov::element::f32, ov::element::i8, ov::element::u8, ov::element::i32}; + static const std::set defaultSupportedPrecisions = {ov::element::f32, + ov::element::i8, + ov::element::u8, + ov::element::i32}; if (defaultSupportedPrecisions.find(inDataPrecision) == defaultSupportedPrecisions.end()) OPENVINO_THROW(logPrefix, "has unsupported precision: ", inDataPrecision.get_type_name()); } - std::vector inDataConfigurators({{LayoutType::ncsp, inDataPrecision}, - {LayoutType::ncsp, ov::element::i32}}); + std::vector inDataConfigurators( + {{LayoutType::ncsp, inDataPrecision}, {LayoutType::ncsp, ov::element::i32}}); if (inputShapes.size() > PER_SAMPLE_WEIGHTS_IDX) inDataConfigurators.push_back({LayoutType::ncsp, inDataPrecision}); @@ -91,7 +100,11 @@ void EmbeddingBagPacked::initFromInputs() { _indices = getSrcDataAtPortAs(INDICES_IDX); } -void EmbeddingBagPacked::getIndices(size_t embIndex, const int*& indices, size_t& size, int& weightsIdx, bool& withWeight) { +void EmbeddingBagPacked::getIndices(size_t embIndex, + const int*& indices, + size_t& size, + int& weightsIdx, + bool& withWeight) { if (static_cast(embIndex) >= _batch * _indicesPerBag) OPENVINO_THROW("Invalid embedding bag index."); @@ -112,20 +125,23 @@ bool EmbeddingBagPacked::isExecutable() const { } void EmbeddingBagPacked::execute(dnnl::stream strm) { - const auto *srcData = getSrcDataAtPortAs(0); + const auto* srcData = getSrcDataAtPortAs(0); const uint8_t* weightsData = nullptr; if (_withWeights) weightsData = getSrcDataAtPortAs(PER_SAMPLE_WEIGHTS_IDX); - const auto &inputMem = getParentEdgeAt(0)->getMemory(); - EmbeddingBag::execute(srcData, weightsData, inputMem.getDesc().getPrecision(), - inputMem.getStaticDims(), getDstMemoryAtPort(0)); + const auto& inputMem = getParentEdgeAt(0)->getMemory(); + EmbeddingBag::execute(srcData, + weightsData, + inputMem.getDesc().getPrecision(), + inputMem.getStaticDims(), + getDstMemoryAtPort(0)); } bool EmbeddingBagPacked::created() const { return getType() == Type::EmbeddingBagPacked; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/embedding_bag_packed.h b/src/plugins/intel_cpu/src/nodes/embedding_bag_packed.h index 6a9d33fe3afccb..a018d1b48929e1 100644 --- a/src/plugins/intel_cpu/src/nodes/embedding_bag_packed.h +++ b/src/plugins/intel_cpu/src/nodes/embedding_bag_packed.h @@ -15,7 +15,7 @@ class EmbeddingBagPacked : public Node, public EmbeddingBag { public: EmbeddingBagPacked(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -36,6 +36,6 @@ class EmbeddingBagPacked : public Node, public EmbeddingBag { size_t _indicesPerBag = 0; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/embedding_segments_sum.cpp b/src/plugins/intel_cpu/src/nodes/embedding_segments_sum.cpp index 2a012c6b941831..8bd91799834bad 100644 --- a/src/plugins/intel_cpu/src/nodes/embedding_segments_sum.cpp +++ b/src/plugins/intel_cpu/src/nodes/embedding_segments_sum.cpp @@ -2,17 +2,20 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "embedding_segments_sum.h" + #include -#include #include -#include "embedding_segments_sum.h" +#include + #include "openvino/opsets/opset3.hpp" namespace ov { namespace intel_cpu { namespace node { -bool EmbeddingSegmentsSum::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool EmbeddingSegmentsSum::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { const auto embBagSegSumOp = ov::as_type_ptr(op); if (!embBagSegSumOp) { @@ -46,8 +49,10 @@ void EmbeddingSegmentsSum::initSupportedPrimitiveDescriptors() { return; std::string logPrefix = std::string("Layer EmbeddingBag with name '") + _layerName + "' "; - static const std::set supportedPrecisions = - {ov::element::f32, ov::element::i8, ov::element::u8, ov::element::i32}; + static const std::set supportedPrecisions = {ov::element::f32, + ov::element::i8, + ov::element::u8, + ov::element::i32}; auto inDataPrecision = getOriginalInputPrecisionAtPort(EMB_TABLE_IDX); if (one_of(inDataPrecision, ov::element::bf16, ov::element::f16)) @@ -56,8 +61,10 @@ void EmbeddingSegmentsSum::initSupportedPrimitiveDescriptors() { if (supportedPrecisions.find(inDataPrecision) == supportedPrecisions.end()) OPENVINO_THROW(logPrefix, "has unsupported precision: ", inDataPrecision.get_type_name()); } else { - static const std::set defaultSupportedPrecisions = - {ov::element::f32, ov::element::i8, ov::element::u8, ov::element::i32}; + static const std::set defaultSupportedPrecisions = {ov::element::f32, + ov::element::i8, + ov::element::u8, + ov::element::i32}; if (defaultSupportedPrecisions.find(inDataPrecision) == defaultSupportedPrecisions.end()) OPENVINO_THROW(logPrefix, "has unsupported precision: ", inDataPrecision.get_type_name()); } @@ -90,7 +97,11 @@ void EmbeddingSegmentsSum::initFromInputs() { } } -void EmbeddingSegmentsSum::getIndices(size_t embIndex, const int*& indices, size_t& size, int& weightsIdx, bool& withWeight) { +void EmbeddingSegmentsSum::getIndices(size_t embIndex, + const int*& indices, + size_t& size, + int& weightsIdx, + bool& withWeight) { if (embIndex >= static_cast(lastNumSegments_)) OPENVINO_THROW("Invalid embedding bag index."); @@ -143,20 +154,23 @@ bool EmbeddingSegmentsSum::isExecutable() const { } void EmbeddingSegmentsSum::execute(dnnl::stream strm) { - const auto *srcData = getSrcDataAtPortAs(0); + const auto* srcData = getSrcDataAtPortAs(0); const uint8_t* weightsData = nullptr; if (_withWeights) weightsData = getSrcDataAtPortAs(PER_SAMPLE_WEIGHTS_IDX); - const auto &inputMem = getParentEdgeAt(0)->getMemory(); - EmbeddingBag::execute(srcData, weightsData, inputMem.getDesc().getPrecision(), - inputMem.getStaticDims(), getDstMemoryAtPort(0)); + const auto& inputMem = getParentEdgeAt(0)->getMemory(); + EmbeddingBag::execute(srcData, + weightsData, + inputMem.getDesc().getPrecision(), + inputMem.getStaticDims(), + getDstMemoryAtPort(0)); } bool EmbeddingSegmentsSum::created() const { return getType() == Type::EmbeddingSegmentsSum; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/embedding_segments_sum.h b/src/plugins/intel_cpu/src/nodes/embedding_segments_sum.h index bb312b4dd47246..984b9de68690b2 100644 --- a/src/plugins/intel_cpu/src/nodes/embedding_segments_sum.h +++ b/src/plugins/intel_cpu/src/nodes/embedding_segments_sum.h @@ -15,7 +15,7 @@ class EmbeddingSegmentsSum : public Node, public EmbeddingBag { public: EmbeddingSegmentsSum(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -45,6 +45,6 @@ class EmbeddingSegmentsSum : public Node, public EmbeddingBag { size_t indicesSize_ = 0; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/common/ref_convert.cpp b/src/plugins/intel_cpu/src/nodes/executors/common/ref_convert.cpp index 2bba0f5e73c0fe..de65176fb72235 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/common/ref_convert.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/common/ref_convert.cpp @@ -3,6 +3,7 @@ // #include "ref_convert.hpp" + #include "nodes/common/cpu_convert.h" namespace ov { @@ -13,9 +14,9 @@ bool CommonConvertExecutor::isSupported(ov::element::Type srcPrc, ov::element::T } bool CommonConvertExecutor::init(const ConvertParams& convertParams, - const MemoryDescPtr& srcDesc, - const MemoryDescPtr& dstDesc, - const dnnl::primitive_attr& attr) { + const MemoryDescPtr& srcDesc, + const MemoryDescPtr& dstDesc, + const dnnl::primitive_attr& attr) { commonConvertParams = convertParams; return true; } @@ -32,5 +33,5 @@ void CommonConvertExecutor::exec(const std::vector& src, const std:: commonConvertParams.size); } -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/common/ref_convert.hpp b/src/plugins/intel_cpu/src/nodes/executors/common/ref_convert.hpp index 337d377f3b3339..4bc3a709d2bcd2 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/common/ref_convert.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/common/ref_convert.hpp @@ -15,9 +15,11 @@ class CommonConvertExecutor : public ConvertExecutor { bool init(const ConvertParams& convertParams, const MemoryDescPtr& srcDesc, const MemoryDescPtr& dstDesc, - const dnnl::primitive_attr &attr) override; + const dnnl::primitive_attr& attr) override; void exec(const std::vector& src, const std::vector& dst) override; - impl_desc_type implType() const override { return implDescType; }; + impl_desc_type implType() const override { + return implDescType; + }; static bool isSupported(ov::element::Type srcPrc, ov::element::Type dstPrc); protected: @@ -26,7 +28,6 @@ class CommonConvertExecutor : public ConvertExecutor { const ExecutorContext::CPtr convertContext; }; - class CommonConvertExecutorBuilder : public ConvertExecutorBuilder { public: ~CommonConvertExecutorBuilder() = default; @@ -40,5 +41,5 @@ class CommonConvertExecutorBuilder : public ConvertExecutorBuilder { } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/common/ref_opt_transpose.cpp b/src/plugins/intel_cpu/src/nodes/executors/common/ref_opt_transpose.cpp index 0e1d43b48f6224..dd0cea3d238a4e 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/common/ref_opt_transpose.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/common/ref_opt_transpose.cpp @@ -3,6 +3,7 @@ // #include "ref_opt_transpose.hpp" + #include "openvino/core/parallel.hpp" namespace ov { @@ -26,21 +27,15 @@ void transpose_to_0312(const int MB, const MemoryCPtr& srcMemPtr, MemoryPtr& dst parallel_for3d(MB, DIM1, DIM2, [&](const int n, const int dim1, const int dim2) { for (int dim3 = 0; dim3 < DIM3; ++dim3) { - const int src_off = n * DIM1 * DIM2 * DIM3 + - dim1 * DIM2 * DIM3 + - dim2 * DIM3 + - dim3; - const int dst_off = n * DIM1 * DIM2 * DIM3 + - dim3 * DIM1 * DIM2 + - dim1 * DIM2 + - dim2; + const int src_off = n * DIM1 * DIM2 * DIM3 + dim1 * DIM2 * DIM3 + dim2 * DIM3 + dim3; + const int dst_off = n * DIM1 * DIM2 * DIM3 + dim3 * DIM1 * DIM2 + dim1 * DIM2 + dim2; dst_data[dst_off] = src_data[src_off]; } }); } -template +template void transpose_to_04123(const int MB, const MemoryCPtr& srcMemPtr, MemoryPtr& dstMemPtr) { const auto src_data = srcMemPtr->getDataAs(); auto dst_data = dstMemPtr->getDataAs(); @@ -52,23 +47,17 @@ void transpose_to_04123(const int MB, const MemoryCPtr& srcMemPtr, MemoryPtr& ds parallel_for4d(MB, DIM1, DIM2, DIM3, [&](const int n, const int dim1, const int dim2, const int dim3) { for (int dim4 = 0; dim4 < DIM4; ++dim4) { - const int src_off = n * DIM1 * DIM2 * DIM3 * DIM4 + - dim1 * DIM2 * DIM3 * DIM4 + - dim2 * DIM3 * DIM4 + - dim3 * DIM4 + - dim4; - const int dst_off = n * DIM1 * DIM2 * DIM3 * DIM4 + - dim4 * DIM1 * DIM2 * DIM3 + - dim1 * DIM2 * DIM3 + - dim2 * DIM3 + - dim3; + const int src_off = + n * DIM1 * DIM2 * DIM3 * DIM4 + dim1 * DIM2 * DIM3 * DIM4 + dim2 * DIM3 * DIM4 + dim3 * DIM4 + dim4; + const int dst_off = + n * DIM1 * DIM2 * DIM3 * DIM4 + dim4 * DIM1 * DIM2 * DIM3 + dim1 * DIM2 * DIM3 + dim2 * DIM3 + dim3; dst_data[dst_off] = src_data[src_off]; } }); } -template +template void transpose_to_051234(const int MB, const MemoryCPtr& srcMemPtr, MemoryPtr& dstMemPtr) { const auto src_data = srcMemPtr->getDataAs(); auto dst_data = dstMemPtr->getDataAs(); @@ -79,61 +68,61 @@ void transpose_to_051234(const int MB, const MemoryCPtr& srcMemPtr, MemoryPtr& d const int DIM4 = srcMemPtr->getStaticDims()[4]; const int DIM5 = srcMemPtr->getStaticDims()[5]; - parallel_for5d(MB, DIM1, DIM2, DIM3, DIM4, [&](const int n, const int dim1, const int dim2, const int dim3, const int dim4) { - for (int dim5 = 0; dim5 < DIM5; ++dim5) { - const int src_off = n * DIM1 * DIM2 * DIM3 * DIM4 * DIM5 + - dim1 * DIM2 * DIM3 * DIM4 * DIM5 + - dim2 * DIM3 * DIM4 * DIM5 + - dim3 * DIM4 * DIM5 + - dim4 * DIM5 + - dim5; - const int dst_off = n * DIM5 * DIM1 * DIM2 * DIM3 * DIM4 + - dim5 * DIM1 * DIM2 * DIM3 * DIM4 + - dim1 * DIM2 * DIM3 * DIM4 + - dim2 * DIM3 * DIM4 + - dim3 * DIM4 + - dim4; - - dst_data[dst_off] = src_data[src_off]; - } - }); + parallel_for5d(MB, + DIM1, + DIM2, + DIM3, + DIM4, + [&](const int n, const int dim1, const int dim2, const int dim3, const int dim4) { + for (int dim5 = 0; dim5 < DIM5; ++dim5) { + const int src_off = n * DIM1 * DIM2 * DIM3 * DIM4 * DIM5 + dim1 * DIM2 * DIM3 * DIM4 * DIM5 + + dim2 * DIM3 * DIM4 * DIM5 + dim3 * DIM4 * DIM5 + dim4 * DIM5 + dim5; + const int dst_off = n * DIM5 * DIM1 * DIM2 * DIM3 * DIM4 + dim5 * DIM1 * DIM2 * DIM3 * DIM4 + + dim1 * DIM2 * DIM3 * DIM4 + dim2 * DIM3 * DIM4 + dim3 * DIM4 + dim4; + + dst_data[dst_off] = src_data[src_off]; + } + }); } -template +template struct TransposeOptimizedEmitter { void operator()(TransposeContext& ctx) { switch (ctx.srcMemPtr->getStaticDims().size()) { - case 4: - transpose_to_0312(ctx.MB, ctx.srcMemPtr, ctx.dstMemPtr); - break; - case 5: - transpose_to_04123(ctx.MB, ctx.srcMemPtr, ctx.dstMemPtr); - break; - case 6: - transpose_to_051234(ctx.MB, ctx.srcMemPtr, ctx.dstMemPtr); - break; - default: - OPENVINO_THROW("Transpose supports optimized execution with only 4D, 5D and 6D shapes"); + case 4: + transpose_to_0312(ctx.MB, ctx.srcMemPtr, ctx.dstMemPtr); + break; + case 5: + transpose_to_04123(ctx.MB, ctx.srcMemPtr, ctx.dstMemPtr); + break; + case 6: + transpose_to_051234(ctx.MB, ctx.srcMemPtr, ctx.dstMemPtr); + break; + default: + OPENVINO_THROW("Transpose supports optimized execution with only 4D, 5D and 6D shapes"); } } }; -} // namespace +} // namespace void RefOptimizedTransposeExecutor::exec(const std::vector& src, const std::vector& dst) { const size_t dataSize = src[0]->getDesc().getPrecision().size(); const int MB = src[0]->getStaticDims()[0]; TransposeContext ctx = {src[0], dst[0], MB}; - OV_SWITCH(intel_cpu, TransposeOptimizedEmitter, ctx, dataSize, + OV_SWITCH(intel_cpu, + TransposeOptimizedEmitter, + ctx, + dataSize, OV_CASE(1u, element_type_traits::value_type), OV_CASE(2u, element_type_traits::value_type), OV_CASE(4u, element_type_traits::value_type)); } -bool RefOptimizedTransposeExecutor::init(const TransposeParams &transposeParams, - const std::vector &srcDescs, - const std::vector &dstDescs, - const dnnl::primitive_attr &attr) { +bool RefOptimizedTransposeExecutor::init(const TransposeParams& transposeParams, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { return true; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/common/ref_opt_transpose.hpp b/src/plugins/intel_cpu/src/nodes/executors/common/ref_opt_transpose.hpp index be420bfb009e5a..65da099caa0f33 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/common/ref_opt_transpose.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/common/ref_opt_transpose.hpp @@ -13,12 +13,14 @@ class RefOptimizedTransposeExecutor : public TransposeExecutor { public: using TransposeExecutor::TransposeExecutor; - bool init(const TransposeParams &transposeParams, - const std::vector &srcDescs, - const std::vector &dstDescs, - const dnnl::primitive_attr &attr) override; - void exec(const std::vector &src, const std::vector &dst) override; - impl_desc_type implType() const override { return impl_desc_type::ref; } + bool init(const TransposeParams& transposeParams, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) override; + void exec(const std::vector& src, const std::vector& dst) override; + impl_desc_type implType() const override { + return impl_desc_type::ref; + } }; class RefOptimizedTransposeExecutorBuilder : public TransposeExecutorBuilder { @@ -27,12 +29,13 @@ class RefOptimizedTransposeExecutorBuilder : public TransposeExecutorBuilder { const std::vector& srcDescs, const std::vector& dstDescs) const override { static const std::vector> optimizedOrders = { - std::vector{0, 3, 1, 2}, - std::vector{0, 4, 1, 2, 3}, - std::vector{0, 5, 1, 2, 3, 4}, + std::vector{0, 3, 1, 2}, + std::vector{0, 4, 1, 2, 3}, + std::vector{0, 5, 1, 2, 3, 4}, }; if (srcDescs[0]->hasLayoutType(LayoutType::ncsp) && - std::find(optimizedOrders.begin(), optimizedOrders.end(), transposeParams.permuteParams.order) != optimizedOrders.end()) { + std::find(optimizedOrders.begin(), optimizedOrders.end(), transposeParams.permuteParams.order) != + optimizedOrders.end()) { return true; } DEBUG_LOG("RefOptimizedTransposeExecutor is not supported, because passed order is not optimized"); @@ -44,5 +47,5 @@ class RefOptimizedTransposeExecutorBuilder : public TransposeExecutorBuilder { } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/common/ref_transpose.cpp b/src/plugins/intel_cpu/src/nodes/executors/common/ref_transpose.cpp index 8db8798ef8eaff..1716f008027fe9 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/common/ref_transpose.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/common/ref_transpose.cpp @@ -3,8 +3,9 @@ // #include "ref_transpose.hpp" -#include "openvino/core/parallel.hpp" + #include "nodes/common/cpu_memcpy.h" +#include "openvino/core/parallel.hpp" namespace ov { namespace intel_cpu { @@ -27,7 +28,10 @@ static inline void parallel_step(size_t nDims, const VectorDims& dims, VectorDim } } -void RefTransposeExecutor::referenceExecute(const uint8_t* src_data, uint8_t* dst_data, jit_permute_config_params jcp, const int mb) { +void RefTransposeExecutor::referenceExecute(const uint8_t* src_data, + uint8_t* dst_data, + jit_permute_config_params jcp, + const int mb) { VectorDims dst_dims = jcp.dst_block_dims; const VectorDims dst_strides = jcp.dst_strides; const VectorDims src_strides = jcp.src_strides; @@ -70,13 +74,13 @@ void RefTransposeExecutor::exec(const std::vector& src, const std::v referenceExecute(src_data, dst_data, jcp, MB); } -bool RefTransposeExecutor::init(const TransposeParams &transposeParams, - const std::vector &srcDescs, - const std::vector &dstDescs, - const dnnl::primitive_attr &attr) { +bool RefTransposeExecutor::init(const TransposeParams& transposeParams, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { jcp = TransposeExecutor::prepareParams(transposeParams.permuteParams); return true; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/common/ref_transpose.hpp b/src/plugins/intel_cpu/src/nodes/executors/common/ref_transpose.hpp index 206d610368a9df..00c1602c0bd119 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/common/ref_transpose.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/common/ref_transpose.hpp @@ -11,13 +11,19 @@ namespace intel_cpu { class RefTransposeExecutor : public TransposeExecutor { public: using TransposeExecutor::TransposeExecutor; - static void referenceExecute(const uint8_t* src_data, uint8_t* dst_data, jit_permute_config_params jcp, const int mb); - bool init(const TransposeParams &transposeParams, - const std::vector &srcDescs, - const std::vector &dstDescs, - const dnnl::primitive_attr &attr) override; - void exec(const std::vector &src, const std::vector &dst) override; - impl_desc_type implType() const override { return impl_desc_type::ref; } + static void referenceExecute(const uint8_t* src_data, + uint8_t* dst_data, + jit_permute_config_params jcp, + const int mb); + bool init(const TransposeParams& transposeParams, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) override; + void exec(const std::vector& src, const std::vector& dst) override; + impl_desc_type implType() const override { + return impl_desc_type::ref; + } + private: jit_permute_config_params jcp; }; @@ -35,5 +41,5 @@ class RefTransposeExecutorBuilder : public TransposeExecutorBuilder { } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/convert.cpp b/src/plugins/intel_cpu/src/nodes/executors/convert.cpp index c8d7ce8addaf22..32141d53b10ee5 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/convert.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/convert.cpp @@ -4,4 +4,5 @@ #include "convert.hpp" -ov::intel_cpu::ConvertExecutor::ConvertExecutor(const ov::intel_cpu::ExecutorContext::CPtr context) : convertContext(context) {} \ No newline at end of file +ov::intel_cpu::ConvertExecutor::ConvertExecutor(const ov::intel_cpu::ExecutorContext::CPtr context) + : convertContext(context) {} \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/convert.hpp b/src/plugins/intel_cpu/src/nodes/executors/convert.hpp index ce766663a0b653..dcb0bdde2ce219 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/convert.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/convert.hpp @@ -5,8 +5,8 @@ #pragma once #include "cpu_memory.h" -#include "onednn/iml_type_mapper.h" #include "executor.hpp" +#include "onednn/iml_type_mapper.h" namespace ov { namespace intel_cpu { @@ -24,8 +24,9 @@ class ConvertExecutor : public Executor { virtual bool init(const ConvertParams& convertParams, const MemoryDescPtr& srcDesc, const MemoryDescPtr& dstDesc, - const dnnl::primitive_attr &attr) = 0; + const dnnl::primitive_attr& attr) = 0; virtual ~ConvertExecutor() = default; + protected: ConvertParams convertParams; const ExecutorContext::CPtr convertContext; @@ -45,5 +46,5 @@ class ConvertExecutorBuilder { using ConvertExecutorBuilderPtr = std::shared_ptr; using ConvertExecutorBuilderCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/convert_list.cpp b/src/plugins/intel_cpu/src/nodes/executors/convert_list.cpp index 504c310ca15124..5375bd21166cc4 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/convert_list.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/convert_list.cpp @@ -9,9 +9,8 @@ namespace intel_cpu { const std::vector& getConvertExecutorsList() { static std::vector descs = { - OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) - OV_CPU_INSTANCE_COMMON(ExecutorType::Common, std::make_shared()) - }; + OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) + OV_CPU_INSTANCE_COMMON(ExecutorType::Common, std::make_shared())}; return descs; } @@ -45,5 +44,5 @@ ConvertExecutorPtr ConvertExecutorFactory::makeExecutor(const ConvertParams& con OPENVINO_THROW("Supported executor is not found"); } -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/convert_list.hpp b/src/plugins/intel_cpu/src/nodes/executors/convert_list.hpp index a7ed05ceb634e4..9ea47f916d859f 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/convert_list.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/convert_list.hpp @@ -4,17 +4,15 @@ #pragma once -#include "executor.hpp" - #include "convert.hpp" +#include "executor.hpp" #if defined(OV_CPU_WITH_ACL) -#include "acl/acl_convert.hpp" +# include "acl/acl_convert.hpp" #endif +#include "common/primitive_cache.hpp" #include "common/ref_convert.hpp" - #include "onednn/iml_type_mapper.h" -#include "common/primitive_cache.hpp" namespace ov { namespace intel_cpu { @@ -31,7 +29,8 @@ class ConvertExecutorFactory : public ExecutorFactoryLegacy { ConvertExecutorFactory(const ConvertParams& convertParams, const MemoryDescPtr& srcDesc, const MemoryDescPtr& dstDesc, - const ExecutorContext::CPtr context) : ExecutorFactoryLegacy(context) { + const ExecutorContext::CPtr context) + : ExecutorFactoryLegacy(context) { for (auto& desc : getConvertExecutorsList()) { if (desc.builder->isSupported(convertParams, srcDesc, dstDesc)) { supportedDescs.push_back(desc); @@ -43,7 +42,7 @@ class ConvertExecutorFactory : public ExecutorFactoryLegacy { virtual ConvertExecutorPtr makeExecutor(const ConvertParams& convertParams, const MemoryDescPtr& srcDesc, const MemoryDescPtr& dstDesc, - const dnnl::primitive_attr &attr); + const dnnl::primitive_attr& attr); private: std::vector supportedDescs; @@ -53,5 +52,5 @@ class ConvertExecutorFactory : public ExecutorFactoryLegacy { using ConvertExecutorFactoryPtr = std::shared_ptr; using ConvertExecutorFactoryCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp b/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp index 26ae6ace59631b..222779a00ee18f 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/debug_messages.hpp @@ -4,25 +4,25 @@ #pragma once -#define UNSUPPORTED_SPARSE_WEIGHTS " sparse weights are not supported" +#define UNSUPPORTED_SPARSE_WEIGHTS " sparse weights are not supported" #define UNSUPPORTED_WEIGHTS_DECOMPRESSION " weights decompression is not supported" -#define UNSUPPORTED_POST_OPS " post ops are not supported" -#define UNSUPPORTED_NUMBER_OF_POSTOPS " the number of post ops is not supported" -#define UNSUPPORTED_TYPE_OF_POSTOPS " the type of post ops is not supported" -#define UNSUPPORTED_SRC_PRECISIONS " unsupported src precisions" -#define UNSUPPORTED_WEI_PRECISIONS " unsupported wei precisions" -#define UNSUPPORTED_DST_PRECISIONS " unsupported dst precisions" -#define UNSUPPORTED_ISA " unsupported isa" -#define UNSUPPORTED_SRC_RANK " unsupported src rank" -#define UNSUPPORTED_WEI_RANK " unsupported wei rank" -#define UNSUPPORTED_DST_RANK " unsupported dst rank" -#define UNSUPPORTED_DST_STRIDES " unsupported dst strides" -#define HEURISTICS_MISMATCH " heuristics mismatch" +#define UNSUPPORTED_POST_OPS " post ops are not supported" +#define UNSUPPORTED_NUMBER_OF_POSTOPS " the number of post ops is not supported" +#define UNSUPPORTED_TYPE_OF_POSTOPS " the type of post ops is not supported" +#define UNSUPPORTED_SRC_PRECISIONS " unsupported src precisions" +#define UNSUPPORTED_WEI_PRECISIONS " unsupported wei precisions" +#define UNSUPPORTED_DST_PRECISIONS " unsupported dst precisions" +#define UNSUPPORTED_ISA " unsupported isa" +#define UNSUPPORTED_SRC_RANK " unsupported src rank" +#define UNSUPPORTED_WEI_RANK " unsupported wei rank" +#define UNSUPPORTED_DST_RANK " unsupported dst rank" +#define UNSUPPORTED_DST_STRIDES " unsupported dst strides" +#define HEURISTICS_MISMATCH " heuristics mismatch" -#define VERIFY(condition, ...) \ - do { \ - if (!(condition)) { \ +#define VERIFY(condition, ...) \ + do { \ + if (!(condition)) { \ DEBUG_LOG(__VA_ARGS__); \ - return false; \ - } \ + return false; \ + } \ } while (0) diff --git a/src/plugins/intel_cpu/src/nodes/executors/deconv.cpp b/src/plugins/intel_cpu/src/nodes/executors/deconv.cpp index 23e0910bd0c82c..e485815e950af4 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/deconv.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/deconv.cpp @@ -5,8 +5,5 @@ #include "deconv.hpp" namespace ov { -namespace intel_cpu { - - -} // namespace intel_cpu -} // namespace ov +namespace intel_cpu {} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/deconv.hpp b/src/plugins/intel_cpu/src/nodes/executors/deconv.hpp index c632cc0cf99ad1..11920c0ab35b49 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/deconv.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/deconv.hpp @@ -34,11 +34,11 @@ class DeconvExecutor { virtual bool init(const DeconvAttrs& deconvAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) = 0; + const dnnl::primitive_attr& attr) = 0; virtual void exec(const std::vector& src, const std::vector& dst, - const void *post_ops_data_) = 0; + const void* post_ops_data_) = 0; virtual ~DeconvExecutor() = default; virtual impl_desc_type getImplType() const = 0; @@ -53,12 +53,14 @@ using DeconvExecutorCPtr = std::shared_ptr; class DeconvExecutorBuilder { public: ~DeconvExecutorBuilder() = default; - virtual bool isSupported(const DeconvAttrs& convAttrs, const std::vector& srcDescs, const std::vector& dstDescs) const = 0; + virtual bool isSupported(const DeconvAttrs& convAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs) const = 0; virtual DeconvExecutorPtr makeExecutor(const ExecutorContext::CPtr context) const = 0; }; using DeconvExecutorBuilderPtr = std::shared_ptr; using DeconvExecutorBuilderCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp index f5b897c2d1b6e1..c093057e47413f 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.cpp @@ -9,11 +9,10 @@ namespace intel_cpu { const std::vector& getDeconvExecutorsList() { static std::vector descs = { - OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) - }; + OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared())}; return descs; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.hpp b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.hpp index 4c63a565aac2e0..fd114094303808 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/deconv_list.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/deconv_list.hpp @@ -4,15 +4,14 @@ #pragma once -#include "executor.hpp" - #include "deconv.hpp" +#include "executor.hpp" #if defined(OV_CPU_WITH_ACL) -#include "acl/acl_deconv.hpp" +# include "acl/acl_deconv.hpp" #endif -#include "onednn/iml_type_mapper.h" #include "common/primitive_cache.hpp" +#include "onednn/iml_type_mapper.h" namespace ov { namespace intel_cpu { @@ -29,7 +28,8 @@ class DeconvExecutorFactory : public ExecutorFactoryLegacy { DeconvExecutorFactory(const DeconvAttrs& deconvAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const ExecutorContext::CPtr context) : ExecutorFactoryLegacy(context) { + const ExecutorContext::CPtr context) + : ExecutorFactoryLegacy(context) { for (auto& desc : getDeconvExecutorsList()) { if (desc.builder->isSupported(deconvAttrs, srcDescs, dstDescs)) { supportedDescs.push_back(desc); @@ -41,7 +41,7 @@ class DeconvExecutorFactory : public ExecutorFactoryLegacy { virtual DeconvExecutorPtr makeExecutor(const DeconvAttrs& deconvAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) { + const dnnl::primitive_attr& attr) { auto build = [&](const DeconvExecutorDesc* desc) { auto executor = desc->builder->makeExecutor(context); if (executor->init(deconvAttrs, srcDescs, dstDescs, attr)) { @@ -75,5 +75,5 @@ class DeconvExecutorFactory : public ExecutorFactoryLegacy { using DeconvExecutorFactoryPtr = std::shared_ptr; using DeconvExecutorFactoryCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_aliases.hpp b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_aliases.hpp index a611e94f617e44..27fa7dd38d7a99 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_aliases.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_aliases.hpp @@ -4,8 +4,8 @@ #pragma once -#include #include +#include namespace ov { namespace intel_cpu { diff --git a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected.hpp b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected.hpp index 1d078feaa6549b..db5c8bed2e43e1 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected.hpp @@ -8,10 +8,10 @@ #include #include "cpu_memory.h" +#include "memory_desc/cpu_memory_desc_utils.h" #include "nodes/executors/dnnl/dnnl_aliases.hpp" #include "nodes/executors/dnnl/dnnl_utils.hpp" #include "nodes/executors/executor.hpp" -#include "memory_desc/cpu_memory_desc_utils.h" #include "nodes/executors/memory_arguments.hpp" #include "post_ops.hpp" @@ -123,7 +123,8 @@ class DnnlFCExecutor : public Executor { if (currentPrimitive && currentPrimitive->weightsDesc()->isCompatible(*newPrimMemDesc)) return; - originalMemDesc = Primitive::makeTransposedWeightDescriptor(originalMemDesc, newPrimMemDesc, m_attrs.weightsNonTransposed); + originalMemDesc = + Primitive::makeTransposedWeightDescriptor(originalMemDesc, newPrimMemDesc, m_attrs.weightsNonTransposed); const auto weiMemory = utils::prepareWeightsMemory(originalMemDesc, newPrimMemDesc, memory, m_context, true); m_primArgs[DNNL_ARG_WEIGHTS] = weiMemory->getPrimitive(); @@ -143,9 +144,7 @@ class DnnlFCExecutor : public Executor { m_primArgs[DNNL_ARG_SCRATCHPAD] = m_scratchPadMemory->getPrimitive(); } - void updateMemory(const PrimitivePtr currentPrimitive, - const PrimitivePtr newPrimitive, - const MemoryArgs& memory) { + void updateMemory(const PrimitivePtr currentPrimitive, const PrimitivePtr newPrimitive, const MemoryArgs& memory) { const auto& srcDesc = MemoryDescUtils::convertToDnnlMemoryDesc(memory.at(ARG_SRC)->getDescPtr()); const auto& weiDesc = MemoryDescUtils::convertToDnnlMemoryDesc(memory.at(ARG_WEI)->getDescPtr()); const auto& dstDesc = MemoryDescUtils::convertToDnnlMemoryDesc(memory.at(ARG_DST)->getDescPtr()); diff --git a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected_primitive.cpp b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected_primitive.cpp index 780dbb6f2f3f11..52434a1eeb8461 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected_primitive.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_fullyconnected_primitive.cpp @@ -74,9 +74,8 @@ bool DnnlFCPrimitive::Key::operator==(const Key& rhs) const { result = result && dst && rhs.dst && dst->getDnnlDesc() == rhs.dst->getDnnlDesc(); } - result = result && *attr.get() == *rhs.attr.get() && - sparseWeights == rhs.sparseWeights && - modelType == rhs.modelType; + result = + result && *attr.get() == *rhs.attr.get() && sparseWeights == rhs.sparseWeights && modelType == rhs.modelType; return result; } @@ -158,7 +157,8 @@ static bool useDynamicQuantizationImpl(size_t dqGroupSize, if (srcDesc->getPrecision() != ov::element::f32) return false; - MemoryCPtr zpPtr = memory.count(ARG_WEI | ARG_ATTR_ZERO_POINTS) ? memory.at(ARG_WEI | ARG_ATTR_ZERO_POINTS) : nullptr; + MemoryCPtr zpPtr = + memory.count(ARG_WEI | ARG_ATTR_ZERO_POINTS) ? memory.at(ARG_WEI | ARG_ATTR_ZERO_POINTS) : nullptr; // For dynamic quantization, VNNI accumulation requires weight to be unsigned. // To support dynamic quantization with weights symmetrically quantized as i8/i4 // w/o zero-point, we will transform weight to u8/u4 weight with zp 128/8. @@ -220,14 +220,8 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const FCAttrs& attrs, one_of(srcDesc->getPrecision(), ov::element::u8, ov::element::i8) && weiDesc->getPrecision() == ov::element::i8; auto outputDataType = DnnlExtensionUtils::ElementTypeToDataType(dstDesc->getPrecision()); - DnnlPostOpsComposer dnnlpoc(postOps, - context->getEngine(), - dims, - dims.size() - 1, - isINT8, - 1 << 0, - memory, - outputDataType); + DnnlPostOpsComposer + dnnlpoc(postOps, context->getEngine(), dims, dims.size() - 1, isINT8, 1 << 0, memory, outputDataType); if (memory.count(ARG_WEI | ARG_ATTR_SCALES)) { auto dstPrc = memory.at(ARG_WEI | ARG_ATTR_SCALES)->getPrecision(); @@ -239,7 +233,9 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const FCAttrs& attrs, if (memory.count(ARG_WEI | ARG_ATTR_ZERO_POINTS)) { auto dstPrc = useDynamicQuantization ? ov::element::u8 : ov::element::f32; - dnnlpoc.appendDecompressionZeroPoints(memory.at(ARG_WEI | ARG_ATTR_ZERO_POINTS), !attrs.weightsNonTransposed, dstPrc); + dnnlpoc.appendDecompressionZeroPoints(memory.at(ARG_WEI | ARG_ATTR_ZERO_POINTS), + !attrs.weightsNonTransposed, + dstPrc); } if (useDynamicQuantization) { diff --git a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_matmul_primitive.cpp b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_matmul_primitive.cpp index 40c365ee5f4da5..86b22607111833 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_matmul_primitive.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_matmul_primitive.cpp @@ -129,14 +129,8 @@ static DnnlPrimitiveAttrs createPrimitiveAttrs(const MatMulAttrs& attrs, one_of(srcDesc->getPrecision(), ov::element::u8, ov::element::i8) && weiDesc->getPrecision() == ov::element::i8; auto outputDataType = DnnlExtensionUtils::ElementTypeToDataType(dstDesc->getPrecision()); - DnnlPostOpsComposer dnnlpoc(postOps, - context->getEngine(), - dims, - dims.size() - 1, - isINT8, - 1 << 0, - memory, - outputDataType); + DnnlPostOpsComposer + dnnlpoc(postOps, context->getEngine(), dims, dims.size() - 1, isINT8, 1 << 0, memory, outputDataType); return dnnlpoc.compose(); } @@ -185,8 +179,7 @@ static dnnl::matmul::primitive_desc createDescriptorInternal(const dnnl::memory: wdt = memory::data_type::s8; } - const dnnl::memory::desc weightsDesc = - dnnl::memory::desc(weiDims, wdt, memory::format_tag::any); + const dnnl::memory::desc weightsDesc = dnnl::memory::desc(weiDims, wdt, memory::format_tag::any); return dnnl::matmul::primitive_desc(engine, inputsDesc, weightsDesc, newBiasDesc, outputsDesc, attr); } @@ -335,7 +328,8 @@ DnnlMatMulPrimitive::DnnlMatMulPrimitive(const Key& key, m_prim(primitive(m_primDesc)) {} void DnnlMatMulPrimitive::execute(const dnnl_primitive_args& primArgs) const { - std::cout << "Executing MM primitive" << "\n"; + std::cout << "Executing MM primitive" + << "\n"; m_prim.execute(m_stream, primArgs); } diff --git a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_shape_agnostic_data.hpp b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_shape_agnostic_data.hpp index d76e1984bd87d9..6a1b128be307ce 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_shape_agnostic_data.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_shape_agnostic_data.hpp @@ -12,8 +12,7 @@ namespace ov { namespace intel_cpu { struct DnnlShapeAgnosticData { - DnnlShapeAgnosticData(DnnlPrimitiveAttrs primAttrs) - : primAttrs(std::move(primAttrs)) {} + DnnlShapeAgnosticData(DnnlPrimitiveAttrs primAttrs) : primAttrs(std::move(primAttrs)) {} DnnlPrimitiveAttrs primAttrs; }; diff --git a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.cpp b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.cpp index fa273ac3d6c3ff..f23fd317d3546d 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/dnnl/dnnl_utils.cpp @@ -8,8 +8,8 @@ #include #include "cpu_memory.h" -#include "memory_desc/dnnl_memory_desc.h" #include "memory_desc/cpu_memory_desc_utils.h" +#include "memory_desc/dnnl_memory_desc.h" #include "nodes/executors/executor.hpp" #include "nodes/reorder.h" #include "utils/cpu_utils.hpp" @@ -79,9 +79,9 @@ MemoryPtr prepareWeightsMemory(const DnnlMemoryDescPtr srcWeightDesc, auto globalWeightCache = context->getWeightsCache(); MemoryPtr ptr; - if (globalWeightCache && - dnnl::memory::format_kind::blocked == dstWeightDesc->getDnnlDesc().get_format_kind()) { - ptr = *globalWeightCache->findOrCreate(DnnlExtensionUtils::computeWeightsStringHash(weightsMem, dstWeightDesc), create); + if (globalWeightCache && dnnl::memory::format_kind::blocked == dstWeightDesc->getDnnlDesc().get_format_kind()) { + ptr = *globalWeightCache->findOrCreate(DnnlExtensionUtils::computeWeightsStringHash(weightsMem, dstWeightDesc), + create); } else { ptr = create(); } diff --git a/src/plugins/intel_cpu/src/nodes/executors/eltwise.cpp b/src/plugins/intel_cpu/src/nodes/executors/eltwise.cpp index 12bce382424e5c..8e7c470984b4f2 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/eltwise.cpp @@ -9,5 +9,5 @@ namespace intel_cpu { EltwiseExecutor::EltwiseExecutor(const ExecutorContext::CPtr context) : context(context) {} -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/eltwise.hpp b/src/plugins/intel_cpu/src/nodes/executors/eltwise.hpp index 4b1271c49d5df0..b33c0eca10dae7 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/eltwise.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/eltwise.hpp @@ -5,8 +5,8 @@ #pragma once #include "cpu_memory.h" -#include "onednn/iml_type_mapper.h" #include "executor.hpp" +#include "onednn/iml_type_mapper.h" namespace ov { namespace intel_cpu { @@ -19,10 +19,7 @@ struct EltwiseData { float gamma; bool operator==(const EltwiseData& rhs) const noexcept { - return algo == rhs.algo && - onednnAlgorithm == rhs.onednnAlgorithm && - alpha == rhs.alpha && - beta == rhs.beta && + return algo == rhs.algo && onednnAlgorithm == rhs.onednnAlgorithm && alpha == rhs.alpha && beta == rhs.beta && gamma == rhs.gamma; } }; @@ -34,24 +31,21 @@ struct EltwiseAttrs { float gamma; EltwiseAttrs() : algorithm(Algorithm::Default), alpha(0), beta(0), gamma(0) {} - EltwiseAttrs(Algorithm algorithm, float alpha, float beta, float gamma) : algorithm(algorithm), alpha(alpha), beta(beta), gamma(gamma) {} + EltwiseAttrs(Algorithm algorithm, float alpha, float beta, float gamma) + : algorithm(algorithm), + alpha(alpha), + beta(beta), + gamma(gamma) {} bool operator==(const EltwiseAttrs& rhs) const { bool retVal = true; - retVal = algorithm == rhs.algorithm && - alpha == rhs.alpha && - beta == rhs.beta && - gamma == rhs.gamma; + retVal = algorithm == rhs.algorithm && alpha == rhs.alpha && beta == rhs.beta && gamma == rhs.gamma; return retVal; } }; -enum class EltwisePostOpType { - Undefined, - Eltwise, - Dnnl -}; +enum class EltwisePostOpType { Undefined, Eltwise, Dnnl }; class EltwisePostOp { public: @@ -72,17 +66,20 @@ class EltwisePostOp { EltwisePostOpType type = EltwisePostOpType::Undefined; - bool operator==(const EltwisePostOp &rhs) const { - if (type != rhs.type) { return false; } + bool operator==(const EltwisePostOp& rhs) const { + if (type != rhs.type) { + return false; + } bool ret = true; switch (type) { - case EltwisePostOpType::Eltwise: - ret = eltwise == rhs.eltwise; - break; - case EltwisePostOpType::Dnnl: - ret = dnnlPostOps == rhs.dnnlPostOps; - break; - default: assert(!"unsupported eltwise post operation type"); + case EltwisePostOpType::Eltwise: + ret = eltwise == rhs.eltwise; + break; + case EltwisePostOpType::Dnnl: + ret = dnnlPostOps == rhs.dnnlPostOps; + break; + default: + assert(!"unsupported eltwise post operation type"); } return ret; } @@ -96,7 +93,9 @@ class EltwiseExecutor { const std::vector& dstDescs, const std::vector& postOps) = 0; - virtual void exec(const std::vector& src, const std::vector& dst, const void *post_ops_data_) = 0; + virtual void exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) = 0; virtual ~EltwiseExecutor() = default; virtual impl_desc_type getImplType() const = 0; @@ -121,5 +120,5 @@ class EltwiseExecutorBuilder { using EltwiseExecutorBuilderPtr = std::shared_ptr; using EltwiseExecutorBuilderCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/eltwise_list.cpp b/src/plugins/intel_cpu/src/nodes/executors/eltwise_list.cpp index 1bd6647310d387..5b9479bdf502b6 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/eltwise_list.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/eltwise_list.cpp @@ -10,11 +10,10 @@ namespace intel_cpu { const std::vector& getEltwiseExecutorsList() { static std::vector descs = { OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) - OV_CPU_INSTANCE_SHL(ExecutorType::Shl, std::make_shared()) - }; + OV_CPU_INSTANCE_SHL(ExecutorType::Shl, std::make_shared())}; return descs; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/eltwise_list.hpp b/src/plugins/intel_cpu/src/nodes/executors/eltwise_list.hpp index 618e3499dc10a7..ac5c27c0ad36dc 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/eltwise_list.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/eltwise_list.hpp @@ -4,19 +4,18 @@ #pragma once -#include "executor.hpp" - #include "eltwise.hpp" +#include "executor.hpp" #if defined(OV_CPU_WITH_ACL) -#include "aarch64/jit_eltwise.hpp" -#include "acl/acl_eltwise.hpp" +# include "aarch64/jit_eltwise.hpp" +# include "acl/acl_eltwise.hpp" #endif #if defined(OV_CPU_WITH_SHL) -#include "shl/shl_eltwise.hpp" +# include "shl/shl_eltwise.hpp" #endif -#include "onednn/iml_type_mapper.h" #include "common/primitive_cache.hpp" +#include "onednn/iml_type_mapper.h" namespace ov { namespace intel_cpu { @@ -31,9 +30,10 @@ const std::vector& getEltwiseExecutorsList(); class EltwiseExecutorFactory : public ExecutorFactoryLegacy { public: EltwiseExecutorFactory(const EltwiseAttrs& eltwiseAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs, - const ExecutorContext::CPtr context) : ExecutorFactoryLegacy(context) { + const std::vector& srcDescs, + const std::vector& dstDescs, + const ExecutorContext::CPtr context) + : ExecutorFactoryLegacy(context) { for (auto& desc : getEltwiseExecutorsList()) { if (desc.builder->isSupported(eltwiseAttrs, srcDescs, dstDescs)) { supportedDescs.push_back(desc); @@ -43,9 +43,9 @@ class EltwiseExecutorFactory : public ExecutorFactoryLegacy { ~EltwiseExecutorFactory() = default; virtual EltwiseExecutorPtr makeExecutor(const EltwiseAttrs& eltwiseAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs, - const std::vector& postOps) { + const std::vector& srcDescs, + const std::vector& dstDescs, + const std::vector& postOps) { auto build = [&](const EltwiseExecutorDesc* desc) { auto executor = desc->builder->makeExecutor(context); if (executor->init(eltwiseAttrs, srcDescs, dstDescs, postOps)) { @@ -84,5 +84,5 @@ class EltwiseExecutorFactory : public ExecutorFactoryLegacy { using EltwiseExecutorFactoryPtr = std::shared_ptr; using EltwiseExecutorFactoryCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/executor.cpp b/src/plugins/intel_cpu/src/nodes/executors/executor.cpp index 236f51c6d16149..399dab3d5499b9 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/executor.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/executor.cpp @@ -2,15 +2,17 @@ // SPDX-License-Identifier: Apache-2.0 // -#include - #include "executor.hpp" +#include + namespace ov { namespace intel_cpu { std::string ExecutorTypeToString(const ExecutorType type) { -#define CASE(_type) case ExecutorType::_type: return #_type; +#define CASE(_type) \ + case ExecutorType::_type: \ + return #_type; switch (type) { CASE(Undefined); CASE(Graph); @@ -27,7 +29,10 @@ std::string ExecutorTypeToString(const ExecutorType type) { } ExecutorType ExecutorTypeFromString(const std::string& typeStr) { -#define CASE(_type) if (typeStr == #_type) { return ExecutorType::_type; } +#define CASE(_type) \ + if (typeStr == #_type) { \ + return ExecutorType::_type; \ + } CASE(Undefined); CASE(Graph); CASE(Common); @@ -41,5 +46,5 @@ ExecutorType ExecutorTypeFromString(const std::string& typeStr) { return ExecutorType::Undefined; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/executor.hpp b/src/plugins/intel_cpu/src/nodes/executors/executor.hpp index 2016e8f5820dee..16a419c95d5efc 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/executor.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/executor.hpp @@ -4,15 +4,15 @@ #pragma once -#include "openvino/core/except.hpp" -#include "openvino/core/visibility.hpp" #include #include "cache/multi_cache.h" #include "cpu_memory.h" #include "graph_context.h" -#include "onednn/iml_type_mapper.h" #include "memory_arguments.hpp" +#include "onednn/iml_type_mapper.h" +#include "openvino/core/except.hpp" +#include "openvino/core/visibility.hpp" namespace ov { namespace intel_cpu { @@ -24,25 +24,25 @@ namespace intel_cpu { #endif #if defined(OV_CPU_WITH_ACL) -# if defined(OPENVINO_ARCH_ARM) -# define OV_CPU_INSTANCE_ACL32(...) {__VA_ARGS__}, -# else -# define OV_CPU_INSTANCE_ACL32(...) -# endif -# if defined(OPENVINO_ARCH_ARM64) -# define OV_CPU_INSTANCE_ACL64(...) {__VA_ARGS__}, -# else -# define OV_CPU_INSTANCE_ACL64(...) -# endif -# if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64) -# define OV_CPU_INSTANCE_ACL(...) {__VA_ARGS__}, -# else -# define OV_CPU_INSTANCE_ACL(...) -# endif +# if defined(OPENVINO_ARCH_ARM) +# define OV_CPU_INSTANCE_ACL32(...) {__VA_ARGS__}, +# else +# define OV_CPU_INSTANCE_ACL32(...) +# endif +# if defined(OPENVINO_ARCH_ARM64) +# define OV_CPU_INSTANCE_ACL64(...) {__VA_ARGS__}, +# else +# define OV_CPU_INSTANCE_ACL64(...) +# endif +# if defined(OPENVINO_ARCH_ARM) || defined(OPENVINO_ARCH_ARM64) +# define OV_CPU_INSTANCE_ACL(...) {__VA_ARGS__}, +# else +# define OV_CPU_INSTANCE_ACL(...) +# endif #else -# define OV_CPU_INSTANCE_ACL32(...) -# define OV_CPU_INSTANCE_ACL64(...) -# define OV_CPU_INSTANCE_ACL(...) +# define OV_CPU_INSTANCE_ACL32(...) +# define OV_CPU_INSTANCE_ACL64(...) +# define OV_CPU_INSTANCE_ACL(...) #endif #if defined(OV_CPU_WITH_DNNL) @@ -72,28 +72,11 @@ namespace intel_cpu { #define OV_CPU_INSTANCE_COMMON(...) {__VA_ARGS__}, // @todo another option is to determine shape relation by executor type -enum class ShapeTolerance { - Agnostic, - Dependant -}; +enum class ShapeTolerance { Agnostic, Dependant }; -enum class ExecutorType { - Undefined, - Graph, - Common, - jit_x64, - Dnnl, - Acl, - Mlas, - jit_aarch64, - Shl -}; +enum class ExecutorType { Undefined, Graph, Common, jit_x64, Dnnl, Acl, Mlas, jit_aarch64, Shl }; -enum class OperationType { - FullyConnected, - MatMul, - Convolution -}; +enum class OperationType { FullyConnected, MatMul, Convolution }; std::string ExecutorTypeToString(const ExecutorType type); ExecutorType ExecutorTypeFromString(const std::string& typeStr); diff --git a/src/plugins/intel_cpu/src/nodes/executors/executor_config.hpp b/src/plugins/intel_cpu/src/nodes/executors/executor_config.hpp index d08c4ad8127325..cd9bcaf7a119f7 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/executor_config.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/executor_config.hpp @@ -4,8 +4,8 @@ #pragma once -#include "post_ops.hpp" #include "memory_arguments.hpp" +#include "post_ops.hpp" namespace ov { namespace intel_cpu { diff --git a/src/plugins/intel_cpu/src/nodes/executors/executor_implementation.hpp b/src/plugins/intel_cpu/src/nodes/executors/executor_implementation.hpp index 3459d1fe35e19e..07a58b0fa6cfa7 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/executor_implementation.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/executor_implementation.hpp @@ -19,22 +19,22 @@ template class ExecutorImplementation { public: using SupportsPredicate = std::function&)>; - using RequiresFallbackPredicate = std::function>(const executor::Config&)>; + using RequiresFallbackPredicate = + std::function>(const executor::Config&)>; using AcceptsShapePredicate = std::function; using CreateFunction = std::function; - ExecutorImplementation( - const char* name, - const ExecutorType type, - const OperationType operationType, - const ShapeTolerance shapeRelation, - SupportsPredicate supports, - RequiresFallbackPredicate requiresFallback, - AcceptsShapePredicate acceptsShape, - CreateFunction create) + ExecutorImplementation(const char* name, + const ExecutorType type, + const OperationType operationType, + const ShapeTolerance shapeRelation, + SupportsPredicate supports, + RequiresFallbackPredicate requiresFallback, + AcceptsShapePredicate acceptsShape, + CreateFunction create) : m_name(name), m_type(type), m_operationType(operationType), diff --git a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp index 10f472ddcd7283..42101ce3fca257 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/fullyconnected_implementations.cpp @@ -11,14 +11,14 @@ #include "memory_desc/cpu_memory_desc.h" #include "nodes/executors/convolution_config.hpp" #include "nodes/executors/dnnl/dnnl_convolution_primitive.hpp" -#include "nodes/executors/dnnl/dnnl_fullyconnected_primitive.hpp" #include "nodes/executors/dnnl/dnnl_fullyconnected.hpp" +#include "nodes/executors/dnnl/dnnl_fullyconnected_primitive.hpp" #include "nodes/executors/dnnl/dnnl_matmul_primitive.hpp" #include "nodes/executors/dnnl/dnnl_shape_agnostic_data.hpp" #include "nodes/executors/executor.hpp" #include "nodes/executors/executor_implementation.hpp" -#include "nodes/executors/implementations.hpp" #include "nodes/executors/fullyconnected_config.hpp" +#include "nodes/executors/implementations.hpp" #include "nodes/executors/memory_arguments.hpp" #include "nodes/executors/mlas/mlas_gemm.hpp" #include "nodes/executors/precision_matcher.hpp" @@ -30,7 +30,7 @@ #include "utils/debug_capabilities.h" #if defined(OV_CPU_WITH_ACL) -#include "nodes/executors/acl/acl_fullyconnected.hpp" +# include "nodes/executors/acl/acl_fullyconnected.hpp" #endif #if defined(OV_CPU_WITH_SHL) @@ -50,7 +50,7 @@ using LayoutConfig = std::vector; static const LayoutConfig dnnlFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp}; static const LayoutConfig aclFCLayoutConfig{LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp, LayoutType::ncsp}; -template +template struct Require { bool operator()() { return dnnl::impl::cpu::x64::mayiuse(ISA); @@ -144,10 +144,10 @@ static bool fullyMatchConfiguration(const MemoryDescArgs& currentDescriptors, continue; if (desc->getPrecision() != type) - return false; // type mismatch + return false; // type mismatch if (!desc->hasLayoutType(layoutConfig[i])) - return false; // layout mismatch + return false; // layout mismatch } return true; @@ -207,6 +207,8 @@ OV_CPU_MAYBE_UNUSED_FUNCTION static inline bool noPostOps(const FCConfig& config return config.postOps.empty(); } +// to keep OV_CPU_INSTANCE macros aligned +// clang-format off template <> const std::vector>& getImplementations() { static const std::vector> fullyconnectedImplementations { @@ -492,5 +494,7 @@ const std::vector>& getImplementations() { return fullyconnectedImplementations; } +// clang-format on + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/graph_emitter.hpp b/src/plugins/intel_cpu/src/nodes/executors/graph_emitter.hpp index 784ed8bc778840..347ac4c981f4f1 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/graph_emitter.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/graph_emitter.hpp @@ -82,9 +82,7 @@ class GraphEmitter { return memoryDescs; } - static executor::Config createConfig(const MemoryArgs& memory, - const Attrs& attrs, - const PostOps& postOps) { + static executor::Config createConfig(const MemoryArgs& memory, const Attrs& attrs, const PostOps& postOps) { return executor::Config{memoryDescsFromMemory(memory), attrs, postOps}; } @@ -104,11 +102,11 @@ class GraphEmitter { const auto& graphExecutor = graphEmitter.createGraph(fallbackConfig.descs, fallbackConfig.attrs, fallbackConfig.postOps, context) - .ensureAttrsMatch() - .ensureSrcDescsMatch() - .ensureDstDescsMatch() - .ensurePostOpsMatch() - .emit(); + .ensureAttrsMatch() + .ensureSrcDescsMatch() + .ensureDstDescsMatch() + .ensurePostOpsMatch() + .emit(); (void)graphExecutor; OPENVINO_THROW("Fallback logic is not implemented yet"); // return graphExecutor; diff --git a/src/plugins/intel_cpu/src/nodes/executors/implementation_utils.hpp b/src/plugins/intel_cpu/src/nodes/executors/implementation_utils.hpp index cd029283a09c50..bee82af305c9d2 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/implementation_utils.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/implementation_utils.hpp @@ -5,6 +5,7 @@ #pragma once #include + #include "cpu_types.h" #include "memory_desc/cpu_memory_desc.h" #include "nodes/executors/memory_arguments.hpp" @@ -13,80 +14,80 @@ namespace ov { namespace intel_cpu { -template +template ov::element::Type memoryDescType(const Config& config) { return config.descs.at(idx)->getPrecision(); } -template +template ov::element::Type srcType(const Config& config) { return memoryDescType(config); } -template +template ov::element::Type weiType(const Config& config) { return memoryDescType(config); } -template +template ov::element::Type biaType(const Config& config) { return memoryDescType(config); } -template +template ov::element::Type dstType(const Config& config) { return memoryDescType(config); } -template +template ov::element::Type dims(const Config& config) { return config.descs.at(idx)->getShape().getDims(); } -template +template const VectorDims& srcDims(const Config& config) { return dims(config); } -template +template const VectorDims& weiDims(const Config& config) { return dims(config); } -template +template size_t rank(const Config& config) { return config.descs.at(idx)->getShape().getRank(); } -template +template size_t srcRank(const Config& config) { return rank(config); } -template +template size_t weiRank(const Config& config) { return rank(config); } -template +template size_t memSize(const Config& config) { return config.descs.at(idx)->getCurrentMemSize(); } -template +template size_t srcMemSize(const Config& config) { return memSize(config); } -template +template size_t weiMemSize(const Config& config) { return memSize(config); } -template +template size_t postOpsNumbers(const Config& config) { return config.postOps.size(); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/interpolate.cpp b/src/plugins/intel_cpu/src/nodes/executors/interpolate.cpp index d0a006b1bea0fa..cb830a36f03cb1 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/interpolate.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/interpolate.cpp @@ -3,18 +3,19 @@ // #include "interpolate.hpp" -#include "openvino/core/parallel.hpp" -#include "nodes/common/cpu_memcpy.h" + #include "emitters/plugin/x64/jit_load_store_emitters.hpp" +#include "nodes/common/cpu_memcpy.h" +#include "openvino/core/parallel.hpp" using namespace ov::intel_cpu; bool ov::intel_cpu::InterpolateExecutor::init(const InterpolateAttrs& interpolateAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs, - const dnnl::primitive_attr &attr) { - const auto &srcDims = srcDescs[0]->getShape().getStaticDims(); - const auto &dstDims = dstDescs[0]->getShape().getStaticDims(); + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { + const auto& srcDims = srcDescs[0]->getShape().getStaticDims(); + const auto& dstDims = dstDescs[0]->getShape().getStaticDims(); interpAttrs = interpolateAttrs; srcDimPad5d = to5Dim(getPaddedInputShape(srcDims, interpolateAttrs.padBegin, interpolateAttrs.padEnd)); dstDim5d = to5Dim(dstDims); @@ -24,38 +25,49 @@ bool ov::intel_cpu::InterpolateExecutor::init(const InterpolateAttrs& interpolat spatialDimSize = getSpatialDimsNum(dataRank); switch (interpAttrs.mode) { - case InterpolateMode::nearest: { - buildTblNN(srcDimPad5d, dstDim5d, interpAttrs.dataScales, interpolateAttrs.layout, interpolateAttrs.nearestMode); - break; - } - case InterpolateMode::linear_onnx: { - buildTblLinearOnnx(srcDimPad5d, dstDim5d, interpAttrs.dataScales, interpolateAttrs.layout); - break; - } - case InterpolateMode::linear: { - static constexpr int LINEAR_KERNEL = 2; - buildTblLinear(srcDimPad5d, dstDim5d, interpAttrs.dataScales, LINEAR_KERNEL, interpolateAttrs.antialias); - break; - } - case InterpolateMode::cubic: { - buildTblCubic(srcDimPad5d, dstDim5d, interpAttrs.dataScales, interpolateAttrs.cubeCoeff, interpolateAttrs.layout); - break; - } - default: { - OPENVINO_THROW("Interpolate executor does not support interpolate mode: ", interpAttrs.mode); - break; - } + case InterpolateMode::nearest: { + buildTblNN(srcDimPad5d, + dstDim5d, + interpAttrs.dataScales, + interpolateAttrs.layout, + interpolateAttrs.nearestMode); + break; + } + case InterpolateMode::linear_onnx: { + buildTblLinearOnnx(srcDimPad5d, dstDim5d, interpAttrs.dataScales, interpolateAttrs.layout); + break; + } + case InterpolateMode::linear: { + static constexpr int LINEAR_KERNEL = 2; + buildTblLinear(srcDimPad5d, dstDim5d, interpAttrs.dataScales, LINEAR_KERNEL, interpolateAttrs.antialias); + break; + } + case InterpolateMode::cubic: { + buildTblCubic(srcDimPad5d, + dstDim5d, + interpAttrs.dataScales, + interpolateAttrs.cubeCoeff, + interpolateAttrs.layout); + break; + } + default: { + OPENVINO_THROW("Interpolate executor does not support interpolate mode: ", interpAttrs.mode); + break; + } } return true; } // ===================================================================================================================== // index layout: // d_0............d_OD-1, h_0..............h_OH-1, w_0................w_OW-1 -void ov::intel_cpu::InterpolateExecutor::buildTblNN(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, - const std::vector& dataScales, InterpolateLayoutType layout, InterpolateNearestMode nearestMode) { +void ov::intel_cpu::InterpolateExecutor::buildTblNN(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + InterpolateLayoutType layout, + InterpolateNearestMode nearestMode) { const int dimSize = dataRank; float fz = (dimSize == 5) ? dataScales[dimSize - 3] : 1.f; - float fy = dataScales[dimSize - 2]; + float fy = dataScales[dimSize - 2]; float fx = dataScales[dimSize - 1]; size_t ID = srcDimPad5d[2], IH = srcDimPad5d[3], IW = srcDimPad5d[4]; size_t OD = dstDim5d[2], OH = dstDim5d[3], OW = dstDim5d[4]; @@ -84,80 +96,91 @@ void ov::intel_cpu::InterpolateExecutor::buildTblNN(const VectorDims& srcDimPad5 // scale is float(outShape) / float(inShape) // strictly consistent with onnx calc manner(div scale, not multiply inverse), given this is done offline // the slight precison diff can produce obvious wrong value due to "nearest round" behavior for NN mode -float ov::intel_cpu::InterpolateExecutor::coordTransToInput(int outCoord, float scale, int inShape, int outShape) const { +float ov::intel_cpu::InterpolateExecutor::coordTransToInput(int outCoord, + float scale, + int inShape, + int outShape) const { if (scale == 1.0f || (inShape == outShape)) { return outCoord; } switch (interpAttrs.coordTransMode) { - case InterpolateCoordTransMode::half_pixel: { + case InterpolateCoordTransMode::half_pixel: { + return (outCoord + 0.5f) / scale - 0.5f; + break; + } + case InterpolateCoordTransMode::pytorch_half_pixel: { + if (outShape > 1) return (outCoord + 0.5f) / scale - 0.5f; - break; - } - case InterpolateCoordTransMode::pytorch_half_pixel: { - if (outShape > 1) - return (outCoord + 0.5f) / scale - 0.5f; - else - return 0; - break; - } - case InterpolateCoordTransMode::asymmetric: { - return static_cast(outCoord) / scale; - break; - } - case InterpolateCoordTransMode::tf_half_pixel_for_nn: { - return (outCoord + 0.5f) / scale; - break; - } - case InterpolateCoordTransMode::align_corners: { - if (outShape > 1) - return outCoord * (static_cast(inShape - 1) / static_cast(outShape - 1)); - else - return 0; - break; - } - default: { - OPENVINO_THROW("errorPrefix", " does not support specified coordinate transformation mode"); - break; - } + else + return 0; + break; + } + case InterpolateCoordTransMode::asymmetric: { + return static_cast(outCoord) / scale; + break; + } + case InterpolateCoordTransMode::tf_half_pixel_for_nn: { + return (outCoord + 0.5f) / scale; + break; + } + case InterpolateCoordTransMode::align_corners: { + if (outShape > 1) + return outCoord * (static_cast(inShape - 1) / static_cast(outShape - 1)); + else + return 0; + break; + } + default: { + OPENVINO_THROW("errorPrefix", " does not support specified coordinate transformation mode"); + break; + } } } -int ov::intel_cpu::InterpolateExecutor::nearestRound(float originCoord, bool isDownsample, InterpolateNearestMode nearestMode) const { +int ov::intel_cpu::InterpolateExecutor::nearestRound(float originCoord, + bool isDownsample, + InterpolateNearestMode nearestMode) const { switch (nearestMode) { - case InterpolateNearestMode::round_prefer_floor: { - if (originCoord == (static_cast(originCoord) + 0.5f)) - return static_cast(std::floor(originCoord)); - else - return static_cast(std::round(originCoord)); - break; - } - case InterpolateNearestMode::round_prefer_ceil: { - return static_cast(std::round(originCoord)); - break; - } - case InterpolateNearestMode::floor: { + case InterpolateNearestMode::round_prefer_floor: { + if (originCoord == (static_cast(originCoord) + 0.5f)) return static_cast(std::floor(originCoord)); - break; - } - case InterpolateNearestMode::ceil: { + else + return static_cast(std::round(originCoord)); + break; + } + case InterpolateNearestMode::round_prefer_ceil: { + return static_cast(std::round(originCoord)); + break; + } + case InterpolateNearestMode::floor: { + return static_cast(std::floor(originCoord)); + break; + } + case InterpolateNearestMode::ceil: { + return static_cast(std::ceil(originCoord)); + break; + } + case InterpolateNearestMode::simple: { + if (isDownsample) return static_cast(std::ceil(originCoord)); - break; - } - case InterpolateNearestMode::simple: { - if (isDownsample) - return static_cast(std::ceil(originCoord)); - else - return static_cast(originCoord); - } - default: { - OPENVINO_THROW("errorPrefix", " does not support specified nearest round mode"); - break; - } + else + return static_cast(originCoord); + } + default: { + OPENVINO_THROW("errorPrefix", " does not support specified nearest round mode"); + break; + } } } -void ov::intel_cpu::InterpolateExecutor::linearOnnxCF(int outCoord, float scale, int inShape, int outShape, - int& index0, int& index1, float& weight0, float& weight1) { +void ov::intel_cpu::InterpolateExecutor::linearOnnxCF(int outCoord, + float scale, + int inShape, + int outShape, + int& index0, + int& index1, + float& weight0, + float& weight1) { float inCoord = coordTransToInput(outCoord, scale, inShape, outShape); inCoord = std::max(0.0f, std::min(inCoord, static_cast(inShape - 1))); index0 = std::min(static_cast(inCoord), inShape - 1); @@ -171,8 +194,10 @@ void ov::intel_cpu::InterpolateExecutor::linearOnnxCF(int outCoord, float scale, } } -void ov::intel_cpu::InterpolateExecutor::buildTblLinearOnnx(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, - const std::vector& dataScales, InterpolateLayoutType layout) { +void ov::intel_cpu::InterpolateExecutor::buildTblLinearOnnx(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + InterpolateLayoutType layout) { int dimSize = dataRank; float fz = (spatialDimSize > 2) ? dataScales[dimSize - 3] : 1.f; float fy = (spatialDimSize > 1) ? dataScales[dimSize - 2] : 1.f; @@ -231,7 +256,7 @@ void ov::intel_cpu::InterpolateExecutor::buildTblLinearOnnx(const VectorDims& sr indexPtr[1][idxOzOyOx] = (izF * IH * IW + iyT * IW + ixR) * scale; weightPtr[0][idxOzOyOx] = weightL; weightPtr[1][idxOzOyOx] = weightR; - if (spatialDimSize > 1) { + if (spatialDimSize > 1) { indexPtr[2][idxOzOyOx] = (izF * IH * IW + iyB * IW + ixL) * scale; indexPtr[3][idxOzOyOx] = (izF * IH * IW + iyB * IW + ixR) * scale; weightPtr[2][idxOzOyOx] = weightT; @@ -284,8 +309,11 @@ void ov::intel_cpu::InterpolateExecutor::buildTblLinearOnnx(const VectorDims& sr // wd .........wd, wh............wh, ww.............ww, id...........id, ih............ih, iw..............iw // | | // wh0.....wh_diameter ih0.....ih_diameter -void ov::intel_cpu::InterpolateExecutor::buildTblLinear(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, - const std::vector& dataScales, int kernel_width, bool antialias) { +void ov::intel_cpu::InterpolateExecutor::buildTblLinear(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + int kernel_width, + bool antialias) { int dimSize = dataRank; float fz = (dimSize == 5) ? dataScales[dimSize - 3] : 1.f; float fy = dataScales[dimSize - 2]; @@ -309,15 +337,15 @@ void ov::intel_cpu::InterpolateExecutor::buildTblLinear(const VectorDims& srcDim int sizeOH = OH * diaOH; int sizeOW = OW * diaOW; indexTable.resize((sizeOD + sizeOH + sizeOW) * 2); - float *weightTable = reinterpret_cast(&indexTable[0]); - float *weightOD = static_cast(&weightTable[0]); - float *weightOH = static_cast(&weightTable[sizeOD]); - float *weightOW = static_cast(&weightTable[sizeOD + sizeOH]); + float* weightTable = reinterpret_cast(&indexTable[0]); + float* weightOD = static_cast(&weightTable[0]); + float* weightOH = static_cast(&weightTable[sizeOD]); + float* weightOW = static_cast(&weightTable[sizeOD + sizeOH]); - int *idxTable = static_cast(&indexTable[sizeOD + sizeOH + sizeOW]); - int *idxOD = static_cast(&idxTable[0]); - int *idxOH = static_cast(&idxTable[sizeOD]); - int *idxOW = static_cast(&idxTable[sizeOD + sizeOH]); + int* idxTable = static_cast(&indexTable[sizeOD + sizeOH + sizeOW]); + int* idxOD = static_cast(&idxTable[0]); + int* idxOH = static_cast(&idxTable[sizeOD]); + int* idxOW = static_cast(&idxTable[sizeOD + sizeOH]); for (int oz = 0; oz < static_cast(OD); oz++) { float iz = coordTransToInput(oz, fz, ID, OD); @@ -375,8 +403,11 @@ std::vector ov::intel_cpu::InterpolateExecutor::getCubicCoeffs(float mant // table layout: // OW OW OW OW OW OH OH OH OH OH // x_idx x_weight0 x_weight1 x_weight2 x_weight3 y_idx y_weight0 y_weight1 y_weight2 y_weight3 -void ov::intel_cpu::InterpolateExecutor::buildTblCubic(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, const std::vector& dataScales, - float cubicCoeff, InterpolateLayoutType layout) { +void ov::intel_cpu::InterpolateExecutor::buildTblCubic(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + float cubicCoeff, + InterpolateLayoutType layout) { int dimSize = dataRank; float fy = dataScales[dimSize - 2]; float fx = dataScales[dimSize - 1]; @@ -394,9 +425,9 @@ void ov::intel_cpu::InterpolateExecutor::buildTblCubic(const VectorDims& srcDimP } int tblAdvance = 0; - int *xOrigin = static_cast(&indexTable[tblAdvance]); + int* xOrigin = static_cast(&indexTable[tblAdvance]); tblAdvance += OW; - float *xFactor = reinterpret_cast(&indexTable[tblAdvance]); + float* xFactor = reinterpret_cast(&indexTable[tblAdvance]); for (int ox = 0; ox < OW; ox++) { float ix = coordTransToInput(ox, fx, IW, OW); int ix_r = static_cast(std::floor(ix)); @@ -410,9 +441,9 @@ void ov::intel_cpu::InterpolateExecutor::buildTblCubic(const VectorDims& srcDimP } tblAdvance += CUBIC_GRID_LEN * OW; - int *yOrigin = static_cast(&indexTable[tblAdvance]); + int* yOrigin = static_cast(&indexTable[tblAdvance]); tblAdvance += OH; - float *yFactor = reinterpret_cast(&indexTable[tblAdvance]); + float* yFactor = reinterpret_cast(&indexTable[tblAdvance]); for (int oy = 0; oy < OH; oy++) { float iy = coordTransToInput(oy, fy, IH, OH); int iy_r = static_cast(std::floor(iy)); @@ -427,9 +458,9 @@ void ov::intel_cpu::InterpolateExecutor::buildTblCubic(const VectorDims& srcDimP if (layout == InterpolateLayoutType::planar) { tblAdvance += CUBIC_GRID_LEN * OH; - int *sequenceOH = static_cast(&indexTable[tblAdvance]); + int* sequenceOH = static_cast(&indexTable[tblAdvance]); tblAdvance += OH * OW; - int *sequenceOW = static_cast(&indexTable[tblAdvance]); + int* sequenceOW = static_cast(&indexTable[tblAdvance]); for (int h = 0; h < OH; ++h) { int offset = h * OW; for (int w = 0; w < OW; ++w) { @@ -447,16 +478,17 @@ inline VectorDims getBlockND(const VectorDims& shape) { int shapeRank = shape.size(); VectorDims blockND(shapeRank + 1, 1); for (int i = shapeRank - 1; i >= 0; i--) { - blockND[i] = shape[i] * blockND[i+1]; + blockND[i] = shape[i] * blockND[i + 1]; } return blockND; } -const uint8_t* ov::intel_cpu::InterpolateExecutor::padPreprocess(const std::vector& src, const std::vector& dst) { - const uint8_t *src_data_origin = src[0]->getDataAs(); +const uint8_t* ov::intel_cpu::InterpolateExecutor::padPreprocess(const std::vector& src, + const std::vector& dst) { + const uint8_t* src_data_origin = src[0]->getDataAs(); - const auto &srcDim = src[0]->getStaticDims(); - const auto &dstDim = dst[0]->getStaticDims(); + const auto& srcDim = src[0]->getStaticDims(); + const auto& dstDim = dst[0]->getStaticDims(); size_t dimSize = srcDim.size(); auto srcDimPad = getSrcDimPad5d(); @@ -465,7 +497,7 @@ const uint8_t* ov::intel_cpu::InterpolateExecutor::padPreprocess(const std::vect const auto dstDim5d = to5Dim(dstDim); const auto srcDataSize = src[0]->getDesc().getPrecision().size(); - const uint8_t *src_data = nullptr; + const uint8_t* src_data = nullptr; std::vector srcPadded; if (interpAttrs.hasPad) { int padB0 = (dimSize > 2) ? interpAttrs.padBegin[0] : 0; @@ -479,23 +511,32 @@ const uint8_t* ov::intel_cpu::InterpolateExecutor::padPreprocess(const std::vect if (interpAttrs.layout == InterpolateLayoutType::planar) { srcPadded.resize(inShapePadBlock[0] * srcDataSize, 0); - uint8_t *src_data_pad = static_cast(&srcPadded[0]); + uint8_t* src_data_pad = static_cast(&srcPadded[0]); parallel_for4d(srcDim5d[0], srcDim5d[1], srcDim5d[2], srcDim5d[3], [&](int n, int c, int d, int h) { - const uint8_t *src = src_data_origin + (inShapeBlock[1] * n + inShapeBlock[2] * c + inShapeBlock[3] * d + inShapeBlock[4] * h) * srcDataSize; - uint8_t *srcPad = src_data_pad + (inShapePadBlock[1] * (n + padB0) + inShapePadBlock[2] * (c + padB1) + - inShapePadBlock[3] * (d + padB2) + inShapePadBlock[4] * (h + padB3) + padB4) * srcDataSize; + const uint8_t* src = src_data_origin + (inShapeBlock[1] * n + inShapeBlock[2] * c + + inShapeBlock[3] * d + inShapeBlock[4] * h) * + srcDataSize; + uint8_t* srcPad = + src_data_pad + (inShapePadBlock[1] * (n + padB0) + inShapePadBlock[2] * (c + padB1) + + inShapePadBlock[3] * (d + padB2) + inShapePadBlock[4] * (h + padB3) + padB4) * + srcDataSize; cpu_memcpy(srcPad, src, srcDim5d[4] * srcDataSize); }); src_data = src_data_pad; } else if (interpAttrs.layout == InterpolateLayoutType::by_channel) { srcPadded.resize(inShapePadBlock[0] * srcDataSize, 0); - uint8_t *src_data_pad = static_cast(&srcPadded[0]); + uint8_t* src_data_pad = static_cast(&srcPadded[0]); parallel_for4d(srcDim5d[0], srcDim5d[2], srcDim5d[3], srcDim5d[4], [&](int n, int d, int h, int w) { - const uint8_t *src = src_data_origin + (inShapeBlock[1] * n + - (inShapeBlock[3] * d + inShapeBlock[4] * h + inShapeBlock[5] * w) * srcDim5d[1]) * srcDataSize; - uint8_t *srcPad = src_data_pad + (inShapePadBlock[1] * (n + padB0) + (inShapePadBlock[3] * (d + padB2) + - inShapePadBlock[4] * (h + padB3) + - inShapePadBlock[5] * (w + padB4)) * srcDimPad5d[1] + padB1) * srcDataSize; + const uint8_t* src = src_data_origin + + (inShapeBlock[1] * n + + (inShapeBlock[3] * d + inShapeBlock[4] * h + inShapeBlock[5] * w) * srcDim5d[1]) * + srcDataSize; + uint8_t* srcPad = src_data_pad + (inShapePadBlock[1] * (n + padB0) + + (inShapePadBlock[3] * (d + padB2) + inShapePadBlock[4] * (h + padB3) + + inShapePadBlock[5] * (w + padB4)) * + srcDimPad5d[1] + + padB1) * + srcDataSize; cpu_memcpy(srcPad, src, srcDim5d[1] * srcDataSize); }); src_data = src_data_pad; @@ -504,23 +545,30 @@ const uint8_t* ov::intel_cpu::InterpolateExecutor::padPreprocess(const std::vect size_t CB = div_up(srcDimPad5d[1], blkSize); size_t eltsTotal = srcDimPad5d[0] * CB * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize; srcPadded.resize(eltsTotal * srcDataSize, 0x0); - uint8_t *src_data_pad = static_cast(&srcPadded[0]); + uint8_t* src_data_pad = static_cast(&srcPadded[0]); if ((srcDim5d[0] != srcDimPad5d[0]) || (srcDim5d[1] != srcDimPad5d[1])) { OPENVINO_THROW("Interpolate layer with name does not support padding on batch and channel dimensions"); } - parallel_for5d(srcDim5d[0], CB, srcDim5d[2], srcDim5d[3], srcDim5d[4], [&](int n, int cb, int d, int h, int w) { - const uint8_t *src = src_data_origin + (n * CB * srcDim5d[2] * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize - + (cb * srcDim5d[2] * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize - + (d * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize - + (h * srcDim5d[4] * blkSize) * srcDataSize - + (w * blkSize) * srcDataSize; - uint8_t *srcPad = src_data_pad + (n * CB * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize - + (cb * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize - + ((d + padB2) * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize - + ((h + padB3) * srcDimPad5d[4] * blkSize) * srcDataSize - + ((w + padB4) * blkSize) * srcDataSize; - cpu_memcpy(srcPad, src, blkSize * srcDataSize); - }); + parallel_for5d( + srcDim5d[0], + CB, + srcDim5d[2], + srcDim5d[3], + srcDim5d[4], + [&](int n, int cb, int d, int h, int w) { + const uint8_t* src = src_data_origin + + (n * CB * srcDim5d[2] * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize + + (cb * srcDim5d[2] * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize + + (d * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize + + (h * srcDim5d[4] * blkSize) * srcDataSize + (w * blkSize) * srcDataSize; + uint8_t* srcPad = + src_data_pad + + (n * CB * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize + + (cb * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize + + ((d + padB2) * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize + + ((h + padB3) * srcDimPad5d[4] * blkSize) * srcDataSize + ((w + padB4) * blkSize) * srcDataSize; + cpu_memcpy(srcPad, src, blkSize * srcDataSize); + }); src_data = src_data_pad; } } else { diff --git a/src/plugins/intel_cpu/src/nodes/executors/interpolate.hpp b/src/plugins/intel_cpu/src/nodes/executors/interpolate.hpp index 15df4eed5f0471..041589c0ab9f6a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/interpolate.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/interpolate.hpp @@ -11,41 +11,15 @@ namespace ov { namespace intel_cpu { -enum InterpolateLayoutType { - planar, - block, - by_channel -}; +enum InterpolateLayoutType { planar, block, by_channel }; -enum InterpolateMode { - nearest, - linear, - linear_onnx, - cubic, - bilinear_pillow, - bicubic_pillow -}; +enum InterpolateMode { nearest, linear, linear_onnx, cubic, bilinear_pillow, bicubic_pillow }; -enum InterpolateCoordTransMode { - half_pixel, - pytorch_half_pixel, - asymmetric, - tf_half_pixel_for_nn, - align_corners -}; +enum InterpolateCoordTransMode { half_pixel, pytorch_half_pixel, asymmetric, tf_half_pixel_for_nn, align_corners }; -enum class InterpolateNearestMode { - round_prefer_floor, - round_prefer_ceil, - floor, - ceil, - simple -}; +enum class InterpolateNearestMode { round_prefer_floor, round_prefer_ceil, floor, ceil, simple }; -enum class InterpolateShapeCalcMode { - sizes, - scales -}; +enum class InterpolateShapeCalcMode { sizes, scales }; struct InterpolateAttrs { InterpolateShapeCalcMode shapeCalcMode = InterpolateShapeCalcMode::sizes; @@ -63,9 +37,9 @@ struct InterpolateAttrs { bool hasPad = false; }; -inline VectorDims getPaddedInputShape(const VectorDims &srcDims, - const std::vector &padBegin, - const std::vector &padEnd) { +inline VectorDims getPaddedInputShape(const VectorDims& srcDims, + const std::vector& padBegin, + const std::vector& padEnd) { VectorDims paddedShape; int dataRank = srcDims.size(); for (int i = 0; i < dataRank; i++) { @@ -80,16 +54,16 @@ inline int clipCoord(int pos, int length) { inline size_t getSpatialDimsNum(const Dim rank) { switch (rank) { - case 1: - case 3: - return 1; - case 2: - case 4: - return 2; - case 5: - return 3; - default: - OPENVINO_THROW("Can't define number spatial"); + case 1: + case 3: + return 1; + case 2: + case 4: + return 2; + case 5: + return 3; + default: + OPENVINO_THROW("Can't define number spatial"); } } @@ -133,27 +107,49 @@ class InterpolateExecutor { virtual bool init(const InterpolateAttrs& interpolateAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr); - virtual void exec(const std::vector& src, const std::vector& dst, const void *post_ops_data_) = 0; + const dnnl::primitive_attr& attr); + virtual void exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) = 0; virtual impl_desc_type getImplType() const = 0; virtual ~InterpolateExecutor() = default; - VectorDims getSrcDimPad5d() const { return srcDimPad5d; } + VectorDims getSrcDimPad5d() const { + return srcDimPad5d; + } const uint8_t* padPreprocess(const std::vector& src, const std::vector& dst); private: - void buildTblNN(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, const std::vector& dataScales, - InterpolateLayoutType layout, InterpolateNearestMode nearestMode); - void buildTblLinearOnnx(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, const std::vector& dataScales, + void buildTblNN(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + InterpolateLayoutType layout, + InterpolateNearestMode nearestMode); + void buildTblLinearOnnx(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, InterpolateLayoutType layout); - void buildTblLinear(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, const std::vector& dataScales, int kernel_width, + void buildTblLinear(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + int kernel_width, bool antialias); - void buildTblCubic(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, const std::vector& dataScales, float cubicCoeff, + void buildTblCubic(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + float cubicCoeff, InterpolateLayoutType layout); float coordTransToInput(int outCoord, float scale, int inShape, int outShape) const; int nearestRound(float origin, bool isDownsample, InterpolateNearestMode nearestMode) const; - void linearOnnxCF(int outCoord, float scale, int inShape, int outShape, int& index0, int& index1, float& weight0, float& weight1); + void linearOnnxCF(int outCoord, + float scale, + int inShape, + int outShape, + int& index0, + int& index1, + float& weight0, + float& weight1); std::vector getCubicCoeffs(float mantissa, float a); protected: @@ -180,5 +176,5 @@ class InterpolateExecutorBuilder { using InterpolateExecutorBuilderPtr = std::shared_ptr; using InterpolateExecutorBuilderCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/interpolate_list.cpp b/src/plugins/intel_cpu/src/nodes/executors/interpolate_list.cpp index 2362b644583763..21ae249757bf9c 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/interpolate_list.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/interpolate_list.cpp @@ -9,11 +9,10 @@ namespace intel_cpu { const std::vector& getInterpolateExecutorsList() { static std::vector descs = { - OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) - }; + OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared())}; return descs; } -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/interpolate_list.hpp b/src/plugins/intel_cpu/src/nodes/executors/interpolate_list.hpp index 2ed16ea04b1852..a0c1fc240731fb 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/interpolate_list.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/interpolate_list.hpp @@ -5,14 +5,13 @@ #pragma once #include "executor.hpp" - #include "interpolate.hpp" #if defined(OV_CPU_WITH_ACL) -#include "acl/acl_interpolate.hpp" +# include "acl/acl_interpolate.hpp" #endif -#include "onednn/iml_type_mapper.h" #include "common/primitive_cache.hpp" +#include "onednn/iml_type_mapper.h" namespace ov { namespace intel_cpu { @@ -27,9 +26,10 @@ const std::vector& getInterpolateExecutorsList(); class InterpolateExecutorFactory : public ExecutorFactoryLegacy { public: InterpolateExecutorFactory(const InterpolateAttrs& InterpolateAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs, - const ExecutorContext::CPtr context) : ExecutorFactoryLegacy(context) { + const std::vector& srcDescs, + const std::vector& dstDescs, + const ExecutorContext::CPtr context) + : ExecutorFactoryLegacy(context) { for (auto& desc : getInterpolateExecutorsList()) { if (desc.builder->isSupported(InterpolateAttrs, srcDescs, dstDescs)) { supportedDescs.push_back(desc); @@ -39,9 +39,9 @@ class InterpolateExecutorFactory : public ExecutorFactoryLegacy { ~InterpolateExecutorFactory() = default; virtual InterpolateExecutorPtr makeExecutor(const InterpolateAttrs& interpolateAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs, - const dnnl::primitive_attr &attr) { + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { auto build = [&](const InterpolateExecutorDesc* desc) { auto executor = desc->builder->makeExecutor(context); if (executor->init(interpolateAttrs, srcDescs, dstDescs, attr)) { @@ -52,7 +52,6 @@ class InterpolateExecutorFactory : public ExecutorFactoryLegacy { return ptr; }; - if (chosenDesc) { if (auto executor = build(chosenDesc)) { return executor; @@ -81,5 +80,5 @@ class InterpolateExecutorFactory : public ExecutorFactoryLegacy { using InterpolateExecutorFactoryPtr = std::shared_ptr; using InterpolateExecutorFactoryCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/memory_arguments.hpp b/src/plugins/intel_cpu/src/nodes/executors/memory_arguments.hpp index 7150226d27c601..05c3cf0d5df259 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/memory_arguments.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/memory_arguments.hpp @@ -12,7 +12,7 @@ namespace ov { namespace intel_cpu { using MemoryDescArgs = std::unordered_map; -using MemoryArgs = std::unordered_map; +using MemoryArgs = std::unordered_map; // basic inputs #define ARG_SRC_0 1 diff --git a/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_gemm.cpp b/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_gemm.cpp index 8fd945b773f262..7e50c8086789a0 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_gemm.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_gemm.cpp @@ -104,8 +104,7 @@ MlasGemmExecutor::MlasGemmExecutor(const FCAttrs& attrs, m_memoryArgs(memory), packedWeights(prepareWeightMemory(memory.at(ARG_WEI), context, !attrs.weightsNonTransposed)), N(batchDim(memory.at(ARG_WEI)->getStaticDims())), - K(memory.at(ARG_WEI)->getStaticDims().back()) -{} + K(memory.at(ARG_WEI)->getStaticDims().back()) {} bool MlasGemmExecutor::update(const MemoryArgs& memory) { const auto& dstDesc = memory.at(ARG_DST)->getDescPtr(); diff --git a/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_transpose.cpp b/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_transpose.cpp index 678fe5a5c22176..2b8b71bfbced0b 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_transpose.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_transpose.cpp @@ -3,9 +3,10 @@ // #include "mlas_transpose.hpp" -#include "openvino/core/parallel.hpp" -#include "nodes/common/cpu_memcpy.h" + #include "mlas.h" +#include "nodes/common/cpu_memcpy.h" +#include "openvino/core/parallel.hpp" namespace ov { namespace intel_cpu { @@ -24,7 +25,12 @@ struct has_mlas_transpose : std::true_type {}; template typename std::enable_if::value, void>::type SimpleTransposeSingleAxisOutwards( - const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop, int64_t writes_per_writer_per_loop) { + const T* input_data, + T* output_data, + int64_t num_loops, + int64_t num_writers, + int64_t writes_per_loop, + int64_t writes_per_writer_per_loop) { const T* end; for (int64_t l = 0; l < num_loops; ++l) { T* output_for_first_writer = output_data; @@ -44,9 +50,17 @@ typename std::enable_if::value, void>::type SimpleTranspo template typename std::enable_if::value, void>::type SimpleTransposeSingleAxisOutwards( - const T* input_data, T* output_data, int64_t num_loops, int64_t num_writers, int64_t writes_per_loop, int64_t writes_per_writer_per_loop) { + const T* input_data, + T* output_data, + int64_t num_loops, + int64_t num_writers, + int64_t writes_per_loop, + int64_t writes_per_writer_per_loop) { for (int64_t l = 0; l < num_loops; ++l) { - MlasTranspose(input_data, output_data, static_cast(writes_per_writer_per_loop), static_cast(num_writers)); + MlasTranspose(input_data, + output_data, + static_cast(writes_per_writer_per_loop), + static_cast(num_writers)); input_data += writes_per_loop; output_data += writes_per_loop; } @@ -54,7 +68,12 @@ typename std::enable_if::value, void>::type SimpleTranspos template typename std::enable_if::value, void>::type SimpleTransposeSingleAxisInwards( - const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop, int64_t reads_per_reader_per_loop) { + const T* input_data, + T* output_data, + int64_t num_loops, + int64_t num_readers, + int64_t reads_per_loop, + int64_t reads_per_reader_per_loop) { T* end; for (int64_t l = 0; l < num_loops; ++l) { const T* input_for_first_reader = input_data; @@ -74,9 +93,17 @@ typename std::enable_if::value, void>::type SimpleTranspo template typename std::enable_if::value, void>::type SimpleTransposeSingleAxisInwards( - const T* input_data, T* output_data, int64_t num_loops, int64_t num_readers, int64_t reads_per_loop, int64_t reads_per_reader_per_loop) { + const T* input_data, + T* output_data, + int64_t num_loops, + int64_t num_readers, + int64_t reads_per_loop, + int64_t reads_per_reader_per_loop) { for (int64_t l = 0; l < num_loops; ++l) { - MlasTranspose(input_data, output_data, static_cast(num_readers), static_cast(reads_per_reader_per_loop)); + MlasTranspose(input_data, + output_data, + static_cast(num_readers), + static_cast(reads_per_reader_per_loop)); input_data += reads_per_loop; output_data += reads_per_loop; } @@ -148,7 +175,10 @@ bool MlasTransposeExecutor::IsTransposeMovingSingleAxis(VectorDims permutations, return single_axis_moved; } -void MlasTransposeExecutor::TransposeSingleAxisOutwards(const MemoryCPtr& input, const MemoryPtr& output, size_t from, size_t to) { +void MlasTransposeExecutor::TransposeSingleAxisOutwards(const MemoryCPtr& input, + const MemoryPtr& output, + size_t from, + size_t to) { const auto& input_shape = input->getShape(); const auto& input_dims = input_shape.getDims(); const auto element_size = input->getDesc().getPrecision().size(); @@ -165,52 +195,68 @@ void MlasTransposeExecutor::TransposeSingleAxisOutwards(const MemoryCPtr& input, const size_t bytes_per_write = static_cast(block_size) * element_size; switch (bytes_per_write) { - case (sizeof(uint8_t)): { - SimpleTransposeSingleAxisOutwards(input_data, output_data, num_loops, num_writers, writes_per_loop, - writes_per_writer_per_loop); - break; - } - case (sizeof(uint16_t)): { - SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), - reinterpret_cast(output_data), num_loops, num_writers, - writes_per_loop, writes_per_writer_per_loop); - break; - } - case (sizeof(uint32_t)): { - SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), - reinterpret_cast(output_data), num_loops, num_writers, - writes_per_loop, writes_per_writer_per_loop); - break; - } - case (sizeof(uint64_t)): { - SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), - reinterpret_cast(output_data), num_loops, num_writers, - writes_per_loop, writes_per_writer_per_loop); - break; - } - default: { - // we need to use memcpy for each block - for (int64_t l = 0; l < num_loops; ++l) { - uint8_t* output_for_first_writer = output_data; + case (sizeof(uint8_t)): { + SimpleTransposeSingleAxisOutwards(input_data, + output_data, + num_loops, + num_writers, + writes_per_loop, + writes_per_writer_per_loop); + break; + } + case (sizeof(uint16_t)): { + SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), + reinterpret_cast(output_data), + num_loops, + num_writers, + writes_per_loop, + writes_per_writer_per_loop); + break; + } + case (sizeof(uint32_t)): { + SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), + reinterpret_cast(output_data), + num_loops, + num_writers, + writes_per_loop, + writes_per_writer_per_loop); + break; + } + case (sizeof(uint64_t)): { + SimpleTransposeSingleAxisOutwards(reinterpret_cast(input_data), + reinterpret_cast(output_data), + num_loops, + num_writers, + writes_per_loop, + writes_per_writer_per_loop); + break; + } + default: { + // we need to use memcpy for each block + for (int64_t l = 0; l < num_loops; ++l) { + uint8_t* output_for_first_writer = output_data; - for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) { - uint8_t* output_for_current_writer = output_for_first_writer; + for (auto wwpl = 0; wwpl < writes_per_writer_per_loop; ++wwpl) { + uint8_t* output_for_current_writer = output_for_first_writer; - for (uint64_t w = 0; w < num_writers; ++w) { - memcpy(output_for_current_writer, input_data, bytes_per_write); - // skip to output position for next writer - output_for_current_writer += (writes_per_writer_per_loop * bytes_per_write); - input_data += bytes_per_write; - } - output_for_first_writer += bytes_per_write; + for (uint64_t w = 0; w < num_writers; ++w) { + memcpy(output_for_current_writer, input_data, bytes_per_write); + // skip to output position for next writer + output_for_current_writer += (writes_per_writer_per_loop * bytes_per_write); + input_data += bytes_per_write; } - output_data += writes_per_loop * bytes_per_write; + output_for_first_writer += bytes_per_write; } + output_data += writes_per_loop * bytes_per_write; } } + } } -void MlasTransposeExecutor::TransposeSingleAxisInwards(const MemoryCPtr& input, const MemoryPtr& output, size_t from, size_t to) { +void MlasTransposeExecutor::TransposeSingleAxisInwards(const MemoryCPtr& input, + const MemoryPtr& output, + size_t from, + size_t to) { const auto& input_shape = input->getShape(); const auto& input_dims = input_shape.getDims(); @@ -227,61 +273,74 @@ void MlasTransposeExecutor::TransposeSingleAxisInwards(const MemoryCPtr& input, const size_t bytes_per_read = static_cast(block_size) * element_size; switch (bytes_per_read) { - case (sizeof(uint8_t)): { - SimpleTransposeSingleAxisInwards(input_data, output_data, num_loops, num_readers, reads_per_loop, - reads_per_reader_per_loop); - break; - } - case (sizeof(uint16_t)): { - SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), - reinterpret_cast(output_data), num_loops, num_readers, reads_per_loop, - reads_per_reader_per_loop); - break; - } - case (sizeof(uint32_t)): { - SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), - reinterpret_cast(output_data), num_loops, num_readers, reads_per_loop, - reads_per_reader_per_loop); - break; - } - case (sizeof(uint64_t)): { - SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), - reinterpret_cast(output_data), num_loops, num_readers, reads_per_loop, - reads_per_reader_per_loop); - break; - } - default: { - // we need to use memcpy for each block - for (int64_t l = 0; l < num_loops; ++l) { - const uint8_t* input_for_first_reader = input_data; - for (auto rrpl = 0; rrpl < reads_per_reader_per_loop; ++rrpl) { - const uint8_t* input_for_current_reader = input_for_first_reader; - for (uint64_t r = 0; r < num_readers; ++r) { - memcpy(output_data, input_for_current_reader, bytes_per_read); - output_data += bytes_per_read; - // skip to input position for next reader - input_for_current_reader += (reads_per_reader_per_loop * bytes_per_read); - } - input_for_first_reader += bytes_per_read; + case (sizeof(uint8_t)): { + SimpleTransposeSingleAxisInwards(input_data, + output_data, + num_loops, + num_readers, + reads_per_loop, + reads_per_reader_per_loop); + break; + } + case (sizeof(uint16_t)): { + SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), + reinterpret_cast(output_data), + num_loops, + num_readers, + reads_per_loop, + reads_per_reader_per_loop); + break; + } + case (sizeof(uint32_t)): { + SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), + reinterpret_cast(output_data), + num_loops, + num_readers, + reads_per_loop, + reads_per_reader_per_loop); + break; + } + case (sizeof(uint64_t)): { + SimpleTransposeSingleAxisInwards(reinterpret_cast(input_data), + reinterpret_cast(output_data), + num_loops, + num_readers, + reads_per_loop, + reads_per_reader_per_loop); + break; + } + default: { + // we need to use memcpy for each block + for (int64_t l = 0; l < num_loops; ++l) { + const uint8_t* input_for_first_reader = input_data; + for (auto rrpl = 0; rrpl < reads_per_reader_per_loop; ++rrpl) { + const uint8_t* input_for_current_reader = input_for_first_reader; + for (uint64_t r = 0; r < num_readers; ++r) { + memcpy(output_data, input_for_current_reader, bytes_per_read); + output_data += bytes_per_read; + // skip to input position for next reader + input_for_current_reader += (reads_per_reader_per_loop * bytes_per_read); } - input_data += reads_per_loop * bytes_per_read; + input_for_first_reader += bytes_per_read; } + input_data += reads_per_loop * bytes_per_read; } } + } } void MlasTransposeExecutor::exec(const std::vector& src, const std::vector& dst) { if (from > to) { - TransposeSingleAxisOutwards(src[0], dst[0], from, to); + TransposeSingleAxisOutwards(src[0], dst[0], from, to); } else { - TransposeSingleAxisInwards(src[0], dst[0], from, to); + TransposeSingleAxisInwards(src[0], dst[0], from, to); } } -bool MlasTransposeExecutor::init(const TransposeParams &transposeParams, - const std::vector &srcDescs, - const std::vector &dstDescs, - const dnnl::primitive_attr &attr) { +bool MlasTransposeExecutor::init(const TransposeParams& transposeParams, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { if (!IsTransposeMovingSingleAxis(transposeParams.permuteParams.order, from, to)) { DEBUG_LOG("MLAS Transpose executor supports moving single axis only"); return false; @@ -292,8 +351,7 @@ bool MlasTransposeExecutor::init(const TransposeParams &transposeParams, bool MlasTransposeExecutorBuilder::isSupported(const TransposeParams& transposeParams, const std::vector& srcDescs, const std::vector& dstDescs) const { - if (!srcDescs[0]->hasLayoutType(LayoutType::ncsp) || - !dstDescs[0]->hasLayoutType(LayoutType::ncsp)) { + if (!srcDescs[0]->hasLayoutType(LayoutType::ncsp) || !dstDescs[0]->hasLayoutType(LayoutType::ncsp)) { DEBUG_LOG("MLAS Transpose executor supports NCHW layout only"); return false; } @@ -308,5 +366,5 @@ TransposeExecutorPtr MlasTransposeExecutorBuilder::makeExecutor(const ExecutorCo return std::make_shared(context); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_transpose.hpp b/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_transpose.hpp index d7e0307414aac9..8f7cd1bf8c22bd 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_transpose.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/mlas/mlas_transpose.hpp @@ -11,13 +11,16 @@ namespace intel_cpu { class MlasTransposeExecutor : public TransposeExecutor { public: using TransposeExecutor::TransposeExecutor; - bool init(const TransposeParams &transposeParams, - const std::vector &srcDescs, - const std::vector &dstDescs, - const dnnl::primitive_attr &attr) override; - void exec(const std::vector &src, const std::vector &dst) override; + bool init(const TransposeParams& transposeParams, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) override; + void exec(const std::vector& src, const std::vector& dst) override; + + impl_desc_type implType() const override { + return impl_desc_type::mlas; + } - impl_desc_type implType() const override { return impl_desc_type::mlas; } private: static int64_t calcShapeSize(const Shape& shape, size_t start, size_t end); static bool IsTransposeMovingSingleAxis(VectorDims permutations, size_t& from, size_t& to); @@ -37,5 +40,5 @@ class MlasTransposeExecutorBuilder : public TransposeExecutorBuilder { TransposeExecutorPtr makeExecutor(const ExecutorContext::CPtr context) const override; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/mvn.cpp b/src/plugins/intel_cpu/src/nodes/executors/mvn.cpp index 9b522ed9887344..eec9d2a8947975 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/mvn.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/mvn.cpp @@ -11,26 +11,34 @@ MVNExecutor::MVNExecutor(const ExecutorContext::CPtr context) : context(context) VectorDims MVNExecutor::transformTo5DCase(const VectorDims& shape, bool initAcrossChannels) { switch (shape.size()) { - // for 1 and 2 rank, if initAcrossChannels_ is true, adjust shape to fully vectorize under unified 5d procedure. - // otherwise there are not enough data in spatial dimension to process in one kernel. - case 1 : // C - if (initAcrossChannels) { - return VectorDims({1, 1, 1, 1, shape[0]}); - } else { - return VectorDims({1, shape[0], 1, 1, 1}); - } - case 2 : // NC - if (initAcrossChannels) { - return VectorDims({1, shape[0], 1, shape[1], 1}); - } else { - return VectorDims({shape[0], shape[1], 1, 1, 1}); - } - case 3 : { return VectorDims({shape[0], shape[1], 1, shape[2], 1}); } - case 4 : { return VectorDims({shape[0], shape[1], 1, shape[2], shape[3]}); } - case 5 : { return VectorDims({shape[0], shape[1], shape[2], shape[3], shape[4]}); } - default : { OPENVINO_THROW("MVN executor doesn't support planar layout with rank: ", shape.size()); } + // for 1 and 2 rank, if initAcrossChannels_ is true, adjust shape to fully vectorize under unified 5d procedure. + // otherwise there are not enough data in spatial dimension to process in one kernel. + case 1: // C + if (initAcrossChannels) { + return VectorDims({1, 1, 1, 1, shape[0]}); + } else { + return VectorDims({1, shape[0], 1, 1, 1}); + } + case 2: // NC + if (initAcrossChannels) { + return VectorDims({1, shape[0], 1, shape[1], 1}); + } else { + return VectorDims({shape[0], shape[1], 1, 1, 1}); + } + case 3: { + return VectorDims({shape[0], shape[1], 1, shape[2], 1}); + } + case 4: { + return VectorDims({shape[0], shape[1], 1, shape[2], shape[3]}); + } + case 5: { + return VectorDims({shape[0], shape[1], shape[2], shape[3], shape[4]}); + } + default: { + OPENVINO_THROW("MVN executor doesn't support planar layout with rank: ", shape.size()); + } } } -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/mvn.hpp b/src/plugins/intel_cpu/src/nodes/executors/mvn.hpp index 759115a4b4b794..da51b5d1ef67e9 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/mvn.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/mvn.hpp @@ -5,29 +5,22 @@ #pragma once #include "cpu_memory.h" -#include "onednn/iml_type_mapper.h" #include "executor.hpp" +#include "onednn/iml_type_mapper.h" namespace ov { namespace intel_cpu { -enum MVNLayoutType { - mvn_planar, - mvn_block, - mvn_by_channel -}; +enum MVNLayoutType { mvn_planar, mvn_block, mvn_by_channel }; // Defines way to add epsilon: inside sqrt or outside. -enum MVNEpsMode { - INSIDE_SQRT, - OUTSIDE_SQRT -}; +enum MVNEpsMode { INSIDE_SQRT, OUTSIDE_SQRT }; struct MVNAttrs { MVNLayoutType layout = mvn_planar; bool initAcrossChannels_ = false; bool execAcrossChannels_ = false; - bool normalizeVariance_ = false; + bool normalizeVariance_ = false; float epsValue_ = 0.0f; MVNEpsMode epsMode_ = INSIDE_SQRT; ov::element::Type src_prc; @@ -40,9 +33,11 @@ class MVNExecutor { virtual bool init(const MVNAttrs& mvnAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) = 0; + const dnnl::primitive_attr& attr) = 0; - virtual void exec(const std::vector& src, const std::vector& dst, const void *post_ops_data_) = 0; + virtual void exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) = 0; virtual ~MVNExecutor() = default; virtual impl_desc_type getImplType() const = 0; @@ -60,12 +55,14 @@ using MVNExecutorCPtr = std::shared_ptr; class MVNExecutorBuilder { public: ~MVNExecutorBuilder() = default; - virtual bool isSupported(const MVNAttrs& mvnAttrs, const std::vector& srcDescs, const std::vector& dstDescs) const = 0; + virtual bool isSupported(const MVNAttrs& mvnAttrs, + const std::vector& srcDescs, + const std::vector& dstDescs) const = 0; virtual MVNExecutorPtr makeExecutor(const ExecutorContext::CPtr context) const = 0; }; using MVNExecutorBuilderPtr = std::shared_ptr; using MVNExecutorBuilderCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/mvn_list.cpp b/src/plugins/intel_cpu/src/nodes/executors/mvn_list.cpp index c27751b7a2d2b4..99a55d79f58177 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/mvn_list.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/mvn_list.cpp @@ -9,11 +9,10 @@ namespace intel_cpu { const std::vector& getMVNExecutorsList() { static std::vector descs = { - OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) - }; + OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared())}; return descs; } -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/mvn_list.hpp b/src/plugins/intel_cpu/src/nodes/executors/mvn_list.hpp index 3a8d3cc61fe585..82f8e868ac2d81 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/mvn_list.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/mvn_list.hpp @@ -5,14 +5,13 @@ #pragma once #include "executor.hpp" - #include "mvn.hpp" #if defined(OV_CPU_WITH_ACL) -#include "acl/acl_mvn.hpp" +# include "acl/acl_mvn.hpp" #endif -#include "onednn/iml_type_mapper.h" #include "common/primitive_cache.hpp" +#include "onednn/iml_type_mapper.h" namespace ov { namespace intel_cpu { @@ -29,7 +28,8 @@ class MVNExecutorFactory : public ExecutorFactoryLegacy { MVNExecutorFactory(const MVNAttrs& mvnAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const ExecutorContext::CPtr context) : ExecutorFactoryLegacy(context) { + const ExecutorContext::CPtr context) + : ExecutorFactoryLegacy(context) { for (auto& desc : getMVNExecutorsList()) { if (desc.builder->isSupported(mvnAttrs, srcDescs, dstDescs)) { supportedDescs.push_back(desc); @@ -41,7 +41,7 @@ class MVNExecutorFactory : public ExecutorFactoryLegacy { virtual MVNExecutorPtr makeExecutor(const MVNAttrs& mvnAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) { + const dnnl::primitive_attr& attr) { auto build = [&](const MVNExecutorDesc* desc) { auto executor = desc->builder->makeExecutor(context); if (executor->init(mvnAttrs, srcDescs, dstDescs, attr)) { @@ -80,5 +80,5 @@ class MVNExecutorFactory : public ExecutorFactoryLegacy { using MVNExecutorFactoryPtr = std::shared_ptr; using MVNExecutorFactoryCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/pooling.cpp b/src/plugins/intel_cpu/src/nodes/executors/pooling.cpp index e15d1a4ef15b8d..95448640e3b125 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/pooling.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/pooling.cpp @@ -9,5 +9,5 @@ namespace intel_cpu { PoolingExecutor::PoolingExecutor(const ExecutorContext::CPtr context) : context(context) {} -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/pooling.hpp b/src/plugins/intel_cpu/src/nodes/executors/pooling.hpp index 5ea358c68afc8e..e826d3a37250db 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/pooling.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/pooling.hpp @@ -5,8 +5,8 @@ #pragma once #include "cpu_memory.h" -#include "onednn/iml_type_mapper.h" #include "executor.hpp" +#include "onednn/iml_type_mapper.h" namespace ov { namespace intel_cpu { @@ -44,9 +44,11 @@ class PoolingExecutor { virtual bool init(const PoolingAttrs& poolingAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) = 0; + const dnnl::primitive_attr& attr) = 0; - virtual void exec(const std::vector& src, const std::vector& dst, std::unordered_map postOpsArgs) = 0; + virtual void exec(const std::vector& src, + const std::vector& dst, + std::unordered_map postOpsArgs) = 0; virtual ~PoolingExecutor() = default; virtual impl_desc_type getImplType() const = 0; @@ -71,5 +73,5 @@ class PoolingExecutorBuilder { using PoolingExecutorBuilderPtr = std::shared_ptr; using PoolingExecutorBuilderCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/pooling_list.cpp b/src/plugins/intel_cpu/src/nodes/executors/pooling_list.cpp index 4b130f37bfff57..d0ee9f7da574c6 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/pooling_list.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/pooling_list.cpp @@ -9,11 +9,10 @@ namespace intel_cpu { const std::vector& getPoolingExecutorsList() { static std::vector descs = { - OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) - }; + OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared())}; return descs; } -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/pooling_list.hpp b/src/plugins/intel_cpu/src/nodes/executors/pooling_list.hpp index d6ce5489105b19..1c051ae7d2959d 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/pooling_list.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/pooling_list.hpp @@ -5,10 +5,9 @@ #pragma once #include "executor.hpp" - #include "pooling.hpp" #if defined(OV_CPU_WITH_ACL) -#include "acl/acl_pooling.hpp" +# include "acl/acl_pooling.hpp" #endif namespace ov { @@ -24,9 +23,10 @@ const std::vector& getPoolingExecutorsList(); class PoolingExecutorFactory : public ExecutorFactoryLegacy { public: PoolingExecutorFactory(const PoolingAttrs& poolingAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs, - const ExecutorContext::CPtr context) : ExecutorFactoryLegacy(context) { + const std::vector& srcDescs, + const std::vector& dstDescs, + const ExecutorContext::CPtr context) + : ExecutorFactoryLegacy(context) { for (auto& desc : getPoolingExecutorsList()) { if (desc.builder->isSupported(poolingAttrs, srcDescs, dstDescs)) { supportedDescs.push_back(desc); @@ -36,9 +36,9 @@ class PoolingExecutorFactory : public ExecutorFactoryLegacy { ~PoolingExecutorFactory() = default; virtual PoolingExecutorPtr makeExecutor(const PoolingAttrs& poolingAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs, - const dnnl::primitive_attr &attr) { + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { auto build = [&](const PoolingExecutorDesc* desc) { auto executor = desc->builder->makeExecutor(context); if (executor->init(poolingAttrs, srcDescs, dstDescs, attr)) { @@ -49,7 +49,6 @@ class PoolingExecutorFactory : public ExecutorFactoryLegacy { return ptr; }; - if (chosenDesc) { if (auto executor = build(chosenDesc)) { return executor; @@ -74,5 +73,5 @@ class PoolingExecutorFactory : public ExecutorFactoryLegacy { using PoolingExecutorFactoryPtr = std::shared_ptr; using PoolingExecutorFactoryCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/precision_matcher.cpp b/src/plugins/intel_cpu/src/nodes/executors/precision_matcher.cpp index 95044a9e205595..ced50dd2ec3dd5 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/precision_matcher.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/precision_matcher.cpp @@ -15,9 +15,12 @@ namespace intel_cpu { bool match(const InOutTypeMask& patterns, const InOutTypes& values) { assert(patterns.size() == values.size()); - return std::equal(values.begin(), values.end(), patterns.begin(), [](const ov::element::Type value, const TypeMask pattern) { - return pattern & value; - }); + return std::equal(values.begin(), + values.end(), + patterns.begin(), + [](const ov::element::Type value, const TypeMask pattern) { + return pattern & value; + }); return true; } diff --git a/src/plugins/intel_cpu/src/nodes/executors/precision_translation.cpp b/src/plugins/intel_cpu/src/nodes/executors/precision_translation.cpp index 73aac151843b08..36aab4f8fddc77 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/precision_translation.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/precision_translation.cpp @@ -14,7 +14,9 @@ namespace ov { namespace intel_cpu { -InOutTypes getTypeConfiguration(const MemoryDescArgs& descriptors, const TypeMapping& mapping, const MappingNotation& notation) { +InOutTypes getTypeConfiguration(const MemoryDescArgs& descriptors, + const TypeMapping& mapping, + const MappingNotation& notation) { InOutTypes types; std::transform(notation.begin(), notation.end(), std::back_inserter(types), [&descriptors](int id) { return descriptors.at(id)->getPrecision(); diff --git a/src/plugins/intel_cpu/src/nodes/executors/precision_translation.hpp b/src/plugins/intel_cpu/src/nodes/executors/precision_translation.hpp index 374b584dd0ffb5..20e613eea2c236 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/precision_translation.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/precision_translation.hpp @@ -18,24 +18,21 @@ namespace intel_cpu { template struct use { - ov::element::Type operator()(const std::vector& types, - size_t idx) const { + ov::element::Type operator()(const std::vector& types, size_t idx) const { assert(bypassId < types.size()); return types[bypassId]; } }; struct bypass { - ov::element::Type operator()(const std::vector& types, - size_t idx) const { + ov::element::Type operator()(const std::vector& types, size_t idx) const { return types[idx]; } }; template struct just { - ov::element::Type operator()(const std::vector& types, - size_t idx) const { + ov::element::Type operator()(const std::vector& types, size_t idx) const { // ignore everything (void)types; (void)idx; @@ -45,8 +42,7 @@ struct just { template <> struct just { - ov::element::Type operator()(const std::vector& types, - size_t idx) const { + ov::element::Type operator()(const std::vector& types, size_t idx) const { // ignore everything (void)types; (void)idx; @@ -58,11 +54,9 @@ using policy = std::function - PortsTranslation(Policies... policies) : - m_policies{policies...} {} + PortsTranslation(Policies... policies) : m_policies{policies...} {} - std::vector operator()( - const std::vector& types) const { + std::vector operator()(const std::vector& types) const { assert(types.size() == m_policies.size()); std::vector result; @@ -73,6 +67,7 @@ struct PortsTranslation { return result; } + private: std::vector m_policies; }; @@ -88,9 +83,7 @@ class TypeMappingEntry { public: using EnabledPredicate = std::function; - TypeMappingEntry(InOutTypeMask mask, - TypeTranslationFunction translation, - EnabledPredicate enabled = {}) + TypeMappingEntry(InOutTypeMask mask, TypeTranslationFunction translation, EnabledPredicate enabled = {}) : m_mask(std::move(mask)), m_translation(std::move(translation)), m_enabled(std::move(enabled)) {} @@ -121,7 +114,9 @@ using TypeMapping = std::vector; using MappingNotation = std::vector; using pt = PortsTranslation; -InOutTypes getTypeConfiguration(const MemoryDescArgs& descriptors, const TypeMapping& mapping, const MappingNotation& notation); +InOutTypes getTypeConfiguration(const MemoryDescArgs& descriptors, + const TypeMapping& mapping, + const MappingNotation& notation); } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/printers.cpp b/src/plugins/intel_cpu/src/nodes/executors/printers.cpp index ac52b25a069541..1bce932225827d 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/printers.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/printers.cpp @@ -4,25 +4,27 @@ #ifdef CPU_DEBUG_CAPS -#include -#include "printers.hpp" -#include "post_ops.hpp" -#include "fullyconnected_config.hpp" +# include "printers.hpp" + +# include + +# include "fullyconnected_config.hpp" +# include "post_ops.hpp" namespace ov { namespace intel_cpu { -std::ostream & operator<<(std::ostream & os, const FCAttrs& attrs) { +std::ostream& operator<<(std::ostream& os, const FCAttrs& attrs) { // @todo print Attrs return os; } -std::ostream & operator<<(std::ostream & os, const PostOps& postOps) { +std::ostream& operator<<(std::ostream& os, const PostOps& postOps) { // @todo print PostOps return os; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov -#endif // CPU_DEBUG_CAPS +#endif // CPU_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/nodes/executors/printers.hpp b/src/plugins/intel_cpu/src/nodes/executors/printers.hpp index d37ab633ba8036..7a96550b3f225c 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/printers.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/printers.hpp @@ -3,25 +3,27 @@ // #ifdef CPU_DEBUG_CAPS -#pragma once +# pragma once -#include -#include "executor_config.hpp" +# include + +# include "executor_config.hpp" namespace ov { namespace intel_cpu { namespace executor { -template struct Config; +template +struct Config; } struct FCAttrs; -std::ostream & operator<<(std::ostream & os, const FCAttrs& attrs); -std::ostream & operator<<(std::ostream & os, const PostOps& postOps); +std::ostream& operator<<(std::ostream& os, const FCAttrs& attrs); +std::ostream& operator<<(std::ostream& os, const PostOps& postOps); -template -std::ostream & operator<<(std::ostream & os, const executor::Config& config) { +template +std::ostream& operator<<(std::ostream& os, const executor::Config& config) { for (const auto& desc : config.descs) { const auto id = desc.first; const auto descPtr = desc.second; @@ -34,7 +36,7 @@ std::ostream & operator<<(std::ostream & os, const executor::Config& conf return os; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov -#endif // CPU_DEBUG_CAPS +#endif // CPU_DEBUG_CAPS diff --git a/src/plugins/intel_cpu/src/nodes/executors/reduce.cpp b/src/plugins/intel_cpu/src/nodes/executors/reduce.cpp index 8e091f0282eb5d..6039813d8fdd28 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/reduce.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/reduce.cpp @@ -9,5 +9,5 @@ namespace intel_cpu { ReduceExecutor::ReduceExecutor(const ExecutorContext::CPtr context) : context(context) {} -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/reduce.hpp b/src/plugins/intel_cpu/src/nodes/executors/reduce.hpp index 8aa6e8f0aaa4ac..21b730a197df3a 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/reduce.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/reduce.hpp @@ -5,9 +5,9 @@ #pragma once #include "cpu_memory.h" -#include "onednn/iml_type_mapper.h" #include "dnnl_scratch_pad.h" #include "executor.hpp" +#include "onednn/iml_type_mapper.h" namespace ov { namespace intel_cpu { @@ -24,9 +24,11 @@ class ReduceExecutor { virtual bool init(const ReduceAttrs& reduceAttrs, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) = 0; + const dnnl::primitive_attr& attr) = 0; - virtual void exec(const std::vector& src, const std::vector& dst, const void *post_ops_data_) = 0; + virtual void exec(const std::vector& src, + const std::vector& dst, + const void* post_ops_data_) = 0; virtual ~ReduceExecutor() = default; virtual impl_desc_type getImplType() const = 0; @@ -51,5 +53,5 @@ class ReduceExecutorBuilder { using ReduceExecutorBuilderPtr = std::shared_ptr; using ReduceExecutorBuilderCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/reduce_list.cpp b/src/plugins/intel_cpu/src/nodes/executors/reduce_list.cpp index aec5c7eb905865..e6f035879a2cc6 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/reduce_list.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/reduce_list.cpp @@ -9,11 +9,10 @@ namespace intel_cpu { const std::vector& getReduceExecutorsList() { static std::vector descs = { - OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) - }; + OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared())}; return descs; } -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/reduce_list.hpp b/src/plugins/intel_cpu/src/nodes/executors/reduce_list.hpp index ea2543a495e64c..faffdebc947c02 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/reduce_list.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/reduce_list.hpp @@ -5,14 +5,13 @@ #pragma once #include "executor.hpp" - #include "reduce.hpp" #if defined(OV_CPU_WITH_ACL) -#include "acl/acl_reduce.hpp" +# include "acl/acl_reduce.hpp" #endif -#include "onednn/iml_type_mapper.h" #include "common/primitive_cache.hpp" +#include "onednn/iml_type_mapper.h" namespace ov { namespace intel_cpu { @@ -27,9 +26,10 @@ const std::vector& getReduceExecutorsList(); class ReduceExecutorFactory : public ExecutorFactoryLegacy { public: ReduceExecutorFactory(const ReduceAttrs& reduceAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs, - const ExecutorContext::CPtr context) : ExecutorFactoryLegacy(context) { + const std::vector& srcDescs, + const std::vector& dstDescs, + const ExecutorContext::CPtr context) + : ExecutorFactoryLegacy(context) { for (auto& desc : getReduceExecutorsList()) { if (desc.builder->isSupported(reduceAttrs, srcDescs, dstDescs)) { supportedDescs.push_back(desc); @@ -39,9 +39,9 @@ class ReduceExecutorFactory : public ExecutorFactoryLegacy { ~ReduceExecutorFactory() = default; virtual ReduceExecutorPtr makeExecutor(const ReduceAttrs& reduceAttrs, - const std::vector& srcDescs, - const std::vector& dstDescs, - const dnnl::primitive_attr &attr) { + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { auto build = [&](const ReduceExecutorDesc* desc) { auto executor = desc->builder->makeExecutor(context); if (executor->init(reduceAttrs, srcDescs, dstDescs, attr)) { @@ -52,7 +52,6 @@ class ReduceExecutorFactory : public ExecutorFactoryLegacy { return ptr; }; - if (chosenDesc) { if (auto executor = build(chosenDesc)) { return executor; @@ -81,5 +80,5 @@ class ReduceExecutorFactory : public ExecutorFactoryLegacy { using ReduceExecutorFactoryPtr = std::shared_ptr; using ReduceExecutorFactoryCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/transpose.cpp b/src/plugins/intel_cpu/src/nodes/executors/transpose.cpp index 57e2e028827a62..b63e32e39ebf8d 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/transpose.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/transpose.cpp @@ -2,9 +2,11 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "transpose.hpp" + #include + #include "openvino/core/parallel.hpp" -#include "transpose.hpp" namespace ov { namespace intel_cpu { @@ -33,27 +35,27 @@ jit_permute_config_params TransposeExecutor::prepareParams(const PermuteParams& } for (int i = tmp_order.size() - 1; i >= 0; i--) { - int pos = std::distance(std::find( - src_block_order.rbegin(), src_block_order.rend(), tmp_order[i]), src_block_order.rend() - 1); + int pos = std::distance(std::find(src_block_order.rbegin(), src_block_order.rend(), tmp_order[i]), + src_block_order.rend() - 1); if (pos != -1) { new_src_block_strides[i] = src_block_strides[pos]; src_block_order.erase(src_block_order.begin() + pos); src_block_strides.erase(src_block_strides.begin() + pos); mask[i] = 0; } else { - new_src_block_strides[i] = new_src_block_strides[tmp_order.size() - 1] * params.dst_block_dims[tmp_order.size() - 1]; + new_src_block_strides[i] = + new_src_block_strides[tmp_order.size() - 1] * params.dst_block_dims[tmp_order.size() - 1]; mask[i] = 1; mask[tmp_order.size() - 1] = 1; } } if (!src_block_order.empty()) { int pos = std::distance(tmp_order.begin(), std::find(tmp_order.begin(), tmp_order.end(), src_block_order[0])); - new_src_block_strides.insert(new_src_block_strides.begin() + pos, - src_block_strides[0]); - new_dst_block_strides.insert(new_dst_block_strides.begin() + pos, - new_dst_block_strides[pos] * params.src_block_dims[params.src_block_dims.size() - 1]); - new_dst_block_order.insert(new_dst_block_order.begin() + pos, - new_dst_block_order[pos]); + new_src_block_strides.insert(new_src_block_strides.begin() + pos, src_block_strides[0]); + new_dst_block_strides.insert( + new_dst_block_strides.begin() + pos, + new_dst_block_strides[pos] * params.src_block_dims[params.src_block_dims.size() - 1]); + new_dst_block_order.insert(new_dst_block_order.begin() + pos, new_dst_block_order[pos]); new_dst_block_dims.insert(new_dst_block_dims.begin() + pos + 1, params.src_block_dims[params.src_block_dims.size() - 1]); new_dst_block_dims[pos] = div_up(new_dst_block_dims[pos], new_dst_block_dims[pos + 1]); @@ -107,12 +109,12 @@ jit_permute_config_params TransposeExecutor::prepareParams(const PermuteParams& } int max_threads = parallel_get_max_threads(); - const int n_max = 3; // max count dims for parallel + const int n_max = 3; // max count dims for parallel int n = 0; int work_amount = sorted_dst_dims[0]; for (size_t i = 1; i < sorted_dst_dims.size() && n < n_max; i++) { n++; - if (work_amount >= 4 * max_threads) { // 4 * max_threads is a specially selected value for best performance + if (work_amount >= 4 * max_threads) { // 4 * max_threads is a specially selected value for best performance break; } work_amount *= sorted_dst_dims[i]; @@ -128,5 +130,5 @@ jit_permute_config_params TransposeExecutor::prepareParams(const PermuteParams& return jcp; } -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/executors/transpose.hpp b/src/plugins/intel_cpu/src/nodes/executors/transpose.hpp index 15f2d5085cd5ad..99e0b0a2742a78 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/transpose.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/transpose.hpp @@ -5,9 +5,9 @@ #pragma once #include "cpu_memory.h" -#include "onednn/iml_type_mapper.h" #include "executor.hpp" #include "nodes/common/permute_kernel.h" +#include "onednn/iml_type_mapper.h" namespace ov { namespace intel_cpu { @@ -23,8 +23,9 @@ class TransposeExecutor : public Executor { virtual bool init(const TransposeParams& transposeParams, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) = 0; + const dnnl::primitive_attr& attr) = 0; virtual ~TransposeExecutor() = default; + protected: PermuteParams permuteParams; const ExecutorContext::CPtr context; @@ -44,5 +45,5 @@ class TransposeExecutorBuilder { using TransposeExecutorBuilderPtr = std::shared_ptr; using TransposeExecutorBuilderCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/transpose_list.cpp b/src/plugins/intel_cpu/src/nodes/executors/transpose_list.cpp index 31db070d04ffe3..f0e72f4bec1ae2 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/transpose_list.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/transpose_list.cpp @@ -9,20 +9,19 @@ namespace intel_cpu { const std::vector& getTransposeExecutorsList() { static const std::vector descs = { - OV_CPU_INSTANCE_COMMON(ExecutorType::Common, std::make_shared()) + OV_CPU_INSTANCE_COMMON(ExecutorType::Common, std::make_shared()) OV_CPU_INSTANCE_ACL(ExecutorType::Acl, std::make_shared()) - OV_CPU_INSTANCE_MLAS_ARM64(ExecutorType::Mlas, std::make_shared()) - OV_CPU_INSTANCE_X64(ExecutorType::jit_x64, std::make_shared()) - OV_CPU_INSTANCE_COMMON(ExecutorType::Common, std::make_shared()) - }; + OV_CPU_INSTANCE_MLAS_ARM64(ExecutorType::Mlas, std::make_shared()) + OV_CPU_INSTANCE_X64(ExecutorType::jit_x64, std::make_shared()) + OV_CPU_INSTANCE_COMMON(ExecutorType::Common, std::make_shared())}; return descs; } TransposeExecutorPtr TransposeExecutorFactory::makeExecutor(const TransposeParams& transposeParams, - const std::vector& srcDescs, - const std::vector& dstDescs, - const dnnl::primitive_attr &attr) { + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { auto build = [&](const TransposeExecutorDesc* desc) { auto executor = desc->builder->makeExecutor(context); if (executor->init(transposeParams, srcDescs, dstDescs, attr)) { @@ -48,5 +47,5 @@ TransposeExecutorPtr TransposeExecutorFactory::makeExecutor(const TransposeParam OPENVINO_THROW("Supported executor is not found"); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/transpose_list.hpp b/src/plugins/intel_cpu/src/nodes/executors/transpose_list.hpp index 90141a6194592e..c81769fd1d0539 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/transpose_list.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/transpose_list.hpp @@ -5,19 +5,17 @@ #pragma once #include "executor.hpp" - #include "transpose.hpp" #if defined(OV_CPU_WITH_ACL) -#include "acl/acl_transpose.hpp" +# include "acl/acl_transpose.hpp" #endif +#include "common/primitive_cache.hpp" #include "common/ref_opt_transpose.hpp" #include "common/ref_transpose.hpp" #include "mlas/mlas_transpose.hpp" -#include "x64/jit_transpose.hpp" - #include "onednn/iml_type_mapper.h" -#include "common/primitive_cache.hpp" +#include "x64/jit_transpose.hpp" namespace ov { namespace intel_cpu { @@ -31,22 +29,23 @@ const std::vector& getTransposeExecutorsList(); class TransposeExecutorFactory : public ExecutorFactoryLegacy { public: -TransposeExecutorFactory(const TransposeParams& transposeParams, - const std::vector& srcDescs, - const std::vector& dstDescs, - const ExecutorContext::CPtr context) : ExecutorFactoryLegacy(context) { - for (auto& desc : getTransposeExecutorsList()) { - if (desc.builder->isSupported(transposeParams, srcDescs, dstDescs)) { - supportedDescs.push_back(desc); + TransposeExecutorFactory(const TransposeParams& transposeParams, + const std::vector& srcDescs, + const std::vector& dstDescs, + const ExecutorContext::CPtr context) + : ExecutorFactoryLegacy(context) { + for (auto& desc : getTransposeExecutorsList()) { + if (desc.builder->isSupported(transposeParams, srcDescs, dstDescs)) { + supportedDescs.push_back(desc); + } } } -} -~TransposeExecutorFactory() = default; -virtual TransposeExecutorPtr makeExecutor(const TransposeParams& transposeParams, - const std::vector& srcDescs, - const std::vector& dstDescs, - const dnnl::primitive_attr &attr); + ~TransposeExecutorFactory() = default; + virtual TransposeExecutorPtr makeExecutor(const TransposeParams& transposeParams, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr); private: std::vector supportedDescs; @@ -56,5 +55,5 @@ virtual TransposeExecutorPtr makeExecutor(const TransposeParams& transposeParams using TransposeExecutorFactoryPtr = std::shared_ptr; using TransposeExecutorFactoryCPtr = std::shared_ptr; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/type_mask.hpp b/src/plugins/intel_cpu/src/nodes/executors/type_mask.hpp index d492bd6b6f368a..ef9fdac7f19208 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/type_mask.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/type_mask.hpp @@ -14,29 +14,29 @@ namespace intel_cpu { struct TypeMask { enum Value : uint64_t { _undefined = 1 << 0, - _dynamic = 1 << 1, - _boolean = 1 << 2, - _bf16 = 1 << 3, - _f16 = 1 << 4, - _f32 = 1 << 5, - _f64 = 1 << 6, - _i4 = 1 << 7, - _i8 = 1 << 8, - _i16 = 1 << 9, - _i32 = 1 << 10, - _i64 = 1 << 11, - _u1 = 1 << 12, - _u4 = 1 << 13, - _u8 = 1 << 14, - _u16 = 1 << 15, - _u32 = 1 << 16, - _u64 = 1 << 17, - _nf4 = 1 << 18, - _f8e4m3 = 1 << 19, - _f8e5m2 = 1 << 20, - _string = 1 << 21, - _f4e2m1 = 1 << 22, - _f8e8m0 = 1 << 23, + _dynamic = 1 << 1, + _boolean = 1 << 2, + _bf16 = 1 << 3, + _f16 = 1 << 4, + _f32 = 1 << 5, + _f64 = 1 << 6, + _i4 = 1 << 7, + _i8 = 1 << 8, + _i16 = 1 << 9, + _i32 = 1 << 10, + _i64 = 1 << 11, + _u1 = 1 << 12, + _u4 = 1 << 13, + _u8 = 1 << 14, + _u16 = 1 << 15, + _u32 = 1 << 16, + _u64 = 1 << 17, + _nf4 = 1 << 18, + _f8e4m3 = 1 << 19, + _f8e5m2 = 1 << 20, + _string = 1 << 21, + _f4e2m1 = 1 << 22, + _f8e8m0 = 1 << 23, }; TypeMask(const ov::element::Type precision) : value(generateMask(precision)), precision(precision) {} diff --git a/src/plugins/intel_cpu/src/nodes/executors/x64/jit_transpose.cpp b/src/plugins/intel_cpu/src/nodes/executors/x64/jit_transpose.cpp index bfcc7ad4ae672a..79c578aaacda61 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/x64/jit_transpose.cpp +++ b/src/plugins/intel_cpu/src/nodes/executors/x64/jit_transpose.cpp @@ -3,6 +3,7 @@ // #include "jit_transpose.hpp" + #include "cpu/x64/cpu_isa_traits.hpp" using namespace dnnl::impl::cpu; @@ -21,9 +22,10 @@ void JitTransposeExecutor::exec(const std::vector& src, const std::v pKernel->execute(srcData, dstData, MB); } -bool JitTransposeExecutor::init(const TransposeParams &transposeParams, - const std::vector &srcDescs, - const std::vector &dstDescs, const dnnl::primitive_attr &attr) { +bool JitTransposeExecutor::init(const TransposeParams& transposeParams, + const std::vector& srcDescs, + const std::vector& dstDescs, + const dnnl::primitive_attr& attr) { pKernel = std::make_shared(transposeParams.permuteParams); return true; } @@ -35,9 +37,9 @@ bool JitTransposeExecutorBuilder::isSupported(const TransposeParams& transposePa if (mayiuse(x64::sse41)) { return true; } -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 return false; } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/executors/x64/jit_transpose.hpp b/src/plugins/intel_cpu/src/nodes/executors/x64/jit_transpose.hpp index d37ac9e5db5ef5..fd6d54257f1489 100644 --- a/src/plugins/intel_cpu/src/nodes/executors/x64/jit_transpose.hpp +++ b/src/plugins/intel_cpu/src/nodes/executors/x64/jit_transpose.hpp @@ -16,9 +16,12 @@ class JitTransposeExecutor : public TransposeExecutor { bool init(const TransposeParams& transposeParams, const std::vector& srcDescs, const std::vector& dstDescs, - const dnnl::primitive_attr &attr) override; + const dnnl::primitive_attr& attr) override; void exec(const std::vector& src, const std::vector& dst) override; - impl_desc_type implType() const override { return impl_desc_type::jit; } + impl_desc_type implType() const override { + return impl_desc_type::jit; + } + private: std::shared_ptr pKernel; }; @@ -33,5 +36,5 @@ class JitTransposeExecutorBuilder : public TransposeExecutorBuilder { } }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/experimental_detectron_detection_output.cpp b/src/plugins/intel_cpu/src/nodes/experimental_detectron_detection_output.cpp index 441e013af2cbbf..dc58aabe26635d 100644 --- a/src/plugins/intel_cpu/src/nodes/experimental_detectron_detection_output.cpp +++ b/src/plugins/intel_cpu/src/nodes/experimental_detectron_detection_output.cpp @@ -2,12 +2,13 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/op/experimental_detectron_detection_output.hpp" + #include #include -#include "openvino/op/experimental_detectron_detection_output.hpp" -#include "openvino/core/parallel.hpp" #include "experimental_detectron_detection_output.h" +#include "openvino/core/parallel.hpp" namespace ov { namespace intel_cpu { @@ -36,13 +37,19 @@ struct Indexer { } }; -static -void refine_boxes(const float* boxes, const float* deltas, const float* weights, const float* scores, - float* refined_boxes, float* refined_boxes_areas, float* refined_scores, - const int rois_num, const int classes_num, - const float img_H, const float img_W, - const float max_delta_log_wh, - float coordinates_offset) { +static void refine_boxes(const float* boxes, + const float* deltas, + const float* weights, + const float* scores, + float* refined_boxes, + float* refined_boxes_areas, + float* refined_scores, + const int rois_num, + const int classes_num, + const float img_H, + const float img_W, + const float max_delta_log_wh, + float coordinates_offset) { Indexer box_idx({rois_num, 4}); Indexer delta_idx({rois_num, classes_num, 4}); Indexer score_idx({rois_num, classes_num}); @@ -114,21 +121,22 @@ static bool SortScorePairDescend(const std::pair>& pa return (pair1.first > pair2.first) || ((pair1.first == pair2.first) && (pair1.second.second < pair2.second.second)); } - struct ConfidenceComparator { explicit ConfidenceComparator(const float* conf_data) : _conf_data(conf_data) {} bool operator()(int idx1, int idx2) { - if (_conf_data[idx1] > _conf_data[idx2]) return true; - if (_conf_data[idx1] < _conf_data[idx2]) return false; + if (_conf_data[idx1] > _conf_data[idx2]) + return true; + if (_conf_data[idx1] < _conf_data[idx2]) + return false; return idx1 < idx2; } const float* _conf_data; }; -static inline float JaccardOverlap(const float *decoded_bbox, - const float *bbox_sizes, +static inline float JaccardOverlap(const float* decoded_bbox, + const float* bbox_sizes, const int idx1, const int idx2, const float coordinates_offset = 1) { @@ -151,7 +159,7 @@ static inline float JaccardOverlap(const float *decoded_bbox, float intersect_xmax = (std::min)(xmax1, xmax2); float intersect_ymax = (std::min)(ymax1, ymax2); - float intersect_width = intersect_xmax - intersect_xmin + coordinates_offset; + float intersect_width = intersect_xmax - intersect_xmin + coordinates_offset; float intersect_height = intersect_ymax - intersect_ymin + coordinates_offset; if (intersect_width <= 0 || intersect_height <= 0) { @@ -165,7 +173,6 @@ static inline float JaccardOverlap(const float *decoded_bbox, return intersect_size / (bbox1_size + bbox2_size - intersect_size); } - static void nms_cf(const float* conf_data, const float* bboxes, const float* sizes, @@ -187,8 +194,10 @@ static void nms_cf(const float* conf_data, int num_output_scores = (pre_nms_topn == -1 ? count : (std::min)(pre_nms_topn, count)); - std::partial_sort_copy(indices, indices + count, - buffer, buffer + num_output_scores, + std::partial_sort_copy(indices, + indices + count, + buffer, + buffer + num_output_scores, ConfidenceComparator(conf_data)); detections = 0; @@ -221,11 +230,13 @@ bool ExperimentalDetectronDetectionOutput::needPrepareParams() const { return false; } -bool ExperimentalDetectronDetectionOutput::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool ExperimentalDetectronDetectionOutput::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { const auto doOp = ov::as_type_ptr(op); if (!doOp) { - errorMessage = "Node is not an instance of the ExperimentalDetectronDetectionOutput from the operations set v6."; + errorMessage = + "Node is not an instance of the ExperimentalDetectronDetectionOutput from the operations set v6."; return false; } } catch (...) { @@ -294,10 +305,17 @@ void ExperimentalDetectronDetectionOutput::execute(dnnl::stream strm) { Indexer refined_box_idx({classes_num_, rois_num, 4}); Indexer refined_score_idx({classes_num_, rois_num}); - refine_boxes(boxes, deltas, &deltas_weights_[0], scores, - &refined_boxes[0], &refined_boxes_areas[0], &refined_scores[0], - rois_num, classes_num_, - img_H, img_W, + refine_boxes(boxes, + deltas, + &deltas_weights_[0], + scores, + &refined_boxes[0], + &refined_boxes_areas[0], + &refined_scores[0], + rois_num, + classes_num_, + img_H, + img_W, max_delta_log_wh_, 1.0f); @@ -353,7 +371,7 @@ void ExperimentalDetectronDetectionOutput::execute(dnnl::stream strm) { memset(output_classes, 0, max_detections_per_image_ * sizeof(output_classes[0])); int i = 0; - for (const auto & detection : conf_index_class_map) { + for (const auto& detection : conf_index_class_map) { float score = detection.first; int cls = detection.second.first; int idx = detection.second.second; @@ -371,6 +389,6 @@ bool ExperimentalDetectronDetectionOutput::created() const { return getType() == Type::ExperimentalDetectronDetectionOutput; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/experimental_detectron_detection_output.h b/src/plugins/intel_cpu/src/nodes/experimental_detectron_detection_output.h index 2f76f1004face5..206f807585de7d 100644 --- a/src/plugins/intel_cpu/src/nodes/experimental_detectron_detection_output.h +++ b/src/plugins/intel_cpu/src/nodes/experimental_detectron_detection_output.h @@ -14,25 +14,27 @@ class ExperimentalDetectronDetectionOutput : public Node { public: ExperimentalDetectronDetectionOutput(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; bool needShapeInfer() const override; bool needPrepareParams() const override; - void executeDynamicImpl(dnnl::stream strm) override { execute(strm); } + void executeDynamicImpl(dnnl::stream strm) override { + execute(strm); + } static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; private: - const int INPUT_ROIS {0}; - const int INPUT_DELTAS {1}; - const int INPUT_SCORES {2}; - const int INPUT_IM_INFO {3}; + const int INPUT_ROIS{0}; + const int INPUT_DELTAS{1}; + const int INPUT_SCORES{2}; + const int INPUT_IM_INFO{3}; - const int OUTPUT_BOXES {0}; - const int OUTPUT_CLASSES {1}; - const int OUTPUT_SCORES {2}; + const int OUTPUT_BOXES{0}; + const int OUTPUT_CLASSES{1}; + const int OUTPUT_SCORES{2}; float score_threshold_; float nms_threshold_; @@ -44,6 +46,6 @@ class ExperimentalDetectronDetectionOutput : public Node { std::vector deltas_weights_; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/experimental_detectron_generate_proposals_single_image.cpp b/src/plugins/intel_cpu/src/nodes/experimental_detectron_generate_proposals_single_image.cpp index 33f17c8d95f093..778e796aacc11a 100644 --- a/src/plugins/intel_cpu/src/nodes/experimental_detectron_generate_proposals_single_image.cpp +++ b/src/plugins/intel_cpu/src/nodes/experimental_detectron_generate_proposals_single_image.cpp @@ -2,22 +2,22 @@ // SPDX-License-Identifier: Apache-2.0 // -#include +#include #include #include +#include #include -#include #include -#include +#include #if defined(HAVE_AVX2) -#include +# include #endif -#include "openvino/op/experimental_detectron_generate_proposals.hpp" -#include "openvino/core/parallel.hpp" #include "common/cpu_memcpy.h" #include "experimental_detectron_generate_proposals_single_image.h" +#include "openvino/core/parallel.hpp" +#include "openvino/op/experimental_detectron_generate_proposals.hpp" namespace ov { namespace intel_cpu { @@ -29,20 +29,29 @@ struct Indexer4d { int dim23_; int dim123_; - explicit Indexer4d(int dim0, int dim1, int dim2, int dim3): - dim3_(dim3), dim23_(dim2 * dim3), dim123_(dim1 * dim2 * dim3) { + explicit Indexer4d(int dim0, int dim1, int dim2, int dim3) + : dim3_(dim3), + dim23_(dim2 * dim3), + dim123_(dim1 * dim2 * dim3) { (void)dim0; } int operator()(int i, int j, int k, int n) const { - return i * dim123_ + j * dim23_ + k * dim3_ + n; + return i * dim123_ + j * dim23_ + k * dim3_ + n; } }; -void refine_anchors(const float* deltas, const float* scores, const float* anchors, - float* proposals, const int anchors_num, const int bottom_H, - const int bottom_W, const float img_H, const float img_W, - const float min_box_H, const float min_box_W, +void refine_anchors(const float* deltas, + const float* scores, + const float* anchors, + float* proposals, + const int anchors_num, + const int bottom_H, + const int bottom_W, + const float img_H, + const float img_W, + const float min_box_H, + const float min_box_W, const float max_delta_log_wh, float coordinates_offset) { Indexer4d delta_idx(anchors_num, 4, bottom_H, bottom_W); @@ -108,17 +117,22 @@ void refine_anchors(const float* deltas, const float* scores, const float* ancho void unpack_boxes(const float* p_proposals, float* unpacked_boxes, int pre_nms_topn) { parallel_for(pre_nms_topn, [&](size_t i) { - unpacked_boxes[0*pre_nms_topn + i] = p_proposals[5*i + 0]; - unpacked_boxes[1*pre_nms_topn + i] = p_proposals[5*i + 1]; - unpacked_boxes[2*pre_nms_topn + i] = p_proposals[5*i + 2]; - unpacked_boxes[3*pre_nms_topn + i] = p_proposals[5*i + 3]; - unpacked_boxes[4*pre_nms_topn + i] = p_proposals[5*i + 4]; + unpacked_boxes[0 * pre_nms_topn + i] = p_proposals[5 * i + 0]; + unpacked_boxes[1 * pre_nms_topn + i] = p_proposals[5 * i + 1]; + unpacked_boxes[2 * pre_nms_topn + i] = p_proposals[5 * i + 2]; + unpacked_boxes[3 * pre_nms_topn + i] = p_proposals[5 * i + 3]; + unpacked_boxes[4 * pre_nms_topn + i] = p_proposals[5 * i + 4]; }); } -void nms_cpu(const int num_boxes, int is_dead[], - const float* boxes, int index_out[], int* const num_out, - const int base_index, const float nms_thresh, const int max_num_out, +void nms_cpu(const int num_boxes, + int is_dead[], + const float* boxes, + int index_out[], + int* const num_out, + const int base_index, + const float nms_thresh, + const int max_num_out, float coordinates_offset) { const int num_proposals = num_boxes; int count = 0; @@ -131,9 +145,9 @@ void nms_cpu(const int num_boxes, int is_dead[], std::memset(is_dead, 0, num_boxes * sizeof(int)); #if defined(HAVE_AVX2) - __m256 vc_fone = _mm256_set1_ps(coordinates_offset); + __m256 vc_fone = _mm256_set1_ps(coordinates_offset); __m256i vc_ione = _mm256_set1_epi32(1); - __m256 vc_zero = _mm256_set1_ps(0.0f); + __m256 vc_zero = _mm256_set1_ps(0.0f); __m256 vc_nms_thresh = _mm256_set1_ps(nms_thresh); #endif @@ -154,13 +168,13 @@ void nms_cpu(const int num_boxes, int is_dead[], __m256 vx1i = _mm256_set1_ps(x1[box]); __m256 vy1i = _mm256_set1_ps(y1[box]); - __m256 vA_width = _mm256_sub_ps(vx1i, vx0i); + __m256 vA_width = _mm256_sub_ps(vx1i, vx0i); __m256 vA_height = _mm256_sub_ps(vy1i, vy0i); - __m256 vA_area = _mm256_mul_ps(_mm256_add_ps(vA_width, vc_fone), _mm256_add_ps(vA_height, vc_fone)); + __m256 vA_area = _mm256_mul_ps(_mm256_add_ps(vA_width, vc_fone), _mm256_add_ps(vA_height, vc_fone)); for (; tail <= num_boxes - 8; tail += 8) { - __m256i *pdst = reinterpret_cast<__m256i*>(is_dead + tail); - __m256i vdst = _mm256_loadu_si256(pdst); + __m256i* pdst = reinterpret_cast<__m256i*>(is_dead + tail); + __m256i vdst = _mm256_loadu_si256(pdst); __m256 vx0j = _mm256_loadu_ps(x0 + tail); __m256 vy0j = _mm256_loadu_ps(y0 + tail); @@ -172,13 +186,13 @@ void nms_cpu(const int num_boxes, int is_dead[], __m256 vx1 = _mm256_min_ps(vx1i, vx1j); __m256 vy1 = _mm256_min_ps(vy1i, vy1j); - __m256 vwidth = _mm256_add_ps(_mm256_sub_ps(vx1, vx0), vc_fone); + __m256 vwidth = _mm256_add_ps(_mm256_sub_ps(vx1, vx0), vc_fone); __m256 vheight = _mm256_add_ps(_mm256_sub_ps(vy1, vy0), vc_fone); __m256 varea = _mm256_mul_ps(_mm256_max_ps(vc_zero, vwidth), _mm256_max_ps(vc_zero, vheight)); - __m256 vB_width = _mm256_sub_ps(vx1j, vx0j); + __m256 vB_width = _mm256_sub_ps(vx1j, vx0j); __m256 vB_height = _mm256_sub_ps(vy1j, vy0j); - __m256 vB_area = _mm256_mul_ps(_mm256_add_ps(vB_width, vc_fone), _mm256_add_ps(vB_height, vc_fone)); + __m256 vB_area = _mm256_mul_ps(_mm256_add_ps(vB_width, vc_fone), _mm256_add_ps(vB_height, vc_fone)); __m256 vdivisor = _mm256_sub_ps(_mm256_add_ps(vA_area, vB_area), varea); __m256 vintersection_area = _mm256_div_ps(varea, vdivisor); @@ -219,9 +233,9 @@ void nms_cpu(const int num_boxes, int is_dead[], const float y1 = std::min(y1i, y1j); // intersection area - const float width = std::max(0.0f, x1 - x0 + coordinates_offset); - const float height = std::max(0.0f, y1 - y0 + coordinates_offset); - const float area = width * height; + const float width = std::max(0.0f, x1 - x0 + coordinates_offset); + const float height = std::max(0.0f, y1 - y0 + coordinates_offset); + const float area = width * height; // area of A, B const float A_area = (x1i - x0i + coordinates_offset) * (y1i - y0i + coordinates_offset); @@ -239,14 +253,18 @@ void nms_cpu(const int num_boxes, int is_dead[], *num_out = count; } -void fill_output_blobs(const float* proposals, const int* roi_indices, - float* rois, float* scores, - const int num_proposals, const int num_rois, const int post_nms_topn) { - const float *src_x0 = proposals + 0 * num_proposals; - const float *src_y0 = proposals + 1 * num_proposals; - const float *src_x1 = proposals + 2 * num_proposals; - const float *src_y1 = proposals + 3 * num_proposals; - const float *src_score = proposals + 4 * num_proposals; +void fill_output_blobs(const float* proposals, + const int* roi_indices, + float* rois, + float* scores, + const int num_proposals, + const int num_rois, + const int post_nms_topn) { + const float* src_x0 = proposals + 0 * num_proposals; + const float* src_y0 = proposals + 1 * num_proposals; + const float* src_x1 = proposals + 2 * num_proposals; + const float* src_y1 = proposals + 3 * num_proposals; + const float* src_score = proposals + 4 * num_proposals; parallel_for(num_rois, [&](size_t i) { int index = roi_indices[i]; @@ -269,10 +287,11 @@ void fill_output_blobs(const float* proposals, const int* roi_indices, } // namespace -bool ExperimentalDetectronGenerateProposalsSingleImage::isSupportedOperation - (const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool ExperimentalDetectronGenerateProposalsSingleImage::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { - const auto proposalOp = ov::as_type_ptr(op); + const auto proposalOp = + ov::as_type_ptr(op); if (!proposalOp) { errorMessage = "Node is not an instance of the Proposal from the operations set v0."; return false; @@ -313,8 +332,7 @@ void ExperimentalDetectronGenerateProposalsSingleImage::initSupportedPrimitiveDe {LayoutType::ncsp, ov::element::f32}, {LayoutType::ncsp, ov::element::f32}, {LayoutType::ncsp, ov::element::f32}}, - {{LayoutType::ncsp, ov::element::f32}, - {LayoutType::ncsp, ov::element::f32}}, + {{LayoutType::ncsp, ov::element::f32}, {LayoutType::ncsp, ov::element::f32}}, impl_desc_type::ref_any); } @@ -325,13 +343,13 @@ void ExperimentalDetectronGenerateProposalsSingleImage::execute(dnnl::stream str } size_t anchor_dims_size = 1; - const auto &anchorDims = getParentEdgeAt(INPUT_ANCHORS)->getMemory().getStaticDims(); + const auto& anchorDims = getParentEdgeAt(INPUT_ANCHORS)->getMemory().getStaticDims(); for (size_t i = 0; i < anchorDims.size(); i++) { anchor_dims_size *= anchorDims[i]; } size_t deltas_dims_size = 1; - const auto &deltaDims = getParentEdgeAt(INPUT_DELTAS)->getMemory().getStaticDims(); + const auto& deltaDims = getParentEdgeAt(INPUT_DELTAS)->getMemory().getStaticDims(); for (size_t i = 0; i < deltaDims.size(); i++) { deltas_dims_size *= deltaDims[i]; } @@ -339,7 +357,7 @@ void ExperimentalDetectronGenerateProposalsSingleImage::execute(dnnl::stream str OPENVINO_THROW("'Anchors' blob size for ONNXProposal is incompatible with 'deltas' blob size!"); size_t score_dims_size = 1; - const auto &scoreDims = getParentEdgeAt(INPUT_SCORES)->getMemory().getStaticDims(); + const auto& scoreDims = getParentEdgeAt(INPUT_SCORES)->getMemory().getStaticDims(); for (size_t i = 0; i < scoreDims.size(); i++) { score_dims_size *= scoreDims[i]; } @@ -347,13 +365,13 @@ void ExperimentalDetectronGenerateProposalsSingleImage::execute(dnnl::stream str OPENVINO_THROW("'Deltas' blob size for ONNXProposal is incompatible with 'scores' blob size!"); // Prepare memory - const float *p_deltas_item = getSrcDataAtPortAs(INPUT_DELTAS); - const float *p_scores_item = getSrcDataAtPortAs(INPUT_SCORES); - const float *p_anchors_item = getSrcDataAtPortAs(INPUT_ANCHORS); - const float *p_img_info_cpu = getSrcDataAtPortAs(INPUT_IM_INFO); + const float* p_deltas_item = getSrcDataAtPortAs(INPUT_DELTAS); + const float* p_scores_item = getSrcDataAtPortAs(INPUT_SCORES); + const float* p_anchors_item = getSrcDataAtPortAs(INPUT_ANCHORS); + const float* p_img_info_cpu = getSrcDataAtPortAs(INPUT_IM_INFO); - float *p_roi_item = getDstDataAtPortAs(OUTPUT_ROIS); - float *p_roi_score_item = getDstDataAtPortAs(OUTPUT_SCORES); + float* p_roi_item = getDstDataAtPortAs(OUTPUT_ROIS); + float* p_roi_score_item = getDstDataAtPortAs(OUTPUT_SCORES); const int anchors_num = scoreDims[0]; @@ -398,24 +416,45 @@ void ExperimentalDetectronGenerateProposalsSingleImage::execute(dnnl::stream str // Execute int batch_size = 1; // inputs[INPUT_DELTAS]->getTensorDesc().getDims()[0]; for (int n = 0; n < batch_size; ++n) { - refine_anchors(p_deltas_item, p_scores_item, p_anchors_item, - reinterpret_cast(&proposals_[0]), anchors_num, bottom_H, - bottom_W, img_H, img_W, - min_box_H, min_box_W, + refine_anchors(p_deltas_item, + p_scores_item, + p_anchors_item, + reinterpret_cast(&proposals_[0]), + anchors_num, + bottom_H, + bottom_W, + img_H, + img_W, + min_box_H, + min_box_W, static_cast(std::log(1000. / 16.)), 1.0f); - std::partial_sort(proposals_.begin(), proposals_.begin() + pre_nms_topn, proposals_.end(), - [](const ProposalBox &struct1, const ProposalBox &struct2) { + std::partial_sort(proposals_.begin(), + proposals_.begin() + pre_nms_topn, + proposals_.end(), + [](const ProposalBox& struct1, const ProposalBox& struct2) { return (struct1.score > struct2.score); }); - unpack_boxes(reinterpret_cast(&proposals_[0]), &unpacked_boxes[0], pre_nms_topn); - nms_cpu(pre_nms_topn, &is_dead[0], &unpacked_boxes[0], &roi_indices_[0], &num_rois, 0, - nms_thresh_, post_nms_topn_, coordinates_offset); - fill_output_blobs(&unpacked_boxes[0], &roi_indices_[0], p_roi_item, p_roi_score_item, - pre_nms_topn, num_rois, post_nms_topn_); + unpack_boxes(reinterpret_cast(&proposals_[0]), &unpacked_boxes[0], pre_nms_topn); + nms_cpu(pre_nms_topn, + &is_dead[0], + &unpacked_boxes[0], + &roi_indices_[0], + &num_rois, + 0, + nms_thresh_, + post_nms_topn_, + coordinates_offset); + fill_output_blobs(&unpacked_boxes[0], + &roi_indices_[0], + p_roi_item, + p_roi_score_item, + pre_nms_topn, + num_rois, + post_nms_topn_); } - } catch (const std::exception &e) { + } catch (const std::exception& e) { std::string errorMsg = e.what(); OPENVINO_THROW(errorMsg); } @@ -433,6 +472,6 @@ bool ExperimentalDetectronGenerateProposalsSingleImage::needPrepareParams() cons return false; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/experimental_detectron_generate_proposals_single_image.h b/src/plugins/intel_cpu/src/nodes/experimental_detectron_generate_proposals_single_image.h index 41aaf63f637e76..d747813e10b258 100644 --- a/src/plugins/intel_cpu/src/nodes/experimental_detectron_generate_proposals_single_image.h +++ b/src/plugins/intel_cpu/src/nodes/experimental_detectron_generate_proposals_single_image.h @@ -13,16 +13,18 @@ namespace node { class ExperimentalDetectronGenerateProposalsSingleImage : public Node { public: ExperimentalDetectronGenerateProposalsSingleImage(const std::shared_ptr& op, - const GraphContext::CPtr context); + const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; bool needShapeInfer() const override; bool needPrepareParams() const override; - void executeDynamicImpl(dnnl::stream strm) override { execute(strm); } + void executeDynamicImpl(dnnl::stream strm) override { + execute(strm); + } static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; private: @@ -32,12 +34,12 @@ class ExperimentalDetectronGenerateProposalsSingleImage : public Node { // Outputs: // top_rois, shape [max_rois, 4] - const int INPUT_IM_INFO {0}; - const int INPUT_ANCHORS {1}; - const int INPUT_DELTAS {2}; - const int INPUT_SCORES {3}; - const int OUTPUT_ROIS {0}; - const int OUTPUT_SCORES {1}; + const int INPUT_IM_INFO{0}; + const int INPUT_ANCHORS{1}; + const int INPUT_DELTAS{2}; + const int INPUT_SCORES{3}; + const int OUTPUT_ROIS{0}; + const int OUTPUT_SCORES{1}; float min_size_; int pre_nms_topn_; @@ -48,6 +50,6 @@ class ExperimentalDetectronGenerateProposalsSingleImage : public Node { std::vector roi_indices_; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/experimental_detectron_priorgridgenerator.cpp b/src/plugins/intel_cpu/src/nodes/experimental_detectron_priorgridgenerator.cpp index eead95def0a8fb..f7df0e533778ed 100644 --- a/src/plugins/intel_cpu/src/nodes/experimental_detectron_priorgridgenerator.cpp +++ b/src/plugins/intel_cpu/src/nodes/experimental_detectron_priorgridgenerator.cpp @@ -2,20 +2,22 @@ // SPDX-License-Identifier: Apache-2.0 // -#include +#include "experimental_detectron_priorgridgenerator.h" #include +#include + #include "openvino/core/parallel.hpp" -#include "experimental_detectron_priorgridgenerator.h" namespace ov { namespace intel_cpu { namespace node { bool ExperimentalDetectronPriorGridGenerator::isSupportedOperation(const std::shared_ptr& op, - std::string& errorMessage) noexcept { + std::string& errorMessage) noexcept { try { - const auto priorGridGen = std::dynamic_pointer_cast(op); + const auto priorGridGen = + std::dynamic_pointer_cast(op); if (!priorGridGen) { errorMessage = "Only opset6 ExperimentalDetectronPriorGridGenerator operation is supported"; return false; @@ -39,7 +41,7 @@ ExperimentalDetectronPriorGridGenerator::ExperimentalDetectronPriorGridGenerator if (getOriginalInputsNumber() != 3 || getOriginalOutputsNumber() != 1) OPENVINO_THROW(errorPrefix, " has incorrect number of input/output edges!"); - const auto &attr = priorGridGen->get_attrs(); + const auto& attr = priorGridGen->get_attrs(); grid_w_ = attr.w; grid_h_ = attr.h; stride_h_ = attr.stride_y; @@ -64,11 +66,15 @@ void ExperimentalDetectronPriorGridGenerator::execute(dnnl::stream strm) { // Execute const int layer_width = grid_w_ ? grid_w_ : getParentEdgeAt(INPUT_FEATUREMAP)->getMemory().getStaticDims()[3]; const int layer_height = grid_h_ ? grid_h_ : getParentEdgeAt(INPUT_FEATUREMAP)->getMemory().getStaticDims()[2]; - const float step_w = stride_w_ ? stride_w_ : static_cast(getParentEdgeAt(INPUT_IMAGE)->getMemory().getStaticDims()[3]) / layer_width; - const float step_h = stride_h_ ? stride_h_ : static_cast(getParentEdgeAt(INPUT_IMAGE)->getMemory().getStaticDims()[2]) / layer_height; + const float step_w = + stride_w_ ? stride_w_ + : static_cast(getParentEdgeAt(INPUT_IMAGE)->getMemory().getStaticDims()[3]) / layer_width; + const float step_h = + stride_h_ ? stride_h_ + : static_cast(getParentEdgeAt(INPUT_IMAGE)->getMemory().getStaticDims()[2]) / layer_height; - const auto *bottom_data_0 = getSrcDataAtPortAs(0); - auto *top_data_0 = getDstDataAtPortAs(OUTPUT_ROIS); + const auto* bottom_data_0 = getSrcDataAtPortAs(0); + auto* top_data_0 = getDstDataAtPortAs(OUTPUT_ROIS); for (int h = 0; h < layer_height; ++h) { for (int w = 0; w < layer_width; ++w) { @@ -91,6 +97,6 @@ bool ExperimentalDetectronPriorGridGenerator::needPrepareParams() const { return false; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/experimental_detectron_priorgridgenerator.h b/src/plugins/intel_cpu/src/nodes/experimental_detectron_priorgridgenerator.h index cf52b4e5c9b934..47c2c16dc558b9 100644 --- a/src/plugins/intel_cpu/src/nodes/experimental_detectron_priorgridgenerator.h +++ b/src/plugins/intel_cpu/src/nodes/experimental_detectron_priorgridgenerator.h @@ -14,13 +14,15 @@ class ExperimentalDetectronPriorGridGenerator : public Node { public: ExperimentalDetectronPriorGridGenerator(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; bool needPrepareParams() const override; - void executeDynamicImpl(dnnl::stream strm) override { execute(strm); } + void executeDynamicImpl(dnnl::stream strm) override { + execute(strm); + } static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; private: @@ -31,11 +33,11 @@ class ExperimentalDetectronPriorGridGenerator : public Node { // Outputs: // priors_grid, shape [m, 4] - const int INPUT_PRIORS {0}; - const int INPUT_FEATUREMAP {1}; - const int INPUT_IMAGE {2}; + const int INPUT_PRIORS{0}; + const int INPUT_FEATUREMAP{1}; + const int INPUT_IMAGE{2}; - const int OUTPUT_ROIS {0}; + const int OUTPUT_ROIS{0}; int grid_w_; int grid_h_; @@ -45,6 +47,6 @@ class ExperimentalDetectronPriorGridGenerator : public Node { std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/experimental_detectron_roifeatureextractor.cpp b/src/plugins/intel_cpu/src/nodes/experimental_detectron_roifeatureextractor.cpp index c92e3c2594d4a9..05f2202537f986 100644 --- a/src/plugins/intel_cpu/src/nodes/experimental_detectron_roifeatureextractor.cpp +++ b/src/plugins/intel_cpu/src/nodes/experimental_detectron_roifeatureextractor.cpp @@ -2,14 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "experimental_detectron_roifeatureextractor.h" + +#include +#include #include #include -#include -#include -#include "openvino/core/parallel.hpp" #include "common/cpu_memcpy.h" -#include "experimental_detectron_roifeatureextractor.h" +#include "openvino/core/parallel.hpp" namespace ov { namespace intel_cpu { @@ -30,31 +31,28 @@ struct PreCalc { }; template -void pre_calc_for_bilinear_interpolate( - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int iy_upper, - const int ix_upper, - T roi_start_h, - T roi_start_w, - T bin_size_h, - T bin_size_w, - int roi_bin_grid_h, - int roi_bin_grid_w, - std::vector>& pre_calc) { +void pre_calc_for_bilinear_interpolate(const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int iy_upper, + const int ix_upper, + T roi_start_h, + T roi_start_w, + T bin_size_h, + T bin_size_w, + int roi_bin_grid_h, + int roi_bin_grid_w, + std::vector>& pre_calc) { int pre_calc_index = 0; for (int ph = 0; ph < pooled_height; ph++) { for (int pw = 0; pw < pooled_width; pw++) { for (int iy = 0; iy < iy_upper; iy++) { const T yy = roi_start_h + ph * bin_size_h + - static_cast(iy + .5f) * bin_size_h / - static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + static_cast(iy + .5f) * bin_size_h / static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 for (int ix = 0; ix < ix_upper; ix++) { const T xx = roi_start_w + pw * bin_size_w + - static_cast(ix + .5f) * bin_size_w / - static_cast(roi_bin_grid_w); + static_cast(ix + .5f) * bin_size_w / static_cast(roi_bin_grid_w); T x = xx; T y = yy; @@ -126,19 +124,18 @@ void pre_calc_for_bilinear_interpolate( } template -void ROIAlignForward_cpu_kernel( - const int nthreads, - const T* bottom_data, - const T& spatial_scale, - const int channels, - const int height, - const int width, - const int pooled_height, - const int pooled_width, - const int sampling_ratio, - const T* bottom_rois, - const bool aligned, - T* top_data) { +void ROIAlignForward_cpu_kernel(const int nthreads, + const T* bottom_data, + const T& spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const T* bottom_rois, + const bool aligned, + T* top_data) { int roi_cols = 4; int n_rois = nthreads / channels / pooled_width / pooled_height; @@ -168,38 +165,33 @@ void ROIAlignForward_cpu_kernel( T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); // We use roi_bin_grid to sample the grid and mimic integral - int roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : static_cast(ceil(roi_height / pooled_height)); // e.g., = 2 - int roi_bin_grid_w = - (sampling_ratio > 0) ? sampling_ratio : static_cast(ceil(roi_width / pooled_width)); + int roi_bin_grid_h = + (sampling_ratio > 0) ? sampling_ratio : static_cast(ceil(roi_height / pooled_height)); // e.g., = 2 + int roi_bin_grid_w = (sampling_ratio > 0) ? sampling_ratio : static_cast(ceil(roi_width / pooled_width)); // We do average (integral) pooling inside a bin const T count = static_cast(roi_bin_grid_h * roi_bin_grid_w); // e.g. = 4 // we want to precalculate indices and weights shared by all chanels, // this is the key point of optimiation - std::vector> pre_calc( - roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); - pre_calc_for_bilinear_interpolate( - height, - width, - pooled_height, - pooled_width, - roi_bin_grid_h, - roi_bin_grid_w, - roi_start_h, - roi_start_w, - bin_size_h, - bin_size_w, - roi_bin_grid_h, - roi_bin_grid_w, - pre_calc); + std::vector> pre_calc(roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); + pre_calc_for_bilinear_interpolate(height, + width, + pooled_height, + pooled_width, + roi_bin_grid_h, + roi_bin_grid_w, + roi_start_h, + roi_start_w, + bin_size_h, + bin_size_w, + roi_bin_grid_h, + roi_bin_grid_w, + pre_calc); for (int c = 0; c < channels; c++) { int index_n_c = index_n + c * pooled_width * pooled_height; - const T* offset_bottom_data = - bottom_data + (roi_batch_ind * channels + c) * height * width; + const T* offset_bottom_data = bottom_data + (roi_batch_ind * channels + c) * height * width; int pre_calc_index = 0; for (int ph = 0; ph < pooled_height; ph++) { @@ -210,10 +202,8 @@ void ROIAlignForward_cpu_kernel( for (int iy = 0; iy < roi_bin_grid_h; iy++) { for (int ix = 0; ix < roi_bin_grid_w; ix++) { PreCalc pc = pre_calc[pre_calc_index]; - output_val += pc.w1 * offset_bottom_data[pc.pos1] + - pc.w2 * offset_bottom_data[pc.pos2] + - pc.w3 * offset_bottom_data[pc.pos3] + - pc.w4 * offset_bottom_data[pc.pos4]; + output_val += pc.w1 * offset_bottom_data[pc.pos1] + pc.w2 * offset_bottom_data[pc.pos2] + + pc.w3 * offset_bottom_data[pc.pos3] + pc.w4 * offset_bottom_data[pc.pos4]; pre_calc_index += 1; } @@ -222,14 +212,12 @@ void ROIAlignForward_cpu_kernel( top_data[index] = output_val; } // for pw - } // for ph - } // for c + } // for ph + } // for c }); } - -void redistribute_rois(const float* rois, int* level_ids, - const int num_rois, const int levels_num) { +void redistribute_rois(const float* rois, int* level_ids, const int num_rois, const int levels_num) { const float canonical_scale = 224.0f; const int canonical_level = 2; @@ -252,11 +240,11 @@ void redistribute_rois(const float* rois, int* level_ids, } } - -void reord(const float* src_data, const int* ranks, const int n, const int step, float* dst_data, - int* dst_mapping) { +void reord(const float* src_data, const int* ranks, const int n, const int step, float* dst_data, int* dst_mapping) { std::iota(dst_mapping, dst_mapping + n, 0); - std::sort(dst_mapping, dst_mapping + n, [&ranks](size_t i1, size_t i2) {return ranks[i1] < ranks[i2];}); + std::sort(dst_mapping, dst_mapping + n, [&ranks](size_t i1, size_t i2) { + return ranks[i1] < ranks[i2]; + }); for (int i = 0; i < n; ++i) { const int j = dst_mapping[i]; assert(0 <= j && j < n); @@ -277,12 +265,13 @@ void split_points(const std::vector& ids, std::vector& rois_per_level, rois_per_level.insert(rois_per_level.begin(), 0); } -} // namespace +} // namespace bool ExperimentalDetectronROIFeatureExtractor::isSupportedOperation(const std::shared_ptr& op, - std::string& errorMessage) noexcept { + std::string& errorMessage) noexcept { try { - const auto roiFeatureExtractor = std::dynamic_pointer_cast(op); + const auto roiFeatureExtractor = + std::dynamic_pointer_cast(op); if (!roiFeatureExtractor) { errorMessage = "Only opset6 ExperimentalDetectronROIFeatureExtractor operation is supported"; return false; @@ -301,8 +290,9 @@ ExperimentalDetectronROIFeatureExtractor::ExperimentalDetectronROIFeatureExtract OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage); } - const auto roiFeatureExtractor = std::dynamic_pointer_cast(op); - const auto &attr = roiFeatureExtractor->get_attrs(); + const auto roiFeatureExtractor = + std::dynamic_pointer_cast(op); + const auto& attr = roiFeatureExtractor->get_attrs(); output_dim_ = attr.output_size; pyramid_scales_ = attr.pyramid_scales; sampling_ratio_ = attr.sampling_ratio; @@ -321,8 +311,7 @@ void ExperimentalDetectronROIFeatureExtractor::initSupportedPrimitiveDescriptors inDataConf.emplace_back(LayoutType::ncsp, ov::element::f32); addSupportedPrimDesc(inDataConf, - {{LayoutType::ncsp, ov::element::f32}, - {LayoutType::ncsp, ov::element::f32}}, + {{LayoutType::ncsp, ov::element::f32}, {LayoutType::ncsp, ov::element::f32}}, impl_desc_type::ref_any); } @@ -332,15 +321,15 @@ void ExperimentalDetectronROIFeatureExtractor::execute(dnnl::stream strm) { const int channels_num = getParentEdgeAt(INPUT_FEATURES_START)->getMemory().getStaticDims()[1]; const int feaxels_per_roi = pooled_height_ * pooled_width_ * channels_num; - auto *input_rois = getSrcDataAtPortAs(INPUT_ROIS); - auto *output_rois_features = getDstDataAtPortAs(OUTPUT_ROI_FEATURES); - float *output_rois = nullptr; + auto* input_rois = getSrcDataAtPortAs(INPUT_ROIS); + auto* output_rois_features = getDstDataAtPortAs(OUTPUT_ROI_FEATURES); + float* output_rois = nullptr; if (OUTPUT_ROIS < outputShapes.size()) { output_rois = getDstDataAtPortAs(OUTPUT_ROIS); } std::vector level_ids(num_rois, 0); - redistribute_rois(input_rois, reinterpret_cast(&level_ids[0]), num_rois, levels_num); + redistribute_rois(input_rois, reinterpret_cast(&level_ids[0]), num_rois, levels_num); std::vector reordered_rois(4 * num_rois, 0); std::vector original_rois_mapping(num_rois, 0); @@ -354,7 +343,7 @@ void ExperimentalDetectronROIFeatureExtractor::execute(dnnl::stream strm) { const int level_rois_offset = rois_per_level[i]; const int level_rois_num = rois_per_level[i + 1] - level_rois_offset; if (level_rois_num > 0) { - auto *featuremap = getSrcDataAtPortAs(INPUT_FEATURES_START + i); + auto* featuremap = getSrcDataAtPortAs(INPUT_FEATURES_START + i); const int featuremap_height = getParentEdgeAt(INPUT_FEATURES_START + i)->getMemory().getStaticDims()[2]; const int featuremap_width = getParentEdgeAt(INPUT_FEATURES_START + i)->getMemory().getStaticDims()[3]; ROIAlignForward_cpu_kernel(feaxels_per_roi * level_rois_num, @@ -373,8 +362,12 @@ void ExperimentalDetectronROIFeatureExtractor::execute(dnnl::stream strm) { } std::vector dummy_mapping(num_rois, 0); - reord(&output_rois_features_temp[0], &original_rois_mapping[0], num_rois, feaxels_per_roi, - output_rois_features, &dummy_mapping[0]); + reord(&output_rois_features_temp[0], + &original_rois_mapping[0], + num_rois, + feaxels_per_roi, + output_rois_features, + &dummy_mapping[0]); if (output_rois != nullptr) { cpu_memcpy(output_rois, input_rois, 4 * num_rois * sizeof(float)); } @@ -384,6 +377,6 @@ bool ExperimentalDetectronROIFeatureExtractor::created() const { return getType() == Type::ExperimentalDetectronROIFeatureExtractor; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/experimental_detectron_roifeatureextractor.h b/src/plugins/intel_cpu/src/nodes/experimental_detectron_roifeatureextractor.h index 94bfdfd224d0c5..374fd62c61b776 100644 --- a/src/plugins/intel_cpu/src/nodes/experimental_detectron_roifeatureextractor.h +++ b/src/plugins/intel_cpu/src/nodes/experimental_detectron_roifeatureextractor.h @@ -14,22 +14,26 @@ class ExperimentalDetectronROIFeatureExtractor : public Node { public: ExperimentalDetectronROIFeatureExtractor(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; - bool needPrepareParams() const override { return false; }; - void executeDynamicImpl(dnnl::stream strm) override { execute(strm); }; + bool needPrepareParams() const override { + return false; + }; + void executeDynamicImpl(dnnl::stream strm) override { + execute(strm); + }; static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; private: - const int INPUT_ROIS {0}; - const int INPUT_FEATURES_START {1}; + const int INPUT_ROIS{0}; + const int INPUT_FEATURES_START{1}; - const int OUTPUT_ROI_FEATURES {0}; - const size_t OUTPUT_ROIS {1}; + const int OUTPUT_ROI_FEATURES{0}; + const size_t OUTPUT_ROIS{1}; int output_dim_ = 0; int pooled_height_ = 0; @@ -39,6 +43,6 @@ class ExperimentalDetectronROIFeatureExtractor : public Node { bool aligned_ = false; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/experimental_detectron_topkrois.cpp b/src/plugins/intel_cpu/src/nodes/experimental_detectron_topkrois.cpp index 46b60fcdb83efd..f09d96ac7a7f7e 100644 --- a/src/plugins/intel_cpu/src/nodes/experimental_detectron_topkrois.cpp +++ b/src/plugins/intel_cpu/src/nodes/experimental_detectron_topkrois.cpp @@ -2,20 +2,22 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "experimental_detectron_topkrois.h" + +#include +#include #include #include -#include -#include -#include "openvino/core/parallel.hpp" #include "common/cpu_memcpy.h" -#include "experimental_detectron_topkrois.h" +#include "openvino/core/parallel.hpp" namespace ov { namespace intel_cpu { namespace node { -bool ExperimentalDetectronTopKROIs::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool ExperimentalDetectronTopKROIs::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { const auto topKROI = std::dynamic_pointer_cast(op); if (!topKROI) { @@ -56,8 +58,7 @@ void ExperimentalDetectronTopKROIs::initSupportedPrimitiveDescriptors() { if (!supportedPrimitiveDescriptors.empty()) return; - addSupportedPrimDesc({{LayoutType::ncsp, ov::element::f32}, - {LayoutType::ncsp, ov::element::f32}}, + addSupportedPrimDesc({{LayoutType::ncsp, ov::element::f32}, {LayoutType::ncsp, ov::element::f32}}, {{LayoutType::ncsp, ov::element::f32}}, impl_desc_type::ref_any); } @@ -66,14 +67,16 @@ void ExperimentalDetectronTopKROIs::execute(dnnl::stream strm) { const int input_rois_num = getParentEdgeAt(INPUT_ROIS)->getMemory().getStaticDims()[0]; const int top_rois_num = (std::min)(max_rois_num_, input_rois_num); - auto *input_rois = getSrcDataAtPortAs(INPUT_ROIS); - auto *input_probs = getSrcDataAtPortAs(INPUT_PROBS); - auto *output_rois = getDstDataAtPortAs(OUTPUT_ROIS); + auto* input_rois = getSrcDataAtPortAs(INPUT_ROIS); + auto* input_probs = getSrcDataAtPortAs(INPUT_PROBS); + auto* output_rois = getDstDataAtPortAs(OUTPUT_ROIS); std::vector idx(input_rois_num); iota(idx.begin(), idx.end(), 0); // FIXME. partial_sort is enough here. - sort(idx.begin(), idx.end(), [&input_probs](size_t i1, size_t i2) {return input_probs[i1] > input_probs[i2];}); + sort(idx.begin(), idx.end(), [&input_probs](size_t i1, size_t i2) { + return input_probs[i1] > input_probs[i2]; + }); for (int i = 0; i < top_rois_num; ++i) { cpu_memcpy(output_rois + 4 * i, input_rois + 4 * idx[i], 4 * sizeof(float)); @@ -84,6 +87,6 @@ bool ExperimentalDetectronTopKROIs::created() const { return getType() == Type::ExperimentalDetectronTopKROIs; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/experimental_detectron_topkrois.h b/src/plugins/intel_cpu/src/nodes/experimental_detectron_topkrois.h index 5c2db1fa2303ea..3fe134948d5e45 100644 --- a/src/plugins/intel_cpu/src/nodes/experimental_detectron_topkrois.h +++ b/src/plugins/intel_cpu/src/nodes/experimental_detectron_topkrois.h @@ -14,14 +14,20 @@ class ExperimentalDetectronTopKROIs : public Node { public: ExperimentalDetectronTopKROIs(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; - bool needShapeInfer() const override { return false; }; - bool needPrepareParams() const override { return false; }; - void executeDynamicImpl(dnnl::stream strm) override { execute(strm); }; + bool needShapeInfer() const override { + return false; + }; + bool needPrepareParams() const override { + return false; + }; + void executeDynamicImpl(dnnl::stream strm) override { + execute(strm); + }; static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; @@ -32,15 +38,15 @@ class ExperimentalDetectronTopKROIs : public Node { // Outputs: // top_rois, shape [max_rois, 4] - const int INPUT_ROIS {0}; - const int INPUT_PROBS {1}; + const int INPUT_ROIS{0}; + const int INPUT_PROBS{1}; - const int OUTPUT_ROIS {0}; + const int OUTPUT_ROIS{0}; int max_rois_num_; std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/extract_image_patches.cpp b/src/plugins/intel_cpu/src/nodes/extract_image_patches.cpp index 8b5d0b510614e1..51ae2123bbd382 100644 --- a/src/plugins/intel_cpu/src/nodes/extract_image_patches.cpp +++ b/src/plugins/intel_cpu/src/nodes/extract_image_patches.cpp @@ -3,15 +3,16 @@ // #include "extract_image_patches.h" -#include "common/primitive_hashing_utils.hpp" -#include "cpu/x64/jit_generator.hpp" -#include "openvino/core/parallel.hpp" -#include "openvino/opsets/opset3.hpp" #include #include #include +#include "common/primitive_hashing_utils.hpp" +#include "cpu/x64/jit_generator.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/opsets/opset3.hpp" + using namespace dnnl::impl::cpu; using namespace dnnl::impl::cpu::x64; using namespace dnnl::impl::utils; @@ -21,13 +22,15 @@ namespace ov { namespace intel_cpu { namespace node { #if defined(OPENVINO_ARCH_X86_64) -#define GET_OFF(field) offsetof(jit_extract_image_patches_args, field) +# define GET_OFF(field) offsetof(jit_extract_image_patches_args, field) template struct jit_extract_image_patches_kernel : public jit_uni_extract_image_patches_kernel, public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_extract_image_patches_kernel) - explicit jit_extract_image_patches_kernel(jit_extract_image_patches_params jpp) : jit_uni_extract_image_patches_kernel(jpp), jit_generator(jit_name()) {} + explicit jit_extract_image_patches_kernel(jit_extract_image_patches_params jpp) + : jit_uni_extract_image_patches_kernel(jpp), + jit_generator(jit_name()) {} void create_ker() override { jit_generator::create_kernel(); @@ -92,35 +95,47 @@ struct jit_extract_image_patches_kernel : public jit_uni_extract_image_patches_k Vmm vmm = Vmm(0); Xmm xmm = Xmm(0); - Vmm vmm_zero = Vmm(1); // reserved for pad + Vmm vmm_zero = Vmm(1); // reserved for pad Xbyak::Xmm xmm_aux = Xbyak::Xmm(2); Vmm vmm_gather_index = Vmm(3); Vmm vmm_gather_mask = Vmm(4); Opmask k_mask = Xbyak::Opmask(1); Xbyak::Label gather_index_table; - inline void load_scalar(Vmm vmm_arg, const Xbyak::Address &op) { + inline void load_scalar(Vmm vmm_arg, const Xbyak::Address& op) { Xbyak::Xmm xmm_src = Xmm(vmm_arg.getIdx()); switch (jpp.dtype_size) { - case 4: uni_vmovss(vmm_arg, op); break; - case 2: uni_vpinsrw(xmm_src, xmm_src, op, 0x0); break; - case 1: uni_vpinsrb(xmm_src, xmm_src, op, 0x0); break; - default: - OPENVINO_THROW("The data type of size '", jpp.dtype_size, "' is not supported."); + case 4: + uni_vmovss(vmm_arg, op); + break; + case 2: + uni_vpinsrw(xmm_src, xmm_src, op, 0x0); + break; + case 1: + uni_vpinsrb(xmm_src, xmm_src, op, 0x0); + break; + default: + OPENVINO_THROW("The data type of size '", jpp.dtype_size, "' is not supported."); } } - inline void store_scalar(const Xbyak::Address &op, Vmm vmm_arg) { + inline void store_scalar(const Xbyak::Address& op, Vmm vmm_arg) { Xbyak::Xmm xmm_dst = Xmm(vmm_arg.getIdx()); switch (jpp.dtype_size) { - case 4: uni_vmovss(op, vmm_arg); break; - case 2: uni_vpextrw(op, xmm_dst, 0x0); break; - case 1: uni_vpextrb(op, xmm_dst, 0x0); break; - default: - OPENVINO_THROW("The data type of size '", jpp.dtype_size, "' is not supported."); + case 4: + uni_vmovss(op, vmm_arg); + break; + case 2: + uni_vpextrw(op, xmm_dst, 0x0); + break; + case 1: + uni_vpextrb(op, xmm_dst, 0x0); + break; + default: + OPENVINO_THROW("The data type of size '", jpp.dtype_size, "' is not supported."); } } - inline void pad_with_zeros(reg64_t ®_num_pads_arg, reg64_t ®_dst_arg) { + inline void pad_with_zeros(reg64_t& reg_num_pads_arg, reg64_t& reg_dst_arg) { Xbyak::Label main, tail, exit; L(main); { @@ -143,57 +158,67 @@ struct jit_extract_image_patches_kernel : public jit_uni_extract_image_patches_k L(exit); } - inline void custom_uni_vgatherdps(const Vmm &vmm_arg, reg64_t &mem_base, const Vmm &mem_offset, Vmm &vmm_mask) { + inline void custom_uni_vgatherdps(const Vmm& vmm_arg, reg64_t& mem_base, const Vmm& mem_offset, Vmm& vmm_mask) { switch (isa) { - case x64::avx2: - uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); - vgatherdps(vmm_arg, ptr[mem_base + mem_offset], vmm_mask); - break; - case x64::avx512_core: - kxnord(k_mask, k_mask, k_mask); - vgatherdps(vmm_arg | k_mask, ptr[mem_base + mem_offset]); - break; - case x64::sse41: - emulate_gather(vmm_arg, mem_base); - break; - default: - OPENVINO_THROW("Got unsupported instruction set."); + case x64::avx2: + uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); + vgatherdps(vmm_arg, ptr[mem_base + mem_offset], vmm_mask); + break; + case x64::avx512_core: + kxnord(k_mask, k_mask, k_mask); + vgatherdps(vmm_arg | k_mask, ptr[mem_base + mem_offset]); + break; + case x64::sse41: + emulate_gather(vmm_arg, mem_base); + break; + default: + OPENVINO_THROW("Got unsupported instruction set."); } } - inline void gather_src2vmm(const Vmm &vmm_arg, reg64_t &mem_base) { + inline void gather_src2vmm(const Vmm& vmm_arg, reg64_t& mem_base) { switch (jpp.dtype_size) { - case 4: custom_uni_vgatherdps(vmm, mem_base, vmm_gather_index, vmm_gather_mask); break; - case 2: - case 1: emulate_gather(vmm_arg, mem_base); break; - default: - OPENVINO_THROW("The data type of size '", jpp.dtype_size, "' is not supported."); + case 4: + custom_uni_vgatherdps(vmm, mem_base, vmm_gather_index, vmm_gather_mask); + break; + case 2: + case 1: + emulate_gather(vmm_arg, mem_base); + break; + default: + OPENVINO_THROW("The data type of size '", jpp.dtype_size, "' is not supported."); } } - inline void emulate_gather(const Xbyak::Xmm &xmm_arg, reg64_t &mem_base, int xmm_offset = 0) { - const int xmm_size = 16; // bytes + inline void emulate_gather(const Xbyak::Xmm& xmm_arg, reg64_t& mem_base, int xmm_offset = 0) { + const int xmm_size = 16; // bytes const int xmm_block_size = xmm_size / jpp.dtype_size; const int offset = xmm_offset * jpp.SW * jpp.dtype_size * xmm_block_size; for (int i = 0; i < xmm_block_size; i++) { Xbyak::Address addr = ptr[mem_base + i * jpp.SW * jpp.dtype_size + offset]; switch (jpp.dtype_size) { - case 4: uni_vpinsrd(xmm_arg, xmm_arg, addr, i); break; - case 2: uni_vpinsrw(xmm_arg, xmm_arg, addr, i); break; - case 1: uni_vpinsrb(xmm_arg, xmm_arg, addr, i); break; - default: - OPENVINO_THROW("The data type of size '", jpp.dtype_size, "' is not supported."); + case 4: + uni_vpinsrd(xmm_arg, xmm_arg, addr, i); + break; + case 2: + uni_vpinsrw(xmm_arg, xmm_arg, addr, i); + break; + case 1: + uni_vpinsrb(xmm_arg, xmm_arg, addr, i); + break; + default: + OPENVINO_THROW("The data type of size '", jpp.dtype_size, "' is not supported."); } } } - inline void emulate_gather(const Xbyak::Ymm &ymm_arg, reg64_t &mem_base) { + inline void emulate_gather(const Xbyak::Ymm& ymm_arg, reg64_t& mem_base) { Xbyak::Xmm low_xmm = Xbyak::Xmm(ymm_arg.getIdx()); emulate_gather(low_xmm, mem_base, 0); emulate_gather(xmm_aux, mem_base, 1); vinserti128(ymm_arg, ymm_arg, xmm_aux, 1); } - inline void emulate_gather(const Xbyak::Zmm &zmm_arg, reg64_t &mem_base) { + inline void emulate_gather(const Xbyak::Zmm& zmm_arg, reg64_t& mem_base) { Xbyak::Xmm low_xmm = Xbyak::Xmm(zmm_arg.getIdx()); emulate_gather(low_xmm, mem_base, 0); for (int i = 1; i < 4; i++) { @@ -270,9 +295,10 @@ struct jit_extract_image_patches_kernel : public jit_uni_extract_image_patches_k dd(i * jpp.SW * jpp.dtype_size); } }; -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 -bool ExtractImagePatches::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool ExtractImagePatches::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { auto extImgPatcher = ov::as_type_ptr(op); if (!extImgPatcher) { @@ -284,7 +310,10 @@ bool ExtractImagePatches::isSupportedOperation(const std::shared_ptrget_sizes().size(), extImgPatcher->get_strides().size(), extImgPatcher->get_rates().size())) { + if (!everyone_is(2u, + extImgPatcher->get_sizes().size(), + extImgPatcher->get_strides().size(), + extImgPatcher->get_rates().size())) { errorMessage = "Doesn't support 'sizes', 'strides', 'rates', attributes with rank != 2"; return false; } @@ -323,7 +352,7 @@ size_t ExtractImagePatchesKey::hash() const { bool ExtractImagePatchesKey::operator==(const ExtractImagePatchesKey& rhs) const { bool result = inDims == rhs.inDims && outDims == rhs.outDims && kSizes == rhs.kSizes && strides == rhs.strides && - rates == rhs.rates && padType == rhs.padType && prcSize == rhs.prcSize; + rates == rhs.rates && padType == rhs.padType && prcSize == rhs.prcSize; return result; } } // namespace @@ -362,7 +391,8 @@ ExtractImagePatches::ExtractImagePatches(const std::shared_ptr& op, co OPENVINO_THROW(errorPrefix, "has unsupported pad type: ", extImgPatcher->get_auto_pad()); } - _ksizes = extImgPatcher->get_sizes();; + _ksizes = extImgPatcher->get_sizes(); + ; _strides = extImgPatcher->get_strides(); _rates = extImgPatcher->get_rates(); if (_ksizes.size() != 2 || _strides.size() != 2 || _rates.size() != 2) @@ -416,9 +446,7 @@ void ExtractImagePatches::initSupportedPrimitiveDescriptors() { if (_supported_precisions_sizes.find(precision.size()) == _supported_precisions_sizes.end()) OPENVINO_THROW(errorPrefix, "has unsupported precision: ", precision.get_type_name()); - addSupportedPrimDesc({{LayoutType::ncsp, precision}}, - {{LayoutType::ncsp, precision}}, - impl_desc_type::ref_any); + addSupportedPrimDesc({{LayoutType::ncsp, precision}}, {{LayoutType::ncsp, precision}}, impl_desc_type::ref_any); } void ExtractImagePatches::execute(dnnl::stream strm) { @@ -437,12 +465,17 @@ void ExtractImagePatches::executeDynamicImpl(dnnl::stream strm) { execute(strm); } -void ExtractImagePatches::ExtractImagePatchesRefExecutor::executeReference( - void* src, void* dst, const VectorDims& istrides, const VectorDims& ostrides) const { +void ExtractImagePatches::ExtractImagePatchesRefExecutor::executeReference(void* src, + void* dst, + const VectorDims& istrides, + const VectorDims& ostrides) const { const char* src_data = reinterpret_cast(src); char* dst_data = reinterpret_cast(dst); - const std::vector ostrides_partial = { ostrides[0], jpp.KW * IC * ostrides[1], IC * ostrides[1], ostrides[1] }; + const std::vector ostrides_partial = {ostrides[0], + jpp.KW * IC * ostrides[1], + IC * ostrides[1], + ostrides[1]}; parallel_for4d(OB, jpp.KH, jpp.KW, IC, [&](const size_t ob, const size_t kh, const size_t kw, const size_t ic) { const int64_t iw_start = static_cast(kw * RW) - PL; @@ -450,12 +483,17 @@ void ExtractImagePatches::ExtractImagePatchesRefExecutor::executeReference( const size_t ih_lpad = ih_start >= 0 ? 0 : std::ceil(-1.f * ih_start / jpp.SH); const size_t iw_lpad = iw_start >= 0 ? 0 : std::ceil(-1.f * iw_start / jpp.SW); - const size_t ih_hpad = std::ceil((IH - 1.f * ih_start) / jpp.SH) > jpp.OH ? jpp.OH : std::ceil((IH + -1.f * ih_start) / jpp.SH); - const size_t iw_hpad = std::ceil((jpp.IW - 1.f * iw_start) / jpp.SW) > jpp.OW ? jpp.OW : std::ceil((jpp.IW - 1.f * iw_start) / jpp.SW); + const size_t ih_hpad = + std::ceil((IH - 1.f * ih_start) / jpp.SH) > jpp.OH ? jpp.OH : std::ceil((IH + -1.f * ih_start) / jpp.SH); + const size_t iw_hpad = std::ceil((jpp.IW - 1.f * iw_start) / jpp.SW) > jpp.OW + ? jpp.OW + : std::ceil((jpp.IW - 1.f * iw_start) / jpp.SW); - char* my_dst_ptr = dst_data + - (ob * ostrides_partial[0] + kh * ostrides_partial[1] + kw * ostrides_partial[2] + ic * ostrides_partial[3]) * jpp.dtype_size; - const char* my_src_ptr = src_data + (ob * istrides[0] + ic * istrides[1] + ih_start * istrides[2] + iw_start) * jpp.dtype_size; + char* my_dst_ptr = dst_data + (ob * ostrides_partial[0] + kh * ostrides_partial[1] + kw * ostrides_partial[2] + + ic * ostrides_partial[3]) * + jpp.dtype_size; + const char* my_src_ptr = + src_data + (ob * istrides[0] + ic * istrides[1] + ih_start * istrides[2] + iw_start) * jpp.dtype_size; size_t num_bytes_to_set = ih_lpad * jpp.OW * jpp.dtype_size; memset(my_dst_ptr, 0, num_bytes_to_set); @@ -463,14 +501,15 @@ void ExtractImagePatches::ExtractImagePatchesRefExecutor::executeReference( const char* src_ptr_h_stop = my_src_ptr + ih_hpad * jpp.SH * jpp.IW * jpp.dtype_size; for (const char* src_h_ptr = my_src_ptr + ih_lpad * jpp.SH * jpp.IW * jpp.dtype_size; - src_h_ptr < src_ptr_h_stop; src_h_ptr += jpp.SH * jpp.IW * jpp.dtype_size) { + src_h_ptr < src_ptr_h_stop; + src_h_ptr += jpp.SH * jpp.IW * jpp.dtype_size) { num_bytes_to_set = iw_lpad * jpp.dtype_size; memset(my_dst_ptr, 0, num_bytes_to_set); my_dst_ptr += num_bytes_to_set; const char* src_ptr_w_stop = src_h_ptr + iw_hpad * jpp.SW * jpp.dtype_size; - for (const char* src_w_ptr = src_h_ptr + iw_lpad * jpp.SW * jpp.dtype_size; - src_w_ptr < src_ptr_w_stop; src_w_ptr += jpp.SW * jpp.dtype_size) { + for (const char* src_w_ptr = src_h_ptr + iw_lpad * jpp.SW * jpp.dtype_size; src_w_ptr < src_ptr_w_stop; + src_w_ptr += jpp.SW * jpp.dtype_size) { num_bytes_to_set = jpp.dtype_size; memcpy(my_dst_ptr, src_w_ptr, num_bytes_to_set); my_dst_ptr += num_bytes_to_set; @@ -484,25 +523,35 @@ void ExtractImagePatches::ExtractImagePatchesRefExecutor::executeReference( }); } -void ExtractImagePatches::ExtractImagePatchesJitExecutor::executeOptimizedGeneric( - void* src, void* dst, const VectorDims& istrides, const VectorDims& ostrides) const { +void ExtractImagePatches::ExtractImagePatchesJitExecutor::executeOptimizedGeneric(void* src, + void* dst, + const VectorDims& istrides, + const VectorDims& ostrides) const { #if defined(OPENVINO_ARCH_X86_64) const char* src_data = reinterpret_cast(src); char* dst_data = reinterpret_cast(dst); const auto& jpp = pKernel->jpp; - const std::vector ostrides_partial = { ostrides[0], jpp.KW * IC * ostrides[1], IC * ostrides[1], ostrides[1] }; + const std::vector ostrides_partial = {ostrides[0], + jpp.KW * IC * ostrides[1], + IC * ostrides[1], + ostrides[1]}; parallel_for4d(OB, jpp.KH, jpp.KW, IC, [&](const size_t ob, const size_t kh, const size_t kw, const size_t ic) { const int64_t ih_start = kh * RH - PT; const int64_t iw_start = kw * RW - PL; const size_t ih_lpad = ih_start >= 0 ? 0 : std::ceil(-1.f * ih_start / jpp.SH); const size_t iw_lpad = iw_start >= 0 ? 0 : std::ceil(-1.f * iw_start / jpp.SW); - const size_t ih_hpad = std::ceil((IH - 1.f * ih_start) / jpp.SH) > jpp.OH ? jpp.OH : std::ceil((IH - 1.f * ih_start) / jpp.SH); - const size_t iw_hpad = std::ceil((jpp.IW - 1.f * iw_start) / jpp.SW) > jpp.OW ? jpp.OW : std::ceil((jpp.IW - 1.f * iw_start) / jpp.SW); + const size_t ih_hpad = + std::ceil((IH - 1.f * ih_start) / jpp.SH) > jpp.OH ? jpp.OH : std::ceil((IH - 1.f * ih_start) / jpp.SH); + const size_t iw_hpad = std::ceil((jpp.IW - 1.f * iw_start) / jpp.SW) > jpp.OW + ? jpp.OW + : std::ceil((jpp.IW - 1.f * iw_start) / jpp.SW); - size_t dst_offset = ob * ostrides_partial[0] + kh * ostrides_partial[1] + kw * ostrides_partial[2] + ic * ostrides_partial[3]; - size_t src_offset = ob * istrides[0] + ic * istrides[1] + ih_start * istrides[2] + iw_start + ih_lpad * jpp.SH * jpp.IW; + size_t dst_offset = + ob * ostrides_partial[0] + kh * ostrides_partial[1] + kw * ostrides_partial[2] + ic * ostrides_partial[3]; + size_t src_offset = + ob * istrides[0] + ic * istrides[1] + ih_start * istrides[2] + iw_start + ih_lpad * jpp.SH * jpp.IW; auto args = jit_extract_image_patches_args(); args.src = src_data + src_offset * jpp.dtype_size; @@ -513,7 +562,7 @@ void ExtractImagePatches::ExtractImagePatchesJitExecutor::executeOptimizedGeneri args.w_hi_pad = iw_hpad; (*pKernel)(&args); }); -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 } jit_extract_image_patches_params ExtractImagePatches::ExtractImagePatchesExecutor::fillJpp( @@ -584,14 +633,13 @@ jit_extract_image_patches_params ExtractImagePatches::ExtractImagePatchesExecuto return jpp; } -ExtractImagePatches::ExtractImagePatchesJitExecutor::ExtractImagePatchesJitExecutor( - const VectorDims& inDims, - const VectorDims& outDims, - const VectorDims& kSizes, - const VectorDims& strides, - const VectorDims& rates, - const ExtImgPatcherPadType& padType, - const size_t prcSize) { +ExtractImagePatches::ExtractImagePatchesJitExecutor::ExtractImagePatchesJitExecutor(const VectorDims& inDims, + const VectorDims& outDims, + const VectorDims& kSizes, + const VectorDims& strides, + const VectorDims& rates, + const ExtImgPatcherPadType& padType, + const size_t prcSize) { #if defined(OPENVINO_ARCH_X86_64) auto jpp = fillJpp(inDims, outDims, kSizes, strides, rates, padType, prcSize); if (mayiuse(x64::avx512_core)) { @@ -606,27 +654,31 @@ ExtractImagePatches::ExtractImagePatchesJitExecutor::ExtractImagePatchesJitExecu if (pKernel) pKernel->create_ker(); -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 } -void ExtractImagePatches::ExtractImagePatchesJitExecutor::exec( - void* src, void* dst, const VectorDims& istrides, const VectorDims& ostrides) { +void ExtractImagePatches::ExtractImagePatchesJitExecutor::exec(void* src, + void* dst, + const VectorDims& istrides, + const VectorDims& ostrides) { if (!pKernel) OPENVINO_THROW("Can't execute, kernel for extract image patches node is not compiled"); executeOptimizedGeneric(src, dst, istrides, ostrides); } -ExtractImagePatches::ExtractImagePatchesRefExecutor::ExtractImagePatchesRefExecutor( - const VectorDims& inDims, - const VectorDims& outDims, - const VectorDims& kSizes, - const VectorDims& strides, - const VectorDims& rates, - const ExtImgPatcherPadType& padType, - const size_t prcSize) : jpp(fillJpp(inDims, outDims, kSizes, strides, rates, padType, prcSize)) {} - -void ExtractImagePatches::ExtractImagePatchesRefExecutor::exec( - void* src, void* dst, const VectorDims& istrides, const VectorDims& ostrides) { +ExtractImagePatches::ExtractImagePatchesRefExecutor::ExtractImagePatchesRefExecutor(const VectorDims& inDims, + const VectorDims& outDims, + const VectorDims& kSizes, + const VectorDims& strides, + const VectorDims& rates, + const ExtImgPatcherPadType& padType, + const size_t prcSize) + : jpp(fillJpp(inDims, outDims, kSizes, strides, rates, padType, prcSize)) {} + +void ExtractImagePatches::ExtractImagePatchesRefExecutor::exec(void* src, + void* dst, + const VectorDims& istrides, + const VectorDims& ostrides) { executeReference(src, dst, istrides, ostrides); } @@ -636,6 +688,6 @@ bool ExtractImagePatches::created() const { return getType() == Type::ExtractImagePatches; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/extract_image_patches.h b/src/plugins/intel_cpu/src/nodes/extract_image_patches.h index 15220fd51a4285..1844b5cafeeb07 100644 --- a/src/plugins/intel_cpu/src/nodes/extract_image_patches.h +++ b/src/plugins/intel_cpu/src/nodes/extract_image_patches.h @@ -30,8 +30,11 @@ struct jit_extract_image_patches_args { }; struct jit_uni_extract_image_patches_kernel { - void (*ker_)(const jit_extract_image_patches_args *); - void operator()(const jit_extract_image_patches_args *args) { assert(ker_); ker_(args); } + void (*ker_)(const jit_extract_image_patches_args*); + void operator()(const jit_extract_image_patches_args* args) { + assert(ker_); + ker_(args); + } jit_extract_image_patches_params jpp; virtual void create_ker() = 0; explicit jit_uni_extract_image_patches_kernel(jit_extract_image_patches_params jpp) : ker_(nullptr), jpp(jpp) {} @@ -42,7 +45,7 @@ class ExtractImagePatches : public Node { public: ExtractImagePatches(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -51,11 +54,7 @@ class ExtractImagePatches : public Node { void prepareParams() override; static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; - enum class ExtImgPatcherPadType { - VALID, - SAME_LOWER, - SAME_UPPER - }; + enum class ExtImgPatcherPadType { VALID, SAME_LOWER, SAME_UPPER }; private: std::vector _ksizes; @@ -69,14 +68,13 @@ class ExtractImagePatches : public Node { struct ExtractImagePatchesExecutor { ExtractImagePatchesExecutor() = default; virtual void exec(void* src, void* dst, const VectorDims& istrides, const VectorDims& ostrides) = 0; - jit_extract_image_patches_params fillJpp( - const VectorDims& inDims, - const VectorDims& outDims, - const VectorDims& kSizes, - const VectorDims& strides, - const VectorDims& rates, - const ExtImgPatcherPadType& padType, - const size_t prcSize); + jit_extract_image_patches_params fillJpp(const VectorDims& inDims, + const VectorDims& outDims, + const VectorDims& kSizes, + const VectorDims& strides, + const VectorDims& rates, + const ExtImgPatcherPadType& padType, + const size_t prcSize); virtual ~ExtractImagePatchesExecutor() = default; protected: @@ -93,30 +91,31 @@ class ExtractImagePatches : public Node { executorPtr execPtr = nullptr; struct ExtractImagePatchesJitExecutor : public ExtractImagePatchesExecutor { - ExtractImagePatchesJitExecutor( - const VectorDims& inDims, - const VectorDims& outDims, - const VectorDims& kSizes, - const VectorDims& strides, - const VectorDims& rates, - const ExtImgPatcherPadType& padType, - const size_t prcSize); + ExtractImagePatchesJitExecutor(const VectorDims& inDims, + const VectorDims& outDims, + const VectorDims& kSizes, + const VectorDims& strides, + const VectorDims& rates, + const ExtImgPatcherPadType& padType, + const size_t prcSize); void exec(void* src, void* dst, const VectorDims& istrides, const VectorDims& ostrides) override; - void executeOptimizedGeneric(void* src, void* dst, const VectorDims& istrides, const VectorDims& ostrides) const; + void executeOptimizedGeneric(void* src, + void* dst, + const VectorDims& istrides, + const VectorDims& ostrides) const; private: std::unique_ptr pKernel; }; struct ExtractImagePatchesRefExecutor : public ExtractImagePatchesExecutor { - ExtractImagePatchesRefExecutor( - const VectorDims& inDims, - const VectorDims& outDims, - const VectorDims& kSizes, - const VectorDims& strides, - const VectorDims& rates, - const ExtImgPatcherPadType& padType, - const size_t prcSize); + ExtractImagePatchesRefExecutor(const VectorDims& inDims, + const VectorDims& outDims, + const VectorDims& kSizes, + const VectorDims& strides, + const VectorDims& rates, + const ExtImgPatcherPadType& padType, + const size_t prcSize); void exec(void* src, void* dst, const VectorDims& istrides, const VectorDims& ostrides) override; void executeReference(void* src, void* dst, const VectorDims& istrides, const VectorDims& ostrides) const; @@ -125,6 +124,6 @@ class ExtractImagePatches : public Node { }; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/eye.cpp b/src/plugins/intel_cpu/src/nodes/eye.cpp index 19e466ad68751a..deb47abdba2dee 100644 --- a/src/plugins/intel_cpu/src/nodes/eye.cpp +++ b/src/plugins/intel_cpu/src/nodes/eye.cpp @@ -3,9 +3,11 @@ // #include "eye.h" -#include "openvino/op/eye.hpp" + #include + #include "openvino/core/parallel.hpp" +#include "openvino/op/eye.hpp" #include "shape_inference/shape_inference.hpp" #include "utils/bfloat16.hpp" @@ -36,20 +38,21 @@ class EyeShapeInferFactory : public ShapeInferFactory { return (m_op->get_input_size() == 4) ? make_shape_inference(m_op) : make_shape_inference(m_op, PortMask(Eye::ROWS_NUM, Eye::COLS_NUM)); } + private: std::shared_ptr m_op; }; -} // namespace +} // namespace -Eye::Eye(const std::shared_ptr& op, const GraphContext::CPtr context) : Node(op, context, EyeShapeInferFactory(op)) { +Eye::Eye(const std::shared_ptr& op, const GraphContext::CPtr context) + : Node(op, context, EyeShapeInferFactory(op)) { std::string errorMessage; if (!isSupportedOperation(op, errorMessage)) { - OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage); + OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage); } outType = op->get_output_element_type(0); withBatchShape = (op->get_input_size() == 4); - if (!one_of(outType, ov::element::f32, ov::element::bf16, - ov::element::i32, ov::element::i8, ov::element::u8)) { + if (!one_of(outType, ov::element::f32, ov::element::bf16, ov::element::i32, ov::element::i8, ov::element::u8)) { THROW_ERROR(errorPrefix, "doesn't support demanded output precision"); } } @@ -61,16 +64,19 @@ void Eye::getSupportedDescriptors() { THROW_ERROR(errorPrefix, "has incorrect number of output edges: ", getChildEdges().size()); } -template +template struct Eye::EyeExecute { - void operator()(Eye *node) { + void operator()(Eye* node) { node->executeSpecified(); } }; void Eye::execute(dnnl::stream strm) { auto outputPrec = getChildEdgeAt(0)->getMemory().getDesc().getPrecision(); - OV_SWITCH(intel_cpu, EyeExecute, this, outputPrec, + OV_SWITCH(intel_cpu, + EyeExecute, + this, + outputPrec, OV_CASE(ov::element::f32, float), OV_CASE(ov::element::bf16, bfloat16_t), OV_CASE(ov::element::i32, int), @@ -99,9 +105,9 @@ void Eye::executeSpecified() { const size_t colNum = getColNum(); const int64_t shift = getDiagIndex(); auto outPtr = getDstMemoryAtPort(0); - if (!outPtr || !outPtr ->isDefined()) + if (!outPtr || !outPtr->isDefined()) THROW_ERROR(errorPrefix, "Destination memory is undefined."); - T *dst = outPtr->getDataAs(); + T* dst = outPtr->getDataAs(); const size_t batchVolume = getBatchVolume(getBatchShape()); const size_t spatialCount = colNum * rowNum; @@ -111,8 +117,8 @@ void Eye::executeSpecified() { const int64_t countByColumns = std::max(int64_t(colNum) - std::abs(shift), int64_t(0)); const int64_t countByRows = std::max(int64_t(rowNum) - std::abs(shift), int64_t(0)); - const size_t onesPerBatchNum = - static_cast(shift > 0 ? std::min(countByColumns, int64_t(rowNum)) : std::min(countByRows, int64_t(colNum))); + const size_t onesPerBatchNum = static_cast(shift > 0 ? std::min(countByColumns, int64_t(rowNum)) + : std::min(countByRows, int64_t(colNum))); const size_t dataShift = static_cast(shift >= 0 ? shift : -shift * colNum); if (spatialSize >= l2CacheSize) { @@ -121,7 +127,8 @@ void Eye::executeSpecified() { splitter(elementsCount, nthr, ithr, start, end); memset(dst + start, 0, (end - start) * sizeof(T)); }); - if (onesPerBatchNum == 0) return; + if (onesPerBatchNum == 0) + return; for (size_t bShift = 0; bShift < batchVolume * spatialCount; bShift += spatialCount) { parallel_nt(0, [&](const size_t ithr, const size_t nthr) { size_t start = 0, end = 0; @@ -136,7 +143,8 @@ void Eye::executeSpecified() { size_t start = 0, end = 0; splitter(batchVolume, nthr, ithr, start, end); memset(dst + start * spatialCount, 0, (end - start) * spatialSize); - if (onesPerBatchNum == 0) return; + if (onesPerBatchNum == 0) + return; for (size_t spShift = start * spatialCount; spShift < end * spatialCount; spShift += spatialCount) { for (size_t j = 0; j < onesPerBatchNum; j++) { dst[dataShift + j * (colNum + 1) + spShift] = static_cast(1); @@ -149,6 +157,6 @@ void Eye::executeSpecified() { bool Eye::created() const { return getType() == Type::Eye; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/eye.h b/src/plugins/intel_cpu/src/nodes/eye.h index 7978c45d8a05d1..fc2b42a18bdbe9 100644 --- a/src/plugins/intel_cpu/src/nodes/eye.h +++ b/src/plugins/intel_cpu/src/nodes/eye.h @@ -5,9 +5,11 @@ #pragma once #include -#include + #include +#include #include + #include "dnnl_extension_utils.h" namespace ov { @@ -28,9 +30,15 @@ class Eye : public Node { void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; - bool needPrepareParams() const override {return false;}; - bool needShapeInfer() const override {return true;}; - void executeDynamicImpl(dnnl::stream strm) override { execute(strm); } + bool needPrepareParams() const override { + return false; + }; + bool needShapeInfer() const override { + return true; + }; + void executeDynamicImpl(dnnl::stream strm) override { + execute(strm); + } static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; @@ -39,13 +47,13 @@ class Eye : public Node { ov::element::Type outType = ov::element::Type_t::undefined; template void executeSpecified(); - template + template struct EyeExecute; inline const size_t getRowNum() const { auto rowMem = getSrcMemoryAtPort(ROWS_NUM); if (rowMem == nullptr) OPENVINO_THROW(errorPrefix, " doesn't contain row_count data"); - const int *rowPtr = rowMem->getDataAs(); + const int* rowPtr = rowMem->getDataAs(); return rowPtr[0]; } @@ -53,7 +61,7 @@ class Eye : public Node { auto colMem = getSrcMemoryAtPort(COLS_NUM); if (colMem == nullptr) OPENVINO_THROW(errorPrefix, " doesn't contain col_count data"); - const int *colPtr = colMem->getDataAs(); + const int* colPtr = colMem->getDataAs(); return colPtr[0]; } @@ -61,28 +69,29 @@ class Eye : public Node { auto diagIndMem = getSrcMemoryAtPort(DIAGONAL_INDEX); if (diagIndMem == nullptr) OPENVINO_THROW(errorPrefix, " doesn't contain diag_index data"); - const int *diagIndexPtr = diagIndMem->getDataAs(); + const int* diagIndexPtr = diagIndMem->getDataAs(); return diagIndexPtr[0]; } inline const std::vector getBatchShape() const { if (withBatchShape) { - const int batchShapeSize = static_cast(getSrcMemoryAtPort(BATCH_SHAPE)->getShape().getElementsCount()); + const int batchShapeSize = + static_cast(getSrcMemoryAtPort(BATCH_SHAPE)->getShape().getElementsCount()); std::vector batchShape(batchShapeSize); - const int *batchShapePtr = getSrcDataAtPortAs(BATCH_SHAPE); + const int* batchShapePtr = getSrcDataAtPortAs(BATCH_SHAPE); batchShape.assign(batchShapePtr, batchShapePtr + batchShapeSize); return batchShape; } else { - return std::vector {}; + return std::vector{}; } } - inline const size_t getBatchVolume(const std::vector &batchShape) { + inline const size_t getBatchVolume(const std::vector& batchShape) { return std::accumulate(begin(batchShape), end(batchShape), 1, std::multiplies()); } bool withBatchShape = false; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/fake_quantize.cpp b/src/plugins/intel_cpu/src/nodes/fake_quantize.cpp index f12ab40cf5643b..9951c5176f0ad1 100644 --- a/src/plugins/intel_cpu/src/nodes/fake_quantize.cpp +++ b/src/plugins/intel_cpu/src/nodes/fake_quantize.cpp @@ -4,28 +4,27 @@ #include "fake_quantize.h" -#include -#include #include +#include + #include -#include #include - -#include "dnnl_types.h" -#include "dnnl_extension_utils.h" -#include "cpu/x64/jit_generator.hpp" #include +#include +#include +#include +#include -#include "openvino/core/parallel.hpp" -#include "utils/general_utils.h" -#include "utils/cpu_utils.hpp" -#include -#include "memory_desc/dnnl_blocked_memory_desc.h" #include "common/cpu_memcpy.h" #include "common/primitive_hashing_utils.hpp" -#include - +#include "cpu/x64/jit_generator.hpp" +#include "dnnl_extension_utils.h" +#include "dnnl_types.h" +#include "memory_desc/dnnl_blocked_memory_desc.h" +#include "openvino/core/parallel.hpp" #include "openvino/opsets/opset1.hpp" +#include "utils/cpu_utils.hpp" +#include "utils/general_utils.h" #include "utils/ngraph_utils.hpp" // Quantization ranges validation is switched off by default in order to avoid regressions on user side @@ -45,13 +44,15 @@ namespace ov { namespace intel_cpu { namespace node { #if defined(OPENVINO_ARCH_X86_64) -#define GET_OFF(field) offsetof(jit_quantize_call_args, field) +# define GET_OFF(field) offsetof(jit_quantize_call_args, field) template struct jit_uni_binarization_kernel : public jit_uni_quantize_kernel, public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_binarization_kernel) - explicit jit_uni_binarization_kernel(const jit_quantize_params& jqp) : jit_uni_quantize_kernel(jqp), jit_generator(jit_name()) {} + explicit jit_uni_binarization_kernel(const jit_quantize_params& jqp) + : jit_uni_quantize_kernel(jqp), + jit_generator(jit_name()) {} void create_ker() override { jit_generator::create_kernel(); @@ -77,7 +78,8 @@ struct jit_uni_binarization_kernel : public jit_uni_quantize_kernel, public jit_ Label tail_label; Label exit_label; - L(unrolled_loop_label); { + L(unrolled_loop_label); + { int step = isa == cpu::x64::sse41 ? nbits / 2 : isa == cpu::x64::avx2 ? nbits : 2 * nbits; const int ur_ch = isa == cpu::x64::sse41 ? nbits : isa == cpu::x64::avx2 ? nbits / 2 : nbits / 4; const int unrolled_loop_step = ur_ch * step; @@ -87,9 +89,9 @@ struct jit_uni_binarization_kernel : public jit_uni_quantize_kernel, public jit_ xor_(reg_bin_32, reg_bin_32); for (int ch = 0; ch < ur_ch; ch++) { - uni_vmovups(vmm_src(0), ptr[reg_from + ch*step*sizeof(float)]); - uni_vmovups(vmm_wei(0), ptr[reg_thresholds + ch*step*sizeof(float)]); - uni_vmovups(vmm_mask(0), ptr[reg_output_mask + ch*step*sizeof(float)]); + uni_vmovups(vmm_src(0), ptr[reg_from + ch * step * sizeof(float)]); + uni_vmovups(vmm_wei(0), ptr[reg_thresholds + ch * step * sizeof(float)]); + uni_vmovups(vmm_mask(0), ptr[reg_output_mask + ch * step * sizeof(float)]); if (isa == avx512_core) { vcmpps(k_mask0, vmm_src(0), vmm_wei(0), _cmp_gt_os); vptestmd(k_mask1, vmm_mask(0), vmm_mask(0)); @@ -105,16 +107,17 @@ struct jit_uni_binarization_kernel : public jit_uni_quantize_kernel, public jit_ } mov(ptr[reg_to], reg_bin_32); - add(reg_from, unrolled_loop_step*sizeof(float)); - add(reg_thresholds, unrolled_loop_step*sizeof(float)); - add(reg_output_mask, unrolled_loop_step*sizeof(float)); + add(reg_from, unrolled_loop_step * sizeof(float)); + add(reg_thresholds, unrolled_loop_step * sizeof(float)); + add(reg_output_mask, unrolled_loop_step * sizeof(float)); add(reg_to, sizeof(uint32_t)); sub(reg_work_amount, unrolled_loop_step); jmp(unrolled_loop_label, T_NEAR); } - L(main_loop_label); { + L(main_loop_label); + { int repeats = isa == cpu::x64::sse41 ? 2 : 1; int step = isa == cpu::x64::sse41 ? nbits / 2 : isa == cpu::x64::avx2 ? nbits : nbits * 2; const int main_loop_step = step * repeats; @@ -124,9 +127,9 @@ struct jit_uni_binarization_kernel : public jit_uni_quantize_kernel, public jit_ xor_(reg_bin_32, reg_bin_32); for (int i = 0; i < repeats; i++) { - uni_vmovups(vmm_src(0), ptr[reg_from + i*step*sizeof(float)]); - uni_vmovups(vmm_wei(0), ptr[reg_thresholds + i*step*sizeof(float)]); - uni_vmovups(vmm_mask(0), ptr[reg_output_mask + i*step*sizeof(float)]); + uni_vmovups(vmm_src(0), ptr[reg_from + i * step * sizeof(float)]); + uni_vmovups(vmm_wei(0), ptr[reg_thresholds + i * step * sizeof(float)]); + uni_vmovups(vmm_mask(0), ptr[reg_output_mask + i * step * sizeof(float)]); if (isa == avx512_core) { vcmpps(k_mask0, vmm_src(0), vmm_wei(0), _cmp_gt_os); vptestmd(k_mask1, vmm_mask(0), vmm_mask(0)); @@ -145,16 +148,17 @@ struct jit_uni_binarization_kernel : public jit_uni_quantize_kernel, public jit_ else mov(ptr[reg_to], reg_bin_8); - add(reg_from, main_loop_step*sizeof(float)); - add(reg_thresholds, main_loop_step*sizeof(float)); - add(reg_output_mask, main_loop_step*sizeof(float)); + add(reg_from, main_loop_step * sizeof(float)); + add(reg_thresholds, main_loop_step * sizeof(float)); + add(reg_output_mask, main_loop_step * sizeof(float)); add(reg_to, isa == avx512_core ? sizeof(uint16_t) : sizeof(uint8_t)); sub(reg_work_amount, main_loop_step); jmp(main_loop_label, T_NEAR); } - L(tail_label); { + L(tail_label); + { if (tail_size != 0) { xor_(reg_bin_32, reg_bin_32); mov(reg_mask, 1); @@ -188,15 +192,27 @@ struct jit_uni_binarization_kernel : public jit_uni_quantize_kernel, public jit_ } private: - using Vmm = typename conditional3::type; + using Vmm = + typename conditional3::type; - inline Vmm vmm_src(int idx) { return Vmm(idx); } - inline Xmm xmm_src(int idx) { return Xmm(idx); } - inline Vmm vmm_wei(int idx) { return Vmm(idx + 4); } - inline Vmm vmm_mask(int idx) { return Vmm(idx + 5); } - inline Xmm xmm_wei(int idx) { return Xmm(idx + 4); } - inline Xmm xmm_mask(int idx) { return Xmm(idx + 5); } + inline Vmm vmm_src(int idx) { + return Vmm(idx); + } + inline Xmm xmm_src(int idx) { + return Xmm(idx); + } + inline Vmm vmm_wei(int idx) { + return Vmm(idx + 4); + } + inline Vmm vmm_mask(int idx) { + return Vmm(idx + 5); + } + inline Xmm xmm_wei(int idx) { + return Xmm(idx + 4); + } + inline Xmm xmm_mask(int idx) { + return Xmm(idx + 5); + } Reg64 param = abi_param1; Reg64 reg_from = r8; @@ -219,7 +235,9 @@ template struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_quantization_kernel) - explicit jit_uni_quantization_kernel(const jit_quantize_params& jqp) : jit_uni_quantize_kernel(jqp), jit_generator(jit_name()) {} + explicit jit_uni_quantization_kernel(const jit_quantize_params& jqp) + : jit_uni_quantize_kernel(jqp), + jit_generator(jit_name()) {} void create_ker() override { jit_generator::create_kernel(); @@ -237,37 +255,78 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ else compute_generic(); - this->postamble(); } private: - using Vmm = typename conditional3::type; - - inline Vmm vmm_val(int idx) { return Vmm(idx + 0); } - inline Vmm vmm_crop_low(int idx) { return Vmm(idx + 2); } - inline Vmm vmm_crop_high(int idx) { return Vmm(idx + 4); } - inline Vmm vmm_input_scale(int idx) { return Vmm(idx + 6); } - inline Vmm vmm_input_shift(int idx) { return Vmm(idx + 8); } - inline Vmm vmm_output_scale(int idx) { return Vmm(idx + 10); } - inline Vmm vmm_output_shift(int idx) { return Vmm(idx + 12); } - - inline Ymm ymm_val(int idx) { return Ymm(idx + 0); } - inline Ymm ymm_crop_low(int idx) { return Ymm(idx + 2); } - inline Ymm ymm_crop_high(int idx) { return Ymm(idx + 4); } - inline Ymm ymm_input_scale(int idx) { return Ymm(idx + 6); } - inline Ymm ymm_input_shift(int idx) { return Ymm(idx + 8); } - inline Ymm ymm_output_scale(int idx) { return Ymm(idx + 10); } - inline Ymm ymm_output_shift(int idx) { return Ymm(idx + 12); } - - inline Xmm xmm_val(int idx) { return Xmm(idx + 0); } - inline Xmm xmm_crop_low(int idx) { return Xmm(idx + 2); } - inline Xmm xmm_crop_high(int idx) { return Xmm(idx + 4); } - inline Xmm xmm_input_scale(int idx) { return Xmm(idx + 6); } - inline Xmm xmm_input_shift(int idx) { return Xmm(idx + 8); } - inline Xmm xmm_output_scale(int idx) { return Xmm(idx + 10); } - inline Xmm xmm_output_shift(int idx) { return Xmm(idx + 12); } + using Vmm = + typename conditional3::type; + + inline Vmm vmm_val(int idx) { + return Vmm(idx + 0); + } + inline Vmm vmm_crop_low(int idx) { + return Vmm(idx + 2); + } + inline Vmm vmm_crop_high(int idx) { + return Vmm(idx + 4); + } + inline Vmm vmm_input_scale(int idx) { + return Vmm(idx + 6); + } + inline Vmm vmm_input_shift(int idx) { + return Vmm(idx + 8); + } + inline Vmm vmm_output_scale(int idx) { + return Vmm(idx + 10); + } + inline Vmm vmm_output_shift(int idx) { + return Vmm(idx + 12); + } + + inline Ymm ymm_val(int idx) { + return Ymm(idx + 0); + } + inline Ymm ymm_crop_low(int idx) { + return Ymm(idx + 2); + } + inline Ymm ymm_crop_high(int idx) { + return Ymm(idx + 4); + } + inline Ymm ymm_input_scale(int idx) { + return Ymm(idx + 6); + } + inline Ymm ymm_input_shift(int idx) { + return Ymm(idx + 8); + } + inline Ymm ymm_output_scale(int idx) { + return Ymm(idx + 10); + } + inline Ymm ymm_output_shift(int idx) { + return Ymm(idx + 12); + } + + inline Xmm xmm_val(int idx) { + return Xmm(idx + 0); + } + inline Xmm xmm_crop_low(int idx) { + return Xmm(idx + 2); + } + inline Xmm xmm_crop_high(int idx) { + return Xmm(idx + 4); + } + inline Xmm xmm_input_scale(int idx) { + return Xmm(idx + 6); + } + inline Xmm xmm_input_shift(int idx) { + return Xmm(idx + 8); + } + inline Xmm xmm_output_scale(int idx) { + return Xmm(idx + 10); + } + inline Xmm xmm_output_shift(int idx) { + return Xmm(idx + 12); + } Vmm vmm_zero = Vmm(14); @@ -296,24 +355,34 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ bool do_dequantization = true; inline void load_broadcasted_vectors_only(size_t idx) { - const auto &broadcasted = jqp_.broadcasted; - if (broadcasted[static_cast(FQ_add_input_type::CROP_LOW)]) uni_vbroadcastss(vmm_crop_low(idx), ptr[reg_crop_low]); - if (broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)]) uni_vbroadcastss(vmm_crop_high(idx), ptr[reg_crop_high]); - if (broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)]) uni_vbroadcastss(vmm_input_scale(idx), ptr[reg_input_scale]); - if (broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)]) uni_vbroadcastss(vmm_input_shift(idx), ptr[reg_input_shift]); + const auto& broadcasted = jqp_.broadcasted; + if (broadcasted[static_cast(FQ_add_input_type::CROP_LOW)]) + uni_vbroadcastss(vmm_crop_low(idx), ptr[reg_crop_low]); + if (broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)]) + uni_vbroadcastss(vmm_crop_high(idx), ptr[reg_crop_high]); + if (broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)]) + uni_vbroadcastss(vmm_input_scale(idx), ptr[reg_input_scale]); + if (broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)]) + uni_vbroadcastss(vmm_input_shift(idx), ptr[reg_input_shift]); if (do_dequantization) { - if (broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)]) uni_vbroadcastss(vmm_output_scale(idx), ptr[reg_output_scale]); - if (broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)]) uni_vbroadcastss(vmm_output_shift(idx), ptr[reg_output_shift]); + if (broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)]) + uni_vbroadcastss(vmm_output_scale(idx), ptr[reg_output_scale]); + if (broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)]) + uni_vbroadcastss(vmm_output_shift(idx), ptr[reg_output_shift]); } } template inline void load_not_broadcasted_vectors_only(size_t idx, size_t offset) { - const auto &broadcasted = jqp_.broadcasted; - if (!broadcasted[static_cast(FQ_add_input_type::CROP_LOW)]) uni_vmovups(T(vmm_crop_low(idx).getIdx()), ptr[reg_crop_low + offset]); - if (!broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)]) uni_vmovups(T(vmm_crop_high(idx).getIdx()), ptr[reg_crop_high + offset]); - if (!broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)]) uni_vmovups(T(vmm_input_scale(idx).getIdx()), ptr[reg_input_scale + offset]); - if (!broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)]) uni_vmovups(T(vmm_input_shift(idx).getIdx()), ptr[reg_input_shift + offset]); + const auto& broadcasted = jqp_.broadcasted; + if (!broadcasted[static_cast(FQ_add_input_type::CROP_LOW)]) + uni_vmovups(T(vmm_crop_low(idx).getIdx()), ptr[reg_crop_low + offset]); + if (!broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)]) + uni_vmovups(T(vmm_crop_high(idx).getIdx()), ptr[reg_crop_high + offset]); + if (!broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)]) + uni_vmovups(T(vmm_input_scale(idx).getIdx()), ptr[reg_input_scale + offset]); + if (!broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)]) + uni_vmovups(T(vmm_input_shift(idx).getIdx()), ptr[reg_input_shift + offset]); if (do_dequantization) { if (!broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)]) uni_vmovups(T(vmm_output_scale(idx).getIdx()), ptr[reg_output_scale + offset]); @@ -323,14 +392,20 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ } inline void increase_ptrs_if_not_broadcasted(size_t offset) { - const auto &broadcasted = jqp_.broadcasted; - if (!broadcasted[static_cast(FQ_add_input_type::CROP_LOW)]) add(reg_crop_low, offset); - if (!broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)]) add(reg_crop_high, offset); - if (!broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)]) add(reg_input_scale, offset); - if (!broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)]) add(reg_input_shift, offset); + const auto& broadcasted = jqp_.broadcasted; + if (!broadcasted[static_cast(FQ_add_input_type::CROP_LOW)]) + add(reg_crop_low, offset); + if (!broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)]) + add(reg_crop_high, offset); + if (!broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)]) + add(reg_input_scale, offset); + if (!broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)]) + add(reg_input_shift, offset); if (do_dequantization) { - if (!broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)]) add(reg_output_scale, offset); - if (!broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)]) add(reg_output_shift, offset); + if (!broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)]) + add(reg_output_scale, offset); + if (!broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)]) + add(reg_output_shift, offset); } } @@ -373,7 +448,8 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ uni_vbroadcastss(vmm_output_shift(0), ptr[reg_output_shift]); } - L(main_loop_label); { + L(main_loop_label); + { cmp(reg_work_amount, simd_w); jl(tail_blk4_label, T_NEAR); @@ -383,8 +459,10 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ uni_vminps(vmm_val(i), vmm_val(i), vmm_crop_high(0)); uni_vmaxps(vmm_val(i), vmm_val(i), vmm_crop_low(0)); uni_vfmadd213ps(vmm_val(i), vmm_input_scale(0), vmm_input_shift(0)); - if (do_rounding) uni_vroundps(vmm_val(i), vmm_val(i), 0); - if (do_dequantization) uni_vfmadd213ps(vmm_val(i), vmm_output_scale(0), vmm_output_shift(0)); + if (do_rounding) + uni_vroundps(vmm_val(i), vmm_val(i), 0); + if (do_dequantization) + uni_vfmadd213ps(vmm_val(i), vmm_output_scale(0), vmm_output_shift(0)); store_vector(ptr[reg_to + i * (simd_w / 2) * dst_type_size], vmm_val(i), jqp_.dst_prc); } @@ -396,7 +474,8 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ jmp(main_loop_label, T_NEAR); } - L(tail_blk4_label); { + L(tail_blk4_label); + { cmp(reg_work_amount, tail_simd_w); jl(tail_blk4_exit_label, T_NEAR); @@ -405,8 +484,10 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ uni_vminps(xmm_val(0), xmm_val(0), xmm_crop_high(0)); uni_vmaxps(xmm_val(0), xmm_val(0), xmm_crop_low(0)); uni_vfmadd213ps(xmm_val(0), xmm_input_scale(0), xmm_input_shift(0)); - if (do_rounding) uni_vroundps(xmm_val(0), xmm_val(0), 0); - if (do_dequantization) uni_vfmadd213ps(xmm_val(0), xmm_output_scale(0), xmm_output_shift(0)); + if (do_rounding) + uni_vroundps(xmm_val(0), xmm_val(0), 0); + if (do_dequantization) + uni_vfmadd213ps(xmm_val(0), xmm_output_scale(0), xmm_output_shift(0)); store_vector(ptr[reg_to], xmm_val(0), jqp_.dst_prc); @@ -420,7 +501,8 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ mov(aux_reg_from, reg_from); mov(aux_reg_to, reg_to); - L(tail_loop_label); { + L(tail_loop_label); + { cmp(reg_work_amount, 0); jle(exit_label, T_NEAR); @@ -429,8 +511,10 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ uni_vminps(xmm_val(0), xmm_val(0), xmm_crop_high(0)); uni_vmaxps(xmm_val(0), xmm_val(0), xmm_crop_low(0)); uni_vfmadd213ps(xmm_val(0), xmm_input_scale(0), xmm_input_shift(0)); - if (do_rounding) uni_vroundps(xmm_val(0), xmm_val(0), 0); - if (do_dequantization) uni_vfmadd213ps(xmm_val(0), xmm_output_scale(0), xmm_output_shift(0)); + if (do_rounding) + uni_vroundps(xmm_val(0), xmm_val(0), 0); + if (do_dequantization) + uni_vfmadd213ps(xmm_val(0), xmm_output_scale(0), xmm_output_shift(0)); store_scalar(ptr[aux_reg_to], xmm_val(0), jqp_.dst_prc); @@ -496,7 +580,8 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ load_not_broadcasted_vectors_only(i, i * (simd_w / 2) * sizeof(float)); } - L(main_loop_label); { + L(main_loop_label); + { cmp(reg_work_amount, 0); jle(exit_label, T_NEAR); @@ -506,8 +591,10 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ uni_vminps(vmm_val(i), vmm_val(i), vmm_crop_high(i)); uni_vmaxps(vmm_val(i), vmm_val(i), vmm_crop_low(i)); uni_vfmadd213ps(vmm_val(i), vmm_input_scale(i), vmm_input_shift(i)); - if (do_rounding) uni_vroundps(vmm_val(i), vmm_val(i), 0); - if (do_dequantization) uni_vfmadd213ps(vmm_val(i), vmm_output_scale(i), vmm_output_shift(i)); + if (do_rounding) + uni_vroundps(vmm_val(i), vmm_val(i), 0); + if (do_dequantization) + uni_vfmadd213ps(vmm_val(i), vmm_output_scale(i), vmm_output_shift(i)); store_vector(ptr[reg_to + i * (simd_w / 2) * dst_type_size], vmm_val(i), jqp_.dst_prc); } @@ -531,7 +618,8 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ load_not_broadcasted_vectors_only(0, 0); - L(tail_blk8_loop_label); { + L(tail_blk8_loop_label); + { cmp(reg_work_amount, 0); jle(tail_blk8_exit_label, T_NEAR); @@ -540,8 +628,10 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ uni_vminps(ymm_val(0), ymm_val(0), ymm_crop_high(0)); uni_vmaxps(ymm_val(0), ymm_val(0), ymm_crop_low(0)); uni_vfmadd213ps(ymm_val(0), ymm_input_scale(0), ymm_input_shift(0)); - if (do_rounding) uni_vroundps(ymm_val(0), ymm_val(0), 0); - if (do_dequantization) uni_vfmadd213ps(ymm_val(0), ymm_output_scale(0), ymm_output_shift(0)); + if (do_rounding) + uni_vroundps(ymm_val(0), ymm_val(0), 0); + if (do_dequantization) + uni_vfmadd213ps(ymm_val(0), ymm_output_scale(0), ymm_output_shift(0)); store_vector(ptr[aux_reg_to], ymm_val(0), jqp_.dst_prc); @@ -571,7 +661,8 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ load_not_broadcasted_vectors_only(0, 0); - L(tail_blk4_loop_label); { + L(tail_blk4_loop_label); + { cmp(reg_work_amount, 0); jle(tail_blk4_exit_label, T_NEAR); @@ -580,8 +671,10 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ uni_vminps(xmm_val(0), xmm_val(0), xmm_crop_high(0)); uni_vmaxps(xmm_val(0), xmm_val(0), xmm_crop_low(0)); uni_vfmadd213ps(xmm_val(0), xmm_input_scale(0), xmm_input_shift(0)); - if (do_rounding) uni_vroundps(xmm_val(0), xmm_val(0), 0); - if (do_dequantization) uni_vfmadd213ps(xmm_val(0), xmm_output_scale(0), xmm_output_shift(0)); + if (do_rounding) + uni_vroundps(xmm_val(0), xmm_val(0), 0); + if (do_dequantization) + uni_vfmadd213ps(xmm_val(0), xmm_output_scale(0), xmm_output_shift(0)); store_vector(ptr[aux_reg_to], xmm_val(0), jqp_.dst_prc); @@ -608,13 +701,14 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ mov(aux_reg_from, reg_from); mov(reg_work_amount, ptr[param + GET_OFF(work_amount)]); - L(tail_loop_label); { + L(tail_loop_label); + { cmp(reg_work_amount, 0); jle(exit_label, T_NEAR); Label end_unroll; auto tail_unroll = [&](size_t iter) { - const auto &broadcasted = jqp_.broadcasted; + const auto& broadcasted = jqp_.broadcasted; for (size_t i = 0; i < iter; i++) { if (!broadcasted[static_cast(FQ_add_input_type::CROP_LOW)]) uni_vmovss(xmm_crop_low(0), ptr[reg_crop_low + i * wei_type_size]); @@ -636,8 +730,10 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ uni_vminps(xmm_val(0), xmm_val(0), xmm_crop_high(0)); uni_vmaxps(xmm_val(0), xmm_val(0), xmm_crop_low(0)); uni_vfmadd213ps(xmm_val(0), xmm_input_scale(0), xmm_input_shift(0)); - if (do_rounding) uni_vroundps(xmm_val(0), xmm_val(0), 0); - if (do_dequantization) uni_vfmadd213ps(xmm_val(0), xmm_output_scale(0), xmm_output_shift(0)); + if (do_rounding) + uni_vroundps(xmm_val(0), xmm_val(0), 0); + if (do_dequantization) + uni_vfmadd213ps(xmm_val(0), xmm_output_scale(0), xmm_output_shift(0)); store_scalar(ptr[aux_reg_to + i * dst_type_size], xmm_val(0), jqp_.dst_prc); } @@ -667,20 +763,20 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ L(exit_label); } - inline void load_vector(Zmm zmm_src, const Xbyak::Address &op, ov::element::Type src_prc) { + inline void load_vector(Zmm zmm_src, const Xbyak::Address& op, ov::element::Type src_prc) { switch (src_prc) { - case ov::element::f32: - case ov::element::i32: - uni_vmovups(zmm_src, op); - break; - case ov::element::i8: - uni_vpmovsxbd(zmm_src, op); - break; - case ov::element::u8: - uni_vpmovzxbd(zmm_src, op); - break; - default: - assert(!"unknown src_prc"); + case ov::element::f32: + case ov::element::i32: + uni_vmovups(zmm_src, op); + break; + case ov::element::i8: + uni_vpmovsxbd(zmm_src, op); + break; + case ov::element::u8: + uni_vpmovzxbd(zmm_src, op); + break; + default: + assert(!"unknown src_prc"); } if (src_prc != ov::element::f32) { @@ -688,20 +784,20 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ } } - inline void load_vector(Ymm ymm_src, const Xbyak::Address &op, ov::element::Type src_prc) { + inline void load_vector(Ymm ymm_src, const Xbyak::Address& op, ov::element::Type src_prc) { switch (src_prc) { - case ov::element::f32: - case ov::element::i32: - uni_vmovups(ymm_src, op); - break; - case ov::element::i8: - uni_vpmovsxbd(ymm_src, op); - break; - case ov::element::u8: - uni_vpmovzxbd(ymm_src, op); - break; - default: - assert(!"unknown src_prc"); + case ov::element::f32: + case ov::element::i32: + uni_vmovups(ymm_src, op); + break; + case ov::element::i8: + uni_vpmovsxbd(ymm_src, op); + break; + case ov::element::u8: + uni_vpmovzxbd(ymm_src, op); + break; + default: + assert(!"unknown src_prc"); } if (src_prc != ov::element::f32) { @@ -709,20 +805,20 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ } } - inline void load_vector(Xmm xmm_src, const Xbyak::Address &op, ov::element::Type src_prc) { + inline void load_vector(Xmm xmm_src, const Xbyak::Address& op, ov::element::Type src_prc) { switch (src_prc) { - case ov::element::f32: - case ov::element::i32: - uni_vmovups(xmm_src, op); - break; - case ov::element::i8: - uni_vpmovsxbd(xmm_src, op); - break; - case ov::element::u8: - uni_vpmovzxbd(xmm_src, op); - break; - default: - assert(!"unknown src_prc"); + case ov::element::f32: + case ov::element::i32: + uni_vmovups(xmm_src, op); + break; + case ov::element::i8: + uni_vpmovsxbd(xmm_src, op); + break; + case ov::element::u8: + uni_vpmovzxbd(xmm_src, op); + break; + default: + assert(!"unknown src_prc"); } if (src_prc != ov::element::f32) { @@ -730,22 +826,22 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ } } - inline void load_scalar(Xmm xmm_src, const Xbyak::Address &op, ov::element::Type src_prc) { + inline void load_scalar(Xmm xmm_src, const Xbyak::Address& op, ov::element::Type src_prc) { switch (src_prc) { - case ov::element::f32: - case ov::element::i32: - uni_vmovss(xmm_src, op); - break; - case ov::element::i8: - movsx(reg_tmp_32, op); - uni_vmovq(xmm_src, reg_tmp_64); - break; - case ov::element::u8: - movzx(reg_tmp_32, op); - uni_vmovq(xmm_src, reg_tmp_64); - break; - default: - assert(!"unknown src_prc"); + case ov::element::f32: + case ov::element::i32: + uni_vmovss(xmm_src, op); + break; + case ov::element::i8: + movsx(reg_tmp_32, op); + uni_vmovq(xmm_src, reg_tmp_64); + break; + case ov::element::u8: + movzx(reg_tmp_32, op); + uni_vmovq(xmm_src, reg_tmp_64); + break; + default: + assert(!"unknown src_prc"); } if (src_prc != ov::element::f32) { @@ -753,29 +849,29 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ } } - inline void store_vector(const Xbyak::Address &op, Zmm zmm_dst, ov::element::Type dst_prc) { + inline void store_vector(const Xbyak::Address& op, Zmm zmm_dst, ov::element::Type dst_prc) { if (dst_prc != ov::element::f32) { uni_vcvtps2dq(zmm_dst, zmm_dst); } switch (dst_prc) { - case ov::element::f32: - case ov::element::i32: - uni_vmovups(op, zmm_dst); - break; - case ov::element::i8: - vpmovsdb(op, zmm_dst); - break; - case ov::element::u8: - vpmaxsd(zmm_dst, zmm_dst, vmm_zero); - vpmovusdb(op, zmm_dst); - break; - default: - assert(!"unknown dst_prc"); - } - } - - inline void store_vector(const Xbyak::Address &op, Ymm ymm_dst, ov::element::Type dst_prc) { + case ov::element::f32: + case ov::element::i32: + uni_vmovups(op, zmm_dst); + break; + case ov::element::i8: + vpmovsdb(op, zmm_dst); + break; + case ov::element::u8: + vpmaxsd(zmm_dst, zmm_dst, vmm_zero); + vpmovusdb(op, zmm_dst); + break; + default: + assert(!"unknown dst_prc"); + } + } + + inline void store_vector(const Xbyak::Address& op, Ymm ymm_dst, ov::element::Type dst_prc) { Xmm xmm_dst = Xmm(ymm_dst.getIdx()); if (dst_prc != ov::element::f32) { @@ -783,82 +879,82 @@ struct jit_uni_quantization_kernel : public jit_uni_quantize_kernel, public jit_ } switch (dst_prc) { - case ov::element::f32: - case ov::element::i32: - uni_vmovups(op, ymm_dst); - break; - case ov::element::i8: - uni_vpackssdw(ymm_dst, ymm_dst, ymm_dst); + case ov::element::f32: + case ov::element::i32: + uni_vmovups(op, ymm_dst); + break; + case ov::element::i8: + uni_vpackssdw(ymm_dst, ymm_dst, ymm_dst); - vpermq(ymm_dst, ymm_dst, 0x08); + vpermq(ymm_dst, ymm_dst, 0x08); - uni_vpacksswb(ymm_dst, ymm_dst, ymm_dst); + uni_vpacksswb(ymm_dst, ymm_dst, ymm_dst); - vmovq(op, xmm_dst); - break; - case ov::element::u8: - uni_vpackusdw(ymm_dst, ymm_dst, ymm_dst); + vmovq(op, xmm_dst); + break; + case ov::element::u8: + uni_vpackusdw(ymm_dst, ymm_dst, ymm_dst); - vpermq(ymm_dst, ymm_dst, 0x08); + vpermq(ymm_dst, ymm_dst, 0x08); - uni_vpackuswb(ymm_dst, ymm_dst, ymm_dst); + uni_vpackuswb(ymm_dst, ymm_dst, ymm_dst); - vmovq(op, xmm_dst); - break; - default: - assert(!"unknown dst_prc"); + vmovq(op, xmm_dst); + break; + default: + assert(!"unknown dst_prc"); } } - inline void store_vector(const Xbyak::Address &op, Xmm xmm_dst, ov::element::Type dst_prc) { + inline void store_vector(const Xbyak::Address& op, Xmm xmm_dst, ov::element::Type dst_prc) { if (dst_prc != ov::element::f32) { uni_vcvtps2dq(xmm_dst, xmm_dst); } switch (dst_prc) { - case ov::element::f32: - case ov::element::i32: - uni_vmovups(op, xmm_dst); - break; - case ov::element::i8: - uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst); - uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst); - uni_vmovd(op, xmm_dst); - break; - case ov::element::u8: - uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst); - uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst); - uni_vmovd(op, xmm_dst); - break; - default: - assert(!"unknown dst_prc"); - } - } - - inline void store_scalar(const Xbyak::Address &op, Xmm xmm_dst, ov::element::Type dst_prc) { + case ov::element::f32: + case ov::element::i32: + uni_vmovups(op, xmm_dst); + break; + case ov::element::i8: + uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst); + uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst); + uni_vmovd(op, xmm_dst); + break; + case ov::element::u8: + uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst); + uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst); + uni_vmovd(op, xmm_dst); + break; + default: + assert(!"unknown dst_prc"); + } + } + + inline void store_scalar(const Xbyak::Address& op, Xmm xmm_dst, ov::element::Type dst_prc) { if (dst_prc != ov::element::f32) { uni_vcvtps2dq(xmm_dst, xmm_dst); } switch (dst_prc) { - case ov::element::f32: - case ov::element::i32: - uni_vmovss(op, xmm_dst); - break; - case ov::element::i8: - uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst); - uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst); - uni_vmovq(reg_tmp_64, xmm_dst); - mov(op, reg_tmp_8); - break; - case ov::element::u8: - uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst); - uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst); - uni_vmovq(reg_tmp_64, xmm_dst); - mov(op, reg_tmp_8); - break; - default: - assert(!"unknown dst_prc"); + case ov::element::f32: + case ov::element::i32: + uni_vmovss(op, xmm_dst); + break; + case ov::element::i8: + uni_vpackssdw(xmm_dst, xmm_dst, xmm_dst); + uni_vpacksswb(xmm_dst, xmm_dst, xmm_dst); + uni_vmovq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_8); + break; + case ov::element::u8: + uni_vpackusdw(xmm_dst, xmm_dst, xmm_dst); + uni_vpackuswb(xmm_dst, xmm_dst, xmm_dst); + uni_vmovq(reg_tmp_64, xmm_dst); + mov(op, reg_tmp_8); + break; + default: + assert(!"unknown dst_prc"); } } }; @@ -877,7 +973,8 @@ bool FakeQuantize::isSupportedOperation(const std::shared_ptr& o } for (size_t i = 1; i < fq->get_input_size(); i++) { if (fq->get_input_partial_shape(i).rank().get_length() > 5) { - errorMessage = "Doesn't support 'range' input with rank: " + std::to_string(fq->get_input_partial_shape(i).rank().get_length()); + errorMessage = "Doesn't support 'range' input with rank: " + + std::to_string(fq->get_input_partial_shape(i).rank().get_length()); return false; } } @@ -935,7 +1032,7 @@ struct FakeQuantKey { seed = hash_combine(seed, jqp.wei_prc.hash()); seed = hash_combine(seed, jqp.dst_prc.hash()); seed = hash_combine(seed, jqp.op_type); - if (jqp.op_type == Algorithm::FQBinarization) { + if (jqp.op_type == Algorithm::FQBinarization) { seed = hash_combine(seed, jqp.c); } else { seed = hash_combine(seed, jqp.broadcasted); @@ -959,8 +1056,8 @@ struct FakeQuantKey { }; } // namespace -FakeQuantize::FakeQuantize(const std::shared_ptr& op, const GraphContext::CPtr context) : - Node(op, context, PassThroughShapeInferFactory()) { +FakeQuantize::FakeQuantize(const std::shared_ptr& op, const GraphContext::CPtr context) + : Node(op, context, PassThroughShapeInferFactory()) { std::string errorMessage; if (isSupportedOperation(op, errorMessage)) { algorithm = Algorithm::FQCommon; @@ -1032,16 +1129,20 @@ FakeQuantize::FakeQuantize(const std::shared_ptr& op, const GraphConte OPENVINO_THROW(errorPrefix, "has different quantization axis size on 'data' and 'range' inputs"); } - const auto inputLowNode = std::dynamic_pointer_cast(fq->get_input_node_shared_ptr(1)); + const auto inputLowNode = + std::dynamic_pointer_cast(fq->get_input_node_shared_ptr(1)); auto inputLowData = inputLowNode->cast_vector(); - const auto inputHighNode = std::dynamic_pointer_cast(fq->get_input_node_shared_ptr(2)); + const auto inputHighNode = + std::dynamic_pointer_cast(fq->get_input_node_shared_ptr(2)); auto inputHighData = inputHighNode->cast_vector(); - const auto outputLowNode = std::dynamic_pointer_cast(fq->get_input_node_shared_ptr(3)); + const auto outputLowNode = + std::dynamic_pointer_cast(fq->get_input_node_shared_ptr(3)); auto outputLowData = outputLowNode->cast_vector(); - const auto outputHighNode = std::dynamic_pointer_cast(fq->get_input_node_shared_ptr(4)); + const auto outputHighNode = + std::dynamic_pointer_cast(fq->get_input_node_shared_ptr(4)); auto outputHighData = outputHighNode->cast_vector(); binarization = levels == 2; @@ -1092,7 +1193,7 @@ FakeQuantize::FakeQuantize(const std::shared_ptr& op, const GraphConte } } } else { - auto allElementsAreEqual = [&](const std::vector &data, size_t size) { + auto allElementsAreEqual = [&](const std::vector& data, size_t size) { if (size == 0) return true; @@ -1146,9 +1247,21 @@ FakeQuantize::FakeQuantize(const std::shared_ptr& op, const GraphConte broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)] = outputScaleSize == 1; broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)] = outputShiftSize == 1; - if (everyone_is(1u, cropLowSize, cropHighSize, inputScaleSize, inputShiftSize, outputScaleSize, outputShiftSize)) + if (everyone_is(1u, + cropLowSize, + cropHighSize, + inputScaleSize, + inputShiftSize, + outputScaleSize, + outputShiftSize)) broadcastingPolicy = PerTensor; - else if (one_of(1u, cropLowSize, cropHighSize, inputScaleSize, inputShiftSize, outputScaleSize, outputShiftSize)) + else if (one_of(1u, + cropLowSize, + cropHighSize, + inputScaleSize, + inputShiftSize, + outputScaleSize, + outputShiftSize)) broadcastingPolicy = Mixed; else broadcastingPolicy = PerChannel; @@ -1224,7 +1337,10 @@ FakeQuantize::FakeQuantize(const std::shared_ptr& op, const GraphConte bool isFakeQuantization = true; bool isFakeQuantizationWithScale = true; - for (size_t i = 0; i < std::max(inputLowAxisSize, std::max(outputLowAxisSize, std::max(inputHighAxisSize, outputHighAxisSize))); i++) { + for (size_t i = 0; + i < std::max(inputLowAxisSize, + std::max(outputLowAxisSize, std::max(inputHighAxisSize, outputHighAxisSize))); + i++) { float il = inputLowData[isInputLowBroadcasted ? 0 : i]; float ol = outputLowData[isOutputLowBroadcasted ? 0 : i]; float ih = inputHighData[isInputHighBroadcasted ? 0 : i]; @@ -1236,7 +1352,10 @@ FakeQuantize::FakeQuantize(const std::shared_ptr& op, const GraphConte } if (isFakeQuantizationWithScale) { - for (size_t i = 0; i < std::max(inputLowAxisSize, std::max(outputLowAxisSize, std::max(inputHighAxisSize, outputHighAxisSize))); i++) { + for (size_t i = 0; + i < std::max(inputLowAxisSize, + std::max(outputLowAxisSize, std::max(inputHighAxisSize, outputHighAxisSize))); + i++) { float il = inputLowData[isInputLowBroadcasted ? 0 : i]; float ol = outputLowData[isOutputLowBroadcasted ? 0 : i]; float ih = inputHighData[isInputHighBroadcasted ? 0 : i]; @@ -1255,22 +1374,22 @@ FakeQuantize::FakeQuantize(const std::shared_ptr& op, const GraphConte std::vector FakeQuantize::getDataFormats() const { // Special case for first FQ in the network - const auto &dims = getInputShapeAtPort(0).getDims(); + const auto& dims = getInputShapeAtPort(0).getDims(); if (dims[getAxis()] == 3) { - return { LayoutType::ncsp }; + return {LayoutType::ncsp}; } else { if (isBinarization()) { - return { LayoutType::nspc }; + return {LayoutType::nspc}; } else { if (one_of(dims.size(), 4u, 5u)) { if (getAxis() == 1) { auto blkFormat = mayiuse(cpu::x64::avx512_core) ? LayoutType::nCsp16c : LayoutType::nCsp8c; - return { blkFormat, LayoutType::nspc, LayoutType::ncsp }; + return {blkFormat, LayoutType::nspc, LayoutType::ncsp}; } else { - return { LayoutType::ncsp }; + return {LayoutType::ncsp}; } } else { - return { LayoutType::ncsp }; + return {LayoutType::ncsp}; } } } @@ -1284,10 +1403,12 @@ void FakeQuantize::init() { inputPrecision = getOriginalInputPrecisionAtPort(0); outputPrecision = getOriginalOutputPrecisionAtPort(0); - if (inputPrecision != ov::element::f32 && inputPrecision != ov::element::u8 && inputPrecision != ov::element::i8) + if (inputPrecision != ov::element::f32 && inputPrecision != ov::element::u8 && + inputPrecision != ov::element::i8) inputPrecision = ov::element::f32; - if (outputPrecision != ov::element::f32 && outputPrecision != ov::element::u8 && outputPrecision != ov::element::i8) + if (outputPrecision != ov::element::f32 && outputPrecision != ov::element::u8 && + outputPrecision != ov::element::i8) outputPrecision = ov::element::f32; } } @@ -1381,7 +1502,8 @@ bool FakeQuantize::needPrepareParams() const { if (!selectedPrimitiveDescriptor) OPENVINO_THROW("CPU quantize node with name '", getName(), "' doesn't have primitive descriptors."); - if (internalBlobMemory.empty() || (selectedPrimitiveDescriptor->getImplementationType() != impl_desc_type::ref && inputShapesModified())) { + if (internalBlobMemory.empty() || + (selectedPrimitiveDescriptor->getImplementationType() != impl_desc_type::ref && inputShapesModified())) { return true; } @@ -1389,7 +1511,8 @@ bool FakeQuantize::needPrepareParams() const { const auto newPaddedSize = rnd_up(axisSize, 16); const auto currPaddedSize = rnd_up(currentAxisSize, 16); - return newPaddedSize != currPaddedSize || ((isInputLowBroadcasted || isOutputHighBroadcasted) && axisSize != currentAxisSize); + return newPaddedSize != currPaddedSize || + ((isInputLowBroadcasted || isOutputHighBroadcasted) && axisSize != currentAxisSize); } return false; } @@ -1401,26 +1524,33 @@ void FakeQuantize::prepareParams() { OPENVINO_ASSERT(newPaddedSize != 0); if (internalBlobMemory.empty() || newPaddedSize != rnd_up(currentAxisSize, 16) || - ((isInputLowBroadcasted || isOutputHighBroadcasted) && axisSize != currentAxisSize)) { - DnnlBlockedMemoryDesc weightsDataDesc(Shape(VectorDims{newPaddedSize}), memory::data_type::f32, memory::format_tag::x); + ((isInputLowBroadcasted || isOutputHighBroadcasted) && axisSize != currentAxisSize)) { + DnnlBlockedMemoryDesc weightsDataDesc(Shape(VectorDims{newPaddedSize}), + memory::data_type::f32, + memory::format_tag::x); constexpr size_t numBinFqIntBlob = 2; bool needUpdThr = false, needUpdMask = false; if (isInputLowBroadcasted && axisSize != currentAxisSize) { binarizationThresholds.resize(newPaddedSize); - std::fill(binarizationThresholds.begin() + 1, binarizationThresholds.begin() + axisSize, binarizationThresholds[0]); + std::fill(binarizationThresholds.begin() + 1, + binarizationThresholds.begin() + axisSize, + binarizationThresholds[0]); std::fill(binarizationThresholds.begin() + axisSize, binarizationThresholds.end(), 0.f); needUpdThr = true; } if (isOutputHighBroadcasted && axisSize != currentAxisSize) { binarizationOutputMask.resize(newPaddedSize); - std::fill(binarizationOutputMask.begin() + 1, binarizationOutputMask.begin() + axisSize, binarizationOutputMask[0]); + std::fill(binarizationOutputMask.begin() + 1, + binarizationOutputMask.begin() + axisSize, + binarizationOutputMask[0]); std::fill(binarizationOutputMask.begin() + axisSize, binarizationOutputMask.end(), 0); needUpdMask = true; } if (internalBlobMemory.empty() || needUpdThr) { - auto binarizationThresholdsDataMem = std::make_shared(getEngine(), weightsDataDesc, getBinarizationTresholdsPtr()); + auto binarizationThresholdsDataMem = + std::make_shared(getEngine(), weightsDataDesc, getBinarizationTresholdsPtr()); if (internalBlobMemory.empty()) { internalBlobMemory.push_back(binarizationThresholdsDataMem); } else { @@ -1429,7 +1559,8 @@ void FakeQuantize::prepareParams() { } if (internalBlobMemory.size() == (numBinFqIntBlob - 1) || needUpdMask) { - auto binarizationMaskDataMem = std::make_shared(getEngine(), weightsDataDesc, getBinarizationOutputMaskPtr()); + auto binarizationMaskDataMem = + std::make_shared(getEngine(), weightsDataDesc, getBinarizationOutputMaskPtr()); if (internalBlobMemory.size() == (numBinFqIntBlob - 1)) { internalBlobMemory.push_back(binarizationMaskDataMem); } else { @@ -1449,31 +1580,39 @@ void FakeQuantize::createPrimitive() { if (selectedPrimitiveDescriptor->getImplementationType() != impl_desc_type::ref) { const auto& config = getSelectedPrimitiveDescriptor()->getConfig(); - //Form FakeQuanKey + // Form FakeQuanKey FakeQuantKey key = {}; key.jqp.src_prc = config.inConfs[0].getMemDesc()->getPrecision(); key.jqp.wei_prc = ov::element::f32; key.jqp.dst_prc = config.outConfs[0].getMemDesc()->getPrecision(); - const auto &srcMemory = getParentEdgeAt(0)->getMemory(); - const auto &srcDesc = srcMemory.getDesc(); + const auto& srcMemory = getParentEdgeAt(0)->getMemory(); + const auto& srcDesc = srcMemory.getDesc(); key.jqp.is_planar = srcDesc.hasLayoutType(LayoutType::ncsp) && one_of(srcDesc.getShape().getRank(), 3u, 4u, 5u); key.jqp.op_type = getAlgorithm(); if (isBinarization()) { - const auto &inDims = srcMemory.getStaticDims(); + const auto& inDims = srcMemory.getStaticDims(); key.jqp.c = inDims.size() > 1 ? inDims[1] : 1; } else { - // in case of blocked layout we need to extend vectors to prevent read from unallocated memory - size_t paddedSize = srcDesc.hasLayoutType(LayoutType::nCsp16c) ? 16 : srcDesc.hasLayoutType(LayoutType::nCsp8c) ? 8 : 1; + // in case of blocked layout we need to extend vectors to prevent read from unallocated memory + size_t paddedSize = srcDesc.hasLayoutType(LayoutType::nCsp16c) ? 16 + : srcDesc.hasLayoutType(LayoutType::nCsp8c) ? 8 + : 1; if (paddedSize != 1) { - if (!broadcasted[static_cast(FQ_add_input_type::CROP_LOW)]) cropLow.resize(rnd_up(cropLow.size(), paddedSize)); - if (!broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)]) cropHigh.resize(rnd_up(cropHigh.size(), paddedSize)); - if (!broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)]) inputScale.resize(rnd_up(inputScale.size(), paddedSize)); - if (!broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)]) inputShift.resize(rnd_up(inputShift.size(), paddedSize)); - if (!broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)]) outputScale.resize(rnd_up(outputScale.size(), paddedSize)); - if (!broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)]) outputShift.resize(rnd_up(outputShift.size(), paddedSize)); + if (!broadcasted[static_cast(FQ_add_input_type::CROP_LOW)]) + cropLow.resize(rnd_up(cropLow.size(), paddedSize)); + if (!broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)]) + cropHigh.resize(rnd_up(cropHigh.size(), paddedSize)); + if (!broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)]) + inputScale.resize(rnd_up(inputScale.size(), paddedSize)); + if (!broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)]) + inputShift.resize(rnd_up(inputShift.size(), paddedSize)); + if (!broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)]) + outputScale.resize(rnd_up(outputScale.size(), paddedSize)); + if (!broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)]) + outputShift.resize(rnd_up(outputShift.size(), paddedSize)); } key.jqp.broadcasted = broadcasted; @@ -1530,11 +1669,10 @@ void FakeQuantize::executeReference() { parallel_nd(N, CB, D, H, W, [&](dim_t n, dim_t cb, dim_t d, dim_t h, dim_t w) { uint8_t bin_val = 0x00; for (int c = cb * nbits, shift = 0; c < std::min(static_cast(C), (cb + 1) * nbits); c++, shift++) { - size_t src_off = srcDims.size() == 4 ? - n * s_str[0] + c * s_str[1] + h * s_str[2] + w * s_str[3] : - srcDims.size() == 5 ? - n * s_str[0] + c * s_str[1] + d * s_str[2] + h * s_str[3] + w * s_str[4] : - n * s_str[0] + c * s_str[1]; + size_t src_off = srcDims.size() == 4 ? n * s_str[0] + c * s_str[1] + h * s_str[2] + w * s_str[3] + : srcDims.size() == 5 + ? n * s_str[0] + c * s_str[1] + d * s_str[2] + h * s_str[3] + w * s_str[4] + : n * s_str[0] + c * s_str[1]; float val = src[src_off]; float thr = thresholds[c]; @@ -1546,11 +1684,10 @@ void FakeQuantize::executeReference() { bin_val |= (bit << shift); } - size_t dst_off = dstDims.size() == 4 ? - n * d_str[0] + (cb * nbits) * d_str[1] + h * d_str[2] + w * d_str[3] : - dstDims.size() == 5 ? - n * d_str[0] + (cb * nbits) * d_str[1] + d * d_str[2] + h * d_str[3] + w * d_str[4] : - n * d_str[0] + (cb * nbits) * d_str[1]; + size_t dst_off = dstDims.size() == 4 ? n * d_str[0] + (cb * nbits) * d_str[1] + h * d_str[2] + w * d_str[3] + : dstDims.size() == 5 + ? n * d_str[0] + (cb * nbits) * d_str[1] + d * d_str[2] + h * d_str[3] + w * d_str[4] + : n * d_str[0] + (cb * nbits) * d_str[1]; dst[dst_off / nbits] = bin_val; }); @@ -1558,46 +1695,44 @@ void FakeQuantize::executeReference() { auto dst = dstMemory->getDataAs(); parallel_nd(N, C, D, H, W, [&](dim_t n, dim_t c, dim_t d, dim_t h, dim_t w) { - size_t src_off = srcDims.size() == 5 ? - n * s_str[0] + c * s_str[1] + d * s_str[2] + h * s_str[3] + w * s_str[4] : - srcDims.size() == 4 ? - n * s_str[0] + c * s_str[1] + h * s_str[2] + w * s_str[3] : - srcDims.size() == 3 ? - n * s_str[0] + c * s_str[1] + h * s_str[2] : - srcDims.size() == 2 ? - n * s_str[0] + c * s_str[1] : - n * s_str[0]; + size_t src_off = srcDims.size() == 5 + ? n * s_str[0] + c * s_str[1] + d * s_str[2] + h * s_str[3] + w * s_str[4] + : srcDims.size() == 4 ? n * s_str[0] + c * s_str[1] + h * s_str[2] + w * s_str[3] + : srcDims.size() == 3 ? n * s_str[0] + c * s_str[1] + h * s_str[2] + : srcDims.size() == 2 ? n * s_str[0] + c * s_str[1] + : n * s_str[0]; float src_val = src[src_off]; int wei_idx = getAxis() == 0 ? n : c; float cl = broadcasted[static_cast(FQ_add_input_type::CROP_LOW)] ? cropLow[0] : cropLow[wei_idx]; float ch = broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)] ? cropHigh[0] : cropHigh[wei_idx]; - float isc = broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)] ? inputScale[0] : inputScale[wei_idx]; - float ish = broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)] ? inputShift[0] : inputShift[wei_idx]; - float osc = broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)] ? outputScale[0] : outputScale[wei_idx]; - float osh = broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)] ? outputShift[0] : outputShift[wei_idx]; + float isc = + broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)] ? inputScale[0] : inputScale[wei_idx]; + float ish = + broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)] ? inputShift[0] : inputShift[wei_idx]; + float osc = broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)] ? outputScale[0] + : outputScale[wei_idx]; + float osh = broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)] ? outputShift[0] + : outputShift[wei_idx]; float dst_val = nstl::min(ch, nstl::max(cl, src_val)); dst_val = dst_val * isc + ish; dst_val = roundf(dst_val); dst_val = dst_val * osc + osh; - size_t dst_off = dstDims.size() == 5 ? - n * d_str[0] + c * d_str[1] + d * d_str[2] + h * d_str[3] + w * d_str[4] : - dstDims.size() == 4 ? - n * d_str[0] + c * d_str[1] + h * d_str[2] + w * d_str[3] : - dstDims.size() == 3 ? - n * d_str[0] + c * d_str[1] + h * d_str[2] : - dstDims.size() == 2 ? - n * d_str[0] + c * d_str[1] : - n * d_str[0]; + size_t dst_off = dstDims.size() == 5 + ? n * d_str[0] + c * d_str[1] + d * d_str[2] + h * d_str[3] + w * d_str[4] + : dstDims.size() == 4 ? n * d_str[0] + c * d_str[1] + h * d_str[2] + w * d_str[3] + : dstDims.size() == 3 ? n * d_str[0] + c * d_str[1] + h * d_str[2] + : dstDims.size() == 2 ? n * d_str[0] + c * d_str[1] + : n * d_str[0]; dst[dst_off] = dst_val; }); } } -void FakeQuantize::executeBinarization(const std::unique_ptr &pKernel) const { +void FakeQuantize::executeBinarization(const std::unique_ptr& pKernel) const { #if defined(OPENVINO_ARCH_X86_64) auto srcMemory = getSrcMemoryAtPort(0); auto dstMemory = getDstMemoryAtPort(0); @@ -1628,8 +1763,8 @@ void FakeQuantize::executeBinarization(const std::unique_ptr &pKernel) const { +void FakeQuantize::executeQuantization(const std::unique_ptr& pKernel) const { #if defined(OPENVINO_ARCH_X86_64) auto srcMemory = getSrcMemoryAtPort(0); auto dstMemory = getDstMemoryAtPort(0); @@ -1651,10 +1786,11 @@ void FakeQuantize::executeQuantization(const std::unique_ptrjqp_; + const auto& jqp = pKernel->jqp_; auto src_type_size = jqp.src_prc.size(); auto dst_type_size = jqp.dst_prc.size(); @@ -1691,15 +1827,20 @@ void FakeQuantize::executeQuantization(const std::unique_ptr(FQ_add_input_type::CROP_LOW)] ? &cropLow[0] : &cropLow[c]; - arg.crop_high = broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)] ? &cropHigh[0] : &cropHigh[c]; - arg.input_scale = broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)] ? &inputScale[0] : &inputScale[c]; - arg.input_shift = broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)] ? &inputShift[0] : &inputShift[c]; - arg.output_scale = broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)] ? &outputScale[0] : &outputScale[c]; - arg.output_shift = broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)] ? &outputShift[0] : &outputShift[c]; - - arg.src_step = (size_t) blk_size * src_type_size; - arg.dst_step = (size_t) blk_size * dst_type_size; - arg.block_size = (size_t) blk_size; + arg.crop_high = + broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)] ? &cropHigh[0] : &cropHigh[c]; + arg.input_scale = + broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)] ? &inputScale[0] : &inputScale[c]; + arg.input_shift = + broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)] ? &inputShift[0] : &inputShift[c]; + arg.output_scale = + broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)] ? &outputScale[0] : &outputScale[c]; + arg.output_shift = + broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)] ? &outputShift[0] : &outputShift[c]; + + arg.src_step = (size_t)blk_size * src_type_size; + arg.dst_step = (size_t)blk_size * dst_type_size; + arg.block_size = (size_t)blk_size; arg.work_amount = (size_t)H; (*pKernel)(&arg); @@ -1714,22 +1855,27 @@ void FakeQuantize::executeQuantization(const std::unique_ptr(FQ_add_input_type::CROP_LOW)] ? &cropLow[0] : &cropLow[c]; - arg.crop_high = broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)] ? &cropHigh[0] : &cropHigh[c]; - arg.input_scale = broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)] ? &inputScale[0] : &inputScale[c]; - arg.input_shift = broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)] ? &inputShift[0] : &inputShift[c]; - arg.output_scale = broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)] ? &outputScale[0] : &outputScale[c]; - arg.output_shift = broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)] ? &outputShift[0] : &outputShift[c]; - - arg.src_step = is_blk_format ? (size_t) blk_size * src_type_size : (size_t) C * src_type_size; - arg.dst_step = is_blk_format ? (size_t) blk_size * dst_type_size : (size_t) C * dst_type_size; - arg.block_size = is_blk_format ? (size_t) blk_size : nstl::min(blk_size, C - c); + arg.crop_high = + broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)] ? &cropHigh[0] : &cropHigh[c]; + arg.input_scale = + broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)] ? &inputScale[0] : &inputScale[c]; + arg.input_shift = + broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)] ? &inputShift[0] : &inputShift[c]; + arg.output_scale = + broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)] ? &outputScale[0] : &outputScale[c]; + arg.output_shift = + broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)] ? &outputShift[0] : &outputShift[c]; + + arg.src_step = is_blk_format ? (size_t)blk_size * src_type_size : (size_t)C * src_type_size; + arg.dst_step = is_blk_format ? (size_t)blk_size * dst_type_size : (size_t)C * dst_type_size; + arg.block_size = is_blk_format ? (size_t)blk_size : nstl::min(blk_size, C - c); arg.work_amount = (size_t)std::min(static_cast(batch_size), H * W - b * batch_size); (*pKernel)(&arg); @@ -1740,25 +1886,29 @@ void FakeQuantize::executeQuantization(const std::unique_ptr(FQ_add_input_type::CROP_LOW)] ? &cropLow[0] : &cropLow[c]; - arg.crop_high = broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)] ? &cropHigh[0] : &cropHigh[c]; - arg.input_scale = broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)] ? &inputScale[0] : &inputScale[c]; - arg.input_shift = broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)] ? &inputShift[0] : &inputShift[c]; - arg.output_scale = broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)] ? &outputScale[0] : &outputScale[c]; - arg.output_shift = broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)] ? &outputShift[0] : &outputShift[c]; - - arg.src_step = is_blk_format ? (size_t) blk_size * src_type_size : (size_t) C * src_type_size; - arg.dst_step = is_blk_format ? (size_t) blk_size * dst_type_size : (size_t) C * dst_type_size; - arg.block_size = (is_blk_format && srcDims.size() != 2) ? (size_t) blk_size : nstl::min(blk_size, C - c); - arg.work_amount = (size_t) W; + arg.crop_high = + broadcasted[static_cast(FQ_add_input_type::CROP_HIGH)] ? &cropHigh[0] : &cropHigh[c]; + arg.input_scale = + broadcasted[static_cast(FQ_add_input_type::INPUT_SCALE)] ? &inputScale[0] : &inputScale[c]; + arg.input_shift = + broadcasted[static_cast(FQ_add_input_type::INPUT_SHIFT)] ? &inputShift[0] : &inputShift[c]; + arg.output_scale = + broadcasted[static_cast(FQ_add_input_type::OUTPUT_SCALE)] ? &outputScale[0] : &outputScale[c]; + arg.output_shift = + broadcasted[static_cast(FQ_add_input_type::OUTPUT_SHIFT)] ? &outputShift[0] : &outputShift[c]; + + arg.src_step = is_blk_format ? (size_t)blk_size * src_type_size : (size_t)C * src_type_size; + arg.dst_step = is_blk_format ? (size_t)blk_size * dst_type_size : (size_t)C * dst_type_size; + arg.block_size = (is_blk_format && srcDims.size() != 2) ? (size_t)blk_size : nstl::min(blk_size, C - c); + arg.work_amount = (size_t)W; (*pKernel)(&arg); }); @@ -1778,7 +1928,7 @@ void FakeQuantize::execute(dnnl::stream strm) { } } -void FakeQuantize::initializePostOpData(const VectorDims &dims, const size_t bufferAlignment, bool doRounding) { +void FakeQuantize::initializePostOpData(const VectorDims& dims, const size_t bufferAlignment, bool doRounding) { if (postOpDataVersion == parameterVersion) return; @@ -1789,11 +1939,15 @@ void FakeQuantize::initializePostOpData(const VectorDims &dims, const size_t buf binarizationOutputMask.resize(axisPaddedSize, 0); if (isInputLowBroadcasted) { - std::fill(binarizationThresholds.begin() + 1, binarizationThresholds.begin() + realAxisSize, binarizationThresholds[0]); + std::fill(binarizationThresholds.begin() + 1, + binarizationThresholds.begin() + realAxisSize, + binarizationThresholds[0]); std::fill(binarizationThresholds.begin() + realAxisSize, binarizationThresholds.end(), 0.f); } if (isOutputHighBroadcasted) { - std::fill(binarizationOutputMask.begin() + 1, binarizationOutputMask.begin() + realAxisSize, binarizationOutputMask[0]); + std::fill(binarizationOutputMask.begin() + 1, + binarizationOutputMask.begin() + realAxisSize, + binarizationOutputMask[0]); std::fill(binarizationThresholds.begin() + realAxisSize, binarizationThresholds.end(), 0.f); } } else { @@ -1803,7 +1957,7 @@ void FakeQuantize::initializePostOpData(const VectorDims &dims, const size_t buf postOpDataVersion = parameterVersion; } -void FakeQuantize::initializePostOpDataLegacy(const VectorDims &dims, const size_t bufferAlignment) { +void FakeQuantize::initializePostOpDataLegacy(const VectorDims& dims, const size_t bufferAlignment) { if (legacyPostOpDataVersion == parameterVersion) return; @@ -1815,11 +1969,15 @@ void FakeQuantize::initializePostOpDataLegacy(const VectorDims &dims, const size binarizationOutputMask.resize(axisPaddedSize, 0); if (isInputLowBroadcasted) { - std::fill(binarizationThresholds.begin() + 1, binarizationThresholds.begin() + realAxisSize, binarizationThresholds[0]); + std::fill(binarizationThresholds.begin() + 1, + binarizationThresholds.begin() + realAxisSize, + binarizationThresholds[0]); std::fill(binarizationThresholds.begin() + realAxisSize, binarizationThresholds.end(), 0.f); } if (isOutputHighBroadcasted) { - std::fill(binarizationOutputMask.begin() + 1, binarizationOutputMask.begin() + realAxisSize, binarizationOutputMask[0]); + std::fill(binarizationOutputMask.begin() + 1, + binarizationOutputMask.begin() + realAxisSize, + binarizationOutputMask[0]); std::fill(binarizationThresholds.begin() + realAxisSize, binarizationThresholds.end(), 0.f); } @@ -1839,7 +1997,10 @@ void FakeQuantize::initializePostOpDataLegacy(const VectorDims &dims, const size legacyPostOpDataVersion = parameterVersion; } -void FakeQuantize::appendMemory(const size_t dataSize, const void *data, MemoryPtr &memPtr, std::vector& postOpsMem) { +void FakeQuantize::appendMemory(const size_t dataSize, + const void* data, + MemoryPtr& memPtr, + std::vector& postOpsMem) { if (!memPtr) { DnnlBlockedMemoryDesc memoryDesc(ov::element::f32, {dataSize}); memPtr = std::make_shared(getEngine(), memoryDesc, data); @@ -1848,12 +2009,15 @@ void FakeQuantize::appendMemory(const size_t dataSize, const void *data, MemoryP } } -void FakeQuantize::appendMemory(const size_t dataSize, const void *data, MemoryPtr &memPtr, std::vector& postOpsMem) { +void FakeQuantize::appendMemory(const size_t dataSize, + const void* data, + MemoryPtr& memPtr, + std::vector& postOpsMem) { postOpsMem.push_back(data); } template -void FakeQuantize::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector& postOpsMem) { +void FakeQuantize::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims& postOpDims, std::vector& postOpsMem) { // try to map fakeQuantizeNode using output scale & eltwise first // if failed, fallback to append_quantization() @@ -1865,21 +2029,40 @@ void FakeQuantize::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &post initializePostOpDataLegacy(postOpDims, bufferAlignment); if (getAlgorithm() == Algorithm::FQBinarization) { - ops.append_binarization(dnnl::algorithm::binarization_depthwise, (const float*)&binarizationThresholds[0], (const float*)&binarizationOutputMask[0]); + ops.append_binarization(dnnl::algorithm::binarization_depthwise, + (const float*)&binarizationThresholds[0], + (const float*)&binarizationOutputMask[0]); } else { - dnnl::algorithm alg = getAlgorithm() == Algorithm::FQQuantization ? dnnl::algorithm::quantization_quantize : - dnnl::algorithm::quantization_quantize_dequantize; + dnnl::algorithm alg = getAlgorithm() == Algorithm::FQQuantization + ? dnnl::algorithm::quantization_quantize + : dnnl::algorithm::quantization_quantize_dequantize; - std::array per_channel = {cropLowSize > 1, cropHighSize > 1, inputScaleSize > 1, - inputShiftSize > 1, outputScaleSize > 1, outputShiftSize > 1}; + std::array per_channel = {cropLowSize > 1, + cropHighSize > 1, + inputScaleSize > 1, + inputShiftSize > 1, + outputScaleSize > 1, + outputShiftSize > 1}; std::array all_default = {false}; - all_default[0] = std::all_of(cropLow.cbegin(), cropLow.cend(), [](float val){ return val == 0.f; }); - all_default[1] = std::all_of(cropHigh.cbegin(), cropHigh.cend(), [](float val){ return val == 0.f; }); - all_default[2] = std::all_of(inputScale.cbegin(), inputScale.cend(), [](float val){ return val == 1.f; }); - all_default[3] = std::all_of(inputShift.cbegin(), inputShift.cend(), [](float val){ return val == 0.f; }); - all_default[4] = std::all_of(outputScale.cbegin(), outputScale.cend(), [](float val){ return val == 1.f; }); - all_default[5] = std::all_of(outputShift.cbegin(), outputShift.cend(), [](float val){ return val == 0.f; }); + all_default[0] = std::all_of(cropLow.cbegin(), cropLow.cend(), [](float val) { + return val == 0.f; + }); + all_default[1] = std::all_of(cropHigh.cbegin(), cropHigh.cend(), [](float val) { + return val == 0.f; + }); + all_default[2] = std::all_of(inputScale.cbegin(), inputScale.cend(), [](float val) { + return val == 1.f; + }); + all_default[3] = std::all_of(inputShift.cbegin(), inputShift.cend(), [](float val) { + return val == 0.f; + }); + all_default[4] = std::all_of(outputScale.cbegin(), outputScale.cend(), [](float val) { + return val == 1.f; + }); + all_default[5] = std::all_of(outputShift.cbegin(), outputShift.cend(), [](float val) { + return val == 0.f; + }); std::array offsets = {0}; offsets[1] = offsets[0] + cropLowSize; @@ -1894,7 +2077,9 @@ void FakeQuantize::appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &post } } -void FakeQuantize::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::unordered_map& postOpsMem, +void FakeQuantize::appendPostOps(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::unordered_map& postOpsMem, const int channelAxis) { std::vector postOpsMemPtrs; appendPostOpsImpl(ops, postOpDims, postOpsMemPtrs); @@ -1906,7 +2091,9 @@ void FakeQuantize::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDi } } -void FakeQuantize::appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector& postOpsMem, +void FakeQuantize::appendPostOps(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::vector& postOpsMem, const int channelAxis) { appendPostOpsImpl(ops, postOpDims, postOpsMem); } @@ -1957,7 +2144,7 @@ void FakeQuantize::updateOptimizedFormula(bool do_rounding) { // per-channel FQ. if (isPerTensor(inputShift, inputShift[0], 0.00005f)) { f.ish.resize(OC); - for (auto & v : f.ish) + for (auto& v : f.ish) v = inputShift[0]; } else { f.ish = inputShift; @@ -2115,7 +2302,7 @@ bool FakeQuantize::appendAttrPostOps(DnnlPostOpsComposerLegacy& dnnlpoc, return true; } -FakeQuantize::FakeQuantizeJitExecutor::FakeQuantizeJitExecutor(const jit_quantize_params &_jqp) { +FakeQuantize::FakeQuantizeJitExecutor::FakeQuantizeJitExecutor(const jit_quantize_params& _jqp) { #if defined(OPENVINO_ARCH_X86_64) bool isBinarization = _jqp.op_type == Algorithm::FQBinarization; if (mayiuse(cpu::x64::avx512_core)) { @@ -2157,6 +2344,6 @@ bool FakeQuantize::created() const { return getType() == Type::FakeQuantize; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/fake_quantize.h b/src/plugins/intel_cpu/src/nodes/fake_quantize.h index 62aea6092451a6..af34c0b91a1a7a 100644 --- a/src/plugins/intel_cpu/src/nodes/fake_quantize.h +++ b/src/plugins/intel_cpu/src/nodes/fake_quantize.h @@ -4,25 +4,17 @@ #pragma once -#include "common/primitive_attr.hpp" -#include "node.h" - #include + +#include "common/primitive_attr.hpp" #include "dnnl_postops_composer_legacy.h" +#include "node.h" namespace ov { namespace intel_cpu { namespace node { -enum class FQ_add_input_type { - CROP_LOW, - CROP_HIGH, - INPUT_SCALE, - INPUT_SHIFT, - OUTPUT_SCALE, - OUTPUT_SHIFT, - INPUTS_SIZE -}; +enum class FQ_add_input_type { CROP_LOW, CROP_HIGH, INPUT_SCALE, INPUT_SHIFT, OUTPUT_SCALE, OUTPUT_SHIFT, INPUTS_SIZE }; struct jit_quantize_params { bool is_planar; @@ -33,8 +25,8 @@ struct jit_quantize_params { Algorithm op_type; - int c; // need only for binarization - std::bitset(FQ_add_input_type::INPUTS_SIZE)> broadcasted; // need only for quantization + int c; // need only for binarization + std::bitset(FQ_add_input_type::INPUTS_SIZE)> broadcasted; // need only for quantization }; struct jit_quantize_call_args { @@ -57,9 +49,9 @@ struct jit_quantize_call_args { }; struct jit_uni_quantize_kernel { - void (*ker_)(const jit_quantize_call_args *); + void (*ker_)(const jit_quantize_call_args*); - void operator()(const jit_quantize_call_args *args) { + void operator()(const jit_quantize_call_args* args) { assert(ker_); ker_(args); } @@ -82,58 +74,116 @@ class FakeQuantize : public Node { void execute(dnnl::stream strm) override; void executeDynamicImpl(dnnl::stream strm) override; - size_t getAxis() const { return axis; } + size_t getAxis() const { + return axis; + } - bool isBinarization() const { return getAlgorithm() == Algorithm::FQBinarization; } + bool isBinarization() const { + return getAlgorithm() == Algorithm::FQBinarization; + } bool needPrepareParams() const override; void prepareParams() override; void createPrimitive() override; - const float* getBinarizationTresholdsPtr() const { return &binarizationThresholds[0]; } - const float* getBinarizationOutputMaskPtr() const { return reinterpret_cast(&binarizationOutputMask[0]); } - size_t getBinarizationTresholdsSize() const { return binarizationThresholds.size(); } - size_t getBinarizationOutputMaskSize() const { return binarizationOutputMask.size(); } + const float* getBinarizationTresholdsPtr() const { + return &binarizationThresholds[0]; + } + const float* getBinarizationOutputMaskPtr() const { + return reinterpret_cast(&binarizationOutputMask[0]); + } + size_t getBinarizationTresholdsSize() const { + return binarizationThresholds.size(); + } + size_t getBinarizationOutputMaskSize() const { + return binarizationOutputMask.size(); + } - const std::vector& getCropLow() const { return cropLow; } - const std::vector& getCropHigh() const { return cropHigh; } - const std::vector& getInputScale() const { return inputScale; } - const std::vector& getInputShift() const { return inputShift; } - const std::vector& getOutputScale() const { return outputScale; } - const std::vector& getOutputShift() const { return outputShift; } - const size_t getLevels() const { return levels; } + const std::vector& getCropLow() const { + return cropLow; + } + const std::vector& getCropHigh() const { + return cropHigh; + } + const std::vector& getInputScale() const { + return inputScale; + } + const std::vector& getInputShift() const { + return inputShift; + } + const std::vector& getOutputScale() const { + return outputScale; + } + const std::vector& getOutputShift() const { + return outputShift; + } + const size_t getLevels() const { + return levels; + } void setCropLow(std::vector newCropLow) { - cropLow = std::move(newCropLow); cropLowSize = cropLow.size(); ++parameterVersion; + cropLow = std::move(newCropLow); + cropLowSize = cropLow.size(); + ++parameterVersion; } void setCropHigh(std::vector newCropHigh) { - cropHigh = std::move(newCropHigh); cropHighSize = cropHigh.size(); ++parameterVersion; + cropHigh = std::move(newCropHigh); + cropHighSize = cropHigh.size(); + ++parameterVersion; } void setInputScale(std::vector newInputScale) { - inputScale = std::move(newInputScale); inputScaleSize = inputScale.size(); ++parameterVersion; + inputScale = std::move(newInputScale); + inputScaleSize = inputScale.size(); + ++parameterVersion; } void setInputShift(std::vector newInputShift) { - inputShift = std::move(newInputShift); inputShiftSize = inputShift.size(); ++parameterVersion; + inputShift = std::move(newInputShift); + inputShiftSize = inputShift.size(); + ++parameterVersion; } void setOutputScale(std::vector newOutputScale) { - outputScale = std::move(newOutputScale); outputScaleSize = outputScale.size(); ++parameterVersion; + outputScale = std::move(newOutputScale); + outputScaleSize = outputScale.size(); + ++parameterVersion; } void setOutputShift(std::vector newOutputShift) { - outputShift = std::move(newOutputShift); outputShiftSize = outputShift.size(); ++parameterVersion; + outputShift = std::move(newOutputShift); + outputShiftSize = outputShift.size(); + ++parameterVersion; } - const std::vector& getFQScales() const { return fqScales; } + const std::vector& getFQScales() const { + return fqScales; + } - bool isInputLowBroadcast() const { return isInputLowBroadcasted; } - bool isInputHighBroadcast() const { return isInputHighBroadcasted; } - bool isOutputLowBroadcast() const { return isOutputLowBroadcasted; } - bool isOutputHighBroadcast() const { return isOutputHighBroadcasted; } + bool isInputLowBroadcast() const { + return isInputLowBroadcasted; + } + bool isInputHighBroadcast() const { + return isInputHighBroadcasted; + } + bool isOutputLowBroadcast() const { + return isOutputLowBroadcasted; + } + bool isOutputHighBroadcast() const { + return isOutputHighBroadcasted; + } - ov::element::Type getInputPrecision() const { return inputPrecision; } - ov::element::Type getOutputPrecision() const { return outputPrecision; } + ov::element::Type getInputPrecision() const { + return inputPrecision; + } + ov::element::Type getOutputPrecision() const { + return outputPrecision; + } - void appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::unordered_map& postOpsMem, const int channelAxis = 1) override; - void appendPostOps(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector& postOpsMem, const int channelAxis = 1) override; + void appendPostOps(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::unordered_map& postOpsMem, + const int channelAxis = 1) override; + void appendPostOps(dnnl::post_ops& ops, + const VectorDims& postOpDims, + std::vector& postOpsMem, + const int channelAxis = 1) override; bool appendAttrPostOps(DnnlPostOpsComposerLegacy& dnnlpoc, bool isLastPostOp, dnnl::memory::data_type outDataType, @@ -143,12 +193,14 @@ class FakeQuantize : public Node { static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; enum BroadcastingPolicy { - PerChannel, // all FQ operations are per channel - PerTensor, // all FQ operations are per tensor - Mixed, // some per channel, some per tensor + PerChannel, // all FQ operations are per channel + PerTensor, // all FQ operations are per tensor + Mixed, // some per channel, some per tensor }; - BroadcastingPolicy getBroadcastingPolicy() const { return broadcastingPolicy; } + BroadcastingPolicy getBroadcastingPolicy() const { + return broadcastingPolicy; + } MemoryPtr cropLowMemory; MemoryPtr cropHighMemory; @@ -165,22 +217,22 @@ class FakeQuantize : public Node { using executorPtr = std::shared_ptr; executorPtr execPtr = nullptr; struct FakeQuantizeJitExecutor : public FakeQuantizeExecutor { - FakeQuantizeJitExecutor(const jit_quantize_params &_jqp); + FakeQuantizeJitExecutor(const jit_quantize_params& _jqp); void exec(const FakeQuantize& node) override; std::unique_ptr pKernel; }; void init() override; std::vector getDataFormats() const; - void initializePostOpData(const VectorDims &postOpDims, const size_t bufferAlignment, bool doRounding); - void initializePostOpDataLegacy(const VectorDims &dims, const size_t bufferAlignment); + void initializePostOpData(const VectorDims& postOpDims, const size_t bufferAlignment, bool doRounding); + void initializePostOpDataLegacy(const VectorDims& dims, const size_t bufferAlignment); void executeReference(); - void executeBinarization(const std::unique_ptr &pKernel) const; - void executeQuantization(const std::unique_ptr &pKernel) const; + void executeBinarization(const std::unique_ptr& pKernel) const; + void executeQuantization(const std::unique_ptr& pKernel) const; - void appendMemory(const size_t dataSize, const void *data, MemoryPtr &memPtr, std::vector& postOpsMem); - void appendMemory(const size_t dataSize, const void *data, MemoryPtr &memPtr, std::vector& postOpsMem); + void appendMemory(const size_t dataSize, const void* data, MemoryPtr& memPtr, std::vector& postOpsMem); + void appendMemory(const size_t dataSize, const void* data, MemoryPtr& memPtr, std::vector& postOpsMem); template - void appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims &postOpDims, std::vector& postOpsMem); + void appendPostOpsImpl(dnnl::post_ops& ops, const VectorDims& postOpDims, std::vector& postOpsMem); size_t levels = 0; @@ -273,6 +325,6 @@ class FakeQuantize : public Node { BroadcastingPolicy broadcastingPolicy; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp b/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp index 0f5c46e8bcd7cd..2df6c0ae7522cc 100644 --- a/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp +++ b/src/plugins/intel_cpu/src/nodes/fullyconnected.cpp @@ -13,6 +13,7 @@ #include "cpu_types.h" #include "dnnl_extension_utils.h" #include "executors/memory_arguments.hpp" +#include "fake_quantize.h" #include "graph_context.h" #include "input.h" #include "memory_desc/blocked_memory_desc.h" @@ -24,17 +25,15 @@ #include "openvino/core/type/element_type.hpp" #include "openvino/runtime/threading/cpu_message.hpp" #include "ov_ops/fully_connected.hpp" +#include "ov_ops/fully_connected_compressed.hpp" #include "ov_ops/fully_connected_quantized.hpp" #include "ov_ops/fully_connected_quantized_legacy.hpp" -#include "ov_ops/fully_connected_compressed.hpp" #include "post_ops.hpp" #include "shape_inference/custom/fullyconnected.hpp" #include "transformations/utils/utils.hpp" #include "utils/debug_capabilities.h" #include "utils/general_utils.h" -#include "fake_quantize.h" - using namespace dnnl; using namespace ov::element; @@ -61,7 +60,8 @@ bool FullyConnected::isSupportedOperation(const std::shared_ptr& if (ov::is_type(op)) { if (!ov::op::util::is_on_constant_path(op->input_value(WEIGHT_SCALES)) || !ov::op::util::is_on_constant_path(op->input_value(WEIGHT_ZERO_POINTS))) { - errorMessage = "Only Constant operation on 'weight scales', and 'weight zero points' inputs is supported"; + errorMessage = + "Only Constant operation on 'weight scales', and 'weight zero points' inputs is supported"; return false; } } @@ -137,19 +137,18 @@ FullyConnected::FullyConnected(const std::shared_ptr& op, const GraphC if (!isSupportedOperation(op, errorMessage)) OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage); - m_atoi[ARG_SRC] = DATA; - m_atoi[ARG_WEI] = WEIGHTS; + m_atoi[ARG_SRC] = DATA; + m_atoi[ARG_WEI] = WEIGHTS; m_atoi[ARG_BIAS] = BIAS; auto mapArgToInput = [&op](std::unordered_map& argToInput, size_t argId, size_t inputId) { - if (op->get_input_size() > inputId && - op->input(inputId).get_element_type() != ov::element::undefined) { + if (op->get_input_size() > inputId && op->input(inputId).get_element_type() != ov::element::undefined) { argToInput[argId] = inputId; } }; if (ov::is_type(op)) { - mapArgToInput(m_atoi, ARG_WEI | ARG_ATTR_SCALES, WEIGHT_SCALES); + mapArgToInput(m_atoi, ARG_WEI | ARG_ATTR_SCALES, WEIGHT_SCALES); mapArgToInput(m_atoi, ARG_WEI | ARG_ATTR_ZERO_POINTS, WEIGHT_ZERO_POINTS); algorithm = Algorithm::FullyConnectedCompressed; } else if (ov::is_type(op)) { @@ -190,7 +189,8 @@ void FullyConnected::needPrepareParamsForTensorParallel() { dim += dims.size(); } OPENVINO_ASSERT(static_cast(dims[dim]) >= tp_cfg.w_size, - getName() + " dim[" + std::to_string(dim) + "] is " + std::to_string(dims[dim]) + ", which is larger than w_size " + std::to_string(tp_cfg.w_size)); + getName() + " dim[" + std::to_string(dim) + "] is " + std::to_string(dims[dim]) + + ", which is larger than w_size " + std::to_string(tp_cfg.w_size)); auto splited_dim_vec = split_parts(dims[dim], tp_cfg.w_size); VectorDims new_dims = std::move(dims); @@ -269,18 +269,34 @@ void FullyConnected::execTensorParallelSync() { for (int idx = 0; idx < tp_cfg.w_size; idx++) { if (wait_list[idx] > 0 && tp_cfg.sub_memory->_memorys_table[tp_cfg.id][idx].flag) { auto new_ptr = static_cast(tp_cfg.sub_memory->_memorys_table[tp_cfg.id][idx].send_buf); - const auto copySize = splited_dim_vec[idx] * prec.size(); // bytes of half selected dim. + const auto copySize = splited_dim_vec[idx] * prec.size(); // bytes of half selected dim. const size_t unloop = 8; size_t step = count / unloop; - parallel_for(step, [&](size_t i){ - cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop) * channel_size, new_ptr + (i * unloop) * copySize, copySize); - cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 1) * channel_size, new_ptr + (i * unloop + 1) * copySize, copySize); - cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 2) * channel_size, new_ptr + (i * unloop + 2) * copySize, copySize); - cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 3) * channel_size, new_ptr + (i * unloop + 3) * copySize, copySize); - cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 4) * channel_size, new_ptr + (i * unloop + 4) * copySize, copySize); - cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 5) * channel_size, new_ptr + (i * unloop + 5) * copySize, copySize); - cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 6) * channel_size, new_ptr + (i * unloop + 6) * copySize, copySize); - cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 7) * channel_size, new_ptr + (i * unloop + 7) * copySize, copySize); + parallel_for(step, [&](size_t i) { + cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop) * channel_size, + new_ptr + (i * unloop) * copySize, + copySize); + cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 1) * channel_size, + new_ptr + (i * unloop + 1) * copySize, + copySize); + cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 2) * channel_size, + new_ptr + (i * unloop + 2) * copySize, + copySize); + cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 3) * channel_size, + new_ptr + (i * unloop + 3) * copySize, + copySize); + cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 4) * channel_size, + new_ptr + (i * unloop + 4) * copySize, + copySize); + cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 5) * channel_size, + new_ptr + (i * unloop + 5) * copySize, + copySize); + cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 6) * channel_size, + new_ptr + (i * unloop + 6) * copySize, + copySize); + cpu_memcpy(dst_ptr + idx * strideSize + (i * unloop + 7) * channel_size, + new_ptr + (i * unloop + 7) * copySize, + copySize); }); size_t tail = count & ~(unloop - 1); for (size_t i = tail; i < count; ++i) { @@ -525,8 +541,10 @@ void FullyConnected::needSplitMemoryForTensorParallel() { memory[ARG_SRC] = getSrcMemoryAtPort(DATA); // wgt // split N direction - tp_cfg.cached_splited_weight = attrs.weightsNonTransposed ? split_vertical(context->getEngine(), std::move(wgt), 0, tp_cfg.w_rank, tp_cfg.w_size) - : split_horizontal(context->getEngine(), std::move(wgt), 0, tp_cfg.w_rank, tp_cfg.w_size); + tp_cfg.cached_splited_weight = + attrs.weightsNonTransposed + ? split_vertical(context->getEngine(), std::move(wgt), 0, tp_cfg.w_rank, tp_cfg.w_size) + : split_horizontal(context->getEngine(), std::move(wgt), 0, tp_cfg.w_rank, tp_cfg.w_size); memory[ARG_WEI] = tp_cfg.cached_splited_weight; // bias if (attrs.withBias) { @@ -539,21 +557,27 @@ void FullyConnected::needSplitMemoryForTensorParallel() { memory[ARG_BIAS] = tp_cfg.cached_splited_bias; // dst memory[ARG_DST] = getDstMemoryAtPort(0); - tp_cfg.cached_dst = split_horizontal(context->getEngine(), std::move(dst), -1, tp_cfg.w_rank, tp_cfg.w_size, false); + tp_cfg.cached_dst = + split_horizontal(context->getEngine(), std::move(dst), -1, tp_cfg.w_rank, tp_cfg.w_size, false); - memory[ARG_DST | ARG_ATTR_SCALES] = split_horizontal(context->getEngine(), memory[ARG_DST | ARG_ATTR_SCALES], 0, tp_cfg.w_rank, tp_cfg.w_size); + memory[ARG_DST | ARG_ATTR_SCALES] = + split_horizontal(context->getEngine(), memory[ARG_DST | ARG_ATTR_SCALES], 0, tp_cfg.w_rank, tp_cfg.w_size); auto scale_mem = std::const_pointer_cast(memory[ARG_WEI | ARG_ATTR_SCALES]); - memory[ARG_WEI | ARG_ATTR_SCALES] = attrs.weightsNonTransposed ? split_vertical(context->getEngine(), scale_mem, 0, tp_cfg.w_rank, tp_cfg.w_size) - : split_horizontal(context->getEngine(), scale_mem, 0, tp_cfg.w_rank, tp_cfg.w_size); + memory[ARG_WEI | ARG_ATTR_SCALES] = + attrs.weightsNonTransposed + ? split_vertical(context->getEngine(), scale_mem, 0, tp_cfg.w_rank, tp_cfg.w_size) + : split_horizontal(context->getEngine(), scale_mem, 0, tp_cfg.w_rank, tp_cfg.w_size); auto zeropoint_mem = std::const_pointer_cast(memory[ARG_WEI | ARG_ATTR_ZERO_POINTS]); auto element_num = zeropoint_mem->getSize() / zeropoint_mem->getPrecision().size(); if (element_num == 1) { tp_cfg.cached_zeropoint = zeropoint_mem; } else { - tp_cfg.cached_zeropoint = attrs.weightsNonTransposed ? split_vertical(context->getEngine(), zeropoint_mem, 0, tp_cfg.w_rank, tp_cfg.w_size) - : split_horizontal(context->getEngine(), zeropoint_mem, 0, tp_cfg.w_rank, tp_cfg.w_size); + tp_cfg.cached_zeropoint = + attrs.weightsNonTransposed + ? split_vertical(context->getEngine(), zeropoint_mem, 0, tp_cfg.w_rank, tp_cfg.w_size) + : split_horizontal(context->getEngine(), zeropoint_mem, 0, tp_cfg.w_rank, tp_cfg.w_size); } } } diff --git a/src/plugins/intel_cpu/src/nodes/fullyconnected.h b/src/plugins/intel_cpu/src/nodes/fullyconnected.h index 177edd3d426339..0b50d882c9e554 100644 --- a/src/plugins/intel_cpu/src/nodes/fullyconnected.h +++ b/src/plugins/intel_cpu/src/nodes/fullyconnected.h @@ -15,8 +15,8 @@ #include "cpu_memory.h" #include "nodes/executors/executor_factory.hpp" -#include "nodes/executors/memory_arguments.hpp" #include "nodes/executors/fullyconnected_config.hpp" +#include "nodes/executors/memory_arguments.hpp" #include "post_ops.hpp" namespace ov { @@ -105,7 +105,7 @@ class FullyConnected : public Node { static bool isConstantInput(const std::shared_ptr& op, InputId port); - std::unordered_map m_atoi; // memory argument id to input id + std::unordered_map m_atoi; // memory argument id to input id void fuseDecompressionConstant(const MemoryCPtr& memory, MemoryCPtr& decompressionValuesPtr); diff --git a/src/plugins/intel_cpu/src/nodes/gather.h b/src/plugins/intel_cpu/src/nodes/gather.h index 6ee097e9a1fbab..c20a56807b0165 100644 --- a/src/plugins/intel_cpu/src/nodes/gather.h +++ b/src/plugins/intel_cpu/src/nodes/gather.h @@ -5,12 +5,13 @@ #pragma once #include -#include "kernels/x64/gather_uni_kernel.hpp" #include #include #include +#include "kernels/x64/gather_uni_kernel.hpp" + namespace ov { namespace intel_cpu { namespace node { @@ -19,7 +20,7 @@ class Gather : public Node { public: Gather(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void createPrimitive() override; void execute(dnnl::stream strm) override; @@ -115,6 +116,6 @@ class Gather : public Node { std::shared_ptr jitKernel; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/gather_elements.cpp b/src/plugins/intel_cpu/src/nodes/gather_elements.cpp index 8653bda8c483d3..d8f221dcebf34d 100644 --- a/src/plugins/intel_cpu/src/nodes/gather_elements.cpp +++ b/src/plugins/intel_cpu/src/nodes/gather_elements.cpp @@ -2,23 +2,25 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "gather_elements.h" + #include -#include #include +#include + +#include "common/cpu_memcpy.h" #include "openvino/core/parallel.hpp" -#include "gather_elements.h" #include "openvino/opsets/opset1.hpp" #include "utils/general_utils.h" -#include "common/cpu_memcpy.h" namespace ov { namespace intel_cpu { namespace node { -bool GatherElements::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool GatherElements::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { - if (!one_of(op->get_type_info(), - ov::op::v6::GatherElements::get_type_info_static())) { + if (!one_of(op->get_type_info(), ov::op::v6::GatherElements::get_type_info_static())) { errorMessage = "Node is not an instance of the GatherElements operation from operation set v6."; return false; } @@ -88,8 +90,7 @@ void GatherElements::initSupportedPrimitiveDescriptors() { dataTypeSize_ = inDataPrecision.size(); - addSupportedPrimDesc({{LayoutType::ncsp, inDataPrecision}, - {LayoutType::ncsp, ov::element::i32}}, + addSupportedPrimDesc({{LayoutType::ncsp, inDataPrecision}, {LayoutType::ncsp, ov::element::i32}}, {{LayoutType::ncsp, inDataPrecision}}, impl_desc_type::ref_any); } @@ -100,9 +101,9 @@ void GatherElements::executeDynamicImpl(dnnl::stream strm) { template void GatherElements::directExecution() { - const auto *srcData = getSrcDataAtPortAs(dataIndex_); - const auto *indices = getSrcDataAtPortAs(indicesIndex_); - auto *dstData = getDstDataAtPortAs(0); + const auto* srcData = getSrcDataAtPortAs(dataIndex_); + const auto* indices = getSrcDataAtPortAs(indicesIndex_); + auto* dstData = getDstDataAtPortAs(0); const int outSize = getChildEdgeAt(0)->getMemory().getShape().getElementsCount(); auto threadBody = [&](const int ithr, const int nthr) { @@ -133,14 +134,14 @@ void GatherElements::directExecution() { void GatherElements::execute(dnnl::stream strm) { switch (dataTypeSize_) { - case sizeof(element_type_traits::value_type): - return directExecution::value_type>(); - case sizeof(element_type_traits::value_type): - return directExecution::value_type>(); - case sizeof(element_type_traits::value_type): - return directExecution::value_type>(); - default: - OPENVINO_THROW("Unsupported data type size"); + case sizeof(element_type_traits::value_type): + return directExecution::value_type>(); + case sizeof(element_type_traits::value_type): + return directExecution::value_type>(); + case sizeof(element_type_traits::value_type): + return directExecution::value_type>(); + default: + OPENVINO_THROW("Unsupported data type size"); } } @@ -148,6 +149,6 @@ bool GatherElements::created() const { return getType() == Type::GatherElements; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/gather_elements.h b/src/plugins/intel_cpu/src/nodes/gather_elements.h index 3c2282401f7431..b050cd4e523490 100644 --- a/src/plugins/intel_cpu/src/nodes/gather_elements.h +++ b/src/plugins/intel_cpu/src/nodes/gather_elements.h @@ -14,7 +14,7 @@ class GatherElements : public Node { public: GatherElements(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -40,6 +40,6 @@ class GatherElements : public Node { void directExecution(); }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/gather_nd.cpp b/src/plugins/intel_cpu/src/nodes/gather_nd.cpp index 8c81f9b770a687..e962839e571663 100644 --- a/src/plugins/intel_cpu/src/nodes/gather_nd.cpp +++ b/src/plugins/intel_cpu/src/nodes/gather_nd.cpp @@ -2,15 +2,17 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "gather_nd.h" + #include -#include +#include #include +#include + +#include "common/cpu_memcpy.h" #include "dnnl_types.h" #include "openvino/core/parallel.hpp" -#include "gather_nd.h" -#include #include "utils/general_utils.h" -#include "common/cpu_memcpy.h" #define THROW_ERROR(...) OPENVINO_THROW("GatherND layer with name '", getName(), "' ", __VA_ARGS__) @@ -20,7 +22,9 @@ namespace node { bool GatherND::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { - if (!one_of(op->get_type_info(), ov::op::v5::GatherND::get_type_info_static(), ov::op::v8::GatherND::get_type_info_static())) { + if (!one_of(op->get_type_info(), + ov::op::v5::GatherND::get_type_info_static(), + ov::op::v8::GatherND::get_type_info_static())) { errorMessage = "Node is not an instance of the GatherND operation from operation set v5 and v8."; return false; } @@ -70,12 +74,16 @@ void GatherND::initSupportedPrimitiveDescriptors() { ov::element::Type indicesPrecision = getOriginalInputPrecisionAtPort(GATHERND_INDEXES); if (!one_of(indicesPrecision, - ov::element::i32, ov::element::i64, ov::element::i16, ov::element::u16, ov::element::i8, ov::element::u8)) { + ov::element::i32, + ov::element::i64, + ov::element::i16, + ov::element::u16, + ov::element::i8, + ov::element::u8)) { THROW_ERROR("has unsupported 'indices' input precision: ", indicesPrecision); } - addSupportedPrimDesc({{LayoutType::ncsp, inDataPrecision}, - {LayoutType::ncsp, ov::element::i32}}, + addSupportedPrimDesc({{LayoutType::ncsp, inDataPrecision}, {LayoutType::ncsp, ov::element::i32}}, {{LayoutType::ncsp, inDataPrecision}}, impl_desc_type::ref_any); } @@ -96,24 +104,33 @@ void GatherND::prepareParams() { attrs.srcDims = srcMemPtr->getStaticDims(); attrs.srcStrides = srcMemPtr->getDescWithType()->getStrides(); attrs.dstElementCount = dstMemPtr->getShape().getElementsCount(); - attrs.sliceRank = idxMemPtr->getStaticDims().back(); + attrs.sliceRank = idxMemPtr->getStaticDims().back(); execPtr = std::make_shared(attrs); } -GatherND::GatherNDExecutor::GatherNDExecutor(const GatherNDAttributes& attrs) : sliceRank(attrs.sliceRank), dataSize(attrs.dataSize) { - batchSize = std::accumulate(attrs.srcDims.begin(), attrs.srcDims.begin() + attrs.batchDims, size_t(1), std::multiplies()); - dataLength = std::accumulate(attrs.srcDims.begin() + sliceRank + attrs.batchDims, attrs.srcDims.end(), size_t(1), +GatherND::GatherNDExecutor::GatherNDExecutor(const GatherNDAttributes& attrs) + : sliceRank(attrs.sliceRank), + dataSize(attrs.dataSize) { + batchSize = std::accumulate(attrs.srcDims.begin(), + attrs.srcDims.begin() + attrs.batchDims, + size_t(1), + std::multiplies()); + dataLength = std::accumulate(attrs.srcDims.begin() + sliceRank + attrs.batchDims, + attrs.srcDims.end(), + size_t(1), std::multiplies()); cycles = attrs.dstElementCount / (dataLength * batchSize); workAmount = batchSize * cycles; - srcBatchStride = std::accumulate(attrs.srcDims.begin() + attrs.batchDims, attrs.srcDims.end(), size_t(1), + srcBatchStride = std::accumulate(attrs.srcDims.begin() + attrs.batchDims, + attrs.srcDims.end(), + size_t(1), std::multiplies()); idxBatchStride = cycles * sliceRank; dstBatchStride = cycles * dataLength; srcShifts.resize(attrs.sliceRank, 0); - for (size_t i = 0; i < attrs.sliceRank ; i++) + for (size_t i = 0; i < attrs.sliceRank; i++) srcShifts[i] = attrs.srcStrides[i + attrs.batchDims] * (dataLength > 1 ? dataSize : 1); // optimized implementation 'blocks' via memcpy @@ -128,25 +145,33 @@ void GatherND::execute(dnnl::stream strm) { if (!execPtr) THROW_ERROR("has not compiled executor."); - execPtr->exec(getSrcMemoryAtPort(GATHERND_DATA), - getSrcMemoryAtPort(GATHERND_INDEXES), - getDstMemoryAtPort(0)); + execPtr->exec(getSrcMemoryAtPort(GATHERND_DATA), getSrcMemoryAtPort(GATHERND_INDEXES), getDstMemoryAtPort(0)); } -void GatherND::GatherNDExecutor::exec(const MemoryPtr& srcMemPtr, const MemoryPtr& idxMemPtr, const MemoryPtr& dstMemPtr) { +void GatherND::GatherNDExecutor::exec(const MemoryPtr& srcMemPtr, + const MemoryPtr& idxMemPtr, + const MemoryPtr& dstMemPtr) { if (dataLength > 1) { gatherBlocks(srcMemPtr, idxMemPtr, dstMemPtr); return; } - GatherNDContext ctx { this, srcMemPtr, idxMemPtr, dstMemPtr }; - OV_SWITCH(intel_cpu, GatherNDEmitter, ctx, dataSize, - OV_CASE(sizeof(element_type_traits::value_type), element_type_traits::value_type), - OV_CASE(sizeof(element_type_traits::value_type), element_type_traits::value_type), - OV_CASE(sizeof(element_type_traits::value_type), element_type_traits::value_type)); + GatherNDContext ctx{this, srcMemPtr, idxMemPtr, dstMemPtr}; + OV_SWITCH(intel_cpu, + GatherNDEmitter, + ctx, + dataSize, + OV_CASE(sizeof(element_type_traits::value_type), + element_type_traits::value_type), + OV_CASE(sizeof(element_type_traits::value_type), + element_type_traits::value_type), + OV_CASE(sizeof(element_type_traits::value_type), + element_type_traits::value_type)); } -void GatherND::GatherNDExecutor::gatherBlocks(const MemoryPtr& srcMemPtr, const MemoryPtr& idxMemPtr, const MemoryPtr& dstMemPtr) { +void GatherND::GatherNDExecutor::gatherBlocks(const MemoryPtr& srcMemPtr, + const MemoryPtr& idxMemPtr, + const MemoryPtr& dstMemPtr) { const uint8_t* srcData = srcMemPtr->getDataAs(); const int32_t* indices = idxMemPtr->getDataAs(); uint8_t* dstData = dstMemPtr->getDataAs(); @@ -183,7 +208,9 @@ void GatherND::GatherNDExecutor::gatherBlocks(const MemoryPtr& srcMemPtr, const } template -void GatherND::GatherNDExecutor::gatherElementwise(const MemoryPtr& srcMemPtr, const MemoryPtr& idxMemPtr, const MemoryPtr& dstMemPtr) { +void GatherND::GatherNDExecutor::gatherElementwise(const MemoryPtr& srcMemPtr, + const MemoryPtr& idxMemPtr, + const MemoryPtr& dstMemPtr) { const dataType* srcData = srcMemPtr->getDataAs(); const int32_t* indices = idxMemPtr->getDataAs(); dataType* dstData = dstMemPtr->getDataAs(); @@ -227,6 +254,6 @@ bool GatherND::created() const { return getType() == Type::GatherND; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/gather_nd.h b/src/plugins/intel_cpu/src/nodes/gather_nd.h index ed643a2da08899..312cb465bf9e6c 100644 --- a/src/plugins/intel_cpu/src/nodes/gather_nd.h +++ b/src/plugins/intel_cpu/src/nodes/gather_nd.h @@ -14,7 +14,7 @@ class GatherND : public Node { public: GatherND(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -65,7 +65,7 @@ class GatherND : public Node { MemoryPtr dstMemPtr; }; - template + template struct GatherNDEmitter { void operator()(GatherNDContext& ctx) { ctx.executor->gatherElementwise(ctx.srcMemPtr, ctx.idxMemPtr, ctx.dstMemPtr); @@ -80,6 +80,6 @@ class GatherND : public Node { executorPtr execPtr = nullptr; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/gather_tree.cpp b/src/plugins/intel_cpu/src/nodes/gather_tree.cpp index 5834cd1e1048ba..2ff9a1ccdb8f59 100644 --- a/src/plugins/intel_cpu/src/nodes/gather_tree.cpp +++ b/src/plugins/intel_cpu/src/nodes/gather_tree.cpp @@ -2,13 +2,14 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "openvino/op/gather_tree.hpp" + +#include #include #include -#include -#include "openvino/op/gather_tree.hpp" -#include "openvino/core/parallel.hpp" #include "gather_tree.h" +#include "openvino/core/parallel.hpp" #include "utils/general_utils.h" namespace ov { @@ -59,11 +60,11 @@ void GatherTree::initSupportedPrimitiveDescriptors() { if (!one_of(precision, ov::element::f32, ov::element::i32)) precision = ov::element::f32; - if (getOriginalInputPrecisionAtPort(GATHER_TREE_PARENT_IDX) != precision || + if (getOriginalInputPrecisionAtPort(GATHER_TREE_PARENT_IDX) != precision || getOriginalInputPrecisionAtPort(GATHER_TREE_MAX_SEQ_LEN) != precision || - getOriginalInputPrecisionAtPort(GATHER_TREE_END_TOKEN) != precision || - getOriginalOutputPrecisionAtPort(0) != precision) { - OPENVINO_THROW(errorPrefix, " has incorrect input/output data precision. Must be the same."); + getOriginalInputPrecisionAtPort(GATHER_TREE_END_TOKEN) != precision || + getOriginalOutputPrecisionAtPort(0) != precision) { + OPENVINO_THROW(errorPrefix, " has incorrect input/output data precision. Must be the same."); } addSupportedPrimDesc({{LayoutType::ncsp, precision}, @@ -121,13 +122,15 @@ void GatherTree::executeDynamicImpl(dnnl::stream strm) { execute(strm); } -GatherTree::GatherTreeExecutor::GatherTreeExecutor(const VectorDims& stepIdxDims, const VectorDims& parentIdxDims, - const VectorDims& maxSeqLenDims, const VectorDims& dstDims) - : maxTime{static_cast(stepIdxDims[0])} - , batchSize{stepIdxDims[1]} - , beamWidth{stepIdxDims[2]} - , bbSize{batchSize * beamWidth} - , parentIdxSize{std::accumulate(parentIdxDims.cbegin(), parentIdxDims.cend(), 1lu, std::multiplies())} { +GatherTree::GatherTreeExecutor::GatherTreeExecutor(const VectorDims& stepIdxDims, + const VectorDims& parentIdxDims, + const VectorDims& maxSeqLenDims, + const VectorDims& dstDims) + : maxTime{static_cast(stepIdxDims[0])}, + batchSize{stepIdxDims[1]}, + beamWidth{stepIdxDims[2]}, + bbSize{batchSize * beamWidth}, + parentIdxSize{std::accumulate(parentIdxDims.cbegin(), parentIdxDims.cend(), 1lu, std::multiplies())} { if (maxTime != static_cast(parentIdxDims[0]) || maxTime != static_cast(dstDims[0]) || batchSize != parentIdxDims[1] || batchSize != dstDims[1] || batchSize != maxSeqLenDims[0] || beamWidth != parentIdxDims[2] || beamWidth != dstDims[2]) { @@ -136,14 +139,17 @@ GatherTree::GatherTreeExecutor::GatherTreeExecutor(const VectorDims& stepIdxDims } } -template -void GatherTree::GatherTreeExecutor::exec(const MemoryPtr& stepIdxMemPtr, const MemoryPtr& parentIdxMemPtr, - const MemoryPtr& maxSeqLenMemPtr, const MemoryPtr& endTokenMemPtr, const MemoryPtr& dstMemPtr) { - const auto *stepIdx = stepIdxMemPtr->getDataAs(); - const auto *parentIdx = parentIdxMemPtr->getDataAs(); - const auto *maxSeqLen = maxSeqLenMemPtr->getDataAs(); +template +void GatherTree::GatherTreeExecutor::exec(const MemoryPtr& stepIdxMemPtr, + const MemoryPtr& parentIdxMemPtr, + const MemoryPtr& maxSeqLenMemPtr, + const MemoryPtr& endTokenMemPtr, + const MemoryPtr& dstMemPtr) { + const auto* stepIdx = stepIdxMemPtr->getDataAs(); + const auto* parentIdx = parentIdxMemPtr->getDataAs(); + const auto* maxSeqLen = maxSeqLenMemPtr->getDataAs(); const auto endToken = (endTokenMemPtr->getDataAs())[0]; - auto *finalIdx = dstMemPtr->getDataAs(); + auto* finalIdx = dstMemPtr->getDataAs(); bool incorrectResult = false; parallel_for2d(batchSize, beamWidth, [&](size_t batch, size_t beam) { @@ -164,7 +170,7 @@ void GatherTree::GatherTreeExecutor::exec(const MemoryPtr& stepIdxMemPtr, const } bool finished = false; - auto *final = &finalIdx[batch * beamWidth + beam]; + auto* final = &finalIdx[batch * beamWidth + beam]; for (time = 0; time < maxSequenceInBeam; time++, final += bbSize) { if (finished) (*final) = endToken; @@ -184,6 +190,6 @@ bool GatherTree::created() const { return getType() == Type::GatherTree; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/gather_tree.h b/src/plugins/intel_cpu/src/nodes/gather_tree.h index 69d63f834b555d..9874fceb835ba5 100644 --- a/src/plugins/intel_cpu/src/nodes/gather_tree.h +++ b/src/plugins/intel_cpu/src/nodes/gather_tree.h @@ -14,7 +14,7 @@ class GatherTree : public Node { public: GatherTree(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -32,7 +32,7 @@ class GatherTree : public Node { const VectorDims& dstDims); ~GatherTreeExecutor() = default; - template + template void exec(const MemoryPtr& stepIdxMemPtr, const MemoryPtr& parentIdxMemPtr, const MemoryPtr& maxSeqLenMemPtr, @@ -60,6 +60,6 @@ class GatherTree : public Node { std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/generate_proposals.cpp b/src/plugins/intel_cpu/src/nodes/generate_proposals.cpp index ae32e1e4729096..0ed50c7b0d73a8 100644 --- a/src/plugins/intel_cpu/src/nodes/generate_proposals.cpp +++ b/src/plugins/intel_cpu/src/nodes/generate_proposals.cpp @@ -2,22 +2,22 @@ // SPDX-License-Identifier: Apache-2.0 // -#include +#include #include #include +#include #include -#include #include -#include +#include #if defined(HAVE_AVX2) -#include +# include #endif -#include "openvino/op/generate_proposals.hpp" -#include "openvino/core/parallel.hpp" #include "common/cpu_memcpy.h" #include "generate_proposals.h" +#include "openvino/core/parallel.hpp" +#include "openvino/op/generate_proposals.hpp" #include "shape_inference/shape_inference_internal_dyn.hpp" namespace ov { @@ -30,21 +30,29 @@ struct Indexer4d { int dim23_; int dim123_; - explicit Indexer4d(int dim0, int dim1, int dim2, int dim3): - dim3_(dim3), dim23_(dim2 * dim3), dim123_(dim1 * dim2 * dim3) { + explicit Indexer4d(int dim0, int dim1, int dim2, int dim3) + : dim3_(dim3), + dim23_(dim2 * dim3), + dim123_(dim1 * dim2 * dim3) { (void)dim0; } int operator()(int i, int j, int k, int n) const { - return i * dim123_ + j * dim23_ + k * dim3_ + n; + return i * dim123_ + j * dim23_ + k * dim3_ + n; } }; - -void refine_anchors(const float* deltas, const float* scores, const float* anchors, - float* proposals, const int anchors_num, const int bottom_H, - const int bottom_W, const float img_H, const float img_W, - const float min_box_H, const float min_box_W, +void refine_anchors(const float* deltas, + const float* scores, + const float* anchors, + float* proposals, + const int anchors_num, + const int bottom_H, + const int bottom_W, + const float img_H, + const float img_W, + const float min_box_H, + const float min_box_W, const float max_delta_log_wh, float coordinates_offset) { Indexer4d delta_idx(anchors_num, 4, bottom_H, bottom_W); @@ -111,18 +119,23 @@ void refine_anchors(const float* deltas, const float* scores, const float* ancho void unpack_boxes(const float* p_proposals, float* unpacked_boxes, int* is_dead, int pre_nms_topn) { parallel_for(pre_nms_topn, [&](size_t i) { - unpacked_boxes[0*pre_nms_topn + i] = p_proposals[6*i + 0]; - unpacked_boxes[1*pre_nms_topn + i] = p_proposals[6*i + 1]; - unpacked_boxes[2*pre_nms_topn + i] = p_proposals[6*i + 2]; - unpacked_boxes[3*pre_nms_topn + i] = p_proposals[6*i + 3]; - unpacked_boxes[4*pre_nms_topn + i] = p_proposals[6*i + 4]; - is_dead[i] = (p_proposals[6*i + 5] == 1.0) ? 0 : 1; + unpacked_boxes[0 * pre_nms_topn + i] = p_proposals[6 * i + 0]; + unpacked_boxes[1 * pre_nms_topn + i] = p_proposals[6 * i + 1]; + unpacked_boxes[2 * pre_nms_topn + i] = p_proposals[6 * i + 2]; + unpacked_boxes[3 * pre_nms_topn + i] = p_proposals[6 * i + 3]; + unpacked_boxes[4 * pre_nms_topn + i] = p_proposals[6 * i + 4]; + is_dead[i] = (p_proposals[6 * i + 5] == 1.0) ? 0 : 1; }); } -void nms_cpu(const int num_boxes, int is_dead[], - const float* boxes, int index_out[], size_t* const num_out, - const int base_index, const float nms_thresh, const int max_num_out, +void nms_cpu(const int num_boxes, + int is_dead[], + const float* boxes, + int index_out[], + size_t* const num_out, + const int base_index, + const float nms_thresh, + const int max_num_out, float coordinates_offset) { const int num_proposals = num_boxes; size_t count = 0; @@ -133,9 +146,9 @@ void nms_cpu(const int num_boxes, int is_dead[], const float* y1 = boxes + 3 * num_proposals; #if defined(HAVE_AVX2) - __m256 vc_fone = _mm256_set1_ps(coordinates_offset); + __m256 vc_fone = _mm256_set1_ps(coordinates_offset); __m256i vc_ione = _mm256_set1_epi32(1); - __m256 vc_zero = _mm256_set1_ps(0.0f); + __m256 vc_zero = _mm256_set1_ps(0.0f); __m256 vc_nms_thresh = _mm256_set1_ps(nms_thresh); #endif @@ -156,13 +169,13 @@ void nms_cpu(const int num_boxes, int is_dead[], __m256 vx1i = _mm256_set1_ps(x1[box]); __m256 vy1i = _mm256_set1_ps(y1[box]); - __m256 vA_width = _mm256_sub_ps(vx1i, vx0i); + __m256 vA_width = _mm256_sub_ps(vx1i, vx0i); __m256 vA_height = _mm256_sub_ps(vy1i, vy0i); - __m256 vA_area = _mm256_mul_ps(_mm256_add_ps(vA_width, vc_fone), _mm256_add_ps(vA_height, vc_fone)); + __m256 vA_area = _mm256_mul_ps(_mm256_add_ps(vA_width, vc_fone), _mm256_add_ps(vA_height, vc_fone)); for (; tail <= num_boxes - 8; tail += 8) { - __m256i *pdst = reinterpret_cast<__m256i*>(is_dead + tail); - __m256i vdst = _mm256_loadu_si256(pdst); + __m256i* pdst = reinterpret_cast<__m256i*>(is_dead + tail); + __m256i vdst = _mm256_loadu_si256(pdst); __m256 vx0j = _mm256_loadu_ps(x0 + tail); __m256 vy0j = _mm256_loadu_ps(y0 + tail); @@ -174,13 +187,13 @@ void nms_cpu(const int num_boxes, int is_dead[], __m256 vx1 = _mm256_min_ps(vx1i, vx1j); __m256 vy1 = _mm256_min_ps(vy1i, vy1j); - __m256 vwidth = _mm256_add_ps(_mm256_sub_ps(vx1, vx0), vc_fone); + __m256 vwidth = _mm256_add_ps(_mm256_sub_ps(vx1, vx0), vc_fone); __m256 vheight = _mm256_add_ps(_mm256_sub_ps(vy1, vy0), vc_fone); __m256 varea = _mm256_mul_ps(_mm256_max_ps(vc_zero, vwidth), _mm256_max_ps(vc_zero, vheight)); - __m256 vB_width = _mm256_sub_ps(vx1j, vx0j); + __m256 vB_width = _mm256_sub_ps(vx1j, vx0j); __m256 vB_height = _mm256_sub_ps(vy1j, vy0j); - __m256 vB_area = _mm256_mul_ps(_mm256_add_ps(vB_width, vc_fone), _mm256_add_ps(vB_height, vc_fone)); + __m256 vB_area = _mm256_mul_ps(_mm256_add_ps(vB_width, vc_fone), _mm256_add_ps(vB_height, vc_fone)); __m256 vdivisor = _mm256_sub_ps(_mm256_add_ps(vA_area, vB_area), varea); __m256 vintersection_area = _mm256_div_ps(varea, vdivisor); @@ -221,9 +234,9 @@ void nms_cpu(const int num_boxes, int is_dead[], const float y1 = std::min(y1i, y1j); // intersection area - const float width = std::max(0.0f, x1 - x0 + coordinates_offset); - const float height = std::max(0.0f, y1 - y0 + coordinates_offset); - const float area = width * height; + const float width = std::max(0.0f, x1 - x0 + coordinates_offset); + const float height = std::max(0.0f, y1 - y0 + coordinates_offset); + const float area = width * height; // area of A, B const float A_area = (x1i - x0i + coordinates_offset) * (y1i - y0i + coordinates_offset); @@ -241,16 +254,20 @@ void nms_cpu(const int num_boxes, int is_dead[], *num_out = count; } - -void fill_output_blobs(const float* proposals, const int* roi_indices, - float* rois, float* scores, uint8_t* roi_num, - const int num_proposals, const size_t num_rois, const int post_nms_topn, +void fill_output_blobs(const float* proposals, + const int* roi_indices, + float* rois, + float* scores, + uint8_t* roi_num, + const int num_proposals, + const size_t num_rois, + const int post_nms_topn, ov::element::Type roi_num_type) { - const float *src_x0 = proposals + 0 * num_proposals; - const float *src_y0 = proposals + 1 * num_proposals; - const float *src_x1 = proposals + 2 * num_proposals; - const float *src_y1 = proposals + 3 * num_proposals; - const float *src_score = proposals + 4 * num_proposals; + const float* src_x0 = proposals + 0 * num_proposals; + const float* src_y0 = proposals + 1 * num_proposals; + const float* src_x1 = proposals + 2 * num_proposals; + const float* src_y1 = proposals + 3 * num_proposals; + const float* src_score = proposals + 4 * num_proposals; parallel_for(num_rois, [&](size_t i) { int index = roi_indices[i]; @@ -274,8 +291,8 @@ void fill_output_blobs(const float* proposals, const int* roi_indices, } // namespace -bool GenerateProposals::isSupportedOperation - (const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool GenerateProposals::isSupportedOperation(const std::shared_ptr& op, + std::string& errorMessage) noexcept { try { if (!ov::as_type_ptr(op)) { errorMessage = "Node is not an instance of the Proposal from the operations set v0."; @@ -332,13 +349,13 @@ void GenerateProposals::execute(dnnl::stream strm) { } size_t anchor_dims_size = 1; - const auto &anchorDims = getParentEdgeAt(INPUT_ANCHORS)->getMemory().getStaticDims(); + const auto& anchorDims = getParentEdgeAt(INPUT_ANCHORS)->getMemory().getStaticDims(); for (size_t i = 0; i < anchorDims.size(); i++) { anchor_dims_size *= anchorDims[i]; } size_t deltas_dims_size = 1; - const auto &deltaDims = getParentEdgeAt(INPUT_DELTAS)->getMemory().getStaticDims(); + const auto& deltaDims = getParentEdgeAt(INPUT_DELTAS)->getMemory().getStaticDims(); for (size_t i = 1; i < deltaDims.size(); i++) { deltas_dims_size *= deltaDims[i]; } @@ -346,7 +363,7 @@ void GenerateProposals::execute(dnnl::stream strm) { OPENVINO_THROW("'Anchors' blob size for GenerateProposals is incompatible with 'deltas' blob size!"); size_t score_dims_size = 1; - const auto &scoreDims = getParentEdgeAt(INPUT_SCORES)->getMemory().getStaticDims(); + const auto& scoreDims = getParentEdgeAt(INPUT_SCORES)->getMemory().getStaticDims(); for (size_t i = 1; i < scoreDims.size(); i++) { score_dims_size *= scoreDims[i]; } @@ -354,16 +371,16 @@ void GenerateProposals::execute(dnnl::stream strm) { OPENVINO_THROW("'Deltas' blob size for GenerateProposals is incompatible with 'scores' blob size!"); size_t im_info_dims_size = 1; - const auto &infoDims = getParentEdgeAt(INPUT_IM_INFO)->getMemory().getStaticDims(); + const auto& infoDims = getParentEdgeAt(INPUT_IM_INFO)->getMemory().getStaticDims(); for (size_t i = 1; i < infoDims.size(); i++) { im_info_dims_size *= infoDims[i]; } // Prepare memory - const float *p_deltas_item = getSrcDataAtPortAs(INPUT_DELTAS); - const float *p_scores_item = getSrcDataAtPortAs(INPUT_SCORES); - const float *p_anchors_item = getSrcDataAtPortAs(INPUT_ANCHORS); - const float *p_img_info_cpu = getSrcDataAtPortAs(INPUT_IM_INFO); + const float* p_deltas_item = getSrcDataAtPortAs(INPUT_DELTAS); + const float* p_scores_item = getSrcDataAtPortAs(INPUT_SCORES); + const float* p_anchors_item = getSrcDataAtPortAs(INPUT_ANCHORS); + const float* p_img_info_cpu = getSrcDataAtPortAs(INPUT_IM_INFO); const int anchors_num = scoreDims[1]; @@ -422,27 +439,50 @@ void GenerateProposals::execute(dnnl::stream strm) { const float min_box_H = min_size_ * scale_h; const float min_box_W = min_size_ * scale_w; - refine_anchors(p_deltas_item, p_scores_item, p_anchors_item, - reinterpret_cast(&proposals_[0]), anchors_num, bottom_H, - bottom_W, img_H, img_W, - min_box_H, min_box_W, + refine_anchors(p_deltas_item, + p_scores_item, + p_anchors_item, + reinterpret_cast(&proposals_[0]), + anchors_num, + bottom_H, + bottom_W, + img_H, + img_W, + min_box_H, + min_box_W, static_cast(std::log(1000. / 16.)), coordinates_offset_); - std::partial_sort(proposals_.begin(), proposals_.begin() + pre_nms_topn, proposals_.end(), - [](const ProposalBox &struct1, const ProposalBox &struct2) { + std::partial_sort(proposals_.begin(), + proposals_.begin() + pre_nms_topn, + proposals_.end(), + [](const ProposalBox& struct1, const ProposalBox& struct2) { return (struct1.score > struct2.score); }); - unpack_boxes(reinterpret_cast(&proposals_[0]), &unpacked_boxes[0], &is_dead[0], pre_nms_topn); - nms_cpu(pre_nms_topn, &is_dead[0], &unpacked_boxes[0], &roi_indices_[0], &num_rois, 0, - nms_thresh_, post_nms_topn_, coordinates_offset_); + unpack_boxes(reinterpret_cast(&proposals_[0]), &unpacked_boxes[0], &is_dead[0], pre_nms_topn); + nms_cpu(pre_nms_topn, + &is_dead[0], + &unpacked_boxes[0], + &roi_indices_[0], + &num_rois, + 0, + nms_thresh_, + post_nms_topn_, + coordinates_offset_); size_t new_num_rois = total_num_rois + num_rois; roi_item.resize(new_num_rois * 4); score_item.resize(new_num_rois); - fill_output_blobs(&unpacked_boxes[0], &roi_indices_[0], &roi_item[total_num_rois * 4], &score_item[total_num_rois], - p_roi_num, pre_nms_topn, num_rois, post_nms_topn_, roi_num_type); + fill_output_blobs(&unpacked_boxes[0], + &roi_indices_[0], + &roi_item[total_num_rois * 4], + &score_item[total_num_rois], + p_roi_num, + pre_nms_topn, + num_rois, + post_nms_topn_, + roi_num_type); p_deltas_item += deltas_dims_size; p_scores_item += score_dims_size; p_img_info_cpu += im_info_dims_size; @@ -451,13 +491,13 @@ void GenerateProposals::execute(dnnl::stream strm) { } // copy to out memory redefineOutputMemory({VectorDims{total_num_rois, 4}, VectorDims{total_num_rois}, VectorDims{batch_size}}); - float *p_roi_item = getDstDataAtPortAs(OUTPUT_ROIS); - float *p_roi_score_item = getDstDataAtPortAs(OUTPUT_SCORES); + float* p_roi_item = getDstDataAtPortAs(OUTPUT_ROIS); + float* p_roi_score_item = getDstDataAtPortAs(OUTPUT_SCORES); uint8_t* p_roi_num_item = getDstDataAtPortAs(OUTPUT_ROI_NUM); memcpy(p_roi_item, &roi_item[0], roi_item.size() * sizeof(float)); memcpy(p_roi_score_item, &score_item[0], score_item.size() * sizeof(float)); memcpy(p_roi_num_item, &roi_num[0], getDstMemoryAtPort(OUTPUT_ROI_NUM)->getSize()); - } catch (const std::exception &e) { + } catch (const std::exception& e) { std::string errorMsg = e.what(); OPENVINO_THROW(errorMsg); } @@ -475,6 +515,6 @@ bool GenerateProposals::needPrepareParams() const { return false; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/generate_proposals.h b/src/plugins/intel_cpu/src/nodes/generate_proposals.h index 5438f30011d986..666338eed3d4aa 100644 --- a/src/plugins/intel_cpu/src/nodes/generate_proposals.h +++ b/src/plugins/intel_cpu/src/nodes/generate_proposals.h @@ -14,7 +14,7 @@ class GenerateProposals : public Node { public: GenerateProposals(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -35,13 +35,13 @@ class GenerateProposals : public Node { // scores, shape [rois_num] // roi_num, shape [N] - const int INPUT_IM_INFO {0}; - const int INPUT_ANCHORS {1}; - const int INPUT_DELTAS {2}; - const int INPUT_SCORES {3}; - const int OUTPUT_ROIS {0}; - const int OUTPUT_SCORES {1}; - const int OUTPUT_ROI_NUM {2}; + const int INPUT_IM_INFO{0}; + const int INPUT_ANCHORS{1}; + const int INPUT_DELTAS{2}; + const int INPUT_SCORES{3}; + const int OUTPUT_ROIS{0}; + const int OUTPUT_SCORES{1}; + const int OUTPUT_ROI_NUM{2}; float min_size_ = 0.f; int pre_nms_topn_ = 0; @@ -52,6 +52,6 @@ class GenerateProposals : public Node { std::vector roi_indices_; }; -} // namespace node +} // namespace node } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/grid_sample.cpp b/src/plugins/intel_cpu/src/nodes/grid_sample.cpp index c8b73360539b68..9f346a2db14dac 100644 --- a/src/plugins/intel_cpu/src/nodes/grid_sample.cpp +++ b/src/plugins/intel_cpu/src/nodes/grid_sample.cpp @@ -3,15 +3,16 @@ // #include "grid_sample.hpp" -#include "openvino/op/grid_sample.hpp" + #include "openvino/core/parallel.hpp" +#include "openvino/op/grid_sample.hpp" using namespace ov::intel_cpu; using namespace ov::intel_cpu::node; #if defined(OPENVINO_ARCH_X86_64) using namespace dnnl::impl::cpu; -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 #define THROW_ERROR(...) OPENVINO_THROW(getTypeStr(), " node with name '", getName(), "' ", __VA_ARGS__) @@ -28,7 +29,7 @@ bool GridSample::isSupportedOperation(const std::shared_ptr& op, } #else return false; -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 } catch (...) { return false; } @@ -61,30 +62,30 @@ GridSample::GridSample(const std::shared_ptr& op, const GraphContext:: const auto& attributes = ov::as_type_ptr(op)->get_attributes(); alignCorners = attributes.align_corners; switch (attributes.mode) { - case op::v9::GridSample::InterpolationMode::BILINEAR: - interpolationMode = GridSampleInterpolationMode::BILINEAR; - break; - case op::v9::GridSample::InterpolationMode::BICUBIC: - interpolationMode = GridSampleInterpolationMode::BICUBIC; - break; - case op::v9::GridSample::InterpolationMode::NEAREST: - interpolationMode = GridSampleInterpolationMode::NEAREST; - break; - default: - THROW_CPU_NODE_ERR("supports only BILINEAR, BICUBIC, NEAREST interpolation modes."); + case op::v9::GridSample::InterpolationMode::BILINEAR: + interpolationMode = GridSampleInterpolationMode::BILINEAR; + break; + case op::v9::GridSample::InterpolationMode::BICUBIC: + interpolationMode = GridSampleInterpolationMode::BICUBIC; + break; + case op::v9::GridSample::InterpolationMode::NEAREST: + interpolationMode = GridSampleInterpolationMode::NEAREST; + break; + default: + THROW_CPU_NODE_ERR("supports only BILINEAR, BICUBIC, NEAREST interpolation modes."); } switch (attributes.padding_mode) { - case op::v9::GridSample::PaddingMode::ZEROS: - paddingMode = GridSamplePaddingMode::ZEROS; - break; - case op::v9::GridSample::PaddingMode::BORDER: - paddingMode = GridSamplePaddingMode::BORDER; - break; - case op::v9::GridSample::PaddingMode::REFLECTION: - paddingMode = GridSamplePaddingMode::REFLECTION; - break; - default: - THROW_CPU_NODE_ERR("supports only BORDER, REFLECTION, ZEROS paddings modes."); + case op::v9::GridSample::PaddingMode::ZEROS: + paddingMode = GridSamplePaddingMode::ZEROS; + break; + case op::v9::GridSample::PaddingMode::BORDER: + paddingMode = GridSamplePaddingMode::BORDER; + break; + case op::v9::GridSample::PaddingMode::REFLECTION: + paddingMode = GridSamplePaddingMode::REFLECTION; + break; + default: + THROW_CPU_NODE_ERR("supports only BORDER, REFLECTION, ZEROS paddings modes."); } } @@ -107,8 +108,7 @@ void GridSample::initSupportedPrimitiveDescriptors() { } // 95905 - to add nspc layout support. - addSupportedPrimDesc({{LayoutType::ncsp, dataPrecision}, - {LayoutType::ncsp, gridPrecision}}, + addSupportedPrimDesc({{LayoutType::ncsp, dataPrecision}, {LayoutType::ncsp, gridPrecision}}, {{LayoutType::ncsp, dataPrecision}}, implType); } @@ -116,25 +116,26 @@ void GridSample::initSupportedPrimitiveDescriptors() { void GridSample::createPrimitive() { kernel::GridSampleKernelConfParams jcp; - jcp.inDataPrc = dataPrecision; - jcp.gridPrc = gridPrecision; + jcp.inDataPrc = dataPrecision; + jcp.gridPrc = gridPrecision; jcp.dynamicShapes = isDynamicNode(); - jcp.alignCorners = alignCorners; + jcp.alignCorners = alignCorners; jcp.interpolationMode = interpolationMode; - jcp.paddingMode = paddingMode; + jcp.paddingMode = paddingMode; const auto& srcDataDims = getInputShapeAtPort(IN_DATA).getDims(); if (!jcp.dynamicShapes) { - jcp.batchNum = srcDataDims[0]; - jcp.cannelNum = srcDataDims[1]; - jcp.dynamicBatch = false; + jcp.batchNum = srcDataDims[0]; + jcp.cannelNum = srcDataDims[1]; + jcp.dynamicBatch = false; jcp.dynamicChannel = false; - jcp.srcBatchStepB = std::accumulate(srcDataDims.begin() + 1, srcDataDims.end(), dataTypeSize, std::multiplies()); + jcp.srcBatchStepB = + std::accumulate(srcDataDims.begin() + 1, srcDataDims.end(), dataTypeSize, std::multiplies()); } else { - jcp.dynamicBatch = srcDataDims[0] == Shape::UNDEFINED_DIM; - jcp.batchNum = jcp.dynamicBatch ? 1lu : srcDataDims[0]; + jcp.dynamicBatch = srcDataDims[0] == Shape::UNDEFINED_DIM; + jcp.batchNum = jcp.dynamicBatch ? 1lu : srcDataDims[0]; jcp.dynamicChannel = srcDataDims[1] == Shape::UNDEFINED_DIM; - jcp.cannelNum = jcp.dynamicChannel ? 1lu : srcDataDims[1]; + jcp.cannelNum = jcp.dynamicChannel ? 1lu : srcDataDims[1]; } if (x64::mayiuse(x64::avx512_core)) { @@ -195,7 +196,7 @@ void GridSample::prepareParams() { const uint64_t dataElPerVec = jitKernel->getDataElPerVec(); const auto& srcDataShape = dataMemPtr->getStaticDims(); - const auto& dstShape = dstMemPtr->getStaticDims(); + const auto& dstShape = dstMemPtr->getStaticDims(); const uint64_t totalWork = dstShape[2] * dstShape[3]; const uint64_t wpt = ((totalWork / dataElPerVec) / m_threads_num + 1) * dataElPerVec; @@ -210,26 +211,27 @@ void GridSample::prepareParams() { return; } - p.batchNum = srcDataShape[0]; - p.channelsNum = srcDataShape[1]; + p.batchNum = srcDataShape[0]; + p.channelsNum = srcDataShape[1]; p.srcHeightF[0] = srcDataShape[2]; - p.srcWidthF[0] = srcDataShape[3]; + p.srcWidthF[0] = srcDataShape[3]; p.gridStartB = dstStart * 2 * gridTypeSize; - p.dstStartB = dstStart * dataTypeSize; + p.dstStartB = dstStart * dataTypeSize; - p.srcBatchStepB = std::accumulate(srcDataShape.begin() + 1, srcDataShape.end(), dataTypeSize, std::multiplies()); + p.srcBatchStepB = + std::accumulate(srcDataShape.begin() + 1, srcDataShape.end(), dataTypeSize, std::multiplies()); p.gridBatchStepB = (dstShape[2] * dstShape[3] - p.workAmount) * 2 * gridTypeSize; - p.dstBatchStepB = (dstShape[1] * dstShape[2] * dstShape[3] - p.workAmount) * dataTypeSize; + p.dstBatchStepB = (dstShape[1] * dstShape[2] * dstShape[3] - p.workAmount) * dataTypeSize; p.srcChannelStepB = srcDataShape[2] * srcDataShape[3] * dataTypeSize; p.dstChannelStepB = dstShape[2] * dstShape[3] * dataTypeSize; p.dataTypeSize[0] = dataTypeSize; p.srcHeightSub1F[0] = p.srcHeightF[0] - 1.f; - p.srcWidthSub1F[0] = p.srcWidthF[0] - 1.f; + p.srcWidthSub1F[0] = p.srcWidthF[0] - 1.f; p.srcHeightMul2F[0] = p.srcHeightF[0] * 2.f; - p.srcWidthMul2F[0] = p.srcWidthF[0] * 2.f; + p.srcWidthMul2F[0] = p.srcWidthF[0] * 2.f; if (interpolationMode == GridSampleInterpolationMode::BICUBIC && srcDataShape[3] >= 4) { p.srcWidthB[0] = (srcDataShape[3] - 3) * dataTypeSize; } else { @@ -237,24 +239,24 @@ void GridSample::prepareParams() { } if (alignCorners) { p.srcHeightMul2Sub1F[0] = p.srcHeightF[0] == 1.f ? 1.f : p.srcHeightSub1F[0] * 2.f; - p.srcWidthMul2Sub1F[0] = p.srcWidthF[0] == 1.f ? 1.f : p.srcWidthSub1F[0] * 2.f; - p.wDenormCoefF[0] = (p.srcWidthF[0] - 1.f) / 2.f; + p.srcWidthMul2Sub1F[0] = p.srcWidthF[0] == 1.f ? 1.f : p.srcWidthSub1F[0] * 2.f; + p.wDenormCoefF[0] = (p.srcWidthF[0] - 1.f) / 2.f; p.hDenormCoefF[0] = (p.srcHeightF[0] - 1.f) / 2.f; } else { p.srcHeightMul2Sub1F[0] = p.srcHeightMul2F[0] - 1.f; - p.srcWidthMul2Sub1F[0] = p.srcWidthMul2F[0] - 1.f; + p.srcWidthMul2Sub1F[0] = p.srcWidthMul2F[0] - 1.f; } if (!x64::mayiuse(x64::avx512_core)) { - std::fill(p.srcHeightF.begin(), p.srcHeightF.end(), p.srcHeightF[0]); - std::fill(p.srcWidthF.begin(), p.srcWidthF.end(), p.srcWidthF[0]); - std::fill(p.dataTypeSize.begin(), p.dataTypeSize.end(), p.dataTypeSize[0]); - std::fill(p.srcHeightSub1F.begin(), p.srcHeightSub1F.end(), p.srcHeightSub1F[0]); - std::fill(p.srcWidthSub1F.begin(), p.srcWidthSub1F.end(), p.srcWidthSub1F[0]); - std::fill(p.srcHeightMul2F.begin(), p.srcHeightMul2F.end(), p.srcHeightMul2F[0]); - std::fill(p.srcWidthMul2F.begin(), p.srcWidthMul2F.end(), p.srcWidthMul2F[0]); - std::fill(p.srcWidthB.begin(), p.srcWidthB.end(), p.srcWidthB[0]); + std::fill(p.srcHeightF.begin(), p.srcHeightF.end(), p.srcHeightF[0]); + std::fill(p.srcWidthF.begin(), p.srcWidthF.end(), p.srcWidthF[0]); + std::fill(p.dataTypeSize.begin(), p.dataTypeSize.end(), p.dataTypeSize[0]); + std::fill(p.srcHeightSub1F.begin(), p.srcHeightSub1F.end(), p.srcHeightSub1F[0]); + std::fill(p.srcWidthSub1F.begin(), p.srcWidthSub1F.end(), p.srcWidthSub1F[0]); + std::fill(p.srcHeightMul2F.begin(), p.srcHeightMul2F.end(), p.srcHeightMul2F[0]); + std::fill(p.srcWidthMul2F.begin(), p.srcWidthMul2F.end(), p.srcWidthMul2F[0]); + std::fill(p.srcWidthB.begin(), p.srcWidthB.end(), p.srcWidthB[0]); std::fill(p.srcHeightMul2Sub1F.begin(), p.srcHeightMul2Sub1F.end(), p.srcHeightMul2Sub1F[0]); - std::fill(p.srcWidthMul2Sub1F.begin(), p.srcWidthMul2Sub1F.end(), p.srcWidthMul2Sub1F[0]); + std::fill(p.srcWidthMul2Sub1F.begin(), p.srcWidthMul2Sub1F.end(), p.srcWidthMul2Sub1F[0]); if (alignCorners) { std::fill(p.wDenormCoefF.begin(), p.wDenormCoefF.end(), p.wDenormCoefF[0]); std::fill(p.hDenormCoefF.begin(), p.hDenormCoefF.end(), p.hDenormCoefF[0]); @@ -264,9 +266,9 @@ void GridSample::prepareParams() { } void GridSample::execute(dnnl::stream strm) { - const void* srcData = getSrcDataAtPort(IN_DATA); + const void* srcData = getSrcDataAtPort(IN_DATA); const uint8_t* gridData = getSrcDataAtPortAs(IN_GRID); - uint8_t* dstData = getDstDataAtPortAs(0); + uint8_t* dstData = getDstDataAtPortAs(0); auto threadBody = [&](const int ithr, const int nthr) { const auto& p = execParamsPerThread[ithr]; @@ -275,30 +277,30 @@ void GridSample::execute(dnnl::stream strm) { return; } - arg.src = srcData; - arg.grid = gridData + p.gridStartB; - arg.dst = dstData + p.dstStartB; - arg.batchNum = p.batchNum; - arg.channelsNum = p.channelsNum; - arg.srcHeightF = p.srcHeightF.data(); - arg.srcWidthF = p.srcWidthF.data(); - arg.srcWidthB = p.srcWidthB.data(); - arg.srcChannelStepB = p.srcChannelStepB; - arg.dstChannelStepB = p.dstChannelStepB; - arg.srcBatchStepB = p.srcBatchStepB; - arg.gridBatchStepB = p.gridBatchStepB; - arg.dstBatchStepB = p.dstBatchStepB; - arg.srcHeightSub1F = p.srcHeightSub1F.data(); - arg.srcWidthSub1F = p.srcWidthSub1F.data(); - arg.srcWidthMul2F = p.srcWidthMul2F.data(); - arg.srcHeightMul2F = p.srcHeightMul2F.data(); + arg.src = srcData; + arg.grid = gridData + p.gridStartB; + arg.dst = dstData + p.dstStartB; + arg.batchNum = p.batchNum; + arg.channelsNum = p.channelsNum; + arg.srcHeightF = p.srcHeightF.data(); + arg.srcWidthF = p.srcWidthF.data(); + arg.srcWidthB = p.srcWidthB.data(); + arg.srcChannelStepB = p.srcChannelStepB; + arg.dstChannelStepB = p.dstChannelStepB; + arg.srcBatchStepB = p.srcBatchStepB; + arg.gridBatchStepB = p.gridBatchStepB; + arg.dstBatchStepB = p.dstBatchStepB; + arg.srcHeightSub1F = p.srcHeightSub1F.data(); + arg.srcWidthSub1F = p.srcWidthSub1F.data(); + arg.srcWidthMul2F = p.srcWidthMul2F.data(); + arg.srcHeightMul2F = p.srcHeightMul2F.data(); arg.srcHeightMul2Sub1F = p.srcHeightMul2Sub1F.data(); - arg.srcWidthMul2Sub1F = p.srcWidthMul2Sub1F.data(); - arg.wDenormCoefF = p.wDenormCoefF.data(); - arg.hDenormCoefF = p.hDenormCoefF.data(); - arg.dataTypeSize = p.dataTypeSize.data(); - arg.buffer = p.buffer.data(); - arg.workAmount = p.workAmount; + arg.srcWidthMul2Sub1F = p.srcWidthMul2Sub1F.data(); + arg.wDenormCoefF = p.wDenormCoefF.data(); + arg.hDenormCoefF = p.hDenormCoefF.data(); + arg.dataTypeSize = p.dataTypeSize.data(); + arg.buffer = p.buffer.data(); + arg.workAmount = p.workAmount; (*jitKernel)(&arg); }; @@ -314,4 +316,4 @@ bool GridSample::created() const { return getType() == Type::GridSample; } -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 diff --git a/src/plugins/intel_cpu/src/nodes/grid_sample.hpp b/src/plugins/intel_cpu/src/nodes/grid_sample.hpp index b4468d58be9b52..eb4fd38b64c878 100644 --- a/src/plugins/intel_cpu/src/nodes/grid_sample.hpp +++ b/src/plugins/intel_cpu/src/nodes/grid_sample.hpp @@ -5,6 +5,7 @@ #pragma once #include + #include "kernels/x64/grid_sample.hpp" namespace ov { @@ -16,35 +17,35 @@ class GridSample : public Node { GridSample(const std::shared_ptr& op, const GraphContext::CPtr context); static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void createPrimitive() override; void execute(dnnl::stream strm) override; bool created() const override; struct threadExecParams { - uint64_t batchNum = 1lu; + uint64_t batchNum = 1lu; uint64_t channelsNum = 1lu; - std::vector srcHeightF{ 1.f }; - std::vector srcWidthF{ 1.f }; - std::vector srcWidthB{ 1lu }; - std::vector dataTypeSize{ 1lu }; - std::vector srcHeightMul2F{ 1.f }; - std::vector srcWidthMul2F{ 1.f }; - std::vector srcHeightMul2Sub1F{ 1.f }; - std::vector srcWidthMul2Sub1F{ 1.f }; - std::vector srcHeightSub1F{ 1.f }; - std::vector srcWidthSub1F{ 1.f }; - std::vector wDenormCoefF{ 1.f }; - std::vector hDenormCoefF{ 1.f }; - uint64_t gridStartB = 0lu; - uint64_t dstStartB = 0lu; + std::vector srcHeightF{1.f}; + std::vector srcWidthF{1.f}; + std::vector srcWidthB{1lu}; + std::vector dataTypeSize{1lu}; + std::vector srcHeightMul2F{1.f}; + std::vector srcWidthMul2F{1.f}; + std::vector srcHeightMul2Sub1F{1.f}; + std::vector srcWidthMul2Sub1F{1.f}; + std::vector srcHeightSub1F{1.f}; + std::vector srcWidthSub1F{1.f}; + std::vector wDenormCoefF{1.f}; + std::vector hDenormCoefF{1.f}; + uint64_t gridStartB = 0lu; + uint64_t dstStartB = 0lu; uint64_t srcChannelStepB = 0lu; uint64_t dstChannelStepB = 0lu; - uint64_t srcBatchStepB = 0lu; - uint64_t gridBatchStepB = 0lu; - uint64_t dstBatchStepB = 0lu; - uint64_t workAmount = 0lu; + uint64_t srcBatchStepB = 0lu; + uint64_t gridBatchStepB = 0lu; + uint64_t dstBatchStepB = 0lu; + uint64_t workAmount = 0lu; std::vector buffer; }; @@ -71,6 +72,6 @@ class GridSample : public Node { std::shared_ptr jitKernel; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/grn.cpp b/src/plugins/intel_cpu/src/nodes/grn.cpp index 10de2ef2286f0f..374452812eaf3a 100644 --- a/src/plugins/intel_cpu/src/nodes/grn.cpp +++ b/src/plugins/intel_cpu/src/nodes/grn.cpp @@ -2,11 +2,12 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "grn.h" + #include -#include "openvino/opsets/opset1.hpp" #include "openvino/core/parallel.hpp" -#include "grn.h" +#include "openvino/opsets/opset1.hpp" namespace ov { namespace intel_cpu { @@ -97,11 +98,12 @@ void GRN::execute(dnnl::stream strm) { parallel_for3d(N, H, W, [&](int b, int h, int w) { double variance = 0; for (int c = 0; c < C; c++) { - variance += std::pow(src_data[b*C*H*W + c*H*W + h*W + w], 2); + variance += std::pow(src_data[b * C * H * W + c * H * W + h * W + w], 2); } variance = std::pow(variance + bias, 0.5f); for (int c = 0; c < C; c++) { - dst_data[b*C*H*W + c*H*W + h*W + w] = src_data[b*C*H*W + c*H*W + h*W + w] / static_cast(variance); + dst_data[b * C * H * W + c * H * W + h * W + w] = + src_data[b * C * H * W + c * H * W + h * W + w] / static_cast(variance); } }); } @@ -110,6 +112,6 @@ bool GRN::created() const { return getType() == Type::GRN; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/grn.h b/src/plugins/intel_cpu/src/nodes/grn.h index 52e77318e2132f..17eac4e81b9d6c 100644 --- a/src/plugins/intel_cpu/src/nodes/grn.h +++ b/src/plugins/intel_cpu/src/nodes/grn.h @@ -14,7 +14,7 @@ class GRN : public Node { public: GRN(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -34,6 +34,6 @@ class GRN : public Node { std::string errorPrefix; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/if.cpp b/src/plugins/intel_cpu/src/nodes/if.cpp index 1b6102ff954689..8de1cf14920d74 100644 --- a/src/plugins/intel_cpu/src/nodes/if.cpp +++ b/src/plugins/intel_cpu/src/nodes/if.cpp @@ -4,22 +4,22 @@ #include "if.h" -#include "openvino/op/if.hpp" +#include +#include #include "common/cpu_memcpy.h" -#include "shape_inference/shape_inference_internal_dyn.hpp" #include "nodes/common/cpu_convert.h" +#include "openvino/op/if.hpp" +#include "shape_inference/shape_inference_internal_dyn.hpp" #include "transformations/utils/utils.hpp" -#include -#include - namespace ov { namespace intel_cpu { namespace node { -If::PortMapHelper::PortMapHelper(const MemoryPtr &from, const std::deque& to, - const dnnl::engine& eng) : srcMemPtr(from), dstMemPtrs(to) { +If::PortMapHelper::PortMapHelper(const MemoryPtr& from, const std::deque& to, const dnnl::engine& eng) + : srcMemPtr(from), + dstMemPtrs(to) { size = 0; if (srcMemPtr->getDesc().isDefined()) size = srcMemPtr->getShape().getElementsCount(); @@ -43,7 +43,7 @@ void If::PortMapHelper::execute(dnnl::stream& strm) { } void If::PortMapHelper::redefineTo() { - const auto &currDesc = dstMemPtrs.front()->getDesc(); + const auto& currDesc = dstMemPtrs.front()->getDesc(); if (currDesc.getShape().isDynamic() || currDesc.getShape().getStaticDims() != srcMemPtr->getStaticDims()) { // TODO : check the entire dstMemPtrs usage considering the proper memory sharing auto newShape = srcMemPtr->getStaticDims(); @@ -60,7 +60,7 @@ bool If::isSupportedOperation(const std::shared_ptr& op, std::st try { if (!one_of(op->get_type_info(), ov::op::v8::If::get_type_info_static())) { errorMessage = "Not supported If operation version " + std::string(op->get_type_info().version_id) + - " with name '" + op->get_friendly_name() + "'. Node If supports only opset8 version."; + " with name '" + op->get_friendly_name() + "'. Node If supports only opset8 version."; return false; } } catch (...) { @@ -69,8 +69,9 @@ bool If::isSupportedOperation(const std::shared_ptr& op, std::st return true; } -If::If(const std::shared_ptr& op, const GraphContext::CPtr context) : - Node(op, context, InternalDynShapeInferFactory()), ovOp(op) { +If::If(const std::shared_ptr& op, const GraphContext::CPtr context) + : Node(op, context, InternalDynShapeInferFactory()), + ovOp(op) { std::string errorMessage; if (!isSupportedOperation(op, errorMessage)) { OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage); @@ -111,49 +112,55 @@ void If::getSupportedDescriptors() { } } - const auto &outMapThen = subGraphThen.GetOutputNodesMap(); + const auto& outMapThen = subGraphThen.GetOutputNodesMap(); for (const auto& out : ifOp->get_then_body()->get_results()) { auto outNode = outMapThen.find(ifOp->get_then_body()->get_result_index(out)); if (outNode != outMapThen.end()) { auto outMem = outNode->second->getSrcMemoryAtPort(0); outputMemThen.push_back(outMem); } else { - OPENVINO_THROW("Then body of node If with name ", getName(), " does not have output with name: ", out->get_friendly_name()); + OPENVINO_THROW("Then body of node If with name ", + getName(), + " does not have output with name: ", + out->get_friendly_name()); } } - const auto &outMapElse = subGraphElse.GetOutputNodesMap(); + const auto& outMapElse = subGraphElse.GetOutputNodesMap(); for (const auto& out : ifOp->get_else_body()->get_results()) { auto outNode = outMapElse.find(ifOp->get_else_body()->get_result_index(out)); if (outNode != outMapElse.end()) { auto outMem = outNode->second->getSrcMemoryAtPort(0); outputMemElse.push_back(outMem); } else { - OPENVINO_THROW("Else body of node If with name ", getName(), " does not have output with name: ", out->get_friendly_name()); + OPENVINO_THROW("Else body of node If with name ", + getName(), + " does not have output with name: ", + out->get_friendly_name()); } } // Port map: outputs for (const auto& desc : ifOp->get_output_descriptions(0)) { auto body_output_idx = desc->m_body_value_index; - thenOutputPortMap.emplace_back(PortMap { - static_cast(desc->m_output_index), static_cast(body_output_idx)}); + thenOutputPortMap.emplace_back( + PortMap{static_cast(desc->m_output_index), static_cast(body_output_idx)}); } for (const auto& desc : ifOp->get_output_descriptions(1)) { auto body_output_idx = desc->m_body_value_index; - elseOutputPortMap.emplace_back(PortMap { - static_cast(desc->m_output_index), static_cast(body_output_idx)}); + elseOutputPortMap.emplace_back( + PortMap{static_cast(desc->m_output_index), static_cast(body_output_idx)}); } for (const auto& desc : ifOp->get_input_descriptions(0)) { auto body_input_index = desc->m_body_parameter_index; - thenInputPortMap.emplace_back(PortMap { - static_cast(desc->m_input_index), static_cast(body_input_index)}); + thenInputPortMap.emplace_back( + PortMap{static_cast(desc->m_input_index), static_cast(body_input_index)}); } for (const auto& desc : ifOp->get_input_descriptions(1)) { auto body_input_index = desc->m_body_parameter_index; - elseInputPortMap.emplace_back(PortMap { - static_cast(desc->m_input_index), static_cast(body_input_index)}); + elseInputPortMap.emplace_back( + PortMap{static_cast(desc->m_input_index), static_cast(body_input_index)}); } } @@ -166,16 +173,17 @@ void If::initSupportedPrimitiveDescriptors() { config.outConfs.reserve(getChildEdges().size()); for (size_t i = 0; i < inputShapes.size(); i++) { - PortConfig dataConf {}; + PortConfig dataConf{}; auto descCreator = BlockedDescCreator::getCommonCreators().at(LayoutType::ncsp); dataConf.setMemDesc(descCreator->createSharedDesc(getOriginalInputPrecisionAtPort(i), getInputShapeAtPort(i))); config.inConfs.emplace_back(dataConf); } for (size_t i = 0; i < outputShapes.size(); i++) { - PortConfig dataConf {}; + PortConfig dataConf{}; auto descCreator = BlockedDescCreator::getCommonCreators().at(LayoutType::ncsp); - dataConf.setMemDesc(descCreator->createSharedDesc(getOriginalOutputPrecisionAtPort(i), getOutputShapeAtPort(i))); + dataConf.setMemDesc( + descCreator->createSharedDesc(getOriginalOutputPrecisionAtPort(i), getOutputShapeAtPort(i))); config.outConfs.push_back(dataConf); } @@ -195,9 +203,9 @@ void If::createPrimitive() { } void If::prepareBeforeMappers(const bool isThen, const dnnl::engine& eng) { - auto &inputPortMap = isThen ? thenInputPortMap : elseInputPortMap; - auto &inputMems = isThen ? inputMemThen : inputMemElse; - auto &beforeMappers = isThen ? beforeThenMappers : beforeElseMappers; + auto& inputPortMap = isThen ? thenInputPortMap : elseInputPortMap; + auto& inputMems = isThen ? inputMemThen : inputMemElse; + auto& beforeMappers = isThen ? beforeThenMappers : beforeElseMappers; for (auto& map_rule : inputPortMap) { auto fromMem = getSrcMemoryAtPort(map_rule.from); auto& toMems = inputMems[map_rule.to]; @@ -216,12 +224,12 @@ void If::prepareBeforeMappers(const bool isThen, const dnnl::engine& eng) { } void If::prepareAfterMappers(const bool isThen, const dnnl::engine& eng) { - auto &outputPortMap = isThen ? thenOutputPortMap : elseOutputPortMap; - auto &outputMems = isThen ? outputMemThen : outputMemElse; - auto &afterMappers = isThen ? afterThenMappers : afterElseMappers; + auto& outputPortMap = isThen ? thenOutputPortMap : elseOutputPortMap; + auto& outputMems = isThen ? outputMemThen : outputMemElse; + auto& afterMappers = isThen ? afterThenMappers : afterElseMappers; for (auto& map_rule : outputPortMap) { auto toMems = getToMemories(this, map_rule.from); - auto &fromMem = outputMems[map_rule.to]; + auto& fromMem = outputMems[map_rule.to]; // Check precision between If node input/output and it's subgrapsh input/output. for (const auto& toMem : toMems) { if (fromMem->getDesc().getPrecision() != toMem->getDesc().getPrecision()) { @@ -250,11 +258,11 @@ void If::execute(dnnl::stream strm) { auto& afterMappers = condition ? afterThenMappers : afterElseMappers; auto& subGraph = condition ? subGraphThen : subGraphElse; - for (auto &mapper : beforeMappers) + for (auto& mapper : beforeMappers) mapper->execute(strm); subGraph.ResetInferCount(); subGraph.Infer(); - for (auto &mapper : afterMappers) + for (auto& mapper : afterMappers) mapper->execute(strm); } @@ -266,6 +274,6 @@ bool If::created() const { return getType() == Type::If; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/if.h b/src/plugins/intel_cpu/src/nodes/if.h index f858c92b0b2651..a2babb45b6c803 100644 --- a/src/plugins/intel_cpu/src/nodes/if.h +++ b/src/plugins/intel_cpu/src/nodes/if.h @@ -4,8 +4,8 @@ #pragma once -#include #include +#include #include #include @@ -25,12 +25,18 @@ class If : public Node { void createPrimitive() override; bool created() const override; void execute(dnnl::stream strm) override; - bool isExecutable() const override { return true; } + bool isExecutable() const override { + return true; + } protected: void executeDynamicImpl(dnnl::stream strm) override; - bool needPrepareParams() const override { return false; }; - bool needShapeInfer() const override { return false; } + bool needPrepareParams() const override { + return false; + }; + bool needShapeInfer() const override { + return false; + } private: void prepareBeforeMappers(const bool isThen, const dnnl::engine& eng); @@ -64,21 +70,14 @@ class If : public Node { std::vector> inputMemThen, inputMemElse; std::deque outputMemThen, outputMemElse; - std::vector> - beforeThenMappers, - beforeElseMappers, - afterThenMappers, + std::vector> beforeThenMappers, beforeElseMappers, afterThenMappers, afterElseMappers; - std::vector - thenInputPortMap, - thenOutputPortMap, - elseInputPortMap, - elseOutputPortMap; + std::vector thenInputPortMap, thenOutputPortMap, elseInputPortMap, elseOutputPortMap; const std::shared_ptr ovOp; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/input.cpp b/src/plugins/intel_cpu/src/nodes/input.cpp index 4ccdc87ada25f1..4bb2f714b284fd 100644 --- a/src/plugins/intel_cpu/src/nodes/input.cpp +++ b/src/plugins/intel_cpu/src/nodes/input.cpp @@ -5,12 +5,12 @@ #include "input.h" #include "cpu/x64/jit_generator.hpp" +#include "memory_desc/cpu_memory_desc_utils.h" #include "nodes/node_config.h" #include "openvino/core/parallel.hpp" #include "openvino/core/shape.hpp" #include "openvino/core/type/element_type.hpp" #include "shape_inference/shape_inference_pass_through.hpp" -#include "memory_desc/cpu_memory_desc_utils.h" using namespace dnnl; using namespace dnnl::impl::cpu::x64; @@ -38,16 +38,14 @@ struct jit_has_subnormals_base : public jit_generator { } fn_t get() { - return jit_ker() || create_kernel() == dnnl::impl::status::success - ? (fn_t)jit_ker() - : nullptr; + return jit_ker() || create_kernel() == dnnl::impl::status::success ? (fn_t)jit_ker() : nullptr; } protected: - void foreach(const Xbyak::Reg64& idx, - size_t step, - const Xbyak::Reg64& end, - std::function && fn) { + void foreach (const Xbyak::Reg64& idx, + size_t step, + const Xbyak::Reg64& end, + std::function && fn) { Label loop, exit; L(loop); @@ -61,75 +59,76 @@ struct jit_has_subnormals_base : public jit_generator { L(exit); } - void copy_floats(const Xbyak::Reg64& dst, - const Xbyak::Reg64& src, - const Xbyak::Reg64& size) { + void copy_floats(const Xbyak::Reg64& dst, const Xbyak::Reg64& src, const Xbyak::Reg64& size) { push(rsi); push(r15); xor_(rsi, rsi); - foreach(rsi, 1, size, [&, this](const Xbyak::Reg64& idx) { + foreach (rsi, 1, size, [&, this](const Xbyak::Reg64& idx) { mov(r15d, dword[src + idx * sizeof(float)]); mov(dword[dst + idx * sizeof(float)], r15d); - }); + }) + ; pop(r15); pop(rsi); } - void check_subnormals(const Xbyak::Reg64& src, const Xbyak::Ymm &exponent_mask, const Xbyak::Ymm &mantissa_mask, const Xbyak::Ymm &zero) { + void check_subnormals(const Xbyak::Reg64& src, + const Xbyak::Ymm& exponent_mask, + const Xbyak::Ymm& mantissa_mask, + const Xbyak::Ymm& zero) { auto a = ymm1; auto b = ymm2; auto c = ymm3; - vmovdqu(a, yword[src]); // load 8 floats - vpand(b, a, mantissa_mask); // b = a & 00000000011111111111111111111111 - vpcmpeqd(b, b, zero); // if (b == 0) b = 1 else b = 0 - vpand(c, a, exponent_mask); // c = a & 01111111100000000000000000000000 - vpcmpeqd(c, c, zero); // if (c == 0) c = 1 else c = 0 - vptest(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0 + vmovdqu(a, yword[src]); // load 8 floats + vpand(b, a, mantissa_mask); // b = a & 00000000011111111111111111111111 + vpcmpeqd(b, b, zero); // if (b == 0) b = 1 else b = 0 + vpand(c, a, exponent_mask); // c = a & 01111111100000000000000000000000 + vpcmpeqd(c, c, zero); // if (c == 0) c = 1 else c = 0 + vptest(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0 } - void check_subnormals(const Xbyak::Reg64& src, const Xbyak::Xmm &exponent_mask, const Xbyak::Xmm &mantissa_mask, const Xbyak::Xmm &zero) { + void check_subnormals(const Xbyak::Reg64& src, + const Xbyak::Xmm& exponent_mask, + const Xbyak::Xmm& mantissa_mask, + const Xbyak::Xmm& zero) { auto a = xmm1; auto b = xmm2; auto c = xmm3; - uni_vmovdqu(a, xword[src]); // load 4 floats - uni_vmovdqu(b, a); // b = a - uni_vmovdqu(c, a); // c = a - uni_vpand(b, b, mantissa_mask); // b = a & 00000000011111111111111111111111 - uni_vpcmpeqd(b, b, zero); // if (b == 0) b = 1 else b = 0 - uni_vpand(c, c, exponent_mask); // c = a & 01111111100000000000000000000000 - uni_vpcmpeqd(c, c, zero); // if (c == 0) c = 1 else c = 0 - uni_vtestps(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0 + uni_vmovdqu(a, xword[src]); // load 4 floats + uni_vmovdqu(b, a); // b = a + uni_vmovdqu(c, a); // c = a + uni_vpand(b, b, mantissa_mask); // b = a & 00000000011111111111111111111111 + uni_vpcmpeqd(b, b, zero); // if (b == 0) b = 1 else b = 0 + uni_vpand(c, c, exponent_mask); // c = a & 01111111100000000000000000000000 + uni_vpcmpeqd(c, c, zero); // if (c == 0) c = 1 else c = 0 + uni_vtestps(b, c); // if ((!b & c) == 0) CF = 1 else CF = 0 } protected: Label exit, has_subnormals, no_subnormals; - const Reg64 ®_src = rax; - const Reg64 ®_dst = rbx; - const Reg64 ®_sz = rdx; - const Reg64 ®_idx = rsi; - const Reg64 ®_mask_addr = r15; + const Reg64& reg_src = rax; + const Reg64& reg_dst = rbx; + const Reg64& reg_sz = rdx; + const Reg64& reg_idx = rsi; + const Reg64& reg_mask_addr = r15; static const uint32_t exponent_mask_data[8]; static const uint32_t mantissa_mask_data[8]; }; -const uint32_t jit_has_subnormals_base::exponent_mask_data[8] = { - 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, - 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000 -}; +const uint32_t jit_has_subnormals_base::exponent_mask_data[8] = + {0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000, 0x7f800000}; -const uint32_t jit_has_subnormals_base::mantissa_mask_data[8] = { - 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, - 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff -}; +const uint32_t jit_has_subnormals_base::mantissa_mask_data[8] = + {0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff, 0x007fffff}; -template +template struct jit_has_subnormals : public jit_has_subnormals_base { using Vmm = typename dnnl::impl::utils::conditional::type; @@ -138,7 +137,7 @@ struct jit_has_subnormals : public jit_has_subnormals_base { const Vmm rmm6 = Vmm(6); const int length = isa == sse41 ? 4 : 8; - void generate() override final { // NOLINT + void generate() override final { // NOLINT size_t const vlen = length; const int sh_bits = std::ilogb(vlen); @@ -165,11 +164,12 @@ struct jit_has_subnormals : public jit_has_subnormals_base { mov(r8, reg_sz); shr(r8, sh_bits); - foreach(reg_idx, 1, r8, [&, this](const Xbyak::Reg64& idx) { + foreach (reg_idx, 1, r8, [&, this](const Xbyak::Reg64& idx) { check_subnormals(reg_src, exponent_mask, mantissa_mask, zero); jnc(has_subnormals); add(reg_src, sizeof(float) * vlen); - }); + }) + ; // Tail shl(reg_idx, sh_bits); @@ -216,11 +216,11 @@ jit_has_subnormals_base::fn_t jit_has_subnormals_function() { return nullptr; } -} // namespace +} // namespace #endif Input::Input(const std::shared_ptr& op, const GraphContext::CPtr context) - : Node(op, context, PassThroughShapeInferFactory()) { + : Node(op, context, PassThroughShapeInferFactory()) { if (!one_of(op->get_type_info(), op::v0::Parameter::get_type_info_static(), op::v0::Constant::get_type_info_static(), @@ -260,7 +260,7 @@ void Input::cloneBlobIfRequired() { needFlushDenormalsToZero = false; } - auto cloneBlob = [&, this] () { + auto cloneBlob = [&, this]() { MemoryPtr memory; // CVS-74980 @@ -269,7 +269,8 @@ void Input::cloneBlobIfRequired() { // in that case we make a copy to avoid overflow if (m_constOp->get_byte_size() >= memDesc.getCurrentMemSize()) { if (m_constOp->get_element_type() == element::string) { - memory = std::make_shared(getEngine(), memDesc, m_constOp->get_data_ptr()); + memory = + std::make_shared(getEngine(), memDesc, m_constOp->get_data_ptr()); } else { memory = std::make_shared(getEngine(), memDesc, m_constOp->get_data_ptr()); } @@ -296,12 +297,12 @@ void Input::cloneBlobIfRequired() { return ptr; }; - auto isBlobAligned = [] (const std::shared_ptr& constant) { + auto isBlobAligned = [](const std::shared_ptr& constant) { #if defined(OPENVINO_ARCH_X86) || defined(OPENVINO_ARCH_X86_64) // Majority of arithmetic and data processing instructions in legacy SSE isa requires // the memory address in the operands must be aligned on 16-byte boundary. To ensure // safely reusing ngraph const blob memory, need to check address alignment. - const void *ptr = constant->get_data_ptr(); + const void* ptr = constant->get_data_ptr(); return mayiuse(cpu_isa_t::avx2) || ((reinterpret_cast(ptr) & 15) == 0); #else return true; @@ -309,9 +310,9 @@ void Input::cloneBlobIfRequired() { }; // The presence of subnormals is better to determined at IR read time. - auto hasSubnormals = [&] () { + auto hasSubnormals = [&]() { if (prec == ov::element::f32) { - uint32_t const *u32data = m_constOp->get_data_ptr(); + uint32_t const* u32data = m_constOp->get_data_ptr(); if (!size) return false; @@ -325,11 +326,9 @@ void Input::cloneBlobIfRequired() { parallel_for(iterations_num, [&](int n) { auto ptr = u32data + n * batch_size; - const jit_has_subnormals_base::args_t args = { - reinterpret_cast(ptr), - std::min(batch_size, (size_t)(u32data + size - ptr)), - false - }; + const jit_has_subnormals_base::args_t args = {reinterpret_cast(ptr), + std::min(batch_size, (size_t)(u32data + size - ptr)), + false}; fn(&args); @@ -352,12 +351,10 @@ void Input::cloneBlobIfRequired() { return false; }; - auto blobKey = [&] () { + auto blobKey = [&]() { char ptr[32]; snprintf(ptr, sizeof ptr, "%p", m_constOp->get_data_ptr()); - return getName() - + "_" + std::to_string(size * prec.size()) - + "_" + ptr; + return getName() + "_" + std::to_string(size * prec.size()) + "_" + ptr; }; const auto weightCache = context->getWeightsCache(); @@ -368,7 +365,8 @@ void Input::cloneBlobIfRequired() { isBlobAligned(m_constOp) && (!needFlushDenormalsToZero || !hasSubnormals()) && // Blob should be cloned in cache only if original weights are stored on other numa node. // This is possible only in multistream case on multisocket machine. - // TODO: don't clone blob for multisocket + multistream case if current stream is run on the numa node where original weights are stored. + // TODO: don't clone blob for multisocket + multistream case if current stream is run on the numa node where + // original weights are stored. (!weightCache || context->getNumNumaNodes() == 1 || context->getCPUStreamExecutor()->get_streams_num() == 1); memoryPtr = clone_is_not_needed ? std::make_shared(getEngine(), memDesc, m_constOp->get_data_ptr()) @@ -376,29 +374,25 @@ void Input::cloneBlobIfRequired() { weightCache ? *weightCache->findOrCreate(blobKey(), cloneBlob) : cloneBlob()); } -static std::vector createInputShapes(const Shape& shape, - const Type type) { +static std::vector createInputShapes(const Shape& shape, const Type type) { if (type == Type::Output) return {shape}; return {}; } -static std::vector createOutputShapes(const Shape& shape, - const Type type) { +static std::vector createOutputShapes(const Shape& shape, const Type type) { if (type == Type::Input) return {shape}; return {}; } -static std::vector createInputPrecisions(const ov::element::Type& prc, - const Type type) { +static std::vector createInputPrecisions(const ov::element::Type& prc, const Type type) { if (type == Type::Output) return {prc}; return {}; } -static std::vector createOutputPrecisions(const ov::element::Type& prc, - const Type type) { +static std::vector createOutputPrecisions(const ov::element::Type& prc, const Type type) { if (type == Type::Input) return {prc}; return {}; @@ -428,17 +422,13 @@ Input::Input(MemoryDescPtr memDesc, const std::string& name, const std::string& extMemDesc = memDesc; } -Input::Input(const std::shared_ptr& op, - const GraphContext::CPtr context, - InputConfig config) +Input::Input(const std::shared_ptr& op, const GraphContext::CPtr context, InputConfig config) : Input(op, context) { extMemDesc = config.desc; m_isInPlace = config.inPlace; } -Input::Input(const std::shared_ptr& op, - const GraphContext::CPtr context, - OutputConfig config) +Input::Input(const std::shared_ptr& op, const GraphContext::CPtr context, OutputConfig config) : Input(op, context) { extMemDesc = config.desc; m_useParentMemoryDescForOutput = config.useParentMemoryDescForOutput; @@ -499,17 +489,23 @@ void Input::createPrimitive() { for (size_t i = 0; i < getChildEdges().size(); i++) { auto dstMemPtr = getDstMemoryAtPort(i); if (!dstMemPtr) - THROW_CPU_NODE_ERR("has null memory object at port ", i, - " to node ", getChildEdgeAt(i)->getChild()->getName(), "."); + THROW_CPU_NODE_ERR("has null memory object at port ", + i, + " to node ", + getChildEdgeAt(i)->getChild()->getName(), + "."); } for (size_t i = 0; i < getParentEdges().size(); i++) { auto srcMemPtr = getSrcMemoryAtPort(i); if (!srcMemPtr) - THROW_CPU_NODE_ERR("has null memory object at port ", i, - " from node ", getParentEdgeAt(i)->getParent()->getName(), "."); + THROW_CPU_NODE_ERR("has null memory object at port ", + i, + " from node ", + getParentEdgeAt(i)->getParent()->getName(), + "."); } - const NodeDesc *selected_pd = getSelectedPrimitiveDescriptor(); + const NodeDesc* selected_pd = getSelectedPrimitiveDescriptor(); if (selected_pd == nullptr) THROW_CPU_NODE_ERR("doesn't have selected primitive descriptor."); } @@ -535,9 +531,7 @@ void Input::initSupportedPdDefault() { inPortConfs.push_back({LayoutType::ncsp, precision}); } - addSupportedPrimDesc(inPortConfs, - outPortConfs, - impl_desc_type::unknown); + addSupportedPrimDesc(inPortConfs, outPortConfs, impl_desc_type::unknown); } void Input::initSupportedPdFromMemDesc() { @@ -553,6 +547,6 @@ void Input::initSupportedPdFromMemDesc() { supportedPrimitiveDescriptors.emplace_back(std::move(config), impl_desc_type::unknown); } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/input.h b/src/plugins/intel_cpu/src/nodes/input.h index e659ea2359aabd..6d1f4c27238540 100644 --- a/src/plugins/intel_cpu/src/nodes/input.h +++ b/src/plugins/intel_cpu/src/nodes/input.h @@ -5,6 +5,7 @@ #pragma once #include + #include namespace ov { @@ -42,13 +43,9 @@ class Input : public Node { Input(MemoryDescPtr memDesc, const std::string& name, const std::string& type, const GraphContext::CPtr context); - Input(const std::shared_ptr& op, - const GraphContext::CPtr context, - InputConfig config); + Input(const std::shared_ptr& op, const GraphContext::CPtr context, InputConfig config); - Input(const std::shared_ptr& op, - const GraphContext::CPtr context, - OutputConfig config); + Input(const std::shared_ptr& op, const GraphContext::CPtr context, OutputConfig config); void getSupportedDescriptors() override; void initSupportedPrimitiveDescriptors() override; @@ -66,8 +63,12 @@ class Input : public Node { return false; } - bool needShapeInfer() const override { return false; } - bool needPrepareParams() const override { return false; } + bool needShapeInfer() const override { + return false; + } + bool needPrepareParams() const override { + return false; + } private: void cloneBlobIfRequired(); @@ -83,6 +84,6 @@ class Input : public Node { bool m_isInPlace = false; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/interaction.cpp b/src/plugins/intel_cpu/src/nodes/interaction.cpp index 5ec48e7a263272..905724c3bc829a 100644 --- a/src/plugins/intel_cpu/src/nodes/interaction.cpp +++ b/src/plugins/intel_cpu/src/nodes/interaction.cpp @@ -4,7 +4,10 @@ #include "interaction.h" -#include "transformations/cpu_opset/x64/op/interaction.hpp" +#include +#include +#include + #include "common/bfloat16.hpp" #include "common/cpu_memcpy.h" #include "cpu/x64/cpu_isa_traits.hpp" @@ -16,10 +19,7 @@ #include "memory_desc/dnnl_blocked_memory_desc.h" #include "nodes/common/cpu_convert.h" #include "onednn/dnnl.h" - -#include -#include -#include +#include "transformations/cpu_opset/x64/op/interaction.hpp" using namespace dnnl::impl::cpu::x64; using namespace Xbyak; @@ -36,7 +36,9 @@ template struct jit_move_scale_kernel : public jit_uni_move_scale_kernel, public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_move_scale_kernel) - explicit jit_move_scale_kernel(const jit_move_scale_compile_params& jcp) : jit_uni_move_scale_kernel(jcp), jit_generator(jit_name()) { + explicit jit_move_scale_kernel(const jit_move_scale_compile_params& jcp) + : jit_uni_move_scale_kernel(jcp), + jit_generator(jit_name()) { runtime_prc = jcp_.src_prc == ov::element::bf16 ? ov::element::bf16 : ov::element::f32; if (jcp_.dst_prc == ov::element::i8 || jcp_.dst_prc == ov::element::u8) runtime_prc = ov::element::f32; @@ -50,12 +52,13 @@ struct jit_move_scale_kernel : public jit_uni_move_scale_kernel, public jit_gene } private: - using Vmm = typename dnnl::impl::utils::conditional3::type; + using Vmm = + typename dnnl::impl::utils::conditional3::type; void generate() override { this->preamble(); -#define GET_OFF(field) offsetof(jit_move_scale_call_args, field) +# define GET_OFF(field) offsetof(jit_move_scale_call_args, field) mov(reg_in, ptr[reg_params + GET_OFF(p_in)]); mov(reg_out, ptr[reg_params + GET_OFF(p_out)]); mov(reg_work_amount, jcp_.input_size); @@ -107,7 +110,7 @@ struct jit_move_scale_kernel : public jit_uni_move_scale_kernel, public jit_gene if (jcp_.with_scales) { if (!jcp_.broadcast_scales) { load(vmm_scales, reg_scales, ov::element::f32, ov::element::f32, step, false); - add(reg_scales, sizeof(float) * step); + add(reg_scales, sizeof(float) * step); } uni_vmulps(vmm_in, vmm_in, vmm_scales); } @@ -119,25 +122,39 @@ struct jit_move_scale_kernel : public jit_uni_move_scale_kernel, public jit_gene add(reg_out_aux, jcp_.dst_prc.size() * step); } } -#undef GET_OFF - - inline void load(const Vmm& vmm_dst, const Xbyak::Reg64& reg_src, ov::element::Type src_prc, ov::element::Type dst_prc, const int& elt_num, bool fill) { +# undef GET_OFF + + inline void load(const Vmm& vmm_dst, + const Xbyak::Reg64& reg_src, + ov::element::Type src_prc, + ov::element::Type dst_prc, + const int& elt_num, + bool fill) { const auto seed = load_emitter_params(src_prc, dst_prc, elt_num, fill, "float_min").hash(); if (!emitters[seed]) { - emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num, src_prc, fill, "float_min")); + emitters[seed].reset( + new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num, src_prc, fill, "float_min")); } - emitters[seed]->emit_code({static_cast(reg_src.getIdx()), 0}, {static_cast(vmm_dst.getIdx())}, - pool_aux_vmm_idxs, pool_aux_gpr_idxs); + emitters[seed]->emit_code({static_cast(reg_src.getIdx()), 0}, + {static_cast(vmm_dst.getIdx())}, + pool_aux_vmm_idxs, + pool_aux_gpr_idxs); } - inline void store(const Xbyak::Reg64& reg_dst, const Vmm& vmm_src, ov::element::Type src_prc, ov::element::Type dst_prc, const int& elt_num) { + inline void store(const Xbyak::Reg64& reg_dst, + const Vmm& vmm_src, + ov::element::Type src_prc, + ov::element::Type dst_prc, + const int& elt_num) { const auto seed = store_emitter_params(src_prc, dst_prc, elt_num).hash(); if (!emitters[seed]) { emitters[seed].reset(new jit_store_emitter(this, isa, src_prc, dst_prc, elt_num)); } - emitters[seed]->emit_code({static_cast(vmm_src.getIdx())}, {static_cast(reg_dst.getIdx())}, - pool_aux_vmm_idxs, pool_aux_gpr_idxs); + emitters[seed]->emit_code({static_cast(vmm_src.getIdx())}, + {static_cast(reg_dst.getIdx())}, + pool_aux_vmm_idxs, + pool_aux_gpr_idxs); } size_t vec_size; @@ -155,13 +172,14 @@ struct jit_move_scale_kernel : public jit_uni_move_scale_kernel, public jit_gene Reg64 reg_work_amount = r14; Reg64 reg_params = abi_param1; - const std::vector pool_aux_gpr_idxs = { static_cast(rsi.getIdx()), static_cast(rbp.getIdx()) }; - const std::vector pool_aux_vmm_idxs = { static_cast(xmm_tmp.getIdx()) }; + const std::vector pool_aux_gpr_idxs = {static_cast(rsi.getIdx()), + static_cast(rbp.getIdx())}; + const std::vector pool_aux_vmm_idxs = {static_cast(xmm_tmp.getIdx())}; std::unordered_map> emitters; }; -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 Interaction::Interaction(const std::shared_ptr& op, const GraphContext::CPtr context) : Node(op, context, NgraphShapeInferFactory(op)) { @@ -174,7 +192,7 @@ Interaction::Interaction(const std::shared_ptr& op, const GraphContext const std::vector& scales = interaction->get_output_scales(); if (!scales.empty()) { fqScales = scales; - outputDataType = interaction->get_output_element_type(0); + outputDataType = interaction->get_output_element_type(0); } } @@ -194,23 +212,12 @@ void Interaction::initSupportedPrimitiveDescriptors() { // initialize input ports std::vector inPortConfigs; for (size_t i = 0; i < getParentEdges().size(); ++i) { - inPortConfigs.emplace_back( - LayoutType::ncsp, - dataPrecision, - getInputShapeAtPort(i), - false, -1); + inPortConfigs.emplace_back(LayoutType::ncsp, dataPrecision, getInputShapeAtPort(i), false, -1); } // initialize output port std::vector outPortConfigs = { - PortConfigurator { - LayoutType::ncsp, - outputDataType, - getOutputShapeAtPort(0), - false, - -1 - } - }; - //add descriptor + PortConfigurator{LayoutType::ncsp, outputDataType, getOutputShapeAtPort(0), false, -1}}; + // add descriptor addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any); } @@ -221,8 +228,7 @@ static inline void cat(uint8_t* out, size_t elemSize) { size_t offset = 0; for (size_t j = 0; j < feature_sizes.size(); j++) { - cpu_memcpy(out + offset * elemSize, in[j] + bs * feature_sizes[j] * elemSize, - feature_sizes[j] * elemSize); + cpu_memcpy(out + offset * elemSize, in[j] + bs * feature_sizes[j] * elemSize, feature_sizes[j] * elemSize); offset += feature_sizes[j]; } } @@ -303,8 +309,7 @@ void Interaction::prepareParams() { auto matmul_pd = matmul::primitive_desc(getEngine(), src_md, weights_md, dst_md, matmul_attr); prim = matmul(matmul_pd); featureSizes.assign(inputSizes, featureSize); - auto initMemoryPtr = [&](const ov::element::Type& prc, const intel_cpu::Shape& shape, - MemoryPtr& ptr) { + auto initMemoryPtr = [&](const ov::element::Type& prc, const intel_cpu::Shape& shape, MemoryPtr& ptr) { ptr = std::make_shared(getEngine(), intel_cpu::DnnlBlockedMemoryDesc(prc, shape)); }; initMemoryPtr(dataPrecision, intel_cpu::Shape{inputSizes, featureSize}, inputMemPtr); @@ -336,7 +341,7 @@ void Interaction::prepareParams() { moveFeatureKernel.reset(new jit_move_scale_kernel(jcp)); moveInteractKernel.reset(new jit_move_scale_kernel(interJcp)); } -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 if (moveFeatureKernel && moveInteractKernel) { moveFeatureKernel->create_ker(); @@ -360,8 +365,7 @@ bool Interaction::isExecutable() const { return true; } -bool Interaction::isSupportedOperation(const std::shared_ptr& op, - std::string& errorMessage) noexcept { +bool Interaction::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { const auto interaction = std::dynamic_pointer_cast(op); if (!interaction) { @@ -374,7 +378,6 @@ bool Interaction::isSupportedOperation(const std::shared_ptr& op return true; } - -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/interaction.h b/src/plugins/intel_cpu/src/nodes/interaction.h index 448484a2512dd1..794ea0af24a87c 100644 --- a/src/plugins/intel_cpu/src/nodes/interaction.h +++ b/src/plugins/intel_cpu/src/nodes/interaction.h @@ -19,31 +19,31 @@ struct jit_move_scale_compile_params { }; struct jit_move_scale_call_args { - const void *p_in; - void *p_out; - const void *p_scales; + const void* p_in; + void* p_out; + const void* p_scales; }; struct jit_uni_move_scale_kernel { - void (*ker_)(const jit_move_scale_call_args*); + void (*ker_)(const jit_move_scale_call_args*); - void operator()(const jit_move_scale_call_args* call_args) { - assert(ker_); - ker_(call_args); - } + void operator()(const jit_move_scale_call_args* call_args) { + assert(ker_); + ker_(call_args); + } - explicit jit_uni_move_scale_kernel(const jit_move_scale_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {} - virtual ~jit_uni_move_scale_kernel() {} + explicit jit_uni_move_scale_kernel(const jit_move_scale_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {} + virtual ~jit_uni_move_scale_kernel() {} - virtual void create_ker() = 0; + virtual void create_ker() = 0; - jit_move_scale_compile_params jcp_; + jit_move_scale_compile_params jcp_; }; class Interaction : public Node { public: Interaction(const std::shared_ptr& op, const GraphContext::CPtr context); - void getSupportedDescriptors() override {}; + void getSupportedDescriptors() override{}; void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; bool created() const override; @@ -74,6 +74,6 @@ class Interaction : public Node { std::unique_ptr moveInteractKernel; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/interpolate.cpp b/src/plugins/intel_cpu/src/nodes/interpolate.cpp index 37008ee17a9603..beb53cb89a831e 100644 --- a/src/plugins/intel_cpu/src/nodes/interpolate.cpp +++ b/src/plugins/intel_cpu/src/nodes/interpolate.cpp @@ -4,6 +4,10 @@ #include "interpolate.h" +#include +#include +#include + #include "common/cpu_memcpy.h" #include "cpu/x64/injectors/jit_uni_depthwise_injector.hpp" #include "cpu/x64/injectors/jit_uni_eltwise_injector.hpp" @@ -26,10 +30,6 @@ #include "utils/cpu_utils.hpp" #include "utils/ngraph_utils.hpp" -#include -#include -#include - using namespace dnnl; using namespace dnnl::impl; @@ -38,7 +38,6 @@ using namespace dnnl::impl::cpu::x64; using namespace dnnl::impl::utils; using namespace Xbyak; - #define GET_OFF(field) offsetof(jit_interpolate_call_args, field) namespace ov { @@ -55,8 +54,9 @@ template struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, public jit_generator { DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_interpolate_kernel_f32) - explicit jit_uni_interpolate_kernel_f32(jit_interpolate_config_params jcp, const dnnl_primitive_attr &attr) - : jit_uni_interpolate_kernel(jcp, attr), jit_generator(jit_name()) {} + explicit jit_uni_interpolate_kernel_f32(jit_interpolate_config_params jcp, const dnnl_primitive_attr& attr) + : jit_uni_interpolate_kernel(jcp, attr), + jit_generator(jit_name()) {} void create_ker() override { jit_generator::create_kernel(); @@ -69,23 +69,24 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi store_pool_gpr_idxs = {static_cast(reg_tmp_64.getIdx())}; store_pool_vec_idxs = {static_cast(vmm_zero.getIdx())}; - const auto &p = attr_.post_ops_; + const auto& p = attr_.post_ops_; for (int i = 0; i < p.len(); i++) { - auto &post_op = p.entry_[i]; + auto& post_op = p.entry_[i]; if (post_op.is_eltwise()) { - eltwise_injectors.push_back(std::make_shared>( - this, - post_op.eltwise.alg, - post_op.eltwise.alpha, - post_op.eltwise.beta, - 1.f)); + eltwise_injectors.push_back(std::make_shared>(this, + post_op.eltwise.alg, + post_op.eltwise.alpha, + post_op.eltwise.beta, + 1.f)); } else if (post_op.is_depthwise()) { - depthwise_injectors.push_back(std::make_shared>( - this, - post_op)); + depthwise_injectors.push_back(std::make_shared>(this, post_op)); } else if (post_op.is_quantization()) { - quantization_injectors.push_back(std::make_shared>( - this, post_op, vmm_d_weights, vmm_d_bias, reg_d_weights, reg_d_bias)); + quantization_injectors.push_back(std::make_shared>(this, + post_op, + vmm_d_weights, + vmm_d_bias, + reg_d_weights, + reg_d_bias)); } } @@ -98,81 +99,82 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi uni_vpxor(vmm_zero, vmm_zero, vmm_zero); switch (jcp_.mode) { - case InterpolateMode::nearest: { - mov(reg_dst, ptr[reg_params + GET_OFF(dst)]); - mov(reg_src, ptr[reg_params + GET_OFF(src_ptr[0])]); - mov(reg_index, ptr[reg_params + GET_OFF(index)]); - mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); - - switch (jcp_.layout) { - case InterpolateLayoutType::planar: { - nn_planar(); - break; - } - case InterpolateLayoutType::block: { - nn_blk(); - break; - } - case InterpolateLayoutType::by_channel: { - nn_by_channel(); - break; - } - default: - assert(!"unsupported memory layout for interpolate layer with nearest neighbor mode."); - } + case InterpolateMode::nearest: { + mov(reg_dst, ptr[reg_params + GET_OFF(dst)]); + mov(reg_src, ptr[reg_params + GET_OFF(src_ptr[0])]); + mov(reg_index, ptr[reg_params + GET_OFF(index)]); + mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); + + switch (jcp_.layout) { + case InterpolateLayoutType::planar: { + nn_planar(); break; } - case InterpolateMode::linear_onnx: { - switch (jcp_.layout) { - case InterpolateLayoutType::planar: { - linear_onnx_planar(); - break; - } - case InterpolateLayoutType::block: - case InterpolateLayoutType::by_channel: { - linear_onnx_c_gathered(); - break; - } - default: - assert(!"unsupported memory layout for interpolate layer with linear_onnx mode."); - } + case InterpolateLayoutType::block: { + nn_blk(); break; } - case InterpolateMode::cubic: { - switch (jcp_.layout) { - case InterpolateLayoutType::planar: { - cubic_planar(); - break; - } - case InterpolateLayoutType::block: - case InterpolateLayoutType::by_channel: { - cubic_c_gathered(); - break; - } - default: - assert(!"unsupported memory layout for interpolate layer with cubic mode."); - } + case InterpolateLayoutType::by_channel: { + nn_by_channel(); break; } - case InterpolateMode::bilinear_pillow: - case InterpolateMode::bicubic_pillow: { - switch (jcp_.layout) { - case InterpolateLayoutType::by_channel: { - pillow_by_channel(); - break; - } - default: - assert(!"unsupported memory layout for interpolate layer with bilinear_pillow and bicubic_pillow modes."); - } + default: + assert(!"unsupported memory layout for interpolate layer with nearest neighbor mode."); + } + break; + } + case InterpolateMode::linear_onnx: { + switch (jcp_.layout) { + case InterpolateLayoutType::planar: { + linear_onnx_planar(); + break; + } + case InterpolateLayoutType::block: + case InterpolateLayoutType::by_channel: { + linear_onnx_c_gathered(); break; } - case InterpolateMode::linear: { - assert(!"unsupported mode for interpolate layer with JITTED implimentation."); + default: + assert(!"unsupported memory layout for interpolate layer with linear_onnx mode."); + } + break; + } + case InterpolateMode::cubic: { + switch (jcp_.layout) { + case InterpolateLayoutType::planar: { + cubic_planar(); + break; + } + case InterpolateLayoutType::block: + case InterpolateLayoutType::by_channel: { + cubic_c_gathered(); + break; + } + default: + assert(!"unsupported memory layout for interpolate layer with cubic mode."); + } + break; + } + case InterpolateMode::bilinear_pillow: + case InterpolateMode::bicubic_pillow: { + switch (jcp_.layout) { + case InterpolateLayoutType::by_channel: { + pillow_by_channel(); break; } - default: { - assert(!"unsupported mode for interpolate layer."); + default: + assert( + !"unsupported memory layout for interpolate layer with bilinear_pillow and bicubic_pillow modes."); } + break; + } + case InterpolateMode::linear: { + assert(!"unsupported mode for interpolate layer with JITTED implimentation."); + break; + } + default: { + assert(!"unsupported mode for interpolate layer."); + } } this->postamble(); @@ -186,8 +188,8 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } private: - using Vmm = typename conditional3::type; + using Vmm = + typename conditional3::type; const int vlen = cpu_isa_traits::vlen; const int vector_step = vlen / sizeof(float); @@ -216,7 +218,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi // for cubic planar Xbyak::Reg64 reg_tbl_y = rsi; Xbyak::Reg64 reg_tbl_x = rbp; - Xbyak::Reg64 reg_table = rdx; // do not need reg_index_offset in this mode, so use rdx + Xbyak::Reg64 reg_table = rdx; // do not need reg_index_offset in this mode, so use rdx Vmm vmm_val = Vmm(1); Vmm vmm_index = Vmm(0); @@ -292,14 +294,21 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi emit_load(reg_src, vmm_src, ov::element::f32, ov::element::f32, elt_num, offset); } - inline void emit_load(Xbyak::Reg64 reg_src, Vmm vmm_src, ov::element::Type src_prc, ov::element::Type dst_prc, const int elt_num, const int offset = 0) { + inline void emit_load(Xbyak::Reg64 reg_src, + Vmm vmm_src, + ov::element::Type src_prc, + ov::element::Type dst_prc, + const int elt_num, + const int offset = 0) { const auto seed = load_emitter_params(src_prc, dst_prc, elt_num).hash(); if (!emitters[seed]) { emitters[seed].reset(new jit_load_emitter(this, isa, src_prc, dst_prc, elt_num)); } emitters[seed]->emit_code({static_cast(reg_src.getIdx()), static_cast(offset)}, - {static_cast(vmm_src.getIdx())}, {}, {load_pool_gpr_idxs}); + {static_cast(vmm_src.getIdx())}, + {}, + {load_pool_gpr_idxs}); } inline void store(Vmm vmm_dst, Xbyak::Reg64 reg_dst, const int elt_num, const int offset = 0) { @@ -309,12 +318,15 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } // for cases when Store emitter need 2 aux vmm we can use vmm_dst as second aux vmm - std::vector local_store_pool_vec_idxs = { static_cast(vmm_dst.getIdx()) }; - local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), store_pool_vec_idxs.begin(), store_pool_vec_idxs.end()); + std::vector local_store_pool_vec_idxs = {static_cast(vmm_dst.getIdx())}; + local_store_pool_vec_idxs.insert(local_store_pool_vec_idxs.begin(), + store_pool_vec_idxs.begin(), + store_pool_vec_idxs.end()); emitters[seed]->emit_code({static_cast(vmm_dst.getIdx())}, {static_cast(reg_dst.getIdx()), static_cast(offset)}, - {local_store_pool_vec_idxs}, {store_pool_gpr_idxs}); + {local_store_pool_vec_idxs}, + {store_pool_gpr_idxs}); } // kernel for OH * OW * C @@ -397,9 +409,10 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } // if int, round if (!isFloatCompatible(jcp_.src_prc)) { - uni_vroundps(vmm_dst, vmm_dst, 0x0); // Round near + uni_vroundps(vmm_dst, vmm_dst, 0x0); // Round near } - // src_prc, dst_prc and buf ov::element::Type is the same, otherwise need another store with buf(src) precision + // src_prc, dst_prc and buf ov::element::Type is the same, otherwise need another store with + // buf(src) precision store(vmm_dst, reg_dst_aux, vector_step); add(reg_dst_aux, vector_step * jcp_.src_data_size); // advance 8/16 faciliate next block @@ -415,7 +428,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi uni_vfmadd231ps(vmm_dst, vmm_val, vmm_weight); } if (!isFloatCompatible(jcp_.src_prc)) { - uni_vroundps(vmm_dst, vmm_dst, 0x0); // Round near + uni_vroundps(vmm_dst, vmm_dst, 0x0); // Round near } store(vmm_dst, reg_dst_aux, tail_num); add(reg_dst_aux, tail_num * jcp_.src_data_size); @@ -447,7 +460,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi uni_vfmadd231ps(vmm_dst, vmm_val, vmm_weight); } if (!isFloatCompatible(jcp_.src_prc)) { - uni_vroundps(vmm_dst, vmm_dst, 0x0); // Round near + uni_vroundps(vmm_dst, vmm_dst, 0x0); // Round near } store(vmm_dst, reg_dst, vector_step); add(reg_dst, vector_step * jcp_.dst_data_size); @@ -463,7 +476,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi uni_vfmadd231ps(vmm_dst, vmm_val, vmm_weight); } if (!isFloatCompatible(jcp_.src_prc)) { - uni_vroundps(vmm_dst, vmm_dst, 0x0); // Round near + uni_vroundps(vmm_dst, vmm_dst, 0x0); // Round near } store(vmm_dst, reg_dst, tail_num); add(reg_dst, tail_num * jcp_.dst_data_size); @@ -495,7 +508,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi cmp(reg_work_amount_oh, 1); jl(out_loop_end, T_NEAR); - //reset work_amount to OW + // reset work_amount to OW mov(reg_work_amount, jcp_.OW); Xbyak::Reg64 reg_src_h = rsi; @@ -512,7 +525,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi Xbyak::Label nn_tail_loop_label; Xbyak::Label nn_tail_loop_end_label; - L(nn_loop_label); // inner loop + L(nn_loop_label); // inner loop { cmp(reg_work_amount, vector_step); jl(nn_loop_end_label, T_NEAR); @@ -552,9 +565,9 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi jmp(nn_tail_loop_label, T_NEAR); } - L(nn_tail_loop_end_label); // inner loop end + L(nn_tail_loop_end_label); // inner loop end - //increment index_h to next row + // increment index_h to next row add(reg_index_h, jcp_.indices_size); sub(reg_work_amount_oh, 1); @@ -620,7 +633,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi cmp(reg_work_amount_out, 1); jl(out_loop_end, T_NEAR); - //inner loop for C + // inner loop for C Xbyak::Label nn_loop_label; Xbyak::Label nn_loop_end_label; Xbyak::Label nn_tail_loop_label; @@ -716,10 +729,12 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi mov(reg_work_amount, ptr[reg_params + GET_OFF(work_amount)]); int blk = (isa == cpu::x64::sse41) ? (2 * vector_step) : vector_step; - int dst_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (vector_step * jcp_.dst_data_size) : - (blk * jcp_.OW * jcp_.OH * jcp_.OD * jcp_.dst_data_size); - int src_stride = (jcp_.layout == InterpolateLayoutType::by_channel) ? (vector_step * jcp_.src_data_size) : - (blk * jcp_.IW * jcp_.IH * jcp_.ID * jcp_.src_data_size); + int dst_stride = (jcp_.layout == InterpolateLayoutType::by_channel) + ? (vector_step * jcp_.dst_data_size) + : (blk * jcp_.OW * jcp_.OH * jcp_.OD * jcp_.dst_data_size); + int src_stride = (jcp_.layout == InterpolateLayoutType::by_channel) + ? (vector_step * jcp_.src_data_size) + : (blk * jcp_.IW * jcp_.IH * jcp_.ID * jcp_.src_data_size); Xbyak::Label main_loop_label; Xbyak::Label main_loop_end_label; @@ -757,8 +772,10 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi // 2d for end depth linear_onnx_worker_2d(); // 3th dimension - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight - uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight + uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight + uni_vfmadd231ps(vmm_valTR, + vmm_d_bias, + vmm_weightF); // start_value * start_weight + end_value * end_weight } if (attr_.post_ops_.len() != 0) { @@ -788,8 +805,10 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi // 2d for end depth linear_onnx_worker_2d(); // 3th dimension - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight - uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight + uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight + uni_vfmadd231ps(vmm_valTR, + vmm_d_bias, + vmm_weightF); // start_value * start_weight + end_value * end_weight } if (attr_.post_ops_.len() != 0) { @@ -813,9 +832,9 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi add(reg_src_aux7, src_stride); } if (jcp_.layout == InterpolateLayoutType::by_channel) { - sub(reg_work_amount, vector_step); // work_amount is c + sub(reg_work_amount, vector_step); // work_amount is c } else { - sub(reg_work_amount, 1); // work_amount = div_up(c, blk), no tails + sub(reg_work_amount, 1); // work_amount = div_up(c, blk), no tails } jmp(main_loop_label, T_NEAR); @@ -843,8 +862,10 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi // 2d for end depth linear_onnx_worker_2d(); // 3th dimension - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight - uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight + uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight + uni_vfmadd231ps(vmm_valTR, + vmm_d_bias, + vmm_weightF); // start_value * start_weight + end_value * end_weight } if (attr_.post_ops_.len() != 0) { @@ -929,8 +950,10 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi load_weights(reg_src_aux, vmm_weightE, vector_step, 5 * weight_stride); load_weights(reg_src_aux, vmm_weightF, vector_step, 4 * weight_stride); - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight - uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight + uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight + uni_vfmadd231ps(vmm_valTR, + vmm_d_bias, + vmm_weightF); // start_value * start_weight + end_value * end_weight } if (attr_.post_ops_.len() != 0) { @@ -1013,8 +1036,10 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi load_weights(reg_src_aux, vmm_weightE, scalar_step, 5 * weight_stride); load_weights(reg_src_aux, vmm_weightF, scalar_step, 4 * weight_stride); - uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight - uni_vfmadd231ps(vmm_valTR, vmm_d_bias, vmm_weightF); // start_value * start_weight + end_value * end_weight + uni_vmulps(vmm_valTR, vmm_valTR, vmm_weightE); // end_value * end_weight + uni_vfmadd231ps(vmm_valTR, + vmm_d_bias, + vmm_weightF); // start_value * start_weight + end_value * end_weight } if (attr_.post_ops_.len() != 0) { @@ -1089,7 +1114,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi cubic_c_gathered_matrix(false); if (attr_.post_ops_.len() != 0) { - apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value to post_ops and store + apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value to post_ops and store add(reg_oc_off, vector_step * sizeof(float)); } store(vmm_val, reg_dst, vector_step); @@ -1117,7 +1142,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi int src_stride = vector_step * jcp_.src_data_size; add(reg_dst, dst_stride); add(reg_src, src_stride); - sub(reg_work_amount, vector_step); // work_amount is c + sub(reg_work_amount, vector_step); // work_amount is c } else { int dst_stride = blk * jcp_.OW * jcp_.OH * jcp_.dst_data_size; int src_stride = blk * jcp_.IW * jcp_.IH * jcp_.src_data_size; @@ -1142,7 +1167,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi cubic_c_gathered_matrix(true); if (attr_.post_ops_.len() != 0) { - apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value + apply_post_ops(jcp_.dst_prc, false); // vmm_val is default dst value add(reg_oc_off, scalar_step * sizeof(float)); } store(vmm_val, reg_dst, scalar_step); @@ -1151,7 +1176,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi int src_stride = scalar_step * jcp_.src_data_size; add(reg_dst, dst_stride); add(reg_src, src_stride); - sub(reg_work_amount, scalar_step); // work_amount is c + sub(reg_work_amount, scalar_step); // work_amount is c jmp(tail_loop_label, T_NEAR); } @@ -1242,7 +1267,9 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi // build weightX used in y0-y3 // weight format: w0_0 w1_0 w2_0 w3_0 w0_1 w1_1 w2_1 w3_1 ... uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); - vgatherdps(vmm_weightX0, ptr[reg_weight_x + vmm_val * grid_len], vmm_mask); // 4 in vmm_val for weight_size, another 4 for grid_len + vgatherdps(vmm_weightX0, + ptr[reg_weight_x + vmm_val * grid_len], + vmm_mask); // 4 in vmm_val for weight_size, another 4 for grid_len uni_vpcmpeqd(vmm_mask, vmm_mask, vmm_mask); // shift weight_size then gather second weight @@ -1326,8 +1353,20 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi // gather weightX by input idx, used in y0-y3 gather_i32_indices(vmm_weightX0, reg_weight_x, 0, vmm_val, grid_len, ov::element::f32, true); gather_i32_indices(vmm_weightX1, reg_weight_x, sizeof(float), vmm_val, grid_len, ov::element::f32, true); - gather_i32_indices(vmm_weightX2, reg_weight_x, 2 * sizeof(float), vmm_val, grid_len, ov::element::f32, true); - gather_i32_indices(vmm_weightX3, reg_weight_x, 3 * sizeof(float), vmm_val, grid_len, ov::element::f32, true); + gather_i32_indices(vmm_weightX2, + reg_weight_x, + 2 * sizeof(float), + vmm_val, + grid_len, + ov::element::f32, + true); + gather_i32_indices(vmm_weightX3, + reg_weight_x, + 3 * sizeof(float), + vmm_val, + grid_len, + ov::element::f32, + true); // vmm_val is now relieved and used for dst_value uni_vpxor(vmm_val, vmm_val, vmm_val); @@ -1354,7 +1393,13 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1)); vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero); // weight y2 - gather_i32_indices(vmm_weightY, reg_weight_y, 2 * sizeof(float), vmm_tbl_y, grid_len, ov::element::f32, true); + gather_i32_indices(vmm_weightY, + reg_weight_y, + 2 * sizeof(float), + vmm_tbl_y, + grid_len, + ov::element::f32, + true); cubic_planar_line(true); // y3 @@ -1364,7 +1409,13 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi vpminsd(vmm_index_y_itr, vmm_index_y_itr, cubic_planar_table_val(1)); vpmaxsd(vmm_index_y_itr, vmm_index_y_itr, vmm_zero); // weight y3 - gather_i32_indices(vmm_weightY, reg_weight_y, 3 * sizeof(float), vmm_tbl_y, grid_len, ov::element::f32, true); + gather_i32_indices(vmm_weightY, + reg_weight_y, + 3 * sizeof(float), + vmm_tbl_y, + grid_len, + ov::element::f32, + true); cubic_planar_line(true); if (attr_.post_ops_.len() != 0) { @@ -1453,8 +1504,13 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } // always gather to Vmm, compute with Vmm, store with Xmm if scalar_step - inline void gather_i32_indices(Vmm vmm_src, const Xbyak::Reg64 &base, int offset, Vmm vmm_indices, int scale, - ov::element::Type src_prc, bool is_scalar) { + inline void gather_i32_indices(Vmm vmm_src, + const Xbyak::Reg64& base, + int offset, + Vmm vmm_indices, + int scale, + ov::element::Type src_prc, + bool is_scalar) { Xbyak::Address table_idx = ptr[base + offset + vmm_indices * scale]; if ((isa == cpu::x64::avx512_core) && !is_scalar) { // [0-15] bit of int to mask @@ -1483,8 +1539,8 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi int repeats = is_scalar ? 1 : vlen / sizeof(float); for (int i = 0; i < repeats; ++i) { - mov(reg_tmp_64.cvt32(), ptr[rsp + i * sizeof(int)]); // sizeof(int) index_size - table_idx = ptr[base + offset + reg_tmp_64 * scale]; // scale: sizeof(float) value_size + mov(reg_tmp_64.cvt32(), ptr[rsp + i * sizeof(int)]); // sizeof(int) index_size + table_idx = ptr[base + offset + reg_tmp_64 * scale]; // scale: sizeof(float) value_size mov(reg_tmp_64.cvt32(), table_idx); mov(ptr[rsp + i * sizeof(int)], reg_tmp_64.cvt32()); } @@ -1497,9 +1553,10 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } } - // is_broadcast for broadcasting param for depth_wise and quantize(channel-sensitive post-ops), for fusion with plain layout. + // is_broadcast for broadcasting param for depth_wise and quantize(channel-sensitive post-ops), for fusion with + // plain layout. void apply_post_ops(ov::element::Type dst_prc, bool is_broadcast) { - const auto &p = attr_.post_ops_; + const auto& p = attr_.post_ops_; int eltwise_inj_idx = 0; int depthwise_inj_idx = 0; int quantization_inj_idx = 0; @@ -1514,8 +1571,11 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi add(reg_d_weights, reg_oc_off); // weight and bias is padded. scalar as vector. - depthwise_injectors[depthwise_inj_idx]->compute_vector_range( - vmm_val.getIdx(), vmm_val.getIdx() + 1, reg_d_weights, reg_d_weights, is_broadcast); + depthwise_injectors[depthwise_inj_idx]->compute_vector_range(vmm_val.getIdx(), + vmm_val.getIdx() + 1, + reg_d_weights, + reg_d_weights, + is_broadcast); post_ops_data_offset += depthwise_injectors[depthwise_inj_idx]->memoryStep(); depthwise_inj_idx++; @@ -1525,15 +1585,25 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi int s_idx = vmm_val.getIdx(); - quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off); + quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_post_ops_data + post_ops_data_offset, + reg_oc_off); quantization_injectors[quantization_inj_idx]->compute_crop(s_idx, s_idx + 1, 0, 0, is_broadcast); - quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off); - quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(s_idx, s_idx + 1, 0, do_rounding, 0, is_broadcast); + quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs( + reg_post_ops_data + post_ops_data_offset, + reg_oc_off); + quantization_injectors[quantization_inj_idx] + ->compute_input_scale_shift(s_idx, s_idx + 1, 0, do_rounding, 0, is_broadcast); if (do_dequantization) { - quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off); - quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(s_idx, s_idx + 1, 0, 0, is_broadcast); + quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs( + reg_post_ops_data + post_ops_data_offset, + reg_oc_off); + quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(s_idx, + s_idx + 1, + 0, + 0, + is_broadcast); } post_ops_data_offset += quantization_injectors[quantization_inj_idx]->memoryStep(); @@ -1543,7 +1613,7 @@ struct jit_uni_interpolate_kernel_f32 : public jit_uni_interpolate_kernel, publi } }; -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 namespace { struct InterpolateKey { @@ -1585,7 +1655,7 @@ size_t InterpolateKey::hash() const { return seed; } -bool InterpolateKey::operator==(const InterpolateKey &rhs) const { +bool InterpolateKey::operator==(const InterpolateKey& rhs) const { if (nodeAttrs.mode != rhs.nodeAttrs.mode) return false; if (nodeAttrs.coordTransMode != rhs.nodeAttrs.coordTransMode) @@ -1619,7 +1689,7 @@ bool InterpolateKey::operator==(const InterpolateKey &rhs) const { return true; } -} // namespace +} // namespace // shapeND: n c d h w // blockND: ncdhw cdhw dhw hw w 1 @@ -1628,7 +1698,7 @@ inline VectorDims getBlockND(const VectorDims& shape) { int shapeRank = shape.size(); VectorDims blockND(shapeRank + 1, 1); for (int i = shapeRank - 1; i >= 0; i--) { - blockND[i] = shape[i] * blockND[i+1]; + blockND[i] = shape[i] * blockND[i + 1]; } return blockND; } @@ -1664,32 +1734,47 @@ using ngInterpShapeCalcMode = ov::op::v4::Interpolate::ShapeCalcMode; bool Interpolate::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { if (const auto interp = std::dynamic_pointer_cast(op)) { - const auto &interpAttr = interp->get_attrs(); - const auto &interpMode = interpAttr.mode; - if (!one_of(interpMode, ngInterpMode::NEAREST, ngInterpMode::LINEAR, ngInterpMode::LINEAR_ONNX, ngInterpMode::CUBIC)) { + const auto& interpAttr = interp->get_attrs(); + const auto& interpMode = interpAttr.mode; + if (!one_of(interpMode, + ngInterpMode::NEAREST, + ngInterpMode::LINEAR, + ngInterpMode::LINEAR_ONNX, + ngInterpMode::CUBIC)) { errorMessage = "Interpolate-4 does not support interpolate mode: " + ov::as_string(interpMode); return false; } - const auto &interpCoordTransMode = interpAttr.coordinate_transformation_mode; - if (!one_of(interpCoordTransMode, ngInterpCoordTransf::HALF_PIXEL, ngInterpCoordTransf::PYTORCH_HALF_PIXEL, ngInterpCoordTransf::ASYMMETRIC, - ngInterpCoordTransf::TF_HALF_PIXEL_FOR_NN, ngInterpCoordTransf::ALIGN_CORNERS)) { - errorMessage = "Interpolate-4 does not support coordinate transformation mode: " + ov::as_string(interpCoordTransMode); + const auto& interpCoordTransMode = interpAttr.coordinate_transformation_mode; + if (!one_of(interpCoordTransMode, + ngInterpCoordTransf::HALF_PIXEL, + ngInterpCoordTransf::PYTORCH_HALF_PIXEL, + ngInterpCoordTransf::ASYMMETRIC, + ngInterpCoordTransf::TF_HALF_PIXEL_FOR_NN, + ngInterpCoordTransf::ALIGN_CORNERS)) { + errorMessage = "Interpolate-4 does not support coordinate transformation mode: " + + ov::as_string(interpCoordTransMode); return false; } if (interpMode == ngInterpMode::NEAREST) { - const auto &interpNearestMode = interpAttr.nearest_mode; - if (!one_of(interpNearestMode, ngInterpNearMode::ROUND_PREFER_FLOOR, ngInterpNearMode::ROUND_PREFER_CEIL, ngInterpNearMode::FLOOR, - ngInterpNearMode::CEIL, ngInterpNearMode::SIMPLE)) { - errorMessage = "Interpolate-4 does not support nearest round mode: " + ov::as_string(interpNearestMode); + const auto& interpNearestMode = interpAttr.nearest_mode; + if (!one_of(interpNearestMode, + ngInterpNearMode::ROUND_PREFER_FLOOR, + ngInterpNearMode::ROUND_PREFER_CEIL, + ngInterpNearMode::FLOOR, + ngInterpNearMode::CEIL, + ngInterpNearMode::SIMPLE)) { + errorMessage = + "Interpolate-4 does not support nearest round mode: " + ov::as_string(interpNearestMode); return false; } } - const auto &interpShapeCalcMode = interpAttr.shape_calculation_mode; + const auto& interpShapeCalcMode = interpAttr.shape_calculation_mode; if (!one_of(interpShapeCalcMode, ngInterpShapeCalcMode::SCALES, ngInterpShapeCalcMode::SIZES)) { - errorMessage = "Interpolate-4 does not support shape_calculation_mode: " + ov::as_string(interpShapeCalcMode); + errorMessage = + "Interpolate-4 does not support shape_calculation_mode: " + ov::as_string(interpShapeCalcMode); return false; } @@ -1700,7 +1785,8 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr& op } if (dataRank == 5 && interpMode == ngInterpMode::CUBIC) { - errorMessage = "Interpolate-4 doesn't support input tensor with rank: " + std::to_string(dataRank) + " for 'cubic' mode "; + errorMessage = "Interpolate-4 doesn't support input tensor with rank: " + std::to_string(dataRank) + + " for 'cubic' mode "; return false; } @@ -1710,21 +1796,22 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr& op return false; } - if (interp->get_input_size() > 3 && - std::dynamic_pointer_cast(interp->get_input_node_shared_ptr(AXES_ID)) == nullptr) { + if (interp->get_input_size() > 3 && std::dynamic_pointer_cast( + interp->get_input_node_shared_ptr(AXES_ID)) == nullptr) { errorMessage = "Only const 'axes' input is supported in Interpolate-4"; return false; } } else if (const auto interp = std::dynamic_pointer_cast(op)) { - const auto &interpAttr = interp->get_attrs(); - const auto &interpMode = interpAttr.mode; + const auto& interpAttr = interp->get_attrs(); + const auto& interpMode = interpAttr.mode; if (!one_of(interpMode, ngInterpMode::BILINEAR_PILLOW, ngInterpMode::BICUBIC_PILLOW)) { errorMessage = "Interpolate-11 does not support interpolate mode: " + ov::as_string(interpMode); return false; } - const auto &interpShapeCalcMode = interpAttr.shape_calculation_mode; + const auto& interpShapeCalcMode = interpAttr.shape_calculation_mode; if (!one_of(interpShapeCalcMode, ngInterpShapeCalcMode::SCALES, ngInterpShapeCalcMode::SIZES)) { - errorMessage = "Interpolate-11 does not support shape_calculation_mode: " + ov::as_string(interpShapeCalcMode); + errorMessage = + "Interpolate-11 does not support shape_calculation_mode: " + ov::as_string(interpShapeCalcMode); return false; } const size_t dataRank = interp->get_input_partial_shape(DATA_ID).rank().get_length(); @@ -1734,12 +1821,12 @@ bool Interpolate::isSupportedOperation(const std::shared_ptr& op return false; } if (!isDynamicNgraphNode(op) && - !ov::is_type(op->get_input_node_ptr(SIZE_OR_SCALE_ID_V11))) { + !ov::is_type(op->get_input_node_ptr(SIZE_OR_SCALE_ID_V11))) { errorMessage = "Only const 'scales_or_sizes' input is supported for static shapes in Interpolate-11"; return false; } - if (interp->get_input_size() > 2 && - std::dynamic_pointer_cast(interp->get_input_node_shared_ptr(AXES_ID_V11)) == nullptr) { + if (interp->get_input_size() > 2 && std::dynamic_pointer_cast( + interp->get_input_node_shared_ptr(AXES_ID_V11)) == nullptr) { errorMessage = "Only const 'axes' input is supported in Interpolate-11"; return false; } @@ -1763,7 +1850,7 @@ class InterpolateShapeInferFactory : public ShapeInferFactory { InterpolateShapeInferFactory(std::shared_ptr op) : m_op(op) {} ShapeInferPtr makeShapeInfer() const override { if (auto interp4 = ov::as_type_ptr(m_op)) { - const auto &attr = interp4->get_attrs(); + const auto& attr = interp4->get_attrs(); const auto is_supported_mode = (attr.shape_calculation_mode == ngInterpShapeCalcMode::SCALES) || (attr.shape_calculation_mode == ngInterpShapeCalcMode::SIZES); OPENVINO_ASSERT(is_supported_mode, "Unsupported interpolate shape calculation mode"); @@ -1782,10 +1869,10 @@ class InterpolateShapeInferFactory : public ShapeInferFactory { private: std::shared_ptr m_op; }; -} // namespace +} // namespace Interpolate::Interpolate(const std::shared_ptr& op, const GraphContext::CPtr context) - : Node(op, context, InterpolateShapeInferFactory(op)) { + : Node(op, context, InterpolateShapeInferFactory(op)) { std::string errorMessage; if (isSupportedOperation(op, errorMessage)) { errorPrefix = "Interpolate node with name '" + getName() + "'"; @@ -1799,9 +1886,9 @@ Interpolate::Interpolate(const std::shared_ptr& op, const GraphContext OPENVINO_THROW(errorPrefix, " has incorrect number of output edges"); isAxesSpecified = numInputs != 3; - const auto &interpAttr = interp->get_attrs(); + const auto& interpAttr = interp->get_attrs(); - const auto &interpMode = interpAttr.mode; + const auto& interpMode = interpAttr.mode; if (interpMode == ngInterpMode::NEAREST) { interpAttrs.mode = InterpolateMode::nearest; } else if (interpMode == ngInterpMode::LINEAR) { @@ -1818,7 +1905,7 @@ Interpolate::Interpolate(const std::shared_ptr& op, const GraphContext OPENVINO_THROW(errorPrefix, " has unsupported interpolate mode"); } - const auto &interpCoordTransMode = interpAttr.coordinate_transformation_mode; + const auto& interpCoordTransMode = interpAttr.coordinate_transformation_mode; if (interpCoordTransMode == ngInterpCoordTransf::HALF_PIXEL) { interpAttrs.coordTransMode = InterpolateCoordTransMode::half_pixel; } else if (interpCoordTransMode == ngInterpCoordTransf::PYTORCH_HALF_PIXEL) { @@ -1834,7 +1921,7 @@ Interpolate::Interpolate(const std::shared_ptr& op, const GraphContext } if (interpAttrs.mode == InterpolateMode::nearest) { - const auto &interpNearestMode = interpAttr.nearest_mode; + const auto& interpNearestMode = interpAttr.nearest_mode; if (interpNearestMode == ngInterpNearMode::ROUND_PREFER_FLOOR) { interpAttrs.nearestMode = InterpolateNearestMode::round_prefer_floor; } else if (interpNearestMode == ngInterpNearMode::ROUND_PREFER_CEIL) { @@ -1853,7 +1940,7 @@ Interpolate::Interpolate(const std::shared_ptr& op, const GraphContext } interpAttrs.antialias = interpAttr.antialias; - const auto &interpShapeCalcMode = interpAttr.shape_calculation_mode; + const auto& interpShapeCalcMode = interpAttr.shape_calculation_mode; if (interpShapeCalcMode == ngInterpShapeCalcMode::SCALES) { interpAttrs.shapeCalcMode = InterpolateShapeCalcMode::scales; } else if (interpShapeCalcMode == ngInterpShapeCalcMode::SIZES) { @@ -1878,14 +1965,16 @@ Interpolate::Interpolate(const std::shared_ptr& op, const GraphContext interpAttrs.padEnd[i] = static_cast(interpAttr.pads_end[i]); } - const auto scalesNode = std::dynamic_pointer_cast(interp->get_input_node_shared_ptr(SCALES_ID)); + const auto scalesNode = + std::dynamic_pointer_cast(interp->get_input_node_shared_ptr(SCALES_ID)); if (scalesNode) { scales = scalesNode->cast_vector(); isScaleConstant = true; } if (isAxesSpecified) { - axes = std::dynamic_pointer_cast(interp->get_input_node_shared_ptr(AXES_ID))->cast_vector(); + axes = std::dynamic_pointer_cast(interp->get_input_node_shared_ptr(AXES_ID)) + ->cast_vector(); } else { axes.resize(dataRank); for (int i = 0; i < static_cast(dataRank); i++) { @@ -1901,13 +1990,13 @@ Interpolate::Interpolate(const std::shared_ptr& op, const GraphContext OPENVINO_THROW(errorPrefix, " has incorrect number of output edges"); isAxesSpecified = numInputs != 2; - const auto &interpAttr = interp->get_attrs(); - const auto &interpMode = interpAttr.mode; + const auto& interpAttr = interp->get_attrs(); + const auto& interpMode = interpAttr.mode; if (interpMode == ngInterpMode::BILINEAR_PILLOW) { interpAttrs.mode = InterpolateMode::bilinear_pillow; } else if (interpMode == ngInterpMode::BICUBIC_PILLOW) { interpAttrs.mode = InterpolateMode::bicubic_pillow; - interpAttrs.cubeCoeff = static_cast(interpAttr.cube_coeff); // fixed to be -0.5 + interpAttrs.cubeCoeff = static_cast(interpAttr.cube_coeff); // fixed to be -0.5 } else { OPENVINO_THROW(errorPrefix, " has unsupported interpolate mode"); } @@ -1916,10 +2005,11 @@ Interpolate::Interpolate(const std::shared_ptr& op, const GraphContext interpAttrs.coordTransMode = InterpolateCoordTransMode::tf_half_pixel_for_nn; interpAttrs.antialias = interpAttr.antialias; - const auto &interpShapeCalcMode = interpAttr.shape_calculation_mode; + const auto& interpShapeCalcMode = interpAttr.shape_calculation_mode; if (interpShapeCalcMode == ngInterpShapeCalcMode::SCALES) { interpAttrs.shapeCalcMode = InterpolateShapeCalcMode::scales; - const auto scalesNode = std::dynamic_pointer_cast(interp->get_input_node_shared_ptr(SIZE_OR_SCALE_ID_V11)); + const auto scalesNode = std::dynamic_pointer_cast( + interp->get_input_node_shared_ptr(SIZE_OR_SCALE_ID_V11)); if (scalesNode) { scales = scalesNode->cast_vector(); isScaleConstant = true; @@ -1947,7 +2037,9 @@ Interpolate::Interpolate(const std::shared_ptr& op, const GraphContext } if (isAxesSpecified) { - axes = std::dynamic_pointer_cast(interp->get_input_node_shared_ptr(AXES_ID_V11))->cast_vector(); + axes = std::dynamic_pointer_cast( + interp->get_input_node_shared_ptr(AXES_ID_V11)) + ->cast_vector(); if (dataRank == 4 && axes.size() == 2 && axes[0] == 1 && axes[1] == 2 && mayiuse(cpu::x64::sse41)) { NCHWAsNHWC = true; axes[0] = 2; @@ -1986,7 +2078,7 @@ void Interpolate::getSupportedDescriptors() { break; } } - //correct pad + // correct pad if (hasPad) { NCHWAsNHWC = false; auto correctPad = [&](std::vector pad, int rank) { @@ -2064,15 +2156,21 @@ void Interpolate::initSupportedPrimitiveDescriptors() { } } auto& creatorsMap = BlockedDescCreator::getCommonCreators(); - auto pushDesc = [&](LayoutType dataFormat, impl_desc_type implDetail, bool is_version11, bool useAclExecutor = false) { - config.inConfs[DATA_ID].setMemDesc(creatorsMap.at(dataFormat)->createSharedDesc(inputPrecision, getInputShapeAtPort(DATA_ID))); + auto pushDesc = [&](LayoutType dataFormat, + impl_desc_type implDetail, + bool is_version11, + bool useAclExecutor = false) { + config.inConfs[DATA_ID].setMemDesc( + creatorsMap.at(dataFormat)->createSharedDesc(inputPrecision, getInputShapeAtPort(DATA_ID))); if (is_version11) { if (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes) { config.inConfs[SIZE_OR_SCALE_ID_V11].setMemDesc( - creatorsMap.at(LayoutType::ncsp)->createSharedDesc(targetShapeType, getInputShapeAtPort(SIZE_OR_SCALE_ID_V11))); + creatorsMap.at(LayoutType::ncsp) + ->createSharedDesc(targetShapeType, getInputShapeAtPort(SIZE_OR_SCALE_ID_V11))); } else { config.inConfs[SIZE_OR_SCALE_ID_V11].setMemDesc( - creatorsMap.at(LayoutType::ncsp)->createSharedDesc(scalesType, getInputShapeAtPort(SIZE_OR_SCALE_ID_V11))); + creatorsMap.at(LayoutType::ncsp) + ->createSharedDesc(scalesType, getInputShapeAtPort(SIZE_OR_SCALE_ID_V11))); } if (isAxesSpecified) @@ -2080,14 +2178,18 @@ void Interpolate::initSupportedPrimitiveDescriptors() { creatorsMap.at(LayoutType::ncsp)->createSharedDesc(axesType, getInputShapeAtPort(AXES_ID_V11))); } else { config.inConfs[TARGET_SHAPE_ID].setMemDesc( - creatorsMap.at(LayoutType::ncsp)->createSharedDesc(targetShapeType, getInputShapeAtPort(TARGET_SHAPE_ID))); - config.inConfs[get_scale_id()].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(scalesType, getInputShapeAtPort(get_scale_id()))); + creatorsMap.at(LayoutType::ncsp) + ->createSharedDesc(targetShapeType, getInputShapeAtPort(TARGET_SHAPE_ID))); + config.inConfs[get_scale_id()].setMemDesc( + creatorsMap.at(LayoutType::ncsp)->createSharedDesc(scalesType, getInputShapeAtPort(get_scale_id()))); if (isAxesSpecified) - config.inConfs[get_axis_id()].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(axesType, getInputShapeAtPort(get_axis_id()))); + config.inConfs[get_axis_id()].setMemDesc( + creatorsMap.at(LayoutType::ncsp)->createSharedDesc(axesType, getInputShapeAtPort(get_axis_id()))); } - config.outConfs[0].setMemDesc(creatorsMap.at(dataFormat)->createSharedDesc(outputPrecision, getOutputShapeAtPort(0))); + config.outConfs[0].setMemDesc( + creatorsMap.at(dataFormat)->createSharedDesc(outputPrecision, getOutputShapeAtPort(0))); if (useAclExecutor) { std::vector srcMemoryDescs; @@ -2099,8 +2201,11 @@ void Interpolate::initSupportedPrimitiveDescriptors() { dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()); } - auto factory = std::make_shared(interpAttrs, srcMemoryDescs, dstMemoryDescs, - std::make_shared(context, getImplPriority())); + auto factory = std::make_shared( + interpAttrs, + srcMemoryDescs, + dstMemoryDescs, + std::make_shared(context, getImplPriority())); if (!factory->isEmpty()) { supportedPrimitiveDescriptors.push_back({config, implDetail, factory}); } @@ -2109,14 +2214,14 @@ void Interpolate::initSupportedPrimitiveDescriptors() { } }; if (is_version11) { -#if defined (OV_CPU_WITH_ACL) +#if defined(OV_CPU_WITH_ACL) interpAttrs.hasPad = hasPad; pushDesc(LayoutType::nspc, undef, true, true); pushDesc(LayoutType::ncsp, undef, true, true); canUseAclExecutor = !supportedPrimitiveDescriptors.empty(); if (canUseAclExecutor) return; - //fallback to f32 if ref is used + // fallback to f32 if ref is used inputPrecision = outputPrecision = ov::element::f32; #endif @@ -2140,17 +2245,17 @@ void Interpolate::initSupportedPrimitiveDescriptors() { } pushDesc(LayoutType::ncsp, ref, true); } else { - const auto &dataMinDims = getInputShapeAtPort(DATA_ID).getMinDims(); + const auto& dataMinDims = getInputShapeAtPort(DATA_ID).getMinDims(); bool isBlkApplied = dataRank > 1 && dataMinDims[1] != Shape::UNDEFINED_DIM && dataMinDims[1] > 1; -#if defined (OV_CPU_WITH_ACL) +#if defined(OV_CPU_WITH_ACL) interpAttrs.hasPad = hasPad; pushDesc(LayoutType::nspc, undef, false, true); pushDesc(LayoutType::ncsp, undef, false, true); canUseAclExecutor = !supportedPrimitiveDescriptors.empty(); if (canUseAclExecutor) return; - //fallback to f32 if ref is used + // fallback to f32 if ref is used inputPrecision = outputPrecision = ov::element::f32; #endif @@ -2195,7 +2300,7 @@ bool Interpolate::needShapeInfer() const { if (lastScales.empty()) { return true; } - const float *scales = getSrcDataAtPortAs(get_scale_id()); + const float* scales = getSrcDataAtPortAs(get_scale_id()); for (size_t i = 0; i < lastScales.size(); i++) { if (lastScales[i] != scales[i]) { return true; @@ -2205,7 +2310,7 @@ bool Interpolate::needShapeInfer() const { if (lastSizes.empty()) { return true; } - const int32_t *sizes = getSrcDataAtPortAs(TARGET_SHAPE_ID); + const int32_t* sizes = getSrcDataAtPortAs(TARGET_SHAPE_ID); for (size_t i = 0; i < lastSizes.size(); i++) { if (sizes[i] != lastSizes[i]) { return true; @@ -2219,12 +2324,12 @@ void Interpolate::executeDynamicImpl(dnnl::stream strm) { execute(strm); const size_t port = interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::sizes ? TARGET_SHAPE_ID : get_scale_id(); - const auto &memory = getParentEdgeAt(port)->getMemory(); + const auto& memory = getParentEdgeAt(port)->getMemory(); if (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::scales) { - const float *scales = memory.getDataAs(); + const float* scales = memory.getDataAs(); lastScales.assign(scales, scales + memory.getDesc().getShape().getElementsCount()); } else { - const int32_t *sizes = memory.getDataAs(); + const int32_t* sizes = memory.getDataAs(); lastSizes.assign(sizes, sizes + memory.getDesc().getShape().getElementsCount()); } } @@ -2277,19 +2382,19 @@ void Interpolate::prepareParams() { OPENVINO_THROW(errorPrefix, " has undefined axes memory"); } - const NodeDesc *selected_pd = getSelectedPrimitiveDescriptor(); + const NodeDesc* selected_pd = getSelectedPrimitiveDescriptor(); if (selected_pd == nullptr) OPENVINO_THROW(errorPrefix, " did not set preferable primitive descriptor"); - const auto &srcDimsOrign = srcMemPtr->getStaticDims(); - const auto &dstDimsOrign = dstMemPtr->getStaticDims(); + const auto& srcDimsOrign = srcMemPtr->getStaticDims(); + const auto& dstDimsOrign = dstMemPtr->getStaticDims(); VectorDims srcDims = srcDimsOrign; VectorDims dstDims = dstDimsOrign; // layoutAlignment if (NCHWAsNHWC && srcMemPtr->getDesc().hasLayoutType(LayoutType::ncsp)) { - auto logicalShapeAlign = [] (VectorDims& Dims) { + auto logicalShapeAlign = [](VectorDims& Dims) { size_t C = Dims[3]; Dims[3] = Dims[2]; Dims[2] = Dims[1]; @@ -2308,7 +2413,8 @@ void Interpolate::prepareParams() { } } - std::vector dataScales = getScales(getPaddedInputShape(srcDims, interpAttrs.padBegin, interpAttrs.padEnd), dstDims); + std::vector dataScales = + getScales(getPaddedInputShape(srcDims, interpAttrs.padBegin, interpAttrs.padEnd), dstDims); if (!NCHWAsNHWC && (getOutputShapeAtPort(0).getRank() > 2 && (dataScales[0] != 1.f || dataScales[1] != 1.f))) { OPENVINO_THROW("Interpolate layer only supports resize on spatial dimensions(depth, height and width)"); } @@ -2324,7 +2430,10 @@ void Interpolate::prepareParams() { dstMemoryDescs.push_back(getDstMemoryAtPort(0)->getDescPtr()); auto selectedPD = getSelectedPrimitiveDescriptor(); - aclExecPtr = selectedPD->getExecutorFactoryAs()->makeExecutor(interpAttrs, srcMemoryDescs, dstMemoryDescs, {}); + aclExecPtr = selectedPD->getExecutorFactoryAs()->makeExecutor(interpAttrs, + srcMemoryDescs, + dstMemoryDescs, + {}); selectedPD->setImplementationType(aclExecPtr->getImplType()); return; @@ -2336,26 +2445,25 @@ void Interpolate::prepareParams() { auto buildExecutor = [&](const InterpolateKey& key) -> std::shared_ptr { std::shared_ptr executor; if ((key.nodeAttrs.mode == InterpolateMode::nearest || key.nodeAttrs.mode == InterpolateMode::linear_onnx || - key.nodeAttrs.mode == InterpolateMode::cubic) && + key.nodeAttrs.mode == InterpolateMode::cubic) && ((key.nodeAttrs.layout != InterpolateLayoutType::planar && mayiuse(cpu::x64::sse41)) || - (mayiuse(cpu::x64::avx2) && key.nodeAttrs.inPrc == ov::element::f32))) { + (mayiuse(cpu::x64::avx2) && key.nodeAttrs.inPrc == ov::element::f32))) { executor = std::make_shared(key.nodeAttrs, - key.srcDims, - key.dstDims, - key.dataScales, - key.attr); - } else if ((key.nodeAttrs.mode == InterpolateMode::bilinear_pillow || key.nodeAttrs.mode == InterpolateMode::bicubic_pillow) && - (key.nodeAttrs.layout == InterpolateLayoutType::by_channel)) { + key.srcDims, + key.dstDims, + key.dataScales, + key.attr); + } else if ((key.nodeAttrs.mode == InterpolateMode::bilinear_pillow || + key.nodeAttrs.mode == InterpolateMode::bicubic_pillow) && + (key.nodeAttrs.layout == InterpolateLayoutType::by_channel)) { executor = std::make_shared(key.nodeAttrs, - key.srcDims, - key.dstDims, - key.dataScales, - key.attr); + key.srcDims, + key.dstDims, + key.dataScales, + key.attr); } else { - executor = std::make_shared(key.nodeAttrs, - key.srcDims, - key.dstDims, - key.dataScales); + executor = + std::make_shared(key.nodeAttrs, key.srcDims, key.dstDims, key.dataScales); } return executor; }; @@ -2402,18 +2510,18 @@ static inline float triangleCoeff(float x) { return (std::max)(0.0f, 1 - std::abs(x)); } -void Interpolate::setPostOps(dnnl::primitive_attr &attr, const VectorDims &dims) { +void Interpolate::setPostOps(dnnl::primitive_attr& attr, const VectorDims& dims) { dnnl::post_ops ops; postOpsDataPtrs.clear(); - for (auto &node : fusedWith) { - auto* fakeQuantizeNode = dynamic_cast(node.get()); + for (auto& node : fusedWith) { + auto* fakeQuantizeNode = dynamic_cast(node.get()); if (fakeQuantizeNode) { fakeQuantizeNode->appendPostOps(ops, {}, postOpsDataPtrs); continue; } - auto* eltwiseNode = dynamic_cast(node.get()); + auto* eltwiseNode = dynamic_cast(node.get()); if (eltwiseNode) { eltwiseNode->appendPostOps(ops, dims, postOpsDataPtrs); continue; @@ -2429,9 +2537,9 @@ void Interpolate::setPostOps(dnnl::primitive_attr &attr, const VectorDims &dims) attr.set_post_ops(ops); } -VectorDims Interpolate::getPaddedInputShape(const VectorDims &srcDims, - const std::vector &padBegin, - const std::vector &padEnd) { +VectorDims Interpolate::getPaddedInputShape(const VectorDims& srcDims, + const std::vector& padBegin, + const std::vector& padEnd) { VectorDims paddedShape; int dataRank = srcDims.size(); for (int i = 0; i < dataRank; i++) { @@ -2443,18 +2551,21 @@ VectorDims Interpolate::getPaddedInputShape(const VectorDims &srcDims, // get scales of data rank size // if "scale" version: set scales with input scales, 1.f for other dims not in axis // if "size" version: scales = shape[target] / shape[input].pad, 1.f for other dims not in axis -// scales is a required input, but should not use input scales when "size" case, which may added eps or is a dummy value, recalculate scales instead. -std::vector Interpolate::getScales(const VectorDims &srcDimPad, const VectorDims &dstDim) { +// scales is a required input, but should not use input scales when "size" case, which may added eps or is a dummy +// value, recalculate scales instead. +std::vector Interpolate::getScales(const VectorDims& srcDimPad, const VectorDims& dstDim) { std::vector fullScales(dataRank, 1.f); const size_t axesRank = axes.size(); for (size_t i = 0; i < axesRank; i++) { int axis = axes[i]; // pillow always re-generate scales with input and output shape - if (interpAttrs.mode == InterpolateMode::bilinear_pillow || interpAttrs.mode == InterpolateMode::bicubic_pillow) { + if (interpAttrs.mode == InterpolateMode::bilinear_pillow || + interpAttrs.mode == InterpolateMode::bicubic_pillow) { fullScales[axis] = static_cast(dstDim[axis]) / static_cast(srcDimPad[axis]); } else { - fullScales[axis] = (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::scales) ? scales[i] : - static_cast(dstDim[axis]) / static_cast(srcDimPad[axis]); + fullScales[axis] = (interpAttrs.shapeCalcMode == InterpolateShapeCalcMode::scales) + ? scales[i] + : static_cast(dstDim[axis]) / static_cast(srcDimPad[axis]); } } return fullScales; @@ -2465,12 +2576,12 @@ void Interpolate::execute(dnnl::stream strm) { auto srcMemPtr = getSrcMemoryAtPort(DATA_ID); if (execPtr) { - uint8_t *dst_data = dstMemPtr->getDataAs(); - const uint8_t *src_data_origin = srcMemPtr->getDataAs(); - const uint8_t *src_data = nullptr; + uint8_t* dst_data = dstMemPtr->getDataAs(); + const uint8_t* src_data_origin = srcMemPtr->getDataAs(); + const uint8_t* src_data = nullptr; std::vector srcPadded; if (hasPad) { - const auto &srcDim = srcMemPtr->getStaticDims(); + const auto& srcDim = srcMemPtr->getStaticDims(); auto srcDimPad = execPtr->getSrcDimPad5d(); size_t dimSize = srcDim.size(); @@ -2489,23 +2600,34 @@ void Interpolate::execute(dnnl::stream strm) { if (interpAttrs.layout == InterpolateLayoutType::planar) { srcPadded.resize(inShapePadBlock[0] * srcDataSize, 0); - uint8_t *src_data_pad = static_cast(&srcPadded[0]); + uint8_t* src_data_pad = static_cast(&srcPadded[0]); parallel_for4d(srcDim5d[0], srcDim5d[1], srcDim5d[2], srcDim5d[3], [&](int n, int c, int d, int h) { - const uint8_t *src = src_data_origin + - (inShapeBlock[1] * n + inShapeBlock[2] * c + inShapeBlock[3] * d + inShapeBlock[4] * h) * srcDataSize; - uint8_t *srcPad = src_data_pad + (inShapePadBlock[1] * (n + padB0) + inShapePadBlock[2] * (c + padB1) + - inShapePadBlock[3] * (d + padB2) + inShapePadBlock[4] * (h + padB3) + padB4) * srcDataSize; + const uint8_t* src = src_data_origin + (inShapeBlock[1] * n + inShapeBlock[2] * c + + inShapeBlock[3] * d + inShapeBlock[4] * h) * + srcDataSize; + uint8_t* srcPad = + src_data_pad + (inShapePadBlock[1] * (n + padB0) + inShapePadBlock[2] * (c + padB1) + + inShapePadBlock[3] * (d + padB2) + inShapePadBlock[4] * (h + padB3) + padB4) * + srcDataSize; cpu_memcpy(srcPad, src, srcDim5d[4] * srcDataSize); }); src_data = src_data_pad; } else if (interpAttrs.layout == InterpolateLayoutType::by_channel) { srcPadded.resize(inShapePadBlock[0] * srcDataSize, 0); - uint8_t *src_data_pad = static_cast(&srcPadded[0]); + uint8_t* src_data_pad = static_cast(&srcPadded[0]); parallel_for4d(srcDim5d[0], srcDim5d[2], srcDim5d[3], srcDim5d[4], [&](int n, int d, int h, int w) { - const uint8_t *src = src_data_origin + (inShapeBlock[1] * n + - (inShapeBlock[3] * d + inShapeBlock[4] * h + inShapeBlock[5] * w) * srcDim5d[1]) * srcDataSize; - uint8_t *srcPad = src_data_pad + (inShapePadBlock[1] * (n + padB0) + (inShapePadBlock[3] * (d + padB2) + - inShapePadBlock[4] * (h + padB3) + inShapePadBlock[5] * (w + padB4)) * srcDimPad5d[1] + padB1) * srcDataSize; + const uint8_t* src = + src_data_origin + + (inShapeBlock[1] * n + + (inShapeBlock[3] * d + inShapeBlock[4] * h + inShapeBlock[5] * w) * srcDim5d[1]) * + srcDataSize; + uint8_t* srcPad = + src_data_pad + (inShapePadBlock[1] * (n + padB0) + + (inShapePadBlock[3] * (d + padB2) + inShapePadBlock[4] * (h + padB3) + + inShapePadBlock[5] * (w + padB4)) * + srcDimPad5d[1] + + padB1) * + srcDataSize; cpu_memcpy(srcPad, src, srcDim5d[1] * srcDataSize); }); src_data = src_data_pad; @@ -2514,25 +2636,34 @@ void Interpolate::execute(dnnl::stream strm) { size_t CB = div_up(srcDimPad5d[1], blkSize); size_t eltsTotal = srcDimPad5d[0] * CB * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize; srcPadded.resize(eltsTotal * srcDataSize, 0x0); - uint8_t *src_data_pad = static_cast(&srcPadded[0]); + uint8_t* src_data_pad = static_cast(&srcPadded[0]); if ((srcDim5d[0] != srcDimPad5d[0]) || (srcDim5d[1] != srcDimPad5d[1])) { OPENVINO_THROW("Interpolate layer with name '", getName(), "' does not support padding on batch and channel dimensions"); } - parallel_for5d(srcDim5d[0], CB, srcDim5d[2], srcDim5d[3], srcDim5d[4], [&](int n, int cb, int d, int h, int w) { - const uint8_t *src = src_data_origin + (n * CB * srcDim5d[2] * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize - + (cb * srcDim5d[2] * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize - + (d * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize - + (h * srcDim5d[4] * blkSize) * srcDataSize - + (w * blkSize) * srcDataSize; - uint8_t *srcPad = src_data_pad + (n * CB * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize - + (cb * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize - + ((d + padB2) * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize - + ((h + padB3) * srcDimPad5d[4] * blkSize) * srcDataSize - + ((w + padB4) * blkSize) * srcDataSize; - cpu_memcpy(srcPad, src, blkSize * srcDataSize); - }); + parallel_for5d(srcDim5d[0], + CB, + srcDim5d[2], + srcDim5d[3], + srcDim5d[4], + [&](int n, int cb, int d, int h, int w) { + const uint8_t* src = + src_data_origin + + (n * CB * srcDim5d[2] * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize + + (cb * srcDim5d[2] * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize + + (d * srcDim5d[3] * srcDim5d[4] * blkSize) * srcDataSize + + (h * srcDim5d[4] * blkSize) * srcDataSize + (w * blkSize) * srcDataSize; + uint8_t* srcPad = + src_data_pad + + (n * CB * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * + srcDataSize + + (cb * srcDimPad5d[2] * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize + + ((d + padB2) * srcDimPad5d[3] * srcDimPad5d[4] * blkSize) * srcDataSize + + ((h + padB3) * srcDimPad5d[4] * blkSize) * srcDataSize + + ((w + padB4) * blkSize) * srcDataSize; + cpu_memcpy(srcPad, src, blkSize * srcDataSize); + }); src_data = src_data_pad; } } else { @@ -2549,26 +2680,35 @@ void Interpolate::execute(dnnl::stream strm) { // for ndhwc and nCdhw8c[16c] // input may be f32/bf16/int8, fused->output varies -void Interpolate::InterpolateJitExecutor::NNCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int ID, int IH, int IW, int OD, int OH, int OW) { - int *index_d = static_cast(&auxTable[0]); - int *index_h = static_cast(&auxTable[OD]); - int *index_w = static_cast(&auxTable[OD + OH]); +void Interpolate::InterpolateJitExecutor::NNCGathered(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int ID, + int IH, + int IW, + int OD, + int OH, + int OW) { + int* index_d = static_cast(&auxTable[0]); + int* index_h = static_cast(&auxTable[OD]); + int* index_w = static_cast(&auxTable[OD + OH]); bool is_nhwc = (configured_for_layout == by_channel); for (int b = 0; b < B; b++) { if (is_nhwc) { - const uint8_t *in_ptr = in_ptr_ + (IW * IH * ID * C * b) * srcDataSize; - uint8_t *out_ptr = out_ptr_ + (OW * OH * OD * C * b) * dstDataSize; + const uint8_t* in_ptr = in_ptr_ + (IW * IH * ID * C * b) * srcDataSize; + uint8_t* out_ptr = out_ptr_ + (OW * OH * OD * C * b) * dstDataSize; std::vector index_w_kernel(OW); for (int ox = 0; ox < OW; ox++) { index_w_kernel[ox] = index_w[ox] * C * srcDataSize; } parallel_for2d(OD, OH, [&](size_t d, size_t h) { // kernel for C * OW - uint8_t *out_ptr_dh = out_ptr + (C * OW * OH * d + C * OW * h) * dstDataSize; - const uint8_t *in_ptr_dh = in_ptr + (C * IW * IH * index_d[d] + C * IW * index_h[h]) * srcDataSize; + uint8_t* out_ptr_dh = out_ptr + (C * OW * OH * d + C * OW * h) * dstDataSize; + const uint8_t* in_ptr_dh = in_ptr + (C * IW * IH * index_d[d] + C * IW * index_h[h]) * srcDataSize; auto arg = jit_interpolate_call_args(); arg.dst = out_ptr_dh; arg.src_ptr[0] = in_ptr_dh; @@ -2581,15 +2721,16 @@ void Interpolate::InterpolateJitExecutor::NNCGathered(const uint8_t *in_ptr_, ui } else { // for blk int blk_size = mayiuse(cpu::x64::avx512_core) ? 16 : 8; int CB = div_up(C, blk_size); - const uint8_t *in_ptr = in_ptr_ + (IW * IH * ID * CB * blk_size * b) * srcDataSize; - uint8_t *out_ptr = out_ptr_ + (OW * OH * OD * CB * blk_size * b) * dstDataSize; + const uint8_t* in_ptr = in_ptr_ + (IW * IH * ID * CB * blk_size * b) * srcDataSize; + uint8_t* out_ptr = out_ptr_ + (OW * OH * OD * CB * blk_size * b) * dstDataSize; std::vector index_w_kernel(OW); for (int ox = 0; ox < OW; ox++) { index_w_kernel[ox] = index_w[ox] * blk_size * srcDataSize; } parallel_for2d(CB, OD, [&](size_t cb, size_t d) { - uint8_t *out_ptr_cbd = out_ptr + (blk_size * OW * OH * OD * cb + blk_size * OW * OH * d) * dstDataSize; - const uint8_t *in_ptr_cbd = in_ptr + (blk_size * IW * IH * ID * cb + blk_size * IW * IH * index_d[d]) * srcDataSize; + uint8_t* out_ptr_cbd = out_ptr + (blk_size * OW * OH * OD * cb + blk_size * OW * OH * d) * dstDataSize; + const uint8_t* in_ptr_cbd = + in_ptr + (blk_size * IW * IH * ID * cb + blk_size * IW * IH * index_d[d]) * srcDataSize; auto arg = jit_interpolate_call_args(); for (int h = 0; h < OH; h++) { // kernel for blk_size * OW arg.dst = out_ptr_cbd + blk_size * OW * h * dstDataSize; @@ -2605,11 +2746,20 @@ void Interpolate::InterpolateJitExecutor::NNCGathered(const uint8_t *in_ptr_, ui } // batch end } -void Interpolate::InterpolateJitExecutor::NNPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int ID, int IH, int IW, int OD, int OH, int OW) { - int *index_d = static_cast(&auxTable[0]); - int *index_h = static_cast(&auxTable[OD]); - int *index_w = static_cast(&auxTable[OD + OH]); +void Interpolate::InterpolateJitExecutor::NNPlanar(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int ID, + int IH, + int IW, + int OD, + int OH, + int OW) { + int* index_d = static_cast(&auxTable[0]); + int* index_h = static_cast(&auxTable[OD]); + int* index_w = static_cast(&auxTable[OD + OH]); std::vector index_kernel(OH + OW); // index_h * IW * srcDataSize to reduce and simplify redundant compute @@ -2622,13 +2772,15 @@ void Interpolate::InterpolateJitExecutor::NNPlanar(const uint8_t *in_ptr_, uint8 } parallel_for3d(B, C, OD, [&](size_t b, size_t c, size_t od) { - const uint8_t *in_ptr = in_ptr_ + (IW * IH * ID * C * b + IW * IH * ID * c + IW * IH * index_d[od]) * srcDataSize; - uint8_t *out_ptr = out_ptr_ + (OW * OH * OD * C * b + OW * OH * OD * c + OW * OH * od) * dstDataSize; + const uint8_t* in_ptr = + in_ptr_ + (IW * IH * ID * C * b + IW * IH * ID * c + IW * IH * index_d[od]) * srcDataSize; + uint8_t* out_ptr = out_ptr_ + (OW * OH * OD * C * b + OW * OH * OD * c + OW * OH * od) * dstDataSize; auto arg = jit_interpolate_call_args(); arg.src_ptr[0] = in_ptr; arg.dst = out_ptr; - arg.index = static_cast(&index_kernel[0]); // need index_h and index_w in kernel, it's in continous memory so one param + arg.index = static_cast( + &index_kernel[0]); // need index_h and index_w in kernel, it's in continous memory so one param arg.oc_off = static_cast(c * sizeof(float)); // work_amount is OH(out loop) and OW(inner loop), can get in kernel from jcp. arg.post_op_data = post_ops_data_; @@ -2636,18 +2788,27 @@ void Interpolate::InterpolateJitExecutor::NNPlanar(const uint8_t *in_ptr_, uint8 }); } -void Interpolate::InterpolateJitExecutor::linearOnnxPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, int B, int C, - int ID, int IH, int IW, int OD, int OH, int OW) { - // FrontTopLeft:0, FrontTopRight:1, FrontBottomLeft:2, FrontBottomRight:3, EndTopLeft:4, EndTopRight:5, EndBottomLeft:6, EndBottomRight:7 - // weight: Left:0, ritht:1, top:2, bottom:3, front:4, end:5 - int *index = static_cast(&auxTable[0]); +void Interpolate::InterpolateJitExecutor::linearOnnxPlanar(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int ID, + int IH, + int IW, + int OD, + int OH, + int OW) { + // FrontTopLeft:0, FrontTopRight:1, FrontBottomLeft:2, FrontBottomRight:3, EndTopLeft:4, EndTopRight:5, + // EndBottomLeft:6, EndBottomRight:7 weight: Left:0, ritht:1, top:2, bottom:3, front:4, end:5 + int* index = static_cast(&auxTable[0]); int eltInGrid = (spatialDimSize > 2) ? MAX_INPUT_INTERPOLATE : ((spatialDimSize > 1) ? 4 : 2); int scratchLen = rnd_up(eltInGrid * OW * OH * OD, 16); - float *weight = reinterpret_cast(&auxTable[scratchLen]); + float* weight = reinterpret_cast(&auxTable[scratchLen]); parallel_for2d(B, C, [&](size_t b, size_t c) { - uint8_t *out_ptr_nc = out_ptr_ + (OH * OW * OD * C * b + OH * OW * OD * c) * dstDataSize; - const uint8_t *in_ptr_nc = in_ptr_ + (IH * IW * ID * C * b + IH * IW * ID * c) * srcDataSize; + uint8_t* out_ptr_nc = out_ptr_ + (OH * OW * OD * C * b + OH * OW * OD * c) * dstDataSize; + const uint8_t* in_ptr_nc = in_ptr_ + (IH * IW * ID * C * b + IH * IW * ID * c) * srcDataSize; auto arg = jit_interpolate_call_args(); arg.src_ptr[0] = in_ptr_nc; arg.index = static_cast(&index[0]); @@ -2660,8 +2821,17 @@ void Interpolate::InterpolateJitExecutor::linearOnnxPlanar(const uint8_t *in_ptr }); } -void Interpolate::InterpolateJitExecutor::linearOnnxCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int ID, int IH, int IW, int OD, int OH, int OW) { +void Interpolate::InterpolateJitExecutor::linearOnnxCGathered(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int ID, + int IH, + int IW, + int OD, + int OH, + int OW) { // left:OW right:OW Top:OH Bottom:OH Front:OD End:OD std::vector indexPtr(MAX_INPUT_INTERPOLATE, 0); std::vector weightPtr(MAX_INPUT_INTERPOLATE, 0); @@ -2696,18 +2866,18 @@ void Interpolate::InterpolateJitExecutor::linearOnnxCGathered(const uint8_t *in_ int I2 = ID * I1; int I3 = CB * I2; parallel_for3d(B, OD, OH, [&](size_t b, size_t d, size_t h) { - uint8_t *out_ptr_ndh = out_ptr_ + (C3 * b + C1 * d + C0 * h) * dstDataSize; - - const uint8_t *in_ptr_n = in_ptr_ + (I3 * b) * srcDataSize; - const uint8_t *in_ptr_nf = in_ptr_n + (indexPtr[4][d] * I1) * srcDataSize; - const uint8_t *in_ptr_nft = in_ptr_nf + (indexPtr[2][h] * I0) * srcDataSize; - const uint8_t *in_ptr_nfb = in_ptr_nf + (indexPtr[3][h] * I0) * srcDataSize; - const uint8_t *in_ptr_ne = in_ptr_n + (indexPtr[5][d] * I1) * srcDataSize; - const uint8_t *in_ptr_net = in_ptr_ne + (indexPtr[2][h] * I0) * srcDataSize; - const uint8_t *in_ptr_neb = in_ptr_ne + (indexPtr[3][h] * I0) * srcDataSize; + uint8_t* out_ptr_ndh = out_ptr_ + (C3 * b + C1 * d + C0 * h) * dstDataSize; + + const uint8_t* in_ptr_n = in_ptr_ + (I3 * b) * srcDataSize; + const uint8_t* in_ptr_nf = in_ptr_n + (indexPtr[4][d] * I1) * srcDataSize; + const uint8_t* in_ptr_nft = in_ptr_nf + (indexPtr[2][h] * I0) * srcDataSize; + const uint8_t* in_ptr_nfb = in_ptr_nf + (indexPtr[3][h] * I0) * srcDataSize; + const uint8_t* in_ptr_ne = in_ptr_n + (indexPtr[5][d] * I1) * srcDataSize; + const uint8_t* in_ptr_net = in_ptr_ne + (indexPtr[2][h] * I0) * srcDataSize; + const uint8_t* in_ptr_neb = in_ptr_ne + (indexPtr[3][h] * I0) * srcDataSize; auto arg = jit_interpolate_call_args(); for (int w = 0; w < OW; ++w) { - uint8_t *out_ptr_ndhw = out_ptr_ndh + CGatherLen * w * dstDataSize; + uint8_t* out_ptr_ndhw = out_ptr_ndh + CGatherLen * w * dstDataSize; arg.src_ptr[0] = in_ptr_nft + (indexPtr[0][w] * CGatherLen) * srcDataSize; arg.src_ptr[1] = in_ptr_nft + (indexPtr[1][w] * CGatherLen) * srcDataSize; @@ -2732,13 +2902,20 @@ void Interpolate::InterpolateJitExecutor::linearOnnxCGathered(const uint8_t *in_ }); } -void Interpolate::InterpolateJitExecutor::cubicCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int IH, int IW, int OH, int OW) { +void Interpolate::InterpolateJitExecutor::cubicCGathered(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int IH, + int IW, + int OH, + int OW) { const int idxNum = 1; - int *xOrigin = static_cast(&auxTable[0]); - float *xFactor = reinterpret_cast(&auxTable[OW]); - int *yOrigin = static_cast(&auxTable[(CUBIC_GRID_LEN + idxNum) * OW]); - float *yFactor = reinterpret_cast(&auxTable[(CUBIC_GRID_LEN + idxNum) * OW + OH]); + int* xOrigin = static_cast(&auxTable[0]); + float* xFactor = reinterpret_cast(&auxTable[OW]); + int* yOrigin = static_cast(&auxTable[(CUBIC_GRID_LEN + idxNum) * OW]); + float* yFactor = reinterpret_cast(&auxTable[(CUBIC_GRID_LEN + idxNum) * OW + OH]); int blkSize = mayiuse(cpu::x64::avx512_core) ? 16 : 8; int CB = div_up(C, blkSize); @@ -2747,8 +2924,8 @@ void Interpolate::InterpolateJitExecutor::cubicCGathered(const uint8_t *in_ptr_, int workAmount = configured_for_layout == InterpolateLayoutType::by_channel ? C : CB; parallel_for3d(B, OH, OW, [&](size_t b, size_t h, size_t w) { - uint8_t *out_ptr_nhw = out_ptr_ + (OH * OW * CSize * b + OW * CGatherLen * h + CGatherLen * w) * dstDataSize; - const uint8_t *in_ptr_n = in_ptr_ + (IH * IW * CSize * b) * srcDataSize; + uint8_t* out_ptr_nhw = out_ptr_ + (OH * OW * CSize * b + OW * CGatherLen * h + CGatherLen * w) * dstDataSize; + const uint8_t* in_ptr_n = in_ptr_ + (IH * IW * CSize * b) * srcDataSize; std::vector kernelIndex(CUBIC_GRID_LEN * CUBIC_GRID_LEN); // 16 address offset to src(batch) or src(CB) int iy = yOrigin[h]; @@ -2763,41 +2940,48 @@ void Interpolate::InterpolateJitExecutor::cubicCGathered(const uint8_t *in_ptr_, } } auto arg = jit_interpolate_call_args(); - arg.dst = out_ptr_nhw; - arg.src_ptr[0] = in_ptr_n; - arg.index = static_cast(&kernelIndex[0]); - // 0 for weight_W, 1 for weight_H - arg.weight_ptr[0] = static_cast(&xFactor[w * CUBIC_GRID_LEN]); - arg.weight_ptr[1] = static_cast(&yFactor[h * CUBIC_GRID_LEN]); - - // for by channel, src + step, dst + step, process next step on continuous memory - // for blk, src + IW*IH*blkSize, dst + OW*OH*blkSize, process the blkSize on next CB - arg.work_amount = workAmount; - arg.oc_off = 0; - arg.post_op_data = post_ops_data_; - (*interpolateKernel)(&arg); + arg.dst = out_ptr_nhw; + arg.src_ptr[0] = in_ptr_n; + arg.index = static_cast(&kernelIndex[0]); + // 0 for weight_W, 1 for weight_H + arg.weight_ptr[0] = static_cast(&xFactor[w * CUBIC_GRID_LEN]); + arg.weight_ptr[1] = static_cast(&yFactor[h * CUBIC_GRID_LEN]); + + // for by channel, src + step, dst + step, process next step on continuous memory + // for blk, src + IW*IH*blkSize, dst + OW*OH*blkSize, process the blkSize on next CB + arg.work_amount = workAmount; + arg.oc_off = 0; + arg.post_op_data = post_ops_data_; + (*interpolateKernel)(&arg); }); } -void Interpolate::InterpolateJitExecutor::cubicPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int IH, int IW, int OH, int OW) { +void Interpolate::InterpolateJitExecutor::cubicPlanar(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int IH, + int IW, + int OH, + int OW) { int tblAdvance = 0; - int *xOrigin = static_cast(&auxTable[tblAdvance]); + int* xOrigin = static_cast(&auxTable[tblAdvance]); tblAdvance += OW; - float *xFactor = reinterpret_cast(&auxTable[tblAdvance]); + float* xFactor = reinterpret_cast(&auxTable[tblAdvance]); tblAdvance += CUBIC_GRID_LEN * OW; - int *yOrigin = static_cast(&auxTable[tblAdvance]); + int* yOrigin = static_cast(&auxTable[tblAdvance]); tblAdvance += OH; - float *yFactor = reinterpret_cast(&auxTable[tblAdvance]); + float* yFactor = reinterpret_cast(&auxTable[tblAdvance]); tblAdvance += CUBIC_GRID_LEN * OH; - int *sequenceOH = static_cast(&auxTable[tblAdvance]); + int* sequenceOH = static_cast(&auxTable[tblAdvance]); tblAdvance += OW * OH; - int *sequenceOW = static_cast(&auxTable[tblAdvance]); + int* sequenceOW = static_cast(&auxTable[tblAdvance]); parallel_for2d(B, C, [&](size_t n, size_t c) { - const uint8_t *in_ptr_nc = in_ptr_ + (IW * IH * C * n + IW * IH * c) * srcDataSize; - uint8_t *out_ptr_nc = out_ptr_ + (OW * OH * C * n + OW * OH * c) * dstDataSize; + const uint8_t* in_ptr_nc = in_ptr_ + (IW * IH * C * n + IW * IH * c) * srcDataSize; + uint8_t* out_ptr_nc = out_ptr_ + (OW * OH * C * n + OW * OH * c) * dstDataSize; auto arg = jit_interpolate_call_args(); arg.dst = out_ptr_nc; @@ -2815,8 +2999,15 @@ void Interpolate::InterpolateJitExecutor::cubicPlanar(const uint8_t *in_ptr_, ui }); } -void Interpolate::InterpolateJitExecutor::pillowCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int IH, int IW, int OH, int OW) { +void Interpolate::InterpolateJitExecutor::pillowCGathered(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int IH, + int IW, + int OH, + int OW) { // workBuffer needed when both pass are true bool xPass = IW != OW; bool yPass = IH != OH; @@ -2848,8 +3039,11 @@ void Interpolate::InterpolateJitExecutor::pillowCGathered(const uint8_t *in_ptr_ // ===================================================================================================================== // index layout: // d_0............d_OD-1, h_0..............h_OH-1, w_0................w_OW-1 -void Interpolate::InterpolateExecutorBase::buildTblNN(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, - const std::vector& dataScales, InterpolateLayoutType layout, InterpolateNearestMode nearestMode) { +void Interpolate::InterpolateExecutorBase::buildTblNN(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + InterpolateLayoutType layout, + InterpolateNearestMode nearestMode) { const int dimSize = dataRank; float fz = (dimSize == 5) ? dataScales[dimSize - 3] : 1.f; float fy = dataScales[dimSize - 2]; @@ -2881,80 +3075,91 @@ void Interpolate::InterpolateExecutorBase::buildTblNN(const VectorDims& srcDimPa // scale is float(outShape) / float(inShape) // strictly consistent with onnx calc manner(div scale, not multiply inverse), given this is done offline // the slight precison diff can produce obvious wrong value due to "nearest round" behavior for NN mode -float Interpolate::InterpolateExecutorBase::coordTransToInput(int outCoord, float scale, int inShape, int outShape) const { +float Interpolate::InterpolateExecutorBase::coordTransToInput(int outCoord, + float scale, + int inShape, + int outShape) const { if (scale == 1.0f || (inShape == outShape)) { return outCoord; } switch (coordTransMode) { - case InterpolateCoordTransMode::half_pixel: { + case InterpolateCoordTransMode::half_pixel: { + return (outCoord + 0.5f) / scale - 0.5f; + break; + } + case InterpolateCoordTransMode::pytorch_half_pixel: { + if (outShape > 1) return (outCoord + 0.5f) / scale - 0.5f; - break; - } - case InterpolateCoordTransMode::pytorch_half_pixel: { - if (outShape > 1) - return (outCoord + 0.5f) / scale - 0.5f; - else - return 0; - break; - } - case InterpolateCoordTransMode::asymmetric: { - return static_cast(outCoord) / scale; - break; - } - case InterpolateCoordTransMode::tf_half_pixel_for_nn: { - return (outCoord + 0.5f) / scale; - break; - } - case InterpolateCoordTransMode::align_corners: { - if (outShape > 1) - return outCoord * (static_cast(inShape - 1) / static_cast(outShape - 1)); - else - return 0; - break; - } - default: { - OPENVINO_THROW("errorPrefix", " does not support specified coordinate transformation mode"); - break; - } + else + return 0; + break; + } + case InterpolateCoordTransMode::asymmetric: { + return static_cast(outCoord) / scale; + break; + } + case InterpolateCoordTransMode::tf_half_pixel_for_nn: { + return (outCoord + 0.5f) / scale; + break; + } + case InterpolateCoordTransMode::align_corners: { + if (outShape > 1) + return outCoord * (static_cast(inShape - 1) / static_cast(outShape - 1)); + else + return 0; + break; + } + default: { + OPENVINO_THROW("errorPrefix", " does not support specified coordinate transformation mode"); + break; + } } } -int Interpolate::InterpolateExecutorBase::nearestRound(float originCoord, bool isDownsample, InterpolateNearestMode nearestMode) const { +int Interpolate::InterpolateExecutorBase::nearestRound(float originCoord, + bool isDownsample, + InterpolateNearestMode nearestMode) const { switch (nearestMode) { - case InterpolateNearestMode::round_prefer_floor: { - if (originCoord == (static_cast(originCoord) + 0.5f)) - return static_cast(std::floor(originCoord)); - else - return static_cast(std::round(originCoord)); - break; - } - case InterpolateNearestMode::round_prefer_ceil: { - return static_cast(std::round(originCoord)); - break; - } - case InterpolateNearestMode::floor: { + case InterpolateNearestMode::round_prefer_floor: { + if (originCoord == (static_cast(originCoord) + 0.5f)) return static_cast(std::floor(originCoord)); - break; - } - case InterpolateNearestMode::ceil: { + else + return static_cast(std::round(originCoord)); + break; + } + case InterpolateNearestMode::round_prefer_ceil: { + return static_cast(std::round(originCoord)); + break; + } + case InterpolateNearestMode::floor: { + return static_cast(std::floor(originCoord)); + break; + } + case InterpolateNearestMode::ceil: { + return static_cast(std::ceil(originCoord)); + break; + } + case InterpolateNearestMode::simple: { + if (isDownsample) return static_cast(std::ceil(originCoord)); - break; - } - case InterpolateNearestMode::simple: { - if (isDownsample) - return static_cast(std::ceil(originCoord)); - else - return static_cast(originCoord); - } - default: { - OPENVINO_THROW("errorPrefix", " does not support specified nearest round mode"); - break; - } + else + return static_cast(originCoord); + } + default: { + OPENVINO_THROW("errorPrefix", " does not support specified nearest round mode"); + break; + } } } -void Interpolate::InterpolateExecutorBase::linearOnnxCF(int outCoord, float scale, int inShape, int outShape, - int& index0, int& index1, float& weight0, float& weight1) { +void Interpolate::InterpolateExecutorBase::linearOnnxCF(int outCoord, + float scale, + int inShape, + int outShape, + int& index0, + int& index1, + float& weight0, + float& weight1) { float inCoord = coordTransToInput(outCoord, scale, inShape, outShape); inCoord = std::max(0.0f, std::min(inCoord, static_cast(inShape - 1))); index0 = std::min(static_cast(inCoord), inShape - 1); @@ -2968,8 +3173,10 @@ void Interpolate::InterpolateExecutorBase::linearOnnxCF(int outCoord, float scal } } -void Interpolate::InterpolateExecutorBase::buildTblLinearOnnx(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, - const std::vector& dataScales, InterpolateLayoutType layout) { +void Interpolate::InterpolateExecutorBase::buildTblLinearOnnx(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + InterpolateLayoutType layout) { int dimSize = dataRank; float fz = (spatialDimSize > 2) ? dataScales[dimSize - 3] : 1.f; float fy = (spatialDimSize > 1) ? dataScales[dimSize - 2] : 1.f; @@ -3028,7 +3235,7 @@ void Interpolate::InterpolateExecutorBase::buildTblLinearOnnx(const VectorDims& indexPtr[1][idxOzOyOx] = (izF * IH * IW + iyT * IW + ixR) * scale; weightPtr[0][idxOzOyOx] = weightL; weightPtr[1][idxOzOyOx] = weightR; - if (spatialDimSize > 1) { + if (spatialDimSize > 1) { indexPtr[2][idxOzOyOx] = (izF * IH * IW + iyB * IW + ixL) * scale; indexPtr[3][idxOzOyOx] = (izF * IH * IW + iyB * IW + ixR) * scale; weightPtr[2][idxOzOyOx] = weightT; @@ -3081,8 +3288,11 @@ void Interpolate::InterpolateExecutorBase::buildTblLinearOnnx(const VectorDims& // wd .........wd, wh............wh, ww.............ww, id...........id, ih............ih, iw..............iw // | | // wh0.....wh_diameter ih0.....ih_diameter -void Interpolate::InterpolateExecutorBase::buildTblLinear(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, - const std::vector& dataScales, int kernel_width, bool antialias) { +void Interpolate::InterpolateExecutorBase::buildTblLinear(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + int kernel_width, + bool antialias) { int dimSize = dataRank; float fz = (dimSize == 5) ? dataScales[dimSize - 3] : 1.f; float fy = dataScales[dimSize - 2]; @@ -3106,15 +3316,15 @@ void Interpolate::InterpolateExecutorBase::buildTblLinear(const VectorDims& srcD int sizeOH = OH * diaOH; int sizeOW = OW * diaOW; auxTable.resize((sizeOD + sizeOH + sizeOW) * 2); - float *weightTable = reinterpret_cast(&auxTable[0]); - float *weightOD = static_cast(&weightTable[0]); - float *weightOH = static_cast(&weightTable[sizeOD]); - float *weightOW = static_cast(&weightTable[sizeOD + sizeOH]); + float* weightTable = reinterpret_cast(&auxTable[0]); + float* weightOD = static_cast(&weightTable[0]); + float* weightOH = static_cast(&weightTable[sizeOD]); + float* weightOW = static_cast(&weightTable[sizeOD + sizeOH]); - int *idxTable = static_cast(&auxTable[sizeOD + sizeOH + sizeOW]); - int *idxOD = static_cast(&idxTable[0]); - int *idxOH = static_cast(&idxTable[sizeOD]); - int *idxOW = static_cast(&idxTable[sizeOD + sizeOH]); + int* idxTable = static_cast(&auxTable[sizeOD + sizeOH + sizeOW]); + int* idxOD = static_cast(&idxTable[0]); + int* idxOH = static_cast(&idxTable[sizeOD]); + int* idxOW = static_cast(&idxTable[sizeOD + sizeOH]); for (size_t oz = 0; oz < OD; oz++) { float iz = coordTransToInput(oz, fz, ID, OD); @@ -3172,8 +3382,11 @@ std::vector Interpolate::InterpolateExecutorBase::getCubicCoeffs(float ma // table layout: // OW OW OW OW OW OH OH OH OH OH // x_idx x_weight0 x_weight1 x_weight2 x_weight3 y_idx y_weight0 y_weight1 y_weight2 y_weight3 -void Interpolate::InterpolateExecutorBase::buildTblCubic(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, const std::vector& dataScales, - float cubicCoeff, InterpolateLayoutType layout) { +void Interpolate::InterpolateExecutorBase::buildTblCubic(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + float cubicCoeff, + InterpolateLayoutType layout) { int dimSize = dataRank; float fy = dataScales[dimSize - 2]; float fx = dataScales[dimSize - 1]; @@ -3191,9 +3404,9 @@ void Interpolate::InterpolateExecutorBase::buildTblCubic(const VectorDims& srcDi } int tblAdvance = 0; - int *xOrigin = static_cast(&auxTable[tblAdvance]); + int* xOrigin = static_cast(&auxTable[tblAdvance]); tblAdvance += OW; - float *xFactor = reinterpret_cast(&auxTable[tblAdvance]); + float* xFactor = reinterpret_cast(&auxTable[tblAdvance]); for (int ox = 0; ox < OW; ox++) { float ix = coordTransToInput(ox, fx, IW, OW); int ix_r = static_cast(std::floor(ix)); @@ -3207,9 +3420,9 @@ void Interpolate::InterpolateExecutorBase::buildTblCubic(const VectorDims& srcDi } tblAdvance += CUBIC_GRID_LEN * OW; - int *yOrigin = static_cast(&auxTable[tblAdvance]); + int* yOrigin = static_cast(&auxTable[tblAdvance]); tblAdvance += OH; - float *yFactor = reinterpret_cast(&auxTable[tblAdvance]); + float* yFactor = reinterpret_cast(&auxTable[tblAdvance]); for (int oy = 0; oy < OH; oy++) { float iy = coordTransToInput(oy, fy, IH, OH); int iy_r = static_cast(std::floor(iy)); @@ -3224,9 +3437,9 @@ void Interpolate::InterpolateExecutorBase::buildTblCubic(const VectorDims& srcDi if (layout == InterpolateLayoutType::planar) { tblAdvance += CUBIC_GRID_LEN * OH; - int *sequenceOH = static_cast(&auxTable[tblAdvance]); + int* sequenceOH = static_cast(&auxTable[tblAdvance]); tblAdvance += OH * OW; - int *sequenceOW = static_cast(&auxTable[tblAdvance]); + int* sequenceOW = static_cast(&auxTable[tblAdvance]); for (int h = 0; h < OH; ++h) { int offset = h * OW; for (int w = 0; w < OW; ++w) { @@ -3256,8 +3469,11 @@ float Interpolate::InterpolateExecutorBase::getPillowBicubicCoeffs(float m) { return 0.0f; } -void Interpolate::InterpolateExecutorBase::buildTblPillow(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, const std::vector& dataScales, - float cubicCoeff, InterpolateLayoutType layout) { +void Interpolate::InterpolateExecutorBase::buildTblPillow(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + float cubicCoeff, + InterpolateLayoutType layout) { int dimSize = dataRank; float fy = dataScales[dimSize - 2]; float fx = dataScales[dimSize - 1]; @@ -3272,15 +3488,15 @@ void Interpolate::InterpolateExecutorBase::buildTblPillow(const VectorDims& srcD }; // pillowScale: e.g. 2.0 means down sample 2 times - auto generateArgs = [&] (float pillowScale) -> filterArgs { + auto generateArgs = [&](float pillowScale) -> filterArgs { filterArgs args; float scaleClip = pillowScale < 1.0f ? 1.0f : pillowScale; args.ScaleClipReciprocal = 1.0f / scaleClip; - args.filterRadius = (mode == InterpolateMode::bilinear_pillow) ? PILLOW_BILINEAR_WINDOW_SCALE * scaleClip : - PILLOW_BICUBIC_WINDOW_SCALE * scaleClip; + args.filterRadius = (mode == InterpolateMode::bilinear_pillow) ? PILLOW_BILINEAR_WINDOW_SCALE * scaleClip + : PILLOW_BICUBIC_WINDOW_SCALE * scaleClip; args.filterLen = static_cast(std::ceil(args.filterRadius) * 2 + 1); - args.weightGen = (mode == InterpolateMode::bilinear_pillow) ? this->getPillowBilinearCoeffs: - this->getPillowBicubicCoeffs; + args.weightGen = + (mode == InterpolateMode::bilinear_pillow) ? this->getPillowBilinearCoeffs : this->getPillowBicubicCoeffs; return args; }; @@ -3295,15 +3511,15 @@ void Interpolate::InterpolateExecutorBase::buildTblPillow(const VectorDims& srcD auxTable[offset] = filterArgsX.filterLen; auxTable[offset + 1] = filterArgsY.filterLen; offset += 2; - float *weightX = reinterpret_cast(&auxTable[offset]); + float* weightX = reinterpret_cast(&auxTable[offset]); offset += filterArgsX.filterLen * OW; - float *weightY = reinterpret_cast(&auxTable[offset]); + float* weightY = reinterpret_cast(&auxTable[offset]); offset += filterArgsY.filterLen * OH; - int *indexX = static_cast(&auxTable[offset]); + int* indexX = static_cast(&auxTable[offset]); offset += 2 * OW; - int *indexY = static_cast(&auxTable[offset]); + int* indexY = static_cast(&auxTable[offset]); - auto generateTbl = [&] (int inLen, int outLen, float fScale, filterArgs args, float* weightTbl, int* idxTbl) { + auto generateTbl = [&](int inLen, int outLen, float fScale, filterArgs args, float* weightTbl, int* idxTbl) { int min = 0; int max = 0; for (int ox = 0; ox < outLen; ox++) { @@ -3347,21 +3563,29 @@ void Interpolate::InterpolateExecutorBase::buildTblPillow(const VectorDims& srcD generateTbl(IH, OH, fy, filterArgsY, weightY, indexY); } -void Interpolate::InterpolateRefExecutor::NNRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int ID, int IH, int IW, - int OD, int OH, int OW) { - int *index_d = static_cast(&auxTable[0]); - int *index_h = static_cast(&auxTable[OD]); - int *index_w = static_cast(&auxTable[OD + OH]); - - const float *in_ptr_f32 = reinterpret_cast(in_ptr_); - float *out_ptr_f32 = reinterpret_cast(out_ptr_); +void Interpolate::InterpolateRefExecutor::NNRef(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + int B, + int C, + int ID, + int IH, + int IW, + int OD, + int OH, + int OW) { + int* index_d = static_cast(&auxTable[0]); + int* index_h = static_cast(&auxTable[OD]); + int* index_w = static_cast(&auxTable[OD + OH]); + + const float* in_ptr_f32 = reinterpret_cast(in_ptr_); + float* out_ptr_f32 = reinterpret_cast(out_ptr_); parallel_for3d(B, C, OD, [&](size_t b, size_t c, size_t od) { - const float *in_ptr = in_ptr_f32 + (IW * IH * ID * C * b + IW * IH * ID * c + IW * IH * index_d[od]); - float *out_ptr = out_ptr_f32 + (OW * OH * OD * C * b + OW * OH * OD * c + OW * OH * od); + const float* in_ptr = in_ptr_f32 + (IW * IH * ID * C * b + IW * IH * ID * c + IW * IH * index_d[od]); + float* out_ptr = out_ptr_f32 + (OW * OH * OD * C * b + OW * OH * OD * c + OW * OH * od); for (int oh = 0; oh < OH; oh++) { - const float *in_ptr_h = in_ptr + (IW * index_h[oh]); - float *out_ptr_h = out_ptr + (OW * oh); + const float* in_ptr_h = in_ptr + (IW * index_h[oh]); + float* out_ptr_h = out_ptr + (OW * oh); for (int ow = 0; ow < OW; ow++) { out_ptr_h[ow] = in_ptr_h[index_w[ow]]; } @@ -3369,8 +3593,16 @@ void Interpolate::InterpolateRefExecutor::NNRef(const uint8_t *in_ptr_, uint8_t }); } -void Interpolate::InterpolateRefExecutor::linearOnnxRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int ID, int IH, int IW, - int OD, int OH, int OW) { +void Interpolate::InterpolateRefExecutor::linearOnnxRef(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + int B, + int C, + int ID, + int IH, + int IW, + int OD, + int OH, + int OW) { std::vector indexPtr(MAX_INPUT_INTERPOLATE, 0); std::vector weightPtr(MAX_INPUT_INTERPOLATE, 0); // FrontTopLeft:0, FrontTopRight:1, FrontBottomLeft:2, FrontBottomRight:3, @@ -3399,87 +3631,87 @@ void Interpolate::InterpolateRefExecutor::linearOnnxRef(const uint8_t *in_ptr_, weightPtr[5] = reinterpret_cast(&auxTable[scratchLen + 5 * OW * OH * OD]); } - const float *in_ptr_f32 = reinterpret_cast(in_ptr_); - float *out_ptr_f32 = reinterpret_cast(out_ptr_); + const float* in_ptr_f32 = reinterpret_cast(in_ptr_); + float* out_ptr_f32 = reinterpret_cast(out_ptr_); parallel_for2d(B, C, [&](size_t b, size_t c) { - float *out_ptr_nc = out_ptr_f32 + (OD * OH * OW * C * b + OD * OH * OW * c); - const float *in_ptr_nc = in_ptr_f32 + (ID * IH * IW * C * b + ID * IH * IW * c); + float* out_ptr_nc = out_ptr_f32 + (OD * OH * OW * C * b + OD * OH * OW * c); + const float* in_ptr_nc = in_ptr_f32 + (ID * IH * IW * C * b + ID * IH * IW * c); // do not combined 1d/2d to 3d unified process to get rid of invalid computing. switch (spatialDimSize) { - case 1: - for (int i = 0; i < OW; i++) { - float src0 = in_ptr_nc[indexPtr[0][i]]; - float src1 = in_ptr_nc[indexPtr[1][i]]; + case 1: + for (int i = 0; i < OW; i++) { + float src0 = in_ptr_nc[indexPtr[0][i]]; + float src1 = in_ptr_nc[indexPtr[1][i]]; - out_ptr_nc[i] = src0 * weightPtr[0][i] + - src1 * weightPtr[1][i]; - } - break; - case 2: - for (int i = 0; i < OH * OW; i++) { - float src00 = in_ptr_nc[indexPtr[0][i]]; - float src01 = in_ptr_nc[indexPtr[1][i]]; - float src10 = in_ptr_nc[indexPtr[2][i]]; - float src11 = in_ptr_nc[indexPtr[3][i]]; - - out_ptr_nc[i] = src00 * weightPtr[2][i] * weightPtr[0][i] + - src01 * weightPtr[2][i] * weightPtr[1][i] + - src10 * weightPtr[3][i] * weightPtr[0][i] + - src11 * weightPtr[3][i] * weightPtr[1][i]; - } - break; - case 3: - for (int i = 0; i < OD * OH * OW; i++) { - float src000 = in_ptr_nc[indexPtr[0][i]]; - float src001 = in_ptr_nc[indexPtr[1][i]]; - float src010 = in_ptr_nc[indexPtr[2][i]]; - float src011 = in_ptr_nc[indexPtr[3][i]]; - float src100 = in_ptr_nc[indexPtr[4][i]]; - float src101 = in_ptr_nc[indexPtr[5][i]]; - float src110 = in_ptr_nc[indexPtr[6][i]]; - float src111 = in_ptr_nc[indexPtr[7][i]]; - - // float dstValue = - // weightPtr[4][i] * weightPtr[2][i] * weightPtr[0][i] * src000 + - // weightPtr[4][i] * weightPtr[2][i] * weightPtr[1][i] * src001 + - // weightPtr[4][i] * weightPtr[3][i] * weightPtr[0][i] * src010 + - // weightPtr[4][i] * weightPtr[3][i] * weightPtr[1][i] * src011 + - // weightPtr[5][i] * weightPtr[2][i] * weightPtr[0][i] * src100 + - // weightPtr[5][i] * weightPtr[2][i] * weightPtr[1][i] * src101 + - // weightPtr[5][i] * weightPtr[3][i] * weightPtr[0][i] * src110 + - // weightPtr[5][i] * weightPtr[3][i] * weightPtr[1][i] * src111; - - out_ptr_nc[i] = - weightPtr[4][i] * (weightPtr[2][i] * (weightPtr[0][i] * src000 + - weightPtr[1][i] * src001) + - weightPtr[3][i] * (weightPtr[0][i] * src010 + - weightPtr[1][i] * src011)) + - weightPtr[5][i] * (weightPtr[2][i] * (weightPtr[0][i] * src100 + - weightPtr[1][i] * src101) + - weightPtr[3][i] * (weightPtr[0][i] * src110 + - weightPtr[1][i] * src111)); - } - break; - default: - break; + out_ptr_nc[i] = src0 * weightPtr[0][i] + src1 * weightPtr[1][i]; + } + break; + case 2: + for (int i = 0; i < OH * OW; i++) { + float src00 = in_ptr_nc[indexPtr[0][i]]; + float src01 = in_ptr_nc[indexPtr[1][i]]; + float src10 = in_ptr_nc[indexPtr[2][i]]; + float src11 = in_ptr_nc[indexPtr[3][i]]; + + out_ptr_nc[i] = src00 * weightPtr[2][i] * weightPtr[0][i] + src01 * weightPtr[2][i] * weightPtr[1][i] + + src10 * weightPtr[3][i] * weightPtr[0][i] + src11 * weightPtr[3][i] * weightPtr[1][i]; + } + break; + case 3: + for (int i = 0; i < OD * OH * OW; i++) { + float src000 = in_ptr_nc[indexPtr[0][i]]; + float src001 = in_ptr_nc[indexPtr[1][i]]; + float src010 = in_ptr_nc[indexPtr[2][i]]; + float src011 = in_ptr_nc[indexPtr[3][i]]; + float src100 = in_ptr_nc[indexPtr[4][i]]; + float src101 = in_ptr_nc[indexPtr[5][i]]; + float src110 = in_ptr_nc[indexPtr[6][i]]; + float src111 = in_ptr_nc[indexPtr[7][i]]; + + // float dstValue = + // weightPtr[4][i] * weightPtr[2][i] * weightPtr[0][i] * src000 + + // weightPtr[4][i] * weightPtr[2][i] * weightPtr[1][i] * src001 + + // weightPtr[4][i] * weightPtr[3][i] * weightPtr[0][i] * src010 + + // weightPtr[4][i] * weightPtr[3][i] * weightPtr[1][i] * src011 + + // weightPtr[5][i] * weightPtr[2][i] * weightPtr[0][i] * src100 + + // weightPtr[5][i] * weightPtr[2][i] * weightPtr[1][i] * src101 + + // weightPtr[5][i] * weightPtr[3][i] * weightPtr[0][i] * src110 + + // weightPtr[5][i] * weightPtr[3][i] * weightPtr[1][i] * src111; + + out_ptr_nc[i] = + weightPtr[4][i] * (weightPtr[2][i] * (weightPtr[0][i] * src000 + weightPtr[1][i] * src001) + + weightPtr[3][i] * (weightPtr[0][i] * src010 + weightPtr[1][i] * src011)) + + weightPtr[5][i] * (weightPtr[2][i] * (weightPtr[0][i] * src100 + weightPtr[1][i] * src101) + + weightPtr[3][i] * (weightPtr[0][i] * src110 + weightPtr[1][i] * src111)); + } + break; + default: + break; } }); } -void Interpolate::InterpolateRefExecutor::cubicRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW) { +void Interpolate::InterpolateRefExecutor::cubicRef(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + int B, + int C, + int IH, + int IW, + int OH, + int OW) { const int idxNum = 1; - int *xOrigin = static_cast(&auxTable[0]); - float *xFactor = reinterpret_cast(&auxTable[OW]); - int *yOrigin = static_cast(&auxTable[(CUBIC_GRID_LEN + idxNum) * OW]); - float *yFactor = reinterpret_cast(&auxTable[(CUBIC_GRID_LEN + idxNum) * OW + OH]); + int* xOrigin = static_cast(&auxTable[0]); + float* xFactor = reinterpret_cast(&auxTable[OW]); + int* yOrigin = static_cast(&auxTable[(CUBIC_GRID_LEN + idxNum) * OW]); + float* yFactor = reinterpret_cast(&auxTable[(CUBIC_GRID_LEN + idxNum) * OW + OH]); - const float *in_ptr_f32 = reinterpret_cast(in_ptr_); - float *out_ptr_f32 = reinterpret_cast(out_ptr_); + const float* in_ptr_f32 = reinterpret_cast(in_ptr_); + float* out_ptr_f32 = reinterpret_cast(out_ptr_); parallel_for4d(B, C, OH, OW, [&](size_t n, size_t c, size_t oy, size_t ox) { - const float *in_ptr_nc = in_ptr_f32 + (IW * IH * C * n + IW * IH * c); - float *out_ptr_nc = out_ptr_f32 + (OW * OH * C * n + OW * OH * c); + const float* in_ptr_nc = in_ptr_f32 + (IW * IH * C * n + IW * IH * c); + float* out_ptr_nc = out_ptr_f32 + (OW * OH * C * n + OW * OH * c); int iy = yOrigin[oy]; int ix = xOrigin[ox]; @@ -3487,7 +3719,7 @@ void Interpolate::InterpolateRefExecutor::cubicRef(const uint8_t *in_ptr_, uint8 float retY = 0.f; for (int y = iy - 1, i = 0; y <= iy + 2; y++, i++) { int yInRange = std::max(0, std::min(y, IH - 1)); - const float *in_ptr_nch = in_ptr_nc + IW * yInRange; + const float* in_ptr_nch = in_ptr_nc + IW * yInRange; float retX = 0.f; for (int x = ix - 1, j = 0; x <= ix + 2; x++, j++) { int xInRange = std::max(0, std::min(x, IW - 1)); @@ -3499,66 +3731,79 @@ void Interpolate::InterpolateRefExecutor::cubicRef(const uint8_t *in_ptr_, uint8 }); } -float Interpolate::InterpolateRefExecutor::getValue(const uint8_t *base, size_t offset, ov::element::Type prec) { - const uint8_t *baseOffset = base + offset; +float Interpolate::InterpolateRefExecutor::getValue(const uint8_t* base, size_t offset, ov::element::Type prec) { + const uint8_t* baseOffset = base + offset; switch (prec) { - case ov::element::u8: { - return static_cast(*baseOffset); - break; - } - case ov::element::i8: { - const int8_t *valuePtr = reinterpret_cast(baseOffset); - return static_cast(*valuePtr); - break; - } - case ov::element::bf16: { - const uint16_t *valuePtr = reinterpret_cast(baseOffset); - return bfloat16_t::from_bits(*valuePtr); - break; - } - case ov::element::f32: { - const float *valuePtr = reinterpret_cast(baseOffset); - return *valuePtr; - break; - } - default: { - OPENVINO_THROW("Interpolate layer does not support precision: ", prec); - break; - } + case ov::element::u8: { + return static_cast(*baseOffset); + break; + } + case ov::element::i8: { + const int8_t* valuePtr = reinterpret_cast(baseOffset); + return static_cast(*valuePtr); + break; + } + case ov::element::bf16: { + const uint16_t* valuePtr = reinterpret_cast(baseOffset); + return bfloat16_t::from_bits(*valuePtr); + break; + } + case ov::element::f32: { + const float* valuePtr = reinterpret_cast(baseOffset); + return *valuePtr; + break; + } + default: { + OPENVINO_THROW("Interpolate layer does not support precision: ", prec); + break; + } } } -void Interpolate::InterpolateRefExecutor::setValue(uint8_t *base, size_t offset, float value, ov::element::Type prec) { - uint8_t *baseOffset = base + offset; +void Interpolate::InterpolateRefExecutor::setValue(uint8_t* base, size_t offset, float value, ov::element::Type prec) { + uint8_t* baseOffset = base + offset; switch (prec) { - case ov::element::u8: { - uint8_t data = static_cast(value < 0 ? 0 : value); - cpu_memcpy(baseOffset, &data, 1); - break; - } - case ov::element::i8: { - int8_t data = static_cast(value); - cpu_memcpy(baseOffset, &data, 1); - break; - } - case ov::element::bf16: { - uint16_t data = bfloat16_t(value).to_bits(); - cpu_memcpy(baseOffset, &data, 2); - break; - } - case ov::element::f32: { - cpu_memcpy(baseOffset, &value, sizeof(float)); - break; - } - default: { - OPENVINO_THROW("Interpolate layer does not support precision: ", prec); - break; - } + case ov::element::u8: { + uint8_t data = static_cast(value < 0 ? 0 : value); + cpu_memcpy(baseOffset, &data, 1); + break; + } + case ov::element::i8: { + int8_t data = static_cast(value); + cpu_memcpy(baseOffset, &data, 1); + break; + } + case ov::element::bf16: { + uint16_t data = bfloat16_t(value).to_bits(); + cpu_memcpy(baseOffset, &data, 2); + break; + } + case ov::element::f32: { + cpu_memcpy(baseOffset, &value, sizeof(float)); + break; + } + default: { + OPENVINO_THROW("Interpolate layer does not support precision: ", prec); + break; + } } } -void Interpolate::InterpolateRefExecutor::linearInterpolation(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int ID, int IH, int IW, - float fx, float fy, float fz, int OD, int OH, int OW, int kernel_width, bool antialias) { +void Interpolate::InterpolateRefExecutor::linearInterpolation(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + int B, + int C, + int ID, + int IH, + int IW, + float fx, + float fy, + float fz, + int OD, + int OH, + int OW, + int kernel_width, + bool antialias) { if (IW == OW && IH == OH && ID == OD) { size_t spatialDimSize = IW * IH * ID; // TODO: enable when fusing into interp with linear mode will support @@ -3567,8 +3812,8 @@ void Interpolate::InterpolateRefExecutor::linearInterpolation(const uint8_t *in_ cpu_memcpy(out_ptr_, in_ptr_, size); } else { parallel_for2d(B, C, [&](size_t b, size_t c) { - const uint8_t *in_ptr_nc = in_ptr_ + (spatialDimSize * C * b + spatialDimSize * c) * srcDataSize; - uint8_t *out_ptr_nc = out_ptr_ + (spatialDimSize * C * b + spatialDimSize * c) * dstDataSize; + const uint8_t* in_ptr_nc = in_ptr_ + (spatialDimSize * C * b + spatialDimSize * c) * srcDataSize; + uint8_t* out_ptr_nc = out_ptr_ + (spatialDimSize * C * b + spatialDimSize * c) * dstDataSize; for (size_t i = 0; i < spatialDimSize; i++) { float dstValue = getValue(in_ptr_nc, i * srcDataSize, inputPrec); setValue(out_ptr_nc, i * dstDataSize, dstValue, outputPrec); @@ -3593,23 +3838,23 @@ void Interpolate::InterpolateRefExecutor::linearInterpolation(const uint8_t *in_ int sizeOH = OH * diaOH; int sizeOW = OW * diaOW; - float *weightTable = reinterpret_cast(&auxTable[0]); - float *weightOD = static_cast(&weightTable[0]); - float *weightOH = static_cast(&weightTable[sizeOD]); - float *weightOW = static_cast(&weightTable[sizeOD + sizeOH]); + float* weightTable = reinterpret_cast(&auxTable[0]); + float* weightOD = static_cast(&weightTable[0]); + float* weightOH = static_cast(&weightTable[sizeOD]); + float* weightOW = static_cast(&weightTable[sizeOD + sizeOH]); - int *idxTable = static_cast(&auxTable[sizeOD + sizeOH + sizeOW]); - int *idxOD = static_cast(&idxTable[0]); - int *idxOH = static_cast(&idxTable[sizeOD]); - int *idxOW = static_cast(&idxTable[sizeOD + sizeOH]); + int* idxTable = static_cast(&auxTable[sizeOD + sizeOH + sizeOW]); + int* idxOD = static_cast(&idxTable[0]); + int* idxOH = static_cast(&idxTable[sizeOD]); + int* idxOW = static_cast(&idxTable[sizeOD + sizeOH]); parallel_for2d(B, C, [&](size_t b, size_t c) { - const uint8_t *in_ptr_nc = in_ptr_ + (IW * IH * ID * C * b + IW * IH * ID * c) * srcDataSize; - uint8_t *out_ptr_nc = out_ptr_ + (OW * OH * OD * C * b + OW * OH * OD * c) * dstDataSize; + const uint8_t* in_ptr_nc = in_ptr_ + (IW * IH * ID * C * b + IW * IH * ID * c) * srcDataSize; + uint8_t* out_ptr_nc = out_ptr_ + (OW * OH * OD * C * b + OW * OH * OD * c) * dstDataSize; for (int oz = 0; oz < OD; oz++) { - uint8_t *out_ptr_ncd = out_ptr_nc + (OW * OH * oz) * dstDataSize; + uint8_t* out_ptr_ncd = out_ptr_nc + (OW * OH * oz) * dstDataSize; for (int oy = 0; oy < OH; oy++) { - uint8_t *out_ptr_ncdh = out_ptr_ncd + (OW * oy) * dstDataSize; + uint8_t* out_ptr_ncdh = out_ptr_ncd + (OW * oy) * dstDataSize; for (int ox = 0; ox < OW; ox++) { float sum = 0.f; float wsum = 0.f; @@ -3652,9 +3897,13 @@ void Interpolate::InterpolateRefExecutor::linearInterpolation(const uint8_t *in_ if (weightOW[ox * diaOW + ix] == 0.f) { continue; } - float w = weightOD[oz * diaOD + iz] * weightOH[oy * diaOH + iy] * weightOW[ox * diaOW + ix]; + float w = + weightOD[oz * diaOD + iz] * weightOH[oy * diaOH + iy] * weightOW[ox * diaOW + ix]; float value = getValue(in_ptr_nc, - (idxOD[oz * diaOD + iz] * IH * IW + idxOH[oy * diaOH + iy] * IW + idxOW[ox * diaOW + ix]) * srcDataSize, inputPrec); + (idxOD[oz * diaOD + iz] * IH * IW + idxOH[oy * diaOH + iy] * IW + + idxOW[ox * diaOW + ix]) * + srcDataSize, + inputPrec); sum += w * value; wsum += w; @@ -3674,18 +3923,25 @@ void Interpolate::InterpolateRefExecutor::linearInterpolation(const uint8_t *in_ }); } -void Interpolate::InterpolateRefExecutor::pillowRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW) { +void Interpolate::InterpolateRefExecutor::pillowRef(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + int B, + int C, + int IH, + int IW, + int OH, + int OW) { size_t offset = 0; int filterLenX = auxTable[offset]; int filterLenY = auxTable[offset + 1]; offset += 2; - float *weightX = reinterpret_cast(&auxTable[offset]); + float* weightX = reinterpret_cast(&auxTable[offset]); offset += filterLenX * OW; - float *weightY = reinterpret_cast(&auxTable[offset]); + float* weightY = reinterpret_cast(&auxTable[offset]); offset += filterLenY * OH; - int *indexX = static_cast(&auxTable[offset]); + int* indexX = static_cast(&auxTable[offset]); offset += 2 * OW; - int *indexY = static_cast(&auxTable[offset]); + int* indexY = static_cast(&auxTable[offset]); // workBuffer needed when both pass is true bool xPass = IW != OW; @@ -3703,21 +3959,24 @@ void Interpolate::InterpolateRefExecutor::pillowRef(const uint8_t *in_ptr_, uint // | | // ---- auto bc_loop = [&](size_t b, size_t c) { - const uint8_t *in_ptr_nc = in_ptr_ + (IW * IH * C * b + IW * IH * c) * srcDataSize; - uint8_t *out_ptr_nc = out_ptr_ + (OW * OH * C * b + OW * OH * c) * dstDataSize; - uint8_t *xpass_out_ptr_nc = nullptr; - const uint8_t *ypass_in_ptr_nc = nullptr; + const uint8_t* in_ptr_nc = in_ptr_ + (IW * IH * C * b + IW * IH * c) * srcDataSize; + uint8_t* out_ptr_nc = out_ptr_ + (OW * OH * C * b + OW * OH * c) * dstDataSize; + uint8_t* xpass_out_ptr_nc = nullptr; + const uint8_t* ypass_in_ptr_nc = nullptr; if (xPass && yPass) { size_t parallel_num = B * C; // IH * OW buf needed if (parallel_num < m_threads_num) { - xpass_out_ptr_nc = static_cast(&pillow_working_buf[(OW * IH * C * b + OW * IH * c) * srcDataSize]); - ypass_in_ptr_nc = static_cast(&pillow_working_buf[(OW * IH * C * b + OW * IH * c) * srcDataSize]); + xpass_out_ptr_nc = + static_cast(&pillow_working_buf[(OW * IH * C * b + OW * IH * c) * srcDataSize]); + ypass_in_ptr_nc = + static_cast(&pillow_working_buf[(OW * IH * C * b + OW * IH * c) * srcDataSize]); } else { size_t threadsIdx = parallel_get_thread_num(); size_t buffer_size = static_cast(OW * IH); xpass_out_ptr_nc = static_cast(&pillow_working_buf[threadsIdx * buffer_size * srcDataSize]); - ypass_in_ptr_nc = static_cast(&pillow_working_buf[threadsIdx * buffer_size * srcDataSize]); + ypass_in_ptr_nc = + static_cast(&pillow_working_buf[threadsIdx * buffer_size * srcDataSize]); } } else if (xPass && !yPass) { xpass_out_ptr_nc = out_ptr_nc; @@ -3775,14 +4034,14 @@ void Interpolate::InterpolateRefExecutor::pillowRef(const uint8_t *in_ptr_, uint void Interpolate::InterpolateExecutorBase::create_pillow_working_buf(InterpolateLayoutType layout) { if (srcDimPad5d[3] == dstDim5d[3] || srcDimPad5d[4] == dstDim5d[4]) return; - size_t bufSize = srcDimPad5d[3] * dstDim5d[4] * srcDataSize; // IH * OW + size_t bufSize = srcDimPad5d[3] * dstDim5d[4] * srcDataSize; // IH * OW m_threads_num = parallel_get_max_threads(); if (layout == InterpolateLayoutType::planar) { // B and C execute in parallel, need separate buf size_t parallel_num = srcDimPad5d[0] * srcDimPad5d[1]; bufSize *= std::min(m_threads_num, parallel_num); } else { - bufSize *= srcDimPad5d[1]; // *C + bufSize *= srcDimPad5d[1]; // *C // B execute in parallel, need separate buf size_t parallel_num = srcDimPad5d[0]; bufSize *= std::min(m_threads_num, parallel_num); @@ -3791,11 +4050,14 @@ void Interpolate::InterpolateExecutorBase::create_pillow_working_buf(Interpolate } Interpolate::InterpolateExecutorBase::InterpolateExecutorBase(const InterpolateAttrs& interpAttrs, - const VectorDims &srcDims, - const VectorDims &dstDims, - const std::vector &dataScales) : - mode(interpAttrs.mode), coordTransMode(interpAttrs.coordTransMode), configured_for_layout(interpAttrs.layout), - inputPrec(interpAttrs.inPrc), outputPrec(interpAttrs.outPrc) { + const VectorDims& srcDims, + const VectorDims& dstDims, + const std::vector& dataScales) + : mode(interpAttrs.mode), + coordTransMode(interpAttrs.coordTransMode), + configured_for_layout(interpAttrs.layout), + inputPrec(interpAttrs.inPrc), + outputPrec(interpAttrs.outPrc) { srcDimPad5d = to5Dim(getPaddedInputShape(srcDims, interpAttrs.padBegin, interpAttrs.padEnd)); dstDim5d = to5Dim(dstDims); srcDataSize = interpAttrs.inPrc.size(); @@ -3804,44 +4066,44 @@ Interpolate::InterpolateExecutorBase::InterpolateExecutorBase(const InterpolateA spatialDimSize = getSpatialDimsNum(dataRank); switch (mode) { - case InterpolateMode::nearest: { - buildTblNN(srcDimPad5d, dstDim5d, dataScales, interpAttrs.layout, interpAttrs.nearestMode); - break; - } - case InterpolateMode::linear_onnx: { - buildTblLinearOnnx(srcDimPad5d, dstDim5d, dataScales, interpAttrs.layout); - break; - } - case InterpolateMode::linear: { - static constexpr int LINEAR_KERNEL = 2; - buildTblLinear(srcDimPad5d, dstDim5d, dataScales, LINEAR_KERNEL, interpAttrs.antialias); - break; - } - case InterpolateMode::cubic: { - buildTblCubic(srcDimPad5d, dstDim5d, dataScales, interpAttrs.cubeCoeff, interpAttrs.layout); - break; - } - case InterpolateMode::bilinear_pillow: - case InterpolateMode::bicubic_pillow: { - buildTblPillow(srcDimPad5d, dstDim5d, dataScales, interpAttrs.cubeCoeff, interpAttrs.layout); - if ((srcDimPad5d[4] != dstDim5d[4]) && (srcDimPad5d[3] != dstDim5d[3])) { - create_pillow_working_buf(interpAttrs.layout); - } - break; - } - default: { - OPENVINO_THROW("Interpolate executor does not support interpolate mode: ", mode); - break; + case InterpolateMode::nearest: { + buildTblNN(srcDimPad5d, dstDim5d, dataScales, interpAttrs.layout, interpAttrs.nearestMode); + break; + } + case InterpolateMode::linear_onnx: { + buildTblLinearOnnx(srcDimPad5d, dstDim5d, dataScales, interpAttrs.layout); + break; + } + case InterpolateMode::linear: { + static constexpr int LINEAR_KERNEL = 2; + buildTblLinear(srcDimPad5d, dstDim5d, dataScales, LINEAR_KERNEL, interpAttrs.antialias); + break; + } + case InterpolateMode::cubic: { + buildTblCubic(srcDimPad5d, dstDim5d, dataScales, interpAttrs.cubeCoeff, interpAttrs.layout); + break; + } + case InterpolateMode::bilinear_pillow: + case InterpolateMode::bicubic_pillow: { + buildTblPillow(srcDimPad5d, dstDim5d, dataScales, interpAttrs.cubeCoeff, interpAttrs.layout); + if ((srcDimPad5d[4] != dstDim5d[4]) && (srcDimPad5d[3] != dstDim5d[3])) { + create_pillow_working_buf(interpAttrs.layout); } + break; + } + default: { + OPENVINO_THROW("Interpolate executor does not support interpolate mode: ", mode); + break; + } } } Interpolate::InterpolateJitExecutor::InterpolateJitExecutor(const InterpolateAttrs& interpAttrs, - const VectorDims &srcDims, - const VectorDims &dstDims, - const std::vector &dataScales, - const dnnl::primitive_attr &attr) : - InterpolateExecutorBase(interpAttrs, srcDims, dstDims, dataScales) { + const VectorDims& srcDims, + const VectorDims& dstDims, + const std::vector& dataScales, + const dnnl::primitive_attr& attr) + : InterpolateExecutorBase(interpAttrs, srcDims, dstDims, dataScales) { auto jcp = jit_interpolate_config_params(); jcp.mode = mode; jcp.src_prc = interpAttrs.inPrc; @@ -3878,7 +4140,7 @@ Interpolate::InterpolateJitExecutor::InterpolateJitExecutor(const InterpolateAtt } else { OPENVINO_THROW("Can't create InterpolateJitExecutor"); } -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 if (interpolateKernel) { interpolateKernel->create_ker(); } else { @@ -3886,7 +4148,7 @@ Interpolate::InterpolateJitExecutor::InterpolateJitExecutor(const InterpolateAtt } } -void Interpolate::InterpolateJitExecutor::exec(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_) { +void Interpolate::InterpolateJitExecutor::exec(const uint8_t* in_ptr_, uint8_t* out_ptr_, const void* post_ops_data_) { size_t N = srcDimPad5d[0], C = srcDimPad5d[1], ID = srcDimPad5d[2], IH = srcDimPad5d[3], IW = srcDimPad5d[4]; size_t OD = dstDim5d[2], OH = dstDim5d[3], OW = dstDim5d[4]; @@ -3894,103 +4156,115 @@ void Interpolate::InterpolateJitExecutor::exec(const uint8_t *in_ptr_, uint8_t * OPENVINO_THROW("Can't execute, kernel for Interpolate node is not compiled"); } switch (mode) { - case InterpolateMode::nearest: { - if (configured_for_layout == InterpolateLayoutType::planar) { - NNPlanar(in_ptr_, out_ptr_, post_ops_data_, N, C, ID, IH, IW, OD, OH, OW); - } else { - NNCGathered(in_ptr_, out_ptr_, post_ops_data_, N, C, ID, IH, IW, OD, OH, OW); - } - break; - } - case InterpolateMode::linear_onnx: { - if (configured_for_layout == InterpolateLayoutType::planar) { - linearOnnxPlanar(in_ptr_, out_ptr_, post_ops_data_, N, C, ID, IH, IW, OD, OH, OW); - } else { - linearOnnxCGathered(in_ptr_, out_ptr_, post_ops_data_, N, C, ID, IH, IW, OD, OH, OW); - } - break; + case InterpolateMode::nearest: { + if (configured_for_layout == InterpolateLayoutType::planar) { + NNPlanar(in_ptr_, out_ptr_, post_ops_data_, N, C, ID, IH, IW, OD, OH, OW); + } else { + NNCGathered(in_ptr_, out_ptr_, post_ops_data_, N, C, ID, IH, IW, OD, OH, OW); } - case InterpolateMode::cubic: { - if (configured_for_layout == InterpolateLayoutType::planar) { - cubicPlanar(in_ptr_, out_ptr_, post_ops_data_, N, C, IH, IW, OH, OW); - } else { - cubicCGathered(in_ptr_, out_ptr_, post_ops_data_, N, C, IH, IW, OH, OW); - } - break; + break; + } + case InterpolateMode::linear_onnx: { + if (configured_for_layout == InterpolateLayoutType::planar) { + linearOnnxPlanar(in_ptr_, out_ptr_, post_ops_data_, N, C, ID, IH, IW, OD, OH, OW); + } else { + linearOnnxCGathered(in_ptr_, out_ptr_, post_ops_data_, N, C, ID, IH, IW, OD, OH, OW); } - case InterpolateMode::bilinear_pillow: - case InterpolateMode::bicubic_pillow: { - if (configured_for_layout == InterpolateLayoutType::by_channel) { - pillowCGathered(in_ptr_, out_ptr_, post_ops_data_, N, C, IH, IW, OH, OW); - } else { - OPENVINO_THROW("Only channel_first jit kernel is supported for pillow mode", mode); - } - break; + break; + } + case InterpolateMode::cubic: { + if (configured_for_layout == InterpolateLayoutType::planar) { + cubicPlanar(in_ptr_, out_ptr_, post_ops_data_, N, C, IH, IW, OH, OW); + } else { + cubicCGathered(in_ptr_, out_ptr_, post_ops_data_, N, C, IH, IW, OH, OW); } - default: { - OPENVINO_THROW("InterpolateJitExecutor has unsupported interpolate mode: ", mode); + break; + } + case InterpolateMode::bilinear_pillow: + case InterpolateMode::bicubic_pillow: { + if (configured_for_layout == InterpolateLayoutType::by_channel) { + pillowCGathered(in_ptr_, out_ptr_, post_ops_data_, N, C, IH, IW, OH, OW); + } else { + OPENVINO_THROW("Only channel_first jit kernel is supported for pillow mode", mode); } + break; + } + default: { + OPENVINO_THROW("InterpolateJitExecutor has unsupported interpolate mode: ", mode); + } } } -void Interpolate::InterpolateRefExecutor::exec(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_) { +void Interpolate::InterpolateRefExecutor::exec(const uint8_t* in_ptr_, uint8_t* out_ptr_, const void* post_ops_data_) { size_t N = srcDimPad5d[0], C = srcDimPad5d[1], ID = srcDimPad5d[2], IH = srcDimPad5d[3], IW = srcDimPad5d[4]; size_t OD = dstDim5d[2], OH = dstDim5d[3], OW = dstDim5d[4]; switch (mode) { - case InterpolateMode::nearest: { - NNRef(in_ptr_, out_ptr_, N, C, ID, IH, IW, OD, OH, OW); - break; - } - case InterpolateMode::linear_onnx: { - linearOnnxRef(in_ptr_, out_ptr_, N, C, ID, IH, IW, OD, OH, OW); - break; - } - case InterpolateMode::cubic: { - cubicRef(in_ptr_, out_ptr_, N, C, IH, IW, OH, OW); - break; - } - case InterpolateMode::linear: { - float fz = (dataRank == 5) ? dataScales[dataRank - 3] : 1.f; - float fy = dataScales[dataRank - 2]; - float fx = dataScales[dataRank - 1]; - - bool isDownsample = (fx < 1.f) || (fy < 1.f) || (fz < 1.f); - int kernel_width = 2; - linearInterpolation(in_ptr_, out_ptr_, N, C, ID, IH, IW, fx, fy, fz, OD, OH, OW, kernel_width, isDownsample && antialias); - break; - } - case InterpolateMode::bilinear_pillow: - case InterpolateMode::bicubic_pillow: { - pillowRef(in_ptr_, out_ptr_, N, C, IH, IW, OH, OW); - break; - } - default: { - OPENVINO_THROW("Interpolate layer has unsupported interpolate mode: ", mode); - } + case InterpolateMode::nearest: { + NNRef(in_ptr_, out_ptr_, N, C, ID, IH, IW, OD, OH, OW); + break; + } + case InterpolateMode::linear_onnx: { + linearOnnxRef(in_ptr_, out_ptr_, N, C, ID, IH, IW, OD, OH, OW); + break; + } + case InterpolateMode::cubic: { + cubicRef(in_ptr_, out_ptr_, N, C, IH, IW, OH, OW); + break; + } + case InterpolateMode::linear: { + float fz = (dataRank == 5) ? dataScales[dataRank - 3] : 1.f; + float fy = dataScales[dataRank - 2]; + float fx = dataScales[dataRank - 1]; + + bool isDownsample = (fx < 1.f) || (fy < 1.f) || (fz < 1.f); + int kernel_width = 2; + linearInterpolation(in_ptr_, + out_ptr_, + N, + C, + ID, + IH, + IW, + fx, + fy, + fz, + OD, + OH, + OW, + kernel_width, + isDownsample && antialias); + break; + } + case InterpolateMode::bilinear_pillow: + case InterpolateMode::bicubic_pillow: { + pillowRef(in_ptr_, out_ptr_, N, C, IH, IW, OH, OW); + break; + } + default: { + OPENVINO_THROW("Interpolate layer has unsupported interpolate mode: ", mode); + } } } size_t Interpolate::getSpatialDimsNum(const Dim rank) { switch (rank) { - case 1: - case 3: - return 1; - case 2: - case 4: - return 2; - case 5: - return 3; - default: - OPENVINO_THROW("Can't define number spatial"); + case 1: + case 3: + return 1; + case 2: + case 4: + return 2; + case 5: + return 3; + default: + OPENVINO_THROW("Can't define number spatial"); } } bool Interpolate::canFuse(const NodePtr& node) const { - if (!mayiuse(cpu::x64::sse41) || - interpAttrs.mode == InterpolateMode::linear || - interpAttrs.mode == InterpolateMode::bilinear_pillow || - interpAttrs.mode == InterpolateMode::bicubic_pillow || + if (!mayiuse(cpu::x64::sse41) || interpAttrs.mode == InterpolateMode::linear || + interpAttrs.mode == InterpolateMode::bilinear_pillow || interpAttrs.mode == InterpolateMode::bicubic_pillow || (!one_of(dataRank, 4u, 5u) && !mayiuse(cpu::x64::avx2))) { return false; } @@ -4002,6 +4276,6 @@ bool Interpolate::created() const { return getType() == Type::Interpolate; } -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/interpolate.h b/src/plugins/intel_cpu/src/nodes/interpolate.h index a43b354aa0306a..c6fedf384f449d 100644 --- a/src/plugins/intel_cpu/src/nodes/interpolate.h +++ b/src/plugins/intel_cpu/src/nodes/interpolate.h @@ -31,34 +31,36 @@ struct jit_interpolate_config_params { }; struct jit_interpolate_call_args { - const void *src_ptr[MAX_INPUT_INTERPOLATE]; - const void *weight_ptr[MAX_INPUT_INTERPOLATE]; - const int *index; - void *dst; + const void* src_ptr[MAX_INPUT_INTERPOLATE]; + const void* weight_ptr[MAX_INPUT_INTERPOLATE]; + const int* index; + void* dst; size_t work_amount; size_t oc_off; - //ptr to array of post op inputs pointers (flat list) + // ptr to array of post op inputs pointers (flat list) const void* post_op_data; }; struct jit_uni_interpolate_kernel { - void (*ker_)(const jit_interpolate_call_args *); + void (*ker_)(const jit_interpolate_call_args*); - void operator()(const jit_interpolate_call_args *args) { + void operator()(const jit_interpolate_call_args* args) { assert(ker_); ker_(args); } - explicit jit_uni_interpolate_kernel(jit_interpolate_config_params jcp, const dnnl_primitive_attr &attr) : ker_(nullptr), jcp_(jcp), attr_(attr) {} + explicit jit_uni_interpolate_kernel(jit_interpolate_config_params jcp, const dnnl_primitive_attr& attr) + : ker_(nullptr), + jcp_(jcp), + attr_(attr) {} virtual ~jit_uni_interpolate_kernel() {} virtual void create_ker() = 0; jit_interpolate_config_params jcp_; - const dnnl_primitive_attr &attr_; + const dnnl_primitive_attr& attr_; }; - class Interpolate : public Node { public: static constexpr size_t DATA_ID = 0; @@ -98,8 +100,9 @@ class Interpolate : public Node { bool is_version11 = true; InterpolateAttrs interpAttrs; // Some FEs or preprocessing step resize spatial dimension for tensor with NHWC layout memory, - // but imported as planar layout[abcd] with axis[1,2] for convenience. In this case, for pillow modes without pad for now, - // nhwc layout path and the kernel(nhwc layout executor) can be used for this planar layout and axis settings(NCHWAsNHWC is true) to get higher perf with + // but imported as planar layout[abcd] with axis[1,2] for convenience. In this case, for pillow modes without pad + // for now, nhwc layout path and the kernel(nhwc layout executor) can be used for this planar layout and axis + // settings(NCHWAsNHWC is true) to get higher perf with // 1. logical shape alignment [abcd-nhwc] to [adbc-nchw]. // 2. axis alignment [1,2] to [2,3]. // 3. config planar layout support and treated it as channel_first layout. @@ -107,120 +110,226 @@ class Interpolate : public Node { size_t dataRank = 0; class InterpolateExecutorBase { - public: - InterpolateExecutorBase(const InterpolateAttrs& interpAttrs, - const VectorDims &srcDims, - const VectorDims &dstDims, - const std::vector &dataScales); - - virtual void exec(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_) = 0; - virtual ~InterpolateExecutorBase() = default; - VectorDims getSrcDimPad5d() const { return srcDimPad5d; } - - private: - void buildTblNN(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, const std::vector& dataScales, - InterpolateLayoutType layout, InterpolateNearestMode nearestMode); - void buildTblLinearOnnx(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, const std::vector& dataScales, - InterpolateLayoutType layout); - void buildTblLinear(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, const std::vector& dataScales, int kernel_width, - bool antialias); - void buildTblCubic(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, const std::vector& dataScales, float cubicCoeff, - InterpolateLayoutType layout); - void buildTblPillow(const VectorDims& srcDimPad5d, const VectorDims& dstDim5d, const std::vector& dataScales, - float cubicCoeff, InterpolateLayoutType layout); - - float coordTransToInput(int outCoord, float scale, int inShape, int outShape) const; - int nearestRound(float origin, bool isDownsample, InterpolateNearestMode nearestMode) const; - void linearOnnxCF(int outCoord, float scale, int inShape, int outShape, int& index0, int& index1, float& weight0, float& weight1); - std::vector getCubicCoeffs(float mantissa, float a); - static float getPillowBilinearCoeffs(float m); - static float getPillowBicubicCoeffs(float m); - inline void create_pillow_working_buf(InterpolateLayoutType layout); - - protected: - InterpolateMode mode; - InterpolateCoordTransMode coordTransMode; - InterpolateLayoutType configured_for_layout; - VectorDims srcDimPad5d, dstDim5d; - ov::element::Type inputPrec, outputPrec; - size_t srcDataSize, dstDataSize; - int spatialDimSize; - size_t dataRank; - std::vector auxTable; - std::vector pillow_working_buf; - size_t m_threads_num = 0lu; + public: + InterpolateExecutorBase(const InterpolateAttrs& interpAttrs, + const VectorDims& srcDims, + const VectorDims& dstDims, + const std::vector& dataScales); + + virtual void exec(const uint8_t* in_ptr_, uint8_t* out_ptr_, const void* post_ops_data_) = 0; + virtual ~InterpolateExecutorBase() = default; + VectorDims getSrcDimPad5d() const { + return srcDimPad5d; + } + + private: + void buildTblNN(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + InterpolateLayoutType layout, + InterpolateNearestMode nearestMode); + void buildTblLinearOnnx(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + InterpolateLayoutType layout); + void buildTblLinear(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + int kernel_width, + bool antialias); + void buildTblCubic(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + float cubicCoeff, + InterpolateLayoutType layout); + void buildTblPillow(const VectorDims& srcDimPad5d, + const VectorDims& dstDim5d, + const std::vector& dataScales, + float cubicCoeff, + InterpolateLayoutType layout); + + float coordTransToInput(int outCoord, float scale, int inShape, int outShape) const; + int nearestRound(float origin, bool isDownsample, InterpolateNearestMode nearestMode) const; + void linearOnnxCF(int outCoord, + float scale, + int inShape, + int outShape, + int& index0, + int& index1, + float& weight0, + float& weight1); + std::vector getCubicCoeffs(float mantissa, float a); + static float getPillowBilinearCoeffs(float m); + static float getPillowBicubicCoeffs(float m); + inline void create_pillow_working_buf(InterpolateLayoutType layout); + + protected: + InterpolateMode mode; + InterpolateCoordTransMode coordTransMode; + InterpolateLayoutType configured_for_layout; + VectorDims srcDimPad5d, dstDim5d; + ov::element::Type inputPrec, outputPrec; + size_t srcDataSize, dstDataSize; + int spatialDimSize; + size_t dataRank; + std::vector auxTable; + std::vector pillow_working_buf; + size_t m_threads_num = 0lu; }; std::shared_ptr execPtr = nullptr; class InterpolateJitExecutor : public InterpolateExecutorBase { - public: - InterpolateJitExecutor(const InterpolateAttrs& interpAttrs, - const VectorDims &srcDims, - const VectorDims &dstDims, - const std::vector &dataScales, - const dnnl::primitive_attr &attr); - - void exec(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_) override; - - private: - // nearest neighbor - void NNPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int ID, int IH, int IW, int OD, int OH, int OW); - void NNCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int ID, int IH, int IW, int OD, int OH, int OW); - - // onnx linear - void linearOnnxPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int ID, int IH, int IW, int OD, int OH, int OW); - void linearOnnxCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int ID, int IH, int IW, int OD, int OH, int OW); - - // cubic - void cubicPlanar(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int IH, int IW, int OH, int OW); - void cubicCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int IH, int IW, int OH, int OW); - - // pillow bilinear and pillow bicubic - void pillowCGathered(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_, - int B, int C, int IH, int IW, int OH, int OW); - - private: - std::shared_ptr interpolateKernel = nullptr; + public: + InterpolateJitExecutor(const InterpolateAttrs& interpAttrs, + const VectorDims& srcDims, + const VectorDims& dstDims, + const std::vector& dataScales, + const dnnl::primitive_attr& attr); + + void exec(const uint8_t* in_ptr_, uint8_t* out_ptr_, const void* post_ops_data_) override; + + private: + // nearest neighbor + void NNPlanar(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int ID, + int IH, + int IW, + int OD, + int OH, + int OW); + void NNCGathered(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int ID, + int IH, + int IW, + int OD, + int OH, + int OW); + + // onnx linear + void linearOnnxPlanar(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int ID, + int IH, + int IW, + int OD, + int OH, + int OW); + void linearOnnxCGathered(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int ID, + int IH, + int IW, + int OD, + int OH, + int OW); + + // cubic + void cubicPlanar(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int IH, + int IW, + int OH, + int OW); + void cubicCGathered(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int IH, + int IW, + int OH, + int OW); + + // pillow bilinear and pillow bicubic + void pillowCGathered(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + const void* post_ops_data_, + int B, + int C, + int IH, + int IW, + int OH, + int OW); + + private: + std::shared_ptr interpolateKernel = nullptr; }; class InterpolateRefExecutor : public InterpolateExecutorBase { - public: - InterpolateRefExecutor(const InterpolateAttrs& interpAttrs, - const VectorDims &srcDims, - const VectorDims &dstDims, - const std::vector &_dataScales) : - InterpolateExecutorBase(interpAttrs, srcDims, dstDims, _dataScales), - antialias(interpAttrs.antialias), dataScales(_dataScales) {} - - void exec(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_) override; - - private: - void NNRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int ID, int IH, int IW, int OD, int OH, int OW); - void linearOnnxRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int ID, int IH, int IW, int OD, int OH, int OW); - - void cubicRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW); - void linearInterpolation(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int ID, int IH, int IW, - float fx, float fy, float fz, int OD, int OH, int OW, int kernel_width, bool antialias); - void pillowRef(const uint8_t *in_ptr_, uint8_t *out_ptr_, int B, int C, int IH, int IW, int OH, int OW); - - static float getValue(const uint8_t *base, size_t offset, ov::element::Type prec); - static void setValue(uint8_t *base, size_t offset, float value, ov::element::Type prec); - - private: - bool antialias; - std::vector dataScales; + public: + InterpolateRefExecutor(const InterpolateAttrs& interpAttrs, + const VectorDims& srcDims, + const VectorDims& dstDims, + const std::vector& _dataScales) + : InterpolateExecutorBase(interpAttrs, srcDims, dstDims, _dataScales), + antialias(interpAttrs.antialias), + dataScales(_dataScales) {} + + void exec(const uint8_t* in_ptr_, uint8_t* out_ptr_, const void* post_ops_data_) override; + + private: + void + NNRef(const uint8_t* in_ptr_, uint8_t* out_ptr_, int B, int C, int ID, int IH, int IW, int OD, int OH, int OW); + void linearOnnxRef(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + int B, + int C, + int ID, + int IH, + int IW, + int OD, + int OH, + int OW); + + void cubicRef(const uint8_t* in_ptr_, uint8_t* out_ptr_, int B, int C, int IH, int IW, int OH, int OW); + void linearInterpolation(const uint8_t* in_ptr_, + uint8_t* out_ptr_, + int B, + int C, + int ID, + int IH, + int IW, + float fx, + float fy, + float fz, + int OD, + int OH, + int OW, + int kernel_width, + bool antialias); + void pillowRef(const uint8_t* in_ptr_, uint8_t* out_ptr_, int B, int C, int IH, int IW, int OH, int OW); + + static float getValue(const uint8_t* base, size_t offset, ov::element::Type prec); + static void setValue(uint8_t* base, size_t offset, float value, ov::element::Type prec); + + private: + bool antialias; + std::vector dataScales; }; - void setPostOps(dnnl::primitive_attr &attr, const VectorDims &dims); + void setPostOps(dnnl::primitive_attr& attr, const VectorDims& dims); - static VectorDims getPaddedInputShape(const VectorDims &srcDims, const std::vector &padBegin, const std::vector &padEnd); - std::vector getScales(const VectorDims &srcDimPad, const VectorDims &dstDim); + static VectorDims getPaddedInputShape(const VectorDims& srcDims, + const std::vector& padBegin, + const std::vector& padEnd); + std::vector getScales(const VectorDims& srcDimPad, const VectorDims& dstDim); static size_t getSpatialDimsNum(const Dim rank); bool hasPad = false; @@ -244,6 +353,6 @@ class Interpolate : public Node { std::shared_ptr aclExecPtr = nullptr; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov +} // namespace node +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/jit_eltwise_call_args_ptrs.hpp b/src/plugins/intel_cpu/src/nodes/kernels/jit_eltwise_call_args_ptrs.hpp index 7370bb824d8c62..66f119ee839b14 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/jit_eltwise_call_args_ptrs.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/jit_eltwise_call_args_ptrs.hpp @@ -9,21 +9,21 @@ namespace ov { namespace intel_cpu { namespace node { -#define MAX_ELTWISE_INPUTS 7 +#define MAX_ELTWISE_INPUTS 7 #define MAX_ELTWISE_DIM_RANK 12 struct jit_eltwise_call_args_ptrs { - const void *src_ptr[MAX_ELTWISE_INPUTS]; - void *dst_ptr; - //ptr to array of post op inputs pointers (flat list) + const void* src_ptr[MAX_ELTWISE_INPUTS]; + void* dst_ptr; + // ptr to array of post op inputs pointers (flat list) const void** post_op_data; // shape agnostic kernel size_t work_amount; - const void *src_offsets[MAX_ELTWISE_INPUTS]; - const void *dst_offsets; + const void* src_offsets[MAX_ELTWISE_INPUTS]; + const void* dst_offsets; }; -} // namespace node -} // namespace intel_cpu -} // namespace ov \ No newline at end of file +} // namespace node +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp index 755330bd850c4d..b4d38086cefe8a 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.cpp @@ -13,10 +13,10 @@ # include #endif -#include "openvino/core/type/bfloat16.hpp" -#include "openvino/core/parallel.hpp" -#include "common.hpp" #include "attn_memcpy.hpp" +#include "common.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/core/type/bfloat16.hpp" namespace ov { namespace Extensions { @@ -26,7 +26,7 @@ namespace XARCH { using namespace ov; // float16 <- float -template +template void attn_copy(TA* a, TB* b, size_t n) { size_t i = 0; #if defined(HAVE_AVX512F) @@ -51,14 +51,11 @@ void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input, const ov::intel_cpu::PlainTensor& past_k_output, const ov::intel_cpu::PlainTensor& past_v_output) { // For compatibility, all input_kvs are permuted to BHLS - size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3], SV = v_input.m_dims[3]; + size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3], + SV = v_input.m_dims[3]; parallel_for3d(L1, B, H, [&](size_t m, size_t b, size_t h) { - attn_copy(past_k_output.ptr(b, h, m, 0), - k_input.ptr(b, h, m, 0), - S); - attn_copy(past_v_output.ptr(b, h, m, 0), - v_input.ptr(b, h, m, 0), - SV); + attn_copy(past_k_output.ptr(b, h, m, 0), k_input.ptr(b, h, m, 0), S); + attn_copy(past_v_output.ptr(b, h, m, 0), v_input.ptr(b, h, m, 0), SV); }); } @@ -67,14 +64,11 @@ static void attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input, const ov::intel_cpu::PlainTensor& past_k_output, const ov::intel_cpu::PlainTensor& past_v_output) { // For compatibility, all input_kvs are permuted to BHLS - size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3], SV = v_input.m_dims[3]; + size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3], + SV = v_input.m_dims[3]; parallel_for3d(L1, B, H, [&](size_t m, size_t b, size_t h) { - std::memcpy(past_k_output.ptr_v(b, h, m, 0), - k_input.ptr_v(b, h, m, 0), - S * k_input.m_element_size); - std::memcpy(past_v_output.ptr_v(b, h, m, 0), - v_input.ptr_v(b, h, m, 0), - SV * v_input.m_element_size); + std::memcpy(past_k_output.ptr_v(b, h, m, 0), k_input.ptr_v(b, h, m, 0), S * k_input.m_element_size); + std::memcpy(past_v_output.ptr_v(b, h, m, 0), v_input.ptr_v(b, h, m, 0), SV * v_input.m_element_size); }); } @@ -84,19 +78,17 @@ static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input, const ov::intel_cpu::PlainTensor& past_k_output, const ov::intel_cpu::PlainTensor& past_v_output, const ov::intel_cpu::PlainTensor& slot_mapping) { - size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3], SV = v_input.m_dims[3]; + size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3], + SV = v_input.m_dims[3]; size_t block_size = past_k_output.m_dims[2]; parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; - if (slot < 0) return; + if (slot < 0) + return; auto block_number = slot / block_size; auto block_offset = slot % block_size; - attn_copy(past_k_output.ptr(block_number, h, block_offset, 0), - k_input.ptr(b, h, m, 0), - S); - attn_copy(past_v_output.ptr(block_number, h, block_offset, 0), - v_input.ptr(b, h, m, 0), - SV); + attn_copy(past_k_output.ptr(block_number, h, block_offset, 0), k_input.ptr(b, h, m, 0), S); + attn_copy(past_v_output.ptr(block_number, h, block_offset, 0), v_input.ptr(b, h, m, 0), SV); }); } @@ -105,11 +97,13 @@ static void paged_attn_memcpy_kernel(const ov::intel_cpu::PlainTensor& k_input, const ov::intel_cpu::PlainTensor& past_k_output, const ov::intel_cpu::PlainTensor& past_v_output, const ov::intel_cpu::PlainTensor& slot_mapping) { - size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3], SV = v_input.m_dims[3]; + size_t B = k_input.m_dims[0], H = k_input.m_dims[1], L1 = k_input.m_dims[2], S = k_input.m_dims[3], + SV = v_input.m_dims[3]; size_t block_size = past_k_output.m_dims[2]; parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; - if (slot < 0) return; + if (slot < 0) + return; auto block_number = slot / block_size; auto block_offset = slot % block_size; std::memcpy(past_k_output.ptr_v(block_number, h, block_offset, 0), @@ -132,7 +126,11 @@ void attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, } else if (k_input.get_precision() == ov::element::f32 && past_k_output.get_precision() == ov::element::bf16) { attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output); } else { - OPENVINO_THROW("unsupport src type: ", k_input.get_precision(), ", dst type: ", past_k_output.get_precision(), " in attn_memcpy"); + OPENVINO_THROW("unsupport src type: ", + k_input.get_precision(), + ", dst type: ", + past_k_output.get_precision(), + " in attn_memcpy"); } } @@ -148,7 +146,11 @@ void paged_attn_memcpy(const ov::intel_cpu::PlainTensor& k_input, } else if (k_input.get_precision() == ov::element::f32 && past_k_output.get_precision() == ov::element::bf16) { paged_attn_memcpy_kernel(k_input, v_input, past_k_output, past_v_output, slot_mapping); } else { - OPENVINO_THROW("unsupport src type: ", k_input.get_precision(), ", dst type: ", past_k_output.get_precision(), " in paged_attn_memcpy"); + OPENVINO_THROW("unsupport src type: ", + k_input.get_precision(), + ", dst type: ", + past_k_output.get_precision(), + " in paged_attn_memcpy"); } } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp index c0e5892db9926b..ea704232e333bd 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_memcpy.hpp @@ -7,6 +7,7 @@ #include #include #include + #include "openvino/core/type/element_type.hpp" #include "utils/plain_tensor.hpp" diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp index ee13118ec80d11..19721b4961fbb0 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.cpp @@ -13,11 +13,11 @@ # include #endif -#include "openvino/core/type/bfloat16.hpp" -#include "openvino/core/parallel.hpp" -#include "common.hpp" #include "attn_quant.hpp" #include "attn_quant_kernel.hpp" +#include "common.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/core/type/bfloat16.hpp" namespace ov { namespace Extensions { @@ -26,7 +26,7 @@ namespace XARCH { using namespace ov; -template +template static void find_minmax(const T* src, size_t n, float& min, float& max) { max = -FLT_MAX; min = FLT_MAX; @@ -133,7 +133,7 @@ static void find_minmax(const T* src, size_t n, float& min, float& max) { } } -template +template static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& zp) { size_t i = 0; float max = -FLT_MAX; @@ -176,7 +176,7 @@ static void quant_u8(const T* src, uint8_t* dst, size_t n, float& scale, float& } } -template +template static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) { size_t i = 0; float max = -FLT_MAX; @@ -184,7 +184,7 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) find_minmax(src, n, min, max); auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; - return dst | (uint8_t) (val << shift); + return dst | (uint8_t)(val << shift); }; auto dst_ptr = reinterpret_cast(dst); scale = (max - min) / ((1 << 4) - 1); @@ -209,17 +209,17 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) v1_i32 = _mm512_min_epi32(v1_i32, v_upper); __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); - auto first_half = _mm512_permutex2var_epi32(v0_i32, idx1, v1_i32); - auto second_half = _mm512_permutex2var_epi32(v0_i32, idx2, v1_i32); + auto first_half = _mm512_permutex2var_epi32(v0_i32, idx1, v1_i32); + auto second_half = _mm512_permutex2var_epi32(v0_i32, idx2, v1_i32); first_half = _mm512_slli_epi32(first_half, 4); auto mask = _mm512_set1_epi32(0x0F); second_half = _mm512_and_epi32(second_half, mask); - auto combined = _mm512_or_epi32(first_half, second_half); + auto combined = _mm512_or_epi32(first_half, second_half); _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); } #endif #if defined(HAVE_AVX2) || defined(HAVE_AVX512F) - auto v256_zero = _mm256_set1_epi32(0); + auto v256_zero = _mm256_set1_epi32(0); auto v256_upper = _mm256_set1_epi32(15); auto v256_scale = _mm256_set1_ps(1 / scale); auto v256_zp = _mm256_set1_ps(zp); @@ -264,7 +264,7 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) #endif for (; i < n; i++) { float tmp = src[i]; - #define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) uint8_t src_val = MIN(15, (uint8_t)(std::round(tmp / scale + zp))); uint8_t dst_val = i % 2 == 0 ? 0 : dst_ptr[i / 2]; dst_val = insert_half_byte(dst_val, src_val, (uint8_t)(i % 2)); @@ -272,13 +272,13 @@ static void quant_u4(const T* src, void* dst, size_t n, float& scale, float& zp) } } -template +template static void quant_s4(const T* src, void* dst, size_t n, float& scale) { auto insert_half_byte = [](uint8_t dst, uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; if (high_half) val &= 0x0F; - return dst | (uint8_t) (val << shift); + return dst | (uint8_t)(val << shift); }; auto dst_ptr = reinterpret_cast(dst); size_t i = 0; @@ -308,18 +308,18 @@ static void quant_s4(const T* src, void* dst, size_t n, float& scale) { __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); - auto first_half = _mm512_permutex2var_epi32(v0_i32, idx1, v1_i32); - auto second_half = _mm512_permutex2var_epi32(v0_i32, idx2, v1_i32); + auto first_half = _mm512_permutex2var_epi32(v0_i32, idx1, v1_i32); + auto second_half = _mm512_permutex2var_epi32(v0_i32, idx2, v1_i32); auto mask = _mm512_set1_epi32(0x0F); second_half = _mm512_and_epi32(second_half, mask); first_half = _mm512_slli_epi32(first_half, 4); - auto combined = _mm512_or_epi32(first_half, second_half); + auto combined = _mm512_or_epi32(first_half, second_half); _mm512_mask_cvtepi32_storeu_epi8(dst_ptr + i / 2, 0xffff, combined); } #endif #if defined(HAVE_AVX2) || defined(HAVE_AVX512F) - auto v256_lower = _mm256_set1_epi32(-8); + auto v256_lower = _mm256_set1_epi32(-8); auto v256_upper = _mm256_set1_epi32(7); auto v256_scale = _mm256_set1_ps(1 / scale); for (; i + vec_len_f32_avx2 * 2 <= n; i += vec_len_f32_avx2 * 2) { @@ -384,16 +384,8 @@ static void attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, parallel_for3d(L1, B, H, [&](size_t m, size_t b, size_t h) { auto p_k = k_scale_zp.ptr(m, b, h); auto p_v = v_scale_zp.ptr(m, b, h); - quant_u8(k_src.ptr(b, h, m), - k_dst.ptr(b, h, m), - S, - p_k[0], - p_k[1]); - quant_u8(v_src.ptr(b, h, m), - v_dst.ptr(b, h, m), - SV, - p_v[0], - p_v[1]); + quant_u8(k_src.ptr(b, h, m), k_dst.ptr(b, h, m), S, p_k[0], p_k[1]); + quant_u8(v_src.ptr(b, h, m), v_dst.ptr(b, h, m), SV, p_v[0], p_v[1]); }); } @@ -414,11 +406,13 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, size_t _value_group_size = value_group_size == 0 ? SV : value_group_size; parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; - if (slot < 0) return; + if (slot < 0) + return; auto block_number = slot / block_size; auto block_offset = slot % block_size; // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { auto p_k = reinterpret_cast( @@ -475,11 +469,13 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; - if (slot < 0) return; + if (slot < 0) + return; auto block_number = slot / block_size; auto block_offset = slot % block_size; // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { auto p_k = reinterpret_cast( @@ -530,11 +526,13 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, size_t sub_byte_multiplier = 8 / v_dst.get_precision().bitwidth(); parallel_for3d(B, L1, H, [&](size_t b, size_t m, size_t h) { auto slot = slot_mapping.ptr(b)[m]; - if (slot < 0) return; + if (slot < 0) + return; auto block_number = slot / block_size; auto block_offset = slot % block_size; // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| for (size_t src_offset = 0, dst_offset = 0; src_offset < S; src_offset += _key_group_size, dst_offset += _key_group_size + sizeof(float) + sizeof(float)) { auto p_k = reinterpret_cast( @@ -553,8 +551,8 @@ static void paged_attn_quant_mt(const ov::intel_cpu::PlainTensor& k_src, p_k[1]); } - for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; src_offset += _value_group_size, - dst_offset += _value_group_size / sub_byte_multiplier + sizeof(float)) { + for (size_t src_offset = 0, dst_offset = 0; src_offset < SV; + src_offset += _value_group_size, dst_offset += _value_group_size / sub_byte_multiplier + sizeof(float)) { uint8_t* v_base = reinterpret_cast( v_dst.m_ptr.get() + (block_number * v_dst.m_strides[0] + h * v_dst.m_strides[1] + block_offset * v_dst.m_strides[2]) / @@ -580,7 +578,11 @@ void attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, } else if (k_src.get_precision() == ov::element::f16 && k_dst.get_precision() == ov::element::u8) { attn_quant_mt(k_src, v_src, k_dst, v_dst, k_scale_zp, v_scale_zp); } else { - OPENVINO_THROW("unsupport src type: ", k_src.get_precision(), ", dst type: ", k_dst.get_precision(), " in attn_quantkv"); + OPENVINO_THROW("unsupport src type: ", + k_src.get_precision(), + ", dst type: ", + k_dst.get_precision(), + " in attn_quantkv"); } } @@ -592,29 +594,33 @@ void paged_attn_quantkv(const ov::intel_cpu::PlainTensor& k_src, const size_t key_group_size, const size_t value_group_size) { using function_type = void (*)(const ov::intel_cpu::PlainTensor&, - const ov::intel_cpu::PlainTensor&, - const ov::intel_cpu::PlainTensor&, - const ov::intel_cpu::PlainTensor&, - const ov::intel_cpu::PlainTensor&, - const size_t, - const size_t); + const ov::intel_cpu::PlainTensor&, + const ov::intel_cpu::PlainTensor&, + const ov::intel_cpu::PlainTensor&, + const ov::intel_cpu::PlainTensor&, + const size_t, + const size_t); static constexpr function_type funcs_fp32[] = { - paged_attn_quant_mt, - paged_attn_quant_mt, - paged_attn_quant_mt, + paged_attn_quant_mt, + paged_attn_quant_mt, + paged_attn_quant_mt, }; static constexpr function_type funcs_bf16[] = { - paged_attn_quant_mt, - paged_attn_quant_mt, - paged_attn_quant_mt, + paged_attn_quant_mt, + paged_attn_quant_mt, + paged_attn_quant_mt, }; static constexpr function_type funcs_f16[] = { - paged_attn_quant_mt, - paged_attn_quant_mt, - paged_attn_quant_mt, + paged_attn_quant_mt, + paged_attn_quant_mt, + paged_attn_quant_mt, }; if (k_dst.get_precision() != ov::element::u8) { - OPENVINO_THROW("unsupport src type: ", k_src.get_precision(), ", dst type: ", k_dst.get_precision(), " in paged_attn_quantkv"); + OPENVINO_THROW("unsupport src type: ", + k_src.get_precision(), + ", dst type: ", + k_dst.get_precision(), + " in paged_attn_quantkv"); } std::map dispatch_table = { {ov::element::u8, 0}, diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp index 70711ee49cbc3c..364e5775861ed2 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant.hpp @@ -7,6 +7,7 @@ #include #include #include + #include "openvino/core/type/element_type.hpp" #include "utils/plain_tensor.hpp" diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp index 434286766d5188..6083715c917910 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/attn_quant_kernel.hpp @@ -3,19 +3,21 @@ // #pragma once -#include +#include "nodes/kernels/scaled_attn/common.hpp" + +#if defined(HAVE_SSE) || defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# include +#endif + #include #include -#include -#include "openvino/core/type/element_type.hpp" -#include "utils/plain_tensor.hpp" namespace ov { namespace Extensions { namespace Cpu { namespace XARCH { -template +template void attn_dequant_u8_kernel(const uint8_t* src, TDST* dst, size_t n, float scale, float zp) { size_t i = 0; // loadu_si128/epi64 does not support const qualifier @@ -218,4 +220,4 @@ void attn_dequant_s4_kernel(const uint8_t* src, TDST* dst, size_t n, float scale } // namespace XARCH } // namespace Cpu } // namespace Extensions -} // namespace ov \ No newline at end of file +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp index 2956c8a6a6b5b8..4e14cf5894b04d 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp @@ -4,16 +4,20 @@ #pragma once #include +#include #include #include #include -#include #include "openvino/core/type/bfloat16.hpp" #include "openvino/core/type/float16.hpp" +#if defined(HAVE_SSE) || defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# include +#endif + #if defined(OPENVINO_ARCH_ARM64) -#include "arm_neon.h" +# include "arm_neon.h" #endif namespace ov { @@ -32,307 +36,307 @@ static constexpr size_t vec_len_f32_neon = vec_len_neon / sizeof(float); static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); #ifdef HAVE_AVX512F - inline __m512 cvt_bf16_to_fp32(const __m256i src) { - __m512i y = _mm512_cvtepu16_epi32(src); - return _mm512_castsi512_ps(_mm512_slli_epi32(y, 16)); - } - - // load addr to __m512 reg - inline __m512 mm512_uni_loadu_ps(const float* a) { - return _mm512_loadu_ps(a); - } - - inline __m512 mm512_uni_loadu_ps(const ov::bfloat16* a) { - auto vec_bf16 = _mm256_loadu_si256(reinterpret_cast(a)); - return cvt_bf16_to_fp32(vec_bf16); - } - - inline __m512 mm512_uni_loadu_ps(const ov::float16* a) { - auto vec_f16 = _mm256_loadu_si256(reinterpret_cast(a)); - return _mm512_cvtph_ps(vec_f16); - } - - // load addr to __m512 reg - inline __m512 mm512_uni_loadu_tail_ps(const float* a, size_t count) { - __mmask16 mask = (1 << count) - 1; - return _mm512_maskz_loadu_ps(mask, a); - } - - inline __m512 mm512_uni_loadu_tail_ps(const ov::bfloat16* a, size_t count) { - auto mask = (1 << count) - 1; - auto bf16_vec = _mm256_maskz_loadu_epi16(mask, a); - return cvt_bf16_to_fp32(bf16_vec); - } - - inline __m512 mm512_uni_loadu_tail_ps(const ov::float16* a, size_t count) { - auto mask = (1 << count) - 1; - auto f16_vec = _mm256_maskz_loadu_epi16(mask, a); - return _mm512_cvtph_ps(f16_vec); - } - - // store __m512 reg to addr - inline void mm512_uni_storeu_ps(float* a, __m512 v) { - _mm512_storeu_ps(a, v); - } - inline void mm512_uni_storeu_ps(ov::bfloat16 *addr, __m512 xps) { - __m512i xpi32 = _mm512_castps_si512(xps); - __m512i nan = _mm512_set1_epi32(0xffff); - auto mask = _mm512_cmp_ps_mask(xps, xps, _CMP_ORD_Q); - __m512i ones = _mm512_set1_epi32(0x1); - __m512i vec_bias = _mm512_set1_epi32(0x7fff); - auto x = _mm512_and_si512(_mm512_srli_epi32(xpi32, 16), ones); // LSB = x[16] - x = _mm512_add_epi32(x, vec_bias); // rounding_bias = 0x7fff + LSB - x = _mm512_srli_epi32(_mm512_add_epi32(x, xpi32), 16); // x = (x + rounding_bias) >> 16; - x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16 - _mm256_storeu_si256(reinterpret_cast<__m256i *>(addr), _mm512_cvtepi32_epi16(x)); - } - - inline void mm512_uni_storeu_ps(ov::float16* addr, __m512 v) { - __m256i vec_f16 = _mm512_cvtps_ph(v, 0); - _mm256_storeu_si256(reinterpret_cast<__m256i *>(addr), vec_f16); - } - - // store __m512 reg to addr - inline void mm512_uni_mask_storeu_ps(ov::bfloat16 *addr, __mmask16 mask_addr, __m512 xps) { - __m512i xpi32 = _mm512_castps_si512(xps); - __m512i nan = _mm512_set1_epi32(0xffff); - auto mask = _mm512_cmp_ps_mask(xps, xps, _CMP_ORD_Q); - __m512i ones = _mm512_set1_epi32(0x1); - __m512i vec_bias = _mm512_set1_epi32(0x7fff); - auto x = _mm512_and_si512(_mm512_srli_epi32(xpi32, 16), ones); // LSB = x[16] - x = _mm512_add_epi32(x, vec_bias); // rounding_bias = 0x7fff + LSB - x = _mm512_srli_epi32(_mm512_add_epi32(x, xpi32), 16); // x = (x + rounding_bias) >> 16; - x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16 - _mm512_mask_cvtepi32_storeu_epi16(addr, mask_addr, x); - } - - inline void mm512_uni_storeu_tail_ps(float *addr, __m512 v, size_t count) { - __mmask16 mask_addr = (1 << count) - 1; - _mm512_mask_storeu_ps(addr, mask_addr, v); - } - - inline void mm512_uni_storeu_tail_ps(ov::bfloat16 *addr, __m512 v, size_t count) { - __mmask16 mask_addr = (1 << count) - 1; - __m512i xpi32 = _mm512_castps_si512(v); - __m512i nan = _mm512_set1_epi32(0xffff); - auto mask = _mm512_cmp_ps_mask(v, v, _CMP_ORD_Q); - __m512i ones = _mm512_set1_epi32(0x1); - __m512i vec_bias = _mm512_set1_epi32(0x7fff); - auto x = _mm512_and_si512(_mm512_srli_epi32(xpi32, 16), ones); // LSB = x[16] - x = _mm512_add_epi32(x, vec_bias); // rounding_bias = 0x7fff + LSB - x = _mm512_srli_epi32(_mm512_add_epi32(x, xpi32), 16); // x = (x + rounding_bias) >> 16; - x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16 - _mm512_mask_cvtepi32_storeu_epi16(addr, mask_addr, x); - } - - inline void mm512_uni_storeu_tail_ps(ov::float16 *addr, __m512 v, size_t count) { - __mmask16 mask_addr = (1 << count) - 1; - __m256i vec_f16 = _mm512_cvtps_ph(v, 0); - _mm256_mask_storeu_epi16(reinterpret_cast<__m256i *>(addr), mask_addr, vec_f16); - } +inline __m512 cvt_bf16_to_fp32(const __m256i src) { + __m512i y = _mm512_cvtepu16_epi32(src); + return _mm512_castsi512_ps(_mm512_slli_epi32(y, 16)); +} + +// load addr to __m512 reg +inline __m512 mm512_uni_loadu_ps(const float* a) { + return _mm512_loadu_ps(a); +} + +inline __m512 mm512_uni_loadu_ps(const ov::bfloat16* a) { + auto vec_bf16 = _mm256_loadu_si256(reinterpret_cast(a)); + return cvt_bf16_to_fp32(vec_bf16); +} + +inline __m512 mm512_uni_loadu_ps(const ov::float16* a) { + auto vec_f16 = _mm256_loadu_si256(reinterpret_cast(a)); + return _mm512_cvtph_ps(vec_f16); +} + +// load addr to __m512 reg +inline __m512 mm512_uni_loadu_tail_ps(const float* a, size_t count) { + __mmask16 mask = (1 << count) - 1; + return _mm512_maskz_loadu_ps(mask, a); +} + +inline __m512 mm512_uni_loadu_tail_ps(const ov::bfloat16* a, size_t count) { + auto mask = (1 << count) - 1; + auto bf16_vec = _mm256_maskz_loadu_epi16(mask, a); + return cvt_bf16_to_fp32(bf16_vec); +} + +inline __m512 mm512_uni_loadu_tail_ps(const ov::float16* a, size_t count) { + auto mask = (1 << count) - 1; + auto f16_vec = _mm256_maskz_loadu_epi16(mask, a); + return _mm512_cvtph_ps(f16_vec); +} + +// store __m512 reg to addr +inline void mm512_uni_storeu_ps(float* a, __m512 v) { + _mm512_storeu_ps(a, v); +} +inline void mm512_uni_storeu_ps(ov::bfloat16* addr, __m512 xps) { + __m512i xpi32 = _mm512_castps_si512(xps); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask = _mm512_cmp_ps_mask(xps, xps, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + auto x = _mm512_and_si512(_mm512_srli_epi32(xpi32, 16), ones); // LSB = x[16] + x = _mm512_add_epi32(x, vec_bias); // rounding_bias = 0x7fff + LSB + x = _mm512_srli_epi32(_mm512_add_epi32(x, xpi32), 16); // x = (x + rounding_bias) >> 16; + x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16 + _mm256_storeu_si256(reinterpret_cast<__m256i*>(addr), _mm512_cvtepi32_epi16(x)); +} + +inline void mm512_uni_storeu_ps(ov::float16* addr, __m512 v) { + __m256i vec_f16 = _mm512_cvtps_ph(v, 0); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(addr), vec_f16); +} + +// store __m512 reg to addr +inline void mm512_uni_mask_storeu_ps(ov::bfloat16* addr, __mmask16 mask_addr, __m512 xps) { + __m512i xpi32 = _mm512_castps_si512(xps); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask = _mm512_cmp_ps_mask(xps, xps, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + auto x = _mm512_and_si512(_mm512_srli_epi32(xpi32, 16), ones); // LSB = x[16] + x = _mm512_add_epi32(x, vec_bias); // rounding_bias = 0x7fff + LSB + x = _mm512_srli_epi32(_mm512_add_epi32(x, xpi32), 16); // x = (x + rounding_bias) >> 16; + x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16 + _mm512_mask_cvtepi32_storeu_epi16(addr, mask_addr, x); +} + +inline void mm512_uni_storeu_tail_ps(float* addr, __m512 v, size_t count) { + __mmask16 mask_addr = (1 << count) - 1; + _mm512_mask_storeu_ps(addr, mask_addr, v); +} + +inline void mm512_uni_storeu_tail_ps(ov::bfloat16* addr, __m512 v, size_t count) { + __mmask16 mask_addr = (1 << count) - 1; + __m512i xpi32 = _mm512_castps_si512(v); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask = _mm512_cmp_ps_mask(v, v, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + auto x = _mm512_and_si512(_mm512_srli_epi32(xpi32, 16), ones); // LSB = x[16] + x = _mm512_add_epi32(x, vec_bias); // rounding_bias = 0x7fff + LSB + x = _mm512_srli_epi32(_mm512_add_epi32(x, xpi32), 16); // x = (x + rounding_bias) >> 16; + x = _mm512_mask_blend_epi32(mask, nan, x); // Check NaN before converting back to bf16 + _mm512_mask_cvtepi32_storeu_epi16(addr, mask_addr, x); +} + +inline void mm512_uni_storeu_tail_ps(ov::float16* addr, __m512 v, size_t count) { + __mmask16 mask_addr = (1 << count) - 1; + __m256i vec_f16 = _mm512_cvtps_ph(v, 0); + _mm256_mask_storeu_epi16(reinterpret_cast<__m256i*>(addr), mask_addr, vec_f16); +} #endif #ifdef HAVE_AVX2 - inline __m256i get_mask(int N7) { - static __m256i mask[] = { - _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, 0), - _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1), - _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1), - _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1), - _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1), - _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1), - _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1), - _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1), - _mm256_set_epi32(-1, -1, -1, -1, -1, -1, -1, -1), - }; - return _mm256_loadu_si256(&mask[N7]); - } - - // load addr to __m256 reg - inline __m256 mm256_uni_loadu_ps(const float* a) { - return _mm256_loadu_ps(a); - } - - inline __m256 mm256_uni_loadu_ps(const ov::bfloat16* a) { - auto vec_bf16 = _mm_loadu_si128(reinterpret_cast(a)); - auto o = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(vec_bf16), 16)); - return o; - } - - inline __m256 mm256_uni_loadu_ps(const ov::float16* a) { - auto vec_f16 = _mm_loadu_si128(reinterpret_cast(a)); - auto o = _mm256_cvtph_ps(vec_f16); - return o; - } - - // load addr tail to __m256 reg - inline __m256 mm256_uni_loadu_tail_ps(const float* a, const size_t count) { - auto mask = get_mask(count); - return _mm256_maskload_ps(a, mask); - } - - inline __m256 mm256_uni_loadu_tail_ps(const ov::bfloat16* a, const size_t count) { - assert("AVX2 version of bfloat16 tail load is just for compilation pass"); - ov::bfloat16 tmp_values[8] = {0}; - std::memcpy(tmp_values, a, count * sizeof(ov::bfloat16)); - return mm256_uni_loadu_ps(tmp_values); - } - - inline __m256 mm256_uni_loadu_tail_ps(const ov::float16* a, const size_t count) { - ov::float16 tmp_values[8] = {0}; - std::memcpy(tmp_values, a, count * sizeof(ov::float16)); - return mm256_uni_loadu_ps(tmp_values); - } - - // store __m256 reg to addr - inline void mm256_uni_storeu_ps(float* a, __m256 v) { - _mm256_storeu_ps(a, v); - } - - inline void mm256_uni_storeu_ps(ov::bfloat16 *addr, __m256 xps) { - __m256i xpi32 = _mm256_castps_si256(xps); - __m256i nan = _mm256_set1_epi32(0xffff); - __m256i mask = _mm256_castps_si256(_mm256_cmp_ps(xps, xps, _CMP_ORD_Q)); - __m256i ones = _mm256_set1_epi32(0x1); - __m256i vec_bias = _mm256_set1_epi32(0x7fff); - auto x = _mm256_and_si256(_mm256_srli_epi32(xpi32, 16), ones); // LSB = x[16] - x = _mm256_add_epi32(x, vec_bias); // rounding_bias = 0x7fff + LSB - x = _mm256_srli_epi32(_mm256_add_epi32(x, xpi32), 16); // x = (x + rounding_bias) >> 16; - x = _mm256_blendv_epi8(nan, x, mask); // Check NaN before converting back to bf16 - x = _mm256_packus_epi32(x, x); - x = _mm256_permute4x64_epi64(x, 0xd8); - __m128i bf16_o = _mm256_extractf128_si256(x, 0); - _mm_storeu_si128(reinterpret_cast<__m128i *>(addr), bf16_o); - } - - inline void mm256_uni_storeu_ps(ov::float16* a, __m256 v) { - __m128i vec_f16 = _mm256_cvtps_ph(v, 0); - _mm_storeu_si128(reinterpret_cast<__m128i *>(a), vec_f16); - } - - // store __m256 to addr - inline void mm256_uni_storeu_tail_ps(float *addr, __m256 v, size_t count) { - const auto mask = get_mask(count); - return _mm256_maskstore_ps(addr, mask, v); - } - - inline void hsum(__m256& x) { - __m256 y; // x: 0 1 2 3 4 5 6 7 - y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 - x = _mm256_add_ps(x, y); // X: 01 12 23 30 45 56 67 74 - y = _mm256_permute_ps(x, 0x4e); // y: 23 30 01 12 67 74 45 56 - x = _mm256_add_ps(x, y); // x: 0123 x x x 4567 x x x - y = _mm256_permute2f128_ps(x, x, 1); // y: 4567 x x x 0123 x x x - x = _mm256_add_ps(x, y); // x: 01234567 x x x x x x x - } - inline void hmax(__m256& x) { - __m256 y; // x: 0 1 2 3 4 5 6 7 - y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 - x = _mm256_max_ps(x, y); // X: 01 12 23 30 45 56 67 74 - y = _mm256_permute_ps(x, 0x4e); // y: 23 30 01 12 67 74 45 56 - x = _mm256_max_ps(x, y); // x: 0123 x x x 4567 x x x - y = _mm256_permute2f128_ps(x, x, 1); // y: 4567 x x x 0123 x x x - x = _mm256_max_ps(x, y); // x: 01234567 x x x x x x x - } - inline void hmin(__m256& x) { - __m256 y; // x: 0 1 2 3 4 5 6 7 - y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 - x = _mm256_min_ps(x, y); // X: 01 12 23 30 45 56 67 74 - y = _mm256_permute_ps(x, 0x4e); // y: 23 30 01 12 67 74 45 56 - x = _mm256_min_ps(x, y); // x: 0123 x x x 4567 x x x - y = _mm256_permute2f128_ps(x, x, 1); // y: 4567 x x x 0123 x x x - x = _mm256_min_ps(x, y); // x: 01234567 x x x x x x x - } +inline __m256i get_mask(int N7) { + static __m256i mask[] = { + _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, 0), + _mm256_set_epi32(0, 0, 0, 0, 0, 0, 0, -1), + _mm256_set_epi32(0, 0, 0, 0, 0, 0, -1, -1), + _mm256_set_epi32(0, 0, 0, 0, 0, -1, -1, -1), + _mm256_set_epi32(0, 0, 0, 0, -1, -1, -1, -1), + _mm256_set_epi32(0, 0, 0, -1, -1, -1, -1, -1), + _mm256_set_epi32(0, 0, -1, -1, -1, -1, -1, -1), + _mm256_set_epi32(0, -1, -1, -1, -1, -1, -1, -1), + _mm256_set_epi32(-1, -1, -1, -1, -1, -1, -1, -1), + }; + return _mm256_loadu_si256(&mask[N7]); +} + +// load addr to __m256 reg +inline __m256 mm256_uni_loadu_ps(const float* a) { + return _mm256_loadu_ps(a); +} + +inline __m256 mm256_uni_loadu_ps(const ov::bfloat16* a) { + auto vec_bf16 = _mm_loadu_si128(reinterpret_cast(a)); + auto o = _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(vec_bf16), 16)); + return o; +} + +inline __m256 mm256_uni_loadu_ps(const ov::float16* a) { + auto vec_f16 = _mm_loadu_si128(reinterpret_cast(a)); + auto o = _mm256_cvtph_ps(vec_f16); + return o; +} + +// load addr tail to __m256 reg +inline __m256 mm256_uni_loadu_tail_ps(const float* a, const size_t count) { + auto mask = get_mask(count); + return _mm256_maskload_ps(a, mask); +} + +inline __m256 mm256_uni_loadu_tail_ps(const ov::bfloat16* a, const size_t count) { + assert("AVX2 version of bfloat16 tail load is just for compilation pass"); + ov::bfloat16 tmp_values[8] = {0}; + std::memcpy(tmp_values, a, count * sizeof(ov::bfloat16)); + return mm256_uni_loadu_ps(tmp_values); +} + +inline __m256 mm256_uni_loadu_tail_ps(const ov::float16* a, const size_t count) { + ov::float16 tmp_values[8] = {0}; + std::memcpy(tmp_values, a, count * sizeof(ov::float16)); + return mm256_uni_loadu_ps(tmp_values); +} + +// store __m256 reg to addr +inline void mm256_uni_storeu_ps(float* a, __m256 v) { + _mm256_storeu_ps(a, v); +} + +inline void mm256_uni_storeu_ps(ov::bfloat16* addr, __m256 xps) { + __m256i xpi32 = _mm256_castps_si256(xps); + __m256i nan = _mm256_set1_epi32(0xffff); + __m256i mask = _mm256_castps_si256(_mm256_cmp_ps(xps, xps, _CMP_ORD_Q)); + __m256i ones = _mm256_set1_epi32(0x1); + __m256i vec_bias = _mm256_set1_epi32(0x7fff); + auto x = _mm256_and_si256(_mm256_srli_epi32(xpi32, 16), ones); // LSB = x[16] + x = _mm256_add_epi32(x, vec_bias); // rounding_bias = 0x7fff + LSB + x = _mm256_srli_epi32(_mm256_add_epi32(x, xpi32), 16); // x = (x + rounding_bias) >> 16; + x = _mm256_blendv_epi8(nan, x, mask); // Check NaN before converting back to bf16 + x = _mm256_packus_epi32(x, x); + x = _mm256_permute4x64_epi64(x, 0xd8); + __m128i bf16_o = _mm256_extractf128_si256(x, 0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(addr), bf16_o); +} + +inline void mm256_uni_storeu_ps(ov::float16* a, __m256 v) { + __m128i vec_f16 = _mm256_cvtps_ph(v, 0); + _mm_storeu_si128(reinterpret_cast<__m128i*>(a), vec_f16); +} + +// store __m256 to addr +inline void mm256_uni_storeu_tail_ps(float* addr, __m256 v, size_t count) { + const auto mask = get_mask(count); + return _mm256_maskstore_ps(addr, mask, v); +} + +inline void hsum(__m256& x) { + __m256 y; // x: 0 1 2 3 4 5 6 7 + y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 + x = _mm256_add_ps(x, y); // X: 01 12 23 30 45 56 67 74 + y = _mm256_permute_ps(x, 0x4e); // y: 23 30 01 12 67 74 45 56 + x = _mm256_add_ps(x, y); // x: 0123 x x x 4567 x x x + y = _mm256_permute2f128_ps(x, x, 1); // y: 4567 x x x 0123 x x x + x = _mm256_add_ps(x, y); // x: 01234567 x x x x x x x +} +inline void hmax(__m256& x) { + __m256 y; // x: 0 1 2 3 4 5 6 7 + y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 + x = _mm256_max_ps(x, y); // X: 01 12 23 30 45 56 67 74 + y = _mm256_permute_ps(x, 0x4e); // y: 23 30 01 12 67 74 45 56 + x = _mm256_max_ps(x, y); // x: 0123 x x x 4567 x x x + y = _mm256_permute2f128_ps(x, x, 1); // y: 4567 x x x 0123 x x x + x = _mm256_max_ps(x, y); // x: 01234567 x x x x x x x +} +inline void hmin(__m256& x) { + __m256 y; // x: 0 1 2 3 4 5 6 7 + y = _mm256_permute_ps(x, 0x39); // y: 1 2 3 0 5 6 7 4 + x = _mm256_min_ps(x, y); // X: 01 12 23 30 45 56 67 74 + y = _mm256_permute_ps(x, 0x4e); // y: 23 30 01 12 67 74 45 56 + x = _mm256_min_ps(x, y); // x: 0123 x x x 4567 x x x + y = _mm256_permute2f128_ps(x, x, 1); // y: 4567 x x x 0123 x x x + x = _mm256_min_ps(x, y); // x: 01234567 x x x x x x x +} #endif #ifdef OPENVINO_ARCH_ARM64 - inline float32x4_t exp_ps_neon_f32(const float32x4_t& src) { - const auto c1 = vreinterpretq_f32_u32(vdupq_n_u32(0x3f7ffff6)); - const auto c2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3efffedb)); - const auto c3 = vreinterpretq_f32_u32(vdupq_n_u32(0x3e2aaf33)); - const auto c4 = vreinterpretq_f32_u32(vdupq_n_u32(0x3d2b9f17)); - const auto c5 = vreinterpretq_f32_u32(vdupq_n_u32(0x3c072010)); - - const auto shift = vreinterpretq_f32_u32(vdupq_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f - const auto one = vdupq_n_f32(1.0f); // 1 - const auto two = vdupq_n_f32(2.0f); // 2 - const auto inv_ln2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3fb8aa3b)); - const auto neg_ln2_hi = vreinterpretq_f32_u32(vdupq_n_u32(0xbf317200)); - const auto neg_ln2_lo = vreinterpretq_f32_u32(vdupq_n_u32(0xb5bfbe8e)); - - const auto inf = vdupq_n_f32(std::numeric_limits::infinity()); - const auto max_input = vdupq_n_f32(88.37f); // Approximately ln(2^127.5) - const auto zero = vdupq_n_f32(0.f); - const auto min_input = vdupq_n_f32(-86.64f); // Approximately ln(2^-125) - - const auto z = vmlaq_f32(shift, src, inv_ln2); - auto n = z - shift; - n = vsubq_f32(n, one); - const auto scale = vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n - - const auto r_hi = vfmaq_f32(src, n, neg_ln2_hi); - const auto r = vfmaq_f32(r_hi, n, neg_ln2_lo); - - const auto r2 = r * r; - - const auto p1 = c1 * r; - const auto p23 = vfmaq_f32(c2, c3, r); - const auto p45 = vfmaq_f32(c4, c5, r); - const auto p2345 = vfmaq_f32(p23, p45, r2); - const auto p12345 = vfmaq_f32(p1, p2345, r2); - - auto poly = vfmaq_f32(scale, p12345, scale); - poly = vmulq_f32(poly, two); - - poly = vbslq_f32(vcltq_f32(src, min_input), zero, poly); - poly = vbslq_f32(vcgtq_f32(src, max_input), inf, poly); - - return poly; - } - inline float32x4_t __vld1q_f32(const ov::bfloat16* a) { - uint16x4_t vec_bf16 = vld1_u16(reinterpret_cast(a)); - - float32x4_t vec_f32 = vcvtq_f32_u32(vmovl_u16(vec_bf16)); - return vec_f32; - } - inline float32x4_t __vld1q_f32(const float* a) { - return vld1q_f32(a); - } - inline float32x4_t __vld1q_f32(const ov::float16* a) { - auto _a = reinterpret_cast(a); - return vcvt_f32_f16(vld1_f16(_a)); - } - inline void __vst1q_f32(float* a, float32x4_t b) { - vst1q_f32(a, b); - } - inline void __vst1q_f32(ov::float16* a, float32x4_t b) { - float16x4_t v_f16 = vcvt_f16_f32(b); - vst1_f16(reinterpret_cast(a), v_f16); - } - inline void __vst1q_f32(ov::bfloat16* a, float32x4_t b) { - uint32x4_t v_int32 = vreinterpretq_u32_f32(b); - uint16x4_t v_bf16 = vshrn_n_u32(v_int32, 16); - - vst1_u16(reinterpret_cast(a), v_bf16); - } +inline float32x4_t exp_ps_neon_f32(const float32x4_t& src) { + const auto c1 = vreinterpretq_f32_u32(vdupq_n_u32(0x3f7ffff6)); + const auto c2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3efffedb)); + const auto c3 = vreinterpretq_f32_u32(vdupq_n_u32(0x3e2aaf33)); + const auto c4 = vreinterpretq_f32_u32(vdupq_n_u32(0x3d2b9f17)); + const auto c5 = vreinterpretq_f32_u32(vdupq_n_u32(0x3c072010)); + + const auto shift = vreinterpretq_f32_u32(vdupq_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto one = vdupq_n_f32(1.0f); // 1 + const auto two = vdupq_n_f32(2.0f); // 2 + const auto inv_ln2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3fb8aa3b)); + const auto neg_ln2_hi = vreinterpretq_f32_u32(vdupq_n_u32(0xbf317200)); + const auto neg_ln2_lo = vreinterpretq_f32_u32(vdupq_n_u32(0xb5bfbe8e)); + + const auto inf = vdupq_n_f32(std::numeric_limits::infinity()); + const auto max_input = vdupq_n_f32(88.37f); // Approximately ln(2^127.5) + const auto zero = vdupq_n_f32(0.f); + const auto min_input = vdupq_n_f32(-86.64f); // Approximately ln(2^-125) + + const auto z = vmlaq_f32(shift, src, inv_ln2); + auto n = z - shift; + n = vsubq_f32(n, one); + const auto scale = vreinterpretq_f32_u32(vreinterpretq_u32_f32(z) << 23); // 2^n + + const auto r_hi = vfmaq_f32(src, n, neg_ln2_hi); + const auto r = vfmaq_f32(r_hi, n, neg_ln2_lo); + + const auto r2 = r * r; + + const auto p1 = c1 * r; + const auto p23 = vfmaq_f32(c2, c3, r); + const auto p45 = vfmaq_f32(c4, c5, r); + const auto p2345 = vfmaq_f32(p23, p45, r2); + const auto p12345 = vfmaq_f32(p1, p2345, r2); + + auto poly = vfmaq_f32(scale, p12345, scale); + poly = vmulq_f32(poly, two); + + poly = vbslq_f32(vcltq_f32(src, min_input), zero, poly); + poly = vbslq_f32(vcgtq_f32(src, max_input), inf, poly); + + return poly; +} +inline float32x4_t __vld1q_f32(const ov::bfloat16* a) { + uint16x4_t vec_bf16 = vld1_u16(reinterpret_cast(a)); + + float32x4_t vec_f32 = vcvtq_f32_u32(vmovl_u16(vec_bf16)); + return vec_f32; +} +inline float32x4_t __vld1q_f32(const float* a) { + return vld1q_f32(a); +} +inline float32x4_t __vld1q_f32(const ov::float16* a) { + auto _a = reinterpret_cast(a); + return vcvt_f32_f16(vld1_f16(_a)); +} +inline void __vst1q_f32(float* a, float32x4_t b) { + vst1q_f32(a, b); +} +inline void __vst1q_f32(ov::float16* a, float32x4_t b) { + float16x4_t v_f16 = vcvt_f16_f32(b); + vst1_f16(reinterpret_cast(a), v_f16); +} +inline void __vst1q_f32(ov::bfloat16* a, float32x4_t b) { + uint32x4_t v_int32 = vreinterpretq_u32_f32(b); + uint16x4_t v_bf16 = vshrn_n_u32(v_int32, 16); + + vst1_u16(reinterpret_cast(a), v_bf16); +} #endif #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - inline float16x8_t exp_ps_neon_f16(float16x8_t x) { - const float32x4_t x_high = vcvt_f32_f16(vget_high_f16(x)); - const float32x4_t x_low = vcvt_f32_f16(vget_low_f16(x)); - - // We use f32 to maintain accuracy - const float16x8_t res = vcombine_f16(vcvt_f16_f32(exp_ps_neon_f32(x_low)), vcvt_f16_f32(exp_ps_neon_f32(x_high))); - return res; - } - inline float16_t hsum(float16x8_t vec) { - float16x4_t sum1 = vpadd_f16(vget_low_f16(vec), vget_high_f16(vec)); - float16x4_t sum2 = vpadd_f16(sum1, sum1); - float16x4_t sum3 = vpadd_f16(sum2, sum2); - return vget_lane_f16(sum3, 0); - } +inline float16x8_t exp_ps_neon_f16(float16x8_t x) { + const float32x4_t x_high = vcvt_f32_f16(vget_high_f16(x)); + const float32x4_t x_low = vcvt_f32_f16(vget_low_f16(x)); + + // We use f32 to maintain accuracy + const float16x8_t res = vcombine_f16(vcvt_f16_f32(exp_ps_neon_f32(x_low)), vcvt_f16_f32(exp_ps_neon_f32(x_high))); + return res; +} +inline float16_t hsum(float16x8_t vec) { + float16x4_t sum1 = vpadd_f16(vget_low_f16(vec), vget_high_f16(vec)); + float16x4_t sum2 = vpadd_f16(sum1, sum1); + float16x4_t sum3 = vpadd_f16(sum2, sum2); + return vget_lane_f16(sum3, 0); +} #endif } // namespace XARCH } // namespace Cpu diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp index c7d12a0818749e..3acb0c7447db5c 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.cpp @@ -13,19 +13,19 @@ # include #endif -#include "openvino/core/type/bfloat16.hpp" -#include "openvino/core/type/float16.hpp" -#include "openvino/core/parallel.hpp" +#include "attn_memcpy.hpp" +#include "attn_quant.hpp" +#include "attn_quant_kernel.hpp" +#include "common.hpp" #include "executor_pa.hpp" #include "executor_pa_common.hpp" -#include "common.hpp" -#include "attn_quant_kernel.hpp" +#include "nodes/kernels/x64/brgemm_kernel.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/core/type/bfloat16.hpp" +#include "openvino/core/type/float16.hpp" #include "softmax_kernel.hpp" #include "transpose_kernel.hpp" #include "utils/plain_tensor.hpp" -#include "attn_memcpy.hpp" -#include "attn_quant.hpp" -#include "nodes/kernels/x64/brgemm_kernel.hpp" namespace ov { namespace Extensions { @@ -38,42 +38,45 @@ using namespace ov::intel_cpu; // currently depends on brgemm which only support x64 #ifdef OPENVINO_ARCH_X86_64 -#if defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# if defined(HAVE_AVX2) || defined(HAVE_AVX512F) -#define prefetch_bytes(bytes, sel, advance, src) { \ - auto *p = reinterpret_cast(src); \ - for (size_t i = 0; i < bytes; i += 64) \ - _mm_prefetch(p + i + advance, sel); \ -} +# define prefetch_bytes(bytes, sel, advance, src) \ + { \ + auto* p = reinterpret_cast(src); \ + for (size_t i = 0; i < bytes; i += 64) \ + _mm_prefetch(p + i + advance, sel); \ + } -#else +# else -#define prefetch_bytes(bytes, sel, advance, src) +# define prefetch_bytes(bytes, sel, advance, src) -#endif +# endif -template +template void cvt_copy(TA* dst, TB* src, size_t n) { size_t i = 0; -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) for (; i + vec_len_f32_avx512 <= n; i += vec_len_f32_avx512) { auto vb = mm512_uni_loadu_ps(src + i); mm512_uni_storeu_ps(dst + i, vb); } -#elif defined(HAVE_AVX2) +# elif defined(HAVE_AVX2) for (; i + vec_len_f32_avx2 <= n; i += vec_len_f32_avx2) { auto vb = mm256_uni_loadu_ps(src + i); mm256_uni_storeu_ps(dst + i, vb); } -#endif +# endif for (; i < n; i++) { dst[i] = src[i]; } } -template::type = true> +template ::type = true> static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size_t block_size, size_t group_size = 0) { -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { auto attn_w_vec0 = _mm512_set1_ps(weight[0]); @@ -132,7 +135,7 @@ static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size } } return; -#elif defined(HAVE_AVX2) +# elif defined(HAVE_AVX2) size_t j = 0; for (; j + 4 <= block_size; j += 4) { auto attn_w_vec0 = _mm256_set1_ps(weight[0]); @@ -191,7 +194,7 @@ static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size } } return; -#endif +# endif for (size_t j = 0; j < block_size; j++) { for (size_t i = 0; i < S; i++) { out[i] += weight[j] * v[i]; @@ -199,18 +202,25 @@ static void attn_acc_value_block(float* out, float* weight, T* v, size_t S, size v += S; } } -template::type = true> -static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size, size_t group_size = 0) { +template ::type = true> +static void attn_acc_value_block(float* out, + float* weight, + uint8_t* v, + size_t S, + size_t block_size, + size_t group_size = 0) { // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) size_t src_offset = 0; size_t dst_offset = 0; const size_t _group_size = group_size ? group_size : S; const size_t params_offset = sizeof(float) * 2; const size_t src_stride = S / _group_size * (_group_size + params_offset); -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { dst_offset = 0; @@ -276,7 +286,9 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S size_t i = 0; for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { auto v_out = mm512_uni_loadu_ps((out + dst_offset + i)); - auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), zp0); + auto v0 = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(v_data_ptr + i)))), + zp0); v_out = _mm512_fmadd_ps(attn_w_vec0, v0, v_out); _mm512_storeu_ps((out + dst_offset + i), v_out); @@ -291,7 +303,7 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S weight++; } return; -#elif defined(HAVE_AVX2) +# elif defined(HAVE_AVX2) size_t j = 0; for (; j < block_size; j++) { dst_offset = 0; @@ -306,7 +318,9 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S v += 8; for (; i + vec_len_f32_avx2 <= _group_size; i += vec_len_f32_avx2) { auto v_out = mm256_uni_loadu_ps(out + dst_offset + i); - auto v0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(v_data_ptr + i)))), zp0); + auto v0 = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<__m128i*>(v_data_ptr + i)))), + zp0); v_out = _mm256_fmadd_ps(attn_w_vec0, v0, v_out); mm256_uni_storeu_ps(out + dst_offset + i, v_out); @@ -321,7 +335,7 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S weight++; } return; -#endif +# endif for (size_t j = 0; j < block_size; j++) { dst_offset = 0; src_offset = 0; @@ -337,8 +351,15 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S } } -template::type = true> -static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size, size_t group_size = 0) { +template ::type = true> +static void attn_acc_value_block(float* out, + float* weight, + uint8_t* v, + size_t S, + size_t block_size, + size_t group_size = 0) { size_t src_offset = 0; size_t dst_offset = 0; const size_t _group_size = group_size ? group_size : S; @@ -348,7 +369,7 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; - return (uint8_t) ((val >> shift) & 0x000F); + return (uint8_t)((val >> shift) & 0x000F); }; for (size_t j = 0; j < block_size; j++) { dst_offset = 0; @@ -356,11 +377,11 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S while (dst_offset < S) { auto v0 = reinterpret_cast(v + src_offset); size_t i = 0; -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); auto v_zp = _mm512_set1_ps(v0[1]); for (; i + vec_len_f32_avx512 * 2 <= _group_size; i += vec_len_f32_avx512 * 2) { - auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); + auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i / 2 + src_offset + params_offset)); auto v_i32 = _mm512_cvtepu8_epi32(data); auto v_512_low_half = _mm512_srli_epi32(v_i32, 4); @@ -385,11 +406,11 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S mm512_uni_storeu_ps(out + dst_offset + i, v_out0); mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); } -#elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); auto v256_zp = _mm256_set1_ps(v0[1]); for (; i + vec_len_f32_avx2 * 2 <= _group_size; i += vec_len_f32_avx2 * 2) { - auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); + auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i / 2 + src_offset + params_offset)); auto v_i32 = _mm256_cvtepu8_epi32(data); auto v_256_low_half = _mm256_srli_epi32(v_i32, 4); @@ -416,9 +437,9 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S mm256_uni_storeu_ps(out + dst_offset + i, v_out0); mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1); } -#endif +# endif for (; i < _group_size; i += 2) { - uint8_t data = v[i/2 + src_offset + params_offset]; + uint8_t data = v[i / 2 + src_offset + params_offset]; float tmp0 = extract_half_byte(data, static_cast(i % 2)); float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); out[dst_offset + i] += weight[j] * (tmp0 - v0[1]) * v0[0]; @@ -431,8 +452,15 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S } } -template::type = true> -static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S, size_t block_size, size_t group_size = 0) { +template ::type = true> +static void attn_acc_value_block(float* out, + float* weight, + uint8_t* v, + size_t S, + size_t block_size, + size_t group_size = 0) { size_t src_offset = 0; size_t dst_offset = 0; const size_t _group_size = group_size ? group_size : S; @@ -442,7 +470,7 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S auto extract_half_byte = [](uint8_t val, bool high_half) -> uint8_t { uint8_t shift = high_half ? 0 : 4; - return (uint8_t) ((val >> shift) & 0x000F); + return (uint8_t)((val >> shift) & 0x000F); }; for (size_t j = 0; j < block_size; j++) { @@ -451,12 +479,12 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S while (dst_offset < S) { auto v0 = reinterpret_cast(v + src_offset); size_t i = 0; -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) auto attn_w_vec0 = _mm512_set1_ps(weight[j] * v0[0]); for (; i + vec_len_f32_avx512 * 2 <= _group_size; i += vec_len_f32_avx512 * 2) { - auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); + auto data = _mm_loadu_si128(reinterpret_cast<__m128i*>(v + i / 2 + src_offset + params_offset)); auto v_i32 = _mm512_cvtepi8_epi32(data); - //cvt to f32 + // cvt to f32 auto v_256_low_half = _mm512_srai_epi32(v_i32, 4); auto v_256_high_half = _mm512_slli_epi32(v_i32, 28); v_256_high_half = _mm512_srai_epi32(v_256_high_half, 28); @@ -475,10 +503,10 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S mm512_uni_storeu_ps(out + dst_offset + i, v_out0); mm512_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx512, v_out1); } -#elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) +# elif defined(HAVE_AVX2) || defined(HAVE_AVX512F) auto v256_attn_w_vec0 = _mm256_set1_ps(weight[j] * v0[0]); for (; i + vec_len_f32_avx2 * 2 <= _group_size; i += vec_len_f32_avx2 * 2) { - auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i/2 + src_offset + params_offset)); + auto data = _mm_loadl_epi64(reinterpret_cast<__m128i*>(v + i / 2 + src_offset + params_offset)); auto v_i32 = _mm256_cvtepi8_epi32(data); auto v_256_low_half = _mm256_srai_epi32(v_i32, 4); @@ -500,15 +528,15 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S mm256_uni_storeu_ps(out + dst_offset + i, v_out0); mm256_uni_storeu_ps(out + dst_offset + i + vec_len_f32_avx2, v_out1); } -#endif +# endif for (; i < _group_size; i += 2) { - uint8_t data = v[i/2 + src_offset + params_offset]; + uint8_t data = v[i / 2 + src_offset + params_offset]; float tmp0 = extract_half_byte(data, static_cast(i % 2)); tmp0 = tmp0 > 8 ? (tmp0 - 16) : tmp0; float tmp1 = extract_half_byte(data, static_cast((i + 1) % 2)); tmp1 = tmp1 > 8 ? (tmp1 - 16) : tmp1; - out[dst_offset + i] += weight[j] * (tmp0) * v0[0]; - out[dst_offset + i + 1] += weight[j] * (tmp1) * v0[0]; + out[dst_offset + i] += weight[j] * (tmp0)*v0[0]; + out[dst_offset + i + 1] += weight[j] * (tmp1)*v0[0]; } dst_offset += _group_size; src_offset += _group_size / sub_byte_multiplyer + params_offset; @@ -517,9 +545,9 @@ static void attn_acc_value_block(float* out, float* weight, uint8_t* v, size_t S } } -template +template static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_size, size_t group_size = 0) { -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { auto vsum0 = _mm512_setzero_ps(); @@ -549,7 +577,7 @@ static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_siz c[2] = sum2; c[3] = sum3; c += 4; - b += 4 * n; + b += 4 * n; } for (; j < block_size; j++) { auto vsum = _mm512_setzero_ps(); @@ -566,7 +594,7 @@ static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_siz *c++ = sum; } return; -#elif defined(HAVE_AVX2) +# elif defined(HAVE_AVX2) size_t j = 0; for (; j + 4 <= block_size; j += 4) { auto vsum0 = _mm256_set1_ps(0.0f); @@ -600,7 +628,7 @@ static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_siz c[2] = sum2; c[3] = sum3; c += 4; - b += 4 * n; + b += 4 * n; } for (; j < block_size; j++) { auto vsum = _mm256_set1_ps(0.0f); @@ -618,7 +646,7 @@ static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_siz *c++ = sum; } return; -#endif +# endif for (size_t j = 0; j < block_size; j++) { float sum = 0; for (size_t i = 0; i < n; i++) { @@ -629,17 +657,17 @@ static void dot_product_block(TA* a, TB* b, float* c, size_t n, size_t block_siz } } -template +template static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t block_size, size_t group_size = 0) { // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) size_t src_offset = 0; size_t dst_offset = 0; const size_t _group_size = group_size ? group_size : n; const size_t params_offset = sizeof(float) * 2; const size_t src_stride = n / _group_size * (_group_size + params_offset); -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) size_t j = 0; for (; j + 4 <= block_size; j += 4) { src_offset = 0; @@ -719,7 +747,9 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc uint8_t* b_data_ptr = b + src_offset + params_offset; for (; i + vec_len_f32_avx512 <= _group_size; i += vec_len_f32_avx512) { auto va = mm512_uni_loadu_ps(a + dst_offset + i); - auto vb = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), v_zp); + auto vb = _mm512_sub_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32( + _mm_loadu_si128(reinterpret_cast<__m128i*>(b_data_ptr + i)))), + v_zp); vsum = _mm512_fmadd_ps(va, vb, vsum); } float group_sum = _mm512_reduce_add_ps(vsum); @@ -734,7 +764,7 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc *c++ = sum; } return; -#elif defined(HAVE_AVX2) +# elif defined(HAVE_AVX2) size_t j = 0; for (; j + 4 <= block_size; j += 4) { src_offset = 0; @@ -788,7 +818,7 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc float group_sum3 = _mm256_cvtss_f32(vsum3); for (; i < _group_size; i++) { group_sum0 += a[dst_offset + i] * (b[i] - b0[1]); - group_sum1 += a[dst_offset + i] * (b[i +src_stride] - b1[1]); + group_sum1 += a[dst_offset + i] * (b[i + src_stride] - b1[1]); group_sum2 += a[dst_offset + i] * (b[i + 2 * src_stride] - b2[1]); group_sum3 += a[dst_offset + i] * (b[i + 3 * src_stride] - b3[1]); } @@ -804,7 +834,7 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc c[2] = sum2; c[3] = sum3; c += 4; - b += 4 * src_stride; + b += 4 * src_stride; } for (; j < block_size; j++) { src_offset = 0; @@ -818,7 +848,9 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc uint8_t* b_data_ptr = b + src_offset + params_offset; for (; i + vec_len_f32_avx2 <= _group_size; i += vec_len_f32_avx2) { auto va = mm256_uni_loadu_ps(a + dst_offset + i); - auto vb = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), v_zp); + auto vb = _mm256_sub_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32( + _mm_loadl_epi64(reinterpret_cast<__m128i*>(b_data_ptr + i)))), + v_zp); vsum = _mm256_fmadd_ps(va, vb, vsum); } hsum(vsum); @@ -834,7 +866,7 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc *c++ = sum; } return; -#endif +# endif for (size_t j = 0; j < block_size; j++) { float sum = 0; dst_offset = 0; @@ -854,11 +886,11 @@ static void dot_product_block(TA* a, uint8_t* b, float* c, size_t n, size_t bloc } } -template +template static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_stride) { size_t i = 0; -#if defined(HAVE_AVX512F) - for (; i + vec_len_f32_avx512 <= S; i+= vec_len_f32_avx512) { +# if defined(HAVE_AVX512F) + for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { auto* src = temp + i; auto result_vec_fp32 = _mm512_setzero_ps(); for (size_t m = 0; m < M; m++) { @@ -869,7 +901,7 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str // save to bf16 mm512_uni_storeu_ps(dst + i, result_vec_fp32); } -#elif defined(HAVE_AVX2) +# elif defined(HAVE_AVX2) for (; i + vec_len_f32_avx2 <= S; i += vec_len_f32_avx2) { auto* src = temp + i; auto result_vec_fp32 = _mm256_set1_ps(0.0f); @@ -880,7 +912,7 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str } mm256_uni_storeu_ps(dst + i, result_vec_fp32); } -#endif +# endif for (; i < S; i++) { auto* src = temp + i; float sum = 0.0f; @@ -894,8 +926,17 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str } // N must be multiple of 16 -template::type = true> -void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::type = true> +void transpose_16NxK(TDST* dst, + void* src, + TDST* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { size_t k = 0; auto* src_ptr = reinterpret_cast::value_type*>(src); for (; k + 16 <= K; k += 16) { @@ -912,7 +953,7 @@ void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t } } } -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) template ::type = true> -void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::type = true> +void transpose_16NxK(TDST* dst, + void* src, + TDST* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = reinterpret_cast::value_type*>(src); auto t = tmp; // if group_size not set, the whole row is used as a group size_t _group_size = group_size ? group_size : K; - for (size_t n = 0; n < N; n ++) { + for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { @@ -960,31 +1010,43 @@ void transpose_16NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t s += src_offset; t += src_stride; } - transpose_16NxK::value>(dst, tmp, reinterpret_cast(0), N, K, dst_stride, src_stride); + transpose_16NxK::value>(dst, + tmp, + reinterpret_cast(0), + N, + K, + dst_stride, + src_stride); } // dequant f16/u8 to float -template::type = true> +template ::type = true> static inline void dequant(T* dst, void* src, size_t N, size_t K, size_t group_size = 0) { // never called OPENVINO_THROW("dequant: should not be called."); } -template::type = true> +template ::type = true> static inline void dequant(float* dst, ov::float16* src, size_t N, size_t K, size_t group_size = 0) { cvt_copy(dst, src, K * N); } -template::type = true> +template ::type = true> void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) { // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = src; const size_t params_offset = sizeof(float) * 2; const size_t _group_size = group_size ? group_size : K; const size_t src_stride = K / _group_size * (_group_size + params_offset); - for (size_t n = 0; n < N; n ++) { + for (size_t n = 0; n < N; n++) { size_t group_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { @@ -998,17 +1060,19 @@ void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) } } -template::type = true> +template ::type = true> void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) { // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = src; const size_t params_offset = sizeof(float) * 2; const size_t _group_size = group_size ? group_size : K; const size_t sub_byte_mulitplier = 2; - for (size_t n = 0; n < N; n ++) { + for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { @@ -1022,17 +1086,19 @@ void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) } } -template::type = true> +template ::type = true> void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) { // The layout for per token per head: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The quantized feature will start from 8bytes=sizeof(float)+sizeof(float) auto s = src; const size_t params_offset = sizeof(float); const size_t _group_size = group_size ? group_size : K; const size_t sub_byte_mulitplier = 2; - for (size_t n = 0; n < N; n ++) { + for (size_t n = 0; n < N; n++) { size_t src_offset = 0; size_t dst_offset = 0; while (dst_offset < K) { @@ -1046,18 +1112,24 @@ void dequant(TDST* dst, uint8_t* src, size_t N, size_t K, size_t group_size = 0) } } -#if defined(HAVE_AVX512F) -template::value || std::is_same::value), bool>::type> +# if defined(HAVE_AVX512F) +template ::value || std::is_same::value), bool>::type> static void pack_32x32_kernel(T* dst, T* src, size_t dst_stride, size_t src_stride) { static const uint64_t idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; auto midx = _mm512_loadu_si512(idx); for (size_t i = 0; i < 16; i++) { - auto a = _mm512_loadu_si512(src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit + auto a = _mm512_loadu_si512(src); // [a1 a2 a3 a4 | a5 a6 a7 a8] total 512-bits in 8 64bits unit auto b = _mm512_loadu_si512(src + src_stride); // [b1 b2 b3 b4 | b5 b6 b7 b8] total 512-bits a = _mm512_permutexvar_epi64(midx, a); // [a1 a5 | a2 a6 | a3 a7 | a4 a8] b = _mm512_permutexvar_epi64(midx, b); // [b1 b5 | b2 b6 | b3 b7 | b4 b8] - auto B0 = _mm512_unpacklo_epi16(a, b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits - auto B1 = _mm512_unpackhi_epi16(a, b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits + auto B0 = _mm512_unpacklo_epi16( + a, + b); // [ a1&b1 a2&b2 a3&b3 a4&b4] for each 128-bits lane, interleave word in low 64 bits + auto B1 = _mm512_unpackhi_epi16( + a, + b); // [ a5&b5 a6&b6 a7&b7 a8&b8] for each 128-bits lane, interleave word in high 64 bits _mm512_storeu_si512(dst, B0); _mm512_storeu_si512(dst + 32, B1); src += 2 * src_stride; @@ -1065,17 +1137,20 @@ static void pack_32x32_kernel(T* dst, T* src, size_t dst_stride, size_t src_stri } } -template::value || std::is_same::value), bool>::type> +template ::value || std::is_same::value), bool>::type> static void pack_32x16_kernel(T* dst, T* src, size_t dst_stride, size_t src_stride) { static const uint64_t idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; auto midx = _mm512_loadu_si512(idx); for (size_t i = 0; i < 16; i++) { - auto x = _mm256_loadu_si256(reinterpret_cast<__m256i*>(src)); // [a1 a2 a3 a4] total 256-bits in 4 64bits unit + auto x = + _mm256_loadu_si256(reinterpret_cast<__m256i*>(src)); // [a1 a2 a3 a4] total 256-bits in 4 64bits unit auto y = _mm256_loadu_si256(reinterpret_cast<__m256i*>(src + src_stride)); // [b1 b2 b3 b4] total 256-bits auto a = _mm512_castsi256_si512(x); auto b = _mm512_castsi256_si512(y); - a = _mm512_permutexvar_epi64(midx, a); // [a1 x | a2 x | a3 x | a4 x] - b = _mm512_permutexvar_epi64(midx, b); // [b1 x | b2 x | b3 x | b4 x] + a = _mm512_permutexvar_epi64(midx, a); // [a1 x | a2 x | a3 x | a4 x] + b = _mm512_permutexvar_epi64(midx, b); // [b1 x | b2 x | b3 x | b4 x] auto B0 = _mm512_unpacklo_epi16(a, b); _mm512_storeu_si512(dst, B0); src += 2 * src_stride; @@ -1083,18 +1158,20 @@ static void pack_32x16_kernel(T* dst, T* src, size_t dst_stride, size_t src_stri } } -template::value || std::is_same::value), bool>::type> +template ::value || std::is_same::value), bool>::type> static void pack_32xK_kernel(T* dst, T* src, size_t dst_stride, size_t src_stride, size_t K) { static const uint64_t idx[8] = {0, 4, 1, 5, 2, 6, 3, 7}; auto midx = _mm512_loadu_si512(idx); __mmask16 mask = (1 << K) - 1; for (size_t i = 0; i < K; i++) { - auto x = _mm256_maskz_loadu_epi16(mask, src); // [a1 a2 a3 a4] total 256-bits in 4 64bits unit - auto y = _mm256_maskz_loadu_epi16(mask, src + src_stride); // [b1 b2 b3 b4] total 256-bits + auto x = _mm256_maskz_loadu_epi16(mask, src); // [a1 a2 a3 a4] total 256-bits in 4 64bits unit + auto y = _mm256_maskz_loadu_epi16(mask, src + src_stride); // [b1 b2 b3 b4] total 256-bits auto a = _mm512_castsi256_si512(x); auto b = _mm512_castsi256_si512(y); - a = _mm512_permutexvar_epi64(midx, a); // [a1 x | a2 x | a3 x | a4 x] - b = _mm512_permutexvar_epi64(midx, b); // [b1 x | b2 x | b3 x | b4 x] + a = _mm512_permutexvar_epi64(midx, a); // [a1 x | a2 x | a3 x | a4 x] + b = _mm512_permutexvar_epi64(midx, b); // [b1 x | b2 x | b3 x | b4 x] auto B0 = _mm512_unpacklo_epi16(a, b); _mm512_mask_storeu_epi32(dst, mask, B0); src += 2 * src_stride; @@ -1239,8 +1316,17 @@ static void pack_32NxK(TDST* dst, } # endif -template::value == ov::element::f32, bool>::type = true> -static void pack_32NxK(TDST* dst, void* src, TDST* tmp, size_t N, size_t K, size_t dst_stride, size_t src_stride, size_t group_size = 0) { +template ::value == ov::element::f32, bool>::type = true> +static void pack_32NxK(TDST* dst, + void* src, + TDST* tmp, + size_t N, + size_t K, + size_t dst_stride, + size_t src_stride, + size_t group_size = 0) { // never called OPENVINO_THROW("pack_32NxK: should not be called."); } @@ -1260,10 +1346,10 @@ struct MHAHelper { size_t _key_group_size = 0; size_t _value_group_size = 0; - PlainTensor _weight; // [nthr, H, 32, rnd_up(kv_len, block_size)], shared by first and second loop along bh - PlainTensor _output; // [nthr, 32, H, S], shared by first and second loop along bh - PlainTensor _qk_scratch_a; // [nthr, scratch_a_size] - PlainTensor _qk_scratch_b; // [B, rnd_up(kv_len, block_size), Hk, scratch_b_size] + PlainTensor _weight; // [nthr, H, 32, rnd_up(kv_len, block_size)], shared by first and second loop along bh + PlainTensor _output; // [nthr, 32, H, S], shared by first and second loop along bh + PlainTensor _qk_scratch_a; // [nthr, scratch_a_size] + PlainTensor _qk_scratch_b; // [B, rnd_up(kv_len, block_size), Hk, scratch_b_size] PlainTensor _wv_scratch_a; PlainTensor _wv_scratch_b; PlainTensor _alibi_lookup; @@ -1288,12 +1374,22 @@ struct MHAHelper { _weight.resize({size_t{1}, size_t{1}, size_t{1}, size_t{1}}); } - explicit MHAHelper(size_t key_group_size, size_t value_group_size) : _key_group_size(key_group_size), _value_group_size(value_group_size) { + explicit MHAHelper(size_t key_group_size, size_t value_group_size) + : _key_group_size(key_group_size), + _value_group_size(value_group_size) { _weight.resize({size_t{1}, size_t{1}, size_t{1}, size_t{1}}); } - void init(size_t H, size_t S, size_t SV, size_t Hk, size_t h_each_group_len, size_t block_size, size_t sliding_window, - float d_scale, size_t kv_len, bool init_alibi_lookup) { + void init(size_t H, + size_t S, + size_t SV, + size_t Hk, + size_t h_each_group_len, + size_t block_size, + size_t sliding_window, + float d_scale, + size_t kv_len, + bool init_alibi_lookup) { // query shape: [B, H, L, S] // present_key shape: [block, H, 32, S] // Q*K': [M1, S] * [M2, S]' @@ -1335,25 +1431,27 @@ struct MHAHelper { _weight.stride(2), false, in_type); - _wv_gemm[i] = std::make_shared(i + 1, - _SV, - _block_size, - // if it's bf16, the stride needs double due to reuse float buffer - (in_type == ov::element::Type_t::f32 ? 1 : 2) * _weight.stride(2), - _SV, - _output.stride(1), - false, - in_type); - _wv_gemm_acc[i] = std::make_shared(i + 1, - _SV, - _block_size, - // if it's bf16, the stride needs double due to reuse float buffer - (in_type == ov::element::Type_t::f32 ? 1 : 2) * _weight.stride(2), - _SV, - _output.stride(1), - false, - in_type, - true); + _wv_gemm[i] = + std::make_shared(i + 1, + _SV, + _block_size, + // if it's bf16, the stride needs double due to reuse float buffer + (in_type == ov::element::Type_t::f32 ? 1 : 2) * _weight.stride(2), + _SV, + _output.stride(1), + false, + in_type); + _wv_gemm_acc[i] = + std::make_shared(i + 1, + _SV, + _block_size, + // if it's bf16, the stride needs double due to reuse float buffer + (in_type == ov::element::Type_t::f32 ? 1 : 2) * _weight.stride(2), + _SV, + _output.stride(1), + false, + in_type, + true); } // wsp is used to compute beta when K is blocked @@ -1361,8 +1459,10 @@ struct MHAHelper { _wsp.resize(_nthr * _wsp_size_per_thread); // allocate scratch a/b, notice get_scratch_a_size/get_scratch_b_size returns in bytes - _qk_scratch_a.resize({_nthr, _qk_gemm[_block_size - 1]->get_scratch_a_size() / sizeof(DATA_TYPE)}); - _wv_scratch_a.resize({_nthr, _wv_gemm[_block_size - 1]->get_scratch_a_size() / sizeof(DATA_TYPE)}); + _qk_scratch_a.resize( + {_nthr, _qk_gemm[_block_size - 1]->get_scratch_a_size() / sizeof(DATA_TYPE)}); + _wv_scratch_a.resize( + {_nthr, _wv_gemm[_block_size - 1]->get_scratch_a_size() / sizeof(DATA_TYPE)}); if ((S % 32 == 0) && (block_size % 16 == 0) && (S <= 32 * 6)) { if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::amx_bf16) && @@ -1378,14 +1478,16 @@ struct MHAHelper { } } if (one_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16) && !_gemv) { - _gemv = std::make_shared(static_cast(S), static_cast(block_size), _fastpath_valid_prec); + _gemv = std::make_shared(static_cast(S), + static_cast(block_size), + _fastpath_valid_prec); } } if (init_alibi_lookup && (!_alibi_lookup || _alibi_lookup.m_dims[0] < kv_len)) { _alibi_lookup.resize({kv_len * 2}); for (size_t i = 0; i < _alibi_lookup.m_dims[0]; i++) - _alibi_lookup.ptr()[i] = - static_cast((_alibi_lookup.m_dims[0] - 1 - i)); + _alibi_lookup.ptr()[i] = -static_cast((_alibi_lookup.m_dims[0] - 1 - i)); } } @@ -1421,9 +1523,21 @@ struct MHAHelper { // output_emb: [L, H * S] // qk_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size] // wv_scratch_b: [rnd_up(kv_len, block_size), Hk, scratch_b_size] - void exec_kernel_multiple(const PlainTensor& query, const PlainTensor& present_value, const PlainTensor& output_emb, - const PlainTensor& qk_scratch_b, const PlainTensor& wv_scratch_b, const int32_t* block_table, size_t ithr, size_t q_blk, - size_t hq_beg, size_t hq_end, size_t hk, size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) { + void exec_kernel_multiple(const PlainTensor& query, + const PlainTensor& present_value, + const PlainTensor& output_emb, + const PlainTensor& qk_scratch_b, + const PlainTensor& wv_scratch_b, + const int32_t* block_table, + size_t ithr, + size_t q_blk, + size_t hq_beg, + size_t hq_end, + size_t hk, + size_t q_len, + size_t cur_kv_len, + const PlainTensor& alibi_slopes, + float* score_output) { auto q_start = q_blk * _block_size; auto q_end = std::min(q_start + _block_size, q_len); auto q_cnt = q_end - q_start; @@ -1496,13 +1610,16 @@ struct MHAHelper { alibi_slope); } if (score_output) { - cvt_copy(score_output + h * rnd_up(cur_kv_len, 16), reinterpret_cast(score), cur_kv_len); + cvt_copy(score_output + h * rnd_up(cur_kv_len, 16), + reinterpret_cast(score), + cur_kv_len); } } // reuse float buffer, need to use float to compute offset auto* w_ptr = reinterpret_cast(_weight.ptr(ithr, h, 0, 0)); - float* fp32_out_ptr = q_is_xf16 ? _output.ptr(ithr, 0, h, 0) : output_emb.ptr(q_start, h * _SV); + float* fp32_out_ptr = + q_is_xf16 ? _output.ptr(ithr, 0, h, 0) : output_emb.ptr(q_start, h * _SV); // for each weight block, loop through all value block for (size_t v_blk = 0; v_blk < cur_kv_len_blocks; v_blk++) { @@ -1520,12 +1637,13 @@ struct MHAHelper { _wsp.data() + ithr * _wsp_size_per_thread, _wv_scratch_a ? _wv_scratch_a.ptr(ithr, 0) : nullptr); } else { - _wv_gemm_acc[q_cnt - 1]->executeGemm(q_cnt < _block_size, - w_ptr + v_blk * _block_size, - v_ptr, - fp32_out_ptr, - _wsp.data() + ithr * _wsp_size_per_thread, - _wv_scratch_a ? _wv_scratch_a.ptr(ithr, 0) : nullptr); + _wv_gemm_acc[q_cnt - 1]->executeGemm( + q_cnt < _block_size, + w_ptr + v_blk * _block_size, + v_ptr, + fp32_out_ptr, + _wsp.data() + ithr * _wsp_size_per_thread, + _wv_scratch_a ? _wv_scratch_a.ptr(ithr, 0) : nullptr); } } if (q_is_xf16) { @@ -1548,17 +1666,28 @@ struct MHAHelper { // output_emb: [L, H * S] // weight: [nthr, H, 32, rnd_up(kv_len, block_size)] // output: [nthr, 32, H, S] - void exec_kernel_one_bh(const PlainTensor& query, const PlainTensor& present_key, const PlainTensor& present_value, const PlainTensor& output_emb, - const int32_t* block_table, size_t ithr, size_t hq_beg, size_t hq_end, size_t hk, - size_t q_len, size_t cur_kv_len, const PlainTensor& alibi_slopes, float* score_output) { + void exec_kernel_one_bh(const PlainTensor& query, + const PlainTensor& present_key, + const PlainTensor& present_value, + const PlainTensor& output_emb, + const int32_t* block_table, + size_t ithr, + size_t hq_beg, + size_t hq_end, + size_t hk, + size_t q_len, + size_t cur_kv_len, + const PlainTensor& alibi_slopes, + float* score_output) { if (one_of(_fastpath_valid_prec, ov::element::bf16, ov::element::f16)) { _gemv->tile_config(); for (size_t pk = 0, i = 0; pk < cur_kv_len; pk += _block_size, i++) { auto block_number = block_table[i]; for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - (*_gemv)(query.ptr(h, pq), present_key.ptr(block_number, hk), - _weight.ptr(ithr, h, pq) + pk); + (*_gemv)(query.ptr(h, pq), + present_key.ptr(block_number, hk), + _weight.ptr(ithr, h, pq) + pk); } } } @@ -1568,8 +1697,12 @@ struct MHAHelper { auto block_number = block_table[i]; for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - dot_product_block(query.ptr(h, pq), present_key.ptr(block_number, hk), - _weight.ptr(ithr, h, pq) + pk, _S, std::min(_block_size, cur_kv_len - pk), _key_group_size); + dot_product_block(query.ptr(h, pq), + present_key.ptr(block_number, hk), + _weight.ptr(ithr, h, pq) + pk, + _S, + std::min(_block_size, cur_kv_len - pk), + _key_group_size); } } } @@ -1597,7 +1730,9 @@ struct MHAHelper { ov::element::f32, alibi_slope); if (score_output) { - memcpy(score_output + h * rnd_up(cur_kv_len, 16), _weight.ptr(ithr, h, pq), cur_kv_len * sizeof(float)); + memcpy(score_output + h * rnd_up(cur_kv_len, 16), + _weight.ptr(ithr, h, pq), + cur_kv_len * sizeof(float)); } } } @@ -1610,26 +1745,28 @@ struct MHAHelper { for (size_t h = hq_beg; h < hq_end; h++) { if (present_value.get_precision() == ov::element::u4) { auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); - size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / sub_byte_multiplyer; + size_t v_stride = + (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / + sub_byte_multiplyer; auto* v_ptr = present_value.m_ptr.get() + v_stride; - attn_acc_value_block( - _output.ptr(ithr, pq, h), - _weight.ptr(ithr, h, pq) + pv, - v_ptr, - _SV, - std::min(_block_size, cur_kv_len - pv), - _value_group_size); + attn_acc_value_block(_output.ptr(ithr, pq, h), + _weight.ptr(ithr, h, pq) + pv, + v_ptr, + _SV, + std::min(_block_size, cur_kv_len - pv), + _value_group_size); } else if (present_value.get_precision() == ov::element::i4) { auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); - size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / sub_byte_multiplyer; + size_t v_stride = + (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / + sub_byte_multiplyer; auto* v_ptr = present_value.m_ptr.get() + v_stride; - attn_acc_value_block( - _output.ptr(ithr, pq, h), - _weight.ptr(ithr, h, pq) + pv, - v_ptr, - _SV, - std::min(_block_size, cur_kv_len - pv), - _value_group_size); + attn_acc_value_block(_output.ptr(ithr, pq, h), + _weight.ptr(ithr, h, pq) + pv, + v_ptr, + _SV, + std::min(_block_size, cur_kv_len - pv), + _value_group_size); } else { attn_acc_value_block::value>( _output.ptr(ithr, pq, h), @@ -1648,9 +1785,9 @@ struct MHAHelper { cvt_copy(output_emb.ptr(pq, h * _SV), _output.ptr(ithr, pq, h), _SV); } - // compute one token, loop along batch, head dimensions and kv_len, it's special for very long kv_len with small batch tokens. - // It will assume NO mixture execution of first and second token. - // all tensors such as query... have batch dimension which is DIFFERENT from above + // compute one token, loop along batch, head dimensions and kv_len, it's special for very long kv_len with small + // batch tokens. It will assume NO mixture execution of first and second token. all tensors such as query... have + // batch dimension which is DIFFERENT from above // query: [B, H, L, S] // present_*: [block_number, H, 32, S] // output_emb: [B, L, H * S] @@ -1689,17 +1826,18 @@ struct MHAHelper { // for bigger batch skip the test to save the cost prefer_static_loop = false; } - auto get_h_params = [] (bool loop_hk, size_t hx, size_t h_each_group_len, size_t& hq_beg, size_t& hq_end, size_t& hk) { - if (loop_hk) { - hk = hx; - hq_beg = hk * h_each_group_len; - hq_end = (hk + 1) * h_each_group_len; - } else { - hq_beg = hx; - hq_end = hx + 1; - hk = hx / h_each_group_len; - } - }; + auto get_h_params = + [](bool loop_hk, size_t hx, size_t h_each_group_len, size_t& hq_beg, size_t& hq_end, size_t& hk) { + if (loop_hk) { + hk = hx; + hq_beg = hk * h_each_group_len; + hq_end = (hk + 1) * h_each_group_len; + } else { + hq_beg = hx; + hq_end = hx + 1; + hk = hx / h_each_group_len; + } + }; auto loop_qk = [&](size_t b, size_t pk_in_blocks, size_t hx) { auto context_len = static_cast(past_lens.ptr()[b]) + 1; size_t hk, hq_beg, hq_end; @@ -1713,16 +1851,21 @@ struct MHAHelper { _gemv->tile_config(); for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - (*_gemv)(query.ptr(b, h, pq), present_key.ptr(block_number, hk), - _weight_bhl.ptr(b, h, pq) + pk); + (*_gemv)(query.ptr(b, h, pq), + present_key.ptr(block_number, hk), + _weight_bhl.ptr(b, h, pq) + pk); } } _gemv->tile_release(); } else { for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = hq_beg; h < hq_end; h++) { - dot_product_block(query.ptr(b, h, pq), present_key.ptr(block_number, hk), - _weight_bhl.ptr(b, h, pq) + pk, _S, std::min(_block_size, context_len - pk), _key_group_size); + dot_product_block(query.ptr(b, h, pq), + present_key.ptr(block_number, hk), + _weight_bhl.ptr(b, h, pq) + pk, + _S, + std::min(_block_size, context_len - pk), + _key_group_size); } } } @@ -1794,7 +1937,9 @@ struct MHAHelper { for (size_t h = hq_beg; h < hq_end; h++) { if (present_value.get_precision() == ov::element::u4) { auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); - size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / sub_byte_multiplyer; + size_t v_stride = + (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / + sub_byte_multiplyer; auto* v_ptr = present_value.m_ptr.get() + v_stride; attn_acc_value_block( _output_bhl.ptr(ithr, b, pq, h), @@ -1805,7 +1950,9 @@ struct MHAHelper { _value_group_size); } else if (present_value.get_precision() == ov::element::i4) { auto sub_byte_multiplyer = 8 / present_value.get_precision().bitwidth(); - size_t v_stride = (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / sub_byte_multiplyer; + size_t v_stride = + (block_number * present_value.m_strides[0] + hk * present_value.m_strides[1]) / + sub_byte_multiplyer; auto* v_ptr = present_value.m_ptr.get() + v_stride; attn_acc_value_block( _output_bhl.ptr(ithr, b, pq, h), @@ -1847,26 +1994,29 @@ template & _helper; struct AttnWorkItem { - int32_t batch_in_reorder; // which batch in reorder buffer will be used - int32_t batch_in_seq; // batch idx in sequence - int32_t q_len; // current sequence length, 1 for second token, 2+ for first token - int32_t q_block_id; // block id in this seq, valid at first token + int32_t batch_in_reorder; // which batch in reorder buffer will be used + int32_t batch_in_seq; // batch idx in sequence + int32_t q_len; // current sequence length, 1 for second token, 2+ for first token + int32_t q_block_id; // block id in this seq, valid at first token }; struct ReorderWorkItem { - int32_t batch_in_seq; // batch idx in sequence - int32_t batch_in_reorder; // which batch in reorder buffer will be used - int32_t kv_block_id; // block id in this kv cache seq + int32_t batch_in_seq; // batch idx in sequence + int32_t batch_in_reorder; // which batch in reorder buffer will be used + int32_t kv_block_id; // block id in this kv cache seq }; struct WorkItems { private: std::vector attn_items; std::vector reorder_items; - int32_t max_kv_len_in_reorder; // max kv len between first tokens + int32_t max_kv_len_in_reorder; // max kv len between first tokens int32_t max_batch_in_reorder; int32_t total_kv_len; public: - void reset(const PlainTensor& query, const PlainTensor& past_lens, const PlainTensor& subsequence_begins, size_t block_size) { + void reset(const PlainTensor& query, + const PlainTensor& past_lens, + const PlainTensor& subsequence_begins, + size_t block_size) { attn_items.clear(); reorder_items.clear(); max_kv_len_in_reorder = 0; @@ -1879,21 +2029,19 @@ struct MHA { auto kv_len = past_lens.ptr()[i] + q_len; auto kv_len_in_block = static_cast(div_up(kv_len, block_size)); if (q_len == 1) { - attn_items.emplace_back(AttnWorkItem{ - 0, // batch_in_reorder - i, // batch_in_seq - 1ull, // q_len - // kv_len in blocks, used in the sort function - kv_len_in_block - 1 - }); + attn_items.emplace_back(AttnWorkItem{0, // batch_in_reorder + i, // batch_in_seq + 1ull, // q_len + // kv_len in blocks, used in the sort function + kv_len_in_block - 1}); } else { auto reorder_sub_work_count = kv_len_in_block; max_kv_len_in_reorder = std::max(max_kv_len_in_reorder, kv_len); for (int32_t block_id = 0; block_id < reorder_sub_work_count; block_id++) { reorder_items.emplace_back(ReorderWorkItem{ - i, // batch_in_seq - max_batch_in_reorder, // batch_in_reorder - block_id // kv_block_id + i, // batch_in_seq + max_batch_in_reorder, // batch_in_reorder + block_id // kv_block_id }); } @@ -1901,17 +2049,18 @@ struct MHA { auto attn_sub_work_count = static_cast(div_up(q_len, block_size)); for (int32_t block_id = 0; block_id < attn_sub_work_count; block_id++) { attn_items.emplace_back(AttnWorkItem{ - max_batch_in_reorder, // batch_in_reorder - i, // batch_in_seq - q_len, // q_len - block_id // q_block_id + max_batch_in_reorder, // batch_in_reorder + i, // batch_in_seq + q_len, // q_len + block_id // q_block_id }); } max_batch_in_reorder++; } total_kv_len += kv_len; } - // std::sort(attn_items.begin(), attn_items.end(), [] (const AttnWorkItem& left, const AttnWorkItem& right) { + // std::sort(attn_items.begin(), attn_items.end(), [] (const AttnWorkItem& left, const AttnWorkItem& right) + // { // // kv block number which will be acessed later // auto left_kv_blocks = left.q_block_id; // auto right_kv_blocks = right.q_block_id; @@ -1965,7 +2114,8 @@ struct MHA { auto reorder_work_count = _workitems.reorder_work_size(); // buffer for transpose and repack - _helper.init_reorder_buffers(_workitems.get_reorder_max_batch_size(), div_up(_workitems.get_reorder_max_kv_len(), _helper._block_size)); + _helper.init_reorder_buffers(_workitems.get_reorder_max_batch_size(), + div_up(_workitems.get_reorder_max_kv_len(), _helper._block_size)); // packed k, v parallel_for2d_dynamic(reorder_work_count, Hk, [&](size_t w, size_t hk) { @@ -1973,7 +2123,8 @@ struct MHA { const auto batch_in_seq = item.batch_in_seq; const auto batch_in_reorder = item.batch_in_reorder; const auto kv_block = item.kv_block_id; - auto block_number = block_indices.ptr()[block_indices_begins.ptr()[batch_in_seq] + kv_block]; + auto block_number = + block_indices.ptr()[block_indices_begins.ptr()[batch_in_seq] + kv_block]; if (block_number < 0) return; @@ -1981,46 +2132,55 @@ struct MHA { auto* k_ptr = k_cache.ptr(block_number, hk); auto* v_ptr = v_cache.ptr(block_number, hk); - transpose_16NxK::value>(_helper._qk_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - k_ptr, - _helper._output.template ptr(ithr), - _helper._block_size, - _helper._S, _helper._block_size, _helper._S, _helper._key_group_size); + transpose_16NxK::value>( + _helper._qk_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + k_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._S, + _helper._block_size, + _helper._S, + _helper._key_group_size); if (q_is_xf16) { if (v_cache.get_precision() == ov::element::u4) { auto sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); - size_t v_stride = (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; + size_t v_stride = + (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; auto* v_ptr = v_cache.m_ptr.get() + v_stride; - pack_32NxK(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._output.template ptr(ithr), - _helper._block_size, - _helper._SV, - rnd_up(_helper._SV, _helper._block_size), - _helper._SV, - _helper._value_group_size); + pack_32NxK( + _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._SV, + rnd_up(_helper._SV, _helper._block_size), + _helper._SV, + _helper._value_group_size); } else if (v_cache.get_precision() == ov::element::i4) { auto sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); - size_t v_stride = (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; + size_t v_stride = + (block_number * v_cache.m_strides[0] + hk * v_cache.m_strides[1]) / sub_byte_multiplyer; auto* v_ptr = v_cache.m_ptr.get() + v_stride; - pack_32NxK(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._output.template ptr(ithr), - _helper._block_size, - _helper._SV, - rnd_up(_helper._SV, _helper._block_size), - _helper._SV, - _helper._value_group_size); + pack_32NxK( + _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._SV, + rnd_up(_helper._SV, _helper._block_size), + _helper._SV, + _helper._value_group_size); } else { - pack_32NxK::value>(_helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), - v_ptr, - _helper._output.template ptr(ithr), - _helper._block_size, - _helper._SV, - rnd_up(_helper._SV, _helper._block_size), - _helper._SV, - _helper._value_group_size); + pack_32NxK::value>( + _helper._wv_scratch_b.template ptr(batch_in_reorder, kv_block, hk), + v_ptr, + _helper._output.template ptr(ithr), + _helper._block_size, + _helper._SV, + rnd_up(_helper._SV, _helper._block_size), + _helper._SV, + _helper._value_group_size); } } else { // need to decompress @@ -2059,10 +2219,13 @@ struct MHA { } }); - // loop along HK dimension: if mixed first/second token and elements count is enough, loop HK to reuse KV in the CPU cache + // loop along HK dimension: if mixed first/second token and elements count is enough, loop HK to reuse KV in the + // CPU cache // else if elements count is small, prefer to loop H to get more work to avoid thread imbalance - bool loop_hk = _workitems.get_reorder_max_batch_size() == past_lens.m_dims[0] || // if only first token, loop H - attn_work_count * Hk <= 2 * _helper._nthr ? false : true; // or less than 2 work items per thread, loop H + bool loop_hk = _workitems.get_reorder_max_batch_size() == past_lens.m_dims[0] || // if only first token, loop H + attn_work_count * Hk <= 2 * _helper._nthr + ? false + : true; // or less than 2 work items per thread, loop H parallel_for2d_dynamic(attn_work_count, loop_hk ? Hk : _helper._H, [&](size_t w, size_t hx) { size_t hk, hq_beg, hq_end; @@ -2090,16 +2253,26 @@ struct MHA { score_output = _helper._score_output.template ptr() + score_offset * _helper._H; } - _helper.exec_kernel_one_bh(q.slice(0, batch_in_token, batch_in_token), k_cache, v_cache, + _helper.exec_kernel_one_bh( + q.slice(0, batch_in_token, batch_in_token), + k_cache, + v_cache, output_emb.slice(0, batch_in_token, batch_in_token), block_indices.ptr() + block_indices_begins.ptr()[batch_in_seq], - ithr, hq_beg, hq_end, hk, 1ul, cur_kv_len, alibi_slopes, + ithr, + hq_beg, + hq_end, + hk, + 1ul, + cur_kv_len, + alibi_slopes, score_output); } else { const auto batch_in_reorder = item.batch_in_reorder; const auto q_blk = item.q_block_id; const auto q_cnt = std::min(_helper._block_size, q_len - q_blk * _helper._block_size); - const auto cur_kv_len = static_cast(past_lens.ptr()[batch_in_seq]) + q_blk * _helper._block_size + q_cnt; + const auto cur_kv_len = + static_cast(past_lens.ptr()[batch_in_seq]) + q_blk * _helper._block_size + q_cnt; float* score_output = nullptr; if (output_score) { // last block @@ -2112,9 +2285,11 @@ struct MHA { PlainTensor sub_query; sub_query.resize({q_len, _helper._H, _helper._S}, q.ptr(batch_in_token)); sub_query = sub_query.permute({1, 0, 2}); - _helper.exec_kernel_multiple(sub_query, + _helper.exec_kernel_multiple( + sub_query, v_cache, - output_emb.slice(0, batch_in_token, batch_in_token + q_len).reshape({q_len, _helper._H * _helper._SV}), + output_emb.slice(0, batch_in_token, batch_in_token + q_len) + .reshape({q_len, _helper._H * _helper._SV}), _helper._qk_scratch_b.slice(0, batch_in_reorder, batch_in_reorder), _helper._wv_scratch_b.slice(0, batch_in_reorder, batch_in_reorder), block_indices.ptr() + block_indices_begins.ptr()[batch_in_seq], @@ -2131,7 +2306,8 @@ struct MHA { }); if (output_score) { parallel_for2d_dynamic(past_lens.m_dims[0], 1, [&](size_t b, size_t pq) { - auto seq_len = static_cast(subsequence_begins.ptr()[b + 1] - subsequence_begins.ptr()[b]); + auto seq_len = static_cast(subsequence_begins.ptr()[b + 1] - + subsequence_begins.ptr()[b]); auto cur_kv_len = static_cast(past_lens.ptr()[b]) + seq_len; auto src_offset = _helper._score_offsets_aligned.template ptr()[b]; auto* src = _helper._score_output.template ptr() + src_offset * _helper._H; @@ -2162,11 +2338,29 @@ struct MHA { auto nthr = static_cast(parallel_get_max_threads()); if (past_lens.m_dims[0] >= nthr || _workitems.get_reorder_max_batch_size() > 0) { - exec_loop_mixed(query, present_key, present_value, output_emb, output_score, max_context_len, past_lens, subsequence_begins, - block_indices, block_indices_begins, alibi_slopes); + exec_loop_mixed(query, + present_key, + present_value, + output_emb, + output_score, + max_context_len, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + alibi_slopes); } else { - _helper.exec_loop_bhl(query, present_key, present_value, output_emb, output_score, max_context_len, past_lens, subsequence_begins, - block_indices, block_indices_begins, alibi_slopes); + _helper.exec_loop_bhl(query, + present_key, + present_value, + output_emb, + output_score, + max_context_len, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + alibi_slopes); } } }; @@ -2183,18 +2377,32 @@ struct AttentionExecutor : public PagedAttentionExecutor { : _helper(MHAHelper(key_group_size, value_group_size)), _kernel(_helper) {} - void init(const std::vector& inputs, const std::vector& outputs, PlainTensor& q, PlainTensor& k, PlainTensor& v, PlainTensor& k_cache, - PlainTensor& v_cache, PlainTensor& past_lens, PlainTensor& subsequence_begins, PlainTensor& block_indices, PlainTensor& block_indices_begins, - float& scale, size_t& sliding_window, PlainTensor& alibi_slopes, size_t& max_context_len, PlainTensor& output_emb, PlainTensor& output_score) { - q.reset(inputs[ID_Q]); // [B_token, H * S] + void init(const std::vector& inputs, + const std::vector& outputs, + PlainTensor& q, + PlainTensor& k, + PlainTensor& v, + PlainTensor& k_cache, + PlainTensor& v_cache, + PlainTensor& past_lens, + PlainTensor& subsequence_begins, + PlainTensor& block_indices, + PlainTensor& block_indices_begins, + float& scale, + size_t& sliding_window, + PlainTensor& alibi_slopes, + size_t& max_context_len, + PlainTensor& output_emb, + PlainTensor& output_score) { + q.reset(inputs[ID_Q]); // [B_token, H * S] k.reset(inputs[ID_K]); v.reset(inputs[ID_V]); - k_cache.reset(inputs[ID_KCACHE]); // [NUM_BLOCKS, H, 32, S] - v_cache.reset(inputs[ID_VCACHE]); // [NUM_BLOCKS, H, 32, S] - past_lens.reset(inputs[ID_PAST_LENS]); // [B_seq] - subsequence_begins.reset(inputs[ID_SUBSEQUENCE_BEGINS]); // [B_seq+1] - block_indices.reset(inputs[ID_BLOCK_INDICES]); // [num_blocks] - block_indices_begins.reset(inputs[ID_BLOCK_INDICES_BEGINS]);// [B_seq+1] + k_cache.reset(inputs[ID_KCACHE]); // [NUM_BLOCKS, H, 32, S] + v_cache.reset(inputs[ID_VCACHE]); // [NUM_BLOCKS, H, 32, S] + past_lens.reset(inputs[ID_PAST_LENS]); // [B_seq] + subsequence_begins.reset(inputs[ID_SUBSEQUENCE_BEGINS]); // [B_seq+1] + block_indices.reset(inputs[ID_BLOCK_INDICES]); // [num_blocks] + block_indices_begins.reset(inputs[ID_BLOCK_INDICES_BEGINS]); // [B_seq+1] scale = *inputs[ID_SCALE]->getDataAs(); sliding_window = static_cast(*inputs[ID_SLIDING_WINDOW]->getDataAs()); if (!inputs[ID_ALIBI_SLOPES]->getShape().hasZeroDims()) @@ -2209,13 +2417,14 @@ struct AttentionExecutor : public PagedAttentionExecutor { auto _key_group_size = _helper._key_group_size; auto _value_group_size = _helper._key_group_size; // The layout for per token per head for u8 kv cache: - // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized feature(u8,idx_S)| - // The actual size needs to deduct scale and zeropoint. + // |scale(f32)|zeropoint(f32)|quantized feature(u8,idx_1)|quantized feature(u8,idx_2)|...|quantized + // feature(u8,idx_S)| The actual size needs to deduct scale and zeropoint. const size_t key_sub_byte_multiplyer = 8 / k_cache.get_precision().bitwidth(); const size_t value_sub_byte_multiplyer = 8 / v_cache.get_precision().bitwidth(); const size_t key_params_size = sizeof(float) * 2 * key_sub_byte_multiplyer; // u4 needs scale + zp. s4 needs scale. - const size_t param_size = one_of(v_cache.get_precision(), ov::element::u4, ov::element::u8) ? sizeof(float) * 2 : sizeof(float); + const size_t param_size = + one_of(v_cache.get_precision(), ov::element::u4, ov::element::u8) ? sizeof(float) * 2 : sizeof(float); const size_t value_params_size = param_size * value_sub_byte_multiplyer; size_t key_group_num = _key_group_size ? k_cache.size(3) / (_key_group_size + key_params_size) : 1; size_t value_group_num = _value_group_size ? v_cache.size(3) / (_value_group_size + value_params_size) : 1; @@ -2261,8 +2470,14 @@ struct AttentionExecutor : public PagedAttentionExecutor { _helper.init(H, S, SV, Hk, h_each_group_len, block_size, sliding_window, scale, max_context_len, alibi_slopes); } - void concat_pastkv(const PlainTensor& k, const PlainTensor& v, const PlainTensor& k_cache, const PlainTensor& v_cache, - const PlainTensor& past_lens, const PlainTensor& subsequence_begins, const PlainTensor& block_indices, const PlainTensor& block_indices_begins) { + void concat_pastkv(const PlainTensor& k, + const PlainTensor& v, + const PlainTensor& k_cache, + const PlainTensor& v_cache, + const PlainTensor& past_lens, + const PlainTensor& subsequence_begins, + const PlainTensor& block_indices, + const PlainTensor& block_indices_begins) { auto B_token = k.size(0); _slot_mapping.resize({B_token}); @@ -2274,13 +2489,21 @@ struct AttentionExecutor : public PagedAttentionExecutor { auto block_offset_start = kv_len - q_len; for (int32_t j = 0; j < q_len; j++) { auto block_offset = block_offset_start + j; - auto block_number = block_indices.ptr()[block_number_start + block_offset / _helper._block_size]; - _slot_mapping.ptr()[idx++] = block_number * _helper._block_size + block_offset % _helper._block_size; + auto block_number = + block_indices.ptr()[block_number_start + block_offset / _helper._block_size]; + _slot_mapping.ptr()[idx++] = + block_number * _helper._block_size + block_offset % _helper._block_size; } } if (k_cache.m_dt == ov::element::Type_t::u8) { - paged_attn_quantkv(k, v, k_cache, v_cache, _slot_mapping, _helper._key_group_size, _helper._value_group_size); + paged_attn_quantkv(k, + v, + k_cache, + v_cache, + _slot_mapping, + _helper._key_group_size, + _helper._value_group_size); } else { paged_attn_memcpy(k, v, k_cache, v_cache, _slot_mapping); } @@ -2296,12 +2519,36 @@ struct AttentionExecutor : public PagedAttentionExecutor { PlainTensor output_emb; PlainTensor output_score; - init(inputs, outputs, q, k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins, - scale, sliding_window, alibi_slopes, max_context_len, output_emb, output_score); + init(inputs, + outputs, + q, + k, + v, + k_cache, + v_cache, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + scale, + sliding_window, + alibi_slopes, + max_context_len, + output_emb, + output_score); concat_pastkv(k, v, k_cache, v_cache, past_lens, subsequence_begins, block_indices, block_indices_begins); - _kernel(q, k_cache, v_cache, output_emb, output_score, max_context_len, past_lens, subsequence_begins, block_indices, - block_indices_begins, alibi_slopes); + _kernel(q, + k_cache, + v_cache, + output_emb, + output_score, + max_context_len, + past_lens, + subsequence_begins, + block_indices, + block_indices_begins, + alibi_slopes); } }; #endif @@ -2315,32 +2562,35 @@ std::shared_ptr make_pa_executor(ov::element::Type data_ #ifdef OPENVINO_ARCH_X86_64 if (data_type == ov::element::bf16) { -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) if (key_cache_type == ov::element::u8) { - executor = std::make_shared>(key_group_size, value_group_size); + executor = + std::make_shared>(key_group_size, value_group_size); } else { OPENVINO_ASSERT(key_cache_type == ov::element::bf16, "expect kvcache type bf16, current: ", key_cache_type); executor = std::make_shared>(); } -#else +# else OPENVINO_THROW("make_pa_executor: bf16 needs avx512+ hardware."); -#endif +# endif } else if (data_type == ov::element::f16) { -#if defined(HAVE_AVX512F) +# if defined(HAVE_AVX512F) if (key_cache_type == ov::element::u8) { - executor = std::make_shared>(key_group_size, value_group_size); + executor = + std::make_shared>(key_group_size, value_group_size); } else { OPENVINO_ASSERT(key_cache_type == ov::element::f16, "expect kvcache type f16, current: ", key_cache_type); executor = std::make_shared>(); } -#else +# else OPENVINO_THROW("make_pa_executor: f16 needs avx512+ hardware."); -#endif +# endif } else if (data_type == ov::element::f32) { if (key_cache_type == ov::element::u8) { executor = std::make_shared>(key_group_size, value_group_size); } else if (key_cache_type == ov::element::f16) { - executor = std::make_shared>(key_group_size, value_group_size); + executor = + std::make_shared>(key_group_size, value_group_size); } else { OPENVINO_ASSERT(key_cache_type == ov::element::f32, "expect kvcache type f32, current: ", key_cache_type); executor = std::make_shared>(key_group_size, value_group_size); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp index d386dc9f44e321..64e4eefc3b760d 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa.hpp @@ -6,8 +6,9 @@ #include #include #include -#include #include +#include + #include "cpu_memory.h" #include "executor_pa_common.hpp" diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.cpp index 70723a577b0c2b..8a7fa211f8f4ce 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.cpp @@ -1,6 +1,8 @@ // Copyright (C) 2018-2024 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // +#include "executor_pa_common.hpp" + #include #include @@ -9,10 +11,9 @@ #include #include +#include "openvino/core/parallel.hpp" #include "openvino/core/type/bfloat16.hpp" #include "openvino/core/type/float16.hpp" -#include "openvino/core/parallel.hpp" -#include "executor_pa_common.hpp" #include "utils/plain_tensor.hpp" namespace ov { @@ -58,20 +59,23 @@ void TileConfiger::generate() { ret(); } -JitMatMulVecAMX::JitMatMulVecAMX(int head_size, int block_size, ov::element::Type amx_prec) : - jit_generator(jit_name()), m_head_size(head_size), m_block_size(block_size), m_amx_prec(amx_prec) { +JitMatMulVecAMX::JitMatMulVecAMX(int head_size, int block_size, ov::element::Type amx_prec) + : jit_generator(jit_name()), + m_head_size(head_size), + m_block_size(block_size), + m_amx_prec(amx_prec) { create_kernel(); m_tile_cfg.reset(1, 0, { - {16, 4}, // C:0 M x 1 (4b) - {16, 64}, // A:1 M x 32/64 (64b) - {16, 4}, // B:2 32/64 x 1 (4b) - {16, 4}, // B:3 - {16, 4}, // B:4 - {16, 4}, // B:5 - {16, 4}, // B:6 - {16, 4}, // B:7 + {16, 4}, // C:0 M x 1 (4b) + {16, 64}, // A:1 M x 32/64 (64b) + {16, 4}, // B:2 32/64 x 1 (4b) + {16, 4}, // B:3 + {16, 4}, // B:4 + {16, 4}, // B:5 + {16, 4}, // B:6 + {16, 4}, // B:7 }); } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp index bc21457a3285b4..81c54c84d9453a 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/executor_pa_common.hpp @@ -6,11 +6,12 @@ #include #include #include -#include #include -#include "cpu_memory.h" +#include + #include "cpu/x64/cpu_isa_traits.hpp" #include "cpu/x64/jit_generator.hpp" +#include "cpu_memory.h" namespace ov { namespace Extensions { @@ -20,20 +21,21 @@ namespace Cpu { struct PagedAttentionExecutor { // PagedAttention input index - static const size_t ID_Q = 0; // [B_token, H * S], float - static const size_t ID_K = 1; // [B_token, Hk * S], float - static const size_t ID_V = 2; // [B_token, Hk * S], float - static const size_t ID_KCACHE = 3; // [block_number, H, block_size, S], float - static const size_t ID_VCACHE = 4; // [block_number, H, block_size, S], float - static const size_t ID_PAST_LENS = 5; // [B_seq] - static const size_t ID_SUBSEQUENCE_BEGINS = 6; // [B_seq+1] - static const size_t ID_BLOCK_INDICES = 7; // [num_blocks] - static const size_t ID_BLOCK_INDICES_BEGINS = 8; // [B_seq+1] - static const size_t ID_SCALE = 9; // [], float - static const size_t ID_SLIDING_WINDOW = 10; // [] - static const size_t ID_ALIBI_SLOPES = 11; // [H|0], float - static const size_t ID_MAX_CONTEXT_LEN = 12; // [] - virtual void execute(const std::vector& inputs, const std::vector outputs) = 0; + static const size_t ID_Q = 0; // [B_token, H * S], float + static const size_t ID_K = 1; // [B_token, Hk * S], float + static const size_t ID_V = 2; // [B_token, Hk * S], float + static const size_t ID_KCACHE = 3; // [block_number, H, block_size, S], float + static const size_t ID_VCACHE = 4; // [block_number, H, block_size, S], float + static const size_t ID_PAST_LENS = 5; // [B_seq] + static const size_t ID_SUBSEQUENCE_BEGINS = 6; // [B_seq+1] + static const size_t ID_BLOCK_INDICES = 7; // [num_blocks] + static const size_t ID_BLOCK_INDICES_BEGINS = 8; // [B_seq+1] + static const size_t ID_SCALE = 9; // [], float + static const size_t ID_SLIDING_WINDOW = 10; // [] + static const size_t ID_ALIBI_SLOPES = 11; // [H|0], float + static const size_t ID_MAX_CONTEXT_LEN = 12; // [] + virtual void execute(const std::vector& inputs, + const std::vector outputs) = 0; virtual ~PagedAttentionExecutor() = default; }; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp index 25ddbb1b4246b1..f2180b5314cc07 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -13,11 +13,10 @@ # include #endif - -#include "openvino/core/type/bfloat16.hpp" -#include "openvino/core/parallel.hpp" -#include "mha_single_token.hpp" #include "common.hpp" +#include "mha_single_token.hpp" +#include "openvino/core/parallel.hpp" +#include "openvino/core/type/bfloat16.hpp" #include "softmax_kernel.hpp" #if defined(OPENVINO_ARCH_ARM64) @@ -33,19 +32,20 @@ using namespace ov; #if defined(HAVE_AVX2) -#define prefetch_bytes(bytes, sel, advance, src) { \ - auto *p = reinterpret_cast(src); \ - for (size_t i = 0; i < bytes; i += 64) \ - _mm_prefetch(p + i + advance, sel); \ -} +# define prefetch_bytes(bytes, sel, advance, src) \ + { \ + auto* p = reinterpret_cast(src); \ + for (size_t i = 0; i < bytes; i += 64) \ + _mm_prefetch(p + i + advance, sel); \ + } #else -#define prefetch_bytes(bytes, sel, advance, src) +# define prefetch_bytes(bytes, sel, advance, src) #endif -template +template void cvt_copy(TA* dst, TB* src, size_t n) { size_t i = 0; #if defined(HAVE_AVX512F) @@ -65,21 +65,21 @@ void cvt_copy(TA* dst, TB* src, size_t n) { __vst1q_f32(dst + i, vb1); } } -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +# if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) if (std::is_same::value && std::is_same::value) { for (; i + vec_len_f16_neon <= n; i += vec_len_f16_neon) { auto vb1 = vld1q_f16(reinterpret_cast(src + i)); vst1q_f16(reinterpret_cast(dst + i), vb1); } } -#endif +# endif #endif for (; i < n; i++) { dst[i] = src[i]; } } -template +template static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scale, float* zp) { size_t i = 0; #if defined(HAVE_AVX512F) @@ -113,12 +113,12 @@ static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scal } #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) -template +template static void attn_acc_value(ov::float16* out, ov::float16 weight, T* v, size_t S, float* scale, float* zp) { size_t i = 0; auto attn_w_vec_fp16 = vdupq_n_f16(weight); - auto _v = reinterpret_cast(v); - auto _out = reinterpret_cast(out); + auto _v = reinterpret_cast(v); + auto _out = reinterpret_cast(out); for (; i + vec_len_f16_neon <= S; i += vec_len_f16_neon) { auto v_value = vld1q_f16(_v + i); auto v_out = vld1q_f16(_out + i); @@ -131,7 +131,6 @@ static void attn_acc_value(ov::float16* out, ov::float16 weight, T* v, size_t S, } #endif - static void attn_acc_value(float* out, float weight, uint8_t* v, size_t S, float* scale, float* zp) { size_t i = 0; weight *= *scale; @@ -285,7 +284,7 @@ static void attn_acc_value(float* out, float weight, uint8_t* v, size_t S, float } } -template +template static float sum_q_head(T* a, size_t n) { float sum = 0.0f; size_t i = 0; @@ -406,7 +405,7 @@ static float sum_q_head(T* a, size_t n) { return sum; } -template +template static float dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float* head_sum) { size_t i = 0; float sum = 0.0f; @@ -552,7 +551,12 @@ static float dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float* } #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) -static ov::float16 dot_product_fp16(ov::float16* a, ov::float16* b, size_t n, float* scale, float* zp, float* head_sum) { +static ov::float16 dot_product_fp16(ov::float16* a, + ov::float16* b, + size_t n, + float* scale, + float* zp, + float* head_sum) { size_t i = 0; ov::float16 sum = 0.0f; auto vsum0 = vdupq_n_f16(0.0f); @@ -609,7 +613,7 @@ static ov::float16 dot_product_fp16(ov::float16* a, ov::float16* b, size_t n, fl } #endif -template +template static float dot_product(TA* a, uint8_t* b, size_t n, float* scale, float* zp, float* head_sum) { size_t i = 0; float sum = 0.0f; @@ -763,11 +767,11 @@ static float dot_product(TA* a, uint8_t* b, size_t n, float* scale, float* zp, f #endif } -template +template static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_stride) { size_t i = 0; #if defined(HAVE_AVX512F) - for (; i + vec_len_f32_avx512 <= S; i+= vec_len_f32_avx512) { + for (; i + vec_len_f32_avx512 <= S; i += vec_len_f32_avx512) { auto* src = temp + i; auto result_vec_fp32 = _mm512_setzero_ps(); for (size_t m = 0; m < M; m++) { @@ -903,11 +907,16 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, for (size_t iwork = start; iwork < end; ++iwork) { auto p = past_k_scale_zp.ptr(pk, 0, h_group); #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - if (std::is_same::value && std::is_same::value && std::is_same::value) { + if (std::is_same::value && std::is_same::value && + std::is_same::value) { auto p_k = present_key.ptr(0, h_group, pk); prefetch_bytes(S, _MM_HINT_T0, 4096, p_k); - auto _qk = dot_product_fp16(query.ptr(0, h_group), p_k, - S, p, p + 1, head_sum.ptr(0, h_group)); + auto _qk = dot_product_fp16(query.ptr(0, h_group), + p_k, + S, + p, + p + 1, + head_sum.ptr(0, h_group)); buf_attn_w.ptr(0, h_group, 0)[pk] = _qk; parallel_it_step(pk, kv_len, b, B, h_group, h_group_num); continue; @@ -915,8 +924,9 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, #endif auto p_k = present_key.ptr(0, h_group, pk); prefetch_bytes(S, _MM_HINT_T0, 4096, p_k); - buf_attn_w.ptr(0, h_group, 0)[pk] = dot_product(query.ptr(0, h_group), p_k, - S, p, p + 1, head_sum.ptr(0, h_group));; + buf_attn_w.ptr(0, h_group, 0)[pk] = + dot_product(query.ptr(0, h_group), p_k, S, p, p + 1, head_sum.ptr(0, h_group)); + ; parallel_it_step(pk, kv_len, b, B, h_group, h_group_num); } } else { @@ -924,10 +934,15 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, auto b_kv = beams ? beams.ptr(b)[pk] : b; auto p = past_k_scale_zp.ptr(pk, b_kv, h_group); #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - if (std::is_same::value && std::is_same::value && std::is_same::value) { + if (std::is_same::value && std::is_same::value && + std::is_same::value) { auto p_k = present_key.ptr(b_kv, h_group, pk); - auto _qk = dot_product_fp16(query.ptr(b, h_group), p_k, - S, p, p + 1, head_sum.ptr(b, h_group)); + auto _qk = dot_product_fp16(query.ptr(b, h_group), + p_k, + S, + p, + p + 1, + head_sum.ptr(b, h_group)); buf_attn_w.ptr(b, h_group, 0)[pk] = _qk; parallel_it_step(pk, kv_len, b, B, h_group, h_group_num); continue; @@ -935,8 +950,7 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, #endif auto p_k = present_key.ptr(b_kv, h_group, pk); buf_attn_w.ptr(b, h_group, 0)[pk] = - dot_product(query.ptr(b, h_group), p_k, - S, p, p + 1, head_sum.ptr(b, h_group)); + dot_product(query.ptr(b, h_group), p_k, S, p, p + 1, head_sum.ptr(b, h_group)); parallel_it_step(pk, kv_len, b, B, h_group, h_group_num); } } @@ -947,17 +961,25 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, auto p = past_k_scale_zp.ptr(pk, b_kv, h_group); for (size_t h = h_group * h_each_group_len; h < (h_group + 1) * h_each_group_len; h++) { #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - if (std::is_same::value && std::is_same::value && std::is_same::value) { + if (std::is_same::value && std::is_same::value && + std::is_same::value) { auto p_k = present_key.ptr(b_kv, h_group, pk); - auto _qk = dot_product_fp16(query.ptr(b, h, pq), p_k, - S, p, p + 1, head_sum.ptr(b, h, pq)); + auto _qk = dot_product_fp16(query.ptr(b, h, pq), + p_k, + S, + p, + p + 1, + head_sum.ptr(b, h, pq)); buf_attn_w.ptr(b, h, pq)[pk] = _qk; continue; } #endif - buf_attn_w.ptr(b, h, pq)[pk] = - dot_product(query.ptr(b, h, pq), present_key.ptr(b_kv, h_group, pk), - S, p, p + 1, head_sum.ptr(b, h, pq)); + buf_attn_w.ptr(b, h, pq)[pk] = dot_product(query.ptr(b, h, pq), + present_key.ptr(b_kv, h_group, pk), + S, + p, + p + 1, + head_sum.ptr(b, h, pq)); } } parallel_it_step(pk, kv_len, b, B, h_group, h_group_num); @@ -1001,7 +1023,8 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, auto* v = present_value.ptr(b_kv, h_group, pv); auto p = past_v_scale_zp.ptr(pv, b_kv, h_group); for (size_t pq = 0; pq < q_len; pq++) { - for (size_t h = h_group * h_each_group_len, group_idx = 0; h < (h_group + 1) * h_each_group_len; h++, group_idx++) { + for (size_t h = h_group * h_each_group_len, group_idx = 0; h < (h_group + 1) * h_each_group_len; + h++, group_idx++) { attn_acc_value(buf_attn_score.ptr(ithr, pq, group_idx), buf_attn_w.ptr(b, h, pq)[pv], v, @@ -1014,7 +1037,7 @@ static void mha_single_token_kernel(const ov::intel_cpu::PlainTensor& query, // convert to dst for (size_t pq = 0; pq < q_len; pq++) { for (size_t h = h_group * h_each_group_len, group_idx = 0; h < (h_group + 1) * h_each_group_len; - h++, group_idx++) { + h++, group_idx++) { auto* dst = has_out_transpose ? output_emb.ptr(b, pq, h * SV) : output_emb.ptr(b, h, pq); cvt_copy(dst, buf_attn_score.ptr(ithr, pq, group_idx), SV); } diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp index e29e2bae0aa07a..2ef0f62d7e0df0 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.hpp @@ -6,8 +6,9 @@ #include #include #include -#include #include +#include + #include "utils/plain_tensor.hpp" namespace ov { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.cpp index 28755e69eaf589..c02f9770a37be9 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.cpp @@ -13,10 +13,10 @@ # include #endif +#include "common.hpp" #include "openvino/core/type/bfloat16.hpp" #include "softmax.hpp" #include "softmax_kernel.hpp" -#include "common.hpp" namespace ov { namespace Extensions { @@ -39,13 +39,33 @@ void attn_softmax(void* a, if (precision == ov::element::f16) { auto _a = reinterpret_cast(a); auto _alibi = reinterpret_cast(alibi); - attn_softmax_kernel(_a, a_dst, scale, _alibi, attn_mask, causal_mask, select_nfltmax_at_0, len, total_size, attn_mask_prec, dst_precision); + attn_softmax_kernel(_a, + a_dst, + scale, + _alibi, + attn_mask, + causal_mask, + select_nfltmax_at_0, + len, + total_size, + attn_mask_prec, + dst_precision); return; } #endif auto _a = reinterpret_cast(a); auto _alibi = reinterpret_cast(alibi); - attn_softmax_kernel(_a, a_dst, scale, _alibi, attn_mask, causal_mask, select_nfltmax_at_0, len, total_size, attn_mask_prec, dst_precision); + attn_softmax_kernel(_a, + a_dst, + scale, + _alibi, + attn_mask, + causal_mask, + select_nfltmax_at_0, + len, + total_size, + attn_mask_prec, + dst_precision); } } // namespace XARCH diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.hpp index ee264924e8f256..d620a01e221788 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax.hpp @@ -6,8 +6,8 @@ #include #include #include -#include #include +#include namespace ov { namespace Extensions { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp index 60c6a24ec5f2fa..48b92b53fa2727 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp @@ -3,16 +3,16 @@ // #pragma once -#include "common.hpp" -#include "openvino/core/type/element_type.hpp" - #include #include #include #include +#include "common.hpp" +#include "openvino/core/type/element_type.hpp" + #if defined(OPENVINO_ARCH_ARM64) -#include "arm_neon.h" +# include "arm_neon.h" #endif namespace ov { @@ -22,7 +22,7 @@ namespace XARCH { #if defined(HAVE_AVX2) inline void exp_ps_avx2(__m256& src) { -#define REPEAT8(x) x, x, x, x, x, x, x, x +# define REPEAT8(x) x, x, x, x, x, x, x, x static const uint32_t c_min[] = {REPEAT8(0xc2aeac50)}; static const uint32_t c_max[] = {REPEAT8(0x42b17218)}; static const uint32_t c_e[] = {REPEAT8(0x3fb8aa3b)}; @@ -36,21 +36,21 @@ inline void exp_ps_avx2(__m256& src) { static const uint32_t c_p4[] = {REPEAT8(0x3d2b9d0d)}; static const uint32_t c_p5[] = {REPEAT8(0x3c07cfce)}; static const uint32_t c_2[] = {REPEAT8(0x40000000)}; -#undef REPEAT8 +# undef REPEAT8 static constexpr int n_mantissa_bits = 23; - __m256 exp_ln_flt_min_f = _mm256_loadu_ps(reinterpret_cast(c_min)); // log(FLT_MIN) - __m256 exp_ln_flt_max_f = _mm256_loadu_ps(reinterpret_cast(c_max)); // log(FLT_MAX) - __m256 exp_log2ef = _mm256_loadu_ps(reinterpret_cast(c_e)); // log2(e) - __m256 half = _mm256_loadu_ps(reinterpret_cast(c_half)); // 0.5f - __m256 ln2f = _mm256_loadu_ps(reinterpret_cast(c_ln2)); // ln(2) - __m256 one = _mm256_loadu_ps(reinterpret_cast(c_1)); // 1.0f - __m256i exponent_bias = _mm256_loadu_si256(reinterpret_cast(c_bias));// 127 - __m256 exp_pol1 = _mm256_loadu_ps(reinterpret_cast(c_p1)); // p1 = 0.999999701f - __m256 exp_pol2 = _mm256_loadu_ps(reinterpret_cast(c_p2)); // p2 = 0.499991506f - __m256 exp_pol3 = _mm256_loadu_ps(reinterpret_cast(c_p3)); // p3 = 0.166676521f - __m256 exp_pol4 = _mm256_loadu_ps(reinterpret_cast(c_p4)); // p4 = 0.0418978221f - __m256 exp_pol5 = _mm256_loadu_ps(reinterpret_cast(c_p5)); // p5 = 0.00828929059f - __m256 two = _mm256_loadu_ps(reinterpret_cast(c_2)); // 2 + __m256 exp_ln_flt_min_f = _mm256_loadu_ps(reinterpret_cast(c_min)); // log(FLT_MIN) + __m256 exp_ln_flt_max_f = _mm256_loadu_ps(reinterpret_cast(c_max)); // log(FLT_MAX) + __m256 exp_log2ef = _mm256_loadu_ps(reinterpret_cast(c_e)); // log2(e) + __m256 half = _mm256_loadu_ps(reinterpret_cast(c_half)); // 0.5f + __m256 ln2f = _mm256_loadu_ps(reinterpret_cast(c_ln2)); // ln(2) + __m256 one = _mm256_loadu_ps(reinterpret_cast(c_1)); // 1.0f + __m256i exponent_bias = _mm256_loadu_si256(reinterpret_cast(c_bias)); // 127 + __m256 exp_pol1 = _mm256_loadu_ps(reinterpret_cast(c_p1)); // p1 = 0.999999701f + __m256 exp_pol2 = _mm256_loadu_ps(reinterpret_cast(c_p2)); // p2 = 0.499991506f + __m256 exp_pol3 = _mm256_loadu_ps(reinterpret_cast(c_p3)); // p3 = 0.166676521f + __m256 exp_pol4 = _mm256_loadu_ps(reinterpret_cast(c_p4)); // p4 = 0.0418978221f + __m256 exp_pol5 = _mm256_loadu_ps(reinterpret_cast(c_p5)); // p5 = 0.00828929059f + __m256 two = _mm256_loadu_ps(reinterpret_cast(c_2)); // 2 // exp(x) = // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression @@ -195,32 +195,33 @@ inline void scale_add2_reduce_max(float* a, // process vector body // unroll to avoid dependency caused by _mm256_max_ps for (; i + 4 * vec_len_f32_avx512 <= size; i += 4 * vec_len_f32_avx512) { - #define ITEM(n) \ - v_a = _mm512_loadu_ps(a + i + n * vec_len_f32_avx512); \ - v_a = _mm512_mul_ps(v_a, v_scale); \ - if (has_alibi) { \ - auto v_lookup = _mm512_loadu_ps(alibi_lookup + i + n * vec_len_f32_avx512); \ - v_a = _mm512_fmadd_ps(v_lookup, v_alibi_slope, v_a); \ - } \ - if (has_attn_mask) { \ - auto v_mask = mm512_uni_loadu_ps(attn_mask + i + n * vec_len_f32_avx512); \ - v_a = _mm512_add_ps(v_a, v_mask); \ - } \ - if (has_causal_mask) { \ - auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i + n * vec_len_f32_avx512)); \ - auto v_maski32 = _mm512_cvtepi8_epi32(v_maski8); \ - auto kmask = _mm512_cmp_epi32_mask(v_maski32, v_zeroi32, _MM_CMPINT_NE); \ - kmask = _kxor_mask16(kmask, kmask_xor); \ - v_a = _mm512_mask_blend_ps(kmask, v_a, v_nfltmax); \ - } \ - v_max##n = _mm512_max_ps(v_max##n, v_a); \ +# define ITEM(n) \ + v_a = _mm512_loadu_ps(a + i + n * vec_len_f32_avx512); \ + v_a = _mm512_mul_ps(v_a, v_scale); \ + if (has_alibi) { \ + auto v_lookup = _mm512_loadu_ps(alibi_lookup + i + n * vec_len_f32_avx512); \ + v_a = _mm512_fmadd_ps(v_lookup, v_alibi_slope, v_a); \ + } \ + if (has_attn_mask) { \ + auto v_mask = mm512_uni_loadu_ps(attn_mask + i + n * vec_len_f32_avx512); \ + v_a = _mm512_add_ps(v_a, v_mask); \ + } \ + if (has_causal_mask) { \ + auto v_maski8 = \ + _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i + n * vec_len_f32_avx512)); \ + auto v_maski32 = _mm512_cvtepi8_epi32(v_maski8); \ + auto kmask = _mm512_cmp_epi32_mask(v_maski32, v_zeroi32, _MM_CMPINT_NE); \ + kmask = _kxor_mask16(kmask, kmask_xor); \ + v_a = _mm512_mask_blend_ps(kmask, v_a, v_nfltmax); \ + } \ + v_max##n = _mm512_max_ps(v_max##n, v_a); \ _mm512_storeu_ps(a + i + n * vec_len_f32_avx512, v_a); ITEM(0); ITEM(1); ITEM(2); ITEM(3); - #undef ITEM +# undef ITEM } while (i + vec_len_f32_avx512 <= size) { v_a = _mm512_loadu_ps(a + i); @@ -295,32 +296,32 @@ inline void scale_add2_reduce_max(float* a, // process vector body // unroll to avoid dependency caused by _mm512_max_ps for (; i + 4 * vec_len_f32_avx2 <= size; i += 4 * vec_len_f32_avx2) { - #define ITEM(n) \ - v_a = _mm256_loadu_ps(a + i + n * vec_len_f32_avx2); \ - v_a = _mm256_mul_ps(v_a, v_scale); \ - if (has_alibi) { \ - auto v_lookup = _mm256_loadu_ps(alibi_lookup + i + n * vec_len_f32_avx2); \ - v_a = _mm256_fmadd_ps(v_lookup, v_alibi_slope, v_a); \ - } \ - if (has_attn_mask) { \ - auto v_mask = mm256_uni_loadu_ps(attn_mask + i + n * vec_len_f32_avx2); \ - v_a = _mm256_add_ps(v_a, v_mask); \ - } \ - if (has_causal_mask) { \ - auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i + n * vec_len_f32_avx2)); \ - auto v_maski32 = _mm256_cvtepi8_epi32(v_maski8); \ - v_maski32 = _mm256_cmpeq_epi32(v_maski32, v_zeroi32);\ - v_maski32 = _mm256_xor_si256(v_maski32, v_mask_xor);\ - v_a = _mm256_blendv_ps(v_nfltmax, v_a, _mm256_castsi256_ps(v_maski32)); \ - } \ - v_max##n = _mm256_max_ps(v_max##n, v_a); \ +# define ITEM(n) \ + v_a = _mm256_loadu_ps(a + i + n * vec_len_f32_avx2); \ + v_a = _mm256_mul_ps(v_a, v_scale); \ + if (has_alibi) { \ + auto v_lookup = _mm256_loadu_ps(alibi_lookup + i + n * vec_len_f32_avx2); \ + v_a = _mm256_fmadd_ps(v_lookup, v_alibi_slope, v_a); \ + } \ + if (has_attn_mask) { \ + auto v_mask = mm256_uni_loadu_ps(attn_mask + i + n * vec_len_f32_avx2); \ + v_a = _mm256_add_ps(v_a, v_mask); \ + } \ + if (has_causal_mask) { \ + auto v_maski8 = _mm_loadu_si128(reinterpret_cast<__m128i const*>(causal_mask + i + n * vec_len_f32_avx2)); \ + auto v_maski32 = _mm256_cvtepi8_epi32(v_maski8); \ + v_maski32 = _mm256_cmpeq_epi32(v_maski32, v_zeroi32); \ + v_maski32 = _mm256_xor_si256(v_maski32, v_mask_xor); \ + v_a = _mm256_blendv_ps(v_nfltmax, v_a, _mm256_castsi256_ps(v_maski32)); \ + } \ + v_max##n = _mm256_max_ps(v_max##n, v_a); \ _mm256_storeu_ps(a + i + n * vec_len_f32_avx2, v_a); ITEM(0); ITEM(1); ITEM(2); ITEM(3); - #undef ITEM +# undef ITEM } while (i + vec_len_f32_avx2 <= size) { @@ -415,7 +416,7 @@ inline void scale_add2_reduce_max(float* a, uint32x4_t v_maski32[2] = {v_maski32_low, v_maski32_high}; for (int j = 0; j < 2; ++j) { uint32x4_t kmask = vceqq_u32(v_maski32[j], v_zeroi32); // ==0 - v_a = vbslq_f32(kmask, v_nfltmax, v_a); // mask => -FLT_MAX + v_a = vbslq_f32(kmask, v_nfltmax, v_a); // mask => -FLT_MAX } } @@ -521,7 +522,7 @@ inline void scale_add2_reduce_max(ov::float16* a, #if defined(HAVE_AVX512F) static inline void exp_ps_avx512(__m512& src) { -#define REPEAT16(x) x, x, x, x, x, x, x, x, x, x, x, x, x, x, x, x +# define REPEAT16(x) x, x, x, x, x, x, x, x, x, x, x, x, x, x, x, x static const uint32_t c_min[] = {REPEAT16(0xc2aeac50)}; static const uint32_t c_max[] = {REPEAT16(0x42b17218)}; static const uint32_t c_e[] = {REPEAT16(0x3fb8aa3b)}; @@ -535,21 +536,21 @@ static inline void exp_ps_avx512(__m512& src) { static const uint32_t c_p4[] = {REPEAT16(0x3d2b9d0d)}; static const uint32_t c_p5[] = {REPEAT16(0x3c07cfce)}; static const uint32_t c_2[] = {REPEAT16(0x40000000)}; -#undef REPEAT16 +# undef REPEAT16 static constexpr int n_mantissa_bits = 23; - __m512 exp_ln_flt_min_f = _mm512_loadu_ps(reinterpret_cast(c_min)); // log(FLT_MIN) - __m512 exp_ln_flt_max_f = _mm512_loadu_ps(reinterpret_cast(c_max)); // log(FLT_MAX) - __m512 exp_log2ef = _mm512_loadu_ps(reinterpret_cast(c_e)); // log2(e) - __m512 half = _mm512_loadu_ps(reinterpret_cast(c_half)); // 0.5f - __m512 ln2f = _mm512_loadu_ps(reinterpret_cast(c_ln2)); // ln(2) - __m512 one = _mm512_loadu_ps(reinterpret_cast(c_1)); // 1.0f - __m512i exponent_bias = _mm512_loadu_si512(c_bias); // 127 - __m512 exp_pol1 = _mm512_loadu_ps(reinterpret_cast(c_p1)); // p1 = 0.999999701f - __m512 exp_pol2 = _mm512_loadu_ps(reinterpret_cast(c_p2)); // p2 = 0.499991506f - __m512 exp_pol3 = _mm512_loadu_ps(reinterpret_cast(c_p3)); // p3 = 0.166676521f - __m512 exp_pol4 = _mm512_loadu_ps(reinterpret_cast(c_p4)); // p4 = 0.0418978221f - __m512 exp_pol5 = _mm512_loadu_ps(reinterpret_cast(c_p5)); // p5 = 0.00828929059f - __m512 two = _mm512_loadu_ps(reinterpret_cast(c_2)); // 2 + __m512 exp_ln_flt_min_f = _mm512_loadu_ps(reinterpret_cast(c_min)); // log(FLT_MIN) + __m512 exp_ln_flt_max_f = _mm512_loadu_ps(reinterpret_cast(c_max)); // log(FLT_MAX) + __m512 exp_log2ef = _mm512_loadu_ps(reinterpret_cast(c_e)); // log2(e) + __m512 half = _mm512_loadu_ps(reinterpret_cast(c_half)); // 0.5f + __m512 ln2f = _mm512_loadu_ps(reinterpret_cast(c_ln2)); // ln(2) + __m512 one = _mm512_loadu_ps(reinterpret_cast(c_1)); // 1.0f + __m512i exponent_bias = _mm512_loadu_si512(c_bias); // 127 + __m512 exp_pol1 = _mm512_loadu_ps(reinterpret_cast(c_p1)); // p1 = 0.999999701f + __m512 exp_pol2 = _mm512_loadu_ps(reinterpret_cast(c_p2)); // p2 = 0.499991506f + __m512 exp_pol3 = _mm512_loadu_ps(reinterpret_cast(c_p3)); // p3 = 0.166676521f + __m512 exp_pol4 = _mm512_loadu_ps(reinterpret_cast(c_p4)); // p4 = 0.0418978221f + __m512 exp_pol5 = _mm512_loadu_ps(reinterpret_cast(c_p5)); // p5 = 0.00828929059f + __m512 two = _mm512_loadu_ps(reinterpret_cast(c_2)); // 2 // exp(x) = // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem // = 2^n * exp(r) // simplify the exp(n*ln(2)) expression @@ -793,7 +794,9 @@ inline void multiply_scalar(float* a, float* a_dst, const float val, const size_ } } -template::value || std::is_same::value), bool>::type> +template ::value || std::is_same::value), bool>::type> inline void multiply_scalar(float* a, T* a_dst, const float val, const size_t size) { size_t i = 0; #if defined(HAVE_AVX512F) @@ -899,47 +902,68 @@ inline void attn_softmax_kernel(float* a, ov::element::Type attn_mask_prec, ov::element::Type dst_precision, float alibi_slope) { - using func_fp32_type = void (*)(float*, float, const float*, const float*, const uint8_t*, bool, size_t, float, float&); - using func_bf16_type = void (*)(float*, float, const float*, const ov::bfloat16*, const uint8_t*, bool, size_t, float, float&); - using func_f16_type = void (*)(float*, float, const float*, const ov::float16*, const uint8_t*, bool, size_t, float, float&); - static constexpr func_fp32_type funcs_fp32[] = { - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max - }; - static constexpr func_bf16_type funcs_bf16[] = { - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max - }; - static constexpr func_f16_type funcs_f16[] = { - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max - }; + using func_fp32_type = + void (*)(float*, float, const float*, const float*, const uint8_t*, bool, size_t, float, float&); + using func_bf16_type = + void (*)(float*, float, const float*, const ov::bfloat16*, const uint8_t*, bool, size_t, float, float&); + using func_f16_type = + void (*)(float*, float, const float*, const ov::float16*, const uint8_t*, bool, size_t, float, float&); + static constexpr func_fp32_type funcs_fp32[] = {scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max}; + static constexpr func_bf16_type funcs_bf16[] = {scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max}; + static constexpr func_f16_type funcs_f16[] = {scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max}; int dispatch = (alibi ? 0b100 : 0) | (attn_mask ? 0b010 : 0) | (causal_mask ? 0b001 : 0); float max = std::numeric_limits::lowest(); if (attn_mask_prec == ov::element::f32) { - funcs_fp32[dispatch](a, scale, alibi, static_cast(attn_mask), causal_mask, select_nfltmax_at_0, len, alibi_slope, max); + funcs_fp32[dispatch](a, + scale, + alibi, + static_cast(attn_mask), + causal_mask, + select_nfltmax_at_0, + len, + alibi_slope, + max); } else if (attn_mask_prec == ov::element::bf16) { - funcs_bf16[dispatch](a, scale, alibi, static_cast(attn_mask), causal_mask, select_nfltmax_at_0, len, alibi_slope, max); + funcs_bf16[dispatch](a, + scale, + alibi, + static_cast(attn_mask), + causal_mask, + select_nfltmax_at_0, + len, + alibi_slope, + max); } else { - funcs_f16[dispatch](a, scale, alibi, static_cast(attn_mask), causal_mask, select_nfltmax_at_0, len, alibi_slope, max); + funcs_f16[dispatch](a, + scale, + alibi, + static_cast(attn_mask), + causal_mask, + select_nfltmax_at_0, + len, + alibi_slope, + max); } float sum = 0.0f; @@ -978,47 +1002,89 @@ inline void attn_softmax_kernel(ov::float16* a, ov::element::Type attn_mask_prec, ov::element::Type dst_precision, float alibi_slope) { - using func_fp32_type = void (*)(ov::float16*, float, const ov::float16*, const float*, const uint8_t*, bool, size_t, float, ov::float16&); - using func_bf16_type = void (*)(ov::float16*, float, const ov::float16*, const ov::bfloat16*, const uint8_t*, bool, size_t, float, ov::float16&); - using func_fp16_type = void (*)(ov::float16*, float, const ov::float16*, const ov::float16*, const uint8_t*, bool, size_t, float, ov::float16&); - static constexpr func_fp32_type funcs_fp32[] = { - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max - }; - static constexpr func_bf16_type funcs_bf16[] = { - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max - }; - static constexpr func_fp16_type funcs_fp16[] = { - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max, - scale_add2_reduce_max - }; + using func_fp32_type = void (*)(ov::float16*, + float, + const ov::float16*, + const float*, + const uint8_t*, + bool, + size_t, + float, + ov::float16&); + using func_bf16_type = void (*)(ov::float16*, + float, + const ov::float16*, + const ov::bfloat16*, + const uint8_t*, + bool, + size_t, + float, + ov::float16&); + using func_fp16_type = void (*)(ov::float16*, + float, + const ov::float16*, + const ov::float16*, + const uint8_t*, + bool, + size_t, + float, + ov::float16&); + static constexpr func_fp32_type funcs_fp32[] = {scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max}; + static constexpr func_bf16_type funcs_bf16[] = {scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max}; + static constexpr func_fp16_type funcs_fp16[] = {scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max, + scale_add2_reduce_max}; int dispatch = (alibi ? 0b100 : 0) | (attn_mask ? 0b010 : 0) | (causal_mask ? 0b001 : 0); ov::float16 max = std::numeric_limits::lowest(); if (attn_mask_prec == ov::element::f32) { - funcs_fp32[dispatch](a, scale, alibi, static_cast(attn_mask), causal_mask, select_nfltmax_at_0, len, alibi_slope, max); + funcs_fp32[dispatch](a, + scale, + alibi, + static_cast(attn_mask), + causal_mask, + select_nfltmax_at_0, + len, + alibi_slope, + max); } else if (attn_mask_prec == ov::element::f16) { - funcs_fp16[dispatch](a, scale, alibi, static_cast(attn_mask), causal_mask, select_nfltmax_at_0, len, alibi_slope, max); + funcs_fp16[dispatch](a, + scale, + alibi, + static_cast(attn_mask), + causal_mask, + select_nfltmax_at_0, + len, + alibi_slope, + max); } else { - funcs_bf16[dispatch](a, scale, alibi, static_cast(attn_mask), causal_mask, select_nfltmax_at_0, len, alibi_slope, max); + funcs_bf16[dispatch](a, + scale, + alibi, + static_cast(attn_mask), + causal_mask, + select_nfltmax_at_0, + len, + alibi_slope, + max); } ov::float16 sum = 0.0f; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose_kernel.hpp index b719246e4976a1..93d7db55107951 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/transpose_kernel.hpp @@ -3,96 +3,108 @@ // #pragma once -#include "common.hpp" -#include "openvino/core/type/element_type.hpp" - #include #include #include #include +#include "common.hpp" +#include "openvino/core/type/element_type.hpp" + namespace ov { namespace Extensions { namespace Cpu { namespace XARCH { #if defined(HAVE_AVX512F) -inline void transpose_m512i_16x16(__m512i& r0, __m512i& r1, __m512i& r2, __m512i& r3, - __m512i& r4, __m512i& r5, __m512i& r6, __m512i& r7, - __m512i& r8, __m512i& r9, __m512i& ra, __m512i& rb, - __m512i& rc, __m512i& rd, __m512i& re, __m512i& rf) { +inline void transpose_m512i_16x16(__m512i& r0, + __m512i& r1, + __m512i& r2, + __m512i& r3, + __m512i& r4, + __m512i& r5, + __m512i& r6, + __m512i& r7, + __m512i& r8, + __m512i& r9, + __m512i& ra, + __m512i& rb, + __m512i& rc, + __m512i& rd, + __m512i& re, + __m512i& rf) { __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf; - t0 = _mm512_unpacklo_epi32(r0, r1); // 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29 - t1 = _mm512_unpackhi_epi32(r0, r1); // 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31 - t2 = _mm512_unpacklo_epi32(r2, r3); // 32 48 33 49 ... - t3 = _mm512_unpackhi_epi32(r2, r3); // 34 50 35 51 ... - t4 = _mm512_unpacklo_epi32(r4, r5); // 64 80 65 81 ... - t5 = _mm512_unpackhi_epi32(r4, r5); // 66 82 67 83 ... - t6 = _mm512_unpacklo_epi32(r6, r7); // 96 112 97 113 ... - t7 = _mm512_unpackhi_epi32(r6, r7); // 98 114 99 115 ... - t8 = _mm512_unpacklo_epi32(r8, r9); // 128 ... - t9 = _mm512_unpackhi_epi32(r8, r9); // 130 ... - ta = _mm512_unpacklo_epi32(ra, rb); // 160 ... - tb = _mm512_unpackhi_epi32(ra, rb); // 162 ... - tc = _mm512_unpacklo_epi32(rc, rd); // 196 ... - td = _mm512_unpackhi_epi32(rc, rd); // 198 ... - te = _mm512_unpacklo_epi32(re, rf); // 228 ... - tf = _mm512_unpackhi_epi32(re, rf); // 230 ... - - r0 = _mm512_unpacklo_epi64(t0, t2); // 0 16 32 48 ... - r1 = _mm512_unpackhi_epi64(t0, t2); // 1 17 33 49 ... - r2 = _mm512_unpacklo_epi64(t1, t3); // 2 18 34 49 ... - r3 = _mm512_unpackhi_epi64(t1, t3); // 3 19 35 51 ... - r4 = _mm512_unpacklo_epi64(t4, t6); // 64 80 96 112 ... - r5 = _mm512_unpackhi_epi64(t4, t6); // 65 81 97 114 ... - r6 = _mm512_unpacklo_epi64(t5, t7); // 66 82 98 113 ... - r7 = _mm512_unpackhi_epi64(t5, t7); // 67 83 99 115 ... - r8 = _mm512_unpacklo_epi64(t8, ta); // 128 144 160 176 ... - r9 = _mm512_unpackhi_epi64(t8, ta); // 129 145 161 178 ... - ra = _mm512_unpacklo_epi64(t9, tb); // 130 146 162 177 ... - rb = _mm512_unpackhi_epi64(t9, tb); // 131 147 163 179 ... - rc = _mm512_unpacklo_epi64(tc, te); // 192 208 228 240 ... - rd = _mm512_unpackhi_epi64(tc, te); // 193 209 229 241 ... - re = _mm512_unpacklo_epi64(td, tf); // 194 210 230 242 ... - rf = _mm512_unpackhi_epi64(td, tf); // 195 211 231 243 ... - - t0 = _mm512_shuffle_i32x4(r0, r4, 0x88); // 0 16 32 48 8 24 40 56 64 80 96 112 ... - t1 = _mm512_shuffle_i32x4(r1, r5, 0x88); // 1 17 33 49 ... - t2 = _mm512_shuffle_i32x4(r2, r6, 0x88); // 2 18 34 50 ... - t3 = _mm512_shuffle_i32x4(r3, r7, 0x88); // 3 19 35 51 ... - t4 = _mm512_shuffle_i32x4(r0, r4, 0xdd); // 4 20 36 52 ... - t5 = _mm512_shuffle_i32x4(r1, r5, 0xdd); // 5 21 37 53 ... - t6 = _mm512_shuffle_i32x4(r2, r6, 0xdd); // 6 22 38 54 ... - t7 = _mm512_shuffle_i32x4(r3, r7, 0xdd); // 7 23 39 55 ... - t8 = _mm512_shuffle_i32x4(r8, rc, 0x88); // 128 144 160 176 ... - t9 = _mm512_shuffle_i32x4(r9, rd, 0x88); // 129 145 161 177 ... - ta = _mm512_shuffle_i32x4(ra, re, 0x88); // 130 146 162 178 ... - tb = _mm512_shuffle_i32x4(rb, rf, 0x88); // 131 147 163 179 ... - tc = _mm512_shuffle_i32x4(r8, rc, 0xdd); // 132 148 164 180 ... - td = _mm512_shuffle_i32x4(r9, rd, 0xdd); // 133 149 165 181 ... - te = _mm512_shuffle_i32x4(ra, re, 0xdd); // 134 150 166 182 ... - tf = _mm512_shuffle_i32x4(rb, rf, 0xdd); // 135 151 167 183 ... - - r0 = _mm512_shuffle_i32x4(t0, t8, 0x88); // 0 16 32 48 64 80 96 112 ... 240 - r1 = _mm512_shuffle_i32x4(t1, t9, 0x88); // 1 17 33 49 66 81 97 113 ... 241 - r2 = _mm512_shuffle_i32x4(t2, ta, 0x88); // 2 18 34 50 67 82 98 114 ... 242 - r3 = _mm512_shuffle_i32x4(t3, tb, 0x88); // 3 19 35 51 68 83 99 115 ... 243 - r4 = _mm512_shuffle_i32x4(t4, tc, 0x88); // 4 ... - r5 = _mm512_shuffle_i32x4(t5, td, 0x88); // 5 ... - r6 = _mm512_shuffle_i32x4(t6, te, 0x88); // 6 ... - r7 = _mm512_shuffle_i32x4(t7, tf, 0x88); // 7 ... - r8 = _mm512_shuffle_i32x4(t0, t8, 0xdd); // 8 ... - r9 = _mm512_shuffle_i32x4(t1, t9, 0xdd); // 9 ... - ra = _mm512_shuffle_i32x4(t2, ta, 0xdd); // 10 ... - rb = _mm512_shuffle_i32x4(t3, tb, 0xdd); // 11 ... - rc = _mm512_shuffle_i32x4(t4, tc, 0xdd); // 12 ... - rd = _mm512_shuffle_i32x4(t5, td, 0xdd); // 13 ... - re = _mm512_shuffle_i32x4(t6, te, 0xdd); // 14 ... - rf = _mm512_shuffle_i32x4(t7, tf, 0xdd); // 15 31 47 63 79 96 111 127 ... 255 + t0 = _mm512_unpacklo_epi32(r0, r1); // 0 16 1 17 4 20 5 21 8 24 9 25 12 28 13 29 + t1 = _mm512_unpackhi_epi32(r0, r1); // 2 18 3 19 6 22 7 23 10 26 11 27 14 30 15 31 + t2 = _mm512_unpacklo_epi32(r2, r3); // 32 48 33 49 ... + t3 = _mm512_unpackhi_epi32(r2, r3); // 34 50 35 51 ... + t4 = _mm512_unpacklo_epi32(r4, r5); // 64 80 65 81 ... + t5 = _mm512_unpackhi_epi32(r4, r5); // 66 82 67 83 ... + t6 = _mm512_unpacklo_epi32(r6, r7); // 96 112 97 113 ... + t7 = _mm512_unpackhi_epi32(r6, r7); // 98 114 99 115 ... + t8 = _mm512_unpacklo_epi32(r8, r9); // 128 ... + t9 = _mm512_unpackhi_epi32(r8, r9); // 130 ... + ta = _mm512_unpacklo_epi32(ra, rb); // 160 ... + tb = _mm512_unpackhi_epi32(ra, rb); // 162 ... + tc = _mm512_unpacklo_epi32(rc, rd); // 196 ... + td = _mm512_unpackhi_epi32(rc, rd); // 198 ... + te = _mm512_unpacklo_epi32(re, rf); // 228 ... + tf = _mm512_unpackhi_epi32(re, rf); // 230 ... + + r0 = _mm512_unpacklo_epi64(t0, t2); // 0 16 32 48 ... + r1 = _mm512_unpackhi_epi64(t0, t2); // 1 17 33 49 ... + r2 = _mm512_unpacklo_epi64(t1, t3); // 2 18 34 49 ... + r3 = _mm512_unpackhi_epi64(t1, t3); // 3 19 35 51 ... + r4 = _mm512_unpacklo_epi64(t4, t6); // 64 80 96 112 ... + r5 = _mm512_unpackhi_epi64(t4, t6); // 65 81 97 114 ... + r6 = _mm512_unpacklo_epi64(t5, t7); // 66 82 98 113 ... + r7 = _mm512_unpackhi_epi64(t5, t7); // 67 83 99 115 ... + r8 = _mm512_unpacklo_epi64(t8, ta); // 128 144 160 176 ... + r9 = _mm512_unpackhi_epi64(t8, ta); // 129 145 161 178 ... + ra = _mm512_unpacklo_epi64(t9, tb); // 130 146 162 177 ... + rb = _mm512_unpackhi_epi64(t9, tb); // 131 147 163 179 ... + rc = _mm512_unpacklo_epi64(tc, te); // 192 208 228 240 ... + rd = _mm512_unpackhi_epi64(tc, te); // 193 209 229 241 ... + re = _mm512_unpacklo_epi64(td, tf); // 194 210 230 242 ... + rf = _mm512_unpackhi_epi64(td, tf); // 195 211 231 243 ... + + t0 = _mm512_shuffle_i32x4(r0, r4, 0x88); // 0 16 32 48 8 24 40 56 64 80 96 112 ... + t1 = _mm512_shuffle_i32x4(r1, r5, 0x88); // 1 17 33 49 ... + t2 = _mm512_shuffle_i32x4(r2, r6, 0x88); // 2 18 34 50 ... + t3 = _mm512_shuffle_i32x4(r3, r7, 0x88); // 3 19 35 51 ... + t4 = _mm512_shuffle_i32x4(r0, r4, 0xdd); // 4 20 36 52 ... + t5 = _mm512_shuffle_i32x4(r1, r5, 0xdd); // 5 21 37 53 ... + t6 = _mm512_shuffle_i32x4(r2, r6, 0xdd); // 6 22 38 54 ... + t7 = _mm512_shuffle_i32x4(r3, r7, 0xdd); // 7 23 39 55 ... + t8 = _mm512_shuffle_i32x4(r8, rc, 0x88); // 128 144 160 176 ... + t9 = _mm512_shuffle_i32x4(r9, rd, 0x88); // 129 145 161 177 ... + ta = _mm512_shuffle_i32x4(ra, re, 0x88); // 130 146 162 178 ... + tb = _mm512_shuffle_i32x4(rb, rf, 0x88); // 131 147 163 179 ... + tc = _mm512_shuffle_i32x4(r8, rc, 0xdd); // 132 148 164 180 ... + td = _mm512_shuffle_i32x4(r9, rd, 0xdd); // 133 149 165 181 ... + te = _mm512_shuffle_i32x4(ra, re, 0xdd); // 134 150 166 182 ... + tf = _mm512_shuffle_i32x4(rb, rf, 0xdd); // 135 151 167 183 ... + + r0 = _mm512_shuffle_i32x4(t0, t8, 0x88); // 0 16 32 48 64 80 96 112 ... 240 + r1 = _mm512_shuffle_i32x4(t1, t9, 0x88); // 1 17 33 49 66 81 97 113 ... 241 + r2 = _mm512_shuffle_i32x4(t2, ta, 0x88); // 2 18 34 50 67 82 98 114 ... 242 + r3 = _mm512_shuffle_i32x4(t3, tb, 0x88); // 3 19 35 51 68 83 99 115 ... 243 + r4 = _mm512_shuffle_i32x4(t4, tc, 0x88); // 4 ... + r5 = _mm512_shuffle_i32x4(t5, td, 0x88); // 5 ... + r6 = _mm512_shuffle_i32x4(t6, te, 0x88); // 6 ... + r7 = _mm512_shuffle_i32x4(t7, tf, 0x88); // 7 ... + r8 = _mm512_shuffle_i32x4(t0, t8, 0xdd); // 8 ... + r9 = _mm512_shuffle_i32x4(t1, t9, 0xdd); // 9 ... + ra = _mm512_shuffle_i32x4(t2, ta, 0xdd); // 10 ... + rb = _mm512_shuffle_i32x4(t3, tb, 0xdd); // 11 ... + rc = _mm512_shuffle_i32x4(t4, tc, 0xdd); // 12 ... + rd = _mm512_shuffle_i32x4(t5, td, 0xdd); // 13 ... + re = _mm512_shuffle_i32x4(t6, te, 0xdd); // 14 ... + rf = _mm512_shuffle_i32x4(t7, tf, 0xdd); // 15 31 47 63 79 96 111 127 ... 255 } -template +template inline void transpose_16x16_kernel(float* _dst, T* src, size_t dst_stride, size_t src_stride) { auto* dst = reinterpret_cast(_dst); __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; @@ -133,7 +145,7 @@ inline void transpose_16x16_kernel(float* _dst, T* src, size_t dst_stride, size_ _mm512_storeu_si512(dst + 15 * dst_stride, rf); } -template +template inline void transpose_16xK_kernel(float* _dst, T* src, size_t K, size_t dst_stride, size_t src_stride) { auto* dst = reinterpret_cast(_dst); __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf; @@ -156,24 +168,110 @@ inline void transpose_16xK_kernel(float* _dst, T* src, size_t K, size_t dst_stri transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); -#define S(m) _mm512_storeu_si512(dst + 0x##m * dst_stride, r##m) -#define S8() S(0); S(1); S(2); S(3); S(4); S(5); S(6); S(7); +# define S(m) _mm512_storeu_si512(dst + 0x##m * dst_stride, r##m) +# define S8() \ + S(0); \ + S(1); \ + S(2); \ + S(3); \ + S(4); \ + S(5); \ + S(6); \ + S(7); switch (K) { - case 8: S8(); break; - case 9: S8() S(8); break; - case 10: S8(); S(8); S(9); break; - case 11: S8(); S(8); S(9); S(a); break; - case 12: S8(); S(8); S(9); S(a); S(b); break; - case 13: S8(); S(8); S(9); S(a); S(b); S(c); break; - case 14: S8(); S(8); S(9); S(a); S(b); S(c); S(d); break; - case 15: S8(); S(8); S(9); S(a); S(b); S(c); S(d); S(e); break; - case 1: S(0); break; - case 2: S(0); S(1); break; - case 3: S(0); S(1); S(2); break; - case 4: S(0); S(1); S(2); S(3); break; - case 5: S(0); S(1); S(2); S(3); S(4); break; - case 6: S(0); S(1); S(2); S(3); S(4); S(5); break; - case 7: S(0); S(1); S(2); S(3); S(4); S(5); S(6); break; + case 8: + S8(); + break; + case 9: + S8() S(8); + break; + case 10: + S8(); + S(8); + S(9); + break; + case 11: + S8(); + S(8); + S(9); + S(a); + break; + case 12: + S8(); + S(8); + S(9); + S(a); + S(b); + break; + case 13: + S8(); + S(8); + S(9); + S(a); + S(b); + S(c); + break; + case 14: + S8(); + S(8); + S(9); + S(a); + S(b); + S(c); + S(d); + break; + case 15: + S8(); + S(8); + S(9); + S(a); + S(b); + S(c); + S(d); + S(e); + break; + case 1: + S(0); + break; + case 2: + S(0); + S(1); + break; + case 3: + S(0); + S(1); + S(2); + break; + case 4: + S(0); + S(1); + S(2); + S(3); + break; + case 5: + S(0); + S(1); + S(2); + S(3); + S(4); + break; + case 6: + S(0); + S(1); + S(2); + S(3); + S(4); + S(5); + break; + case 7: + S(0); + S(1); + S(2); + S(3); + S(4); + S(5); + S(6); + break; } } @@ -240,30 +338,109 @@ inline void transpose_16xK_kernel(uint32_t* dst, uint32_t* src, size_t K, size_t transpose_m512i_16x16(r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf); switch (K) { - case 8: S8(); break; - case 9: S8() S(8); break; - case 10: S8(); S(8); S(9); break; - case 11: S8(); S(8); S(9); S(a); break; - case 12: S8(); S(8); S(9); S(a); S(b); break; - case 13: S8(); S(8); S(9); S(a); S(b); S(c); break; - case 14: S8(); S(8); S(9); S(a); S(b); S(c); S(d); break; - case 15: S8(); S(8); S(9); S(a); S(b); S(c); S(d); S(e); break; - case 1: S(0); break; - case 2: S(0); S(1); break; - case 3: S(0); S(1); S(2); break; - case 4: S(0); S(1); S(2); S(3); break; - case 5: S(0); S(1); S(2); S(3); S(4); break; - case 6: S(0); S(1); S(2); S(3); S(4); S(5); break; - case 7: S(0); S(1); S(2); S(3); S(4); S(5); S(6); break; + case 8: + S8(); + break; + case 9: + S8() S(8); + break; + case 10: + S8(); + S(8); + S(9); + break; + case 11: + S8(); + S(8); + S(9); + S(a); + break; + case 12: + S8(); + S(8); + S(9); + S(a); + S(b); + break; + case 13: + S8(); + S(8); + S(9); + S(a); + S(b); + S(c); + break; + case 14: + S8(); + S(8); + S(9); + S(a); + S(b); + S(c); + S(d); + break; + case 15: + S8(); + S(8); + S(9); + S(a); + S(b); + S(c); + S(d); + S(e); + break; + case 1: + S(0); + break; + case 2: + S(0); + S(1); + break; + case 3: + S(0); + S(1); + S(2); + break; + case 4: + S(0); + S(1); + S(2); + S(3); + break; + case 5: + S(0); + S(1); + S(2); + S(3); + S(4); + break; + case 6: + S(0); + S(1); + S(2); + S(3); + S(4); + S(5); + break; + case 7: + S(0); + S(1); + S(2); + S(3); + S(4); + S(5); + S(6); + break; } -#undef S -#undef S8 +# undef S +# undef S8 } #elif defined(HAVE_AVX2) // https://stackoverflow.com/questions/25622745/transpose-an-8x8-float-using-avx-avx2 -inline void transpose_8x8(__m256& r0, __m256& r1, __m256& r2, __m256& r3, __m256& r4, __m256& r5, __m256& r6, __m256& r7) { +inline void +transpose_8x8(__m256& r0, __m256& r1, __m256& r2, __m256& r3, __m256& r4, __m256& r5, __m256& r6, __m256& r7) { __m256 t0, t1, t2, t3, t4, t5, t6, t7; __m256 tt0, tt1, tt2, tt3, tt4, tt5, tt6, tt7; t0 = _mm256_unpacklo_ps(r0, r1); @@ -292,7 +469,7 @@ inline void transpose_8x8(__m256& r0, __m256& r1, __m256& r2, __m256& r3, __m256 r7 = _mm256_permute2f128_ps(tt3, tt7, 0x31); } -template +template inline void transpose_16x16_kernel(float* dst, T* src, size_t dst_stride, size_t src_stride) { __m256 r0, r1, r2, r3, r4, r5, r6, r7; @@ -323,7 +500,7 @@ inline void transpose_16x16_kernel(float* dst, T* src, size_t dst_stride, size_t } } -template +template inline void transpose_16xK_kernel(float* dst, T* src, size_t K, size_t dst_stride, size_t src_stride) { __m256 r0, r1, r2, r3, r4, r5, r6, r7; @@ -366,24 +543,59 @@ inline void transpose_16xK_kernel(float* dst, T* src, size_t K, size_t dst_strid transpose_8x8(r0, r1, r2, r3, r4, r5, r6, r7); -#define S(m) _mm256_storeu_ps(dst + j + m * dst_stride, r##m) +# define S(m) _mm256_storeu_ps(dst + j + m * dst_stride, r##m) switch (K) { - case 1: S(0); break; - case 2: S(0); S(1); break; - case 3: S(0); S(1); S(2); break; - case 4: S(0); S(1); S(2); S(3); break; - case 5: S(0); S(1); S(2); S(3); S(4); break; - case 6: S(0); S(1); S(2); S(3); S(4); S(5); break; - case 7: S(0); S(1); S(2); S(3); S(4); S(5); S(6); break; + case 1: + S(0); + break; + case 2: + S(0); + S(1); + break; + case 3: + S(0); + S(1); + S(2); + break; + case 4: + S(0); + S(1); + S(2); + S(3); + break; + case 5: + S(0); + S(1); + S(2); + S(3); + S(4); + break; + case 6: + S(0); + S(1); + S(2); + S(3); + S(4); + S(5); + break; + case 7: + S(0); + S(1); + S(2); + S(3); + S(4); + S(5); + S(6); + break; } -#undef S +# undef S } } } #else -template +template inline void transpose_16x16_kernel(TDST* dst, TSRC* src, size_t dst_stride, size_t src_stride) { for (size_t i = 0; i < 16; i++) { for (size_t j = 0; j < 16; j++) { @@ -392,7 +604,7 @@ inline void transpose_16x16_kernel(TDST* dst, TSRC* src, size_t dst_stride, size } } -template +template inline void transpose_16xK_kernel(TDST* dst, TSRC* src, size_t K, size_t dst_stride, size_t src_stride) { for (size_t i = 0; i < K; i++) { for (size_t j = 0; j < 16; j++) { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp index 2895a272b982b5..7df2e2371a843a 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/brgemm_kernel.cpp @@ -4,11 +4,12 @@ #include "brgemm_kernel.hpp" -#include "dnnl_extension_utils.h" -#include "utils/cpu_utils.hpp" #include #include +#include "dnnl_extension_utils.h" +#include "utils/cpu_utils.hpp" + using namespace dnnl::impl::cpu::x64; using namespace dnnl::impl; using namespace dnnl::impl::cpu::x64::matmul; @@ -100,8 +101,9 @@ BrgemmKernel::BrgemmKernel(size_t M, brgemmCtx.M = M_; brgemmCtx.N = N_; brgemmCtx.K = K_; - brgemmCtx.LDA = k ? K_blk : (is_avx_f16_only ? K : lda); // f16 use f32 internally - brgemmCtx.LDB = (!is_f32 || b_transposed) ? rnd_up(N, N_blk) : ldb; // bf16/fp16/b_transposed needs copy + brgemmCtx.LDA = k ? K_blk : (is_avx_f16_only ? K : lda); // f16 use f32 internally + brgemmCtx.LDB = + (!is_f32 || b_transposed) ? rnd_up(N, N_blk) : ldb; // bf16/fp16/b_transposed needs copy brgemmCtx.LDC = ldc; brgemmCtx.dt_in0 = static_cast(DnnlExtensionUtils::ElementTypeToDataType(srcType)); brgemmCtx.dt_in1 = static_cast(DnnlExtensionUtils::ElementTypeToDataType(weiType)); @@ -158,8 +160,8 @@ const size_t BrgemmKernel::get_scratch_b_size() const { } void BrgemmKernel::init_brgemm(brgemmCtx& ctx, - std::unique_ptr& brgKernel, - bool use_amx) { + std::unique_ptr& brgKernel, + bool use_amx) { brgemm_desc_t brgDesc; const bool is_int8 = @@ -208,7 +210,8 @@ void BrgemmKernel::init_brgemm(brgemmCtx& ctx, brgattr.max_bs = 1; brgattr.wary_tail_read = false; brgattr.hint_innermost_loop = brgemm_innermost_undef; - // if b_accumulate is true, it means we want c+=a*b. jit_brgemm_amx_uker_base_t::load_accumulators can support this using tileload(c) without postops + // if b_accumulate is true, it means we want c+=a*b. jit_brgemm_amx_uker_base_t::load_accumulators can support + // this using tileload(c) without postops brgattr.use_uker = true; brgattr.use_interleave_stores = true; brgattr.hint_prefetching = brgemm_kernel_prefetching_t::brgemm_prf1; @@ -248,7 +251,7 @@ void BrgemmKernel::init_brgemm_copy_a( brgCopyKernelConf.K_tail = K_tail; brgCopyKernelConf.K_blk = K_blk; brgCopyKernelConf.use_buffer_a_tail_only = false; - //padding K tail to K_blk, LDA is the stride for target tensor + // padding K tail to K_blk, LDA is the stride for target tensor brgCopyKernelConf.LDA = LDA; brgCopyKernelConf.has_zero_point_b = false; brgCopyKernelConf.s8s8_compensation_required = false; @@ -258,9 +261,13 @@ void BrgemmKernel::init_brgemm_copy_a( brgCopyKernelConf.copy_A_src_stride = copy_A_src_stride; // copy_a_kernel assumes that in/out tensor has same data type except f16 // copy_a_kernel has special path for f16: assuming input(f16) -> output(f32) - brgCopyKernelConf.a_dt_sz = is_avx_f16_only ? sizeof(ov::float16) : DnnlExtensionUtils::sizeOfDataType(static_cast(dt_in0)); + brgCopyKernelConf.a_dt_sz = is_avx_f16_only + ? sizeof(ov::float16) + : DnnlExtensionUtils::sizeOfDataType(static_cast(dt_in0)); // copied A has the same precision of original - brgCopyKernelConf.tr_a_dt_sz = is_avx_f16_only ? sizeof(float) : DnnlExtensionUtils::sizeOfDataType(static_cast(dt_in0)); + brgCopyKernelConf.tr_a_dt_sz = + is_avx_f16_only ? sizeof(float) + : DnnlExtensionUtils::sizeOfDataType(static_cast(dt_in0)); brgCopyKernelConf.transposed_A = transpose; brgCopyKernelConf.isa = is_avx_f16_only ? avx512_core_fp16 : avx512_core_amx; @@ -284,7 +291,7 @@ void BrgemmKernel::init_brgemm_copy_b( brgCopyKernelConf.wei_dt = is_avx_f16_only ? dnnl_data_type_t::dnnl_f32 : dt_in1; brgCopyKernelConf.orig_wei_dt = dt_in1; brgCopyKernelConf.wei_n_blk = N_blk; - brgCopyKernelConf.wei_tag = transpose ? dnnl_ba : dnnl_ab; + brgCopyKernelConf.wei_tag = transpose ? dnnl_ba : dnnl_ab; brgCopyKernelConf.copy_B_wei_stride = copy_B_wei_stride; brgCopyKernelConf.transposed_B = transpose; @@ -298,10 +305,14 @@ void BrgemmKernel::init_brgemm_copy_b( brgCopyKernelConf.K_tail = 0; brgCopyKernelConf.N_chunk_elems = brgCopyKernelConf.N_blk; // f16 is computed by upconverting. in(f16) -> out(f32) - brgCopyKernelConf.b_dt_sz = is_avx_f16_only ? sizeof(ov::float16) : - DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); - brgCopyKernelConf.tr_b_dt_sz = is_avx_f16_only ? sizeof(float) : - DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); + brgCopyKernelConf.b_dt_sz = + is_avx_f16_only + ? sizeof(ov::float16) + : DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); + brgCopyKernelConf.tr_b_dt_sz = + is_avx_f16_only + ? sizeof(float) + : DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); brgCopyKernelConf.req_wei_vnni_downconvert = false; if (is_with_amx) { @@ -390,12 +401,7 @@ void BrgemmKernel::executeGemm(bool is_M_tail, void* a, void* b, void* c, void* auto weight_ptr = ptr_scartch_b + B_stride; auto C_stride = n * count_N * ov::element::f32.size(); auto out_ptr = ptr_C + C_stride; - callBrgemm(brgemmCtx, - brgKernels[getBrgIdx(mIdx, k, n)], - local_a_ptr, - weight_ptr, - out_ptr, - wsp); + callBrgemm(brgemmCtx, brgKernels[getBrgIdx(mIdx, k, n)], local_a_ptr, weight_ptr, out_ptr, wsp); // stride K, N if body kernel is executed. if (k == 0) { count_K = brgemmCtx.K * brgemmCtx.LDB; diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/dft_uni_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/dft_uni_kernel.cpp index f8b0df611258a7..1d5e81410a0bf3 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/dft_uni_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/dft_uni_kernel.cpp @@ -4,7 +4,6 @@ #include "dft_uni_kernel.hpp" - using namespace dnnl::impl; using namespace dnnl::impl::utils; using namespace dnnl::impl::cpu::x64; @@ -16,7 +15,8 @@ namespace ov { namespace intel_cpu { template -jit_uni_dft_kernel_f32::jit_uni_dft_kernel_f32() : jit_uni_dft_kernel(), jit_generator(jit_name()) {} +jit_uni_dft_kernel_f32::jit_uni_dft_kernel_f32() : jit_uni_dft_kernel(), + jit_generator(jit_name()) {} template void jit_uni_dft_kernel_f32::create_ker() { @@ -115,11 +115,9 @@ template struct jit_uni_dft_kernel_f32; template struct jit_uni_dft_kernel_f32; template struct jit_uni_dft_kernel_f32; - template -jit_uni_fft_kernel_f32::jit_uni_fft_kernel_f32() - : jit_uni_fft_kernel(), - jit_generator(jit_name()) {} +jit_uni_fft_kernel_f32::jit_uni_fft_kernel_f32() : jit_uni_fft_kernel(), + jit_generator(jit_name()) {} template void jit_uni_fft_kernel_f32::create_ker() { diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/dft_uni_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/dft_uni_kernel.hpp index b70c99e5f8a527..095a3db97d2a64 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/dft_uni_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/dft_uni_kernel.hpp @@ -130,7 +130,6 @@ struct jit_uni_fft_kernel_f32 : public jit_uni_fft_kernel, public dnnl::impl::cp Vmm vmm_data_result = vmm_data_odd_2; - template void loop_process(int step); @@ -138,5 +137,5 @@ struct jit_uni_fft_kernel_f32 : public jit_uni_fft_kernel, public dnnl::impl::cp void move_data(const Xbyak::Xmm& x, const Xbyak::Address& addr, int count); }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/gather_uni_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/gather_uni_kernel.cpp index 5aaefb086f119c..c0de6520b7099c 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/gather_uni_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/gather_uni_kernel.cpp @@ -3,6 +3,7 @@ // #include "gather_uni_kernel.hpp" + #include "openvino/core/except.hpp" using namespace dnnl::impl::cpu; @@ -10,23 +11,52 @@ using namespace dnnl::impl::cpu; namespace ov { namespace intel_cpu { -const unsigned jitGatherKernelBase::shufMask8bitUni[16] = {0x0C080400, 0x80808080, 0x80808080, 0x80808080, 0x0C080400, 0x80808080, 0x80808080, 0x80808080, - 0x0C080400, 0x80808080, 0x80808080, 0x80808080, 0x0C080400, 0x80808080, 0x80808080, 0x80808080}; -const unsigned jitGatherKernelBase::permMask8bitA2[8] = {0, 4, 1, 5, 2, 6, 3, 7}; -const unsigned jitGatherKernelBase::permMask8bitA5[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; - -const unsigned jitGatherKernelBase::shufMask16bitUni[16] = {0x05040100, 0x0D0C0908, 0x80808080, 0x80808080, 0x05040100, 0x0D0C0908, 0x80808080, 0x80808080, - 0x05040100, 0x0D0C0908, 0x80808080, 0x80808080, 0x05040100, 0x0D0C0908, 0x80808080, 0x80808080}; -const unsigned jitGatherKernelBase::permMask16bitA2[8] = {0, 1, 4, 5, 2, 3, 6, 7}; -const unsigned jitGatherKernelBase::permMask16bitA5[16] = {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15}; +const unsigned jitGatherKernelBase::shufMask8bitUni[16] = {0x0C080400, + 0x80808080, + 0x80808080, + 0x80808080, + 0x0C080400, + 0x80808080, + 0x80808080, + 0x80808080, + 0x0C080400, + 0x80808080, + 0x80808080, + 0x80808080, + 0x0C080400, + 0x80808080, + 0x80808080, + 0x80808080}; +const unsigned jitGatherKernelBase::permMask8bitA2[8] = {0, 4, 1, 5, 2, 6, 3, 7}; +const unsigned jitGatherKernelBase::permMask8bitA5[16] = {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15}; + +const unsigned jitGatherKernelBase::shufMask16bitUni[16] = {0x05040100, + 0x0D0C0908, + 0x80808080, + 0x80808080, + 0x05040100, + 0x0D0C0908, + 0x80808080, + 0x80808080, + 0x05040100, + 0x0D0C0908, + 0x80808080, + 0x80808080, + 0x05040100, + 0x0D0C0908, + 0x80808080, + 0x80808080}; +const unsigned jitGatherKernelBase::permMask16bitA2[8] = {0, 1, 4, 5, 2, 3, 6, 7}; +const unsigned jitGatherKernelBase::permMask16bitA5[16] = {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15}; const unsigned jitGatherKernelBase::incVec[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; #define GET_OFF(field) offsetof(gatherJitExecArgs, field) template -jitUniGatherKernel::jitUniGatherKernel(const jGatherConfParams& jcp) : - jitGatherKernelBase(jcp), x64::jit_generator(jit_name()) { +jitUniGatherKernel::jitUniGatherKernel(const jGatherConfParams& jcp) + : jitGatherKernelBase(jcp), + x64::jit_generator(jit_name()) { vlen = x64::cpu_isa_traits::vlen; dataElPerVec = vlen / jcp.dataTypeSize; idxElPerVec = vlen / indicesTypeSize; @@ -74,7 +104,7 @@ void jitUniGatherKernel::generate() { if (!jcp.dynamicShapes) { mov(regAux1, ptr[regParams + GET_OFF(specIndicesSize)]); uni_vpbroadcastd(vmmSpecIdxSizeB, ptr[regAux1]); - uni_vpslld(vmmSpecIdxSizeB, vmmSpecIdxSizeB, idxTypeShift); // multiply by indices type size. + uni_vpslld(vmmSpecIdxSizeB, vmmSpecIdxSizeB, idxTypeShift); // multiply by indices type size. mov(regAux1, ptr[regParams + GET_OFF(specIdxB)]); uni_vmovups(vmmSpecIdxB, ptr[regAux1]); @@ -84,7 +114,7 @@ void jitUniGatherKernel::generate() { uni_vmovups(vmmSrcBeforeAxisSumB, ptr[regAux1]); } - if (jcp.afterAxisSize == 1lu) { // Elementwise case. + if (jcp.afterAxisSize == 1lu) { // Elementwise case. uni_vmovd(reg32SpecIdxSizeB, xmmSpecIdxSizeB); if (jcp.beforeAxisSize != 1lu) { mov(regAux1, ptr[regParams + GET_OFF(axisAndAfterAxisSizeB)]); @@ -98,8 +128,9 @@ void jitUniGatherKernel::generate() { mov(regBetweenBatchAndAxisSize, ptr[regAux1]); mov(regBetweenBatchAndAxisIter, ptr[regParams + GET_OFF(betweenBatchAndAxisIter)]); - if (jcp.specIdxSize < idxElPerVec) { // Short case. - if (jcp.specIdxSize != 1 && jcp.specIdxSize != 2 && jcp.specIdxSize != 4 && jcp.specIdxSize != 8 && jcp.specIdxSize != 16) { + if (jcp.specIdxSize < idxElPerVec) { // Short case. + if (jcp.specIdxSize != 1 && jcp.specIdxSize != 2 && jcp.specIdxSize != 4 && jcp.specIdxSize != 8 && + jcp.specIdxSize != 16) { mov(regAux1, ptr[regParams + GET_OFF(permIdxMask)]); uni_vmovups(vmmPermIdxMask, ptr[regAux1]); } @@ -107,7 +138,7 @@ void jitUniGatherKernel::generate() { mov(regAux1, ptr[regParams + GET_OFF(beforeAxisDiff)]); uni_vmovups(vmmBeforeAxDiffB, ptr[regAux1]); if (jcp.dataTypeSize != 1) - uni_vpslld(vmmBeforeAxDiffB, vmmBeforeAxDiffB, dataTypeShift); // multiply by data type size + uni_vpslld(vmmBeforeAxDiffB, vmmBeforeAxDiffB, dataTypeShift); // multiply by data type size } if (jcp.batchDims > 0lu) { mov(regAux1, ptr[regParams + GET_OFF(srcAfterBatchSizeB)]); @@ -115,14 +146,14 @@ void jitUniGatherKernel::generate() { } process(true, false); - } else { // Long case. + } else { // Long case. uni_vmovd(reg32IdxIter, xmmSpecIdxB); fillVlenVector(); process(false, false); } - } else { // Blocked case. - if (jcp.afterAxisSize <= idxElPerVec) { // Short case. + } else { // Blocked case. + if (jcp.afterAxisSize <= idxElPerVec) { // Short case. mov(regAux1, ptr[regParams + GET_OFF(afterAxIdxB)]); uni_vmovups(vmmAfterAxisIdxB, ptr[regAux1]); mov(regAux1, ptr[regParams + GET_OFF(afterAxisPermMask)]); @@ -146,18 +177,19 @@ void jitUniGatherKernel::generate() { } const uint64_t specIdxAndAfterAxisSize = jcp.specIdxSize * jcp.afterAxisSize; if (specIdxAndAfterAxisSize != 1 && specIdxAndAfterAxisSize != 2 && specIdxAndAfterAxisSize != 4 && - specIdxAndAfterAxisSize != 8 && specIdxAndAfterAxisSize != 16) { + specIdxAndAfterAxisSize != 8 && specIdxAndAfterAxisSize != 16) { mov(regAux1, ptr[regParams + GET_OFF(beforeAxisPermMask)]); uni_vmovups(vmmBeforeAxPermMask, ptr[regAux1]); } } process(true, true); - } else { // Long case. - OPENVINO_THROW("Gather kernel does not support static shape with after axis size greater than elements in vector."); + } else { // Long case. + OPENVINO_THROW("Gather kernel does not support static shape with after axis size greater than elements " + "in vector."); } } - } else { // Dynamic shapes. + } else { // Dynamic shapes. mov(regAux1, ptr[regParams + GET_OFF(start)]); uni_vpbroadcastd(vmmSpecIdxB, ptr[regAux1]); mov(regAux1, reinterpret_cast(incVec)); @@ -172,8 +204,8 @@ void jitUniGatherKernel::generate() { uni_vroundps(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, 0x1); uni_vfnmadd231ps(vmmSpecIdxB, vmmSrcBeforeAxisSumB, vAux1); uni_vcvtps2dq(vmmSpecIdxB, vmmSpecIdxB); - uni_vpslld(vmmSpecIdxB, vmmSpecIdxB, idxTypeShift); // multiply by indices type size. - uni_vpslld(vmmSpecIdxSizeB, vmmSpecIdxSizeB, idxTypeShift); // multiply by indices type size. + uni_vpslld(vmmSpecIdxB, vmmSpecIdxB, idxTypeShift); // multiply by indices type size. + uni_vpslld(vmmSpecIdxSizeB, vmmSpecIdxSizeB, idxTypeShift); // multiply by indices type size. uni_vmovd(reg32SpecIdxSizeB, xmmSpecIdxSizeB); mov(regAux1, ptr[regParams + GET_OFF(betweenBatchAndAxisSize)]); @@ -189,7 +221,8 @@ void jitUniGatherKernel::generate() { mov(regAux1, ptr[regParams + GET_OFF(axisAndAfterAxisSizeB)]); uni_vpbroadcastd(vmmAxisAndAfterAxisSizeB, ptr[regAux1]); - // Formula: srcBeforeAxisSum = ((start / specIndicesSize) % betweenBatchAndAxis) * axisAndAfterAxisSize + srcAfterBatchSize * idxBatchSum + // Formula: srcBeforeAxisSum = ((start / specIndicesSize) % betweenBatchAndAxis) * axisAndAfterAxisSize + + // srcAfterBatchSize * idxBatchSum if (jcp.beforeAxisSize != 1lu) { uni_vpmulld(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); mov(regAux1, ptr[regParams + GET_OFF(srcAfterBatchSizeB)]); @@ -210,28 +243,29 @@ void jitUniGatherKernel::generate() { cmp(regSpecIdxSizeB, vlen); jl(lLessThanVector1, T_NEAR); - uni_vmovd(reg32IdxIter, xmmSpecIdxB); - fillVlenVector(); + uni_vmovd(reg32IdxIter, xmmSpecIdxB); + fillVlenVector(); - process(false, false); - jmp(lE1, T_NEAR); + process(false, false); + jmp(lE1, T_NEAR); L(lLessThanVector1); - mov(regAux1, ptr[regParams + GET_OFF(permIdxMask)]); - uni_vmovups(vmmPermIdxMask, ptr[regAux1]); - if (jcp.beforeAxisSize != 1lu) { - mov(regAux1, ptr[regParams + GET_OFF(beforeAxisDiff)]); - uni_vmovups(vmmBeforeAxDiffB, ptr[regAux1]); - if (jcp.dataTypeSize != 1) - uni_vpslld(vmmBeforeAxDiffB, vmmBeforeAxDiffB, dataTypeShift); // multiply by data type size - } - mov(regAux1, ptr[regParams + GET_OFF(srcAfterBatchSizeB)]); - uni_vpbroadcastd(vmmSrcAfterBatchSizeB, ptr[regAux1]); + mov(regAux1, ptr[regParams + GET_OFF(permIdxMask)]); + uni_vmovups(vmmPermIdxMask, ptr[regAux1]); + if (jcp.beforeAxisSize != 1lu) { + mov(regAux1, ptr[regParams + GET_OFF(beforeAxisDiff)]); + uni_vmovups(vmmBeforeAxDiffB, ptr[regAux1]); + if (jcp.dataTypeSize != 1) + uni_vpslld(vmmBeforeAxDiffB, vmmBeforeAxDiffB, dataTypeShift); // multiply by data type size + } + mov(regAux1, ptr[regParams + GET_OFF(srcAfterBatchSizeB)]); + uni_vpbroadcastd(vmmSrcAfterBatchSizeB, ptr[regAux1]); - process(true, false); + process(true, false); L(lE1); jmp(lEnd, T_NEAR); } - L(lBlock); { + L(lBlock); + { mov(regAux1, ptr[regParams + GET_OFF(start)]); uni_vpbroadcastd(vmmAfterAxisIdxB, ptr[regAux1]); mov(regAux1, reinterpret_cast(incVec)); @@ -246,40 +280,40 @@ void jitUniGatherKernel::generate() { uni_vroundps(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, 0x1); uni_vfnmadd231ps(vmmAfterAxisIdxB, vmmSrcBeforeAxisSumB, vAux1); uni_vcvtps2dq(vmmAfterAxisIdxB, vmmAfterAxisIdxB); - uni_vpslld(vmmAfterAxisIdxB, vmmAfterAxisIdxB, idxTypeShift); // multiply by indices type size. + uni_vpslld(vmmAfterAxisIdxB, vmmAfterAxisIdxB, idxTypeShift); // multiply by indices type size. Xbyak::Label lLessThanVector2, lTail3, lTail4, lE2; cmp(regAux2, dataElPerVec); jl(lLessThanVector2, T_NEAR); - uni_vmovd(reg32IdxIter, xmmSpecIdxB); - fillVlenVector(); + uni_vmovd(reg32IdxIter, xmmSpecIdxB); + fillVlenVector(); -// process(false, true); - jmp(lE2, T_NEAR); + // process(false, true); + jmp(lE2, T_NEAR); L(lLessThanVector2); - auto& vAux2 = vmmAuxContainer[2]; - // Calculate permute mask - uni_vmovd(xAux0, reg32Aux2); - uni_vpbroadcastd(vAux1, xAux0); - mov(regAux1, reinterpret_cast(&idxElPerVec)); - uni_vpbroadcastd(vAux0, ptr[regAux1]); - uni_vpsubd(vmmAfterAxisPermMask, vAux0, vAux1); - mov(regAux1, reinterpret_cast(incVec)); - uni_vpaddd(vmmAfterAxisPermMask, vmmAfterAxisPermMask, ptr[regAux1]); - for (int i = 0; i < 6; i++) { - if (isa == x64::avx512_core) { - Xbyak::Opmask kMask2 = Xbyak::Opmask(vAux2.getIdx()); - vpcmpgtd(kMask2, vAux0, vmmAfterAxisPermMask); - uni_vpsubd(vmmAfterAxisPermMask | kMask2, vmmAfterAxisPermMask, vAux1); - } else { - vpcmpgtd(vAux2, vAux0, vmmAfterAxisPermMask); - vpandn(vAux2, vAux2, vAux1); - uni_vpsubd(vmmAfterAxisPermMask, vmmAfterAxisPermMask, vAux2); - } + auto& vAux2 = vmmAuxContainer[2]; + // Calculate permute mask + uni_vmovd(xAux0, reg32Aux2); + uni_vpbroadcastd(vAux1, xAux0); + mov(regAux1, reinterpret_cast(&idxElPerVec)); + uni_vpbroadcastd(vAux0, ptr[regAux1]); + uni_vpsubd(vmmAfterAxisPermMask, vAux0, vAux1); + mov(regAux1, reinterpret_cast(incVec)); + uni_vpaddd(vmmAfterAxisPermMask, vmmAfterAxisPermMask, ptr[regAux1]); + for (int i = 0; i < 6; i++) { + if (isa == x64::avx512_core) { + Xbyak::Opmask kMask2 = Xbyak::Opmask(vAux2.getIdx()); + vpcmpgtd(kMask2, vAux0, vmmAfterAxisPermMask); + uni_vpsubd(vmmAfterAxisPermMask | kMask2, vmmAfterAxisPermMask, vAux1); + } else { + vpcmpgtd(vAux2, vAux0, vmmAfterAxisPermMask); + vpandn(vAux2, vAux2, vAux1); + uni_vpsubd(vmmAfterAxisPermMask, vmmAfterAxisPermMask, vAux2); } + } - process(true, true); + process(true, true); L(lE2); } L(lEnd); @@ -323,7 +357,7 @@ void jitUniGatherKernel::normalizeRawIndices(Vmm& vRawIndices, } // Check boundaries. vpcmpgtd(kAuxMask, vmmAxisDim, vRawIndices); - vpcmpd(kDstMask | kAuxMask, vmmZeros, vRawIndices, 2); // 2 - LE + vpcmpd(kDstMask | kAuxMask, vmmZeros, vRawIndices, 2); // 2 - LE // Multiply by type size. if (jcp.dataTypeSize > 1) uni_vpslld(vRawIndices, vRawIndices, dataTypeShift); @@ -338,7 +372,7 @@ void jitUniGatherKernel::normWithUpperBound(Vmm& vTarget, Vmm& vMax, template <> void jitUniGatherKernel::normWithUpperBound(Vmm& vTarget, Vmm& vMax, Vmask& kAuxMask) { - vpcmpd(kAuxMask, vMax, vTarget, 2); // 2 -> LE + vpcmpd(kAuxMask, vMax, vTarget, 2); // 2 -> LE uni_vpsubd(vTarget | kAuxMask, vTarget, vMax); } @@ -359,77 +393,77 @@ void jitUniGatherKernel::calcSrcShiftLong(Vmm* vAuxPool, bool shiftFi add(regIdxIter, vlen); cmp(regIdxIter, regSpecIdxSizeB); jge(lIdxStride, T_NEAR); + if (jcp.batchDims > 0lu) { + uni_vpaddd(vDstShifts, vmmIdxBatchSumB, vmmSpecIdxB); + uni_vmovd(reg32Aux1, xmmAuxContainer[vDstShifts.getIdx()]); + } else { + uni_vmovd(reg32Aux1, xmmSpecIdxB); + } + vmovdqu(vDstShifts, ptr[regIndices + regAux1]); + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + if (jcp.beforeAxisSize != 1lu) + uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); + jmp(lExit, T_NEAR); + L(lIdxStride); + sub(regIdxIter, regSpecIdxSizeB); + vpcmpeqd(kDstMask, vAux0, vAux0); + if (shiftFirst) { + vpcmpgtd(vAux0, vmmSpecIdxSizeB, vmmSpecIdxB); + vpandn(vAux1, vAux0, vmmSpecIdxSizeB); + uni_vpsubd(vAux1, vmmSpecIdxB, vAux1); + if (jcp.batchDims > 0lu) + uni_vpaddd(vAux1, vmmIdxBatchSumB, vAux1); + uni_vpsubd(vmmSpecIdxB, vmmSpecIdxB, vmmSpecIdxSizeB); + } else { if (jcp.batchDims > 0lu) { - uni_vpaddd(vDstShifts, vmmIdxBatchSumB, vmmSpecIdxB); - uni_vmovd(reg32Aux1, xmmAuxContainer[vDstShifts.getIdx()]); + uni_vpaddd(vAux0, vmmIdxBatchSumB, vmmSpecIdxB); + uniVpGatherDd(vDstShifts, ptr[regIndices + vAux0], kDstMask); } else { - uni_vmovd(reg32Aux1, xmmSpecIdxB); + uniVpGatherDd(vDstShifts, ptr[regIndices + vmmSpecIdxB], kDstMask); } - vmovdqu(vDstShifts, ptr[regIndices + regAux1]); normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); - if (jcp.beforeAxisSize != 1lu) - uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); - jmp(lExit, T_NEAR); - L(lIdxStride); - sub(regIdxIter, regSpecIdxSizeB); - vpcmpeqd(kDstMask, vAux0, vAux0); - if (shiftFirst) { - vpcmpgtd(vAux0, vmmSpecIdxSizeB, vmmSpecIdxB); - vpandn(vAux1, vAux0, vmmSpecIdxSizeB); - uni_vpsubd(vAux1, vmmSpecIdxB, vAux1); - if (jcp.batchDims > 0lu) - uni_vpaddd(vAux1, vmmIdxBatchSumB, vAux1); - uni_vpsubd(vmmSpecIdxB, vmmSpecIdxB, vmmSpecIdxSizeB); - } else { - if (jcp.batchDims > 0lu) { - uni_vpaddd(vAux0, vmmIdxBatchSumB, vmmSpecIdxB); - uniVpGatherDd(vDstShifts, ptr[regIndices + vAux0], kDstMask); - } else { - uniVpGatherDd(vDstShifts, ptr[regIndices + vmmSpecIdxB], kDstMask); - } - normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); - uni_vpbroadcastd(vAux0, xmmSpecIdxB); - vpcmpgtd(vAux1, vAux0, vmmSpecIdxB); - vpandn(vAux0, vAux1, vmmSpecIdxSizeB); - uni_vpsubd(vmmSpecIdxB, vmmSpecIdxB, vAux0); + uni_vpbroadcastd(vAux0, xmmSpecIdxB); + vpcmpgtd(vAux1, vAux0, vmmSpecIdxB); + vpandn(vAux0, vAux1, vmmSpecIdxSizeB); + uni_vpsubd(vmmSpecIdxB, vmmSpecIdxB, vAux0); - if (jcp.beforeAxisSize != 1lu) { - uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); - vpandn(vAux0, vAux1, vmmAxisAndAfterAxisSizeB); - uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vAux0); - } + if (jcp.beforeAxisSize != 1lu) { + uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); + vpandn(vAux0, vAux1, vmmAxisAndAfterAxisSizeB); + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vAux0); } + } - if (jcp.batchDims > 0lu) { - Xbyak::Label l1; - inc(regBetweenBatchAndAxisIter); - cmp(regBetweenBatchAndAxisIter, regBetweenBatchAndAxisSize); - jl(l1, T_NEAR); - mov(regBetweenBatchAndAxisIter, 0); - if (shiftFirst) { - uni_vpaddd(vmmIdxBatchSumB, vmmIdxBatchSumB, vmmSpecIdxSizeB); - vpandn(vDstShifts, vAux0, vmmSpecIdxSizeB); - uni_vpaddd(vAux1, vAux1, vDstShifts); - } else { - vpandn(vAux0, vAux1, vmmSpecIdxSizeB); - uni_vpaddd(vmmIdxBatchSumB, vmmIdxBatchSumB, vAux0); - } - L(l1); + if (jcp.batchDims > 0lu) { + Xbyak::Label l1; + inc(regBetweenBatchAndAxisIter); + cmp(regBetweenBatchAndAxisIter, regBetweenBatchAndAxisSize); + jl(l1, T_NEAR); + mov(regBetweenBatchAndAxisIter, 0); + if (shiftFirst) { + uni_vpaddd(vmmIdxBatchSumB, vmmIdxBatchSumB, vmmSpecIdxSizeB); + vpandn(vDstShifts, vAux0, vmmSpecIdxSizeB); + uni_vpaddd(vAux1, vAux1, vDstShifts); + } else { + vpandn(vAux0, vAux1, vmmSpecIdxSizeB); + uni_vpaddd(vmmIdxBatchSumB, vmmIdxBatchSumB, vAux0); } + L(l1); + } - if (shiftFirst) { - uniVpGatherDd(vDstShifts, ptr[regIndices + vAux1], kDstMask); - normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + if (shiftFirst) { + uniVpGatherDd(vDstShifts, ptr[regIndices + vAux1], kDstMask); + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); - if (jcp.beforeAxisSize != 1lu) { - vpandn(vAux0, vAux0, vmmAxisAndAfterAxisSizeB); - uni_vpaddd(vAux0, vAux0, vmmSrcBeforeAxisSumB); - uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); + if (jcp.beforeAxisSize != 1lu) { + vpandn(vAux0, vAux0, vmmAxisAndAfterAxisSizeB); + uni_vpaddd(vAux0, vAux0, vmmSrcBeforeAxisSumB); + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); - uni_vpaddd(vDstShifts, vDstShifts, vAux0); - } + uni_vpaddd(vDstShifts, vDstShifts, vAux0); } + } L(lExit); } @@ -451,81 +485,81 @@ void jitUniGatherKernel::calcSrcShiftLong(Vmm* vAuxPool, bool add(regIdxIter, vlen); cmp(regIdxIter, regSpecIdxSizeB); jge(lIdxStride, T_NEAR); + if (jcp.batchDims > 0lu) { + uni_vpaddd(vDstShifts, vmmIdxBatchSumB, vmmSpecIdxB); + uni_vmovd(reg32Aux1, xmmAuxContainer[vDstShifts.getIdx()]); + } else { + uni_vmovd(reg32Aux1, xmmSpecIdxB); + } + vmovdqu64(vDstShifts, ptr[regIndices + regAux1]); + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + if (jcp.beforeAxisSize != 1lu) + uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); + jmp(lExit, T_NEAR); + L(lIdxStride); + sub(regIdxIter, regSpecIdxSizeB); + vpcmpeqd(kDstMask, vDstShifts, vDstShifts); + if (shiftFirst) { + vpcmpd(kAuxMask1, vmmSpecIdxSizeB, vmmSpecIdxB, 2); // 2 -> LE if (jcp.batchDims > 0lu) { - uni_vpaddd(vDstShifts, vmmIdxBatchSumB, vmmSpecIdxB); - uni_vmovd(reg32Aux1, xmmAuxContainer[vDstShifts.getIdx()]); + uni_vpaddd(vAux1, vmmIdxBatchSumB, vmmSpecIdxB); + uni_vpsubd(vAux1 | kAuxMask1, vAux1, vmmSpecIdxSizeB); } else { - uni_vmovd(reg32Aux1, xmmSpecIdxB); + uni_vmovups(vAux1, vmmSpecIdxB); + uni_vpsubd(vAux1 | kAuxMask1, vmmSpecIdxB, vmmSpecIdxSizeB); } - vmovdqu64(vDstShifts, ptr[regIndices + regAux1]); - normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); - if (jcp.beforeAxisSize != 1lu) - uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); - jmp(lExit, T_NEAR); - L(lIdxStride); - sub(regIdxIter, regSpecIdxSizeB); - vpcmpeqd(kDstMask, vDstShifts, vDstShifts); - if (shiftFirst) { - vpcmpd(kAuxMask1, vmmSpecIdxSizeB, vmmSpecIdxB, 2); // 2 -> LE - if (jcp.batchDims > 0lu) { - uni_vpaddd(vAux1, vmmIdxBatchSumB, vmmSpecIdxB); - uni_vpsubd(vAux1 | kAuxMask1, vAux1, vmmSpecIdxSizeB); - } else { - uni_vmovups(vAux1, vmmSpecIdxB); - uni_vpsubd(vAux1 | kAuxMask1, vmmSpecIdxB, vmmSpecIdxSizeB); - } - uni_vpsubd(vmmSpecIdxB, vmmSpecIdxB, vmmSpecIdxSizeB); + uni_vpsubd(vmmSpecIdxB, vmmSpecIdxB, vmmSpecIdxSizeB); + } else { + if (jcp.batchDims > 0lu) { + uni_vpaddd(vAux0, vmmIdxBatchSumB, vmmSpecIdxB); + uniVpGatherDd(vDstShifts, ptr[regIndices + vAux0], kDstMask); } else { - if (jcp.batchDims > 0lu) { - uni_vpaddd(vAux0, vmmIdxBatchSumB, vmmSpecIdxB); - uniVpGatherDd(vDstShifts, ptr[regIndices + vAux0], kDstMask); - } else { - uniVpGatherDd(vDstShifts, ptr[regIndices + vmmSpecIdxB], kDstMask); - } - normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + uniVpGatherDd(vDstShifts, ptr[regIndices + vmmSpecIdxB], kDstMask); + } + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); - uni_vpbroadcastd(vAux0, xmmSpecIdxB); - vpcmpd(kAuxMask1, vAux0, vmmSpecIdxB, 2); // 2 -> LE - uni_vpsubd(vmmSpecIdxB | kAuxMask1, vmmSpecIdxB, vmmSpecIdxSizeB); + uni_vpbroadcastd(vAux0, xmmSpecIdxB); + vpcmpd(kAuxMask1, vAux0, vmmSpecIdxB, 2); // 2 -> LE + uni_vpsubd(vmmSpecIdxB | kAuxMask1, vmmSpecIdxB, vmmSpecIdxSizeB); - if (jcp.beforeAxisSize != 1lu) { - uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); - uni_vpaddd(vmmSrcBeforeAxisSumB | kAuxMask1, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); - } + if (jcp.beforeAxisSize != 1lu) { + uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); + uni_vpaddd(vmmSrcBeforeAxisSumB | kAuxMask1, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); } + } - if (jcp.batchDims > 0lu) { - Xbyak::Label l1; - inc(regBetweenBatchAndAxisIter); - cmp(regBetweenBatchAndAxisIter, regBetweenBatchAndAxisSize); - jl(l1, T_NEAR); - mov(regBetweenBatchAndAxisIter, 0); - if (shiftFirst) { - uni_vpaddd(vmmIdxBatchSumB, vmmIdxBatchSumB, vmmSpecIdxSizeB); - uni_vpaddd(vAux1 | kAuxMask1, vAux1, vmmSpecIdxSizeB); - } else { - uni_vpaddd(vmmIdxBatchSumB | kAuxMask1, vmmIdxBatchSumB, vmmSpecIdxSizeB); - } - L(l1); + if (jcp.batchDims > 0lu) { + Xbyak::Label l1; + inc(regBetweenBatchAndAxisIter); + cmp(regBetweenBatchAndAxisIter, regBetweenBatchAndAxisSize); + jl(l1, T_NEAR); + mov(regBetweenBatchAndAxisIter, 0); + if (shiftFirst) { + uni_vpaddd(vmmIdxBatchSumB, vmmIdxBatchSumB, vmmSpecIdxSizeB); + uni_vpaddd(vAux1 | kAuxMask1, vAux1, vmmSpecIdxSizeB); + } else { + uni_vpaddd(vmmIdxBatchSumB | kAuxMask1, vmmIdxBatchSumB, vmmSpecIdxSizeB); } + L(l1); + } - if (shiftFirst) { - uniVpGatherDd(vDstShifts, ptr[regIndices + vAux1], kDstMask); - normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); + if (shiftFirst) { + uniVpGatherDd(vDstShifts, ptr[regIndices + vAux1], kDstMask); + normalizeRawIndices(vDstShifts, kDstMask, kAuxMask0); - if (jcp.beforeAxisSize != 1lu) { - uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); - uni_vpaddd(vDstShifts | kAuxMask1, vDstShifts, vmmAxisAndAfterAxisSizeB); - uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); - } + if (jcp.beforeAxisSize != 1lu) { + uni_vpaddd(vDstShifts, vDstShifts, vmmSrcBeforeAxisSumB); + uni_vpaddd(vDstShifts | kAuxMask1, vDstShifts, vmmAxisAndAfterAxisSizeB); + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); } + } L(lExit); } template void jitUniGatherKernel::calcSrcShiftLongBlock(Vmm* vAuxPool, bool shiftFirst) { - // Most likely there will no significant performance gain vs memcpy in reference implementation on big blocks after axis, - // therefore no time was invested to this case yet. + // Most likely there will no significant performance gain vs memcpy in reference implementation on big blocks after + // axis, therefore no time was invested to this case yet. OPENVINO_THROW("Unsupported case."); } @@ -541,7 +575,8 @@ void jitUniGatherKernel::calcSrcShiftShort(Vmm* vAuxPool, bool shiftFirst) if (jcp.beforeAxisSize != 1lu) uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmBeforeAxDiffB); // No sense to permute if specIdxSize is one of {1, 2, 4, 8, 16}. 0 is reserved for dynamic case. - if (jcp.specIdxSize != 1 && jcp.specIdxSize != 2 && jcp.specIdxSize != 4 && jcp.specIdxSize != 8 && jcp.specIdxSize != 16) { + if (jcp.specIdxSize != 1 && jcp.specIdxSize != 2 && jcp.specIdxSize != 4 && jcp.specIdxSize != 8 && + jcp.specIdxSize != 16) { vpermd(vmmSpecIdxB, vmmPermIdxMask, vmmSpecIdxB); if (jcp.beforeAxisSize != 1lu) vpermd(vmmBeforeAxDiffB, vmmPermIdxMask, vmmBeforeAxDiffB); @@ -588,7 +623,8 @@ void jitUniGatherKernel::calcSrcShiftShortBlock(Vmm* vAuxPool, bool shiftFi normWithUpperBound(vmmSpecIdxB, vmmSpecIdxSizeB, kAuxMask0); } // No sense to permute if afterAxisSize is one of {1, 2, 4, 8, 16}. 0 is reserved for dynamic case. - if (jcp.afterAxisSize != 1 && jcp.afterAxisSize != 2 && jcp.afterAxisSize != 4 && jcp.afterAxisSize != 8 && jcp.afterAxisSize != 16) { + if (jcp.afterAxisSize != 1 && jcp.afterAxisSize != 2 && jcp.afterAxisSize != 4 && jcp.afterAxisSize != 8 && + jcp.afterAxisSize != 16) { vpermd(vmmAfterAxisIdxB, vmmAfterAxisPermMask, vmmAfterAxisIdxB); if (jcp.specIdxSize != 1) vpermd(vmmSpecIdxDiff, vmmAfterAxisPermMask, vmmSpecIdxDiff); @@ -600,33 +636,33 @@ void jitUniGatherKernel::calcSrcShiftShortBlock(Vmm* vAuxPool, bool shiftFi uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmBeforeAxDiffB); uni_vmovups(vAux1, vmmSrcBeforeAxisSumB); if (specIdxAndAfterAxisSize != 1 && specIdxAndAfterAxisSize != 2 && specIdxAndAfterAxisSize != 4 && - specIdxAndAfterAxisSize != 8 && specIdxAndAfterAxisSize != 16) + specIdxAndAfterAxisSize != 8 && specIdxAndAfterAxisSize != 16) vpermd(vmmBeforeAxDiffB, vmmBeforeAxPermMask, vmmBeforeAxDiffB); } else { Xbyak::Label lBeforeAxStep, lBeforeAxStepEnd; add(rSpecIdxAndAfterAxIterB, idxElPerVec * jcp.dataTypeSize); cmp(rSpecIdxAndAfterAxIterB, rSpecIdxAndAfterAxSizeB); jl(lBeforeAxStep, T_NEAR); - sub(rSpecIdxAndAfterAxIterB, rSpecIdxAndAfterAxSizeB); - - vpmulld(vAux0, vmmSpecIdxB, vmmAfterAxisSize); - uni_vpaddd(vAux0, vAux0, vmmAfterAxisIdxB); - Xbyak::Xmm& xAux0 = xmmAuxContainer[vAux0.getIdx()]; - uni_vpbroadcastd(vAux1, xAux0); - if (isa == x64::avx512_core) { - Xbyak::Opmask kMask0 = Xbyak::Opmask(kAuxMask0.getIdx()); - vpcmpgtd(kMask0, vAux1, vAux0); - uni_vmovups(vAux1, vmmSrcBeforeAxisSumB); - uni_vpaddd(vAux1 | kMask0, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); - } else { - vpcmpgtd(vAux1, vAux1, vAux0); - vpand(vAux1, vAux1, vmmAxisAndAfterAxisSizeB); - uni_vpaddd(vAux1, vmmSrcBeforeAxisSumB, vAux1); - } - uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); - jmp(lBeforeAxStepEnd); - L(lBeforeAxStep); + sub(rSpecIdxAndAfterAxIterB, rSpecIdxAndAfterAxSizeB); + + vpmulld(vAux0, vmmSpecIdxB, vmmAfterAxisSize); + uni_vpaddd(vAux0, vAux0, vmmAfterAxisIdxB); + Xbyak::Xmm& xAux0 = xmmAuxContainer[vAux0.getIdx()]; + uni_vpbroadcastd(vAux1, xAux0); + if (isa == x64::avx512_core) { + Xbyak::Opmask kMask0 = Xbyak::Opmask(kAuxMask0.getIdx()); + vpcmpgtd(kMask0, vAux1, vAux0); uni_vmovups(vAux1, vmmSrcBeforeAxisSumB); + uni_vpaddd(vAux1 | kMask0, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); + } else { + vpcmpgtd(vAux1, vAux1, vAux0); + vpand(vAux1, vAux1, vmmAxisAndAfterAxisSizeB); + uni_vpaddd(vAux1, vmmSrcBeforeAxisSumB, vAux1); + } + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); + jmp(lBeforeAxStepEnd); + L(lBeforeAxStep); + uni_vmovups(vAux1, vmmSrcBeforeAxisSumB); L(lBeforeAxStepEnd); } } else { @@ -648,10 +684,10 @@ void jitUniGatherKernel::calcSrcShiftShortBlock(Vmm* vAuxPool, bool shiftFi add(rSpecIdxAndAfterAxIterB, idxElPerVec * jcp.dataTypeSize); cmp(rSpecIdxAndAfterAxIterB, rSpecIdxAndAfterAxSizeB); jl(lBeforeAxStepEnd1, T_NEAR); - sub(rSpecIdxAndAfterAxIterB, rSpecIdxAndAfterAxSizeB); + sub(rSpecIdxAndAfterAxIterB, rSpecIdxAndAfterAxSizeB); cmp(rSpecIdxAndAfterAxIterB, 0); jne(lBeforeAxStepEnd1, T_NEAR); - uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); + uni_vpaddd(vmmSrcBeforeAxisSumB, vmmSrcBeforeAxisSumB, vmmAxisAndAfterAxisSizeB); L(lBeforeAxStepEnd1); } } @@ -689,15 +725,15 @@ void jitUniGatherKernel::process(bool isShortIdx, bool blocked) { Xbyak::Label lTailProc, lEndProc; cmp(regWorkAmount, dataElPerVec); jl(lTailProc, T_NEAR); - if (jcp.dataTypeSize == 4) - process32b(isShortIdx, blocked); - else if (jcp.dataTypeSize == 2) - process16b(isShortIdx, blocked); - else if (jcp.dataTypeSize == 1) - process8b(isShortIdx, blocked); + if (jcp.dataTypeSize == 4) + process32b(isShortIdx, blocked); + else if (jcp.dataTypeSize == 2) + process16b(isShortIdx, blocked); + else if (jcp.dataTypeSize == 1) + process8b(isShortIdx, blocked); jmp(lEndProc, T_NEAR); L(lTailProc); - tail(isShortIdx, false, blocked); + tail(isShortIdx, false, blocked); L(lEndProc); } @@ -735,11 +771,11 @@ void jitUniGatherKernel::process16b(bool isShortIdx, bool blocked) { if (isa == x64::avx512_core) { vPermMask = vmmAuxContainer[7]; vShufMask = vmmAuxContainer[8]; - vBuff0 = vmmAuxContainer[9]; + vBuff0 = vmmAuxContainer[9]; } else { vPermMask = vmmAuxContainer[1]; vShufMask = vmmAuxContainer[4]; - vBuff0 = vmmAuxContainer[5]; + vBuff0 = vmmAuxContainer[5]; } mov(regAux1, reinterpret_cast(shufMask16bitUni)); @@ -799,13 +835,13 @@ void jitUniGatherKernel::process8b(bool isShortIdx, bool blocked) { if (isa == x64::avx512_core) { vPermMask = vmmAuxContainer[7]; vShufMask = vmmAuxContainer[8]; - vBuff0 = vmmAuxContainer[9]; - vBuff1 = vmmAuxContainer[10]; + vBuff0 = vmmAuxContainer[9]; + vBuff1 = vmmAuxContainer[10]; } else { vPermMask = vmmAuxContainer[1]; vShufMask = vmmAuxContainer[4]; - vBuff0 = vmmAuxContainer[5]; - vBuff1 = vmmAuxContainer[6]; + vBuff0 = vmmAuxContainer[5]; + vBuff1 = vmmAuxContainer[6]; } mov(regAux1, reinterpret_cast(shufMask8bitUni)); uni_vmovups(vShufMask, ptr[regAux1]); @@ -951,24 +987,30 @@ void jitUniGatherKernel::tail(bool isShortIdx, bool shiftFirst, bool blocke } template <> -void jitUniGatherKernel::fillRestWorkMask(Vmask& kDstMask, Vmm& vmmAux, const Xbyak::Reg64& rWorkRest, - const Xbyak::Reg64& rAux0, const Xbyak::Reg64& rAux1) { +void jitUniGatherKernel::fillRestWorkMask(Vmask& kDstMask, + Vmm& vmmAux, + const Xbyak::Reg64& rWorkRest, + const Xbyak::Reg64& rAux0, + const Xbyak::Reg64& rAux1) { Xbyak::Label lKmov; Xbyak::Reg32 rOnes(rAux1.getIdx()); mov(rOnes, 0x0000FFFF); cmp(rWorkRest, idxElPerVec); jge(lKmov); - Xbyak::Reg8 rShift(Xbyak::Operand::CL); - mov(rShift, idxElPerVec); - sub(rShift, rWorkRest); - shr(rOnes, rShift); + Xbyak::Reg8 rShift(Xbyak::Operand::CL); + mov(rShift, idxElPerVec); + sub(rShift, rWorkRest); + shr(rOnes, rShift); L(lKmov); kmovw(kDstMask, rOnes); } template <> -void jitUniGatherKernel::fillRestWorkMask(Vmask& kDstMask, Vmm& vAux, const Xbyak::Reg64& rWorkRest, - const Xbyak::Reg64& rAux0, const Xbyak::Reg64& rAux1) { +void jitUniGatherKernel::fillRestWorkMask(Vmask& kDstMask, + Vmm& vAux, + const Xbyak::Reg64& rWorkRest, + const Xbyak::Reg64& rAux0, + const Xbyak::Reg64& rAux1) { Xbyak::Label lEnd; mov(rAux0, rWorkRest); Xbyak::Reg32 rOnes(rAux1.getIdx()); @@ -990,7 +1032,10 @@ void jitUniGatherKernel::fillRestWorkMask(Vmask& kDstMask, Vmm& vAux, } template -void jitUniGatherKernel::storeVectorPart(const Xbyak::Reg64& rDst, const Xbyak::Reg64& rToStoreCounter, Vmm& vmmSrc, Vmm& vAux) { +void jitUniGatherKernel::storeVectorPart(const Xbyak::Reg64& rDst, + const Xbyak::Reg64& rToStoreCounter, + Vmm& vmmSrc, + Vmm& vAux) { Xbyak::Label lEnd; Xbyak::Xmm xAux(vAux.getIdx()); for (size_t j = 0; j < vlen / vlenXmm; j++) { @@ -1025,7 +1070,7 @@ void jitUniGatherKernel::fillVlenVector() { template <> void jitUniGatherKernel::fillVlenVector() { vpcmpeqd(vmmVecLenB, vmmVecLenB, vmmVecLenB); - vpsrld(vmmVecLenB, vmmVecLenB, 31); // Right shift to 1. + vpsrld(vmmVecLenB, vmmVecLenB, 31); // Right shift to 1. uni_vpslld(vmmVecLenB, vmmVecLenB, 5); // Left shift to 32. } @@ -1047,5 +1092,5 @@ bool jitUniGatherKernel::isSupportedConfiguration(uint64_t afterAxisSize) { template struct jitUniGatherKernel; template struct jitUniGatherKernel; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/gather_uni_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/gather_uni_kernel.hpp index 765efb17d091e2..de8cda30d06499 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/gather_uni_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/gather_uni_kernel.hpp @@ -19,12 +19,11 @@ // 1 | X | X | X | X | X | X | //-------------------------------------------------------------- - #pragma once -#include "jit_kernel_base.hpp" #include "cpu/x64/jit_generator.hpp" #include "dnnl_types.h" +#include "jit_kernel_base.hpp" namespace ov { namespace intel_cpu { @@ -71,8 +70,8 @@ struct gatherJitExecArgs { }; struct jitGatherKernelBase { - void (*ker_)(const gatherJitExecArgs *); - void operator()(const gatherJitExecArgs *args) { + void (*ker_)(const gatherJitExecArgs*); + void operator()(const gatherJitExecArgs* args) { assert(ker_); ker_(args); } @@ -120,8 +119,10 @@ struct jitUniGatherKernel : public jitGatherKernelBase, public dnnl::impl::cpu:: bool isSupportedConfiguration(uint64_t afterAxisSize) override; protected: - using Vmm = typename dnnl::impl::utils::conditional::type; - using Vmask = typename dnnl::impl::utils::conditional::type; + using Vmm = + typename dnnl::impl::utils::conditional::type; + using Vmask = + typename dnnl::impl::utils::conditional::type; static const uint32_t vlenXmm = dnnl::impl::cpu::x64::cpu_isa_traits::vlen; static const uint32_t indicesTypeSize = sizeof(uint32_t); static const uint8_t idxTypeShift = 2; @@ -155,7 +156,8 @@ struct jitUniGatherKernel : public jitGatherKernelBase, public dnnl::impl::cpu:: // Masks pool. Do not use k0 with gather instruction! Vmask masksContainer[8] = {Vmask(0), Vmask(1), Vmask(2), Vmask(3), Vmask(4), Vmask(5), Vmask(6), Vmask(7)}; // Auxiliary pool. - Vmm vmmAuxContainer[12] = {Vmm(0), Vmm(1), Vmm(2), Vmm(3), Vmm(4), Vmm(5), Vmm(6), /*AVX5*/ Vmm(16), Vmm(17), Vmm(18), Vmm(19), Vmm(20)}; + Vmm vmmAuxContainer[12] = + {Vmm(0), Vmm(1), Vmm(2), Vmm(3), Vmm(4), Vmm(5), Vmm(6), /*AVX5*/ Vmm(16), Vmm(17), Vmm(18), Vmm(19), Vmm(20)}; // Common. Vmm vmmZeros = Vmm(7); Vmm vmmSrcBeforeAxisSumB = Vmm(8); @@ -165,13 +167,13 @@ struct jitUniGatherKernel : public jitGatherKernelBase, public dnnl::impl::cpu:: Vmm vmmAxisAndAfterAxisSizeB = Vmm(12); // Only short. - Vmm vmmSrcAfterBatchSizeB = Vmm(13); - Vmm vmmPermIdxMask = Vmm(14); + Vmm vmmSrcAfterBatchSizeB = Vmm(13); + Vmm vmmPermIdxMask = Vmm(14); Vmm& vmmBeforeAxDiffB = vmmAxisAndAfterAxisSizeB; // Blocked short. Vmm& vmmSpecIdxDiff = vmmAuxContainer[4]; Vmm& vmmAfterAxisSize = vmmAuxContainer[5]; - Vmm vmmAfterAxisIdxB = Vmm(15); + Vmm vmmAfterAxisIdxB = Vmm(15); Vmm& vmmAfterAxisPermMask = vmmPermIdxMask; Vmm& vmmBeforeAxPermMask = vmmAuxContainer[6]; // Only long. @@ -179,13 +181,13 @@ struct jitUniGatherKernel : public jitGatherKernelBase, public dnnl::impl::cpu:: Vmm vmmIdxBatchSumB = Vmm(14); // XMM - Xbyak::Xmm xmmAuxContainer[6] = {Xbyak::Xmm(0), Xbyak::Xmm(1), Xbyak::Xmm(2), Xbyak::Xmm(3), Xbyak::Xmm(4), Xbyak::Xmm(16)}; + Xbyak::Xmm xmmAuxContainer[6] = + {Xbyak::Xmm(0), Xbyak::Xmm(1), Xbyak::Xmm(2), Xbyak::Xmm(3), Xbyak::Xmm(4), Xbyak::Xmm(16)}; Xbyak::Xmm xmmZeros = Xbyak::Xmm(vmmZeros.getIdx()); Xbyak::Xmm xmmSrcBeforeAxisSum = Xbyak::Xmm(vmmSrcBeforeAxisSumB.getIdx()); Xbyak::Xmm xmmSpecIdxSizeB = Xbyak::Xmm(vmmSpecIdxSizeB.getIdx()); Xbyak::Xmm xmmSpecIdxB = Xbyak::Xmm(vmmSpecIdxB.getIdx()); - void calcSrcShiftLong(Vmm* vAuxPool, bool shiftFirst = true); void calcSrcShiftLongBlock(Vmm* vAuxPool, bool shiftFirst = true); void calcSrcShiftShort(Vmm* vAuxPool, bool shiftFirst = true); @@ -199,7 +201,11 @@ struct jitUniGatherKernel : public jitGatherKernelBase, public dnnl::impl::cpu:: // Aux functions. void normalizeRawIndices(Vmm& rawIndices, Vmask& dstMask, Vmask& aux); void normWithUpperBound(Vmm& vTarget, Vmm& vMax, Vmask& kAuxMask); - void fillRestWorkMask(Vmask& kMask, Vmm& vAux, const Xbyak::Reg64& rWorkRest, const Xbyak::Reg64& rAux0, const Xbyak::Reg64& rAux1); + void fillRestWorkMask(Vmask& kMask, + Vmm& vAux, + const Xbyak::Reg64& rWorkRest, + const Xbyak::Reg64& rAux0, + const Xbyak::Reg64& rAux1); void storeVectorPart(const Xbyak::Reg64& rDst, const Xbyak::Reg64& rToStoreCounter, Vmm& vmmSrc, Vmm& vAux); void uniVpGatherDd(Vmm& vDst, const Xbyak::Address& srcAddr, Vmask& vMask); void fillVlenVector(); @@ -208,5 +214,5 @@ struct jitUniGatherKernel : public jitGatherKernelBase, public dnnl::impl::cpu:: const unsigned* permMask16bitUni; }; -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/grid_sample.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/grid_sample.cpp index d91688689b86c0..908de00cbb0534 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/grid_sample.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/grid_sample.cpp @@ -13,8 +13,8 @@ namespace kernel { #define GET_OFF(field) offsetof(GridSamplesKernelExecArgs, field) template -GridSampleKernel::GridSampleKernel(const GridSampleKernelConfParams& jcp) : - GridSampleKernelBase(jit_name(), jcp, isa) { +GridSampleKernel::GridSampleKernel(const GridSampleKernelConfParams& jcp) + : GridSampleKernelBase(jit_name(), jcp, isa) { vlen = x64::cpu_isa_traits::vlen; dataTypeSize = jcp.inDataPrc.size(); gridTypeSize = jcp.gridPrc.size(); @@ -39,15 +39,15 @@ void GridSampleKernel::generate() { this->preamble(); registersPool = RegistersPool::create(isa, {rax, rcx, rsp, rdi, k0}); - regSrc = getReg64(); + regSrc = getReg64(); regGrid = getReg64(); - regDst = getReg64(); + regDst = getReg64(); regSrcChannelStepB = getReg64(); regDstChannelStepB = getReg64(); - mov(regSrc, ptr[regParams + GET_OFF(src)]); + mov(regSrc, ptr[regParams + GET_OFF(src)]); mov(regGrid, ptr[regParams + GET_OFF(grid)]); - mov(regDst, ptr[regParams + GET_OFF(dst)]); + mov(regDst, ptr[regParams + GET_OFF(dst)]); mov(regSrcChannelStepB, ptr[regParams + GET_OFF(srcChannelStepB)]); mov(regDstChannelStepB, ptr[regParams + GET_OFF(dstChannelStepB)]); @@ -82,7 +82,7 @@ void GridSampleKernel::initVectors() { if (one_of(jcp.interpolationMode, GridSampleInterpolationMode::BICUBIC, GridSampleInterpolationMode::BILINEAR)) { vOnesF = getVmm(); - mov(r32Aux, 0x3f800000); // 1.f + mov(r32Aux, 0x3f800000); // 1.f vpbroadcastd(vOnesF, r32Aux); } @@ -96,11 +96,11 @@ void GridSampleKernel::initVectors() { uni_vpbroadcastd(vHDenormCoefF, ptr[rAux]); } else { vHalfF = getVmm(); - mov(r32Aux, 0x3f000000); // 0.5f + mov(r32Aux, 0x3f000000); // 0.5f vpbroadcastd(vHalfF, r32Aux); } - static const unsigned gridPermMask[16] = { 0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15 }; + static const unsigned gridPermMask[16] = {0, 2, 4, 6, 8, 10, 12, 14, 1, 3, 5, 7, 9, 11, 13, 15}; mov(rAux, reinterpret_cast(gridPermMask)); vGridPermMask = getVmm(); uni_vmovups(vGridPermMask, ptr[rAux]); @@ -141,24 +141,24 @@ void GridSampleKernel::initVectors() { if (jcp.interpolationMode == GridSampleInterpolationMode::BICUBIC) { vConst_0_75 = getVmm(); - mov(r32Aux, 0xbf400000); // -0.75f + mov(r32Aux, 0xbf400000); // -0.75f vpbroadcastd(vConst_0_75, r32Aux); vConst_1_25 = getVmm(); - mov(r32Aux, 0x3fa00000); // 1.25f + mov(r32Aux, 0x3fa00000); // 1.25f vpbroadcastd(vConst_1_25, r32Aux); vConst_1_50 = getVmm(); - mov(r32Aux, 0x3fc00000); // 1.5f + mov(r32Aux, 0x3fc00000); // 1.5f vpbroadcastd(vConst_1_50, r32Aux); vConst_2_00 = getVmm(); - mov(r32Aux, 0x40000000); // 2.0f + mov(r32Aux, 0x40000000); // 2.0f vpbroadcastd(vConst_2_00, r32Aux); vConst_2_25 = getVmm(); - mov(r32Aux, 0x40100000); // 2.25f + mov(r32Aux, 0x40100000); // 2.25f vpbroadcastd(vConst_2_25, r32Aux); } } -template // Works for AVX2, AVX, SSE41 +template // Works for AVX2, AVX, SSE41 void GridSampleKernel::initVectors() { auto rAux = getReg64(); @@ -167,9 +167,10 @@ void GridSampleKernel::initVectors() { uni_vmovups(vSrcWidthF, ptr[rAux]); if (one_of(jcp.interpolationMode, GridSampleInterpolationMode::BILINEAR, GridSampleInterpolationMode::NEAREST) || - (jcp.interpolationMode == GridSampleInterpolationMode::BICUBIC && (jcp.paddingMode == GridSamplePaddingMode::REFLECTION || - (jcp.paddingMode == GridSamplePaddingMode::BORDER && !jcp.alignCorners) || - jcp.paddingMode == GridSamplePaddingMode::ZEROS)) ) { + (jcp.interpolationMode == GridSampleInterpolationMode::BICUBIC && + (jcp.paddingMode == GridSamplePaddingMode::REFLECTION || + (jcp.paddingMode == GridSamplePaddingMode::BORDER && !jcp.alignCorners) || + jcp.paddingMode == GridSamplePaddingMode::ZEROS))) { vSrcHeightF = getVmm(); mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]); uni_vmovups(vSrcHeightF, ptr[rAux]); @@ -184,7 +185,8 @@ void GridSampleKernel::initVectors() { if (jcp.interpolationMode != GridSampleInterpolationMode::BICUBIC) { if (one_of(jcp.paddingMode, GridSamplePaddingMode::BORDER, GridSamplePaddingMode::ZEROS) && - ((isa == x64::avx2 && jcp.interpolationMode == GridSampleInterpolationMode::NEAREST) || one_of(isa, x64::avx, x64::sse41))) { + ((isa == x64::avx2 && jcp.interpolationMode == GridSampleInterpolationMode::NEAREST) || + one_of(isa, x64::avx, x64::sse41))) { vZeros = getVmm(); uni_vpxor(vZeros, vZeros, vZeros); } @@ -193,20 +195,21 @@ void GridSampleKernel::initVectors() { mov(rAux, ptr[regParams + GET_OFF(wDenormCoefF)]); vWDenormCoefF = getVmm(); uni_vmovups(vWDenormCoefF, ptr[rAux]); - if (!(jcp.interpolationMode == GridSampleInterpolationMode::BILINEAR && jcp.paddingMode == GridSamplePaddingMode::ZEROS)) { + if (!(jcp.interpolationMode == GridSampleInterpolationMode::BILINEAR && + jcp.paddingMode == GridSamplePaddingMode::ZEROS)) { mov(rAux, ptr[regParams + GET_OFF(hDenormCoefF)]); vHDenormCoefF = getVmm(); uni_vmovups(vHDenormCoefF, ptr[rAux]); } } else { - static const float halfArr[8] = { 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f }; + static const float halfArr[8] = {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; mov(rAux, reinterpret_cast(halfArr)); vHalfF = getVmm(); uni_vmovups(vHalfF, ptr[rAux]); } if (isa == x64::avx2 && jcp.interpolationMode == GridSampleInterpolationMode::NEAREST) { - static const unsigned gridPermMask[8] = { 0, 2, 4, 6, 1, 3, 5, 7 }; + static const unsigned gridPermMask[8] = {0, 2, 4, 6, 1, 3, 5, 7}; mov(rAux, reinterpret_cast(gridPermMask)); vGridPermMask = getVmm(); uni_vmovups(vGridPermMask, ptr[rAux]); @@ -214,15 +217,16 @@ void GridSampleKernel::initVectors() { } if (jcp.interpolationMode == GridSampleInterpolationMode::BICUBIC || - (jcp.interpolationMode == GridSampleInterpolationMode::BILINEAR && jcp.paddingMode != GridSamplePaddingMode::ZEROS)) { - static const float onesArr[8] = { 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f }; + (jcp.interpolationMode == GridSampleInterpolationMode::BILINEAR && + jcp.paddingMode != GridSamplePaddingMode::ZEROS)) { + static const float onesArr[8] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; mov(rAux, reinterpret_cast(onesArr)); vOnesF = getVmm(); uni_vmovups(vOnesF, ptr[rAux]); } } -template // Works for AVX512, AVX2, AVX, SSE41 +template // Works for AVX512, AVX2, AVX, SSE41 void GridSampleKernel::process() { regWorkAmount = getReg64(); @@ -244,12 +248,12 @@ void GridSampleKernel::process() { spatialLoop(); if (jcp.dynamicShapes) { - add(regSrc, ptr[regParams + GET_OFF(srcBatchStepB)]); + add(regSrc, ptr[regParams + GET_OFF(srcBatchStepB)]); } else { add(regSrc, jcp.srcBatchStepB); } add(regGrid, ptr[regParams + GET_OFF(gridBatchStepB)]); - add(regDst, ptr[regParams + GET_OFF(dstBatchStepB)]); + add(regDst, ptr[regParams + GET_OFF(dstBatchStepB)]); if (jcp.dynamicBatch) { dec(regBatch); @@ -259,7 +263,7 @@ void GridSampleKernel::process() { } } -template // Works for AVX512, AVX2, AVX, SSE41 +template // Works for AVX512, AVX2, AVX, SSE41 void GridSampleKernel::spatialLoop() { auto vHCoord = getVmm(); auto vWCoord = getVmm(); @@ -286,7 +290,7 @@ void GridSampleKernel::spatialLoop() { tail(); } -template // Works for AVX512, AVX2, AVX, SSE41 +template // Works for AVX512, AVX2, AVX, SSE41 void GridSampleKernel::interpolation(const Vmm& vWCoord, const Vmm& vHCoord, bool tail) { if (jcp.interpolationMode == GridSampleInterpolationMode::BILINEAR) { bilinearInterpolation(vWCoord, vHCoord, tail); @@ -297,7 +301,7 @@ void GridSampleKernel::interpolation(const Vmm& vWCoord, const Vmm& vHCoord } } -template // Works for AVX512, AVX2, AVX, SSE41 +template // Works for AVX512, AVX2, AVX, SSE41 void GridSampleKernel::tail() { Xbyak::Label lEnd; cmp(regWorkAmount, 0); @@ -311,7 +315,7 @@ void GridSampleKernel::tail() { interpolation(vWCoord, vHCoord, true); if (dataTypeSize > 1) - sal(regWorkAmount, dataTypeShift); // Multiply by source data type size. + sal(regWorkAmount, dataTypeShift); // Multiply by source data type size. add(regDst, regWorkAmount); L(lEnd); @@ -319,15 +323,15 @@ void GridSampleKernel::tail() { template <> void GridSampleKernel::getCoordinates(const Vmm& vHCoord, const Vmm& vWCoord) { - vpermd(vWCoord, vGridPermMask, ptr[regGrid]); // Permute to XXXX.XXXX.YYYY.YYYY - vshuff64x2(vHCoord, vWCoord, vHCoord, 0B11101110); // Extract Y component + vpermd(vWCoord, vGridPermMask, ptr[regGrid]); // Permute to XXXX.XXXX.YYYY.YYYY + vshuff64x2(vHCoord, vWCoord, vHCoord, 0B11101110); // Extract Y component add(regGrid, vlen); auto vAux = getVmm(); - vpermd(vAux, vGridPermMask, ptr[regGrid]); // Permute to XXXX.XXXX.YYYY.YYYY - vshuff64x2(vWCoord, vWCoord, vAux, 0B01000100); // Extract X component - vshuff64x2(vHCoord, vHCoord, vAux, 0B11100100); // Extract Y component + vpermd(vAux, vGridPermMask, ptr[regGrid]); // Permute to XXXX.XXXX.YYYY.YYYY + vshuff64x2(vWCoord, vWCoord, vAux, 0B01000100); // Extract X component + vshuff64x2(vHCoord, vHCoord, vAux, 0B11100100); // Extract Y component add(regGrid, vlen); } @@ -349,19 +353,19 @@ void GridSampleKernel::getCoordinates(const Vmm& vHCoord, const Vmm& uni_vmovups(vPermMask, ptr[rAux]); } - vpermd(vWCoord, vPermMask, ptr[regGrid]); // Permute to XXXX.YYYY - vperm2f128(vHCoord, vHCoord, vWCoord, 0B00000011); // Extract Y component + vpermd(vWCoord, vPermMask, ptr[regGrid]); // Permute to XXXX.YYYY + vperm2f128(vHCoord, vHCoord, vWCoord, 0B00000011); // Extract Y component add(regGrid, vlen); - vpermd(vAux, vPermMask, ptr[regGrid]); // Permute to XXXX.YYYY - vperm2f128(vWCoord, vWCoord, vAux, 0B00100000); // Extract X component - vperm2f128(vHCoord, vHCoord, vAux, 0B00110000); // Extract Y component + vpermd(vAux, vPermMask, ptr[regGrid]); // Permute to XXXX.YYYY + vperm2f128(vWCoord, vWCoord, vAux, 0B00100000); // Extract X component + vperm2f128(vHCoord, vHCoord, vAux, 0B00110000); // Extract Y component add(regGrid, vlen); } -template // Works for AVX, SSE41 +template // Works for AVX, SSE41 void GridSampleKernel::getCoordinates(const Vmm& vHCoord, const Vmm& vWCoord) { auto vAux = getVmm(); Xbyak::Xmm xmmWCoord(vWCoord.getIdx()); @@ -417,12 +421,12 @@ void GridSampleKernel::getTailCoordinates(const Vmm& vHCoord, auto rAux = getReg64(); mov(rAux, regWorkAmount); - sal(rAux, 0x1); // Multiply by gridShape[3]. + sal(rAux, 0x1); // Multiply by gridShape[3]. cmp(regWorkAmount, dataElPerVec / 2); jl(lRest, T_NEAR); { vpermd(vWCoord, vGridPermMask, ptr[regGrid]); - vshuff64x2(vHCoord, vWCoord, vHCoord, 0B11101110); // Extract Y component + vshuff64x2(vHCoord, vWCoord, vHCoord, 0B11101110); // Extract Y component add(regGrid, vlen); sub(rAux, dataElPerVec); @@ -433,8 +437,8 @@ void GridSampleKernel::getTailCoordinates(const Vmm& vHCoord, uni_vmovups((Vmm)vAux | kTailMask, ptr[regGrid]); vpermd(vAux, vGridPermMask, vAux); Xbyak::Ymm ymmAux(vAux.getIdx()); - vshuff64x2(vWCoord, vWCoord, vAux, 0B01000100); // Extract X component - vshuff64x2(vHCoord, vHCoord, vAux, 0B11100100); // Extract Y component + vshuff64x2(vWCoord, vWCoord, vAux, 0B01000100); // Extract X component + vshuff64x2(vHCoord, vHCoord, vAux, 0B11100100); // Extract Y component jmp(lGridShift, T_NEAR); } @@ -443,12 +447,12 @@ void GridSampleKernel::getTailCoordinates(const Vmm& vHCoord, fillRestWorkMask(kTailMask, rAux); uni_vmovups(vWCoord | kTailMask, ptr[regGrid]); vpermd(vWCoord, vGridPermMask, vWCoord); - vshuff64x2(vHCoord, vWCoord, vHCoord, 0B11101110); // Extract Y component + vshuff64x2(vHCoord, vWCoord, vHCoord, 0B11101110); // Extract Y component } L(lGridShift); if (dataTypeSize > 1) - sal(rAux, dataTypeShift); // Multiply by source data type size. + sal(rAux, dataTypeShift); // Multiply by source data type size. add(regGrid, rAux); L(lEnd); @@ -475,36 +479,36 @@ void GridSampleKernel::getTailCoordinates(const Vmm& vHCoord, const V } mov(rAux, regWorkAmount); - sal(rAux, 0x1); // multiply by gridShape[3] == 2 + sal(rAux, 0x1); // multiply by gridShape[3] == 2 cmp(regWorkAmount, dataElPerVec / 2); jl(lRest, T_NEAR); { - vpermd(vWCoord, vPermMask, ptr[regGrid]); // Permute to XXXX.YYYY - vperm2f128(vHCoord, vHCoord, vWCoord, 0B00000011); // Extract Y component + vpermd(vWCoord, vPermMask, ptr[regGrid]); // Permute to XXXX.YYYY + vperm2f128(vHCoord, vHCoord, vWCoord, 0B00000011); // Extract Y component add(regGrid, vlen); sub(rAux, dataElPerVec); cmp(rAux, 0); jle(lEnd, T_NEAR); - auto vAux = getVmm(); + auto vAux = getVmm(); load(vAux, ptr[regGrid], rAux, dataTypeSize); vpermd(vAux, vPermMask, vAux); - vperm2f128(vWCoord, vWCoord, vAux, 0B00100000); // Extract X component - vperm2f128(vHCoord, vHCoord, vAux, 0B00110000); // Extract Y component + vperm2f128(vWCoord, vWCoord, vAux, 0B00100000); // Extract X component + vperm2f128(vHCoord, vHCoord, vAux, 0B00110000); // Extract Y component jmp(lGridShift, T_NEAR); } L(lRest); { load(vWCoord, ptr[regGrid], rAux, dataTypeSize); - vpermd(vWCoord, vPermMask, vWCoord); // Permute to XXXX.YYYY - vperm2f128(vHCoord, vHCoord, vWCoord, 0B00000011); // Extract Y component + vpermd(vWCoord, vPermMask, vWCoord); // Permute to XXXX.YYYY + vperm2f128(vHCoord, vHCoord, vWCoord, 0B00000011); // Extract Y component } L(lGridShift); if (dataTypeSize > 1) - sal(rAux, dataTypeShift); // Multiply by source data type size. + sal(rAux, dataTypeShift); // Multiply by source data type size. add(regGrid, rAux); L(lEnd); @@ -519,7 +523,7 @@ void GridSampleKernel::getTailCoordinates(const Vmm& vHCoord, const Vm auto rGridRest = getReg64(); mov(rGridRest, regWorkAmount); - sal(rGridRest, 0x1); // multiply by gridShape[3] == 2 + sal(rGridRest, 0x1); // multiply by gridShape[3] == 2 for (size_t i = 0; i < dataElPerVec; i++) { cmp(rGridRest, 0); @@ -566,7 +570,7 @@ void GridSampleKernel::getTailCoordinates(const Vmm& vHCoord, const auto rAux = getReg64(); mov(rAux, regWorkAmount); - sal(rAux, 0x1); // Multiply by gridShape[3] == 2 + sal(rAux, 0x1); // Multiply by gridShape[3] == 2 cmp(regWorkAmount, dataElPerVec / 2); jl(lRest, T_NEAR); { @@ -584,31 +588,31 @@ void GridSampleKernel::getTailCoordinates(const Vmm& vHCoord, const auto vAux = getVmm(); load(vAux, ptr[regGrid], rAux, dataTypeSize); pshufd(vAux, vAux, 0B11011000); - shufpd(vWCoord, vAux, 0x0); // Extract X component - shufpd(vHCoord, vAux, 0B00000011); // Extract Y component + shufpd(vWCoord, vAux, 0x0); // Extract X component + shufpd(vHCoord, vAux, 0B00000011); // Extract Y component jmp(lGridShift, T_NEAR); L(lHShuf); - shufpd(vHCoord, vHCoord, 0B00000001); // Extract Y component + shufpd(vHCoord, vHCoord, 0B00000001); // Extract Y component jmp(lEnd, T_NEAR); } L(lRest); { load(vWCoord, ptr[regGrid], rAux, dataTypeSize); - pshufd(vWCoord, vWCoord, 0B11011000); // Extract X component - shufpd(vHCoord, vWCoord, 0B00000010); // Extract Y component + pshufd(vWCoord, vWCoord, 0B11011000); // Extract X component + shufpd(vHCoord, vWCoord, 0B00000010); // Extract Y component shufpd(vHCoord, vHCoord, 0B00000001); } L(lGridShift); if (dataTypeSize > 1) - sal(rAux, dataTypeShift); // Multiply by source data type size. + sal(rAux, dataTypeShift); // Multiply by source data type size. add(regGrid, rAux); L(lEnd); } -template // Works for AVX512, AVX2, AVX, SSE41 +template // Works for AVX512, AVX2, AVX, SSE41 void GridSampleKernel::denormalizeRawCoordinates(const Vmm& vWCoord, const Vmm& vHCoord) { if (jcp.alignCorners) { if (vWDenormCoefF.isInitialized()) { @@ -640,7 +644,7 @@ void GridSampleKernel::denormalizeRawCoordinates(const Vmm& vWCoord, const halfHolder = getVmm(); vHalfTmp = halfHolder; static const float halfValues[x64::cpu_isa_traits::vlen / sizeof(float)] = - { 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f }; + {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}; mov(rAux, reinterpret_cast(halfValues)); uni_vmovups(vHalfTmp, ptr[rAux]); } @@ -671,14 +675,14 @@ void GridSampleKernel::denormalizeRawCoordinates(const Vmm& vWCoord, const template <> void GridSampleKernel::zerosPaddingW(const Vmask& kDst, const Vmm& vCoord) { - vcmpps(kDst, vCoord, vSrcWidthF, CMP_LT_PS); // vCoord < vUpperBound - vcmpps(kDst | kDst, vZeros, vCoord, CMP_LE_PS); // vCoord >= vZeros + vcmpps(kDst, vCoord, vSrcWidthF, CMP_LT_PS); // vCoord < vUpperBound + vcmpps(kDst | kDst, vZeros, vCoord, CMP_LE_PS); // vCoord >= vZeros } template <> void GridSampleKernel::zerosPaddingH(const Vmask& kDst, const Vmm& vCoord, const Vmask& kMaskW) { - vcmpps(kDst | kMaskW, vCoord, vSrcHeightF, CMP_LT_PS); // vCoord < vUpperBound - vcmpps(kDst | kDst, vZeros, vCoord, CMP_LE_PS); // vCoord >= vZeros + vcmpps(kDst | kMaskW, vCoord, vSrcHeightF, CMP_LT_PS); // vCoord < vUpperBound + vcmpps(kDst | kDst, vZeros, vCoord, CMP_LE_PS); // vCoord >= vZeros } template <> @@ -692,7 +696,7 @@ void GridSampleKernel::zerosPaddingW(const Vmask& kDst, const Vmm& v auto vAux = getVmm(); if (vSrcWidthF.isInitialized()) { - uni_vcmpps(vAux, vWCoord, vSrcWidthF, CMP_LT_PS); // vWCoord < vSrcWidthF + uni_vcmpps(vAux, vWCoord, vSrcWidthF, CMP_LT_PS); // vWCoord < vSrcWidthF } else { auto rAux = getReg64(); mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]); @@ -700,8 +704,8 @@ void GridSampleKernel::zerosPaddingW(const Vmask& kDst, const Vmm& v } uni_vpxor(kDst, kDst, kDst); - uni_vcmpps(kDst, kDst, vWCoord, CMP_LE_PS); // vWCoord >= vZeros - uni_vpand(kDst, kDst, vAux); // vZeros <= vWCoord < vSrcWidthF + uni_vcmpps(kDst, kDst, vWCoord, CMP_LE_PS); // vWCoord >= vZeros + uni_vpand(kDst, kDst, vAux); // vZeros <= vWCoord < vSrcWidthF } template <> @@ -709,18 +713,18 @@ void GridSampleKernel::zerosPaddingH(const Vmask& kDst, const Vmm& v auto vAux = getVmm(); if (vSrcHeightF.isInitialized()) { - uni_vcmpps(vAux, vHCoord, vSrcHeightF, CMP_LT_PS); // vHCoord < vSrcHeightF + uni_vcmpps(vAux, vHCoord, vSrcHeightF, CMP_LT_PS); // vHCoord < vSrcHeightF } else { auto rAux = getReg64(); mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]); - uni_vcmpps(vAux, vHCoord, ptr[rAux], CMP_LT_PS); // vHCoord < vSrcHeightF + uni_vcmpps(vAux, vHCoord, ptr[rAux], CMP_LT_PS); // vHCoord < vSrcHeightF } uni_vmovups(kDst, kMaskW); - uni_vpand(kDst, kDst, vAux); // vHCoord < vSrcHeightF && vZeros <= vWCoord < vSrcWidthF + uni_vpand(kDst, kDst, vAux); // vHCoord < vSrcHeightF && vZeros <= vWCoord < vSrcWidthF uni_vpxor(vAux, vAux, vAux); - uni_vcmpps(vAux, vAux, vHCoord, CMP_LE_PS); // vHCoord >= vZeros - uni_vpand(kDst, kDst, vAux); // vZeros <= vHCoord < vSrcHeightF && vZeros <= vWCoord < vSrcWidthF + uni_vcmpps(vAux, vAux, vHCoord, CMP_LE_PS); // vHCoord >= vZeros + uni_vpand(kDst, kDst, vAux); // vZeros <= vHCoord < vSrcHeightF && vZeros <= vWCoord < vSrcWidthF } template <> @@ -729,7 +733,7 @@ void GridSampleKernel::zerosPadding(const Vmask& kDst, const Vmm& vH zerosPaddingH(kDst, vHCoord, kDst); } -template // Works for AVX2, AVX +template // Works for AVX2, AVX void GridSampleKernel::zerosPaddingW(const Vmask& kDst, const Vmm& vCoord) { auto vAux = getVmm(); Vmm vZerosTmp; @@ -743,18 +747,18 @@ void GridSampleKernel::zerosPaddingW(const Vmask& kDst, const Vmm& vCoord) } if (vSrcWidthF.isInitialized()) { - uni_vcmpps(vAux, vCoord, vSrcWidthF, CMP_LT_PS); // vWCoord < vSrcWidthF + uni_vcmpps(vAux, vCoord, vSrcWidthF, CMP_LT_PS); // vWCoord < vSrcWidthF } else { auto rAux = getReg64(); mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]); uni_vcmpps(vAux, vCoord, ptr[rAux], CMP_LT_PS); // vWCoord < vSrcWidthF } - uni_vcmpps(kDst, vZerosTmp, vCoord, CMP_LE_PS); // vWCoord >= vZeros - uni_vandps(kDst, kDst, vAux); // vZeros <= vWCoord < vSrcWidthF + uni_vcmpps(kDst, vZerosTmp, vCoord, CMP_LE_PS); // vWCoord >= vZeros + uni_vandps(kDst, kDst, vAux); // vZeros <= vWCoord < vSrcWidthF } -template // Works for AVX2, AVX +template // Works for AVX2, AVX void GridSampleKernel::zerosPaddingH(const Vmask& kDst, const Vmm& vCoord, const Vmask& kMaskW) { auto vAux = getVmm(); Vmm vZerosTmp; @@ -768,19 +772,19 @@ void GridSampleKernel::zerosPaddingH(const Vmask& kDst, const Vmm& vCoord, } if (vSrcHeightF.isInitialized()) { - uni_vcmpps(vAux, vCoord, vSrcHeightF, CMP_LT_PS); // vHCoord < vSrcHeightF + uni_vcmpps(vAux, vCoord, vSrcHeightF, CMP_LT_PS); // vHCoord < vSrcHeightF } else { auto rAux = getReg64(); mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]); - uni_vcmpps(vAux, vCoord, ptr[rAux], CMP_LT_PS); // vHCoord < vSrcHeightF + uni_vcmpps(vAux, vCoord, ptr[rAux], CMP_LT_PS); // vHCoord < vSrcHeightF } uni_vandps(kDst, kMaskW, vAux); - uni_vcmpps(vAux, vZerosTmp, vCoord, CMP_LE_PS); // vHCoord >= vZeros + uni_vcmpps(vAux, vZerosTmp, vCoord, CMP_LE_PS); // vHCoord >= vZeros uni_vandps(kDst, kDst, vAux); } -template // Works for AVX2, AVX +template // Works for AVX2, AVX void GridSampleKernel::zerosPadding(const Vmask& kDst, const Vmm& vHCoord, const Vmm& vWCoord) { bool releaseZeroVec = false; if (!vZeros.isInitialized()) { @@ -799,11 +803,14 @@ void GridSampleKernel::zerosPadding(const Vmask& kDst, const Vmm& vHCoord, template <> void GridSampleKernel::borderPadding(const Vmm& vCoordDst, const Vmm& vCoordOrigin, const coord dim) { - vrangeps(vCoordDst, vCoordOrigin, dim == coord::w ? vSrcWidthSub1F : vSrcHeightSub1F, 0x0); // vWCoord >= vSrcWidthF - vrangeps(vCoordDst, vCoordDst, vZeros, 0x1); // vWCoord < vZeros + vrangeps(vCoordDst, + vCoordOrigin, + dim == coord::w ? vSrcWidthSub1F : vSrcHeightSub1F, + 0x0); // vWCoord >= vSrcWidthF + vrangeps(vCoordDst, vCoordDst, vZeros, 0x1); // vWCoord < vZeros } -template // Works for AVX2, AVX, SSE41 +template // Works for AVX2, AVX, SSE41 void GridSampleKernel::borderPadding(const Vmm& vCoordDst, const Vmm& vCoordOrigin, const coord dim) { auto rAux = getReg64(); auto vAux = getVmm(); @@ -836,7 +843,7 @@ void GridSampleKernel::borderPadding(const Vmm& vCoordDst, const Vmm& vCoor uni_vaddps(vCoordDst, vCoordDst, vAux); if (vZeros.isInitialized()) { - uni_vcmpps(vAux, vCoordDst, vZeros, 0x6); // vCoord >= vZeros + uni_vcmpps(vAux, vCoordDst, vZeros, 0x6); // vCoord >= vZeros } else { if (isa == x64::sse41) { if (!vAux1.isInitialized()) { @@ -844,27 +851,29 @@ void GridSampleKernel::borderPadding(const Vmm& vCoordDst, const Vmm& vCoor vSub1F = vAux1; } uni_vpxor(vSub1F, vSub1F, vSub1F); - uni_vcmpps(vAux, vCoordDst, vSub1F, 0x6); // vCoord >= vZeros + uni_vcmpps(vAux, vCoordDst, vSub1F, 0x6); // vCoord >= vZeros } else { uni_vpxor(vAux, vAux, vAux); - uni_vcmpps(vAux, vCoordDst, vAux, 0x6); // vCoord >= vZeros + uni_vcmpps(vAux, vCoordDst, vAux, 0x6); // vCoord >= vZeros } } uni_vandps(vCoordDst, vCoordDst, vAux); } template <> -void GridSampleKernel::reflectionPadding(const Vmm& vCoordDst, const Vmm& vCoordOrigin, const coord dim) { +void GridSampleKernel::reflectionPadding(const Vmm& vCoordDst, + const Vmm& vCoordOrigin, + const coord dim) { auto vAux = getVmm(); auto kAux = getMask(); const auto& vSrcDimMul2Sub1F = dim == coord::w ? vSrcWidthMul2Sub1F : vSrcHeightMul2Sub1F; if (jcp.alignCorners) { // abs(x) % D21 - uni_vandps(vCoordDst, vCoordOrigin, vAbsMask); // abs(x) + uni_vandps(vCoordDst, vCoordOrigin, vAbsMask); // abs(x) uni_vdivps(vAux, vCoordDst, vSrcDimMul2Sub1F); - uni_vroundps(vAux, vAux, 0x3); // Truncation - uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2Sub1F); // abs(x) % D21 + uni_vroundps(vAux, vAux, 0x3); // Truncation + uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2Sub1F); // abs(x) % D21 // Check that the result does not exceed the divisor. vcmpps(kAux, vSrcDimMul2Sub1F, vCoordDst, CMP_LE_PS); @@ -876,12 +885,12 @@ void GridSampleKernel::reflectionPadding(const Vmm& vCoordDst, if (vCoordDst.getIdx() != vCoordOrigin.getIdx()) uni_vmovups(vCoordDst, vCoordOrigin); uni_vdivps(vAux, vCoordDst, vSrcDimMul2F); - uni_vroundps(vAux, vAux, 0x3); // Truncation - uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2F); // x % D2 - uni_vaddps(vCoordDst, vCoordDst, vSrcDimMul2F); // x % D2 + D2 + uni_vroundps(vAux, vAux, 0x3); // Truncation + uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2F); // x % D2 + uni_vaddps(vCoordDst, vCoordDst, vSrcDimMul2F); // x % D2 + D2 uni_vdivps(vAux, vCoordDst, vSrcDimMul2F); - uni_vroundps(vAux, vAux, 0x3); // Truncation - uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2F); // (x % D2 + D2) % D2 + uni_vroundps(vAux, vAux, 0x3); // Truncation + uni_vfnmadd231ps(vCoordDst, vAux, vSrcDimMul2F); // (x % D2 + D2) % D2 // Check that the result does not exceed the divisor. vcmpps(kAux, vSrcDimMul2F, vCoordDst, CMP_LE_PS); @@ -890,13 +899,13 @@ void GridSampleKernel::reflectionPadding(const Vmm& vCoordDst, } uni_vsubps(vAux, vSrcDimMul2Sub1F, vCoordDst); - vcmpps(kAux, dim == coord::w ? vSrcWidthF : vSrcHeightF, vCoordDst, CMP_LE_PS); // vCoordDst >= vSrcDimF + vcmpps(kAux, dim == coord::w ? vSrcWidthF : vSrcHeightF, vCoordDst, CMP_LE_PS); // vCoordDst >= vSrcDimF uni_vmovups(vCoordDst | kAux, vAux); } -template // Works for AVX2, AVX, SSE41 +template // Works for AVX2, AVX, SSE41 void GridSampleKernel::reflectionPadding(const Vmm& vCoordDst, const Vmm& vCoordOrigin, const coord dim) { - auto rAux = getReg64(); + auto rAux = getReg64(); auto vAux0 = getVmm(); auto vAux1 = getVmm(); @@ -904,14 +913,15 @@ void GridSampleKernel::reflectionPadding(const Vmm& vCoordDst, const Vmm& v // D21 = (Dim - 1) * 2 if (jcp.alignCorners) { // x' = abs(x) % D21 - D21 - static const unsigned absMask[8] = { 0x7fffffff, 0x7fffffff, 0x7fffffff, 0x7fffffff, 0x7fffffff, 0x7fffffff, 0x7fffffff, 0x7fffffff }; - if (isa ==x64::sse41) { - static const unsigned *absPtr = absMask + (reinterpret_cast(absMask) % 16) / sizeof(unsigned); + static const unsigned absMask[8] = + {0x7fffffff, 0x7fffffff, 0x7fffffff, 0x7fffffff, 0x7fffffff, 0x7fffffff, 0x7fffffff, 0x7fffffff}; + if (isa == x64::sse41) { + static const unsigned* absPtr = absMask + (reinterpret_cast(absMask) % 16) / sizeof(unsigned); mov(rAux, reinterpret_cast(absPtr)); } else { mov(rAux, reinterpret_cast(absMask)); } - uni_vandps(vCoordDst, vCoordOrigin, ptr[rAux]); // abs(x) + uni_vandps(vCoordDst, vCoordOrigin, ptr[rAux]); // abs(x) Vmm vMul2Sub1; if (dim == coord::w) { @@ -932,8 +942,8 @@ void GridSampleKernel::reflectionPadding(const Vmm& vCoordDst, const Vmm& v } } uni_vdivps(vAux0, vCoordDst, vMul2Sub1); - uni_vroundps(vAux0, vAux0, 0x3); // Truncation - uni_vfnmadd231ps(vCoordDst, vAux0, vMul2Sub1); // abs(x) % D21 + uni_vroundps(vAux0, vAux0, 0x3); // Truncation + uni_vfnmadd231ps(vCoordDst, vAux0, vMul2Sub1); // abs(x) % D21 // Check that the result does not exceed the divisor. uni_vcmpps(vAux0, vCoordDst, vMul2Sub1, CMP_LT_PS); @@ -942,7 +952,7 @@ void GridSampleKernel::reflectionPadding(const Vmm& vCoordDst, const Vmm& v uni_vcmpps(vAux0, vAux0, vCoordDst, CMP_LE_PS); uni_vandps(vCoordDst, vCoordDst, vAux0); - uni_vsubps(vAux0, vCoordDst, vMul2Sub1); // abs(x) % D21 - D21 + uni_vsubps(vAux0, vCoordDst, vMul2Sub1); // abs(x) % D21 - D21 } else { // x' = (x % D2 + D2) % D2 - D21 if (vCoordDst.getIdx() != vCoordOrigin.getIdx()) @@ -966,12 +976,12 @@ void GridSampleKernel::reflectionPadding(const Vmm& vCoordDst, const Vmm& v } } uni_vdivps(vAux0, vCoordOrigin, vMul2); - uni_vroundps(vAux0, vAux0, 0x3); // Truncation - uni_vfnmadd231ps(vCoordDst, vAux0, vMul2); // x % D2 - uni_vaddps(vCoordDst, vCoordDst, vMul2); // x % D2 + D2 + uni_vroundps(vAux0, vAux0, 0x3); // Truncation + uni_vfnmadd231ps(vCoordDst, vAux0, vMul2); // x % D2 + uni_vaddps(vCoordDst, vCoordDst, vMul2); // x % D2 + D2 uni_vdivps(vAux0, vCoordDst, vMul2); - uni_vroundps(vAux0, vAux0, 0x3); // Truncation - uni_vfnmadd231ps(vCoordDst, vAux0, vMul2); // (x % D2 + D2) % D2 + uni_vroundps(vAux0, vAux0, 0x3); // Truncation + uni_vfnmadd231ps(vCoordDst, vAux0, vMul2); // (x % D2 + D2) % D2 // Check that the result does not exceed the divisor. uni_vcmpps(vAux0, vCoordDst, vMul2, CMP_LT_PS); @@ -1002,20 +1012,20 @@ void GridSampleKernel::reflectionPadding(const Vmm& vCoordDst, const Vmm& v uni_vcmpps(vAux1, vCoordDst, vSrcWidthF, CMP_LT_PS); // vCoordDst < vUpperBound } else { mov(rAux, ptr[regParams + GET_OFF(srcWidthF)]); - uni_vcmpps(vAux1, vCoordDst, ptr[rAux], CMP_LT_PS); // vCoordDst < vUpperBound + uni_vcmpps(vAux1, vCoordDst, ptr[rAux], CMP_LT_PS); // vCoordDst < vUpperBound } } else { if (vSrcHeightF.isInitialized()) { - uni_vcmpps(vAux1, vCoordDst, vSrcHeightF, CMP_LT_PS); // vCoordDst < vUpperBound + uni_vcmpps(vAux1, vCoordDst, vSrcHeightF, CMP_LT_PS); // vCoordDst < vUpperBound } else { mov(rAux, ptr[regParams + GET_OFF(srcHeightF)]); - uni_vcmpps(vAux1, vCoordDst, ptr[rAux], CMP_LT_PS); // vCoordDst < vUpperBound + uni_vcmpps(vAux1, vCoordDst, ptr[rAux], CMP_LT_PS); // vCoordDst < vUpperBound } } uni_vandps(vCoordDst, vCoordDst, vAux1); uni_vandnps(vAux1, vAux1, vAux0); - uni_vsubps(vCoordDst, vCoordDst, vAux1); // set -x' for vCoordDst >= Dim + uni_vsubps(vCoordDst, vCoordDst, vAux1); // set -x' for vCoordDst >= Dim } template <> @@ -1045,12 +1055,13 @@ void GridSampleKernel::bicubicCoefficients(const Vmm& vCoef, c template <> void GridSampleKernel::bicubicCoefficients(const Vmm& vCoef, const Vmm& vDDim, const uint8_t idx) { - static const size_t elPerVec = x64::cpu_isa_traits::vlen / sizeof(float);; - static const float const_0_75[elPerVec] = { -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f }; - static const float const_1_25[elPerVec] = { 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f }; - static const float const_1_50[elPerVec] = { 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f }; - static const float const_2_00[elPerVec] = { 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f }; - static const float const_2_25[elPerVec] = { 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f }; + static const size_t elPerVec = x64::cpu_isa_traits::vlen / sizeof(float); + ; + static const float const_0_75[elPerVec] = {-0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f}; + static const float const_1_25[elPerVec] = {1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f}; + static const float const_1_50[elPerVec] = {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f}; + static const float const_2_00[elPerVec] = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}; + static const float const_2_25[elPerVec] = {2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f}; auto rAux = getReg64(); @@ -1088,11 +1099,11 @@ void GridSampleKernel::bicubicCoefficients(const Vmm& vCoef, const Vm template <> void GridSampleKernel::bicubicCoefficients(const Vmm& vCoef, const Vmm& vDDim, const uint8_t idx) { static const size_t elPerVec = x64::cpu_isa_traits::vlen / sizeof(float); - static const float const_0_75[elPerVec] = { -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f }; - static const float const_1_25[elPerVec] = { 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f }; - static const float const_1_50[elPerVec] = { 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f }; - static const float const_2_00[elPerVec] = { 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f }; - static const float const_2_25[elPerVec] = { 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f }; + static const float const_0_75[elPerVec] = {-0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f}; + static const float const_1_25[elPerVec] = {1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f}; + static const float const_1_50[elPerVec] = {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f}; + static const float const_2_00[elPerVec] = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}; + static const float const_2_25[elPerVec] = {2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f}; auto rAux = getReg64(); auto vAux = getVmm(); @@ -1136,11 +1147,11 @@ template <> void GridSampleKernel::bicubicCoefficients(const Vmm& vCoef, const Vmm& vDDim, const uint8_t idx) { static const size_t elToAllocate = 2 * x64::cpu_isa_traits::vlen / sizeof(float); // Allocation with a margin for address alignment. - static const float c_0_75[elToAllocate] = { -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f }; - static const float c_1_25[elToAllocate] = { 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f }; - static const float c_1_50[elToAllocate] = { 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f }; - static const float c_2_00[elToAllocate] = { 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f }; - static const float c_2_25[elToAllocate] = { 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f }; + static const float c_0_75[elToAllocate] = {-0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f, -0.75f}; + static const float c_1_25[elToAllocate] = {1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f, 1.25f}; + static const float c_1_50[elToAllocate] = {1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f, 1.5f}; + static const float c_2_00[elToAllocate] = {2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}; + static const float c_2_25[elToAllocate] = {2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f, 2.25f}; // Address alignment for XMM. static const float* const_0_75 = c_0_75 + (reinterpret_cast(c_0_75) % 16) / sizeof(float); static const float* const_1_25 = c_1_25 + (reinterpret_cast(c_1_25) % 16) / sizeof(float); @@ -1193,15 +1204,15 @@ void GridSampleKernel::bicubicCoefficients(const Vmm& vCoef, const V } } -template // Works for AVX512, AVX2, AVX, SSE41 +template // Works for AVX512, AVX2, AVX, SSE41 void GridSampleKernel::nearestInterpolation(const Vmm& vWCoord, const Vmm& vHCoord, bool tail) { const auto& vSrcShift = vWCoord; - const auto& vAux = vHCoord; - auto kGatherMask = getMask(); - auto kAuxMask = getMask(); + const auto& vAux = vHCoord; + auto kGatherMask = getMask(); + auto kAuxMask = getMask(); - uni_vroundps(vWCoord, vWCoord, 0x0); // Round near - uni_vroundps(vHCoord, vHCoord, 0x0); // Round near + uni_vroundps(vWCoord, vWCoord, 0x0); // Round near + uni_vroundps(vHCoord, vHCoord, 0x0); // Round near bool useMask = false, zeroFill = false; if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) { @@ -1272,15 +1283,15 @@ template <> void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoord, const Vmm& vHCoord, bool tail) { const auto& vDX = vWCoord; const auto& vDY = vHCoord; - auto shift00 = getVmm(); - auto shift01 = getVmm(); - auto shift10 = getVmm(); - auto shift11 = getVmm(); - auto vAux = getVmm(); + auto shift00 = getVmm(); + auto shift01 = getVmm(); + auto shift10 = getVmm(); + auto shift11 = getVmm(); + auto vAux = getVmm(); RegistersPool::Reg kMask00, kMask01, kMask10, kMask11; - uni_vroundps(shift00, vWCoord, 0x1); // Round floor - uni_vroundps(shift01, vHCoord, 0x1); // Round floor + uni_vroundps(shift00, vWCoord, 0x1); // Round floor + uni_vroundps(shift01, vHCoord, 0x1); // Round floor uni_vsubps(vDX, vWCoord, shift00); uni_vsubps(vDY, vHCoord, shift01); uni_vaddps(shift10, shift00, vOnesF); @@ -1294,10 +1305,10 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoor kMask10 = getMask(); kMask11 = getMask(); - zerosPadding(kMask00, shift01, shift00); // (y; x) - zerosPadding(kMask01, shift01, shift10); // (y; x + 1) - zerosPadding(kMask11, shift11, shift10); // (y + 1; x + 1) - zerosPadding(kMask10, shift11, shift00); // (y + 1; x) + zerosPadding(kMask00, shift01, shift00); // (y; x) + zerosPadding(kMask01, shift01, shift10); // (y; x + 1) + zerosPadding(kMask11, shift11, shift10); // (y + 1; x + 1) + zerosPadding(kMask10, shift11, shift00); // (y + 1; x) hwShiftPs2dq(shift00, shift01, shift00, vSrcWidthF); uni_vpaddd(shift01, shift00, vDataTypeSizeB); @@ -1330,8 +1341,8 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoor // PER CHANNEL LOOP Xbyak::Label lChannelLoopBegin, lChannelLoopEnd; RegistersPool::Reg rChannel; - auto rSrcTmp = getReg64(); - auto rDstTmp = getReg64(); + auto rSrcTmp = getReg64(); + auto rDstTmp = getReg64(); mov(rSrcTmp, regSrc); mov(rDstTmp, regDst); @@ -1349,11 +1360,11 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoor if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) { kmovw(kAuxMask, kMask00); } - gatherdd(vQ0, rSrcTmp, shift00, kAuxMask, useMask, zeroFill); // v00 -> vQ0 + gatherdd(vQ0, rSrcTmp, shift00, kAuxMask, useMask, zeroFill); // v00 -> vQ0 if (jcp.inDataPrc == ov::element::i32) { uni_vcvtdq2ps(vQ0, vQ0); } - uni_vfmsub213ps(vQ0, vDX, vQ0); // q0 = -(v00 - dx * v00) + uni_vfmsub213ps(vQ0, vDX, vQ0); // q0 = -(v00 - dx * v00) // (y; x + 1) if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) { @@ -1363,7 +1374,7 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoor if (jcp.inDataPrc == ov::element::i32) { uni_vcvtdq2ps(vAux, vAux); } - uni_vfmsub231ps(vQ0, vAux, vDX); // q0 = -q0 + dx * v01 + uni_vfmsub231ps(vQ0, vAux, vDX); // q0 = -q0 + dx * v01 // (y + 1; x + 1) if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) { @@ -1383,14 +1394,14 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoor uni_vcvtdq2ps(vQ1, vQ1); } - uni_vfmsub213ps(vQ1, vDX, vQ1); // q1 = -(v10 - dx * v10) - uni_vfmsub231ps(vQ1, vAux, vDX); // q1 = -q1 + dx * v11 + uni_vfmsub213ps(vQ1, vDX, vQ1); // q1 = -(v10 - dx * v10) + uni_vfmsub231ps(vQ1, vAux, vDX); // q1 = -q1 + dx * v11 // Res = q0 + dy * (q1 - q0) uni_vsubps(vQ1, vQ1, vQ0); uni_vfmadd132ps(vQ1, vQ0, vDY); if (jcp.inDataPrc == ov::element::i32) { - uni_vroundps(vQ1, vQ1, 0x3); // Truncation + uni_vroundps(vQ1, vQ1, 0x3); // Truncation uni_vcvtps2dq(vQ1, vQ1); } @@ -1410,20 +1421,20 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoor } } -template // Works for AVX2, AVX, SSE41 +template // Works for AVX2, AVX, SSE41 void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoord, const Vmm& vHCoord, bool tail) { auto vWRound = getVmm(); auto vHRound = getVmm(); - auto& vDX = vWCoord; - auto& vDY = vHCoord; - auto vAux = getVmm(); + auto& vDX = vWCoord; + auto& vDY = vHCoord; + auto vAux = getVmm(); Vmm shift00, shift01, shift10, shift11; RegistersPool::Reg shift10Holder, shift11Holder; // For ZEROS padding only. RegistersPool::Reg vMask00, vMask01, vMask10, vMask11; - uni_vroundps(vWRound, vWCoord, 0x1); // Round floor - uni_vroundps(vHRound, vHCoord, 0x1); // Round floor + uni_vroundps(vWRound, vWCoord, 0x1); // Round floor + uni_vroundps(vHRound, vHCoord, 0x1); // Round floor uni_vsubps(vDX, vDX, vWRound); uni_vsubps(vDY, vDY, vHRound); @@ -1444,9 +1455,9 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoord, const Vmm& useMask = zeroFill = true; { auto rAux = getReg64(); - static const float onesArr[8] = { 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f }; - if (isa ==x64::sse41) { - static const float *onesPtr = onesArr + (reinterpret_cast(onesArr) % 16) / sizeof(float); + static const float onesArr[8] = {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}; + if (isa == x64::sse41) { + static const float* onesPtr = onesArr + (reinterpret_cast(onesArr) % 16) / sizeof(float); mov(rAux, reinterpret_cast(onesPtr)); } else { mov(rAux, reinterpret_cast(onesArr)); @@ -1463,10 +1474,10 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoord, const Vmm& uni_vaddps(vMask00, vWRound, vAux); uni_vaddps(vAux, vAux, vHRound); - zerosPadding(vMask01, vHRound, vMask00); // (y; x + 1) - zerosPadding(vMask10, vAux, vWRound); // (y + 1; x) - zerosPadding(vMask11, vAux, vMask00); // (y + 1; x + 1) - zerosPadding(vMask00, vHRound, vWRound); // (y; x) + zerosPadding(vMask01, vHRound, vMask00); // (y; x + 1) + zerosPadding(vMask10, vAux, vWRound); // (y + 1; x) + zerosPadding(vMask11, vAux, vMask00); // (y + 1; x + 1) + zerosPadding(vMask00, vHRound, vWRound); // (y; x) hwShiftPs2dq(shift00, vHRound, vWRound, vSrcWidthF); } else if (jcp.paddingMode == GridSamplePaddingMode::BORDER) { @@ -1490,17 +1501,17 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoord, const Vmm& } auto vGatherMask = getVmm(); - auto vQ0 = getVmm(); - auto vQ1 = getVmm(); + auto vQ0 = getVmm(); + auto vQ1 = getVmm(); // PER CHANNEL LOOP Xbyak::Label lChannelLoopBegin, lChannelLoopEnd; RegistersPool::Reg rChannel; - auto rSrcTmp = getReg64(); - auto rDstTmp = getReg64(); + auto rSrcTmp = getReg64(); + auto rDstTmp = getReg64(); auto rTypeSize = getReg64(); - mov(rSrcTmp, regSrc); - mov(rDstTmp, regDst); + mov(rSrcTmp, regSrc); + mov(rDstTmp, regDst); mov(rTypeSize, ptr[regParams + GET_OFF(dataTypeSize)]); for (uint64_t ch = 0; ch < jcp.cannelNum; ch++) { @@ -1517,12 +1528,17 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoord, const Vmm& if (jcp.paddingMode == GridSamplePaddingMode::ZEROS && isa == x64::avx2) { uni_vmovups(vGatherMask, vMask00); } - gatherdd(vQ0, rSrcTmp, shift00, (isa == x64::avx2 || !vMask00.isInitialized()) ? vGatherMask : vMask00, useMask, zeroFill); // v00 -> vQ0 + gatherdd(vQ0, + rSrcTmp, + shift00, + (isa == x64::avx2 || !vMask00.isInitialized()) ? vGatherMask : vMask00, + useMask, + zeroFill); // v00 -> vQ0 if (jcp.inDataPrc == ov::element::i32) { uni_vcvtdq2ps(vQ0, vQ0); } if (isa == x64::avx2) { - uni_vfmsub213ps(vQ0, vDX, vQ0); // q0 = -(v00 - dx * v00) + uni_vfmsub213ps(vQ0, vDX, vQ0); // q0 = -(v00 - dx * v00) } else { uni_vmulps(vGatherMask, vQ0, vDX); uni_vsubps(vQ0, vQ0, vGatherMask); @@ -1534,13 +1550,17 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoord, const Vmm& if (isa == x64::avx2) uni_vmovups(vGatherMask, vMask01); } - gatherdd(vAux, rSrcTmp, jcp.paddingMode != GridSamplePaddingMode::ZEROS ? shift01 : shift10, - (isa == x64::avx2 || !vMask01.isInitialized()) ? vGatherMask : vMask01, useMask, zeroFill); + gatherdd(vAux, + rSrcTmp, + jcp.paddingMode != GridSamplePaddingMode::ZEROS ? shift01 : shift10, + (isa == x64::avx2 || !vMask01.isInitialized()) ? vGatherMask : vMask01, + useMask, + zeroFill); if (jcp.inDataPrc == ov::element::i32) { uni_vcvtdq2ps(vAux, vAux); } if (isa == x64::avx2) { - uni_vfmsub231ps(vQ0, vAux, vDX); // q0 = -q0 + dx * v01 + uni_vfmsub231ps(vQ0, vAux, vDX); // q0 = -q0 + dx * v01 } else { uni_vmulps(vAux, vAux, vDX); uni_vaddps(vQ0, vQ0, vAux); @@ -1556,8 +1576,12 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoord, const Vmm& if (isa == x64::avx2) uni_vmovups(vGatherMask, vMask11); } - gatherdd(vAux, rSrcTmp, jcp.paddingMode != GridSamplePaddingMode::ZEROS ? shift11 : shift10, - (isa == x64::avx2 || !vMask11.isInitialized()) ? vGatherMask : vMask11, useMask, zeroFill); + gatherdd(vAux, + rSrcTmp, + jcp.paddingMode != GridSamplePaddingMode::ZEROS ? shift11 : shift10, + (isa == x64::avx2 || !vMask11.isInitialized()) ? vGatherMask : vMask11, + useMask, + zeroFill); if (jcp.inDataPrc == ov::element::i32) { uni_vcvtdq2ps(vAux, vAux); } @@ -1568,7 +1592,12 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoord, const Vmm& if (isa == x64::avx2) uni_vmovups(vGatherMask, vMask10); } - gatherdd(vQ1, rSrcTmp, shift10, (isa == x64::avx2 || !vMask10.isInitialized()) ? vGatherMask : vMask10, useMask, zeroFill); + gatherdd(vQ1, + rSrcTmp, + shift10, + (isa == x64::avx2 || !vMask10.isInitialized()) ? vGatherMask : vMask10, + useMask, + zeroFill); if (jcp.inDataPrc == ov::element::i32) { uni_vcvtdq2ps(vQ1, vQ1); } @@ -1585,13 +1614,13 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoord, const Vmm& uni_vmovups(vQ1, vGatherMask); } } - uni_vfmsub231ps(vQ1, vAux, vDX); // q1 = -q1 + dx * v11 + uni_vfmsub231ps(vQ1, vAux, vDX); // q1 = -q1 + dx * v11 // Res = q0 + dy * (q1 - q0) uni_vsubps(vQ1, vQ1, vQ0); uni_vfmadd132ps(vQ1, vQ0, vDY); if (jcp.inDataPrc == ov::element::i32) { - uni_vroundps(vQ1, vQ1, 0x3); // Truncation + uni_vroundps(vQ1, vQ1, 0x3); // Truncation uni_vcvtps2dq(vQ1, vQ1); } @@ -1614,27 +1643,27 @@ void GridSampleKernel::bilinearInterpolation(const Vmm& vWCoord, const Vmm& template <> void GridSampleKernel::bicubicInterpolation(const Vmm& vWCoord, const Vmm& vHCoord, bool tail) { - auto vHTop = getVmm(); - auto vWLeft = getVmm(); - auto vDX = getVmm(); - auto vDY = getVmm(); - auto vXDotProd = getVmm(); + auto vHTop = getVmm(); + auto vWLeft = getVmm(); + auto vDX = getVmm(); + auto vDY = getVmm(); + auto vXDotProd = getVmm(); auto& vYDotProd = vDX; auto vSrcShift0 = getVmm(); - auto vSrcShift = getVmm(); - auto vAux = getVmm(); - auto kAuxMask = getMask(); + auto vSrcShift = getVmm(); + auto vAux = getVmm(); + auto kAuxMask = getMask(); RegistersPool::Reg kMaskH; std::vector> wMasks; - uni_vroundps(vHTop, vHCoord, 0x1); // Round floor - uni_vroundps(vWLeft, vWCoord, 0x1); // Round floor + uni_vroundps(vHTop, vHCoord, 0x1); // Round floor + uni_vroundps(vWLeft, vWCoord, 0x1); // Round floor uni_vsubps(vDY, vHCoord, vHTop); uni_vsubps(vDX, vWCoord, vWLeft); uni_vsubps(vHTop, vHTop, vOnesF); uni_vsubps(vWLeft, vWLeft, vOnesF); - RegistersPool::Reg vCX[4] = {getVmm(), getVmm(), getVmm(), getVmm() }; + RegistersPool::Reg vCX[4] = {getVmm(), getVmm(), getVmm(), getVmm()}; for (int i = 0; i < 4; i++) { bicubicCoefficients(vCX[i], vDX, i); } @@ -1659,8 +1688,8 @@ void GridSampleKernel::bicubicInterpolation(const Vmm& vWCoord // PER CHANNEL LOOP Xbyak::Label lChannelLoopBegin, lChannelLoopEnd; RegistersPool::Reg rChannel; - auto rSrcTmp = getReg64(); - auto rDstTmp = getReg64(); + auto rSrcTmp = getReg64(); + auto rDstTmp = getReg64(); mov(rSrcTmp, regSrc); mov(rDstTmp, regDst); @@ -1742,7 +1771,7 @@ void GridSampleKernel::bicubicInterpolation(const Vmm& vWCoord } if (jcp.inDataPrc == ov::element::i32) { - uni_vroundps(vYDotProd, vYDotProd, 0x3); // Truncation + uni_vroundps(vYDotProd, vYDotProd, 0x3); // Truncation uni_vcvtps2dq(vYDotProd, vYDotProd); } @@ -1762,15 +1791,15 @@ void GridSampleKernel::bicubicInterpolation(const Vmm& vWCoord } } -template // Works for AVX2, AVX, SSE41 +template // Works for AVX2, AVX, SSE41 void GridSampleKernel::bicubicInterpolation(const Vmm& vWCoord, const Vmm& vHCoord, bool tail) { - auto vHTop = getVmm(); + auto vHTop = getVmm(); auto vWLeft = getVmm(); - auto vDX = getVmm(); - auto vDY = getVmm(); + auto vDX = getVmm(); + auto vDY = getVmm(); - uni_vroundps(vHTop, vHCoord, 0x1); // Round floor - uni_vroundps(vWLeft, vWCoord, 0x1); // Round floor + uni_vroundps(vHTop, vHCoord, 0x1); // Round floor + uni_vroundps(vWLeft, vWCoord, 0x1); // Round floor uni_vsubps(vDY, vHCoord, vHTop); uni_vsubps(vDX, vWCoord, vWLeft); uni_vsubps(vHTop, vHTop, vOnesF); @@ -1791,7 +1820,7 @@ void GridSampleKernel::bicubicInterpolation(const Vmm& vWCoord, const Vmm& } auto vW0 = getVmm(), vW1 = getVmm(); - Vmm vW[4] = { vW0, vW1, vHCoord, vWCoord }; + Vmm vW[4] = {vW0, vW1, vHCoord, vWCoord}; for (int w = 0; w < 4; w++) { borderPadding(vW[w], vWLeft, coord::w); if (w < 3) { @@ -1806,7 +1835,7 @@ void GridSampleKernel::bicubicInterpolation(const Vmm& vWCoord, const Vmm& mov(rAux, ptr[regParams + GET_OFF(srcHeightSub1F)]); uni_vmovups(vSrcHeightSub1F, ptr[rAux]); } - auto vH = getVmm(); + auto vH = getVmm(); size_t bufShift = 0lu; for (int h = 0; h < 4; h++) { @@ -1839,7 +1868,7 @@ void GridSampleKernel::bicubicInterpolation(const Vmm& vWCoord, const Vmm& } auto vW0 = getVmm(), vW1 = getVmm(); - Vmm vW[4] = { vW0, vW1, vHCoord, vWCoord }; + Vmm vW[4] = {vW0, vW1, vHCoord, vWCoord}; for (int w = 0; w < 4; w++) { reflectionPadding(vW[w], vWLeft, coord::w); if (w < 3) { @@ -1860,7 +1889,7 @@ void GridSampleKernel::bicubicInterpolation(const Vmm& vWCoord, const Vmm& mov(rAux, ptr[regParams + GET_OFF(srcHeightMul2Sub1F)]); uni_vmovups(vSrcHeightMul2Sub1F, ptr[rAux]); } - auto vH = getVmm(); + auto vH = getVmm(); size_t bufShift = 0lu; for (int h = 0; h < 4; h++) { @@ -1883,7 +1912,7 @@ void GridSampleKernel::bicubicInterpolation(const Vmm& vWCoord, const Vmm& } else if (jcp.paddingMode == GridSamplePaddingMode::ZEROS) { useMask = zeroFill = true; - RegistersPool::Reg vWMask[4] = { getVmm(), getVmm(), getVmm(), getVmm() }; + RegistersPool::Reg vWMask[4] = {getVmm(), getVmm(), getVmm(), getVmm()}; for (int w = 0; w < 4; w++) { if (w == 0) { zerosPaddingW(vWMask[w], vWLeft); @@ -1933,21 +1962,21 @@ void GridSampleKernel::bicubicInterpolation(const Vmm& vWCoord, const Vmm& vDataTypeSizeB.release(); } - RegistersPool::Reg vCX[4] = { getVmm(), getVmm(), getVmm(), getVmm() }; + RegistersPool::Reg vCX[4] = {getVmm(), getVmm(), getVmm(), getVmm()}; for (int w = 0; w < 4; w++) { bicubicCoefficients(vCX[w], vDX, w); } auto vCY0 = getVmm(), vCY1 = getVmm(); - Vmm vCY[4] = { vCY0, vCY1, vHCoord, vWCoord }; + Vmm vCY[4] = {vCY0, vCY1, vHCoord, vWCoord}; for (int h = 0; h < 4; h++) { bicubicCoefficients(vCY[h], vDY, h); } const auto& vXDotProd = vDX; const auto& vYDotProd = vDY; - auto vSrcShift = getVmm(); + auto vSrcShift = getVmm(); auto kGatherMask = getVmm(); - auto vAux = getVmm(); + auto vAux = getVmm(); // PER CHANNEL LOOP Xbyak::Label lChannelLoopBegin, lChannelLoopEnd; @@ -2003,7 +2032,7 @@ void GridSampleKernel::bicubicInterpolation(const Vmm& vWCoord, const Vmm& } if (jcp.inDataPrc == ov::element::i32) { - uni_vroundps(vYDotProd, vYDotProd, 0x3); // Truncation + uni_vroundps(vYDotProd, vYDotProd, 0x3); // Truncation uni_vcvtps2dq(vYDotProd, vYDotProd); } @@ -2028,7 +2057,7 @@ void GridSampleKernel::dataTypeShiftPs2Dq(const Vmm& vDst, const Vmm& vSrc) if (dataTypeSize == 1) return; - if (isa == x64::avx) { // vpslld works just with XMM for AVX, so use vmulps for YMM + if (isa == x64::avx) { // vpslld works just with XMM for AVX, so use vmulps for YMM auto rAux = getReg64(); static const float val = dataTypeSize; static const float dataTypeSizeArr[8] = {val, val, val, val, val, val, val, val}; @@ -2038,7 +2067,7 @@ void GridSampleKernel::dataTypeShiftPs2Dq(const Vmm& vDst, const Vmm& vSrc) } else { uni_vcvtps2dq(vDst, vSrc); if (dataTypeSize > 1) - uni_vpslld(vDst, vDst, dataTypeShift); // multiply by source data type size. + uni_vpslld(vDst, vDst, dataTypeShift); // multiply by source data type size. } } @@ -2066,7 +2095,7 @@ void GridSampleKernel::hwShiftPs2dq(const Vmm& vDst, const Vmm& vHCoord, co } } - if (isa == x64::avx) { // vpslld works just with XMM for AVX, so use vmulps for YMM + if (isa == x64::avx) { // vpslld works just with XMM for AVX, so use vmulps for YMM if (dataTypeSize > 1) { auto rAux = getReg64(); const float val = dataTypeSize; @@ -2078,7 +2107,7 @@ void GridSampleKernel::hwShiftPs2dq(const Vmm& vDst, const Vmm& vHCoord, co } else { uni_vcvtps2dq(vDst, vDst); if (dataTypeSize > 1) - uni_vpslld(vDst, vDst, dataTypeShift); // multiply by source data type size. + uni_vpslld(vDst, vDst, dataTypeShift); // multiply by source data type size. } } @@ -2086,6 +2115,6 @@ template class GridSampleKernel; template class GridSampleKernel; template class GridSampleKernel; -} // namespace kernel -} // namespace intel_cpu -} // namespace ov +} // namespace kernel +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/grid_sample.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/grid_sample.hpp index cb13d62c3509d1..f276580a837bd2 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/grid_sample.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/grid_sample.hpp @@ -4,9 +4,10 @@ #pragma once -#include "jit_kernel_base.hpp" #include +#include "jit_kernel_base.hpp" + namespace ov { namespace intel_cpu { @@ -20,16 +21,16 @@ class GridSampleKernelBase; #if defined(OPENVINO_ARCH_X86_64) struct GridSampleKernelConfParams { - bool dynamicShapes = false; - bool dynamicBatch = false; + bool dynamicShapes = false; + bool dynamicBatch = false; bool dynamicChannel = false; - bool alignCorners = false; + bool alignCorners = false; GridSampleInterpolationMode interpolationMode = GridSampleInterpolationMode::BILINEAR; GridSamplePaddingMode paddingMode = GridSamplePaddingMode::ZEROS; ov::element::Type inDataPrc; ov::element::Type gridPrc; - uint64_t batchNum = 1lu; - uint64_t cannelNum = 1lu; + uint64_t batchNum = 1lu; + uint64_t cannelNum = 1lu; uint64_t srcBatchStepB = 0lu; }; @@ -37,13 +38,13 @@ struct GridSamplesKernelExecArgs { const void* src; const void* grid; void* dst; - uint64_t batchNum = 1lu; + uint64_t batchNum = 1lu; uint64_t channelsNum = 1lu; const float* srcWidthF; const float* srcHeightF; - uint64_t srcBatchStepB = 0lu; - uint64_t gridBatchStepB = 0lu; - uint64_t dstBatchStepB = 0lu; + uint64_t srcBatchStepB = 0lu; + uint64_t gridBatchStepB = 0lu; + uint64_t dstBatchStepB = 0lu; uint64_t srcChannelStepB = 0lu; uint64_t dstChannelStepB = 0lu; const void* wDenormCoefF; @@ -60,19 +61,21 @@ struct GridSamplesKernelExecArgs { uint64_t workAmount = 0lu; }; -enum coord { - w, h -}; +enum coord { w, h }; -class GridSampleKernelBase: public JitKernelBase { +class GridSampleKernelBase : public JitKernelBase { public: - void (*ker_)(const GridSamplesKernelExecArgs *); - void operator()(const GridSamplesKernelExecArgs *args) { + void (*ker_)(const GridSamplesKernelExecArgs*); + void operator()(const GridSamplesKernelExecArgs* args) { assert(ker_); ker_(args); } - explicit GridSampleKernelBase(const char* name, const GridSampleKernelConfParams& jcp, dnnl::impl::cpu::x64::cpu_isa_t isa) - : JitKernelBase(name, isa), ker_(nullptr), jcp(jcp) {} + explicit GridSampleKernelBase(const char* name, + const GridSampleKernelConfParams& jcp, + dnnl::impl::cpu::x64::cpu_isa_t isa) + : JitKernelBase(name, isa), + ker_(nullptr), + jcp(jcp) {} virtual void create_ker() = 0; uint64_t getVecLen() { @@ -87,7 +90,7 @@ class GridSampleKernelBase: public JitKernelBase { protected: GridSampleKernelConfParams jcp; - uint64_t vlen = 16lu; + uint64_t vlen = 16lu; uint64_t dataTypeSize = 1lu; uint64_t gridTypeSize = 1lu; uint64_t dataElPerVec = 1lu; @@ -104,12 +107,16 @@ class GridSampleKernel : public GridSampleKernelBase { void create_ker() override; void generate() override; - using Vmm = typename dnnl::impl::utils::conditional3::type; - using Vmask = typename dnnl::impl::utils::conditional3::type; + using Vmm = typename dnnl::impl::utils::conditional3::type; + using Vmask = typename dnnl::impl::utils::conditional3::type; private: uint8_t dataTypeShift = 0; @@ -138,23 +145,23 @@ class GridSampleKernel : public GridSampleKernelBase { RegistersPool::Reg vWDenormCoefF; RegistersPool::Reg vHDenormCoefF; RegistersPool::Reg vGridPermMask; - RegistersPool::Reg vDataTypeSizeB; // for ZEROS padding - RegistersPool::Reg vSrcWidthB; // for ZEROS padding + RegistersPool::Reg vDataTypeSizeB; // for ZEROS padding + RegistersPool::Reg vSrcWidthB; // for ZEROS padding - RegistersPool::Reg vSrcHeightSub1F; // for BORDER padding - RegistersPool::Reg vSrcWidthSub1F; // for BORDER padding + RegistersPool::Reg vSrcHeightSub1F; // for BORDER padding + RegistersPool::Reg vSrcWidthSub1F; // for BORDER padding - RegistersPool::Reg vSrcHeightMul2F; // for REFLECTION padding - RegistersPool::Reg vSrcWidthMul2F; // for REFLECTION padding - RegistersPool::Reg vSrcHeightMul2Sub1F; // for REFLECTION padding - RegistersPool::Reg vSrcWidthMul2Sub1F; // for REFLECTION padding - RegistersPool::Reg vAbsMask; // for REFLECTION padding + RegistersPool::Reg vSrcHeightMul2F; // for REFLECTION padding + RegistersPool::Reg vSrcWidthMul2F; // for REFLECTION padding + RegistersPool::Reg vSrcHeightMul2Sub1F; // for REFLECTION padding + RegistersPool::Reg vSrcWidthMul2Sub1F; // for REFLECTION padding + RegistersPool::Reg vAbsMask; // for REFLECTION padding - RegistersPool::Reg vConst_0_75; // for BICUBIC interpolation - RegistersPool::Reg vConst_1_25; // for BICUBIC interpolation - RegistersPool::Reg vConst_1_50; // for BICUBIC interpolation - RegistersPool::Reg vConst_2_00; // for BICUBIC interpolation - RegistersPool::Reg vConst_2_25; // for BICUBIC interpolation + RegistersPool::Reg vConst_0_75; // for BICUBIC interpolation + RegistersPool::Reg vConst_1_25; // for BICUBIC interpolation + RegistersPool::Reg vConst_1_50; // for BICUBIC interpolation + RegistersPool::Reg vConst_2_00; // for BICUBIC interpolation + RegistersPool::Reg vConst_2_25; // for BICUBIC interpolation void initVectors(); void process(); @@ -179,8 +186,8 @@ class GridSampleKernel : public GridSampleKernelBase { void hwShiftPs2dq(const Vmm& vDst, const Vmm& vHCoord, const Vmm& vWCoord, const Vmm& vWidth); }; -#endif // OPENVINO_ARCH_X86_64 +#endif // OPENVINO_ARCH_X86_64 -} // namespace kernel -} // namespace intel_cpu -} // namespace ov +} // namespace kernel +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel.cpp index cd8b32d9ad2a38..2eb981007f2217 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel.cpp @@ -3,9 +3,10 @@ // #include "jit_kernel.hpp" -#include -#include + #include +#include +#include #include using namespace dnnl::impl; @@ -17,16 +18,16 @@ namespace intel_cpu { namespace { -template +template using registers = std::array, 16>; bool isRegAllocable(int id) { - return id != abi_param1.getIdx() // function argument - && id != Operand::Code::RSP; // stack pointer + return id != abi_param1.getIdx() // function argument + && id != Operand::Code::RSP; // stack pointer } -template -const RegType & reserveReg(jit_kernel::reg_indices & freeRegs, const registers & regs) { +template +const RegType& reserveReg(jit_kernel::reg_indices& freeRegs, const registers& regs) { if (freeRegs.empty()) throw std::runtime_error("No free registers"); const auto idx = freeRegs.back(); @@ -34,8 +35,8 @@ const RegType & reserveReg(jit_kernel::reg_indices & freeRegs, const registers -void freeReg(jit_kernel::reg_indices & freeRegs, const registers & regs, const RegType & reg) { +template +void freeReg(jit_kernel::reg_indices& freeRegs, const registers& regs, const RegType& reg) { const auto idx = reg.getIdx(); // Debug: // auto it = std::find(freeRegs.begin(), freeRegs.end(), idx); @@ -46,105 +47,189 @@ void freeReg(jit_kernel::reg_indices & freeRegs, const registers & regs OPENVINO_THROW("Some register was freed twice"); } -const registers & x64regs() { +const registers& x64regs() { using namespace Xbyak::util; - static const registers _x64regs {{ - rax, rcx, rdx, rbx, - rsp, rbp, rsi, rdi, - r8, r9, r10, r11, - r12, r13, r14, r15, + static const registers _x64regs{{ + rax, + rcx, + rdx, + rbx, + rsp, + rbp, + rsi, + rdi, + r8, + r9, + r10, + r11, + r12, + r13, + r14, + r15, }}; return _x64regs; } -const registers & x32regs() { +const registers& x32regs() { using namespace Xbyak::util; - static const registers _x32regs {{ - eax, ecx, edx, ebx, - esp, ebp, esi, edi, - r8d, r9d, r10d, r11d, - r12d, r13d, r14d, r15d, + static const registers _x32regs{{ + eax, + ecx, + edx, + ebx, + esp, + ebp, + esi, + edi, + r8d, + r9d, + r10d, + r11d, + r12d, + r13d, + r14d, + r15d, }}; return _x32regs; } -const registers & x16regs() { +const registers& x16regs() { using namespace Xbyak::util; - static const registers _x16regs {{ - ax, cx, dx, bx, - sp, bp, si, di, - r8w, r9w, r10w, r11w, - r12w, r13w, r14w, r15w, + static const registers _x16regs{{ + ax, + cx, + dx, + bx, + sp, + bp, + si, + di, + r8w, + r9w, + r10w, + r11w, + r12w, + r13w, + r14w, + r15w, }}; return _x16regs; } -const registers & x8regs() { +const registers& x8regs() { using namespace Xbyak::util; - static const registers _x8regs {{ - al, cl, dl, bl, - spl, bpl, sil, dil, - r8b, r9b, r10b, r11b, - r12b, r13b, r14b, r15b, + static const registers _x8regs{{ + al, + cl, + dl, + bl, + spl, + bpl, + sil, + dil, + r8b, + r9b, + r10b, + r11b, + r12b, + r13b, + r14b, + r15b, }}; return _x8regs; } -const registers & xmmregs() { - static const registers _xmmregs {{ - Xbyak::util::xmm0, Xbyak::util::xmm1, Xbyak::util::xmm2, Xbyak::util::xmm3, - Xbyak::util::xmm4, Xbyak::util::xmm5, Xbyak::util::xmm6, Xbyak::util::xmm7, - Xbyak::util::xmm8, Xbyak::util::xmm9, Xbyak::util::xmm10, Xbyak::util::xmm11, - Xbyak::util::xmm12, Xbyak::util::xmm13, Xbyak::util::xmm14, Xbyak::util::xmm15, +const registers& xmmregs() { + static const registers _xmmregs{{ + Xbyak::util::xmm0, + Xbyak::util::xmm1, + Xbyak::util::xmm2, + Xbyak::util::xmm3, + Xbyak::util::xmm4, + Xbyak::util::xmm5, + Xbyak::util::xmm6, + Xbyak::util::xmm7, + Xbyak::util::xmm8, + Xbyak::util::xmm9, + Xbyak::util::xmm10, + Xbyak::util::xmm11, + Xbyak::util::xmm12, + Xbyak::util::xmm13, + Xbyak::util::xmm14, + Xbyak::util::xmm15, }}; return _xmmregs; } -const registers & ymmregs() { - static const registers _ymmregs {{ - Xbyak::util::ymm0, Xbyak::util::ymm1, Xbyak::util::ymm2, Xbyak::util::ymm3, - Xbyak::util::ymm4, Xbyak::util::ymm5, Xbyak::util::ymm6, Xbyak::util::ymm7, - Xbyak::util::ymm8, Xbyak::util::ymm9, Xbyak::util::ymm10, Xbyak::util::ymm11, - Xbyak::util::ymm12, Xbyak::util::ymm13, Xbyak::util::ymm14, Xbyak::util::ymm15, +const registers& ymmregs() { + static const registers _ymmregs{{ + Xbyak::util::ymm0, + Xbyak::util::ymm1, + Xbyak::util::ymm2, + Xbyak::util::ymm3, + Xbyak::util::ymm4, + Xbyak::util::ymm5, + Xbyak::util::ymm6, + Xbyak::util::ymm7, + Xbyak::util::ymm8, + Xbyak::util::ymm9, + Xbyak::util::ymm10, + Xbyak::util::ymm11, + Xbyak::util::ymm12, + Xbyak::util::ymm13, + Xbyak::util::ymm14, + Xbyak::util::ymm15, }}; return _ymmregs; } -const registers & zmmregs() { - static const registers _zmmregs {{ - Xbyak::util::zmm0, Xbyak::util::zmm1, Xbyak::util::zmm2, Xbyak::util::zmm3, - Xbyak::util::zmm4, Xbyak::util::zmm5, Xbyak::util::zmm6, Xbyak::util::zmm7, - Xbyak::util::zmm8, Xbyak::util::zmm9, Xbyak::util::zmm10, Xbyak::util::zmm11, - Xbyak::util::zmm12, Xbyak::util::zmm13, Xbyak::util::zmm14, Xbyak::util::zmm15, +const registers& zmmregs() { + static const registers _zmmregs{{ + Xbyak::util::zmm0, + Xbyak::util::zmm1, + Xbyak::util::zmm2, + Xbyak::util::zmm3, + Xbyak::util::zmm4, + Xbyak::util::zmm5, + Xbyak::util::zmm6, + Xbyak::util::zmm7, + Xbyak::util::zmm8, + Xbyak::util::zmm9, + Xbyak::util::zmm10, + Xbyak::util::zmm11, + Xbyak::util::zmm12, + Xbyak::util::zmm13, + Xbyak::util::zmm14, + Xbyak::util::zmm15, }}; return _zmmregs; } -} // namespace +} // namespace namespace internal { -template<> +template <> ov::element::Type type2precision() { return ov::element::f32; } -template<> +template <> ov::element::Type type2precision() { return ov::element::i32; } -template<> +template <> ov::element::Type type2precision() { return ov::element::bf16; } -template<> +template <> ov::element::Type type2precision() { return ov::element::u8; } -template<> +template <> ov::element::Type type2precision() { return ov::element::i8; } @@ -157,27 +242,24 @@ cpu_isa_t get_current_isa() { return cpu_isa_t::sse41; } -stack_frame::stack_frame(ov::intel_cpu::jit_kernel & kernel, size_t size, uint32_t alignment) - : _kernel(kernel) - , _size(size) - , _alignment(alignment) { +stack_frame::stack_frame(ov::intel_cpu::jit_kernel& kernel, size_t size, uint32_t alignment) + : _kernel(kernel), + _size(size), + _alignment(alignment) { if (_size || _alignment) { if (_size && _alignment == 1) { _kernel.sub(_kernel.rsp, _size); } else { auto tmp = _kernel.var(); tmp = _kernel.rsp; - _kernel.sub(_kernel.rsp, sizeof(size_t) + size); // allocate - _kernel.and_(_kernel.rsp, ~(alignment - 1)); // align - _kernel.mov(_kernel.ptr[_kernel.rsp + size], tmp); // remember previous rsp + _kernel.sub(_kernel.rsp, sizeof(size_t) + size); // allocate + _kernel.and_(_kernel.rsp, ~(alignment - 1)); // align + _kernel.mov(_kernel.ptr[_kernel.rsp + size], tmp); // remember previous rsp } } } -stack_frame::stack_frame(stack_frame && rhs) - : _kernel(rhs._kernel) - , _size(rhs._size) - , _alignment(rhs._alignment) { +stack_frame::stack_frame(stack_frame&& rhs) : _kernel(rhs._kernel), _size(rhs._size), _alignment(rhs._alignment) { rhs._size = 0; rhs._alignment = 0; } @@ -192,25 +274,29 @@ stack_frame::~stack_frame() { } } -const Xbyak::Reg64 & stack_frame::pointer() const { +const Xbyak::Reg64& stack_frame::pointer() const { return _kernel.rsp; } void stack_frame::clear() const { const size_t end = _size & ~(size_t)7u; - _kernel.foreach(0, end, [&](const Reg64 & idx) { - _kernel.mov(_kernel.qword[pointer() + idx], 0); - }, sizeof(size_t)); + _kernel.foreach ( + 0, + end, + [&](const Reg64& idx) { + _kernel.mov(_kernel.qword[pointer() + idx], 0); + }, + sizeof(size_t)); if (end < _size) { - _kernel.foreach(end, _size, [&](const Reg64 & idx) { + _kernel.foreach (end, _size, [&](const Reg64& idx) { _kernel.mov(_kernel.byte[pointer() + idx], 0); }); } } -const void * consts_table::store(const void *data, size_t size) { +const void* consts_table::store(const void* data, size_t size) { if (size > chunk_size) throw std::runtime_error("Data size is too large"); const size_t capacity = _chunks.size() * chunk_size; @@ -218,17 +304,16 @@ const void * consts_table::store(const void *data, size_t size) { _size = _chunks.size() * chunk_size; _chunks.emplace_back(); } - auto & dst = _chunks.back(); + auto& dst = _chunks.back(); const size_t offset = _size % chunk_size; memcpy(&dst[offset], data, size); _size += size; return &dst[offset]; } -} // namespace internal +} // namespace internal -jit_kernel::jit_kernel(const char* name) - : jit_generator(name) { +jit_kernel::jit_kernel(const char* name) : jit_generator(name) { _free_rmmregs.reserve(16); _free_rmmregs.reserve(16); @@ -239,73 +324,73 @@ jit_kernel::jit_kernel(const char* name) } } -template<> -const Reg64 & jit_kernel::reserve() { +template <> +const Reg64& jit_kernel::reserve() { return reserveReg(_free_x64regs, x64regs()); } -template<> -const Reg32 & jit_kernel::reserve() { +template <> +const Reg32& jit_kernel::reserve() { return reserveReg(_free_x64regs, x32regs()); } -template<> -const Reg16 & jit_kernel::reserve() { +template <> +const Reg16& jit_kernel::reserve() { return reserveReg(_free_x64regs, x16regs()); } -template<> -const Reg8 & jit_kernel::reserve() { +template <> +const Reg8& jit_kernel::reserve() { return reserveReg(_free_x64regs, x8regs()); } -template<> -void jit_kernel::free(const Reg64 & reg) { +template <> +void jit_kernel::free(const Reg64& reg) { freeReg(_free_x64regs, x64regs(), reg); } -template<> -void jit_kernel::free(const Reg32 & reg) { +template <> +void jit_kernel::free(const Reg32& reg) { freeReg(_free_x64regs, x32regs(), reg); } -template<> -void jit_kernel::free(const Reg16 & reg) { +template <> +void jit_kernel::free(const Reg16& reg) { freeReg(_free_x64regs, x16regs(), reg); } -template<> -void jit_kernel::free(const Reg8 & reg) { +template <> +void jit_kernel::free(const Reg8& reg) { freeReg(_free_x64regs, x8regs(), reg); } -template<> -const Xmm & jit_kernel::reserve() { +template <> +const Xmm& jit_kernel::reserve() { return reserveReg(_free_rmmregs, xmmregs()); } -template<> -void jit_kernel::free(const Xmm & reg) { +template <> +void jit_kernel::free(const Xmm& reg) { freeReg(_free_rmmregs, xmmregs(), reg); } -template<> -const Ymm & jit_kernel::reserve() { +template <> +const Ymm& jit_kernel::reserve() { return reserveReg(_free_rmmregs, ymmregs()); } -template<> -void jit_kernel::free(const Ymm & reg) { +template <> +void jit_kernel::free(const Ymm& reg) { freeReg(_free_rmmregs, ymmregs(), reg); } -template<> -const Zmm & jit_kernel::reserve() { +template <> +const Zmm& jit_kernel::reserve() { return reserveReg(_free_rmmregs, zmmregs()); } -template<> -void jit_kernel::free(const Zmm & reg) { +template <> +void jit_kernel::free(const Zmm& reg) { freeReg(_free_rmmregs, zmmregs(), reg); } @@ -317,26 +402,33 @@ void jit_kernel::postamble() { } } -const AddressFrame & jit_kernel::address_frame(size_t size) const { - switch (size) { - case 1: return byte; - case 2: return word; - case 4: return dword; - case 8: return qword; - case 16: return xword; - case 32: return yword; - case 64: return zword; - default: - break; - } - return ptr; +const AddressFrame& jit_kernel::address_frame(size_t size) const { + switch (size) { + case 1: + return byte; + case 2: + return word; + case 4: + return dword; + case 8: + return qword; + case 16: + return xword; + case 32: + return yword; + case 64: + return zword; + default: + break; + } + return ptr; } -const jit_kernel::reg_indices & jit_kernel::free_x64regs() const { +const jit_kernel::reg_indices& jit_kernel::free_x64regs() const { return _free_x64regs; } -const jit_kernel::reg_indices & jit_kernel::free_rmmregs() const { +const jit_kernel::reg_indices& jit_kernel::free_rmmregs() const { return _free_rmmregs; } @@ -386,5 +478,5 @@ void jit_kernel::uni_vblendps(const Xbyak::Zmm& z1, const Xbyak::Zmm& z2, uint16 vblendmps(z1 | k1, z1, z2); } -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel.hpp index 8934bf5dff052b..0073ca91d0b76f 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel.hpp @@ -3,14 +3,15 @@ // #pragma once -#include "cpu/x64/jit_generator.hpp" -#include "emitters/plugin/x64/jit_load_store_emitters.hpp" +#include #include -#include #include -#include -#include #include +#include +#include + +#include "cpu/x64/jit_generator.hpp" +#include "emitters/plugin/x64/jit_load_store_emitters.hpp" namespace ov { namespace intel_cpu { @@ -19,113 +20,103 @@ struct jit_kernel; namespace internal { -template +template struct reg_traits_by_size; -template +template struct reg_traits; -template +template struct reg_traits; -template +template struct isa_traits; -template<> +template <> struct reg_traits_by_size<1> { using type = Xbyak::Reg8; - constexpr static size_t size = 1; // in bytes - constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa - = dnnl::impl::cpu::x64::cpu_isa_t::isa_undef; + constexpr static size_t size = 1; // in bytes + constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa = dnnl::impl::cpu::x64::cpu_isa_t::isa_undef; }; -template<> +template <> struct reg_traits_by_size<2> { using type = Xbyak::Reg16; - constexpr static size_t size = 2; // in bytes - constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa - = dnnl::impl::cpu::x64::cpu_isa_t::isa_undef; + constexpr static size_t size = 2; // in bytes + constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa = dnnl::impl::cpu::x64::cpu_isa_t::isa_undef; }; -template<> +template <> struct reg_traits_by_size<4> { using type = Xbyak::Reg32; - constexpr static size_t size = 4; // in bytes - constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa - = dnnl::impl::cpu::x64::cpu_isa_t::isa_undef; + constexpr static size_t size = 4; // in bytes + constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa = dnnl::impl::cpu::x64::cpu_isa_t::isa_undef; }; -template<> +template <> struct reg_traits_by_size<8> { using type = Xbyak::Reg64; - constexpr static size_t size = 8; // in bytes - constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa - = dnnl::impl::cpu::x64::cpu_isa_t::isa_undef; + constexpr static size_t size = 8; // in bytes + constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa = dnnl::impl::cpu::x64::cpu_isa_t::isa_undef; }; -template<> +template <> struct reg_traits_by_size<16> { using type = Xbyak::Xmm; - constexpr static size_t size = 16; // in bytes - constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa - = dnnl::impl::cpu::x64::cpu_isa_t::sse41; + constexpr static size_t size = 16; // in bytes + constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa = dnnl::impl::cpu::x64::cpu_isa_t::sse41; }; -template<> +template <> struct reg_traits_by_size<32> { using type = Xbyak::Ymm; - constexpr static size_t size = 32; // in bytes - constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa - = dnnl::impl::cpu::x64::cpu_isa_t::avx2; + constexpr static size_t size = 32; // in bytes + constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa = dnnl::impl::cpu::x64::cpu_isa_t::avx2; }; -template<> +template <> struct reg_traits_by_size<64> { using type = Xbyak::Zmm; - constexpr static size_t size = 64; // in bytes - constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa - = dnnl::impl::cpu::x64::cpu_isa_t::avx512_core; + constexpr static size_t size = 64; // in bytes + constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa = dnnl::impl::cpu::x64::cpu_isa_t::avx512_core; }; -template +template struct reg_traits : public reg_traits_by_size {}; -template +template struct vec_min_size { - constexpr static size_t size = N <= 16 ? 16 : - N <= 32 ? 32 : - 64; + constexpr static size_t size = N <= 16 ? 16 : N <= 32 ? 32 : 64; }; -template +template struct reg_traits : public reg_traits_by_size::size> {}; -template<> +template <> struct reg_traits { using type = Xbyak::Fpu; - constexpr static size_t size = 10; // in bytes - constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa - = dnnl::impl::cpu::x64::cpu_isa_t::isa_undef; + constexpr static size_t size = 10; // in bytes + constexpr static dnnl::impl::cpu::x64::cpu_isa_t isa = dnnl::impl::cpu::x64::cpu_isa_t::isa_undef; }; -template<> +template <> struct reg_traits : public reg_traits {}; -template<> +template <> struct isa_traits { struct reg { using type = Xbyak::Xmm; - constexpr static size_t size = 4 * 4; // in bytes - constexpr static size_t length = 4; // in dwords + constexpr static size_t size = 4 * 4; // in bytes + constexpr static size_t length = 4; // in dwords }; }; -template<> +template <> struct isa_traits { struct reg { using type = Xbyak::Ymm; - constexpr static size_t size = 8 * 4; // in bytes - constexpr static size_t length = 8; // in dwords + constexpr static size_t size = 8 * 4; // in bytes + constexpr static size_t length = 8; // in dwords }; }; -template<> +template <> struct isa_traits { struct reg { using type = Xbyak::Zmm; @@ -134,39 +125,39 @@ struct isa_traits { }; }; -template +template class variable; -template +template class if_expression; -template +template class then_expression; -template +template using shared_reg = std::shared_ptr; -template -shared_reg make_shared(Reg & reg, jit_kernel & kernel); +template +shared_reg make_shared(Reg& reg, jit_kernel& kernel); -template +template class boolean_expression { public: using reg_type = const typename reg_traits::type; enum class type { - eq, // == - neq, // != - ls, // < - gt, // > - le, // <= - ge // >= + eq, // == + neq, // != + ls, // < + gt, // > + le, // <= + ge // >= }; - boolean_expression(jit_kernel & kernel, type t, const shared_reg & lhs, const shared_reg & rhs); - boolean_expression(jit_kernel & kernel, type t, const shared_reg & lhs, T rhs); + boolean_expression(jit_kernel& kernel, type t, const shared_reg& lhs, const shared_reg& rhs); + boolean_expression(jit_kernel& kernel, type t, const shared_reg& lhs, T rhs); private: - void cmp(const Xbyak::Label & exit) const; + void cmp(const Xbyak::Label& exit) const; - jit_kernel & _kernel; + jit_kernel& _kernel; type _type; shared_reg _lhs; shared_reg _rhs; @@ -176,33 +167,33 @@ class boolean_expression { friend class then_expression; }; -template +template class then_expression { public: - then_expression(if_expression & expr); + then_expression(if_expression& expr); - template - void _else(F && fn); + template + void _else(F&& fn); private: - if_expression & _if_expr; + if_expression& _if_expr; }; -template +template class if_expression { public: - if_expression(const boolean_expression & expr) - : _expr(expr) {} + if_expression(const boolean_expression& expr) : _expr(expr) {} ~if_expression() { try { if (!_is_exit_valid) _expr._kernel.assignL(_exit, _else); - } catch(...) {} + } catch (...) { + } } - template - then_expression _then(F && fn) { + template + then_expression _then(F&& fn) { using namespace Xbyak; _expr.cmp(_else); @@ -214,7 +205,7 @@ class if_expression { } private: - const boolean_expression & _expr; + const boolean_expression& _expr; Xbyak::Label _exit; Xbyak::Label _else; bool _is_exit_valid = false; @@ -222,287 +213,291 @@ class if_expression { friend class then_expression; }; -typedef struct register_tag {} register_tag; -typedef struct memory_tag {} memory_tag; +typedef struct register_tag { +} register_tag; +typedef struct memory_tag { +} memory_tag; -template +template class variable_base; -template +template class variable_base { public: using reg_type = const typename reg_traits::type; - variable_base & operator = (const variable_base &) = delete; + variable_base& operator=(const variable_base&) = delete; - variable_base(const variable_base &); - variable_base(variable_base &&); + variable_base(const variable_base&); + variable_base(variable_base&&); - reg_type & reg() const { - return *_reg; + reg_type& reg() const { + return *_reg; } - const shared_reg & shreg() const { + const shared_reg& shreg() const { return _reg; } - operator reg_type &() const { + operator reg_type&() const { return reg(); } - operator Xbyak::RegExp () const { + operator Xbyak::RegExp() const { return reg(); } protected: - variable_base(jit_kernel & krnl, const shared_reg & reg); + variable_base(jit_kernel& krnl, const shared_reg& reg); ~variable_base() = default; - jit_kernel & _kernel; + jit_kernel& _kernel; shared_reg _reg; }; -template +template class variable_base { public: using reg_type = const typename reg_traits::type; - variable_base & operator = (const variable_base &) = delete; + variable_base& operator=(const variable_base&) = delete; - variable_base(const variable_base &); - variable_base(variable_base &&); + variable_base(const variable_base&); + variable_base(variable_base&&); - reg_type & reg() const { - return *_addr; + reg_type& reg() const { + return *_addr; } protected: - variable_base(jit_kernel & krnl, const shared_reg & addr); + variable_base(jit_kernel& krnl, const shared_reg& addr); ~variable_base() = default; - jit_kernel & _kernel; + jit_kernel& _kernel; shared_reg _addr; }; -template -class variable : public variable_base::value, T>::type, register_tag> { +template +class variable + : public variable_base::value, T>::type, register_tag> { public: using type = T; using base = variable_base; using reg_type = const typename base::reg_type; using arithmetic_type = typename std::conditional::value, size_t, T>::type; - variable(variable &&) = default; - variable(jit_kernel & krnl); - variable(jit_kernel & krnl, const shared_reg & reg); + variable(variable&&) = default; + variable(jit_kernel& krnl); + variable(jit_kernel& krnl, const shared_reg& reg); - typename std::conditional::value - && !std::is_pointer::type>::value, - variable::type, memory_tag>, void>::type - operator *() const { + typename std::conditional::value && + !std::is_pointer::type>::value, + variable::type, memory_tag>, + void>::type + operator*() const { return variable::type, memory_tag>(base::_kernel, base::shreg()); } - const variable & operator = (reg_type & rhs) const { + const variable& operator=(reg_type& rhs) const { base::_kernel.mov(base::reg(), rhs); return *this; } - template - const variable & operator = (U *rhs) const { + template + const variable& operator=(U* rhs) const { // interpret pointers as size_t base::_kernel.mov(base::reg(), reinterpret_cast(rhs)); return *this; } - const variable & operator = (arithmetic_type rhs) const { + const variable& operator=(arithmetic_type rhs) const { base::_kernel.mov(base::reg(), static_cast(rhs)); return *this; } - const variable & operator += (reg_type & rhs) const { + const variable& operator+=(reg_type& rhs) const { base::_kernel.add(base::reg(), rhs); return *this; } - variable operator + (reg_type & rhs) const { + variable operator+(reg_type& rhs) const { variable res(base::_kernel); res = base::reg(); res += rhs; return res; } - const variable & operator += (arithmetic_type rhs) const { + const variable& operator+=(arithmetic_type rhs) const { base::_kernel.add(base::reg(), rhs); return *this; } - variable operator + (arithmetic_type rhs) const { + variable operator+(arithmetic_type rhs) const { variable res(base::_kernel); res = base::reg(); res += rhs; return res; } - const variable & operator -= (reg_type & rhs) const { + const variable& operator-=(reg_type& rhs) const { base::_kernel.sub(base::reg(), rhs); return *this; } - variable operator - (reg_type & rhs) const { + variable operator-(reg_type& rhs) const { variable res(base::_kernel); res = base::reg(); res -= rhs; return res; } - const variable & operator -= (arithmetic_type rhs) const { + const variable& operator-=(arithmetic_type rhs) const { base::_kernel.sub(base::reg(), rhs); return *this; } - variable operator - (arithmetic_type rhs) const { + variable operator-(arithmetic_type rhs) const { variable res(base::_kernel); res = base::reg(); res -= rhs; return res; } - const variable & operator *= (reg_type & rhs) const { + const variable& operator*=(reg_type& rhs) const { base::_kernel.imul(base::reg(), rhs); return *this; } - variable operator * (reg_type & rhs) const { + variable operator*(reg_type& rhs) const { variable res(base::_kernel); res = base::reg(); res *= rhs; return res; } - const variable & operator *= (arithmetic_type rhs) const { + const variable& operator*=(arithmetic_type rhs) const { base::_kernel.imul(base::reg(), base::reg(), static_cast(rhs)); return *this; } - variable operator * (arithmetic_type rhs) const { + variable operator*(arithmetic_type rhs) const { variable res(base::_kernel); res = base::reg(); res *= rhs; return res; } - const variable & operator &= (reg_type & rhs) const { + const variable& operator&=(reg_type& rhs) const { base::_kernel.and_(base::reg(), rhs); return *this; } - variable operator & (reg_type & rhs) const { + variable operator&(reg_type& rhs) const { variable res(base::_kernel); res = base::reg(); res &= rhs; return res; } - const variable & operator &= (T rhs) const { + const variable& operator&=(T rhs) const { base::_kernel.and_(base::reg(), rhs); return *this; } - variable operator & (T rhs) const { + variable operator&(T rhs) const { variable res(base::_kernel); res = base::reg(); res &= rhs; return res; } - const variable & operator |= (reg_type & rhs) const { + const variable& operator|=(reg_type& rhs) const { base::_kernel.or_(base::reg(), rhs); return *this; } - variable operator | (reg_type & rhs) const { + variable operator|(reg_type& rhs) const { variable res(base::_kernel); res = base::reg(); res |= rhs; return res; } - const variable & operator |= (T rhs) const { + const variable& operator|=(T rhs) const { base::_kernel.or_(base::reg(), rhs); return *this; } - variable operator | (T rhs) const { + variable operator|(T rhs) const { variable res(base::_kernel); res = base::reg(); res |= rhs; return res; } - const variable & operator >>= (size_t rhs) const { + const variable& operator>>=(size_t rhs) const { base::_kernel.shr(base::reg(), rhs); return *this; } - variable operator >> (size_t rhs) const { + variable operator>>(size_t rhs) const { variable res(base::_kernel); res = base::reg(); res >>= rhs; return res; } - const variable & operator <<= (size_t rhs) const { + const variable& operator<<=(size_t rhs) const { base::_kernel.shl(base::reg(), rhs); return *this; } - variable operator << (size_t rhs) const { + variable operator<<(size_t rhs) const { variable res(base::_kernel); res = base::reg(); res <<= rhs; return res; } - boolean_expression operator == (const variable & rhs) const { + boolean_expression operator==(const variable& rhs) const { return boolean_expression(base::_kernel, boolean_expression::type::eq, base::shreg(), rhs.shreg()); } - boolean_expression operator == (T rhs) const { + boolean_expression operator==(T rhs) const { return boolean_expression(base::_kernel, boolean_expression::type::eq, base::shreg(), rhs); } - boolean_expression operator != (const variable & rhs) const { + boolean_expression operator!=(const variable& rhs) const { return boolean_expression(base::_kernel, boolean_expression::type::neq, base::shreg(), rhs.shreg()); } - boolean_expression operator != (T rhs) const { + boolean_expression operator!=(T rhs) const { return boolean_expression(base::_kernel, boolean_expression::type::neq, base::shreg(), rhs); } - boolean_expression operator < (const variable & rhs) const { + boolean_expression operator<(const variable& rhs) const { return boolean_expression(base::_kernel, boolean_expression::type::ls, base::shreg(), rhs.shreg()); } - boolean_expression operator < (T rhs) const { + boolean_expression operator<(T rhs) const { return boolean_expression(base::_kernel, boolean_expression::type::ls, base::shreg(), rhs); } - boolean_expression operator > (const variable & rhs) const { + boolean_expression operator>(const variable& rhs) const { return boolean_expression(base::_kernel, boolean_expression::type::gt, base::shreg(), rhs.shreg()); } - boolean_expression operator > (T rhs) const { + boolean_expression operator>(T rhs) const { return boolean_expression(base::_kernel, boolean_expression::type::gt, base::shreg(), rhs); } - boolean_expression operator <= (const variable & rhs) const { + boolean_expression operator<=(const variable& rhs) const { return boolean_expression(base::_kernel, boolean_expression::type::le, base::shreg(), rhs.shreg()); } - boolean_expression operator <= (T rhs) const { + boolean_expression operator<=(T rhs) const { return boolean_expression(base::_kernel, boolean_expression::type::le, base::shreg(), rhs); } - boolean_expression operator >= (const variable & rhs) const { + boolean_expression operator>=(const variable& rhs) const { return boolean_expression(base::_kernel, boolean_expression::type::ge, base::shreg(), rhs.shreg()); } - boolean_expression operator >= (T rhs) const { + boolean_expression operator>=(T rhs) const { return boolean_expression(base::_kernel, boolean_expression::type::ge, base::shreg(), rhs); } // TODO: add necessary operations }; -template +template class variable : public variable_base { public: using type = T; using base = variable_base; using reg_type = const typename base::reg_type; - variable(variable &&) = default; - variable(jit_kernel & krnl, const shared_reg & reg); + variable(variable&&) = default; + variable(jit_kernel& krnl, const shared_reg& reg); - const variable & operator = (const variable & rhs) const; + const variable& operator=(const variable& rhs) const; }; -template +template class variable : public variable_base { public: using type = T[N]; @@ -510,34 +505,34 @@ class variable : public variable_base { using reg_type = const typename base::reg_type; constexpr static size_t length = N; - variable(variable &&) = default; - variable(jit_kernel & krnl); - variable(jit_kernel & krnl, const shared_reg & reg); + variable(variable&&) = default; + variable(jit_kernel& krnl); + variable(jit_kernel& krnl, const shared_reg& reg); - const variable & operator = (reg_type & rhs) const { + const variable& operator=(reg_type& rhs) const { base::_kernel.uni_vmovups(base::reg(), rhs); return *this; } - const variable & operator = (const type & rhs) const { - const type & cref = base::_kernel.constant(rhs); + const variable& operator=(const type& rhs) const { + const type& cref = base::_kernel.constant(rhs); variable creg(base::_kernel); creg = &cref; base::_kernel.uni_vmovdqu(base::reg(), base::_kernel.ptr[creg]); return *this; } - const variable & blend(reg_type & rhs, uint16_t mask) const { + const variable& blend(reg_type& rhs, uint16_t mask) const { base::_kernel.uni_vblendps(base::reg(), rhs, mask); return *this; } - const variable & permute(const std::array & order) const { + const variable& permute(const std::array& order) const { base::_kernel.uni_vpermps(base::reg(), order.data(), base::reg()); return *this; } - const variable & permute(const uint8_t * order) const { + const variable& permute(const uint8_t* order) const { base::_kernel.uni_vpermps(base::reg(), order, base::reg()); return *this; } @@ -546,139 +541,132 @@ class variable : public variable_base { }; class stack_frame { - stack_frame(const stack_frame &) = delete; - stack_frame & operator = (const stack_frame &) = delete; + stack_frame(const stack_frame&) = delete; + stack_frame& operator=(const stack_frame&) = delete; public: - stack_frame(jit_kernel & kernel, size_t size, uint32_t alignment = 1); - stack_frame(stack_frame && rhs); + stack_frame(jit_kernel& kernel, size_t size, uint32_t alignment = 1); + stack_frame(stack_frame&& rhs); ~stack_frame(); - const Xbyak::Reg64 & pointer() const; + const Xbyak::Reg64& pointer() const; void clear() const; private: - jit_kernel & _kernel; + jit_kernel& _kernel; size_t _size; uint32_t _alignment; }; -template +template ov::element::Type type2precision(); dnnl::impl::cpu::x64::cpu_isa_t get_current_isa(); class consts_table { - consts_table(const consts_table &) = delete; - consts_table & operator = (const consts_table &) = delete; + consts_table(const consts_table&) = delete; + consts_table& operator=(const consts_table&) = delete; public: consts_table() = default; - const void * store(const void *data, size_t size); + const void* store(const void* data, size_t size); private: static constexpr const size_t chunk_size = 512; using chunk = std::array; std::list _chunks; - size_t _size {}; + size_t _size{}; }; -} // namespace internal +} // namespace internal struct jit_kernel : public dnnl::impl::cpu::x64::jit_generator { using reg_indices = std::vector; - template + template using reg_traits = internal::reg_traits; - template + template using reg_traits_by_size = internal::reg_traits_by_size; - template + template using isa_traits = internal::isa_traits; using stack_frame = internal::stack_frame; using register_tag = internal::register_tag; using memory_tag = internal::memory_tag; - template + template using variable = internal::variable; - template + template using if_expression = internal::if_expression; - template + template using boolean_expression = internal::boolean_expression; - template + template Xbyak::Address argPtr(U T::*member) const { auto memPtr = &(reinterpret_cast(0)->*member); - const size_t offs = reinterpret_cast(memPtr) - reinterpret_cast(0); + const size_t offs = reinterpret_cast(memPtr) - reinterpret_cast(0); return address_frame(sizeof(U))[param1 + offs]; } - template + template variable arg(U T::*member) { using traits = internal::reg_traits; using reg_type = typename traits::type; - const auto & res = reserve(); + const auto& res = reserve(); if (sizeof(T) < traits::size) movzx(res, argPtr(member)); else mov(res, argPtr(member)); - return { *this, internal::make_shared(res, *this) }; + return {*this, internal::make_shared(res, *this)}; } - template + template variable arg(U T::*member) { using traits = internal::reg_traits; using reg_type = typename traits::type; - const auto & res = reserve(); + const auto& res = reserve(); if (sizeof(T) < traits::size) movzx(res, argPtr(member)); else mov(res, argPtr(member)); - return { *this, internal::make_shared(res, *this) }; - } - - jit_kernel(const char *name); - - template - const RegType & reserve(); - - template - void free(const RegType & reg); - - template - void copy(const Xbyak::Reg64& dst, - const Xbyak::Reg64& src, - const Xbyak::Reg64& size); - template - void copy(const Xbyak::Address& dst, - const Xbyak::Reg64& src, - const Xbyak::Reg64& size); - - template - void load(const variable & dst, const variable & src, size_t length = N); - template - void load(const variable & dst, const variable & src, const variable & length); - template - void store(const variable & dst, const variable & src, size_t length = N); - template - void store(const variable & dst, const variable & src, const variable & length); - - template - void foreach(const B & begin, - const E & end, - std::function&)> && fn, - const S & step = 1); - - template + return {*this, internal::make_shared(res, *this)}; + } + + jit_kernel(const char* name); + + template + const RegType& reserve(); + + template + void free(const RegType& reg); + + template + void copy(const Xbyak::Reg64& dst, const Xbyak::Reg64& src, const Xbyak::Reg64& size); + template + void copy(const Xbyak::Address& dst, const Xbyak::Reg64& src, const Xbyak::Reg64& size); + + template + void load(const variable& dst, const variable& src, size_t length = N); + template + void load(const variable& dst, const variable& src, const variable& length); + template + void store(const variable& dst, const variable& src, size_t length = N); + template + void store(const variable& dst, const variable& src, const variable& length); + + template + void foreach (const B& begin, const E& end, std::function&)> && fn, const S& step = 1); + + template variable var(); - template - variable var(const T & val); + template + variable var(const T& val); - template - const T & constant(const T & c); - template - const T * constant(const T * c, size_t size); + template + const T& constant(const T& c); + template + const T* constant(const T* c, size_t size); stack_frame stack(size_t size, uint32_t alignment = 1); - template - if_expression _if(const boolean_expression & expr) const; + template + if_expression _if(const boolean_expression& expr) const; void uni_vpermps(const Xbyak::Xmm& x1, const uint8_t mask[4], const Xbyak::Operand& op); void uni_vpermps(const Xbyak::Ymm& y1, const uint8_t mask[8], const Xbyak::Operand& op); @@ -689,9 +677,9 @@ struct jit_kernel : public dnnl::impl::cpu::x64::jit_generator { void postamble(); - const Xbyak::AddressFrame & address_frame(size_t size) const; - const reg_indices & free_x64regs() const; - const reg_indices & free_rmmregs() const; + const Xbyak::AddressFrame& address_frame(size_t size) const; + const reg_indices& free_x64regs() const; + const reg_indices& free_rmmregs() const; private: reg_indices _free_x64regs; @@ -703,44 +691,40 @@ struct jit_kernel : public dnnl::impl::cpu::x64::jit_generator { template <> const Xbyak::Reg64& jit_kernel::reserve(); -template -void jit_kernel::copy(const Xbyak::Reg64& dst, - const Xbyak::Reg64& src, - const Xbyak::Reg64& size) { - const auto & addr_frame = address_frame(sizeof(T)); +template +void jit_kernel::copy(const Xbyak::Reg64& dst, const Xbyak::Reg64& src, const Xbyak::Reg64& size) { + const auto& addr_frame = address_frame(sizeof(T)); auto p = reserve::type>(); - foreach(0, size, [&](const Xbyak::Reg64& idx) { + foreach (0, size, [&](const Xbyak::Reg64& idx) { mov(p, addr_frame[src + idx * sizeof(T)]); mov(addr_frame[dst + idx * sizeof(T)], p); - }); + }) + ; free(p); } -template -void jit_kernel::copy(const Xbyak::Address& dst, - const Xbyak::Reg64& src, - const Xbyak::Reg64& size) { - const auto & addr_frame = address_frame(sizeof(T)); +template +void jit_kernel::copy(const Xbyak::Address& dst, const Xbyak::Reg64& src, const Xbyak::Reg64& size) { + const auto& addr_frame = address_frame(sizeof(T)); auto p = reserve::type>(); auto d = reserve(); lea(d, dst); - foreach(0, size, [&](const Xbyak::Reg64& idx) { + foreach (0, size, [&](const Xbyak::Reg64& idx) { mov(p, addr_frame[src + idx * sizeof(T)]); mov(addr_frame[d + idx * sizeof(T)], p); - }); + }) + ; free(d); free(p); } -template -void jit_kernel::load(const variable & dst, const variable & src, size_t length) { +template +void jit_kernel::load(const variable& dst, const variable& src, size_t length) { static_assert(std::is_same::reg_type, const Xbyak::Reg64>::value, - "Source register must be Reg64"); + "Source register must be Reg64"); - using src_type = typename std::remove_cv< - typename std::remove_pointer::type>::type; - using dst_type = typename std::remove_cv< - typename std::remove_pointer::type>::type; + using src_type = typename std::remove_cv::type>::type; + using dst_type = typename std::remove_cv::type>::type; const std::vector pool_vec_idxs(_free_rmmregs.begin(), _free_rmmregs.end()); const std::vector pool_gpr_idxs(_free_x64regs.begin(), _free_x64regs.end()); @@ -752,17 +736,15 @@ void jit_kernel::load(const variable & dst, const variable & src, if (!_emitters[key]) { _emitters[key].reset(new jit_load_emitter(this, internal::get_current_isa(), src_prc, dst_prc, length)); } - _emitters[key]->emit_code( - { static_cast(static_cast(src).getIdx()) }, - { static_cast(static_cast(dst).getIdx()) }, - pool_vec_idxs, - pool_gpr_idxs); + _emitters[key]->emit_code({static_cast(static_cast(src).getIdx())}, + {static_cast(static_cast(dst).getIdx())}, + pool_vec_idxs, + pool_gpr_idxs); } -template -void jit_kernel::load(const variable & dst, const variable & src, const variable & length) { - using src_type = typename std::remove_cv< - typename std::remove_pointer::type>::type; +template +void jit_kernel::load(const variable& dst, const variable& src, const variable& length) { + using src_type = typename std::remove_cv::type>::type; auto s = stack(N * sizeof(src_type)); s.clear(); @@ -775,15 +757,13 @@ void jit_kernel::load(const variable & dst, const variable & src, load(dst, tmp); } -template -void jit_kernel::store(const variable & dst, const variable & src, size_t length) { +template +void jit_kernel::store(const variable& dst, const variable& src, size_t length) { static_assert(std::is_same::reg_type, const Xbyak::Reg64>::value, - "Destination register must be Reg64"); + "Destination register must be Reg64"); - using src_type = typename std::remove_cv< - typename std::remove_pointer::type>::type; - using dst_type = typename std::remove_cv< - typename std::remove_pointer::type>::type; + using src_type = typename std::remove_cv::type>::type; + using dst_type = typename std::remove_cv::type>::type; const std::vector pool_vec_idxs(_free_rmmregs.begin(), _free_rmmregs.end()); const std::vector pool_gpr_idxs(_free_x64regs.begin(), _free_x64regs.end()); @@ -795,17 +775,15 @@ void jit_kernel::store(const variable & dst, const variable & src if (!_emitters[key]) { _emitters[key].reset(new jit_store_emitter(this, internal::get_current_isa(), src_prc, dst_prc, length)); } - _emitters[key]->emit_code( - { static_cast(static_cast(src).getIdx()) }, - { static_cast(static_cast(dst).getIdx()) }, - pool_vec_idxs, - pool_gpr_idxs); + _emitters[key]->emit_code({static_cast(static_cast(src).getIdx())}, + {static_cast(static_cast(dst).getIdx())}, + pool_vec_idxs, + pool_gpr_idxs); } -template -void jit_kernel::store(const variable & dst, const variable & src, const variable & length) { - using dst_type = typename std::remove_cv< - typename std::remove_pointer::type>::type; +template +void jit_kernel::store(const variable& dst, const variable& src, const variable& length) { + using dst_type = typename std::remove_cv::type>::type; auto s = stack(N * sizeof(dst_type)); @@ -817,11 +795,11 @@ void jit_kernel::store(const variable & dst, const variable & src copy(dst, tmp, length); } -template -void jit_kernel::foreach(const B & begin, - const E & end, - std::function&)> && fn, - const S & step) { +template +void jit_kernel::foreach (const B& begin, + const E& end, + std::function&)> && fn, + const S& step) { using namespace Xbyak; Label loop, exit; @@ -841,36 +819,36 @@ void jit_kernel::foreach(const B & begin, L(exit); } -template +template jit_kernel::variable jit_kernel::var() { using reg_type = typename reg_traits::type; - const auto & reg = reserve(); + const auto& reg = reserve(); return variable(*this, internal::make_shared(reg, *this)); } -template -jit_kernel::variable jit_kernel::var(const T & val) { +template +jit_kernel::variable jit_kernel::var(const T& val) { using reg_type = typename reg_traits::type; - const auto & reg = reserve(); + const auto& reg = reserve(); variable res(*this, internal::make_shared(reg, *this)); res = val; return res; } -template -const T & jit_kernel::constant(const T & c) { +template +const T& jit_kernel::constant(const T& c) { auto res = _consts.store(&c, sizeof c); return *reinterpret_cast(res); } -template -const T * jit_kernel::constant(const T * c, size_t size) { +template +const T* jit_kernel::constant(const T* c, size_t size) { auto res = _consts.store(c, size * sizeof(T)); return reinterpret_cast(res); } -template -jit_kernel::if_expression jit_kernel::_if(const boolean_expression & expr) const { +template +jit_kernel::if_expression jit_kernel::_if(const boolean_expression& expr) const { return if_expression(expr); } @@ -879,12 +857,13 @@ namespace internal { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // shared_reg -template -shared_reg make_shared(Reg & reg, jit_kernel & kernel) { - std::shared_ptr ptr(®, [&kernel](Reg *preg) { +template +shared_reg make_shared(Reg& reg, jit_kernel& kernel) { + std::shared_ptr ptr(®, [&kernel](Reg* preg) { try { kernel.free(*preg); - } catch(...) {} + } catch (...) { + } }); return ptr; } @@ -892,68 +871,68 @@ shared_reg make_shared(Reg & reg, jit_kernel & kernel) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // boolean_expression -template -boolean_expression::boolean_expression(jit_kernel & kernel, type t, const shared_reg & lhs, const shared_reg & rhs) - : _kernel(kernel) - , _type(t) - , _lhs(lhs) - , _rhs(rhs) - , _rvalue {} { -} - -template -boolean_expression::boolean_expression(jit_kernel & kernel, type t, const shared_reg & lhs, T rhs) - : _kernel(kernel) - , _type(t) - , _lhs(lhs) - , _rvalue(rhs) { -} - -template -void boolean_expression::cmp(const Xbyak::Label & exit) const { +template +boolean_expression::boolean_expression(jit_kernel& kernel, + type t, + const shared_reg& lhs, + const shared_reg& rhs) + : _kernel(kernel), + _type(t), + _lhs(lhs), + _rhs(rhs), + _rvalue{} {} + +template +boolean_expression::boolean_expression(jit_kernel& kernel, type t, const shared_reg& lhs, T rhs) + : _kernel(kernel), + _type(t), + _lhs(lhs), + _rvalue(rhs) {} + +template +void boolean_expression::cmp(const Xbyak::Label& exit) const { if (_rhs) _kernel.cmp(*_lhs, *_rhs); else _kernel.cmp(*_lhs, _rvalue); switch (_type) { - case type::eq: { - _kernel.jne(exit, Xbyak::CodeGenerator::T_NEAR); - break; - } - case type::neq: { - _kernel.je(exit, Xbyak::CodeGenerator::T_NEAR); - break; - } - case type::ls: { - _kernel.jge(exit, Xbyak::CodeGenerator::T_NEAR); - break; - } - case type::gt: { - _kernel.jle(exit, Xbyak::CodeGenerator::T_NEAR); - break; - } - case type::le: { - _kernel.jg(exit, Xbyak::CodeGenerator::T_NEAR); - break; - } - case type::ge: { - _kernel.jl(exit, Xbyak::CodeGenerator::T_NEAR); - break; - } + case type::eq: { + _kernel.jne(exit, Xbyak::CodeGenerator::T_NEAR); + break; + } + case type::neq: { + _kernel.je(exit, Xbyak::CodeGenerator::T_NEAR); + break; + } + case type::ls: { + _kernel.jge(exit, Xbyak::CodeGenerator::T_NEAR); + break; + } + case type::gt: { + _kernel.jle(exit, Xbyak::CodeGenerator::T_NEAR); + break; + } + case type::le: { + _kernel.jg(exit, Xbyak::CodeGenerator::T_NEAR); + break; + } + case type::ge: { + _kernel.jl(exit, Xbyak::CodeGenerator::T_NEAR); + break; + } } } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // then_expression -template -then_expression::then_expression(if_expression & expr) - : _if_expr(expr) {} +template +then_expression::then_expression(if_expression& expr) : _if_expr(expr) {} -template -template -void then_expression::_else(F && fn) { +template +template +void then_expression::_else(F&& fn) { fn(); _if_expr._expr._kernel.L(_if_expr._exit); _if_expr._is_exit_valid = true; @@ -962,75 +941,57 @@ void then_expression::_else(F && fn) { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // variable -template -variable_base::variable_base(jit_kernel & krnl, const shared_reg & reg) - : _kernel(krnl) - , _reg(reg) { -} +template +variable_base::variable_base(jit_kernel& krnl, const shared_reg& reg) + : _kernel(krnl), + _reg(reg) {} -template -variable_base::variable_base(const variable_base & rhs) - : _kernel(rhs._kernel) - , _reg(rhs._reg) { -} +template +variable_base::variable_base(const variable_base& rhs) : _kernel(rhs._kernel), + _reg(rhs._reg) {} -template -variable_base::variable_base(variable_base && rhs) - : _kernel(rhs._kernel) - , _reg(std::move(rhs._reg)) { -} +template +variable_base::variable_base(variable_base&& rhs) : _kernel(rhs._kernel), + _reg(std::move(rhs._reg)) {} -template -variable_base::variable_base(jit_kernel & krnl, const shared_reg & addr) - : _kernel(krnl) - , _addr(addr) { -} +template +variable_base::variable_base(jit_kernel& krnl, const shared_reg& addr) + : _kernel(krnl), + _addr(addr) {} -template -variable_base::variable_base(const variable_base & rhs) - : _kernel(rhs._kernel) - , _addr(rhs._addr) { -} +template +variable_base::variable_base(const variable_base& rhs) : _kernel(rhs._kernel), + _addr(rhs._addr) {} -template -variable_base::variable_base(variable_base && rhs) - : _kernel(rhs._kernel) - , _addr(std::move(rhs._addr)) { -} +template +variable_base::variable_base(variable_base&& rhs) : _kernel(rhs._kernel), + _addr(std::move(rhs._addr)) {} -template -variable::variable(jit_kernel & krnl) - : base(krnl, make_shared(krnl.reserve::type>(), krnl)) { -} +template +variable::variable(jit_kernel& krnl) + : base(krnl, make_shared(krnl.reserve::type>(), krnl)) {} -template -variable::variable(jit_kernel & krnl, const shared_reg & reg) - : base(krnl, reg) { -} +template +variable::variable(jit_kernel& krnl, const shared_reg& reg) : base(krnl, reg) {} -template -variable::variable(jit_kernel & krnl, const shared_reg & reg) - : base(krnl, reg) { -} +template +variable::variable(jit_kernel& krnl, const shared_reg& reg) : base(krnl, reg) {} -template -const variable & variable::operator = (const variable & rhs) const { - const auto & addr_frame = base::_kernel.address_frame(sizeof(T)); +template +const variable& variable::operator=(const variable& rhs) const { + const auto& addr_frame = base::_kernel.address_frame(sizeof(T)); base::_kernel.mov(addr_frame[base::reg()], rhs); return *this; } -template -variable::variable(jit_kernel & krnl) - : base(krnl, make_shared(krnl.reserve::type>(), krnl)) { -} +template +variable::variable(jit_kernel& krnl) + : base(krnl, make_shared(krnl.reserve::type>(), krnl)) {} -template -variable::variable(jit_kernel & krnl, const shared_reg & reg) - : base(krnl, reg) { -} +template +variable::variable(jit_kernel& krnl, const shared_reg& reg) : base(krnl, reg) {} -} // namespace internal +} // namespace internal -} // namespace intel_cpu -} // namespace ov +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel_base.cpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel_base.cpp index 8fd3a966e13887..ffc0286431b279 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel_base.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel_base.cpp @@ -10,14 +10,11 @@ namespace ov { namespace intel_cpu { namespace kernel { -JitKernelBase::JitKernelBase(const char* name, x64::cpu_isa_t isa) - : x64::jit_generator(name, isa), m_isa(isa) { +JitKernelBase::JitKernelBase(const char* name, x64::cpu_isa_t isa) : x64::jit_generator(name, isa), m_isa(isa) { vlen = x64::isa_max_vlen(isa); } -void JitKernelBase::uni_vfmsub132ps(const Xbyak::Xmm& v_dst, - const Xbyak::Xmm& v_src, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vfmsub132ps(const Xbyak::Xmm& v_dst, const Xbyak::Xmm& v_src, const Xbyak::Operand& op) { if (isValidIsa(x64::avx2)) { vfmsub132ps(v_dst, v_src, op); } else if (isValidIsa(x64::avx)) { @@ -31,9 +28,7 @@ void JitKernelBase::uni_vfmsub132ps(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::uni_vfnmadd132ps(const Xbyak::Xmm& v_dst, - const Xbyak::Xmm& v_src, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vfnmadd132ps(const Xbyak::Xmm& v_dst, const Xbyak::Xmm& v_src, const Xbyak::Operand& op) { if (isValidIsa(x64::avx2)) { vfnmadd132ps(v_dst, v_src, op); } else if (isValidIsa(x64::avx)) { @@ -48,9 +43,7 @@ void JitKernelBase::uni_vfnmadd132ps(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::uni_vfmsub231ps(const Xbyak::Xmm& v_dst, - const Xbyak::Xmm& v_src, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vfmsub231ps(const Xbyak::Xmm& v_dst, const Xbyak::Xmm& v_src, const Xbyak::Operand& op) { if (isValidIsa(x64::avx2)) { vfmsub231ps(v_dst, v_src, op); } else if (isValidIsa(x64::avx)) { @@ -65,9 +58,7 @@ void JitKernelBase::uni_vfmsub231ps(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::uni_vpaddd(const Xbyak::Ymm& v_dst, - const Xbyak::Ymm& v_src, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vpaddd(const Xbyak::Ymm& v_dst, const Xbyak::Ymm& v_src, const Xbyak::Operand& op) { if (isValidIsa(x64::avx2)) { vpaddd(v_dst, v_src, op); } else if (isValidIsa(x64::avx)) { @@ -99,9 +90,7 @@ void JitKernelBase::uni_vpaddd(const Xbyak::Ymm& v_dst, } } -void JitKernelBase::uni_vpaddq(const Xbyak::Xmm& v_dst, - const Xbyak::Xmm& v_src, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vpaddq(const Xbyak::Xmm& v_dst, const Xbyak::Xmm& v_src, const Xbyak::Operand& op) { if (isValidIsa(x64::avx2)) { vpaddq(v_dst, v_src, op); } else { @@ -112,9 +101,7 @@ void JitKernelBase::uni_vpaddq(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::uni_vpsubd(const Xbyak::Ymm& v_dst, - const Xbyak::Ymm& v_src, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vpsubd(const Xbyak::Ymm& v_dst, const Xbyak::Ymm& v_src, const Xbyak::Operand& op) { if (isValidIsa(x64::avx2)) { vpsubd(v_dst, v_src, op); } else if (isValidIsa(x64::avx)) { @@ -146,9 +133,7 @@ void JitKernelBase::uni_vpsubd(const Xbyak::Ymm& v_dst, } } -void JitKernelBase::uni_vsubpd(const Xbyak::Xmm& v_dst, - const Xbyak::Xmm& v_src, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vsubpd(const Xbyak::Xmm& v_dst, const Xbyak::Xmm& v_src, const Xbyak::Operand& op) { if (isValidIsa(x64::avx)) { vsubpd(v_dst, v_src, op); } else { @@ -159,9 +144,7 @@ void JitKernelBase::uni_vsubpd(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::uni_vmulpd(const Xbyak::Xmm& v_dst, - const Xbyak::Xmm& v_src, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vmulpd(const Xbyak::Xmm& v_dst, const Xbyak::Xmm& v_src, const Xbyak::Operand& op) { if (isValidIsa(x64::avx)) { vmulpd(v_dst, v_src, op); } else { @@ -172,9 +155,7 @@ void JitKernelBase::uni_vmulpd(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::uni_vpmuludq(const Xbyak::Xmm& v_dst, - const Xbyak::Xmm& v_src, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vpmuludq(const Xbyak::Xmm& v_dst, const Xbyak::Xmm& v_src, const Xbyak::Operand& op) { if (isValidIsa(x64::avx2)) { vpmuludq(v_dst, v_src, op); } else { @@ -185,9 +166,7 @@ void JitKernelBase::uni_vpmuludq(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::uni_vdivps(const Xbyak::Xmm& v_dst, - const Xbyak::Operand& op1, - const Xbyak::Operand& op2) { +void JitKernelBase::uni_vdivps(const Xbyak::Xmm& v_dst, const Xbyak::Operand& op1, const Xbyak::Operand& op2) { if (isValidIsa(x64::avx)) { vdivps(v_dst, op1, op2); } else { @@ -198,9 +177,7 @@ void JitKernelBase::uni_vdivps(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::uni_vdivpd(const Xbyak::Xmm& v_dst, - const Xbyak::Xmm& v_src, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vdivpd(const Xbyak::Xmm& v_dst, const Xbyak::Xmm& v_src, const Xbyak::Operand& op) { if (isValidIsa(x64::avx)) { vdivpd(v_dst, v_src, op); } else { @@ -211,9 +188,7 @@ void JitKernelBase::uni_vdivpd(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::uni_vandps(const Xbyak::Xmm& v_dst, - const Xbyak::Xmm& vSrs, - const Xbyak::Operand &op) { +void JitKernelBase::uni_vandps(const Xbyak::Xmm& v_dst, const Xbyak::Xmm& vSrs, const Xbyak::Operand& op) { if (isValidIsa(x64::avx)) { vandps(v_dst, vSrs, op); } else { @@ -224,9 +199,7 @@ void JitKernelBase::uni_vandps(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::uni_vandnps(const Xbyak::Xmm& v_dst, - const Xbyak::Xmm& vSrs, - const Xbyak::Operand &op) { +void JitKernelBase::uni_vandnps(const Xbyak::Xmm& v_dst, const Xbyak::Xmm& vSrs, const Xbyak::Operand& op) { if (isValidIsa(x64::avx)) { vandnps(v_dst, vSrs, op); } else { @@ -237,9 +210,9 @@ void JitKernelBase::uni_vandnps(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::gatherdd(const Xbyak::Xmm& v_dst, - const Xbyak::Reg64& rSrcPtr, - const Xbyak::Xmm& vSrcShift, +void JitKernelBase::gatherdd(const Xbyak::Xmm& v_dst, + const Xbyak::Reg64& rSrcPtr, + const Xbyak::Xmm& vSrcShift, const Xbyak::Opmask& kReadMask, const bool useMask, const bool zeroFill) { @@ -254,17 +227,18 @@ void JitKernelBase::gatherdd(const Xbyak::Xmm& v_dst, vpgatherdd(v_dst | kReadMask, ptr[rSrcPtr + vSrcShift]); } -void JitKernelBase::gatherdd(const Xbyak::Xmm& v_dst, +void JitKernelBase::gatherdd(const Xbyak::Xmm& v_dst, const Xbyak::Reg64& rSrcPtr, - const Xbyak::Xmm& vSrcShift, - const Xbyak::Xmm& vReadMask, + const Xbyak::Xmm& vSrcShift, + const Xbyak::Xmm& vReadMask, const bool useMask, const bool zeroFill) { - if (v_dst.getIdx() == vSrcShift.getIdx() || v_dst.getIdx() == vReadMask.getIdx() || vSrcShift.getIdx() == vReadMask.getIdx()) { + if (v_dst.getIdx() == vSrcShift.getIdx() || v_dst.getIdx() == vReadMask.getIdx() || + vSrcShift.getIdx() == vReadMask.getIdx()) { OPENVINO_THROW("Any pair of the index, mask, or destination registers cannot be the same."); } if (zeroFill) - pxor(v_dst, v_dst); // Don't use vpxor. It zeros the rest of the YMM register. + pxor(v_dst, v_dst); // Don't use vpxor. It zeros the rest of the YMM register. if (isValidIsa(x64::avx2)) { if (!useMask) @@ -280,7 +254,7 @@ void JitKernelBase::gatherdd(const Xbyak::Xmm& v_dst, Xbyak::Label lLoopNext; if (useMask) { uni_vpextrd(r32Aux, vReadMask, i); - cmp(r32Aux, 0); // TODO: check significant bit + cmp(r32Aux, 0); // TODO: check significant bit je(lLoopNext, T_NEAR); } uni_vpextrd(r32Aux, vSrcShift, i); @@ -292,13 +266,14 @@ void JitKernelBase::gatherdd(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::gatherdd(const Xbyak::Ymm& v_dst, +void JitKernelBase::gatherdd(const Xbyak::Ymm& v_dst, const Xbyak::Reg64& rSrcPtr, - const Xbyak::Ymm& vSrcShift, - const Xbyak::Ymm& vReadMask, + const Xbyak::Ymm& vSrcShift, + const Xbyak::Ymm& vReadMask, const bool useMask, const bool zeroFill) { - if (v_dst.getIdx() == vSrcShift.getIdx() || v_dst.getIdx() == vReadMask.getIdx() || vSrcShift.getIdx() == vReadMask.getIdx()) { + if (v_dst.getIdx() == vSrcShift.getIdx() || v_dst.getIdx() == vReadMask.getIdx() || + vSrcShift.getIdx() == vReadMask.getIdx()) { OPENVINO_THROW("Any pair of the index, mask, or destination registers cannot be the same."); } if (isValidIsa(x64::avx2)) { @@ -309,8 +284,7 @@ void JitKernelBase::gatherdd(const Xbyak::Ymm& v_dst, vpgatherdd(v_dst, ptr[rSrcPtr + vSrcShift], vReadMask); } else { - Xbyak::Xmm xmmDst = Xbyak::Xmm(v_dst.getIdx()), - xmmSrcShft = Xbyak::Xmm(vSrcShift.getIdx()), + Xbyak::Xmm xmmDst = Xbyak::Xmm(v_dst.getIdx()), xmmSrcShft = Xbyak::Xmm(vSrcShift.getIdx()), xmmReadMask = Xbyak::Xmm(vReadMask.getIdx()); for (uint8_t i = 0; i < 2; i++) { gatherdd(xmmDst, rSrcPtr, xmmSrcShft, xmmReadMask, useMask, zeroFill); @@ -323,7 +297,7 @@ void JitKernelBase::gatherdd(const Xbyak::Ymm& v_dst, } } -void JitKernelBase::uni_vpbroadcastq(const Xbyak::Xmm &x, const Xbyak::Operand &op) { +void JitKernelBase::uni_vpbroadcastq(const Xbyak::Xmm& x, const Xbyak::Operand& op) { if (isValidIsa(x64::avx2)) { vpbroadcastq(x, op); } else { @@ -332,7 +306,7 @@ void JitKernelBase::uni_vpbroadcastq(const Xbyak::Xmm &x, const Xbyak::Operand & } } -void JitKernelBase::uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op) { +void JitKernelBase::uni_vpbroadcastd(const Xbyak::Xmm& x, const Xbyak::Operand& op) { if (isValidIsa(x64::avx2)) { vpbroadcastd(x, op); } else if (isValidIsa(x64::avx)) { @@ -348,7 +322,7 @@ void JitKernelBase::uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand & } } -void JitKernelBase::uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op) { +void JitKernelBase::uni_vpbroadcastd(const Xbyak::Ymm& x, const Xbyak::Operand& op) { if (isValidIsa(x64::avx2)) { vpbroadcastd(x, op); } else { @@ -375,8 +349,7 @@ void JitKernelBase::uni_vroundpd(const Xbyak::Xmm& v_dst, const Xbyak::Operand& } } -void JitKernelBase::uni_vcvtdq2pd(const Xbyak::Xmm& v_dst, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vcvtdq2pd(const Xbyak::Xmm& v_dst, const Xbyak::Operand& op) { if (isValidIsa(x64::avx)) { vcvtdq2pd(v_dst, op); } else { @@ -384,8 +357,7 @@ void JitKernelBase::uni_vcvtdq2pd(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::uni_vcvtpd2dq(const Xbyak::Xmm& v_dst, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vcvtpd2dq(const Xbyak::Xmm& v_dst, const Xbyak::Operand& op) { if (isValidIsa(x64::avx)) { vcvtpd2dq(v_dst, op); } else { @@ -393,8 +365,7 @@ void JitKernelBase::uni_vcvtpd2dq(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::uni_vpmovzxdq(const Xbyak::Xmm& v_dst, - const Xbyak::Operand& op) { +void JitKernelBase::uni_vpmovzxdq(const Xbyak::Xmm& v_dst, const Xbyak::Operand& op) { if (isValidIsa(x64::avx2)) { vpmovzxdq(v_dst, op); } else { @@ -416,8 +387,7 @@ void JitKernelBase::uni_vshufpd(const Xbyak::Xmm& v_dst, } } -void JitKernelBase::fillRestWorkMask(const Xbyak::Opmask& dstMask, - const Xbyak::Reg64& rWorkRest) { +void JitKernelBase::fillRestWorkMask(const Xbyak::Opmask& dstMask, const Xbyak::Reg64& rWorkRest) { auto rOnes = getReg64(); mov(rOnes, 0xFFFFFFFFFFFFFFFF); @@ -493,11 +463,11 @@ void JitKernelBase::fillRestWorkMask(const Xbyak::Ymm& ymmDstMask, L(lEnd); } -void JitKernelBase::load(const Xbyak::Xmm& v_dst, +void JitKernelBase::load(const Xbyak::Xmm& v_dst, const Xbyak::Address& srcAddr, - const Xbyak::Reg64& rLoadNum, - const size_t typeSize, - const bool zeroFilling) { + const Xbyak::Reg64& rLoadNum, + const size_t typeSize, + const bool zeroFilling) { if (!one_of(typeSize, 1u, 2u, 4u, 8u)) { OPENVINO_THROW("Could not load data with type size ", typeSize); } @@ -523,11 +493,11 @@ void JitKernelBase::load(const Xbyak::Xmm& v_dst, L(lEnd); } -void JitKernelBase::load(const Xbyak::Ymm& v_dst, +void JitKernelBase::load(const Xbyak::Ymm& v_dst, const Xbyak::Address& srcAddr, - const Xbyak::Reg64& rLoadNum, - const size_t typeSize, - const bool zeroFilling) { + const Xbyak::Reg64& rLoadNum, + const size_t typeSize, + const bool zeroFilling) { if (!one_of(typeSize, 1u, 2u, 4u, 8u)) { OPENVINO_THROW("Could not load data with type size ", typeSize); } @@ -564,9 +534,9 @@ void JitKernelBase::load(const Xbyak::Ymm& v_dst, } void JitKernelBase::store(const Xbyak::Address& dstAddr, - const Xbyak::Xmm& v_src, - const Xbyak::Reg64& rToStoreNum, - const size_t typeSize) { + const Xbyak::Xmm& v_src, + const Xbyak::Reg64& rToStoreNum, + const size_t typeSize) { if (!one_of(typeSize, 1u, 2u, 4u, 8u)) { OPENVINO_THROW("Could not store data with type size ", typeSize); } @@ -592,9 +562,9 @@ void JitKernelBase::store(const Xbyak::Address& dstAddr, } void JitKernelBase::store(const Xbyak::Address& dstAddr, - const Xbyak::Ymm& v_src, - const Xbyak::Reg64& rToStoreNum, - const size_t typeSize) { + const Xbyak::Ymm& v_src, + const Xbyak::Reg64& rToStoreNum, + const size_t typeSize) { if (!one_of(typeSize, 1u, 2u, 4u, 8u)) { OPENVINO_THROW("Could not store data with type size ", typeSize); } @@ -631,11 +601,11 @@ void JitKernelBase::store(const Xbyak::Address& dstAddr, void JitKernelBase::memMovDD(const Xbyak::Reg64& rDst, const Xbyak::Reg64& rSrc, - const Xbyak::Xmm& vReadMask, - const Xbyak::Xmm& vSrcShift, + const Xbyak::Xmm& vReadMask, + const Xbyak::Xmm& vSrcShift, const Xbyak::Reg64& rToStoreNum, - const bool useMask, - const bool zeroFill) { + const bool useMask, + const bool zeroFill) { Xbyak::Label lEnd; auto rAux = getReg64(); Xbyak::Reg32 r32Aux = Xbyak::Reg32(rAux.getIdx()); @@ -671,11 +641,11 @@ void JitKernelBase::memMovDD(const Xbyak::Reg64& rDst, void JitKernelBase::memMovDD(const Xbyak::Reg64& rDst, const Xbyak::Reg64& rSrc, - const Xbyak::Ymm& vReadMask, - const Xbyak::Ymm& vSrcShift, + const Xbyak::Ymm& vReadMask, + const Xbyak::Ymm& vSrcShift, const Xbyak::Reg64& rToStoreNum, - const bool useMask, - const bool zeroFill) { + const bool useMask, + const bool zeroFill) { Xbyak::Label lEnd; if (isValidIsa(x64::avx2)) { auto vAux = RegistersPool::Reg(registersPool); @@ -684,8 +654,7 @@ void JitKernelBase::memMovDD(const Xbyak::Reg64& rDst, } else if (isValidIsa(x64::avx)) { const uint8_t typeSize = sizeof(int); const uint8_t elPerXmm = x64::cpu_isa_traits::vlen / typeSize; - Xbyak::Xmm xmmReadMask = Xbyak::Xmm(vReadMask.getIdx()), - xmmSrcShft = Xbyak::Xmm(vSrcShift.getIdx()); + Xbyak::Xmm xmmReadMask = Xbyak::Xmm(vReadMask.getIdx()), xmmSrcShft = Xbyak::Xmm(vSrcShift.getIdx()); for (uint8_t i = 0; i < 2; i++) { memMovDD(rDst, rSrc, xmmReadMask, xmmSrcShft, rToStoreNum, useMask, zeroFill); @@ -707,6 +676,6 @@ void JitKernelBase::memMovDD(const Xbyak::Reg64& rDst, L(lEnd); } -} // namespace kernel -} // namespace intel_cpu -} // namespace ov +} // namespace kernel +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel_base.hpp b/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel_base.hpp index 260d7196331a7f..eee4ff4d8c0708 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel_base.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/x64/jit_kernel_base.hpp @@ -7,9 +7,9 @@ #include "openvino/core/visibility.hpp" #if defined(OPENVINO_ARCH_X86_64) -#include "cpu/x64/jit_generator.hpp" -#include "registers_pool.hpp" -#endif // OPENVINO_ARCH_X86_64 +# include "cpu/x64/jit_generator.hpp" +# include "registers_pool.hpp" +#endif // OPENVINO_ARCH_X86_64 namespace ov { namespace intel_cpu { @@ -19,18 +19,22 @@ class JitKernelBase; #if defined(OPENVINO_ARCH_X86_64) -#define getReg64() RegistersPool::Reg(registersPool) -#define getReg32() RegistersPool::Reg(registersPool) -#define getVmm() RegistersPool::Reg(registersPool) -#define getMask() RegistersPool::Reg(registersPool) +# define getReg64() RegistersPool::Reg(registersPool) +# define getReg32() RegistersPool::Reg(registersPool) +# define getVmm() RegistersPool::Reg(registersPool) +# define getMask() RegistersPool::Reg(registersPool) -class JitKernelBase: public dnnl::impl::cpu::x64::jit_generator { +class JitKernelBase : public dnnl::impl::cpu::x64::jit_generator { public: JitKernelBase(const char* name, dnnl::impl::cpu::x64::cpu_isa_t max_cpu_isa); - dnnl::impl::cpu::x64::cpu_isa_t getIsa() { return m_isa; } + dnnl::impl::cpu::x64::cpu_isa_t getIsa() { + return m_isa; + } - size_t getVectorLen() { return vlen; } + size_t getVectorLen() { + return vlen; + } void uni_vfmsub132ps(const Xbyak::Xmm& vDst, const Xbyak::Xmm& vSrc, const Xbyak::Operand& op); @@ -62,9 +66,9 @@ class JitKernelBase: public dnnl::impl::cpu::x64::jit_generator { void uni_vdivpd(const Xbyak::Xmm& v_dst, const Xbyak::Xmm& v_src, const Xbyak::Operand& op2); - void uni_vandps(const Xbyak::Xmm& vDst, const Xbyak::Xmm& vSrs, const Xbyak::Operand &op); + void uni_vandps(const Xbyak::Xmm& vDst, const Xbyak::Xmm& vSrs, const Xbyak::Operand& op); - void uni_vandnps(const Xbyak::Xmm& vDst, const Xbyak::Xmm& vSrs, const Xbyak::Operand &op); + void uni_vandnps(const Xbyak::Xmm& vDst, const Xbyak::Xmm& vSrs, const Xbyak::Operand& op); void uni_kmovd(const Xbyak::Opmask& kDst, const Xbyak::Opmask& kSrc) { kmovd(kDst, kSrc); @@ -82,11 +86,11 @@ class JitKernelBase: public dnnl::impl::cpu::x64::jit_generator { uni_vandps(kDst, kSrc1, kSrc2); } - void uni_vpbroadcastd(const Xbyak::Xmm &x, const Xbyak::Operand &op); + void uni_vpbroadcastd(const Xbyak::Xmm& x, const Xbyak::Operand& op); - void uni_vpbroadcastd(const Xbyak::Ymm &x, const Xbyak::Operand &op); + void uni_vpbroadcastd(const Xbyak::Ymm& x, const Xbyak::Operand& op); - void uni_vpbroadcastq(const Xbyak::Xmm &x, const Xbyak::Operand &op); + void uni_vpbroadcastq(const Xbyak::Xmm& x, const Xbyak::Operand& op); void uni_vroundpd(const Xbyak::Xmm& v_dst, const Xbyak::Operand& op, const uint8_t imm); @@ -98,76 +102,71 @@ class JitKernelBase: public dnnl::impl::cpu::x64::jit_generator { void uni_vshufpd(const Xbyak::Xmm& v_dst, const Xbyak::Xmm& v_srs, const Xbyak::Operand& op, uint8_t imm); - void gatherdd(const Xbyak::Xmm& vDst, - const Xbyak::Reg64& rSrcPtr, - const Xbyak::Xmm& vSrcShift, + void gatherdd(const Xbyak::Xmm& vDst, + const Xbyak::Reg64& rSrcPtr, + const Xbyak::Xmm& vSrcShift, const Xbyak::Opmask& kReadMask, - const bool useMask = true, - const bool zeroFill = false); + const bool useMask = true, + const bool zeroFill = false); - void gatherdd(const Xbyak::Xmm& vDst, + void gatherdd(const Xbyak::Xmm& vDst, const Xbyak::Reg64& rSrcPtr, - const Xbyak::Xmm& vSrcShift, - const Xbyak::Xmm& vReadMask, - const bool useMask = true, + const Xbyak::Xmm& vSrcShift, + const Xbyak::Xmm& vReadMask, + const bool useMask = true, const bool zeroFill = false); - void gatherdd(const Xbyak::Ymm& vDst, + void gatherdd(const Xbyak::Ymm& vDst, const Xbyak::Reg64& rSrcPtr, - const Xbyak::Ymm& vSrcShift, - const Xbyak::Ymm& vReadMask, - const bool useMask = true, + const Xbyak::Ymm& vSrcShift, + const Xbyak::Ymm& vReadMask, + const bool useMask = true, const bool zeroFill = false); - void fillRestWorkMask(const Xbyak::Opmask& kDstMask, - const Xbyak::Reg64& rWorkRest); + void fillRestWorkMask(const Xbyak::Opmask& kDstMask, const Xbyak::Reg64& rWorkRest); - void fillRestWorkMask(const Xbyak::Xmm& ymmDstMask, - const Xbyak::Reg64& rWorkRest, - const uint64_t typeSize = 4); + void fillRestWorkMask(const Xbyak::Xmm& ymmDstMask, const Xbyak::Reg64& rWorkRest, const uint64_t typeSize = 4); - void fillRestWorkMask(const Xbyak::Ymm& ymmDstMask, - const Xbyak::Reg64& rWorkRest, - const uint64_t typeSize = 4); + void fillRestWorkMask(const Xbyak::Ymm& ymmDstMask, const Xbyak::Reg64& rWorkRest, const uint64_t typeSize = 4); - void load(const Xbyak::Xmm& vDst, + void load(const Xbyak::Xmm& vDst, const Xbyak::Address& srcAddr, - const Xbyak::Reg64& rLoadNum, - const size_t typeSize, + const Xbyak::Reg64& rLoadNum, + const size_t typeSize, const bool zeroFill = false); - void load(const Xbyak::Ymm& vDst, + void load(const Xbyak::Ymm& vDst, const Xbyak::Address& srcAddr, - const Xbyak::Reg64& rLoadNum, - const size_t typeSize, + const Xbyak::Reg64& rLoadNum, + const size_t typeSize, const bool zeroFill = false); void store(const Xbyak::Address& dstAddr, - const Xbyak::Xmm& vSrc, - const Xbyak::Reg64& rToStoreNum, - const size_t typeSize); + const Xbyak::Xmm& vSrc, + const Xbyak::Reg64& rToStoreNum, + const size_t typeSize); void store(const Xbyak::Address& dstAddr, - const Xbyak::Ymm& vSrc, - const Xbyak::Reg64& rToStoreNum, - const size_t typeSize); + const Xbyak::Ymm& vSrc, + const Xbyak::Reg64& rToStoreNum, + const size_t typeSize); // Makes gather from memory under the vReadMask and writes to the memory m128. void memMovDD(const Xbyak::Reg64& rDst, const Xbyak::Reg64& rSrc, - const Xbyak::Xmm& vReadMask, - const Xbyak::Xmm& vSrcShift, + const Xbyak::Xmm& vReadMask, + const Xbyak::Xmm& vSrcShift, const Xbyak::Reg64& rToStoreCounter, - const bool useMask = true, + const bool useMask = true, const bool zeroFill = false); // Makes gather from the memory under the vReadMask and writes to the memory m256. void memMovDD(const Xbyak::Reg64& rDst, const Xbyak::Reg64& rSrc, - const Xbyak::Ymm& vReadMask, - const Xbyak::Ymm& vSrcShift, + const Xbyak::Ymm& vReadMask, + const Xbyak::Ymm& vSrcShift, const Xbyak::Reg64& rToStoreCounter, - const bool useMask = true, + const bool useMask = true, const bool zeroFill = false); protected: @@ -181,32 +180,37 @@ class JitKernelBase: public dnnl::impl::cpu::x64::jit_generator { enum { // Comparison predicate operand (immediate byte) for single-precision floating-point values. - CMP_EQ_PS = 0, // Equal (ordered, non-signaling) - CMP_LT_PS, // Less-than (ordered, signaling) - CMP_LE_PS, // Less-than-or-equal (ordered, signaling) - CMP_UNORD_PS, // Unordered (non-signaling) - CMP_NEQ_PS, // Not-equal (unordered, non-signaling) - CMP_NLT_PS, // Not-less-than (unordered, signaling) - CMP_NLE_PS, // Not-less-than-or-equal (unordered, signaling) - CMP_ORD_PS // Ordered (non-signaling) + CMP_EQ_PS = 0, // Equal (ordered, non-signaling) + CMP_LT_PS, // Less-than (ordered, signaling) + CMP_LE_PS, // Less-than-or-equal (ordered, signaling) + CMP_UNORD_PS, // Unordered (non-signaling) + CMP_NEQ_PS, // Not-equal (unordered, non-signaling) + CMP_NLT_PS, // Not-less-than (unordered, signaling) + CMP_NLE_PS, // Not-less-than-or-equal (unordered, signaling) + CMP_ORD_PS // Ordered (non-signaling) }; }; -template +template class JitKernel : public JitKernelBase { public: - using KernelFunc = void (*)(const CallArgs *); + using KernelFunc = void (*)(const CallArgs*); explicit JitKernel(const char* name, const CompileParams& jcp, dnnl::impl::cpu::x64::cpu_isa_t max_cpu_isa) - : JitKernelBase{name, max_cpu_isa}, m_jcp{jcp}, m_func{nullptr} {} + : JitKernelBase{name, max_cpu_isa}, + m_jcp{jcp}, + m_func{nullptr} {} ~JitKernel() override = default; dnnl::impl::status_t create_kernel() override { const dnnl::impl::status_t code = jit_generator::create_kernel(); if (code != dnnl::impl::status::success) { - OPENVINO_THROW("Could not create kernel. Error code: ", std::to_string(code), ". ", - "Xbyak error code: ", Xbyak::ConvertErrorToString(Xbyak::GetError())); + OPENVINO_THROW("Could not create kernel. Error code: ", + std::to_string(code), + ". ", + "Xbyak error code: ", + Xbyak::ConvertErrorToString(Xbyak::GetError())); } m_func = (decltype(m_func))jit_ker(); return code; @@ -221,21 +225,21 @@ class JitKernel : public JitKernelBase { this->operator()(&args); } - template class KernelT> + template