From b543d0b0a703b86acbc92189dd22ede1c9a59dee Mon Sep 17 00:00:00 2001 From: Nishant Prabhu <168545601+NishantPrabhuFujitsu@users.noreply.github.com> Date: Tue, 17 Dec 2024 10:34:25 +0530 Subject: [PATCH 1/5] [ARM] [SDPA] SVE implementation of MHASingleToken for FP32 (#27273) ### Details: - Adds SVE FP32 implementations for functions called during execution of `MHASingleToken` for SVE-128, SVE-256 and SVE-512 platforms. - SVE implementations are compiled only if runtime support for SVE is detected on the hardware, otherwise it falls back to Neon. - Adds a new implementation for exponential function `exp_ps_` using fewer FMA operations. Executes ~18% faster and has better output precision. **Note:** I am aware of the Neon FP16 implementation of SDPA added recently. To accommodate for this, the current SVE changes will be used only if the hardware does not have ARM FP16 support. I will follow up with SVE FP16 implementations soon. ### [SVE] Benchmarking results Below are the benchmarking results of execution time of each ported function. Measurements were performed by running each function individually on dummy inputs (128 fp32 elements) for 1,000,000 iterations and computing average time (in micro-seconds). ![image](https://github.com/user-attachments/assets/3f82238f-af7e-4b68-b4b1-259cf389e41a) Execution time of `MHASingleToken` as a whole was also measured for two LLMs, the results of which are shown below. For LlaMA-3-8B, the SVE-128 and SVE-512 systems at my disposal did not have enough memory, so only SVE-256 results are shown. While there is an improvement overall, these results could be contaminated with run-to-run variation due to the small execution time of the kernel. **Benchmarking details:** Prompt length of 108 tokens was used; total time for generating 50 tokens was measured and average execution time was computed. ![image](https://github.com/user-attachments/assets/893c1c46-085f-46af-ab5a-2c1481c75f68) ### New exponential implementation It is based on the discussion in [these slides](https://www.slideshare.net/slideshow/hpc-phys20201203/239717194#23) (this is based on a past talk in Fujitsu hence the document is in Japanese, sorry!). The algorithm followed is slightly different from the current implementation, in that it uses `fexpa` instruction available on ARM and requires only 3 Taylor expansion terms (2 FMA operations) to be precise until the 8th decimal place. Our benchmarking results showed this implementation to be 44%-58% faster than the existing Neon implementation. It is ~18% faster than the SVE implementation of the current algorithm in Neon. ![image](https://github.com/user-attachments/assets/117df21d-3977-499c-8ab8-8f4346286113) In this PR, the new implementation is called by default. The SVE port of the existing Neon implementation has also been retained, if needed. --- .../compile_flags/os_flags.cmake | 88 +++++++++ .../cross_compiled_disp_gen.cmake | 1 + .../cross_compile/cross_compiled_func.cmake | 7 +- cmake/developer_package/features.cmake | 2 + .../dev_api/openvino/runtime/system_conf.hpp | 7 + src/inference/src/system_conf.cpp | 19 ++ src/plugins/intel_cpu/CMakeLists.txt | 22 ++- .../src/nodes/kernels/scaled_attn/common.hpp | 80 ++++++++ .../kernels/scaled_attn/mha_single_token.cpp | 174 +++++++++++++++++- .../kernels/scaled_attn/softmax_kernel.hpp | 46 ++++- 10 files changed, 437 insertions(+), 9 deletions(-) diff --git a/cmake/developer_package/compile_flags/os_flags.cmake b/cmake/developer_package/compile_flags/os_flags.cmake index fdfd7211c8e815..660fd6160893ae 100644 --- a/cmake/developer_package/compile_flags/os_flags.cmake +++ b/cmake/developer_package/compile_flags/os_flags.cmake @@ -4,6 +4,7 @@ include(ProcessorCount) include(CheckCXXCompilerFlag) +include(CheckCXXSourceCompiles) # # ov_disable_deprecated_warnings() @@ -91,6 +92,50 @@ macro(ov_dev_package_no_errors) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${ov_c_cxx_dev_no_errors}") endmacro() +# +# ov_check_compiler_supports_sve(flags) +# +# Checks whether CXX compiler for passed language supports SVE code compilation +# +macro(ov_check_compiler_supports_sve flags) + # Code to compile + set(SVE_CODE " + #include + int main() { + svfloat64_t a; + a = svdup_n_f64(0); + return 0; + }") + + # Save the current state of required flags + set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS}) + + # Set the flags necessary for compiling the test code with SVE support + set(CMAKE_REQUIRED_FLAGS "${CMAKE_CXX_FLAGS_INIT} ${flags}") + + # Check if the source code compiles with the given flags for C++ + CHECK_CXX_SOURCE_COMPILES("${SVE_CODE}" CXX_HAS_SVE) + + # If the compilation test is successful, set appropriate variables indicating support + if(CXX_HAS_SVE) + set(CXX_SVE_FOUND TRUE CACHE BOOL "SVE available on host") + set(CXX_SVE_FOUND TRUE CACHE BOOL "CXX SVE support") + set(CXX_SVE_FLAGS "${flags}" CACHE STRING "CXX SVE flags") + endif() + + # Restore the original state of required flags + set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE}) + + # If the compilation test fails, indicate that the support is not found + if(NOT CXX_SVE_FOUND) + set(CXX_SVE_FOUND FALSE CACHE BOOL "CXX SVE support") + set(CXX_SVE_FLAGS "" CACHE STRING "CXX SVE flags") + endif() + + # Mark the variables as advanced to hide them in the default CMake GUI + mark_as_advanced(CXX_SVE_FOUND CXX_SVE_FLAGS) +endmacro() + # # ov_sse42_optimization_flags() # @@ -208,6 +253,49 @@ macro(ov_arm_neon_fp16_optimization_flags flags) endif() endmacro() +# +# ov_arm_sve_optimization_flags() +# +macro(ov_arm_sve_optimization_flags flags) + # Check for compiler SVE support + ov_check_compiler_supports_sve("-march=armv8-a+sve") + + if(OV_COMPILER_IS_INTEL_LLVM) + message(WARNING "Unsupported CXX compiler ${CMAKE_CXX_COMPILER_ID}") + elseif(CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + # nothing should be required here + elseif(ANDROID) + if(ANDROID_ABI STREQUAL "arm64-v8a") + set(${flags} -Wno-unused-command-line-argument) + if(CXX_SVE_FOUND) + list(APPEND ${flags} -march=armv8-a+sve) + else() + message(WARNING "SVE is not supported on this Android ABI: ${ANDROID_ABI}") + endif() + else() + message(WARNING "SVE is not supported on this Android ABI: ${ANDROID_ABI}") + endif() + else() + if(AARCH64) + set(${flags} -O2) + + # Add flag for SVE if supported + if(CXX_SVE_FOUND) + list(APPEND ${flags} -march=armv8-a+sve) + endif() + if(NOT CMAKE_CL_64) + list(APPEND ${flags} -ftree-vectorize) + endif() + + set(${flags} ${${flags}}) + elseif(ARM) + message(WARNING "SVE is not supported on 32-bit ARM architectures.") + else() + message(WARNING "SVE is not supported by architecture ${CMAKE_SYSTEM_PROCESSOR}") + endif() + endif() +endmacro() + # # ov_disable_all_warnings() # diff --git a/cmake/developer_package/cross_compile/cross_compiled_disp_gen.cmake b/cmake/developer_package/cross_compile/cross_compiled_disp_gen.cmake index c33d64635eb10b..fd534f3e600bfe 100644 --- a/cmake/developer_package/cross_compile/cross_compiled_disp_gen.cmake +++ b/cmake/developer_package/cross_compile/cross_compiled_disp_gen.cmake @@ -18,6 +18,7 @@ set(_CPU_CHECK_ANY "true") set(_CPU_CHECK_SSE42 "with_cpu_x86_sse42()") set(_CPU_CHECK_AVX "with_cpu_x86_avx()") set(_CPU_CHECK_NEON_FP16 "with_cpu_neon_fp16()") +set(_CPU_CHECK_SVE "with_cpu_sve()") set(_CPU_CHECK_AVX2 "with_cpu_x86_avx2()") set(_CPU_CHECK_AVX512F "with_cpu_x86_avx512f()") diff --git a/cmake/developer_package/cross_compile/cross_compiled_func.cmake b/cmake/developer_package/cross_compile/cross_compiled_func.cmake index 1e92fe3bfdaf8c..962aa5d373a4db 100644 --- a/cmake/developer_package/cross_compile/cross_compiled_func.cmake +++ b/cmake/developer_package/cross_compile/cross_compiled_func.cmake @@ -3,7 +3,7 @@ # ## list of available instruction sets -set(_ARCH_LIST ANY SSE42 AVX AVX2 AVX512F NEON_FP16) +set(_ARCH_LIST ANY SSE42 AVX AVX2 AVX512F NEON_FP16 SVE) set(_ACCEPTED_ARCHS_ANY "^(ANY)$") set(_ACCEPTED_ARCHS_SSE42 "^(ANY|SSE42)$") @@ -11,6 +11,7 @@ set(_ACCEPTED_ARCHS_AVX "^(ANY|SSE42|AVX)$") set(_ACCEPTED_ARCHS_AVX2 "^(ANY|SSE42|AVX|AVX2)$") set(_ACCEPTED_ARCHS_AVX512F "^(ANY|SSE42|AVX|AVX2|AVX512F)$") set(_ACCEPTED_ARCHS_NEON_FP16 "^(ANY|NEON_FP16)$") +set(_ACCEPTED_ARCHS_SVE "^(ANY|SVE)$") ## Arch specific definitions set(_DEFINE_ANY "") @@ -19,12 +20,14 @@ set(_DEFINE_AVX "HAVE_AVX" ${_DEFINE_SSE42}) set(_DEFINE_AVX2 "HAVE_AVX2" ${_DEFINE_AVX}) set(_DEFINE_AVX512F "HAVE_AVX512F" ${_DEFINE_AVX2}) set(_DEFINE_NEON_FP16 "HAVE_NEON_FP16" ${_DEFINE_ANY}) +set(_DEFINE_SVE "HAVE_SVE" ${_DEFINE_SVE}) ## Arch specific compile options ov_avx512_optimization_flags(_FLAGS_AVX512F) ov_avx2_optimization_flags (_FLAGS_AVX2) ov_sse42_optimization_flags (_FLAGS_SSE42) ov_arm_neon_fp16_optimization_flags(_FLAGS_NEON_FP16) +ov_arm_sve_optimization_flags(_FLAGS_SVE) set(_FLAGS_AVX "") ## TBD is not defined for OV project yet set(_FLAGS_ANY "") ## @@ -185,6 +188,8 @@ endfunction() function(_currently_requested_top_arch VAR) if(ENABLE_NEON_FP16) set(RES NEON_FP16) + elseif(ENABLE_SVE) + set(RES SVE) elseif(ENABLE_AVX512F) set(RES AVX512F) elseif(ENABLE_AVX2) diff --git a/cmake/developer_package/features.cmake b/cmake/developer_package/features.cmake index 8d1f3696c6759c..ae5313cea8a8b4 100644 --- a/cmake/developer_package/features.cmake +++ b/cmake/developer_package/features.cmake @@ -51,6 +51,8 @@ ov_dependent_option (ENABLE_AVX512F "Enable AVX512 optimizations" ON "X86_64 OR ov_dependent_option(ENABLE_NEON_FP16 "Enable ARM FP16 optimizations" ON "AARCH64" OFF) +ov_dependent_option(ENABLE_SVE "Enable SVE optimizations" ON "AARCH64" OFF) + # Type of build, we add this as an explicit option to default it to ON get_property(BUILD_SHARED_LIBS_DEFAULT GLOBAL PROPERTY TARGET_SUPPORTS_SHARED_LIBS) ov_option (BUILD_SHARED_LIBS "Build as a shared library" ${BUILD_SHARED_LIBS_DEFAULT}) diff --git a/src/inference/dev_api/openvino/runtime/system_conf.hpp b/src/inference/dev_api/openvino/runtime/system_conf.hpp index 59d56dfdd49d73..bebc2014ab8028 100644 --- a/src/inference/dev_api/openvino/runtime/system_conf.hpp +++ b/src/inference/dev_api/openvino/runtime/system_conf.hpp @@ -83,6 +83,13 @@ OPENVINO_RUNTIME_API bool with_cpu_x86_sse42(); */ OPENVINO_RUNTIME_API bool with_cpu_neon_fp16(); +/** + * @brief Checks whether CPU supports ARM SVE capability + * @ingroup ov_dev_api_system_conf + * @return `True` if ARM SVE instructions are available, `false` otherwise + */ +OPENVINO_RUNTIME_API bool with_cpu_sve(); + /** * @brief Checks whether CPU supports AVX capability * @ingroup ov_dev_api_system_conf diff --git a/src/inference/src/system_conf.cpp b/src/inference/src/system_conf.cpp index 27c671d07ad5c9..3227b1a3034903 100644 --- a/src/inference/src/system_conf.cpp +++ b/src/inference/src/system_conf.cpp @@ -22,6 +22,7 @@ # include # define ARM_COMPUTE_CPU_FEATURE_HWCAP_FPHP (1 << 9) # define ARM_COMPUTE_CPU_FEATURE_HWCAP_ASIMDHP (1 << 10) +# define ARM_COMPUTE_CPU_FEATURE_HWCAP_SVE (1 << 24) #elif defined(__APPLE__) && defined(__aarch64__) # include # include @@ -114,6 +115,10 @@ bool with_cpu_neon_fp16() { return false; } +bool with_cpu_sve() { + return false; +} + #else // OPENVINO_ARCH_X86 || OPENVINO_ARCH_X86_64 bool with_cpu_x86_sse42() { @@ -173,6 +178,20 @@ bool with_cpu_neon_fp16() { return false; # endif } +bool with_cpu_sve() { +# if !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \ + !defined(__arm__) && defined(__aarch64__) + const uint32_t hwcaps = getauxval(AT_HWCAP); + return hwcaps & ARM_COMPUTE_CPU_FEATURE_HWCAP_SVE; +# elif !defined(_WIN64) && !defined(BARE_METAL) && !defined(__APPLE__) && !defined(__OpenBSD__) && \ + !defined(__aarch64__) && defined(__arm__) + return false; +# elif defined(__aarch64__) && defined(__APPLE__) + return false; +# else + return false; +# endif +} #endif // OPENVINO_ARCH_X86 || OPENVINO_ARCH_X86_64 bool check_open_mp_env_vars(bool include_omp_num_threads) { diff --git a/src/plugins/intel_cpu/CMakeLists.txt b/src/plugins/intel_cpu/CMakeLists.txt index eb56a3fb39503e..aa6ce49a051e00 100644 --- a/src/plugins/intel_cpu/CMakeLists.txt +++ b/src/plugins/intel_cpu/CMakeLists.txt @@ -278,6 +278,24 @@ target_include_directories(${TARGET_NAME} SYSTEM PRIVATE $) +# ARCH lists for softmax.cpp and mha_single_token.cpp +# Based on result of above calls, decide whether to add SVE +set(SOFTMAX_ARCH_LIST AVX512F AVX2) +set(MHA_SINGLE_TOKEN_ARCH_LIST AVX512F AVX2) + +if(ENABLE_NEON_FP16) + list(APPEND SOFTMAX_ARCH_LIST NEON_FP16) + list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST NEON_FP16) +endif() + +if(ENABLE_SVE) + list(APPEND SOFTMAX_ARCH_LIST SVE) + list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST SVE) +endif() + +list(APPEND SOFTMAX_ARCH_LIST ANY) +list(APPEND MHA_SINGLE_TOKEN_ARCH_LIST ANY) + # Cross compiled function # TODO: The same for proposal, proposalONNX, topk cross_compiled_file(${TARGET_NAME} @@ -288,14 +306,14 @@ cross_compiled_file(${TARGET_NAME} NAMESPACE ov::Extensions::Cpu::XARCH ) cross_compiled_file(${TARGET_NAME} - ARCH AVX512F AVX2 NEON_FP16 ANY + ARCH ${SOFTMAX_ARCH_LIST} src/nodes/kernels/scaled_attn/softmax.cpp API src/nodes/kernels/scaled_attn/softmax.hpp NAME attn_softmax NAMESPACE ov::Extensions::Cpu::XARCH ) cross_compiled_file(${TARGET_NAME} - ARCH AVX512F AVX2 NEON_FP16 ANY + ARCH ${MHA_SINGLE_TOKEN_ARCH_LIST} src/nodes/kernels/scaled_attn/mha_single_token.cpp API src/nodes/kernels/scaled_attn/mha_single_token.hpp NAME mha_single_token diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp index 4e14cf5894b04d..63cbbb4464ee92 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/common.hpp @@ -17,6 +17,9 @@ #endif #if defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) +# include "arm_sve.h" +# endif # include "arm_neon.h" #endif @@ -35,6 +38,10 @@ static constexpr size_t vec_len_f32_avx2 = vec_len_avx2 / sizeof(float); static constexpr size_t vec_len_f32_neon = vec_len_neon / sizeof(float); static constexpr size_t vec_len_f16_neon = vec_len_neon / sizeof(ov::float16); +#if defined(HAVE_SVE) +static constexpr size_t vec_len_f32_sve = svcntw(); +#endif + #ifdef HAVE_AVX512F inline __m512 cvt_bf16_to_fp32(const __m256i src) { __m512i y = _mm512_cvtepu16_epi32(src); @@ -250,6 +257,79 @@ inline void hmin(__m256& x) { #endif #ifdef OPENVINO_ARCH_ARM64 +# if defined(HAVE_SVE) +inline svfloat32_t exp_ps_sve(svbool_t& pg, svfloat32_t& src) { + // Constants + const auto log2_e = svdup_n_f32(1.4426950409f); + const auto ln2 = svdup_n_f32(0.6931473921f); + const auto half_ln2_sq = svdup_n_f32(0.2413862043f); + const auto not_mask17 = svdup_n_u32(~((1u << 17) - 1)); + const auto one = svdup_n_f32(1.0f); + + // Algorithm starts here + svfloat32_t t0 = svmul_f32_z(pg, src, log2_e); // y = x * log2(e) + svfloat32_t t1 = svrintm_f32_z(pg, t0); // rount to int (float) + svint32_t t2 = svcvt_s32_f32_z(pg, t1); // n + + t1 = svsub_f32_z(pg, t0, t1); // a = y - floor(y) + t1 = svadd_f32_z(pg, t1, one); // b = a + 1 + + svuint32_t t3 = svlsr_n_u32_z(pg, svreinterpret_u32_f32(t1), 17); // v = b >> 17 (u32) + svfloat32_t t4 = svexpa_f32(t3); // c = fexpa(v) + t4 = svscale_f32_z(pg, t4, t2); // fexpa(v) * 2^(n) + + // and_(t2.d, t1.d, not_mask17.d) + svfloat32_t t5 = svreinterpret_f32_u32(svand_u32_z(pg, svreinterpret_u32_f32(t1), not_mask17)); + t5 = svsub_f32_z(pg, t1, t5); // z + t0 = svmla_f32_z(pg, ln2, t5, half_ln2_sq); // ln2 + half_ln2_sq * z + t0 = svmla_f32_z(pg, one, t5, t0); // 1 + (ln2 * z) + (half_ln2_sq * z * z) + t0 = svmul_f32_z(pg, t0, t4); // Final result + + return t0; +} +inline svfloat32_t exp_ps_sve_legacy(svbool_t& pg, svfloat32_t& src) { + const auto c1 = svreinterpret_f32_u32(svdup_n_u32(0x3f7ffff6)); + const auto c2 = svreinterpret_f32_u32(svdup_n_u32(0x3efffedb)); + const auto c3 = svreinterpret_f32_u32(svdup_n_u32(0x3e2aaf33)); + const auto c4 = svreinterpret_f32_u32(svdup_n_u32(0x3d2b9f17)); + const auto c5 = svreinterpret_f32_u32(svdup_n_u32(0x3c072010)); + + const auto shift = svreinterpret_f32_u32(svdup_n_u32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto one = svdup_n_f32(1.0f); // 1 + const auto two = svdup_n_f32(2.0f); // 2 + const auto inv_ln2 = svreinterpret_f32_u32(svdup_n_u32(0x3fb8aa3b)); + const auto neg_ln2_hi = svreinterpret_f32_u32(svdup_n_u32(0xbf317200)); + const auto neg_ln2_lo = svreinterpret_f32_u32(svdup_n_u32(0xb5bfbe8e)); + + const auto inf = svdup_n_f32(std::numeric_limits::infinity()); + const auto max_input = svdup_n_f32(88.37f); // Approximately ln(2^127.5) + const auto zero = svdup_n_f32(0.f); + const auto min_input = svdup_n_f32(-86.64f); // Approximately ln(2^-125) + + const auto z = svmla_f32_z(pg, shift, src, inv_ln2); + auto n = svsub_f32_z(pg, z, shift); + n = svsub_f32_z(pg, n, one); + const auto scale = svreinterpret_f32_u32(svlsl_n_u32_z(pg, svreinterpret_u32_f32(z), 23)); // 2^n + + const auto r_hi = svmla_f32_z(pg, src, n, neg_ln2_hi); + const auto r = svmla_f32_z(pg, r_hi, n, neg_ln2_lo); + const auto r2 = svmul_f32_z(pg, r, r); + + const auto p1 = svmul_f32_z(pg, c1, r); + const auto p23 = svmla_f32_z(pg, c2, c3, r); + const auto p45 = svmla_f32_z(pg, c4, c5, r); + const auto p2345 = svmla_f32_z(pg, p23, p45, r2); + const auto p12345 = svmla_f32_z(pg, p1, p2345, r2); + + auto poly = svmla_f32_z(pg, scale, p12345, scale); + poly = svmul_f32_z(pg, poly, two); + + poly = svsel_f32(svcmplt_f32(pg, src, min_input), zero, poly); + poly = svsel_f32(svcmpgt_f32(pg, src, max_input), inf, poly); + + return poly; +} +# endif inline float32x4_t exp_ps_neon_f32(const float32x4_t& src) { const auto c1 = vreinterpretq_f32_u32(vdupq_n_u32(0x3f7ffff6)); const auto c2 = vreinterpretq_f32_u32(vdupq_n_u32(0x3efffedb)); diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp index f2180b5314cc07..5a6f0d66f1f221 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/mha_single_token.cpp @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -20,6 +21,9 @@ #include "softmax_kernel.hpp" #if defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) +# include +# endif # include #endif @@ -59,19 +63,35 @@ void cvt_copy(TA* dst, TB* src, size_t n) { mm256_uni_storeu_ps(dst + i, vb); } #elif defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) + auto _dst = reinterpret_cast(dst); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < n) { + if (n - i < vec_len_f32_sve) { + inc = n - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + svfloat32_t b1 = svld1_f32(pg, src + i); + svst1_f32(pg, _dst + i, b1); + i += inc; + } +# else if (std::is_same::value && std::is_same::value) { for (; i + vec_len_f32_neon <= n; i += vec_len_f32_neon) { float32x4_t vb1 = __vld1q_f32(src + i); __vst1q_f32(dst + i, vb1); } } -# if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) +# if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) if (std::is_same::value && std::is_same::value) { for (; i + vec_len_f16_neon <= n; i += vec_len_f16_neon) { auto vb1 = vld1q_f16(reinterpret_cast(src + i)); vst1q_f16(reinterpret_cast(dst + i), vb1); } } +# endif # endif #endif for (; i < n; i++) { @@ -99,6 +119,27 @@ static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scal mm256_uni_storeu_ps(out + i, v_out); } #elif defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) + auto _v = reinterpret_cast(v); + svfloat32_t attn_w_vec_fp32 = svdup_n_f32(weight); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < S) { + if (S - i < vec_len_f32_sve) { + inc = S - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + svfloat32_t v_value = svld1_f32(pg, _v + i); + svfloat32_t v_out = svld1_f32(pg, out + i); + + // svmla with merging to preserve inactive lane values when there's ... + // fewer than vec_len elements left + v_out = svmla_f32_m(pg, v_out, attn_w_vec_fp32, v_value); + svst1_f32(pg, out + i, v_out); + i += inc; + } +# else float32x4_t attn_w_vec_fp32 = vdupq_n_f32(weight); for (; i + vec_len_f32_neon <= S; i += vec_len_f32_neon) { float32x4_t v_value = __vld1q_f32(v + i); @@ -106,6 +147,7 @@ static void attn_acc_value(float* out, float weight, T* v, size_t S, float* scal v_out = vmlaq_f32(v_out, attn_w_vec_fp32, v_value); __vst1q_f32(out + i, v_out); } +# endif #endif for (; i < S; i++) { out[i] += weight * v[i]; @@ -356,7 +398,50 @@ static float sum_q_head(T* a, size_t n) { hsum(vsum0); sum = _mm256_cvtss_f32(vsum0); #elif defined(OPENVINO_ARCH_ARM64) - size_t vec_len_f32_neon = 4; +# if defined(HAVE_SVE) + svfloat32_t sum0 = svdup_n_f32(0.0f); + svfloat32_t sum1 = svdup_n_f32(0.0f); + svfloat32_t sum2 = svdup_n_f32(0.0f); + svfloat32_t sum3 = svdup_n_f32(0.0f); + svbool_t pg = svptrue_b32(); + + for (; i + 4 * vec_len_f32_sve <= n; i += 4 * vec_len_f32_sve) { + svfloat32_t a0 = svld1_f32(pg, a + i); + svfloat32_t a1 = svld1_f32(pg, a + i + vec_len_f32_sve); + svfloat32_t a2 = svld1_f32(pg, a + i + vec_len_f32_sve * 2); + svfloat32_t a3 = svld1_f32(pg, a + i + vec_len_f32_sve * 3); + + sum0 = svadd_f32_z(pg, a0, sum0); + sum1 = svadd_f32_z(pg, a1, sum1); + sum2 = svadd_f32_z(pg, a2, sum2); + sum3 = svadd_f32_z(pg, a3, sum3); + } + if (i + 2 * vec_len_f32_sve <= n) { + svfloat32_t a0 = svld1_f32(pg, a + i); + svfloat32_t a1 = svld1_f32(pg, a + i + vec_len_f32_sve); + + sum0 = svadd_f32_z(pg, a0, sum0); + sum1 = svadd_f32_z(pg, a1, sum1); + i += 2 * vec_len_f32_sve; + } + if (i + vec_len_f32_sve <= n) { + svfloat32_t a0 = svld1_f32(pg, a + i); + sum0 = svadd_f32_z(pg, a0, sum0); + i += vec_len_f32_sve; + } + // Process tail elements parallely as well (if any) + if (i != n) { + svbool_t pg_rem = svwhilelt_b32(0, static_cast(n - i)); + svfloat32_t a0 = svld1_f32(pg_rem, a + i); + sum0 = svadd_f32_m(pg_rem, sum0, a0); + i = n; + } + float32_t sum_0 = svaddv_f32(pg, sum0); + float32_t sum_1 = svaddv_f32(pg, sum1); + float32_t sum_2 = svaddv_f32(pg, sum2); + float32_t sum_3 = svaddv_f32(pg, sum3); + sum = static_cast(sum_0 + sum_1 + sum_2 + sum_3); +# else float32x4_t vsum0 = vdupq_n_f32(0.0f); float32x4_t vsum1 = vdupq_n_f32(0.0f); float32x4_t vsum2 = vdupq_n_f32(0.0f); @@ -396,8 +481,8 @@ static float sum_q_head(T* a, size_t n) { sum_low = vadd_f32(sum_low, sum_high); sum_low = vpadd_f32(sum_low, sum_low); sum = vget_lane_f32(sum_low, 0); +# endif #endif - for (; i < n; i++) { float tmp = a[i]; sum += tmp; @@ -496,6 +581,63 @@ static float dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float* sum = _mm256_cvtss_f32(vsum0); #elif defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) + svbool_t pg = svptrue_b32(); + svfloat32_t sum0 = svdup_n_f32(0.0f); + svfloat32_t sum1 = svdup_n_f32(0.0f); + svfloat32_t sum2 = svdup_n_f32(0.0f); + svfloat32_t sum3 = svdup_n_f32(0.0f); + + auto _a = reinterpret_cast(a); + auto _b = reinterpret_cast(b); + + for (; i + 4 * vec_len_f32_sve <= n; i += 4 * vec_len_f32_sve) { + svfloat32_t a0 = svld1_f32(pg, _a + i); + svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len_f32_sve); + svfloat32_t a2 = svld1_f32(pg, _a + i + vec_len_f32_sve * 2); + svfloat32_t a3 = svld1_f32(pg, _a + i + vec_len_f32_sve * 3); + + svfloat32_t b0 = svld1_f32(pg, _b + i); + svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len_f32_sve); + svfloat32_t b2 = svld1_f32(pg, _b + i + vec_len_f32_sve * 2); + svfloat32_t b3 = svld1_f32(pg, _b + i + vec_len_f32_sve * 3); + + sum0 = svmla_f32_z(pg, sum0, a0, b0); + sum1 = svmla_f32_z(pg, sum1, a1, b1); + sum2 = svmla_f32_z(pg, sum2, a2, b2); + sum3 = svmla_f32_z(pg, sum3, a3, b3); + } + if (i + 2 * vec_len_f32_sve <= n) { + svfloat32_t a0 = svld1_f32(pg, _a + i); + svfloat32_t a1 = svld1_f32(pg, _a + i + vec_len_f32_sve); + + svfloat32_t b0 = svld1_f32(pg, _b + i); + svfloat32_t b1 = svld1_f32(pg, _b + i + vec_len_f32_sve); + + sum0 = svmla_f32_z(pg, sum0, a0, b0); + sum1 = svmla_f32_z(pg, sum1, a1, b1); + i += 2 * vec_len_f32_sve; + } + if (i + vec_len_f32_sve <= n) { + svfloat32_t a0 = svld1_f32(pg, _a + i); + svfloat32_t b0 = svld1_f32(pg, _b + i); + sum0 = svmla_f32_z(pg, sum0, a0, b0); + i += vec_len_f32_sve; + } + // Process the tail elements parallely as well (if any) + if (i != n) { + svbool_t pg_rem = svwhilelt_b32(0, static_cast(n - i)); + svfloat32_t a0 = svld1_f32(pg_rem, _a + i); + svfloat32_t b0 = svld1_f32(pg_rem, _b + i); + sum0 = svmla_f32_m(pg_rem, sum0, a0, b0); + i = n; + } + float32_t sum_0 = svaddv_f32(pg, sum0); + float32_t sum_1 = svaddv_f32(pg, sum1); + float32_t sum_2 = svaddv_f32(pg, sum2); + float32_t sum_3 = svaddv_f32(pg, sum3); + sum = static_cast(sum_0 + sum_1 + sum_2 + sum_3); +# else float32x4_t vsum0 = vdupq_n_f32(0.0f); float32x4_t vsum1 = vdupq_n_f32(0.0f); float32x4_t vsum2 = vdupq_n_f32(0.0f); @@ -542,8 +684,8 @@ static float dot_product(TA* a, TB* b, size_t n, float* scale, float* zp, float* float32x2_t temp_sum = vadd_f32(vget_low_f32(vsum0), vget_high_f32(vsum0)); temp_sum = vpadd_f32(temp_sum, temp_sum); sum = vget_lane_f32(temp_sum, 0); +# endif #endif - for (; i < n; i++) { sum += a[i] * b[i]; } @@ -794,6 +936,28 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str mm256_uni_storeu_ps(dst + i, result_vec_fp32); } #elif defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) + auto _dst = reinterpret_cast(dst); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < S) { + if (S - i < vec_len_f32_sve) { + inc = S - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + auto* src = temp + i; + auto result_vec_fp32 = svdup_n_f32(0.0f); + + for (size_t m = 0; m < M; m++) { + auto o_vec_fp32 = svld1_f32(pg, src); + result_vec_fp32 = svadd_f32_m(pg, result_vec_fp32, o_vec_fp32); + src += temp_stride; + } + svst1_f32(pg, _dst + i, result_vec_fp32); + i += inc; + } +# else for (; i + vec_len_f32_neon <= S; i += vec_len_f32_neon) { auto* src = temp + i; auto result_vec_fp32 = vdupq_n_f32(0.0f); @@ -804,6 +968,7 @@ static void attn_reduce(T* dst, float* temp, size_t M, size_t S, size_t temp_str } __vst1q_f32(dst + i, result_vec_fp32); } +# endif #endif for (; i < S; i++) { auto* src = temp + i; @@ -1262,6 +1427,7 @@ void mha_single_token(const ov::intel_cpu::PlainTensor& query, OPENVINO_THROW("Unsupported precision: ", query.get_precision()); } } + } // namespace XARCH } // namespace Cpu } // namespace Extensions diff --git a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp index 48b92b53fa2727..35aab5b59c7d0e 100644 --- a/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp +++ b/src/plugins/intel_cpu/src/nodes/kernels/scaled_attn/softmax_kernel.hpp @@ -12,6 +12,9 @@ #include "openvino/core/type/element_type.hpp" #if defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) +# include "arm_sve.h" +# endif # include "arm_neon.h" #endif @@ -657,6 +660,28 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float& hsum(v_sum); sum = _mm256_cvtss_f32(v_sum); #elif defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) + svfloat32_t v_a; + svfloat32_t v_max = svdup_n_f32(max); + svfloat32_t v_sum = svdup_n_f32(0.0f); + size_t vec_len_f32_sve = svcntw(); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < size) { + if (size - i < vec_len_f32_sve) { + inc = size - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + v_a = svld1_f32(pg, a + i); + v_a = svsub_f32_z(pg, v_a, v_max); + v_a = exp_ps_sve(pg, v_a); + v_sum = svadd_f32_m(pg, v_sum, v_a); + svst1_f32(pg, a + i, v_a); + i += inc; + } + sum = svaddv_f32(svptrue_b32(), v_sum); +# else float32x4_t v_a; float32x4_t v_max = vdupq_n_f32(max); float32x4_t v_sum = vdupq_n_f32(0.0f); @@ -670,7 +695,7 @@ inline void exp_reduce_sum(float* a, const float max, const size_t size, float& i += vec_len_f32_neon; } sum = vaddvq_f32(v_sum); - +# endif #endif for (; i < size; i++) { a[i] = exp(a[i] - max); @@ -781,6 +806,22 @@ inline void multiply_scalar(float* a, float* a_dst, const float val, const size_ i += (size - i); } #elif defined(OPENVINO_ARCH_ARM64) +# if defined(HAVE_SVE) + svfloat32_t v_scale = svdup_n_f32(val); + size_t inc = vec_len_f32_sve; + svbool_t pg = svptrue_b32(); + + while (i < size) { + if (size - i < vec_len_f32_sve) { + inc = size - i; + pg = svwhilelt_b32(0, static_cast(inc)); + } + svfloat32_t v_a = svld1_f32(pg, a + i); + v_a = svmul_f32_z(pg, v_a, v_scale); + svst1_f32(pg, a_dst + i, v_a); + i += inc; + } +# else float32x4_t v_scale = vdupq_n_f32(val); while (i + vec_len_f32_neon <= size) { float32x4_t v_a = vld1q_f32(a + i); @@ -788,6 +829,7 @@ inline void multiply_scalar(float* a, float* a_dst, const float val, const size_ vst1q_f32(a_dst + i, v_a); i += vec_len_f32_neon; } +# endif #endif for (; i < size; i++) { a_dst[i] = a[i] * val; @@ -972,7 +1014,7 @@ inline void attn_softmax_kernel(float* a, // divide sum float scalar = 1.0f / sum; if (dst_precision == ov::element::f32) { - multiply_scalar(a, static_cast(a_dst), scalar, len); + multiply_scalar(a, reinterpret_cast(a_dst), scalar, len); // apply causual mask to final result instead of attn_score if (total_size > len) memset(static_cast(a_dst) + len, 0, sizeof(float) * (total_size - len)); From 8f0094dabda2dfe02c8414fd13f7d268c06ce6c7 Mon Sep 17 00:00:00 2001 From: Chenhu Wang Date: Tue, 17 Dec 2024 13:12:21 +0800 Subject: [PATCH 2/5] [CPU] sns f16_mha_on_avx512_core_amx_f16_target (#27514) ### Details: - *support f16 precision mha on GNR* ### Tickets: - *CVS-122494, CVS-122495* --- src/common/snippets/src/op/brgemm.cpp | 3 +- .../x64/jit_brgemm_copy_b_emitter.hpp | 3 +- .../snippets/x64/jit_brgemm_emitter.cpp | 3 +- src/plugins/intel_cpu/src/nodes/subgraph.cpp | 15 +- .../snippets/x64/op/brgemm_copy_b.cpp | 2 +- .../snippets/x64/op/brgemm_utils.cpp | 16 ++- .../snippets/x64/op/brgemm_utils.hpp | 4 +- .../x64/pass/brgemm_to_brgemm_cpu.hpp | 2 +- .../snippets/x64/pass/enforce_precision.cpp | 11 +- .../transformation_pipeline.cpp | 28 ++-- .../custom/subgraph_tests/src/x64/mha.cpp | 22 ++- .../skip_tests_config.cpp | 5 + .../shared_tests_instances/snippets/mha.cpp | 135 ++++++++++++++---- .../shared_tests_instances/snippets/utils.hpp | 11 ++ .../plugin/shared/src/snippets/mha.cpp | 2 + 15 files changed, 201 insertions(+), 61 deletions(-) diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 72fc692fff5d70..7190074c8ae30b 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -87,7 +87,8 @@ ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, con const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1); const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8; const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1); - if (is_f32 || is_bf16) { + const bool is_f16 = utils::everyone_is(element::f16, in_type0, in_type1); + if (is_f32 || is_bf16 || is_f16) { return element::f32; } else if (is_int8) { return element::i32; diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp index 96a80153bba4b6..d937e646b603da 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp @@ -17,13 +17,12 @@ class jit_brgemm_copy_b_emitter : public jit_emitter { const ov::snippets::lowered::ExpressionPtr& expr, const snippets::KernelExecutorTablePtr& kernel_table, const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache); - size_t get_inputs_num() const override { return 1; } static std::set> get_supported_precisions( const std::shared_ptr& node = nullptr) { - return {{element::i8}, {element::bf16}, {element::f32}}; + return {{element::i8}, {element::bf16}, {element::f16}, {element::f32}}; } private: diff --git a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp index 172a1cc0b98284..8d343cec908732 100644 --- a/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp +++ b/src/plugins/intel_cpu/src/emitters/snippets/x64/jit_brgemm_emitter.cpp @@ -79,7 +79,8 @@ std::set> jit_brgemm_emitter::get_supported_precision } else if (brgemm->get_type() == BRGEMM_TYPE::WITH_AMX) { return {{element::i8, element::i8, element::u8}, {element::u8, element::i8, element::u8}, - {element::bf16, element::bf16, element::u8}}; + {element::bf16, element::bf16, element::u8}, + {element::f16, element::f16, element::u8}}; } OV_CPU_JIT_EMITTER_THROW("got BrgemmCPU node with unsupported type"); } diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 94e01cd89a39fa..2b0c7b55fb043d 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -458,11 +458,12 @@ void Subgraph::initSupportedPrimitiveDescriptors() { config.inConfs.resize(inputShapes.size()); for (size_t i = 0; i < inputShapes.size(); i++) { const auto originalInputPrecision = getOriginalInputPrecisionAtPort(i); - const auto precision = ((originalInputPrecision == ov::element::f32) && - context->getConfig().inferencePrecision == ov::element::bf16 && - subgraph_attrs->snippet->has_domain_sensitive_ops()) - ? static_cast(ov::element::bf16) - : originalInputPrecision; + const auto precision = + ((originalInputPrecision == ov::element::f32) && + one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && + subgraph_attrs->snippet->has_domain_sensitive_ops()) + ? context->getConfig().inferencePrecision + : originalInputPrecision; if (supportedPrecisions.count(precision) == 0) OPENVINO_THROW("Subgraph node with name `", getName(), "` doesn't support ", precision, " precision."); @@ -653,7 +654,7 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { ov::snippets::pass::Canonicalization, ov::snippets::pass::AnalyzeBroadcastableInputs, broadcastable_inputs); - if (context->getConfig().inferencePrecision == ov::element::bf16 && + if (one_of(context->getConfig().inferencePrecision, ov::element::bf16, ov::element::f16) && subgraph_attrs->snippet->has_domain_sensitive_ops()) { // enforce BF16 precisions to supported operations // MatMul has to be decomposed to Brgemm operations before enforcement @@ -663,7 +664,7 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() { ov::snippets::pass::MatMulToBrgemm, pass::EnforcePrecision, element::f32, - element::bf16); + context->getConfig().inferencePrecision); } SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, ov::snippets::pass::PropagatePrecision, diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp index a513299a516f5f..7e52905145869f 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp @@ -91,7 +91,7 @@ void BrgemmCopyB::validate_and_infer_types() { } void BrgemmCopyB::validate_element_type(const ov::element::Type& element_type) { - OPENVINO_ASSERT(one_of(element_type, element::f32, element::bf16, element::i8), + OPENVINO_ASSERT(one_of(element_type, element::f32, element::bf16, element::f16, element::i8), "BrgemmCopyB doesn't support element type" + element_type.get_type_name()); } diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp index e1802d2914127a..386941fd94bb98 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.cpp @@ -35,7 +35,12 @@ cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx) { // Note: AMX might be not used even if it's supported by the hardware, check the BrgemmToBrgemmCPU pass for details if (is_with_amx) { - SUPPORT_ONE(avx512_core_amx, "Unsupported hardware configuration: amx is supported only on avx512 platforms") + if (dt_in0 == ov::element::f16) + SUPPORT_ONE(avx512_core_amx_fp16, + "Unsupported hardware configuration: amx is supported only on avx512 platforms") + else + SUPPORT_ONE(avx512_core_amx, + "Unsupported hardware configuration: amx is supported only on avx512 platforms") } else if (dt_in0 == ov::element::bf16) { SUPPORT_ONE(avx512_core_bf16, "Unsupported hardware configuration: bf16 is supported only on avx512 platforms") } else if (one_of(dt_in0, ov::element::u8, ov::element::i8)) { @@ -59,12 +64,15 @@ BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, bool transp return transpose_b ? BRGEMM_TYPE::REPACKING_ONLY : BRGEMM_TYPE::STAND_ALONE; OPENVINO_ASSERT(element_type_a != element::bf16 || mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16), - "BF16 precision is not supported on this hardware"); + "BrgemmCPU BF16 precision is not supported on non avx512_core_bf16 system"); + OPENVINO_ASSERT(element_type_a != element::f16 || mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16), + "BrgemmCPU FP16 precision is not supported on non avx512_core_amx_fp16 system"); if (one_of(element_type_a, element::u8, element::i8, element::bf16) && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) return BRGEMM_TYPE::WITH_AMX; - + if (element_type_a == ov::element::f16 && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16)) + return BRGEMM_TYPE::WITH_AMX; // Note: this condition reproduces logic from the OneDNN Brgemm implementation. This is needed to align with the // backend requirements. More details in onednn/src/cpu/x64/brgemm/brgemm_utils.cpp if (element_type_a == ov::element::i8) @@ -96,6 +104,8 @@ size_t compute_inner_n_block(const ov::element::Type& precision) { return 64; case element::bf16: return 32; + case element::f16: + return 32; case element::f32: return 16; default: diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp index 46428828e7139c..b5f470c1c695ba 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_utils.hpp @@ -15,8 +15,8 @@ namespace intel_cpu { namespace brgemm_utils { enum class BRGEMM_TYPE { - STAND_ALONE, // No extra requirements, used for f32|f32 - WITH_AMX, // i8|i8 or bf16|bf16 on AMX system - needs BrgemmCopyB and scratchpad + STAND_ALONE, // No extra requirements, used for f32|f32 + WITH_AMX, // i8|i8 or bf16|bf16 on AMX system or fp16|fp16 on AMX_FP16 system - needs BrgemmCopyB and scratchpad WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations REPACKING_ONLY, // u8|i8, or bf16|bf16 (non-AMX system), or brgemm with transpose_b=true - needs BrgemmCopyB on // second input for data repacking diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp index 2cbf2d7e087919..9475171b24f65d 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp @@ -25,7 +25,7 @@ namespace pass { * \ Buffer (with repacked data) Buffer (with compensations) * \ | / * BrgemmCPU - * - f32|f32 with transpose_b, u8|i8, i8|i8 or bf16|bf16 on AMX system: + * - f32|f32 with transpose_b, u8|i8, i8|i8 or bf16|bf16 on AMX system or fp16|fp16 on AMX_FP16 system: * \ BrgemmCopyB * \ Buffer (with repacked data) Buffer (with new memory) * \ | / diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp index 92b5be2692f3b2..6b7d5d31a5b12f 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/enforce_precision.cpp @@ -121,9 +121,12 @@ bool EnforcePrecision::run_on_model(const std::shared_ptr& f) { std::set> EnforcePrecision::get_supported_precisions_default( const std::shared_ptr& op) noexcept { - if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16) && - ov::is_type(op)) { - return {{element::bf16, element::bf16}}; + std::set> types; + if (ov::is_type(op)) { + if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16)) + types.insert({element::f16, element::f16}); + if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16)) + types.insert({element::bf16, element::bf16}); } - return {}; + return types; } diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 909f6b7531d421..4013c1c3cd84f9 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -866,7 +866,7 @@ void Transformations::PostLpt() { postLPTPassManager, [](const std::shared_ptr& node) -> bool { if (!ov::is_type(node) && - node->get_output_element_type(0) != node->get_input_element_type(0)) + node->get_output_element_type(0).size() > node->get_input_element_type(0).size()) return true; if (node->get_input_size() >= 2) { return node->get_input_element_type(1) == ov::element::i8 || @@ -986,7 +986,7 @@ void Transformations::MainSnippets(void) { // MatMul and Result. However there may be Convert [f32->bf16] before Result since: // - bf16 Brgemm has f32 output; // - CPU Node Subgraph requires bf16 on output when inference precision is bf16. - // To avoid sitations when Transpose is not alone node between MatMul and Result, + // To avoid situations when Transpose is not alone node between MatMul and Result, // Plugin disables Transpose tokenization on output bool mha_token_enable_transpose_on_output = one_of(config.inferencePrecision, element::f32, element::undefined); size_t concurrency = config.streamExecutorConfig.get_threads_per_stream(); @@ -1023,6 +1023,7 @@ void Transformations::MainSnippets(void) { ov::pass::Manager snippetsManager("CPU:Snippets"); snippetsManager.set_per_pass_validation(false); + // if callback needed for better perf, enable SnippetsMarkSkipped, and disable TokenizeFCSnippets. if (!ignoreCallback) { #if defined(OPENVINO_ARCH_ARM64) CPU_REGISTER_PASS_ARM(snippetsManager, SnippetsMarkSkipped); @@ -1033,9 +1034,7 @@ void Transformations::MainSnippets(void) { } CPU_REGISTER_PASS_COMMON(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config); - // - MHA has BRGEMM that is supported only on AVX512 platforms - // - CPU Plugin Subgraph supports only f32, bf16 (and quantized) BRGEMM - // [122494] Need to add support of f16 + // - CPU Plugin Subgraph supports f32, bf16, quantized and fp16(on avx_512_core_amx_fp16 target) BRGEMM const bool isMHASupported = #if defined(OPENVINO_ARCH_ARM64) false; @@ -1043,7 +1042,9 @@ void Transformations::MainSnippets(void) { (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) && one_of(config.inferencePrecision, ov::element::f32, element::undefined)) || (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) && - one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)); + one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) || + (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) && + one_of(config.inferencePrecision, ov::element::f16)); #endif if (!isMHASupported) { CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets); @@ -1059,13 +1060,13 @@ void Transformations::MainSnippets(void) { const auto in_type1 = matmul->get_input_element_type(1); const auto is_fp32 = (in_type0 == ov::element::f32 && in_type1 == ov::element::f32 && one_of(config.inferencePrecision, element::f32, element::undefined)); - const auto is_fp16 = (in_type0 == ov::element::f16 || in_type1 == ov::element::f16); + const auto is_fp16 = + (in_type0 == ov::element::f16 || in_type1 == ov::element::f16) || + (in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::f16); const auto is_bf16 = (in_type0 == ov::element::bf16 && in_type1 == ov::element::bf16) || ((in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::bf16)); const auto is_int8 = in_type0 == ov::element::i8; - if (is_fp16) - return false; if (is_fp32) return true; // Only FP32 dynamic MHA is supported @@ -1076,13 +1077,14 @@ void Transformations::MainSnippets(void) { // brgemm_copy_b kernel if (matmul->get_transpose_a() || matmul->get_transpose_b()) return false; - // [150842] The execution of Brgemm INT8/BF16 on AMX platforms depends on the value of "K % VNNIFactor". + // [150842] The execution of Brgemm INT8/BF16/FP16 on AMX platforms depends on the value of "K % VNNIFactor". // For more details, please teake a look at the ticket 150842 if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) { const auto& b_shape = matmul->get_input_partial_shape(1); const auto K = matmul->get_transpose_b() ? *b_shape.rbegin() : *++b_shape.rbegin(); - if (is_bf16) - return K.is_static() && (K.get_length() % 2 == 0); + const size_t brgemm_vnni_factor_for_real16 = 2; // 4/2(size in term of byte for bf16/fp16) + if (is_bf16 || is_fp16) + return K.is_static() && (K.get_length() % brgemm_vnni_factor_for_real16 == 0); if (is_int8) return K.is_static(); } @@ -1091,6 +1093,8 @@ void Transformations::MainSnippets(void) { dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni); if (is_bf16) return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16); + if (is_fp16) + return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16); return true; }; auto is_unsupported_parallel_work_amount = [&](const std::shared_ptr& n, diff --git a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mha.cpp b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mha.cpp index 8517612a348f68..a94f52be91df02 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/subgraph_tests/src/x64/mha.cpp @@ -189,7 +189,7 @@ class MHATest : public testing::WithParamInterface, virtual public Sub for (size_t i = 0; i < funcInputs.size(); ++i) { const auto& funcInput = funcInputs[i]; ov::Tensor tensor; - if (funcInput.get_element_type() == ov::element::bf16) { + if (funcInput.get_element_type() == ov::element::bf16 || funcInput.get_element_type() == ov::element::f16) { ov::test::utils::InputGenerateData in_data; in_data.start_from = -1; in_data.range = 2; @@ -232,6 +232,9 @@ class MHATest : public testing::WithParamInterface, virtual public Sub configuration.insert({ov::hint::inference_precision(ov::element::bf16)}); } + if (inputPrecisions[0] == ElementType::f16) + configuration.insert({ov::hint::inference_precision(ov::element::f16)}); + // Snippets MHA tokenization has limitations to avoid performance degradations. These limitations depend on // target machine. Just for testing, we disable these limitations to allow Snippets to tokenize pattern on all // machines for validation. @@ -253,6 +256,9 @@ TEST_P(MHATest, CompareWithRefs) { if (inputPrecisions[0] == ElementType::bf16 && !ov::with_cpu_x86_bfloat16()) GTEST_SKIP(); + if (inputPrecisions[0] == ElementType::f16 && !ov::with_cpu_x86_avx512_core_amx_fp16()) + GTEST_SKIP(); + if (!ov::with_cpu_x86_avx512_core()) GTEST_SKIP(); @@ -308,6 +314,20 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(ov::test::utils::DEVICE_CPU)), MHATest::getTestCaseName); +INSTANTIATE_TEST_SUITE_P( + smoke_MHA_FP16, + MHATest, + ::testing::Combine( + ::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)), + ::testing::Values( + std::vector{ElementType::f16, ElementType::f16, ElementType::f16, ElementType::f16}), + ::testing::ValuesIn(matMulIn0Precisions), + ::testing::ValuesIn(patternTypes), + ::testing::Values(ExpectedNodes{{"Subgraph", 1}, + {"Transpose", 1}}), // Plugin disables tokenization of Transpose on output + ::testing::Values(ov::test::utils::DEVICE_CPU)), + MHATest::getTestCaseName); + } // namespace static std::shared_ptr initMHAQuantSubgraph0(std::vector& inputDynamicShapes, diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index 089a03b4d6bba7..e9b38fedc0b4e5 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -565,6 +565,11 @@ std::vector disabledTestPatterns() { retVector.emplace_back(R"(.*smoke_Snippets_MHA.*EnforceBF16.*)"); retVector.emplace_back(R"(.*ConcatSDPTest.*bf16.*)"); } + // MHA FP16 precision is only supported on amx_fp16 platform + if (!ov::with_cpu_x86_avx512_core_amx_fp16()) { + retVector.emplace_back(R"(.*smoke_Snippets_MHA.*FP16.*)"); + } + #ifdef SNIPPETS_LIBXSMM_TPP // GN in TPP requires exposing tmp Buffer results outside the loop (ticket: 151234) retVector.emplace_back(R"(.*smoke_Snippets_GroupNormalization.*)"); diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp index 63f5176684ccc1..df0b69f99ef06d 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/mha.cpp @@ -12,32 +12,41 @@ namespace snippets { namespace { -std::vector> transposedShape_4D(bool with_dynamic = true) { - auto shapes = SNIPPETS_TESTS_STATIC_SHAPES( - {{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, - {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}}, - {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}}, - {{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}}, - {{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}}); +std::vector> transposedShape_4D(bool with_static = true, bool with_dynamic = true) { + std::vector> shapes; + if (with_static) { + auto static_shapes = + SNIPPETS_TESTS_STATIC_SHAPES({{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}}, + {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}}, + {{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}}, + {{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}}, + {{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}}); + shapes.insert(shapes.end(), static_shapes.begin(), static_shapes.end()); + } if (with_dynamic) { - std::vector> dynamic_shapes = {{ - {PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}}, - {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}}, - {PartialShape{-1, -1, -1, 128}, {{1, 4, 64, 128}, {2, 2, 16, 128}, {1, 4, 72, 128}}}, - {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}}, - }, - { - {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 16, 2, 100}, {1, 128, 3, 64}}}, - {PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 128, 2, 100}, {1, 128, 1, 64}}}, - {PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {2, 2, 16, 128}, {2, 1, 128, 128}}}, - {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 128, 2, 100}, {1, 128, 3, 64}}}, - }, - { - {PartialShape{-1, -1, 12, 64}, {{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}}, - {PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}}, - {PartialShape{-1, 12, -1, -1}, {{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}}, - {PartialShape{-1, -1, 12, 64}, {{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}}, - }}; + std::vector> dynamic_shapes = { + { + {PartialShape{-1, -1, -1, 100}, {{1, 64, 4, 100}, {2, 16, 2, 100}, {1, 72, 4, 100}}}, + {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}}, + {PartialShape{-1, -1, -1, 128}, {{1, 4, 64, 128}, {2, 2, 16, 128}, {1, 4, 72, 128}}}, + {PartialShape{-1, 128, -1, 100}, {{1, 128, 4, 100}, {2, 128, 2, 100}, {1, 128, 4, 100}}}, + }, + { + {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 16, 2, 100}, {1, 128, 3, 64}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 128, 1, 64}, {2, 128, 2, 100}, {1, 128, 1, 64}}}, + {PartialShape{-1, -1, -1, -1}, {{2, 1, 128, 128}, {2, 2, 16, 128}, {2, 1, 128, 128}}}, + {PartialShape{-1, -1, -1, -1}, {{1, 128, 3, 64}, {2, 128, 2, 100}, {1, 128, 3, 64}}}, + }, + { + {PartialShape{-1, -1, 12, 64}, + {{1, 70, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 20, 12, 64}, {1, 70, 12, 64}}}, + {PartialShape{-1, -1, 12, 64}, + {{1, 35, 12, 64}, {2, 10, 12, 64}, {2, 1, 12, 64}, {2, 10, 12, 64}, {1, 35, 12, 64}}}, + {PartialShape{-1, 12, -1, -1}, + {{2, 12, 70, 35}, {1, 12, 20, 10}, {1, 12, 20, 10}, {1, 12, 20, 1}, {2, 12, 70, 35}}}, + {PartialShape{-1, -1, 12, 64}, + {{1, 35, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 10, 12, 64}, {1, 35, 12, 64}}}, + }}; shapes.insert(shapes.end(), dynamic_shapes.begin(), dynamic_shapes.end()); } return shapes; @@ -74,7 +83,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_4D, INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_4D_WithScalarMul, MHA, - ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false)), + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)), ::testing::ValuesIn(precision_f32(4)), ::testing::Values(ov::element::f32), ::testing::Values(true), @@ -137,6 +146,80 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceBF16, ::testing::Values(CPUTestUtils::cpu_bf16_plugin_config)), MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_Without_Multiply, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), + ::testing::ValuesIn(precision_fp16_if_supported(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({false}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(2), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::empty_plugin_config)), + MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_With_Multiply_Static, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)), + ::testing::ValuesIn(precision_fp16_if_supported(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({true}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(2), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::empty_plugin_config)), + MHA::getTestCaseName); +// 3 nodes and 2 subgraph for dynamic with multiply case. +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA_FP16_4D_With_Multiply_Dynamic, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false, true)), + ::testing::ValuesIn(precision_fp16_if_supported(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({true}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(3), + ::testing::Values(2), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::empty_plugin_config)), + MHA::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_Without_Multiply, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D()), + ::testing::ValuesIn(precision_f32(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({false}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(2), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpu_f16_plugin_config)), + MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_With_Multiply_Static, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(true, false)), + ::testing::ValuesIn(precision_f32(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({true}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(2), + ::testing::Values(1), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpu_f16_plugin_config)), + MHA::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAEnforceFP16_With_Multiply_Dynamic, + MHA, + ::testing::Combine(::testing::ValuesIn(transposedShape_4D(false, true)), + ::testing::ValuesIn(precision_f32(4)), + ::testing::Values(ov::element::f16), + ::testing::ValuesIn({true}), + ::testing::Values(MHA::default_thread_count), + ::testing::Values(3), + ::testing::Values(2), + ::testing::Values(ov::test::utils::DEVICE_CPU), + ::testing::Values(CPUTestUtils::cpu_f16_plugin_config)), + MHA::getTestCaseName); } // namespace } // namespace snippets } // namespace test diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/utils.hpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/utils.hpp index 6c0d54da973086..6815cdab671cea 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/utils.hpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/utils.hpp @@ -16,6 +16,10 @@ static inline bool is_bf16_supported_by_brgemm() { return ov::with_cpu_x86_bfloat16() || ov::with_cpu_x86_avx512_core_amx_bf16(); } +static inline bool is_fp16_supported_by_brgemm() { + return ov::with_cpu_x86_avx512_core_amx_fp16(); +} + static inline bool is_i8_supported_by_brgemm() { return ov::with_cpu_x86_avx512_core_vnni() || ov::with_cpu_x86_avx512_core_amx_int8(); } @@ -33,6 +37,13 @@ static inline std::vector> precision_bf16_if_supporte return prc; } +static inline std::vector> precision_fp16_if_supported(size_t count) { + std::vector> prc; + if (is_fp16_supported_by_brgemm()) + prc.emplace_back(std::vector(count, element::f16)); + return prc; +} + static inline std::vector> quantized_precisions_if_supported() { std::vector> prc = {}; // In Snippets MatMul INT8 is supported only on VNNI/AMX platforms diff --git a/src/tests/functional/plugin/shared/src/snippets/mha.cpp b/src/tests/functional/plugin/shared/src/snippets/mha.cpp index 8d0cb8613bc47e..0a8fcc77717c42 100644 --- a/src/tests/functional/plugin/shared/src/snippets/mha.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/mha.cpp @@ -65,6 +65,8 @@ void MHABase::SetUp() { #endif if (inType == ov::element::bf16) rel_threshold = 0.05f; + if (inType == ov::element::f16) + abs_threshold = 2e-2; } std::string MHA::getTestCaseName(testing::TestParamInfo obj) { From adf097b31f033535c61101767ea66c06f717fedd Mon Sep 17 00:00:00 2001 From: Wilson Seok Date: Tue, 17 Dec 2024 16:27:30 +0900 Subject: [PATCH 3/5] [GPU] Add ConvolutionBackpropData in is_decompression_multiply() of MarkDequantizationSubgraph callback (#28075) ### Details: - Add ConvolutionBackpropData in target_consumers,convolutions of is_decompression_multiply() ### Tickets: - 159207 --- src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index a53b6ddac7332a..010318703dde09 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -181,10 +181,14 @@ static bool is_decompression_multiply(const std::shared_ptr node ov::op::v8::Gather::get_type_info_static(), ov::op::v1::Convolution::get_type_info_static(), ov::opset1::Convolution::get_type_info_static(), + ov::op::v1::ConvolutionBackpropData::get_type_info_static(), + ov::opset1::ConvolutionBackpropData::get_type_info_static(), ov::opset1::GroupConvolution::get_type_info_static() }; std::vector convolutions = { ov::op::v1::Convolution::get_type_info_static(), ov::opset1::Convolution::get_type_info_static(), + ov::op::v1::ConvolutionBackpropData::get_type_info_static(), + ov::opset1::ConvolutionBackpropData::get_type_info_static(), ov::opset1::GroupConvolution::get_type_info_static() }; auto all_has_types = [](const std::set>& consumers, const std::vector& types) { From b0a8c14d452e1c2af6a178046889e43793e9e968 Mon Sep 17 00:00:00 2001 From: Arshad Mehmood Date: Tue, 17 Dec 2024 15:43:32 +0800 Subject: [PATCH 4/5] [GPU] Updated GPU cache size retrieval and refined closest_pow_of_2 (#28059) Details: Existing method for cache size calculation was static and need continious updates to the sku table which was already being missed for latest skus e.g DG2. This update introduces a new member variable, max_global_cache_size, to store the GPU's global cache size, obtained via the OpenCL property CL_DEVICE_GLOBAL_MEM_CACHE_SIZE. The existing hard coded cache calculations are removed. Additionally, the closest_pow_of_2 function has been enhanced to return the nearest power of 2, favoring the upper value if the input is within 30% of the range for the upper bound. These changes improve memory management and ensure better utilization of GPU resources towards bottle neck situations. Tickets: CVS-159076 Signed-off-by: Arshad Mehmood --- .../include/intel_gpu/runtime/device_info.hpp | 1 + src/plugins/intel_gpu/src/plugin/plugin.cpp | 49 ++++++++----------- .../intel_gpu/src/runtime/ocl/ocl_device.cpp | 1 + 3 files changed, 22 insertions(+), 29 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/device_info.hpp b/src/plugins/intel_gpu/include/intel_gpu/runtime/device_info.hpp index d44b8c0536fe4a..8a3c4409246ad1 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/device_info.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/device_info.hpp @@ -56,6 +56,7 @@ struct device_info { uint64_t max_local_mem_size; ///< Maximum size of local memory arena in bytes. uint64_t max_global_mem_size; ///< Maximum size of global device memory in bytes. uint64_t max_alloc_mem_size; ///< Maximum size of memory object allocation in bytes. + uint64_t max_global_cache_size; ///< Maximum size of cache memory bytes. uint64_t max_image2d_width; ///< Maximum image 2d width supported by the device. uint64_t max_image2d_height; ///< Maximum image 2d height supported by the device. diff --git a/src/plugins/intel_gpu/src/plugin/plugin.cpp b/src/plugins/intel_gpu/src/plugin/plugin.cpp index f6c15bc2e8943a..5650f5a66a2ae6 100644 --- a/src/plugins/intel_gpu/src/plugin/plugin.cpp +++ b/src/plugins/intel_gpu/src/plugin/plugin.cpp @@ -797,12 +797,24 @@ uint32_t Plugin::get_optimal_batch_size(const ov::AnyMap& options) const { auto device_id = get_property(ov::device::id.name(), options).as(); auto context = get_default_contexts().at(device_id); const auto& device_info = context->get_engine().get_device_info(); - auto next_pow_of_2 = [] (float x) { - return pow(2, ceil(std::log(x)/std::log(2))); - }; + auto closest_pow_of_2 = [] (float x) { - return pow(2, floor(std::log(x)/std::log(2))); + int lower_power = static_cast(floor(std::log(x) / std::log(2))); + double lower_value = pow(2, lower_power); // Current power of 2 + double upper_value = pow(2, lower_power + 1); // Next power of 2 + + // Determine the threshold (70% of the range between lower and upper values) + // If x is within the upper 30% of the range, return the upper power of 2. + double threshold = 0.7 * (upper_value - lower_value); + + // Compare x with the threshold and return the appropriate power of 2 + if (x - lower_value > threshold) { + return upper_value; // Return the next power of 2 + } else { + return lower_value; // Return the current power of 2 + } }; + auto model_param = options.find(ov::hint::model.name()); if (model_param == options.end()) { GPU_DEBUG_INFO << "[OPTIMAL_BATCH_SIZE] ov::hint::model is not set: return 1" << std::endl; @@ -816,31 +828,10 @@ uint32_t Plugin::get_optimal_batch_size(const ov::AnyMap& options) const { } GPU_DEBUG_INFO << "DEVICE_INFO:" << "gfx_version.major, " << device_info.gfx_ver.major - << "gfx_version.minor " << std::to_string(device_info.gfx_ver.minor) << std::endl; - static std::map gen_kbytes_per_bank = { - {{12, 0, 0}, 480}, // TGL - {{12, 1, 0}, 2048}, // DG1 - {{12, 5, 0}, 320}, - {{12, 7, 0}, 512}, - }; - size_t L3_cache_size = device_info.gfx_ver.major && (device_info.gfx_ver.major <= 9) - ? 768 * 1024 // Gen9 - : 2 * 768 * 1024; //reasonable default when no arch has been detected (e.g. due to old driver ver) - cldnn::gfx_version gen = {device_info.gfx_ver.major, device_info.gfx_ver.minor, 0 /*ignore the revision*/}; - auto val = gen_kbytes_per_bank.find(gen); - if (gen_kbytes_per_bank.end() != val) { - auto kbytes_per_bank = val->second; - auto num_banks_per_slice = device_info.num_sub_slices_per_slice > 4 - ? next_pow_of_2(device_info.num_sub_slices_per_slice) - : 2 * device_info.num_sub_slices_per_slice; - L3_cache_size = kbytes_per_bank * 1024 * num_banks_per_slice * device_info.num_slices; - GPU_DEBUG_INFO << "DEVICE_INFO:" - << "num_slices " << device_info.num_slices - << ", num_sub_slices_per_slice " << device_info.num_sub_slices_per_slice - << ", num_banks_per_slice " << num_banks_per_slice - << ", gen_kbytes_per_bank : " << kbytes_per_bank - << ", L3_cache_size is (MB): " << float(L3_cache_size) / 1024 / 1024 << std::endl; - } + << "gfx_version.minor " << std::to_string(device_info.gfx_ver.minor) + << "Cache size " << std::to_string(device_info.max_global_cache_size) << std::endl; + + size_t L3_cache_size = device_info.max_global_cache_size; auto config = m_configs_map.at(device_id); auto cloned_model = clone_and_transform_model(model, config, context); ov::MemBandwidthPressure memPressure = ov::mem_bandwidth_pressure_tolerance(cloned_model, L3_cache_size); diff --git a/src/plugins/intel_gpu/src/runtime/ocl/ocl_device.cpp b/src/plugins/intel_gpu/src/runtime/ocl/ocl_device.cpp index 7ab48308cfeaf7..74dbc016c65d31 100644 --- a/src/plugins/intel_gpu/src/runtime/ocl/ocl_device.cpp +++ b/src/plugins/intel_gpu/src/runtime/ocl/ocl_device.cpp @@ -224,6 +224,7 @@ device_info init_device_info(const cl::Device& device, const cl::Context& contex info.max_local_mem_size = static_cast(device.getInfo()); info.max_global_mem_size = static_cast(device.getInfo()); info.max_alloc_mem_size = static_cast(device.getInfo()); + info.max_global_cache_size = static_cast(device.getInfo()); info.supports_image = static_cast(device.getInfo()); info.max_image2d_width = static_cast(device.getInfo()); From 5ce61572919f431eed31960cac58815b047f4ae6 Mon Sep 17 00:00:00 2001 From: Mikhail Ryzhov Date: Tue, 17 Dec 2024 10:04:46 +0100 Subject: [PATCH 5/5] [GHA] Set checkout timeout (#27995) ### Details: - Set checkout timeouts because the git clone command could freeze. ### Tickets: - *156678* --- .github/workflows/android_arm64.yml | 4 ++++ .github/workflows/android_x64.yml | 4 ++++ .github/workflows/build_doc.yml | 1 + .github/workflows/check_pr_commits.yml | 1 + .github/workflows/cleanup_caches.yml | 2 ++ .github/workflows/code_snippets.yml | 1 + .github/workflows/code_style.yml | 3 +++ .github/workflows/coverage.yml | 1 + .github/workflows/coverity.yml | 4 ++++ .github/workflows/debian_10_arm.yml | 2 ++ .github/workflows/dependency_review.yml | 1 + .github/workflows/dev_cpu_linux_snippets_libxsmm.yml | 4 ++++ .github/workflows/export_workflow_metrics.yml | 1 + .github/workflows/fedora_29.yml | 2 ++ .github/workflows/files_size.yml | 1 + .github/workflows/job_build_linux.yml | 2 ++ .github/workflows/job_build_windows.yml | 2 ++ .github/workflows/job_cpu_functional_tests.yml | 1 + .github/workflows/job_jax_models_tests.yml | 1 + .github/workflows/job_onnx_runtime.yml | 2 ++ .github/workflows/job_openvino_js.yml | 1 + .github/workflows/job_python_api_tests.yml | 2 ++ .github/workflows/job_python_unit_tests.yml | 1 + .github/workflows/job_pytorch_layer_tests.yml | 1 + .github/workflows/job_pytorch_models_tests.yml | 1 + .github/workflows/job_samples_tests.yml | 1 + .github/workflows/job_tensorflow_layer_tests.yml | 1 + .github/workflows/job_tensorflow_models_tests.yml | 1 + .github/workflows/job_tokenizers.yml | 2 ++ .github/workflows/labeler.yml | 1 + .github/workflows/linux_arm64.yml | 2 ++ .github/workflows/linux_conditional_compilation.yml | 6 ++++++ .github/workflows/linux_riscv.yml | 3 +++ .github/workflows/linux_sanitizers.yml | 5 +++++ .github/workflows/mac.yml | 3 +++ .github/workflows/mac_arm64.yml | 3 +++ .github/workflows/manylinux_2014.yml | 3 +++ .github/workflows/ovc.yml | 1 + .github/workflows/py_checks.yml | 1 + .github/workflows/ubuntu_20.yml | 2 ++ .github/workflows/ubuntu_22.yml | 4 ++++ .github/workflows/ubuntu_22_dpcpp.yml | 2 ++ .github/workflows/ubuntu_24.yml | 2 ++ .github/workflows/webassembly.yml | 3 +++ .github/workflows/windows_conditional_compilation.yml | 6 ++++++ .github/workflows/windows_vs2019_debug.yml | 1 + .github/workflows/windows_vs2019_release.yml | 6 ++++++ .github/workflows/workflow_rerunner.yml | 2 ++ .github/workflows/workflows_scans.yml | 1 + 49 files changed, 108 insertions(+) diff --git a/.github/workflows/android_arm64.yml b/.github/workflows/android_arm64.yml index e0954871f4b51e..b760d9746d7842 100644 --- a/.github/workflows/android_arm64.yml +++ b/.github/workflows/android_arm64.yml @@ -25,6 +25,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -54,6 +55,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker @@ -99,6 +101,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: 'openvino' @@ -117,6 +120,7 @@ jobs: - name: Clone vcpkg uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'microsoft/vcpkg' ref: ${{ env.VCPKG_VERSION }} diff --git a/.github/workflows/android_x64.yml b/.github/workflows/android_x64.yml index b0b46c662abdbb..efd14541010730 100644 --- a/.github/workflows/android_x64.yml +++ b/.github/workflows/android_x64.yml @@ -28,6 +28,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -57,6 +58,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker @@ -98,12 +100,14 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: 'openvino' submodules: 'true' - name: Clone OpenVINO GenAI uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/openvino.genai' path: ${{ env.OPENVINO_GENAI_REPO }} diff --git a/.github/workflows/build_doc.yml b/.github/workflows/build_doc.yml index 8c78375e61769c..c0dac9816598e1 100644 --- a/.github/workflows/build_doc.yml +++ b/.github/workflows/build_doc.yml @@ -19,6 +19,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: submodules: 'true' lfs: 'true' diff --git a/.github/workflows/check_pr_commits.yml b/.github/workflows/check_pr_commits.yml index f7f66be299876c..91d6a2a497a8cd 100644 --- a/.github/workflows/check_pr_commits.yml +++ b/.github/workflows/check_pr_commits.yml @@ -10,6 +10,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - name: Install dependencies run: python3 -m pip install -r ./.github/github_org_control/requirements.txt diff --git a/.github/workflows/cleanup_caches.yml b/.github/workflows/cleanup_caches.yml index 3fc69b21374093..d6633fd9dab3ee 100644 --- a/.github/workflows/cleanup_caches.yml +++ b/.github/workflows/cleanup_caches.yml @@ -49,6 +49,7 @@ jobs: steps: - name: Checkout cach action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/cache @@ -71,6 +72,7 @@ jobs: steps: - name: Checkout cach action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/cache diff --git a/.github/workflows/code_snippets.yml b/.github/workflows/code_snippets.yml index 9337fdff4b2905..5916f91447abc9 100644 --- a/.github/workflows/code_snippets.yml +++ b/.github/workflows/code_snippets.yml @@ -29,6 +29,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: submodules: 'true' diff --git a/.github/workflows/code_style.yml b/.github/workflows/code_style.yml index d4da2a16d38923..3969da2b97c5a1 100644 --- a/.github/workflows/code_style.yml +++ b/.github/workflows/code_style.yml @@ -15,6 +15,7 @@ jobs: pull-requests: write steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: submodules: 'true' @@ -75,6 +76,7 @@ jobs: pull-requests: write steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: submodules: 'true' @@ -107,6 +109,7 @@ jobs: if: ${{ github.repository_owner == 'openvinotoolkit' }} steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: submodules: 'true' diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml index cde1b9cf67e2fc..fd6a029abfaa67 100644 --- a/.github/workflows/coverage.yml +++ b/.github/workflows/coverage.yml @@ -33,6 +33,7 @@ jobs: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: submodules: 'true' diff --git a/.github/workflows/coverity.yml b/.github/workflows/coverity.yml index 5a08ec084dadac..52ac10c9a6882a 100644 --- a/.github/workflows/coverity.yml +++ b/.github/workflows/coverity.yml @@ -35,6 +35,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -63,6 +64,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker @@ -98,6 +100,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: ${{ env.OPENVINO_REPO }} submodules: 'true' @@ -105,6 +108,7 @@ jobs: - name: Clone OpenVINO Contrib uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/openvino_contrib' path: ${{ env.OPENVINO_CONTRIB_REPO }} diff --git a/.github/workflows/debian_10_arm.yml b/.github/workflows/debian_10_arm.yml index cf628d12c29b89..20b1daa0c5dc8d 100644 --- a/.github/workflows/debian_10_arm.yml +++ b/.github/workflows/debian_10_arm.yml @@ -25,6 +25,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -59,6 +60,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker diff --git a/.github/workflows/dependency_review.yml b/.github/workflows/dependency_review.yml index 59a1eaa6e1c26f..690c789cb65222 100644 --- a/.github/workflows/dependency_review.yml +++ b/.github/workflows/dependency_review.yml @@ -10,6 +10,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - name: Dependency Review uses: actions/dependency-review-action@72eb03d02c7872a771aacd928f3123ac62ad6d3a # v4.3.3 diff --git a/.github/workflows/dev_cpu_linux_snippets_libxsmm.yml b/.github/workflows/dev_cpu_linux_snippets_libxsmm.yml index ba458da5d3ec1a..5ed82e8330778c 100644 --- a/.github/workflows/dev_cpu_linux_snippets_libxsmm.yml +++ b/.github/workflows/dev_cpu_linux_snippets_libxsmm.yml @@ -33,6 +33,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -66,6 +67,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker @@ -110,6 +112,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: ${{ env.OPENVINO_REPO }} submodules: 'true' @@ -296,6 +299,7 @@ jobs: - name: Fetch setup_python action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml diff --git a/.github/workflows/export_workflow_metrics.yml b/.github/workflows/export_workflow_metrics.yml index 084dfbdc34af7f..39bb699b8caa91 100644 --- a/.github/workflows/export_workflow_metrics.yml +++ b/.github/workflows/export_workflow_metrics.yml @@ -40,6 +40,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: '.github' diff --git a/.github/workflows/fedora_29.yml b/.github/workflows/fedora_29.yml index f3b101327f76dc..0dd101225dc533 100644 --- a/.github/workflows/fedora_29.yml +++ b/.github/workflows/fedora_29.yml @@ -25,6 +25,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -59,6 +60,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker diff --git a/.github/workflows/files_size.yml b/.github/workflows/files_size.yml index 2768e731b6578b..c263afed1fe465 100644 --- a/.github/workflows/files_size.yml +++ b/.github/workflows/files_size.yml @@ -13,6 +13,7 @@ jobs: if: ${{ github.repository_owner == 'openvinotoolkit' }} steps: - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - name: git ls-tree run: git ls-tree -r -t -l --full-name HEAD | sort -n -r -k 4 diff --git a/.github/workflows/job_build_linux.yml b/.github/workflows/job_build_linux.yml index 3964f049be2abb..c56de5872cc2df 100644 --- a/.github/workflows/job_build_linux.yml +++ b/.github/workflows/job_build_linux.yml @@ -92,6 +92,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: ${{ env.OPENVINO_REPO }} submodules: 'true' @@ -107,6 +108,7 @@ jobs: - name: Clone OpenVINO Contrib uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/openvino_contrib' path: ${{ env.OPENVINO_CONTRIB_REPO }} diff --git a/.github/workflows/job_build_windows.yml b/.github/workflows/job_build_windows.yml index 7b682f208c3435..4e3969d978cb83 100644 --- a/.github/workflows/job_build_windows.yml +++ b/.github/workflows/job_build_windows.yml @@ -60,12 +60,14 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: 'openvino' submodules: 'true' - name: Clone OpenVINO Contrib uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/openvino_contrib' path: 'openvino_contrib' diff --git a/.github/workflows/job_cpu_functional_tests.yml b/.github/workflows/job_cpu_functional_tests.yml index 0366ec47ff437e..568c33d39e307b 100644 --- a/.github/workflows/job_cpu_functional_tests.yml +++ b/.github/workflows/job_cpu_functional_tests.yml @@ -72,6 +72,7 @@ jobs: - name: Fetch setup_python action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml diff --git a/.github/workflows/job_jax_models_tests.yml b/.github/workflows/job_jax_models_tests.yml index 43fa8f2a7f1740..07155db1016057 100644 --- a/.github/workflows/job_jax_models_tests.yml +++ b/.github/workflows/job_jax_models_tests.yml @@ -65,6 +65,7 @@ jobs: - name: Fetch setup_python action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml diff --git a/.github/workflows/job_onnx_runtime.yml b/.github/workflows/job_onnx_runtime.yml index df50c4f3e2ad3c..92f86511e99e4a 100644 --- a/.github/workflows/job_onnx_runtime.yml +++ b/.github/workflows/job_onnx_runtime.yml @@ -64,6 +64,7 @@ jobs: - name: Fetch ONNX runtime version and skip tests list uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | src/frontends/onnx/tests/ci_utils/onnxruntime @@ -78,6 +79,7 @@ jobs: - name: Clone ONNX Runtime uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'microsoft/onnxruntime' path: ${{ env.ONNX_RUNTIME_REPO }} diff --git a/.github/workflows/job_openvino_js.yml b/.github/workflows/job_openvino_js.yml index ecb278fdb54ca3..fd04d8842daae7 100644 --- a/.github/workflows/job_openvino_js.yml +++ b/.github/workflows/job_openvino_js.yml @@ -33,6 +33,7 @@ jobs: steps: - name: Fetch OpenVINO JS sources uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | src/bindings/js diff --git a/.github/workflows/job_python_api_tests.yml b/.github/workflows/job_python_api_tests.yml index 654d634f4f56f3..81092db2bb808c 100644 --- a/.github/workflows/job_python_api_tests.yml +++ b/.github/workflows/job_python_api_tests.yml @@ -62,6 +62,7 @@ jobs: - name: Fetch setup_python and install wheels actions uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml @@ -115,6 +116,7 @@ jobs: - name: Clone API snippets if: runner.os != 'macOS' uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 + timeout-minutes: 15 with: sparse-checkout: docs/articles_en/assets/snippets path: ${{ env.OPENVINO_REPO }} diff --git a/.github/workflows/job_python_unit_tests.yml b/.github/workflows/job_python_unit_tests.yml index 47506c83bf0945..b04f719c8e296f 100644 --- a/.github/workflows/job_python_unit_tests.yml +++ b/.github/workflows/job_python_unit_tests.yml @@ -77,6 +77,7 @@ jobs: - name: Fetch setup_python and install wheels actions uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml diff --git a/.github/workflows/job_pytorch_layer_tests.yml b/.github/workflows/job_pytorch_layer_tests.yml index 271b7948d435dc..9a9abaf72ade62 100644 --- a/.github/workflows/job_pytorch_layer_tests.yml +++ b/.github/workflows/job_pytorch_layer_tests.yml @@ -86,6 +86,7 @@ jobs: - name: Fetch setup_python and install wheels actions uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml diff --git a/.github/workflows/job_pytorch_models_tests.yml b/.github/workflows/job_pytorch_models_tests.yml index d52b819981d821..af304b18a5688f 100644 --- a/.github/workflows/job_pytorch_models_tests.yml +++ b/.github/workflows/job_pytorch_models_tests.yml @@ -78,6 +78,7 @@ jobs: - name: Fetch setup_python action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml diff --git a/.github/workflows/job_samples_tests.yml b/.github/workflows/job_samples_tests.yml index 6f95d316abfc3f..07fc17b797592e 100644 --- a/.github/workflows/job_samples_tests.yml +++ b/.github/workflows/job_samples_tests.yml @@ -68,6 +68,7 @@ jobs: - name: Fetch setup_python and install wheels actions uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml diff --git a/.github/workflows/job_tensorflow_layer_tests.yml b/.github/workflows/job_tensorflow_layer_tests.yml index 98f385e990f5e6..fb905f8ec4820b 100644 --- a/.github/workflows/job_tensorflow_layer_tests.yml +++ b/.github/workflows/job_tensorflow_layer_tests.yml @@ -86,6 +86,7 @@ jobs: - name: Fetch setup_python and install wheels actions uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml diff --git a/.github/workflows/job_tensorflow_models_tests.yml b/.github/workflows/job_tensorflow_models_tests.yml index 5321beb8703de1..de5cf95484256a 100644 --- a/.github/workflows/job_tensorflow_models_tests.yml +++ b/.github/workflows/job_tensorflow_models_tests.yml @@ -70,6 +70,7 @@ jobs: - name: Fetch setup_python action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml diff --git a/.github/workflows/job_tokenizers.yml b/.github/workflows/job_tokenizers.yml index 1068ec550d1752..89d572885b1abe 100644 --- a/.github/workflows/job_tokenizers.yml +++ b/.github/workflows/job_tokenizers.yml @@ -58,6 +58,7 @@ jobs: - name: checkout actions uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python @@ -79,6 +80,7 @@ jobs: - name: Clone OpenVINO Tokenizers uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/openvino_tokenizers' path: ${{ env.OPENVINO_TOKENIZERS_REPO }} diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 00f3a321e0dd1f..063b920eed80e9 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -27,6 +27,7 @@ jobs: steps: - name: Checkout Labeller Script uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: '.github' diff --git a/.github/workflows/linux_arm64.yml b/.github/workflows/linux_arm64.yml index 255a30cbc88770..9ca6c5461a62ea 100644 --- a/.github/workflows/linux_arm64.yml +++ b/.github/workflows/linux_arm64.yml @@ -29,6 +29,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -63,6 +64,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker diff --git a/.github/workflows/linux_conditional_compilation.yml b/.github/workflows/linux_conditional_compilation.yml index ce78a9f3ae63b7..f198e64f7ad2ed 100644 --- a/.github/workflows/linux_conditional_compilation.yml +++ b/.github/workflows/linux_conditional_compilation.yml @@ -30,6 +30,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -64,6 +65,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker @@ -110,12 +112,14 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: ${{ env.OPENVINO_REPO }} submodules: 'true' - name: Clone test models uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/testdata' path: ${{ env.MODELS_PATH }} @@ -282,12 +286,14 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: ${{ env.OPENVINO_REPO }} submodules: 'true' - name: Clone test models uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/testdata' path: ${{ env.MODELS_PATH }} diff --git a/.github/workflows/linux_riscv.yml b/.github/workflows/linux_riscv.yml index 85b0db8c36294e..8966a63f611d36 100644 --- a/.github/workflows/linux_riscv.yml +++ b/.github/workflows/linux_riscv.yml @@ -29,6 +29,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -64,6 +65,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker @@ -103,6 +105,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: 'openvino' diff --git a/.github/workflows/linux_sanitizers.yml b/.github/workflows/linux_sanitizers.yml index 4bb597d83fadc8..cf8e1642fa5f51 100644 --- a/.github/workflows/linux_sanitizers.yml +++ b/.github/workflows/linux_sanitizers.yml @@ -25,6 +25,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -53,6 +54,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker @@ -108,12 +110,14 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: ${{ env.OPENVINO_REPO }} submodules: 'true' - name: Clone OpenVINO Contrib uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/openvino_contrib' path: ${{ env.OPENVINO_CONTRIB_REPO }} @@ -281,6 +285,7 @@ jobs: - name: Fetch Sanitizer Suppression Lists uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | tests/sanitizers/lsan/suppressions.txt diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 26289e969c4e00..94460a2721b60f 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -43,6 +43,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -83,12 +84,14 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: 'openvino' submodules: 'true' - name: Clone OpenVINO Contrib uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/openvino_contrib' path: 'openvino_contrib' diff --git a/.github/workflows/mac_arm64.yml b/.github/workflows/mac_arm64.yml index d3fb10082adfd4..3340ce62e0104f 100644 --- a/.github/workflows/mac_arm64.yml +++ b/.github/workflows/mac_arm64.yml @@ -43,6 +43,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -83,12 +84,14 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: 'openvino' submodules: 'true' - name: Clone OpenVINO Contrib uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/openvino_contrib' path: 'openvino_contrib' diff --git a/.github/workflows/manylinux_2014.yml b/.github/workflows/manylinux_2014.yml index aa0b06b6cf05bd..4b5fc137c1504e 100644 --- a/.github/workflows/manylinux_2014.yml +++ b/.github/workflows/manylinux_2014.yml @@ -28,6 +28,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -62,6 +63,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker @@ -113,6 +115,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: ${{ env.OPENVINO_REPO }} submodules: 'true' diff --git a/.github/workflows/ovc.yml b/.github/workflows/ovc.yml index 4d69563a741d3a..3e7dedf50ad51b 100644 --- a/.github/workflows/ovc.yml +++ b/.github/workflows/ovc.yml @@ -20,6 +20,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - name: Setup Python uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 diff --git a/.github/workflows/py_checks.yml b/.github/workflows/py_checks.yml index caed37eee89056..dcf0932df8024e 100644 --- a/.github/workflows/py_checks.yml +++ b/.github/workflows/py_checks.yml @@ -29,6 +29,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - name: Setup Python uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 diff --git a/.github/workflows/ubuntu_20.yml b/.github/workflows/ubuntu_20.yml index ac00405ae71ed3..19760ff2551773 100644 --- a/.github/workflows/ubuntu_20.yml +++ b/.github/workflows/ubuntu_20.yml @@ -31,6 +31,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -65,6 +66,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker diff --git a/.github/workflows/ubuntu_22.yml b/.github/workflows/ubuntu_22.yml index a32caecfbd073d..d749164abbefd0 100644 --- a/.github/workflows/ubuntu_22.yml +++ b/.github/workflows/ubuntu_22.yml @@ -33,6 +33,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -67,6 +68,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker @@ -185,6 +187,7 @@ jobs: - name: Fetch setup_python action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml @@ -471,6 +474,7 @@ jobs: - name: Clone OpenVINO Contrib uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/openvino_contrib' path: ${{ env.OPENVINO_CONTRIB_REPO }} diff --git a/.github/workflows/ubuntu_22_dpcpp.yml b/.github/workflows/ubuntu_22_dpcpp.yml index 48230155f7e903..ad11a31f7403bf 100644 --- a/.github/workflows/ubuntu_22_dpcpp.yml +++ b/.github/workflows/ubuntu_22_dpcpp.yml @@ -21,6 +21,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -55,6 +56,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker diff --git a/.github/workflows/ubuntu_24.yml b/.github/workflows/ubuntu_24.yml index 1ad3951ecd3347..2c76149ecdcb94 100644 --- a/.github/workflows/ubuntu_24.yml +++ b/.github/workflows/ubuntu_24.yml @@ -28,6 +28,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -62,6 +63,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker diff --git a/.github/workflows/webassembly.yml b/.github/workflows/webassembly.yml index 45d6c9ce98317a..350df3113b0f3a 100644 --- a/.github/workflows/webassembly.yml +++ b/.github/workflows/webassembly.yml @@ -25,6 +25,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -59,6 +60,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 - uses: ./.github/actions/handle_docker id: handle_docker @@ -92,6 +94,7 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: 'openvino' submodules: 'true' diff --git a/.github/workflows/windows_conditional_compilation.yml b/.github/workflows/windows_conditional_compilation.yml index 2c8ba236d8503c..0f965eabd3c1ad 100644 --- a/.github/workflows/windows_conditional_compilation.yml +++ b/.github/workflows/windows_conditional_compilation.yml @@ -31,6 +31,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -74,12 +75,14 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: 'openvino' submodules: 'true' - name: Clone test models uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/testdata' path: 'testdata' @@ -283,12 +286,14 @@ jobs: steps: - name: Clone OpenVINO uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: path: 'openvino' submodules: 'true' - name: Clone test models uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: repository: 'openvinotoolkit/testdata' path: 'testdata' @@ -370,6 +375,7 @@ jobs: - name: Fetch setup_python action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml diff --git a/.github/workflows/windows_vs2019_debug.yml b/.github/workflows/windows_vs2019_debug.yml index 68a99055f5bdb8..4fcdc6b58b79d1 100644 --- a/.github/workflows/windows_vs2019_debug.yml +++ b/.github/workflows/windows_vs2019_debug.yml @@ -27,6 +27,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci diff --git a/.github/workflows/windows_vs2019_release.yml b/.github/workflows/windows_vs2019_release.yml index da526e57bed1ec..f1fd0be596baa2 100644 --- a/.github/workflows/windows_vs2019_release.yml +++ b/.github/workflows/windows_vs2019_release.yml @@ -29,6 +29,7 @@ jobs: steps: - name: checkout action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: .github/actions/smart-ci @@ -112,6 +113,7 @@ jobs: - name: Fetch setup_python and install wheels actions uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml @@ -185,6 +187,7 @@ jobs: steps: - name: Fetch OpenVINO JS sources uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | src/bindings/js @@ -283,6 +286,7 @@ jobs: - name: Fetch setup_python and install wheels actions uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml @@ -425,6 +429,7 @@ jobs: - name: Fetch setup_python and install wheels actions uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml @@ -544,6 +549,7 @@ jobs: - name: Fetch setup_python action uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: | .github/actions/setup_python/action.yml diff --git a/.github/workflows/workflow_rerunner.yml b/.github/workflows/workflow_rerunner.yml index 55ecc2500635b1..0d8d6610bea588 100644 --- a/.github/workflows/workflow_rerunner.yml +++ b/.github/workflows/workflow_rerunner.yml @@ -38,6 +38,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: '.github/scripts/workflow_rerun' @@ -73,6 +74,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: sparse-checkout: '.github/scripts/workflow_rerun' lfs: true diff --git a/.github/workflows/workflows_scans.yml b/.github/workflows/workflows_scans.yml index 0a293a4152b9a0..1a3d091544e784 100644 --- a/.github/workflows/workflows_scans.yml +++ b/.github/workflows/workflows_scans.yml @@ -29,6 +29,7 @@ jobs: steps: - name: Checkout uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + timeout-minutes: 15 with: submodules: 'false' sparse-checkout: .github/workflows