diff --git a/1_17_patch.patch b/1_17_patch.patch index 40866ef..c10f70e 100644 --- a/1_17_patch.patch +++ b/1_17_patch.patch @@ -1,8 +1,37254 @@ +diff --git a/.pipelines/windowsai-steps.yml b/.pipelines/windowsai-steps.yml +index ff5179e61..6e551d818 100644 +--- a/.pipelines/windowsai-steps.yml ++++ b/.pipelines/windowsai-steps.yml +@@ -84,7 +84,7 @@ jobs: + 7z x cmake-3.26.3-windows-x86_64.zip + set PYTHONHOME=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools + set PYTHONPATH=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools +- $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe ++ $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_qspectre --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Generate cmake config' + +diff --git a/VERSION_NUMBER b/VERSION_NUMBER +index 84cc52946..092afa15d 100644 +--- a/VERSION_NUMBER ++++ b/VERSION_NUMBER +@@ -1 +1 @@ +-1.18.0 ++1.17.0 +diff --git a/build_arm64x.bat b/build_arm64x.bat +index 1ed268ae9..fbcdd3730 100644 +--- a/build_arm64x.bat ++++ b/build_arm64x.bat +@@ -5,6 +5,7 @@ + + setlocal + set PATH=C:\Program Files\Git\usr\bin;%PATH% ++set LINK_REPRO_NAME=/mylink.rsp + + rem Requires a Python install to be available in your PATH + python "%~dp0\tools\ci_build\build.py" --arm64 --buildasx --build_dir "%~dp0\build\arm64-x" %* +diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json +index efd901787..03e3f8454 100644 +--- a/cgmanifests/generated/cgmanifest.json ++++ b/cgmanifests/generated/cgmanifest.json +@@ -42,16 +42,6 @@ + "comments": "abseil_cpp" + } + }, +- { +- "component": { +- "type": "git", +- "git": { +- "commitHash": "dbb0094fd0cb936469e35320bf37e866ef7a1da4", +- "repositoryUrl": "https://github.com/apple/coremltools.git" +- }, +- "comments": "coremltools" +- } +- }, + { + "component": { + "type": "git", +diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt +index 34e7687e9..94d650f68 100644 +--- a/cmake/CMakeLists.txt ++++ b/cmake/CMakeLists.txt +@@ -79,7 +79,6 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) + cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF) + + option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF) +-option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF) + option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) + option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) + option(onnxruntime_USE_COREML "Build with CoreML support" OFF) +@@ -88,7 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) + option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) + option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) + option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) +-option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF) ++option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON) + option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) + option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) + option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) +@@ -641,7 +640,6 @@ else() + check_cxx_compiler_flag(-Wunused-but-set-variable HAS_UNUSED_BUT_SET_VARIABLE) + check_cxx_compiler_flag(-Wunused-variable HAS_UNUSED_VARIABLE) + check_cxx_compiler_flag(-Wuseless-cast HAS_USELESS_CAST) +- check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) + check_function_exists(reallocarray HAS_REALLOCARRAY) + if (NOT APPLE AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_target_platform STREQUAL "aarch64") + check_cxx_compiler_flag(-march=armv8.2-a+bf16 HAS_ARM64_BFLOAT16) +diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake +index 2c7bf9f1c..30d8cbf78 100644 +--- a/cmake/adjust_global_compile_flags.cmake ++++ b/cmake/adjust_global_compile_flags.cmake +@@ -123,11 +123,6 @@ if (onnxruntime_DISABLE_RTTI) + add_compile_options("$<$:/GR->" "$<$:/we4541>") + else() + add_compile_options("$<$:-fno-rtti>") +- if (onnxruntime_USE_WEBNN) +- # Avoid unboundTypeError for WebNN EP since unbound type names are illegal with RTTI disabled +- # in Embind API, relevant issue: https://github.com/emscripten-core/emscripten/issues/7001 +- add_compile_options("$<$:-DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0>") +- endif() + endif() + else() + #MSVC RTTI flag /GR is not added to CMAKE_CXX_FLAGS by default. But, anyway VC++2019 treats "/GR" default on. +diff --git a/cmake/deps.txt b/cmake/deps.txt +index cb431f8c7..ba9c2bb73 100644 +--- a/cmake/deps.txt ++++ b/cmake/deps.txt +@@ -13,7 +13,6 @@ + # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 + # + abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240116.0.zip;bc2cec6baaad67fcb6c0c38972b687d4797927e9 +-coremltools;https://github.com/apple/coremltools/archive/refs/tags/7.1.zip;f1bab0f30966f2e217d8e01207d518f230a1641a + cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 + date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 + dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 +@@ -56,4 +55,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2 + cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee + utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 + extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c +-composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 ++composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 +\ No newline at end of file +diff --git a/cmake/external/dnnl.cmake b/cmake/external/dnnl.cmake +index 9eb5fed7a..d7b706407 100644 +--- a/cmake/external/dnnl.cmake ++++ b/cmake/external/dnnl.cmake +@@ -2,7 +2,7 @@ include (ExternalProject) + + set(DNNL_URL https://github.com/oneapi-src/onednn.git) + # If DNNL_TAG is updated, check if MKLML_VERSION and platform.cmake.patch need to be updated. +-set(DNNL_TAG v3.0.1) ++set(DNNL_TAG v3.0) + + if(WIN32) + set(DNNL_SHARED_LIB dnnl.dll) +diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake +index 22d12b128..78f63227c 100644 +--- a/cmake/external/onnxruntime_external_deps.cmake ++++ b/cmake/external/onnxruntime_external_deps.cmake +@@ -108,14 +108,41 @@ FetchContent_Declare( + ) + + # Download a protoc binary from Internet if needed +-if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) ++if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) + # This part of code is only for users' convenience. The code couldn't handle all cases. Users always can manually + # download protoc from Protobuf's Github release page and pass the local path to the ONNX_CUSTOM_PROTOC_EXECUTABLE + # variable. +- if (CMAKE_HOST_APPLE) +- # Using CMAKE_CROSSCOMPILING is not recommended for Apple target devices. +- # https://cmake.org/cmake/help/v3.26/variable/CMAKE_CROSSCOMPILING.html +- # To keep it simple, just download and use the universal protoc binary for all Apple host builds. ++ message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") ++ if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") ++ if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") ++ FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64}) ++ FetchContent_Populate(protoc_binary) ++ elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86") ++ FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32}) ++ FetchContent_Populate(protoc_binary) ++ endif() ++ if(protoc_binary_SOURCE_DIR) ++ message("Use prebuilt protoc") ++ set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) ++ set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) ++ endif() ++ elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") ++ if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") ++ FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64}) ++ FetchContent_Populate(protoc_binary) ++ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") ++ FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86}) ++ FetchContent_Populate(protoc_binary) ++ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") ++ FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64}) ++ FetchContent_Populate(protoc_binary) ++ endif() ++ if(protoc_binary_SOURCE_DIR) ++ message("Use prebuilt protoc") ++ set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) ++ set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) ++ endif() ++ elseif ((CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin") + FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal}) + FetchContent_Populate(protoc_binary) + if(protoc_binary_SOURCE_DIR) +@@ -123,38 +150,6 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE) + set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) + set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) + endif() +- elseif (CMAKE_CROSSCOMPILING) +- message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}") +- if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") +- if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64") +- FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64}) +- FetchContent_Populate(protoc_binary) +- elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86") +- FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32}) +- FetchContent_Populate(protoc_binary) +- endif() +- if(protoc_binary_SOURCE_DIR) +- message("Use prebuilt protoc") +- set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe) +- set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) +- endif() +- elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux") +- if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") +- FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64}) +- FetchContent_Populate(protoc_binary) +- elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") +- FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86}) +- FetchContent_Populate(protoc_binary) +- elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") +- FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64}) +- FetchContent_Populate(protoc_binary) +- endif() +- if(protoc_binary_SOURCE_DIR) +- message("Use prebuilt protoc") +- set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc) +- set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE}) +- endif() +- endif() + endif() + endif() + +@@ -189,9 +184,9 @@ FetchContent_Declare( + ) + + set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE) +-#TODO: we'd better to turn the following option off. However, it will cause ++#TODO: we'd better to turn the following option off. However, it will cause + # ".\build.bat --config Debug --parallel --skip_submodule_sync --update" fail with an error message: +-# install(EXPORT "ONNXTargets" ...) includes target "onnx_proto" which requires target "libprotobuf-lite" that is ++# install(EXPORT "ONNXTargets" ...) includes target "onnx_proto" which requires target "libprotobuf-lite" that is + # not in any export set. + #set(protobuf_INSTALL OFF CACHE BOOL "Install protobuf binaries and files" FORCE) + set(protobuf_USE_EXTERNAL_GTEST ON CACHE BOOL "" FORCE) +@@ -224,6 +219,8 @@ FetchContent_Declare( + URL_HASH SHA1=${DEP_SHA1_mp11} + ) + ++set(JSON_BuildTests OFF CACHE INTERNAL "") ++set(JSON_Install OFF CACHE INTERNAL "") + set(JSON_BuildTests OFF CACHE INTERNAL "") + set(JSON_Install OFF CACHE INTERNAL "") + +@@ -539,17 +536,6 @@ if(onnxruntime_ENABLE_TRAINING OR (onnxruntime_ENABLE_TRAINING_APIS AND onnxrunt + onnxruntime_fetchcontent_makeavailable(cxxopts) + endif() + +-if (onnxruntime_USE_COREML) +- FetchContent_Declare( +- coremltools +- URL ${DEP_URL_coremltools} +- URL_HASH SHA1=${DEP_SHA1_coremltools} +- PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/coremltools/crossplatformbuild.patch +- ) +- # we don't build directly so use Populate. selected files are built from onnxruntime_providers_coreml.cmake +- FetchContent_Populate(coremltools) +-endif() +- + message("Finished fetching external dependencies") + + +@@ -576,3 +562,4 @@ endif() + + FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} ORT_BINARY_DIR) + FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} ORT_SOURCE_DIR) ++ +diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake +index 41f02ce6f..e661aa51b 100644 +--- a/cmake/external/xnnpack.cmake ++++ b/cmake/external/xnnpack.cmake +@@ -6,14 +6,10 @@ set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") + set(PTHREADPOOL_BUILD_TESTS OFF CACHE INTERNAL "") + set(PTHREADPOOL_BUILD_BENCHMARKS OFF CACHE INTERNAL "") + +-if(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") +- set(XNNPACK_USE_SYSTEM_LIBS OFF) +-endif() +- + # BF16 instructions cause ICE in Android NDK compiler + if(CMAKE_ANDROID_ARCH_ABI STREQUAL armeabi-v7a) + set(XNNPACK_ENABLE_ARM_BF16 OFF) +-endif() ++ENDIF() + + # fp16 depends on psimd + FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd}) +diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake +index 2ead13e55..c900f4d4b 100644 +--- a/cmake/onnxruntime.cmake ++++ b/cmake/onnxruntime.cmake +@@ -189,6 +189,7 @@ set(onnxruntime_INTERNAL_LIBRARIES + ${PROVIDERS_SNPE} + ${PROVIDERS_TVM} + ${PROVIDERS_RKNPU} ++ ${PROVIDERS_VITISAI} + ${PROVIDERS_XNNPACK} + ${PROVIDERS_WEBNN} + ${PROVIDERS_AZURE} +diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake +index 6b8c2560b..43d5fa9bd 100644 +--- a/cmake/onnxruntime_common.cmake ++++ b/cmake/onnxruntime_common.cmake +@@ -189,8 +189,6 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + set(ARM TRUE) + elseif(dumpmachine_output MATCHES "^aarch64.*") + set(ARM64 TRUE) +- elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") +- set(RISCV64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") + set(X86 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") +@@ -200,7 +198,7 @@ elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + endif() + + +-if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) ++if (ARM64 OR ARM OR X86 OR X64 OR X86_64) + if((WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) OR ((ARM64 OR ARM) AND MSVC)) + # msvc compiler report syntax error with cpuinfo arm source files + # and cpuinfo does not have code for getting arm uarch info under windows +diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake +index c6c9d8f48..8d3ea403f 100644 +--- a/cmake/onnxruntime_providers.cmake ++++ b/cmake/onnxruntime_providers.cmake +@@ -67,7 +67,7 @@ if(onnxruntime_USE_CUDA) + endif() + if(onnxruntime_USE_COREML) + if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") +- set(PROVIDERS_COREML onnxruntime_providers_coreml coreml_proto) ++ set(PROVIDERS_COREML onnxruntime_providers_coreml onnxruntime_coreml_proto) + else() + set(PROVIDERS_COREML onnxruntime_providers_coreml) + endif() +diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake +index 2ca4a22ac..aa8c35526 100644 +--- a/cmake/onnxruntime_providers_coreml.cmake ++++ b/cmake/onnxruntime_providers_coreml.cmake +@@ -1,119 +1,107 @@ + # Copyright (c) Microsoft Corporation. All rights reserved. + # Licensed under the MIT License. + +-if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) +- message(FATAL_ERROR "CoreML EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") +-endif() +- +-add_compile_definitions(USE_COREML=1) +- +-# Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml_proto +-set(COREML_PROTO_ROOT ${coremltools_SOURCE_DIR}/mlmodel/format) +-file(GLOB coreml_proto_srcs "${COREML_PROTO_ROOT}/*.proto") +- +-onnxruntime_add_static_library(coreml_proto ${coreml_proto_srcs}) +-target_include_directories(coreml_proto +- PUBLIC $ +- "${CMAKE_CURRENT_BINARY_DIR}") +-target_compile_definitions(coreml_proto +- PUBLIC $) +-set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") +-set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") +-set(_src_sub_dir "coreml_proto/") +- +-onnxruntime_protobuf_generate( +- APPEND_PATH +- GEN_SRC_SUB_DIR ${_src_sub_dir} +- IMPORT_DIRS ${COREML_PROTO_ROOT} +- TARGET coreml_proto +-) +- +-if (NOT onnxruntime_BUILD_SHARED_LIB) +- install(TARGETS coreml_proto +- ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} +- LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} +- RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} +- FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR} ++ if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) ++ message(FATAL_ERROR "CoreML EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") ++ endif() ++ ++ add_compile_definitions(USE_COREML=1) ++ ++ # Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml ++ if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") ++ set(COREML_PROTO_ROOT ${PROJECT_SOURCE_DIR}/../onnxruntime/core/providers/coreml/mlmodel_format) ++ file(GLOB coreml_proto_srcs ++ "${COREML_PROTO_ROOT}/*.proto" ++ ) ++ onnxruntime_add_static_library(onnxruntime_coreml_proto ${coreml_proto_srcs}) ++ target_include_directories(onnxruntime_coreml_proto PUBLIC $ "${CMAKE_CURRENT_BINARY_DIR}") ++ target_compile_definitions(onnxruntime_coreml_proto PUBLIC $) ++ set_target_properties(onnxruntime_coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") ++ set_target_properties(onnxruntime_coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") ++ set(_src_sub_dir "coreml/") ++ onnxruntime_protobuf_generate( ++ APPEND_PATH ++ GEN_SRC_SUB_DIR ${_src_sub_dir} ++ IMPORT_DIRS ${COREML_PROTO_ROOT} ++ TARGET onnxruntime_coreml_proto ++ ) ++ ++ if (NOT onnxruntime_BUILD_SHARED_LIB) ++ install(TARGETS onnxruntime_coreml_proto ++ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} ++ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ++ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} ++ FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR} ++ ) ++ endif() ++ endif() ++ ++ # These are shared utils, ++ # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML ++ file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS ++ "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" ++ "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" + ) +-endif() +- +-# Add the .proto and generated .cc/.h files to the External/coreml_proto folder in Visual Studio. +-# Separate source_group for each as the .proto files are in the repo and the .cc/.h files are generated in the build +-# output directory. +-set_target_properties(coreml_proto PROPERTIES FOLDER "External") +-source_group(TREE ${COREML_PROTO_ROOT} PREFIX coreml_proto FILES ${coreml_proto_srcs}) +- +-# filter to the generated .cc/.h files +-get_target_property(coreml_proto_generated_srcs coreml_proto SOURCES) +-list(FILTER coreml_proto_generated_srcs INCLUDE REGEX "\.pb\.(h|cc)$") +-source_group(TREE ${CMAKE_CURRENT_BINARY_DIR} PREFIX coreml_proto_generated FILES ${coreml_proto_generated_srcs}) +- +-# These are shared utils, +-# TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML +-file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS +- "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" +- "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" +-) +- +-file(GLOB +- onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS +- "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" +- "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.cc" +-) + +-# Add builder source code +-file(GLOB_RECURSE +- onnxruntime_providers_coreml_cc_srcs_nested CONFIGURE_DEPENDS +- "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" +- "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" +-) +-if (NOT CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") +- list(REMOVE_ITEM onnxruntime_providers_coreml_cc_srcs_nested +- "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.h" +- "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.cc" +- ) +-endif() +- +-# Add CoreML objective c++ source code +-if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") + file(GLOB +- onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS +- "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" +- "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.mm" +- "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" +- "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" ++ onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS ++ "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" ++ "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.cc" + ) +-endif() +- +-set(onnxruntime_providers_coreml_cc_srcs +- ${onnxruntime_providers_coreml_cc_srcs_top} +- ${onnxruntime_providers_coreml_cc_srcs_nested} +- ${onnxruntime_providers_shared_utils_cc_srcs} +-) +- +-source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_coreml_cc_srcs}) +-onnxruntime_add_static_library(onnxruntime_providers_coreml +- ${onnxruntime_providers_coreml_cc_srcs} ${onnxruntime_providers_coreml_objcc_srcs} +-) +-onnxruntime_add_include_to_target(onnxruntime_providers_coreml +- onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface +-) +-if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") +- onnxruntime_add_include_to_target(onnxruntime_providers_coreml coreml_proto) +- target_link_libraries(onnxruntime_providers_coreml PRIVATE coreml_proto "-framework Foundation" "-framework CoreML") +- add_dependencies(onnxruntime_providers_coreml coreml_proto) +-endif() +-add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) + +-set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) +-set_target_properties(onnxruntime_providers_coreml PROPERTIES FOLDER "ONNXRuntime") +-target_include_directories(onnxruntime_providers_coreml PRIVATE ${ONNXRUNTIME_ROOT} ${coreml_INCLUDE_DIRS}) +-set_target_properties(onnxruntime_providers_coreml PROPERTIES LINKER_LANGUAGE CXX) ++ # Add builder source code ++ file(GLOB_RECURSE ++ onnxruntime_providers_coreml_cc_srcs_nested CONFIGURE_DEPENDS ++ "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" ++ "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" ++ ) ++ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") ++ list(REMOVE_ITEM onnxruntime_providers_coreml_cc_srcs_nested ++ "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.h" ++ "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.cc" ++ ) ++ endif() ++ ++ # Add CoreML objective c++ source code ++ if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") ++ file(GLOB ++ onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS ++ "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" ++ "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.mm" ++ "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" ++ "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" ++ ) ++ endif() ++ ++ set(onnxruntime_providers_coreml_cc_srcs ++ ${onnxruntime_providers_coreml_cc_srcs_top} ++ ${onnxruntime_providers_coreml_cc_srcs_nested} ++ ${onnxruntime_providers_shared_utils_cc_srcs} ++ ) + +-if (NOT onnxruntime_BUILD_SHARED_LIB) +- install(TARGETS onnxruntime_providers_coreml +- ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} +- LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} +- RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} +- FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +-endif() ++ source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_coreml_cc_srcs}) ++ onnxruntime_add_static_library(onnxruntime_providers_coreml ++ ${onnxruntime_providers_coreml_cc_srcs} ${onnxruntime_providers_coreml_objcc_srcs} ++ ) ++ onnxruntime_add_include_to_target(onnxruntime_providers_coreml ++ onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ++ ) ++ if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") ++ onnxruntime_add_include_to_target(onnxruntime_providers_coreml onnxruntime_coreml_proto) ++ target_link_libraries(onnxruntime_providers_coreml PRIVATE onnxruntime_coreml_proto "-framework Foundation" "-framework CoreML") ++ add_dependencies(onnxruntime_providers_coreml onnxruntime_coreml_proto) ++ endif() ++ add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) ++ ++ set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) ++ set_target_properties(onnxruntime_providers_coreml PROPERTIES FOLDER "ONNXRuntime") ++ target_include_directories(onnxruntime_providers_coreml PRIVATE ${ONNXRUNTIME_ROOT} ${coreml_INCLUDE_DIRS}) ++ set_target_properties(onnxruntime_providers_coreml PROPERTIES LINKER_LANGUAGE CXX) ++ ++ if (NOT onnxruntime_BUILD_SHARED_LIB) ++ install(TARGETS onnxruntime_providers_coreml ++ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} ++ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ++ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} ++ FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) ++ endif() +\ No newline at end of file +diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake +index 9887d615c..84d1376f9 100644 +--- a/cmake/onnxruntime_providers_cuda.cmake ++++ b/cmake/onnxruntime_providers_cuda.cmake +@@ -1,25 +1,10 @@ + # Copyright (c) Microsoft Corporation. All rights reserved. + # Licensed under the MIT License. + +- +- if (onnxruntime_CUDA_MINIMAL) +- file(GLOB onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS +- "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" +- "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" +- "${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.h" +- "${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.cc" +- ) +- # Remove pch files +- list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs +- "${ONNXRUNTIME_ROOT}/core/providers/cuda/integer_gemm.cc" +- "${ONNXRUNTIME_ROOT}/core/providers/cuda/triton_kernel.h" +- ) +- else() +- file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS +- "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" +- "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" +- ) +- endif() ++ file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS ++ "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h" ++ "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc" ++ ) + # Remove pch files + list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs + "${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h" +@@ -31,16 +16,11 @@ + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" + ) ++ file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS ++ "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" ++ "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" ++ ) + +- +- if (onnxruntime_CUDA_MINIMAL) +- set(onnxruntime_providers_cuda_shared_srcs "") +- else() +- file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS +- "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu" +- "${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh" +- ) +- endif() + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) + set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs}) + +@@ -176,15 +156,10 @@ + endif() + + add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) +- if(onnxruntime_CUDA_MINIMAL) +- target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL) +- target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) +- else() +- target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) +- if(onnxruntime_CUDNN_HOME) +- target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) +- target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) +- endif() ++ target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) ++ if(onnxruntime_CUDNN_HOME) ++ target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) ++ target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) + endif() + + if (onnxruntime_USE_TRITON_KERNEL) +diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake +index 183a3e196..0951c2d02 100644 +--- a/cmake/onnxruntime_providers_vitisai.cmake ++++ b/cmake/onnxruntime_providers_vitisai.cmake +@@ -14,19 +14,14 @@ + "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" +- "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" +- "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" + ) + source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs}) +- onnxruntime_add_shared_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) +- onnxruntime_add_include_to_target(onnxruntime_providers_vitisai ${ONNXRUNTIME_PROVIDERS_SHARED} nlohmann_json::nlohmann_json safeint_interface flatbuffers::flatbuffers) +- target_link_libraries(onnxruntime_providers_vitisai PRIVATE ${ONNXRUNTIME_PROVIDERS_SHARED}) +- if(MSVC) +- onnxruntime_add_include_to_target(onnxruntime_providers_vitisai dbghelp) +- set_property(TARGET onnxruntime_providers_vitisai APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/vitisai/symbols.def") +- else(MSVC) +- set_property(TARGET onnxruntime_providers_vitisai APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/vitisai/version_script.lds -Xlinker --gc-sections") +- endif(MSVC) ++ onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) ++ onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto) ++ target_link_libraries(onnxruntime_providers_vitisai PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json) ++ if(NOT MSVC) ++ target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) ++ endif(NOT MSVC) + + target_include_directories(onnxruntime_providers_vitisai PRIVATE "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include" ${XRT_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/VitisAI) + if(MSVC) +@@ -35,18 +30,17 @@ + target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4251") + # for unused formal parameter + target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4100") +- # for type name first seen using 'class' now seen using 'struct' +- target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4099") + else(MSVC) +- target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) + target_compile_options(onnxruntime_providers_vitisai PRIVATE -Wno-unused-parameter) + endif(MSVC) + + set_target_properties(onnxruntime_providers_vitisai PROPERTIES FOLDER "ONNXRuntime") + set_target_properties(onnxruntime_providers_vitisai PROPERTIES LINKER_LANGUAGE CXX) + +- install(TARGETS onnxruntime_providers_vitisai +- ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} +- LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} +- RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} +- FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) ++ if (NOT onnxruntime_BUILD_SHARED_LIB) ++ install(TARGETS onnxruntime_providers_vitisai ++ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} ++ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} ++ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} ++ FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) ++ endif() +diff --git a/cmake/onnxruntime_providers_xnnpack.cmake b/cmake/onnxruntime_providers_xnnpack.cmake +index 6342c24b2..9c00703ca 100644 +--- a/cmake/onnxruntime_providers_xnnpack.cmake ++++ b/cmake/onnxruntime_providers_xnnpack.cmake +@@ -19,12 +19,6 @@ + flatbuffers::flatbuffers Boost::mp11 safeint_interface + ) + +- # TODO fix stringop-overflow warnings +- # Add compile option to suppress stringop-overflow error in Flatbuffers. +- if (HAS_STRINGOP_OVERFLOW) +- target_compile_options(onnxruntime_providers_xnnpack PRIVATE -Wno-error=stringop-overflow) +- endif() +- + add_dependencies(onnxruntime_providers_xnnpack onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + set_target_properties(onnxruntime_providers_xnnpack PROPERTIES FOLDER "ONNXRuntime") + +diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake +index 456344aa3..2e3594f25 100644 +--- a/cmake/onnxruntime_python.cmake ++++ b/cmake/onnxruntime_python.cmake +@@ -170,6 +170,7 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE + onnxruntime_session + ${onnxruntime_libs} + ${PROVIDERS_TVM} ++ ${PROVIDERS_VITISAI} + ${PROVIDERS_NNAPI} + ${PROVIDERS_XNNPACK} + ${PROVIDERS_COREML} +@@ -851,16 +852,6 @@ if (onnxruntime_USE_DNNL) + ) + endif() + +-if (onnxruntime_USE_VITISAI) +- add_custom_command( +- TARGET onnxruntime_pybind11_state POST_BUILD +- COMMAND ${CMAKE_COMMAND} -E copy +- ${DNNL_DLL_PATH} $ +- $ +- $/onnxruntime/capi/ +- ) +-endif() +- + if (onnxruntime_USE_TENSORRT) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD +diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake +index d485abe6b..f70961a66 100644 +--- a/cmake/onnxruntime_rocm_hipify.cmake ++++ b/cmake/onnxruntime_rocm_hipify.cmake +@@ -47,9 +47,6 @@ set(contrib_ops_excluded_files + "diffusion/group_norm.cc" + "diffusion/group_norm_impl.cu" + "diffusion/group_norm_impl.h" +- "diffusion/group_norm_impl_kernel.cuh" +- "diffusion/group_norm_common_base.h" +- "diffusion/group_norm_common_base.cc" + "diffusion/nhwc_conv.cc" + "math/gemm_float8.cc" + "math/gemm_float8.cu" +diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake +index 5b4a007d6..fa395802d 100644 +--- a/cmake/onnxruntime_unittests.cmake ++++ b/cmake/onnxruntime_unittests.cmake +@@ -566,7 +566,7 @@ endif() + + if(onnxruntime_USE_COREML) + if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") +- list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) ++ list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml onnxruntime_coreml_proto) + else() + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) + endif() +@@ -591,6 +591,7 @@ set(ONNXRUNTIME_TEST_LIBS + # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime + ${PROVIDERS_NNAPI} + ${PROVIDERS_JS} ++ ${PROVIDERS_VITISAI} + ${PROVIDERS_QNN} + ${PROVIDERS_SNPE} + ${PROVIDERS_RKNPU} +@@ -675,9 +676,9 @@ endif() + if(onnxruntime_USE_COREML) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/coreml/*) + if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") +- list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto) +- list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) +- list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto) ++ list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml onnxruntime_coreml_proto) ++ list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml onnxruntime_coreml_proto) ++ list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml onnxruntime_coreml_proto) + else() + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) +@@ -823,8 +824,6 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + "${TEST_SRC_DIR}/providers/memcpy_test.cc" + ) + endif() +- list(REMOVE_ITEM all_tests "${TEST_SRC_DIR}/providers/cpu/reduction/reduction_ops_test.cc" +- "${TEST_SRC_DIR}/providers/cpu/tensor/grid_sample_test.cc") + endif() + + set(test_all_args) +@@ -907,7 +906,7 @@ if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js) +- set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s INITIAL_MEMORY=536870912 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") ++ set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_all_adapter.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1 -s DEMANGLE_SUPPORT=1") + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") + endif() +@@ -1278,9 +1277,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + if (onnxruntime_USE_CUDA) + list(APPEND onnxruntime_shared_lib_test_LIBS cudart) + endif() +- if (onnxruntime_USE_ROCM) +- list(APPEND onnxruntime_shared_lib_test_LIBS hip::host) +- endif() + if (onnxruntime_USE_TENSORRT) + list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) + endif() +@@ -1298,10 +1294,6 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) + endif() +- if (onnxruntime_USE_ROCM) +- target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include) +- target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__) +- endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") + target_sources(onnxruntime_shared_lib_test PRIVATE + "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" +diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake +index 546d50c1c..858583e64 100644 +--- a/cmake/onnxruntime_webassembly.cmake ++++ b/cmake/onnxruntime_webassembly.cmake +@@ -268,10 +268,7 @@ else() + endif() + + if (onnxruntime_USE_WEBNN) +- set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT") +- if (onnxruntime_DISABLE_RTTI) +- set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -fno-rtti -DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0") +- endif() ++ set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT") + endif() + + # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. +diff --git a/cmake/patches/coremltools/crossplatformbuild.patch b/cmake/patches/coremltools/crossplatformbuild.patch +deleted file mode 100644 +index 7f2268f50..000000000 +--- a/cmake/patches/coremltools/crossplatformbuild.patch ++++ /dev/null +@@ -1,155 +0,0 @@ +-diff --git a/mlmodel/src/MILBlob/Blob/FileWriter.cpp b/mlmodel/src/MILBlob/Blob/FileWriter.cpp +-index adc7bfcf..7b2bf9cc 100644 +---- a/mlmodel/src/MILBlob/Blob/FileWriter.cpp +-+++ b/mlmodel/src/MILBlob/Blob/FileWriter.cpp +-@@ -8,8 +8,12 @@ +- +- #include +- #include +-+ +-+// ORT_EDIT: Exclude mmap on Windows. Not used in this file anyway. +-+#if !defined(_WIN32) +- #include +- #include +-+#endif +- +- using namespace MILBlob; +- using namespace MILBlob::Blob; +-diff --git a/mlmodel/src/MILBlob/Fp16.cpp b/mlmodel/src/MILBlob/Fp16.cpp +-index ae1e71a1..77a7161f 100644 +---- a/mlmodel/src/MILBlob/Fp16.cpp +-+++ b/mlmodel/src/MILBlob/Fp16.cpp +-@@ -5,6 +5,8 @@ +- +- #include "MILBlob/Fp16.hpp" +- +-+// ORT_EDIT: Exclude clang specific pragmas from other builds +-+#if defined(__clang__) +- // fp16 lib code has some conversion warnings we don't want to globally ignore +- #pragma clang diagnostic push +- #pragma clang diagnostic ignored "-Wincompatible-pointer-types" +-@@ -12,6 +14,9 @@ +- #pragma clang diagnostic ignored "-Wconversion" +- #include "fp16/fp16.h" +- #pragma clang diagnostic pop +-+#else +-+#include "fp16/fp16.h" +-+#endif +- +- using namespace MILBlob; +- +-diff --git a/modelpackage/src/ModelPackage.cpp b/modelpackage/src/ModelPackage.cpp +-index 8fee56b9..99e0d8d6 100644 +---- a/modelpackage/src/ModelPackage.cpp +-+++ b/modelpackage/src/ModelPackage.cpp +-@@ -26,7 +26,14 @@ namespace std { +- #else +- #error "missing required header " +- #endif +-+ +-+// ORT_EDIT: Use UuidCreate on Windows. +-+#if defined(_WIN32) +-+#pragma comment(lib, "rpcrt4.lib") // UuidCreate +-+#include +-+#else +- #include +-+#endif +- #include +- +- #if defined(__cplusplus) +-@@ -187,7 +194,10 @@ public: +- ModelPackageItemInfo createFile(const std::string& name, const std::string& author, const std::string& description); +- }; +- +-+// ORT_EDIT: pragma only available on APPLE platforms +-+#if defined(__APPLE__) +- #pragma mark ModelPackageImpl +-+#endif +- +- ModelPackageImpl::ModelPackageImpl(const std::filesystem::path& path, bool createIfNecessary, bool readOnly) +- : m_packagePath(path), +-@@ -372,6 +382,20 @@ std::filesystem::path ModelPackageImpl::getItemPath(const std::string& name, con +- } +- +- std::string ModelPackageImpl::generateIdentifier() const { +-+// ORT_EDIT: Use built-in UUID generation on Windows +-+#if defined(_WIN32) +-+ UUID uuid; +-+ UuidCreate(&uuid); +-+ +-+ RPC_CSTR uuidStr; +-+ UuidToStringA(&uuid, &uuidStr); +-+ +-+ std::string uuidStrCpp(reinterpret_cast(uuidStr)); +-+ +-+ RpcStringFreeA(&uuidStr); +-+ +-+ return uuidStrCpp; +-+#else +- uuid_t uuid; +- +- // uuid_unparse generates a 36-character null-terminated string (37 bytes). +-@@ -383,6 +407,7 @@ std::string ModelPackageImpl::generateIdentifier() const { +- uuid_unparse(uuid, buf); +- +- return std::string(buf); +-+#endif +- } +- +- ModelPackageItemInfo ModelPackageImpl::createFile(const std::string& name, const std::string& author, const std::string& description) { +-@@ -468,7 +493,13 @@ std::shared_ptr ModelPackageImpl::findItem(const std::stri +- auto author = itemInfoEntry->getString(kModelPackageItemInfoAuthorKey); +- auto description = itemInfoEntry->getString(kModelPackageItemInfoDescriptionKey); +- +-+// ORT_EDIT: need to use path.string() on Windows +-+#if defined(_WIN32) +-+ return std::make_shared(std::make_shared(identifier, path.string(), name, author, description)); +-+ +-+#else +- return std::make_shared(std::make_shared(identifier, path, name, author, description)); +-+#endif +- } +- +- std::shared_ptr ModelPackageImpl::findItem(const std::string& name, const std::string& author) const +-@@ -514,7 +545,9 @@ void ModelPackageImpl::removeItem(const std::string& identifier) +- } +- +- auto path = m_packageDataDirPath / itemInfoEntry->getString(kModelPackageItemInfoPathKey); +-- if (0 != std::remove(path.c_str())) { +-+ // ORT_EDIT: std::remove doesn't work on Windows. Use std::filesystem::remove instead. +-+ // if (0 != std::remove(path.c_str())) { +-+ if (!std::filesystem::remove(path)) { +- throw std::runtime_error("Failed to remove file at path: " + path.string()); +- } +- +-@@ -525,13 +558,16 @@ bool ModelPackageImpl::isValid(const std::filesystem::path& path) +- { +- try { +- ModelPackageImpl(path, false, true); +-- } catch (std::runtime_error& e) { +-+ } catch (std::runtime_error& /*e*/) { // ORT_EDIT: comment out unused variable +- return false; +- } +- return true; +- } +- +-+// ORT_EDIT: pragma only available on APPLE platforms +-+#if defined(__APPLE__) +- #pragma mark ModelPackage +-+#endif +- +- ModelPackage::ModelPackage(const std::string& packagePath, bool createIfNecessary, bool readOnly) +- : m_modelPackageImpl(std::make_shared(packagePath, createIfNecessary, readOnly)) +-@@ -544,7 +580,12 @@ ModelPackage::~ModelPackage() +- +- std::string ModelPackage::path() const +- { +-+// ORT_EDIT: Windows doesn't automatically convert to std::string as the native format could be char or wchar. +-+#if defined(_WIN32) +-+ return m_modelPackageImpl->path().string(); +-+#else +- return m_modelPackageImpl->path(); +-+#endif +- } +- +- std::string ModelPackage::setRootModel(const std::string& path, const std::string& name, const std::string& author, const std::string& description) +diff --git a/cmake/patches/flatbuffers/flatbuffers.patch b/cmake/patches/flatbuffers/flatbuffers.patch +index f141d358c..fb2678ef1 100644 +--- a/cmake/patches/flatbuffers/flatbuffers.patch ++++ b/cmake/patches/flatbuffers/flatbuffers.patch +@@ -7,7 +7,7 @@ index 3987eac9..5e5462f1 100644 + endif(CYGWIN) + set(CMAKE_CXX_FLAGS + - "${CMAKE_CXX_FLAGS} -Wall -pedantic -Werror -Wextra -Werror=shadow") +-+ "${CMAKE_CXX_FLAGS} -Wall -pedantic -Wextra -Werror=shadow -Wno-error=stringop-overflow") +++ "${CMAKE_CXX_FLAGS} -Wall -pedantic -Werror -Wextra -Werror=shadow -Wno-error=stringop-overflow") + set(FLATBUFFERS_PRIVATE_CXX_FLAGS "-Wold-style-cast") + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.4) + if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) +diff --git a/cmake/riscv64.toolchain.cmake b/cmake/riscv64.toolchain.cmake +deleted file mode 100644 +index 0fda239f9..000000000 +--- a/cmake/riscv64.toolchain.cmake ++++ /dev/null +@@ -1,35 +0,0 @@ +-# Copyright (c) 2024 SiFive, Inc. All rights reserved. +-# Copyright (c) 2024, Phoebe Chen +-# Licensed under the MIT License. +- +-set(CMAKE_SYSTEM_NAME Linux) +-set(CMAKE_SYSTEM_PROCESSOR riscv64) +- +-list(APPEND CMAKE_TRY_COMPILE_PLATFORM_VARIABLES RISCV_TOOLCHAIN_ROOT) +- +-if(NOT RISCV_TOOLCHAIN_ROOT) +- message(FATAL_ERROR "RISCV_TOOLCHAIN_ROOT is not defined. Please set the RISCV_TOOLCHAIN_ROOT variable.") +-endif() +- +-set(CMAKE_C_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc") +-set(CMAKE_ASM_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-gcc") +-set(CMAKE_CXX_COMPILER "${RISCV_TOOLCHAIN_ROOT}/bin/riscv64-unknown-linux-gnu-g++") +- +-set(CMAKE_FIND_ROOT_PATH ${RISCV_TOOLCHAIN_ROOT}) +-set(CMAKE_SYSROOT "${RISCV_TOOLCHAIN_ROOT}/sysroot") +-set(CMAKE_INCLUDE_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/include/") +-set(CMAKE_LIBRARY_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/lib/") +-set(CMAKE_PROGRAM_PATH "${RISCV_TOOLCHAIN_ROOT}/sysroot/usr/bin/") +- +-if(RISCV_QEMU_PATH) +- message(STATUS "RISCV_QEMU_PATH=${RISCV_QEMU_PATH} is defined during compilation.") +- set(CMAKE_CROSSCOMPILING_EMULATOR "${RISCV_QEMU_PATH};-L;${CMAKE_SYSROOT}") +-endif() +- +-set(CMAKE_CROSSCOMPILING TRUE) +- +-set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +-set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +-set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +-set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) +- +diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +index 7fe16f415..68a399f8b 100644 +--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs ++++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +@@ -65,10 +65,10 @@ namespace Microsoft.ML.OnnxRuntime + DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi)); + + // TODO: Make this save the pointer, and not copy the whole structure across +- api_ = (OrtApi)OrtGetApi(18 /*ORT_API_VERSION*/); ++ api_ = (OrtApi)OrtGetApi(17 /*ORT_API_VERSION*/); + + OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi)); +- trainingApiPtr = OrtGetTrainingApi(18 /*ORT_API_VERSION*/); ++ trainingApiPtr = OrtGetTrainingApi(17 /*ORT_API_VERSION*/); + if (trainingApiPtr != IntPtr.Zero) + { + trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi)); +diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +index fec0d46e9..877677dca 100644 +--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs ++++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +@@ -282,48 +282,6 @@ namespace Microsoft.ML.OnnxRuntime + } + } + +- /// +- /// This function performs a training step that computes the outputs of the training model and the gradients +- /// of the trainable parameters for the given OrtValue inputs. The train step is performed based on the training model +- /// that was provided to the training session. +- /// The TrainStep method is equivalent of running forward propagation and backward propagation in a single +- /// step. +- /// The gradients computed are stored inside the training session state so they can be later consumed +- /// by the OptimizerStep function. +- /// The gradients can be lazily reset by invoking the LazyResetGrad function. +- /// Example usage: +- /// +- /// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...); +- /// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...); +- /// List inputValues = new List { x, label }; +- /// using (var loss = trainingSession.TrainStep(inputValues)) +- /// { +- /// // process output values +- /// } +- /// +- /// +- /// Specify a collection of that indicates the input values to the training model. +- /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. +- public IDisposableReadOnlyCollection TrainStep(IReadOnlyCollection inputValues) +- { +- IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues); +- IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount]; +- +- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count, +- inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray)); +- +- +- var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray); +- try +- { +- return CreateDisposableResult(disposableHandles); +- } +- finally +- { +- disposableHandles.Dispose(); +- } +- } +- + /// + /// Convert native OrtValue handles to OrtValue instances + /// in an exceptions safe manner. +@@ -412,42 +370,6 @@ namespace Microsoft.ML.OnnxRuntime + inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); + } + +- /// +- /// This function performs an eval step that computes the outputs of the eval model for the given inputs. +- /// Inputs are expected to be of type OrtValue. The eval step is performed based on the eval model that was +- /// provided to the training session. +- /// Example usage: +- /// +- /// using OrtValue x = OrtValue.CreateTensorValueFromMemory(...); +- /// using OrtValue label = OrtValue.CreateTensorValueFromMemory(...); +- /// List inputValues = new List { x, label }; +- /// using (var loss = trainingSession.EvalSteps(inputValues)) +- /// { +- /// // process output values +- /// } +- /// +- /// +- /// Specify a collection of that indicates the input values to the eval model. +- public IDisposableReadOnlyCollection EvalStep(IReadOnlyCollection inputValues) +- { +- IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues); +- IntPtr[] outputValuesArray = new IntPtr[(int)_evalOutputCount]; +- +- NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, IntPtr.Zero, (UIntPtr)inputValues.Count, +- inputValuesArray, (UIntPtr)_evalOutputCount, outputValuesArray)); +- +- +- var disposableHandles = new DisposableOrtValueHandleArray(outputValuesArray); +- try +- { +- return CreateDisposableResult(disposableHandles); +- } +- finally +- { +- disposableHandles.Dispose(); +- } +- } +- + + /// + /// Sets the learning rate for this training session. +@@ -780,35 +702,6 @@ namespace Microsoft.ML.OnnxRuntime + return valuesArray; + } + +- private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection inputValues) +- { +- var valuesArray = new IntPtr[inputValues.Count]; +- for (int index = 0; index < inputValues.Count; ++index) +- { +- valuesArray[index] = inputValues.ElementAt(index).Handle; +- } +- return valuesArray; +- } +- +- private static IDisposableReadOnlyCollection CreateDisposableResult(DisposableOrtValueHandleArray disposableHandles) +- { +- var outputValues = new DisposableList(disposableHandles.Span.Length); +- try +- { +- for (int i = 0; i < disposableHandles.Span.Length; i++) +- { +- outputValues.Add(new OrtValue(disposableHandles.Span[i])); +- disposableHandles.Span[i] = IntPtr.Zero; +- } +- return outputValues; +- } +- catch (Exception) +- { +- outputValues.Dispose(); +- throw; +- } +- } +- + private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection names, DisposableList cleanupList) + { + cleanupList.Capacity += names.Count; +diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +index 9b7232620..68b1d5bcc 100644 +--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs ++++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +@@ -612,81 +612,6 @@ namespace Microsoft.ML.OnnxRuntime.Tests + } + } + +- [Fact(DisplayName = "TestTrainingSessionTrainStepWithOrtValues")] +- public void TestTrainingSessionTrainStepWithOrtValues() +- { +- string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); +- using (var cleanUp = new DisposableListTest()) +- { +- var state = CheckpointState.LoadCheckpoint(checkpointPath); +- cleanUp.Add(state); +- Assert.NotNull(state); +- string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); +- string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); +- +- var trainingSession = new TrainingSession(state, trainingPath, optimizerPath); +- cleanUp.Add(trainingSession); +- +- float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out"); +- var expectedOutputDimensions = new int[] { 1 }; +- float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in"); +- long[] inputShape = { 2, 784 }; +- Int32[] labelsData = { 1, 1 }; +- long[] labelsShape = { 2 }; +- +- using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape); +- using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory(labelsData, labelsShape); +- var inputValues = new List { inputOrtValue, labelsOrtValue }; +- +- using (var results = trainingSession.TrainStep(inputValues)) +- { +- Assert.Single(results); +- var outputOrtValue = results[0]; +- Assert.True(outputOrtValue.IsTensor); +- var resultSpan = outputOrtValue.GetTensorDataAsSpan().ToArray(); +- Assert.Equal(expectedOutput, resultSpan, new FloatComparer()); +- } +- } +- } +- +- [Fact(DisplayName = "TestTrainingSessionEvalStepWithOrtValues")] +- public void TestTrainingSessionEvalStepWithOrtValues() +- { +- string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); +- using (var cleanUp = new DisposableListTest()) +- { +- var state = CheckpointState.LoadCheckpoint(checkpointPath); +- cleanUp.Add(state); +- Assert.NotNull(state); +- string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); +- string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); +- string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); +- +- var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); +- cleanUp.Add(trainingSession); +- +- float[] expectedOutput = TestDataLoader.LoadTensorFromFile("loss_1.out"); +- var expectedOutputDimensions = new int[] { 1 }; +- float[] inputData = TestDataLoader.LoadTensorFromFile("input-0.in"); +- long[] inputShape = { 2, 784 }; +- Int32[] labelsData = { 1, 1 }; +- long[] labelsShape = { 2 }; +- +- using OrtValue inputOrtValue = OrtValue.CreateTensorValueFromMemory(inputData, inputShape); +- using OrtValue labelsOrtValue = OrtValue.CreateTensorValueFromMemory(labelsData, labelsShape); +- var inputValues = new List { inputOrtValue, labelsOrtValue }; +- +- using (var results = trainingSession.EvalStep(inputValues)) +- { +- Assert.Single(results); +- var outputOrtValue = results[0]; +- Assert.True(outputOrtValue.IsTensor); +- var resultSpan = outputOrtValue.GetTensorDataAsSpan().ToArray(); +- Assert.Equal(expectedOutput, resultSpan, new FloatComparer()); +- } +- } +- } +- + internal class FloatComparer : IEqualityComparer + { + private float atol = 1e-3f; +diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md +index e7b537d68..fd26b09b0 100644 +--- a/docs/ContribOperators.md ++++ b/docs/ContribOperators.md +@@ -5769,7 +5769,7 @@ This version of the operator has been available since version 1 of the 'com.micr +
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
+ + +-#### Inputs (5 - 15) ++#### Inputs (5 - 14) + +
+
input_ids : F
+@@ -5800,8 +5800,6 @@ This version of the operator has been available since version 1 of the 'com.micr +
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
+
extra_decoding_ids (optional) : I
+
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
+-
temperature (optional) : T
+-
Temperature value to apply to logits processing during this execution's decoding. Shape is (1)
+
+ + #### Outputs (1 - 5) +diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md +index 2ea557b7d..6e5d84200 100644 +--- a/docs/OperatorKernels.md ++++ b/docs/OperatorKernels.md +@@ -499,7 +499,7 @@ Do not modify directly.* + |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| + |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)| + |Unique|*in* x:**T**
*out* y:**T**
*out* idx:**tensor(int64)**
*out* counts:**tensor(int64)**|1+|**T** = tensor(float)| +-|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| ++|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| + |WordConvEmbedding|*in* Sequence:**T**
*in* W:**T1**
*in* B:**T1**
*in* C:**T1**
*out* Y:**T1**|1+|**T** = tensor(int32)
**T1** = tensor(float)| + | | + | | +@@ -682,8 +682,7 @@ Do not modify directly.* + |PRelu|*in* X:**T**
*in* slope:**T**
*out* Y:**T**|16+|**T** = tensor(double), tensor(float), tensor(float16)| + |||[9, 15]|**T** = tensor(double), tensor(float), tensor(float16)| + |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| +-|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|18+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| +-|||[13, 17]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| ++|Pad|*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*in* axes:**Tind**
*out* output:**T**

or

*in* data:**T**
*in* pads:**tensor(int64)**
*in* constant_value:**T**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16)| + |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| + |||[2, 10]|**T** = tensor(double), tensor(float), tensor(float16)| + |ParametricSoftplus|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +@@ -744,9 +743,7 @@ Do not modify directly.* + |||[9, 10]|**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |||8|**I** = tensor(int64)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |Scatter|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|[9, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +-|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +-|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +-|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| ++|ScatterElements|*in* data:**T**
*in* indices:**Tind**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| + |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| + |ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +@@ -879,7 +876,7 @@ Do not modify directly.* + |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| + |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +-|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| ++|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| + | | + | | + +@@ -925,12 +922,10 @@ Do not modify directly.* + |BitwiseNot|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +-|Cast|*in* input:**T1**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +-|||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| ++|Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +-|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +-|||15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| ++|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| + |||6+|**T** = tensor(float), tensor(float16)| + |Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float), tensor(float16)| +@@ -957,8 +952,7 @@ Do not modify directly.* + |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +-|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +-|||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| ++|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| + |||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)| + |Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +@@ -967,8 +961,7 @@ Do not modify directly.* + |DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| + |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(float), tensor(float16)| + |Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| +-|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +-|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| ++|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| + |||11+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| + |||7+|**T** = tensor(float), tensor(float16)
**T1** = tensor(bool)| + |Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| +@@ -1011,8 +1004,7 @@ Do not modify directly.* + |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| + |||11+|**T** = tensor(float), tensor(float16)| + |||1+|**T** = tensor(float), tensor(float16)| +-|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +-|||16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| ++|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |||14+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +@@ -1107,8 +1099,7 @@ Do not modify directly.* + |||7+|**T** = tensor(float), tensor(float16)| + |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| + |QLinearMatMul|*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| +-|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| +-|||13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| ++|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| + |||10+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| + |RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(float), tensor(float16)| + |||7+|**T** = tensor(float), tensor(float16)| +@@ -1159,8 +1150,7 @@ Do not modify directly.* + |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| + |||13+|**T** = tensor(float), tensor(float16)| + |||6+|**T** = tensor(float), tensor(float16)| +-|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +-|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| ++|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +@@ -1188,8 +1178,7 @@ Do not modify directly.* + |SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| + |SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| + |SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| +-|Shape|*in* data:**T**
*out* shape:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +-|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| ++|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| + |||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| + |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| + |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| +@@ -1199,8 +1188,7 @@ Do not modify directly.* + |||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| + |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| + |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| +-|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +-|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| ++|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| + |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| + |Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| + |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| +diff --git a/docs/python/README.rst b/docs/python/README.rst +index bbc8571fe..32bb3729e 100644 +--- a/docs/python/README.rst ++++ b/docs/python/README.rst +@@ -8,11 +8,6 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime (); ++ } + } + + /* +@@ -271,6 +274,19 @@ class IExecutionProvider { + return logger_; + } + ++ /** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance. ++ The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models. ++ @param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph. ++ @param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model. ++ This is created using the model path if available, ++ or the model input names and the output names from all nodes in the main graph. ++ @remarks e.g. the TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches ++ compiled kernels, so the name must be unique and deterministic across models and sessions. ++ NOTE: Ideally this would be a protected method, but to work across the EP bridge it has to be public and ++ virtual, and ModelMetadefIdGenerator but be defined in the header as well. ++ */ ++ virtual int GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const; ++ + virtual std::unique_ptr GetProfiler() { + return {}; + } +@@ -324,5 +340,18 @@ class IExecutionProvider { + + // It will be set when this object is registered to a session + const logging::Logger* logger_ = nullptr; ++ ++ // helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across ++ // multiple sessions. ++ class ModelMetadefIdGenerator { ++ public: ++ int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash); ++ ++ private: ++ std::unordered_map main_graph_hash_; // map graph instance hash to model contents hash ++ std::unordered_map model_metadef_id_; // current unique id for model ++ }; ++ ++ std::unique_ptr metadef_id_generator_; + }; + } // namespace onnxruntime +diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h +index 1370f5c4c..9416fad5f 100644 +--- a/include/onnxruntime/core/providers/cuda/cuda_context.h ++++ b/include/onnxruntime/core/providers/cuda/cuda_context.h +@@ -16,10 +16,9 @@ + #include "core/providers/custom_op_context.h" + #include + #include +-#ifndef USE_CUDA_MINIMAL + #include + #include +-#endif ++ + namespace Ort { + + namespace Custom { +diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h +index 5577c840c..64095a31a 100644 +--- a/include/onnxruntime/core/session/onnxruntime_c_api.h ++++ b/include/onnxruntime/core/session/onnxruntime_c_api.h +@@ -38,7 +38,7 @@ + * + * This value is used by some API functions to behave as this version of the header expects. + */ +-#define ORT_API_VERSION 18 ++#define ORT_API_VERSION 17 + + #ifdef __cplusplus + extern "C" { +@@ -496,7 +496,6 @@ typedef struct OrtROCMProviderOptions { + has_user_compute_stream{}, + user_compute_stream{}, + default_memory_arena_cfg{}, +- enable_hip_graph{false}, + tunable_op_enable{false}, + tunable_op_tuning_enable{false}, + tunable_op_max_tuning_duration_ms{} {} +@@ -549,8 +548,6 @@ typedef struct OrtROCMProviderOptions { + */ + OrtArenaCfg* default_memory_arena_cfg; + +- int enable_hip_graph; +- + /** \brief Enable TunableOp for using. + * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. + * This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. +@@ -4569,23 +4566,6 @@ struct OrtApi { + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); +- +- /** \brief Append VitisAI provider to session options +- * +- * If VitisAI is not available (due to a non VitisAI enabled build, or if VitisAI is not installed on the system), this function will return failure. +- * +- * \param[in] options +- * \param[in] provider_options_keys +- * \param[in] provider_options_values +- * \param[in] num_keys +- * +- * \snippet{doc} snippets.dox OrtStatus Return Value +- */ +- ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, +- _In_ OrtSessionOptions* options, +- _In_reads_(num_keys) const char* const* provider_options_keys, +- _In_reads_(num_keys) const char* const* provider_options_values, +- _In_ size_t num_keys); + }; + + /* +diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h +index ae4c4bef9..7a553f9f9 100644 +--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h ++++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h +@@ -901,9 +901,6 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { + SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {}); + + SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction +- +- ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI +- SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options = {}); + }; + } // namespace detail + +diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +index 23246adff..957e849cf 100644 +--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h ++++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +@@ -885,25 +885,6 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_Ope + return *this; + } + +-template +-inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options) { +- auto num_entries = provider_options.size(); +- std::vector keys, values; +- if (num_entries > 0) { +- keys.reserve(num_entries); +- values.reserve(num_entries); +- +- for (const auto& entry : provider_options) { +- keys.push_back(entry.first.c_str()); +- values.push_back(entry.second.c_str()); +- } +- } +- +- ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_VitisAI(this->p_, keys.data(), values.data(), num_entries)); +- +- return *this; +-} +- + template + inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, + const CustomOpConfigs& custom_op_configs) { +diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +index 4a5e2b7ef..3a1c0d1bb 100644 +--- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c ++++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +@@ -8,7 +8,7 @@ + #include "onnxruntime/core/session/onnxruntime_c_api.h" + #include "OrtJniUtil.h" + #include "ai_onnxruntime_OrtSession_SessionOptions.h" +-#ifdef _WIN32 ++#ifdef WIN32 + #include + #else + #include +@@ -318,7 +318,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeC + + // Iterate the handles, calling the appropriate close function + for (jint i = 0; i < numHandles; i++) { +-#ifdef _WIN32 ++#ifdef WIN32 + FreeLibrary((void*)handles[i]); + #else + dlclose((void*)handles[i]); +diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts +index 4f85c3b46..1221b52cd 100644 +--- a/js/common/lib/inference-session.ts ++++ b/js/common/lib/inference-session.ts +@@ -111,7 +111,7 @@ export declare namespace InferenceSession { + optimizedModelFilePath?: string; + + /** +- * Whether enable profiling. ++ * Wether enable profiling. + * + * This setting is a placeholder for a future use. + */ +@@ -154,12 +154,6 @@ export declare namespace InferenceSession { + */ + preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation}; + +- /** +- * Whether enable graph capture. +- * This setting is available only in ONNXRuntime Web for WebGPU EP. +- */ +- enableGraphCapture?: boolean; +- + /** + * Store configurations for a session. See + * https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/ +diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts +index 40f970ddf..96c2361cc 100644 +--- a/js/common/lib/version.ts ++++ b/js/common/lib/version.ts +@@ -4,4 +4,4 @@ + // This file is generated by /js/scripts/update-version.ts + // Do not modify file content manually. + +-export const version = '1.18.0'; ++export const version = '1.17.0'; +diff --git a/js/common/package-lock.json b/js/common/package-lock.json +index a5ada877b..84f6dba83 100644 +--- a/js/common/package-lock.json ++++ b/js/common/package-lock.json +@@ -1,12 +1,12 @@ + { + "name": "onnxruntime-common", +- "version": "1.18.0", ++ "version": "1.17.0", + "lockfileVersion": 2, + "requires": true, + "packages": { + "": { + "name": "onnxruntime-common", +- "version": "1.18.0", ++ "version": "1.17.0", + "license": "MIT", + "devDependencies": { + "typedoc": "^0.23.22" +diff --git a/js/common/package.json b/js/common/package.json +index 64ab2736a..beab7d29b 100644 +--- a/js/common/package.json ++++ b/js/common/package.json +@@ -2,7 +2,7 @@ + "license": "MIT", + "type": "module", + "name": "onnxruntime-common", +- "version": "1.18.0", ++ "version": "1.17.0", + "repository": { + "url": "https://github.com/Microsoft/onnxruntime.git", + "type": "git" +diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts +index 40f970ddf..96c2361cc 100644 +--- a/js/node/lib/version.ts ++++ b/js/node/lib/version.ts +@@ -4,4 +4,4 @@ + // This file is generated by /js/scripts/update-version.ts + // Do not modify file content manually. + +-export const version = '1.18.0'; ++export const version = '1.17.0'; +diff --git a/js/node/package-lock.json b/js/node/package-lock.json +index 2d7c39c86..542eebe74 100644 +--- a/js/node/package-lock.json ++++ b/js/node/package-lock.json +@@ -1,12 +1,12 @@ + { + "name": "onnxruntime-node", +- "version": "1.18.0", ++ "version": "1.17.0", + "lockfileVersion": 2, + "requires": true, + "packages": { + "": { + "name": "onnxruntime-node", +- "version": "1.18.0", ++ "version": "1.17.0", + "license": "MIT", + "os": [ + "win32", +@@ -27,7 +27,7 @@ + }, + "../common": { + "name": "onnxruntime-common", +- "version": "1.18.0", ++ "version": "1.17.0", + "license": "MIT", + "devDependencies": { + "typedoc": "^0.23.22" +diff --git a/js/node/package.json b/js/node/package.json +index 026840742..8e591d8f4 100644 +--- a/js/node/package.json ++++ b/js/node/package.json +@@ -13,7 +13,7 @@ + 3 + ] + }, +- "version": "1.18.0", ++ "version": "1.17.0", + "dependencies": { + "onnxruntime-common": "file:../common" + }, +diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts +index 40f970ddf..96c2361cc 100644 +--- a/js/react_native/lib/version.ts ++++ b/js/react_native/lib/version.ts +@@ -4,4 +4,4 @@ + // This file is generated by /js/scripts/update-version.ts + // Do not modify file content manually. + +-export const version = '1.18.0'; ++export const version = '1.17.0'; +diff --git a/js/react_native/package.json b/js/react_native/package.json +index 47324a76f..39e6cb08b 100644 +--- a/js/react_native/package.json ++++ b/js/react_native/package.json +@@ -36,7 +36,7 @@ + "registry": "https://registry.npmjs.org/" + }, + "source": "lib/index", +- "version": "1.18.0", ++ "version": "1.17.0", + "main": "dist/commonjs/index", + "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", + "files": [ +diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock +index 4dca90d74..ff9be7fbe 100644 +--- a/js/react_native/yarn.lock ++++ b/js/react_native/yarn.lock +@@ -5254,7 +5254,7 @@ onetime@^5.1.0, onetime@^5.1.2: + mimic-fn "^2.1.0" + + "onnxruntime-common@file:../common": +- version "1.18.0" ++ version "1.17.0" + + open@^6.2.0: + version "6.4.0" +diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md +index 2557971eb..2f510308d 100644 +--- a/js/web/docs/webgpu-operators.md ++++ b/js/web/docs/webgpu-operators.md +@@ -52,7 +52,6 @@ Do not modify directly.* + | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | + | Greater | ai.onnx(7-8,9-12,13+) | | + | GreaterOrEqual | ai.onnx(12-15,16+) | | +-| HardSigmoid | ai.onnx(6+) | | + | If | ai.onnx(1-10,11-12,13-18,19+) | | + | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | + | LayerNormalization | ai.onnx(17+) | | +diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts +index 2c9cd88a3..b3868871a 100644 +--- a/js/web/lib/build-def.d.ts ++++ b/js/web/lib/build-def.d.ts +@@ -21,6 +21,10 @@ interface BuildDefinitions { + /** + * defines whether to disable the whole WebNN backend in the build. + */ ++ readonly DISABLE_WEBNN: boolean; ++ /** ++ * defines whether to disable the whole WebAssembly backend in the build. ++ */ + readonly DISABLE_WASM: boolean; + /** + * defines whether to disable proxy feature in WebAssembly backend in the build. +diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts +index b212c0f49..baf45e74a 100644 +--- a/js/web/lib/index.ts ++++ b/js/web/lib/index.ts +@@ -23,10 +23,12 @@ if (!BUILD_DEFS.DISABLE_WASM) { + require('./backend-wasm-training').wasmBackend; + if (!BUILD_DEFS.DISABLE_WEBGPU) { + registerBackend('webgpu', wasmBackend, 5); +- registerBackend('webnn', wasmBackend, 5); + } + registerBackend('cpu', wasmBackend, 10); + registerBackend('wasm', wasmBackend, 10); ++ if (!BUILD_DEFS.DISABLE_WEBNN) { ++ registerBackend('webnn', wasmBackend, 9); ++ } + } + + Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); +diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts +index 40f970ddf..96c2361cc 100644 +--- a/js/web/lib/version.ts ++++ b/js/web/lib/version.ts +@@ -4,4 +4,4 @@ + // This file is generated by /js/scripts/update-version.ts + // Do not modify file content manually. + +-export const version = '1.18.0'; ++export const version = '1.17.0'; +diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts +index 5dd715191..9d4d58753 100644 +--- a/js/web/lib/wasm/binding/ort-wasm.d.ts ++++ b/js/web/lib/wasm/binding/ort-wasm.d.ts +@@ -13,9 +13,6 @@ export declare namespace JSEP { + type ReleaseKernelFunction = (kernel: number) => void; + type RunFunction = + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; +- type CaptureBeginFunction = () => void; +- type CaptureEndFunction = () => void; +- type ReplayFunction = () => void; + } + + export interface OrtWasmModule extends EmscriptenModule { +@@ -34,7 +31,7 @@ export interface OrtWasmModule extends EmscriptenModule { + + _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; + +- _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise; ++ _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number; + _OrtReleaseSession(sessionHandle: number): void; + _OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; + _OrtGetInputName(sessionHandle: number, index: number): number; +@@ -131,8 +128,7 @@ export interface OrtWasmModule extends EmscriptenModule { + jsepInit? + (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, + download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, +- releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction, +- captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void; ++ releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; + + /** + * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). +@@ -162,6 +158,12 @@ export interface OrtWasmModule extends EmscriptenModule { + * @returns the GPU data ID for the registered GPU buffer. + */ + jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; ++ /** ++ * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. ++ * ++ * @param sessionId - specify the session ID. ++ */ ++ jsepUnregisterBuffers?: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. + * +@@ -180,19 +182,6 @@ export interface OrtWasmModule extends EmscriptenModule { + jsepCreateDownloader: + (gpuBuffer: GPUBuffer, size: number, + type: Tensor.GpuBufferDataTypes) => () => Promise; +- /** +- * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before +- * _OrtRun[WithBinding]() is called. +- * @param sessionId - specify the session ID. +- */ +- jsepOnRunStart: (sessionId: number) => void; +- /** +- * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is +- * called. +- * @param sessionId - specify the session ID. +- * @returns +- */ +- jsepOnReleaseSession: (sessionId: number) => void; + // #endregion + } + +diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts +index 98990a6fe..2956ec1ca 100644 +--- a/js/web/lib/wasm/jsep/backend-webgpu.ts ++++ b/js/web/lib/wasm/jsep/backend-webgpu.ts +@@ -3,21 +3,14 @@ + + import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; + +-import {DataType, tensorDataTypeEnumToString} from '../wasm-common'; ++import {tensorDataTypeEnumToString} from '../wasm-common'; + + import {configureLogger, LOG_DEBUG} from './log'; + import {createView, TensorView} from './tensor-view'; + import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; + import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; + import {ProgramManager} from './webgpu/program-manager'; +-import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; +- +-interface CommandInfo { +- readonly kernelId: number; +- readonly computePipeline: GPUComputePipeline; +- readonly bindGroup: GPUBindGroup; +- readonly dispatchGroup: [number, number, number]; +-} ++import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, TimestampQuery} from './webgpu/types'; + + interface KernelInfo { + readonly kernelType: string; +@@ -110,13 +103,6 @@ export class WebGpuBackend { + */ + programManager: ProgramManager; + +- /** +- * representing the session ID of which is currently being run. +- * `null` means no session is being run. +- * only valid when session.run is executed. +- */ +- currentSessionId: number|null = null; +- + /** + * representing the kernel ID of which is currently being computed (CPU code perspective). + * `null` means no kernel is being computed. +@@ -169,16 +155,6 @@ export class WebGpuBackend { + queryType: TimestampQuery; + + env: Env; +- sessionStatus: SessionState = 'default'; +- /** +- * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session. +- */ +- capturedCommandList: Map = new Map(); +- +- /** +- * a SessionID -> PendingKernelInfo[] mapping for profiling. +- */ +- private capturedPendingKernels: Map = new Map(); + + /** + * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. +@@ -232,7 +208,7 @@ export class WebGpuBackend { + + Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); + +- // init queryType, which is necessary for InferenceSession.create ++ // init queryType, which is necessary for createKernel + this.setQueryType(); + } + +@@ -246,13 +222,24 @@ export class WebGpuBackend { + getCommandEncoder(): GPUCommandEncoder { + if (!this.commandEncoder) { + this.commandEncoder = this.device.createCommandEncoder(); ++ ++ // refresh queryType, as sometimes we only need to enable query for a specific run ++ this.setQueryType(); ++ if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { ++ this.querySet = this.device.createQuerySet({ ++ type: 'timestamp', ++ count: this.maxDispatchNumber * 2, ++ }); ++ this.queryResolveBuffer = this.device.createBuffer( ++ // eslint-disable-next-line no-bitwise ++ {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE}); ++ } + } + return this.commandEncoder; + } + + getComputePassEncoder(): GPUComputePassEncoder { + if (!this.computePassEncoder) { +- const commandEncoder = this.getCommandEncoder(); + const computePassDescriptor: GPUComputePassDescriptor = {}; + + if (this.queryType === 'at-passes') { +@@ -263,7 +250,7 @@ export class WebGpuBackend { + }; + } + +- this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor); ++ this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor); + } + return this.computePassEncoder; + } +@@ -453,26 +440,13 @@ export class WebGpuBackend { + return; + } + // https://www.w3.org/TR/WGSL/#alignof +- const sizeOfElement = v.type === DataType.float16 ? 2 : 4; +- let sizeOfVecOrMat; +- let baseAlignment; +- if (v.type === DataType.float16) { +- baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); +- sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length; +- } else { +- baseAlignment = data.length <= 2 ? data.length * sizeOfElement : 16; +- sizeOfVecOrMat = 16; +- } ++ const baseAlignment = data.length <= 2 ? data.length * 4 : 16; + currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment; + offsets.push(currentOffset); +- // For non-float16 type, when data.length > 4, the uniform variable is of type array,N>, where +- // N = Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * +- // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type +- // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte +- // length is N * SizeOf(mat2x4). +- const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4; +- currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : +- data.length * sizeOfElement; ++ // When data.length > 4, the uniform variable is of type array,N>, where N = ++ // Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * ++ // SizeOf(vec4). ++ currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4; + }); + + // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set +@@ -483,17 +457,12 @@ export class WebGpuBackend { + programUniforms.forEach((v, i) => { + const offset = offsets[i]; + const data = typeof v.data === 'number' ? [v.data] : v.data; +- if (v.type === DataType.int32) { ++ if (v.type === 'int32') { + new Int32Array(arrayBuffer, offset, data.length).set(data); +- } else if (v.type === DataType.uint32) { ++ } else if (v.type === 'uint32') { + new Uint32Array(arrayBuffer, offset, data.length).set(data); +- } else if (v.type === DataType.float16) { +- // TODO: use Float16Array. +- new Uint16Array(arrayBuffer, offset, data.length).set(data); +- } else if (v.type === DataType.float) { +- new Float32Array(arrayBuffer, offset, data.length).set(data); + } else { +- throw new Error(`Unsupported uniform type: ${tensorDataTypeEnumToString(v.type)}`); ++ new Float32Array(arrayBuffer, offset, data.length).set(data); + } + }); + +@@ -521,7 +490,7 @@ export class WebGpuBackend { + () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ + normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); + +- if (this.queryType !== 'none' || this.sessionStatus === 'capturing') { ++ if (this.queryType !== 'none') { + const pendingKernelInfo: PendingKernelInfo = { + kernelId: this.currentKernelId!, + programName: artifact.programInfo.name, +@@ -529,11 +498,6 @@ export class WebGpuBackend { + outputTensorViews, + }; + this.pendingKernels.push(pendingKernelInfo); +- +- if (this.sessionStatus === 'capturing') { +- const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); +- sessionPendingKernels!.push(pendingKernelInfo); +- } + } + + this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding); +@@ -675,7 +639,6 @@ export class WebGpuBackend { + return createView(data.buffer, type); + }; + } +- // #endregion + writeTimestamp(index: number): void { + if (this.queryType !== 'inside-passes') { + return; +@@ -692,81 +655,7 @@ export class WebGpuBackend { + } else if (this.device.features.has('timestamp-query')) { + this.queryType = 'at-passes'; + } +- +- if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { +- this.querySet = this.device.createQuerySet({ +- type: 'timestamp', +- count: this.maxDispatchNumber * 2, +- }); +- this.queryResolveBuffer = this.device.createBuffer( +- // eslint-disable-next-line no-bitwise +- {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE}); +- } +- } +- } +- +- captureBegin(): void { +- LOG_DEBUG('info', 'captureBegin'); +- if (!this.capturedCommandList.get(this.currentSessionId!)) { +- this.capturedCommandList.set(this.currentSessionId!, []); +- } +- if (!this.capturedPendingKernels.get(this.currentSessionId!)) { +- this.capturedPendingKernels.set(this.currentSessionId!, []); + } +- // flush the left commands before we change the status. +- this.flush(); +- this.sessionStatus = 'capturing'; +- } +- captureEnd(): void { +- LOG_DEBUG('info', 'captureEnd'); +- // flush the left commands before we change the status. +- this.flush(); +- this.sessionStatus = 'default'; +- } +- replay(): void { +- LOG_DEBUG('info', 'replay'); +- this.sessionStatus = 'replaying'; +- const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); +- const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); +- const length = sessionCommandList!.length; +- this.pendingKernels = []; +- for (let i = 0; i < length; i++) { +- const computePassEncoder = this.getComputePassEncoder(); +- const command = sessionCommandList![i]; +- this.writeTimestamp(this.pendingDispatchNumber * 2); +- computePassEncoder.setPipeline(command.computePipeline); +- computePassEncoder.setBindGroup(0, command.bindGroup); +- computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); +- this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); +- this.pendingDispatchNumber++; +- if (this.queryType !== 'none') { +- this.pendingKernels.push(sessionPendingKernels![i]); +- } +- if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') { +- this.endComputePass(); +- } +- if (this.pendingDispatchNumber >= this.maxDispatchNumber) { +- this.flush(); +- } +- } +- // flush the left commands before we change the status. +- this.flush(); +- this.sessionStatus = 'default'; +- } +- +- onReleaseSession(sessionId: number): void { +- this.unregisterBuffers(sessionId); +- if (this.capturedCommandList.has(sessionId)) { +- this.capturedCommandList.delete(sessionId); +- } +- if (this.capturedPendingKernels.has(sessionId)) { +- this.capturedPendingKernels.delete(sessionId); +- } +- this.gpuDataManager.onReleaseSession(sessionId); +- } +- +- onRunStart(sessionId: number): void { +- this.currentSessionId = sessionId; +- this.setQueryType(); + } ++ // #endregion + } +diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts +index 786ae4164..f1794d715 100644 +--- a/js/web/lib/wasm/jsep/init.ts ++++ b/js/web/lib/wasm/jsep/init.ts +@@ -201,11 +201,5 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte + contextDataOffset}`); + const context = new ComputeContextImpl(module, backend, contextDataOffset); + return backend.computeKernel(kernel, context, errors); +- }, +- // jsepCaptureBegin +- () => backend.captureBegin(), +- // jsepCaptureEnd +- () => backend.captureEnd(), +- // jsepReplay +- () => backend.replay()); ++ }); + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +index c17bd1e14..6f3d9a52d 100644 +--- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts ++++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +@@ -60,15 +60,9 @@ export interface GpuDataManager { + unregisterExternalBuffer(buffer: GPUBuffer): void; + + /** +- * destroy all gpu buffers. ++ * destroy all gpu buffers. Call this when the session.release is called. + */ + dispose(): void; +- +- /** +- * release session related data. +- * @param sessionId - specify the session ID. +- */ +- onReleaseSession(sessionId: number): void; + } + + interface StorageCacheValue { +@@ -145,10 +139,6 @@ class GpuDataManagerImpl implements GpuDataManager { + // The external buffers registered users for IO Binding. + private externalBuffers: Map; + +- // The pendingBuffers for capture graph. +- // a SessionID -> GPUBuffer[] mapping. +- private capturedPendingBuffers: Map; +- + constructor(private backend: WebGpuBackend) { + this.storageCache = new Map(); + this.freeBuffers = new Map(); +@@ -156,7 +146,6 @@ class GpuDataManagerImpl implements GpuDataManager { + this.buffersForUploadingPending = []; + this.buffersPending = []; + this.externalBuffers = new Map(); +- this.capturedPendingBuffers = new Map(); + } + + upload(id: GpuDataId, data: Uint8Array): void { +@@ -231,9 +220,6 @@ class GpuDataManagerImpl implements GpuDataManager { + () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ + id}, buffer is the same, skip.`); + return id; +- } else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) { +- throw new Error(`Registering a different external buffer under graph capture mode is not supported yet. +- Please use the previous external buffer!`); + } + this.externalBuffers.delete(previousBuffer); + } else { +@@ -326,39 +312,20 @@ class GpuDataManagerImpl implements GpuDataManager { + buffer.destroy(); + } + this.buffersForUploadingPending = []; +- +- if (this.buffersPending.length === 0) { +- return; +- } +- +- if (this.backend.sessionStatus === 'default') { +- for (const buffer of this.buffersPending) { ++ for (const buffer of this.buffersPending) { ++ // eslint-disable-next-line no-bitwise ++ if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { ++ // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. ++ this.freeBuffers.get(buffer.size)!.push(buffer); + // eslint-disable-next-line no-bitwise +- if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { +- // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. +- this.freeBuffers.get(buffer.size)!.push(buffer); +- // eslint-disable-next-line no-bitwise +- } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { +- // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. +- this.freeUniformBuffers.get(buffer.size)!.push(buffer); +- } else { +- buffer.destroy(); +- } +- } +- this.buffersPending = []; +- } else { +- // Don't release intermediate tensors in non-default mode. +- // TODO: reuse the storage buffers in non-default mode. +- let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!); +- if (!capturedBuffers) { +- capturedBuffers = []; +- this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers); +- } +- for (const buffer of this.buffersPending) { +- capturedBuffers.push(buffer); ++ } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { ++ // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. ++ this.freeUniformBuffers.get(buffer.size)!.push(buffer); ++ } else { ++ buffer.destroy(); + } +- this.buffersPending = []; + } ++ this.buffersPending = []; + } + + dispose() { +@@ -377,26 +344,9 @@ class GpuDataManagerImpl implements GpuDataManager { + storage.gpuData.buffer.destroy(); + }); + +- this.capturedPendingBuffers.forEach((buffers) => { +- buffers.forEach(buffer => { +- buffer.destroy(); +- }); +- }); + this.storageCache = new Map(); + this.freeBuffers = new Map(); + this.freeUniformBuffers = new Map(); +- this.capturedPendingBuffers = new Map(); +- } +- +- onReleaseSession(sessionId: number) { +- // release the captured pending buffers. +- const pendingBuffers = this.capturedPendingBuffers.get(sessionId); +- if (pendingBuffers) { +- pendingBuffers.forEach(buffer => { +- buffer.destroy(); +- }); +- this.capturedPendingBuffers.delete(sessionId); +- } + } + } + +diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +index d737a2865..90e02da98 100644 +--- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts ++++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +@@ -25,7 +25,7 @@ import * as pool from './ops/pool'; + import {range} from './ops/range'; + import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; + import {parseResizeAttributes, resize} from './ops/resize'; +-import {skipLayerNorm} from './ops/skip-layer-norm'; ++import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; + import {parseSliceAttributes, slice} from './ops/slice'; + import {parseSoftmaxAttributes, softmax} from './ops/softmax'; + import {parseSplitAttributes, split} from './ops/split'; +@@ -82,7 +82,6 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new + ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], + ['Greater', [binaryOps.greater]], + ['GreaterOrEqual', [binaryOps.greaterOrEqual]], +- ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], + ['InstanceNormalization', [instanceNorm]], + ['LayerNormalization', [layerNorm]], + ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], +@@ -116,7 +115,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new + ['Sin', [unaryOps.sin]], + ['Sinh', [unaryOps.sinh]], + ['Slice', [slice, parseSliceAttributes]], +- ['SkipLayerNormalization', [skipLayerNorm]], ++ ['SkipLayerNormalization', [skipLayerNorm, parseSkipLayerNormAttributes]], + ['Split', [split, parseSplitAttributes]], + ['Sqrt', [unaryOps.sqrt]], + ['Softmax', [softmax, parseSoftmaxAttributes]], +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +index 24006d393..3638938df 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +@@ -19,13 +19,12 @@ + // + // modified to fit the needs of the project + +-import {DataType} from '../../../../wasm-common'; + import {LOG_DEBUG} from '../../../log'; + import {TensorView} from '../../../tensor-view'; +-import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +-import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; ++import {ProgramInfo, ProgramUniform} from '../../types'; ++import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; + import {ConvAttributes} from '../conv'; +-import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; ++import {getActivationSnippet} from '../fuse-utils'; + + import {biasSnippet, typeSnippet} from './activation_util'; + import {utilFunctions} from './conv_util'; +@@ -89,10 +88,10 @@ const conv2dCommonSnippet = + let outRow = ${row} / outWidth; + let outCol = ${row} % outWidth; + +- let WRow = ${col} / (i32(uniforms.w_shape[1]) * inChannels); +- let WCol = ${col} / inChannels % i32(uniforms.w_shape[1]); +- let xRow = outRow * uniforms.stride[0] + uniforms.dilation[0] * WRow - uniforms.pad[0]; +- let xCol = outCol * uniforms.stride[1] + uniforms.dilation[1] * WCol - uniforms.pad[1]; ++ let WRow = ${col} / (filterDims[1] * inChannels); ++ let WCol = ${col} / inChannels % filterDims[1]; ++ let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; ++ let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; + let xCh = ${col} % inChannels; + var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0); + // The bounds checking is always needed since we use it to pad zero for +@@ -109,7 +108,7 @@ const conv2dCommonSnippet = + ${readXSnippet}` : + ` + let col = colIn * ${innerElementSizeX}; +- if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ++ if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + ${readXSnippet} + } + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : +@@ -118,7 +117,7 @@ const conv2dCommonSnippet = + ${readXSnippet}` : + ` + let col = colIn * ${innerElementSizeX}; +- if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ++ if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + ${readXSnippet} + } + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); +@@ -130,8 +129,9 @@ const conv2dCommonSnippet = + isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); + const bType = + isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); +- const applyActivation = getActivationSnippet(attributes, resType, dataType); ++ const {activationFunction, applyActivation} = getActivationSnippet(attributes, resType); + const userCode = ` ++ ${activationFunction} + fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { + ${isChannelsLast ? sampleX : sampleW} + } +@@ -142,7 +142,7 @@ const conv2dCommonSnippet = + + fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) { + let col = colIn * ${innerElementSize}; +- if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) ++ if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) + { + var value = valueIn; + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; +@@ -181,40 +181,31 @@ export const createConv2DMatMulProgramInfo = + LOG_DEBUG('verbose', () => `[conv2d_mm_webgpu] dispatch = ${dispatch}`); + + const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : 1; ++ + const tileAOuter = workGroupSize[1] * elementsPerThread[1]; + const tileBOuter = workGroupSize[0] * elementsPerThread[0]; + const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); ++ + const fitAOuter = dimAOuter % tileAOuter === 0; + const fitBOuter = dimBOuter % tileBOuter === 0; + const fitInner = dimInner % tileInner === 0; ++ + const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; ++ const t = tensorTypeToWsglStorageType(inputs[0].dataType); + +- const programUniforms: ProgramUniform[] = [ +- {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, +- {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]}, +- {type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations} +- ]; +- appendActivationUniformsData(attributes, programUniforms); +- programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); +- const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; +- if (hasBias) { +- programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); +- inputDependencies.push('rank'); +- } +- programUniforms.push(...createTensorShapeVariables(outputShape)); ++ // TODO: support component 2, 3. ++ const components = isVec4 ? 4 : 1; ++ const programUniforms: ProgramUniform[] = ++ [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; ++ const x = ++ inputVariable('x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); ++ const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); ++ const inputVariables = [x, w]; + +- const getShaderSource = (shaderHelper: ShaderHelper) => { +- const uniforms: UniformsArrayType = [ +- {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, +- {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2}, +- {name: 'dilation', type: 'i32', length: 2} +- ]; +- appendActivationUniforms(attributes, uniforms); ++ programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); ++ programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + +- // TODO: support component 2, 3. +- const components = isVec4 ? 4 : 1; +- const t = tensorTypeToWsglStorageType(inputs[0].dataType); +- let declareFunctions = ` ++ let declareFunctions = ` + fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { + result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); + } +@@ -222,50 +213,51 @@ export const createConv2DMatMulProgramInfo = + let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); + setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); + }`; +- const x = inputVariable( +- 'x', inputs[0].dataType, inputs[0].dims.length, innerElementSize === 3 ? 1 : innerElementSize); +- const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, components); +- const inputVariables = [x, w]; +- const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); +- if (hasBias) { +- const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); +- inputVariables.push(bias); +- declareFunctions += ` ++ if (hasBias) { ++ const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); ++ inputVariables.push(bias); ++ ++ programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); ++ ++ declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; +- } +- +- return ` ++ } ++ const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); ++ programUniforms.push(...createTensorShapeVariables(outputShape)); ++ return { ++ name: 'Conv2DMatMul', ++ shaderCache: {hint: attributes.cacheKey}, ++ getRunData: () => ({ ++ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], ++ dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, ++ programUniforms, ++ }), ++ getShaderSource: (shaderHelper: ShaderHelper) => ` + ${utilFunctions('uniforms.result_strides')} + //struct Uniforms { xShape : vec4, wShape : vec4, outShape : vec4, + // outShapeStrides: vec3, filterDims : vec2, pad : vec2, stride : vec2, + // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; +- ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ++ ${ ++ shaderHelper.registerUniform('dimAOuter', 'i32') ++ .registerUniform('dimBOuter', 'i32') ++ .registerUniform('dimInner', 'i32') ++ .declareVariables(...inputVariables, output)} ++ const filterDims : vec2 = vec2(${attributes.kernelShape[0]}, ${attributes.kernelShape[1]}); ++ const pad : vec2 = vec2(${attributes.pads[0]}, ${attributes.pads[1]}); ++ const stride : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); ++ const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + ${declareFunctions} + ${ + conv2dCommonSnippet( + isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, attributes, elementsSize[0], elementsSize[1], + elementsSize[2], t)} +- ${ ++ ${ + isVec4 ? + makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, +- sequentialAccessByThreads)}`; +- }; +- return { +- name: 'Conv2DMatMul', +- shaderCache: { +- hint: `${attributes.cacheKey};${innerElementSize};${isVec4};${fitAOuter};${fitBOuter};${fitInner};${ +- tileAOuter};${tileBOuter};${tileInner}`, +- inputDependencies +- }, +- getRunData: () => ({ +- outputs: [{dims: outputShape, dataType: inputs[0].dataType}], +- dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, +- programUniforms, +- }), +- getShaderSource ++ sequentialAccessByThreads)}` + }; + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +index b5b6a2a15..d42515585 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +@@ -19,13 +19,12 @@ + // + // modified to fit the needs of the project + +-import {DataType} from '../../../../wasm-common'; + import {LOG_DEBUG} from '../../../log'; + import {TensorView} from '../../../tensor-view'; +-import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +-import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; ++import {ProgramInfo, ProgramUniform} from '../../types'; ++import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from '../common'; + import {ConvTransposeAttributes} from '../conv-transpose'; +-import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; ++import {getActivationSnippet} from '../fuse-utils'; + + import {biasSnippet, typeSnippet} from './activation_util'; + import {utilFunctions} from './conv_util'; +@@ -75,21 +74,21 @@ const conv2dTransposeCommonSnippet = + col % outWidth); + `; + +- const xHeight = isChannelsLast ? 'i32(uniforms.x_shape[1])' : 'i32(uniforms.x_shape[2])'; +- const xWidth = isChannelsLast ? 'i32(uniforms.x_shape[2])' : 'i32(uniforms.x_shape[3])'; ++ const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; ++ const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const row = isChannelsLast ? 'row' : 'col'; + const col = isChannelsLast ? 'col' : 'row'; + + const readASnippet = ` +- let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; ++ let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; + let outRow = ${row} / outWidth; + let outCol = ${row} % outWidth; + +- let WRow = ${col} / (uniforms.filter_dims[1] * inChannels); +- let WCol = ${col} / inChannels % uniforms.filter_dims[1]; +- let xR = f32(outRow - uniforms.pads[0] + uniforms.dilations[0] * WRow) / f32(uniforms.strides[0]); +- let xC = f32(outCol - uniforms.pads[1] + uniforms.dilations[1] * WCol) / f32(uniforms.strides[1]); ++ let WRow = ${col} / (filterDims[1] * inChannels); ++ let WCol = ${col} / inChannels % filterDims[1]; ++ let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); ++ let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); + if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { + return ${type}(0.0); + } +@@ -104,25 +103,25 @@ const conv2dTransposeCommonSnippet = + + const sampleA = isChannelsLast ? ` + let col = colIn * ${innerElementSize}; +- if (row < uniforms.dim_a_outer && col < uniforms.dim_inner) { ++ if (row < uniforms.dimAOuter && col < uniforms.dimInner) { + ${readASnippet} + } + return ${type}(0.0);` : + ` + let col = colIn * ${innerElementSize}; +- if (row < uniforms.dim_inner && col < uniforms.dim_b_outer) { ++ if (row < uniforms.dimInner && col < uniforms.dimBOuter) { + ${readASnippet} + } + return ${type}(0.0);`; + + const sampleW = ` + let col = colIn * ${innerElementSize}; +- let inChannels = ${isChannelsLast ? 'i32(uniforms.x_shape[3])' : 'i32(uniforms.x_shape[1])'}; +- let coordX = uniforms.filter_dims[0] - 1 - row / (uniforms.filter_dims[1] * inChannels); +- let coordY = uniforms.filter_dims[1] - 1 - (row / inChannels) % uniforms.filter_dims[1]; ++ let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; ++ let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); ++ let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + if (${ +- isChannelsLast ? 'row < uniforms.dim_inner && col < uniforms.dim_b_outer' : +- 'row < uniforms.dim_inner && col < uniforms.dim_a_outer'} && coordX >= 0 && coordY >= 0) { ++ isChannelsLast ? 'row < uniforms.dimInner && col < uniforms.dimBOuter' : ++ 'row < uniforms.dimInner && col < uniforms.dimAOuter'} && coordX >= 0 && coordY >= 0) { + let rowInner = row % inChannels; + let coord = vec4(coordX, coordY, col, rowInner); + ${getWSnippet(innerElementSize)} +@@ -130,8 +129,9 @@ const conv2dTransposeCommonSnippet = + return ${type}(0.0); + `; + +- const applyActivation = getActivationSnippet(attributes, type); ++ const {activationFunction, applyActivation} = getActivationSnippet(attributes, type); + const userCode = ` ++ ${activationFunction} + fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { + ${isChannelsLast ? sampleA : sampleW} + } +@@ -142,7 +142,7 @@ const conv2dTransposeCommonSnippet = + + fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { + let col = colIn * ${innerElementSize}; +- if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { ++ if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + var value = valueInput; + let outWidth = ${isChannelsLast ? 'i32(uniforms.result_shape[2])' : 'i32(uniforms.result_shape[3])'}; + ${coordResSnippet} +@@ -186,59 +186,65 @@ export const createConv2DTransposeMatMulProgramInfo = + const innerElementSize = isVec4 ? 4 : 1; + const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + const components = isVec4 ? 4 : 1; +- const filterDims = +- [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; +- const effectiveFilterDims = [ +- filterDims[0] + (attributes.dilations[0] <= 1 ? 0 : (filterDims[0] - 1) * (attributes.dilations[0] - 1)), +- filterDims[1] + (attributes.dilations[1] <= 1 ? 0 : (filterDims[1] - 1) * (attributes.dilations[1] - 1)) +- ]; +- const pads = [ +- effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), +- effectiveFilterDims[1] - 1 - Math.floor((attributes.pads[1] + attributes.pads[3]) / 2) +- ]; +- +- const programUniforms: ProgramUniform[] = [ +- {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, +- {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: attributes.strides}, +- {type: DataType.int32, data: attributes.dilations}, {type: DataType.int32, data: filterDims}, +- {type: DataType.int32, data: pads} +- ]; +- appendActivationUniformsData(attributes, programUniforms); +- programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); ++ const programUniforms: ProgramUniform[] = ++ [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; ++ const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); ++ const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); ++ const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); ++ const inputVariables = [x, w]; ++ programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); ++ programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); + +- const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; ++ let declareFunctions = ''; + if (hasBias) { ++ const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); ++ inputVariables.push(bias); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); +- inputDependencies.push('rank'); +- } +- programUniforms.push(...createTensorShapeVariables(outputShape)); + +- const getShaderSource = (shaderHelper: ShaderHelper) => { +- const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length, components); +- const w = inputVariable('w', inputs[1].dataType, inputs[1].dims.length, 1); +- const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); +- const inputVariables = [x, w]; ++ declareFunctions += ` ++ fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { ++ return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; ++ }`; ++ } + +- let declareFunctions = ''; +- if (hasBias) { +- const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); +- inputVariables.push(bias); +- declareFunctions += ` +- fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { +- return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; +- }`; +- } ++ programUniforms.push(...createTensorShapeVariables(outputShape)); + +- const uniforms: UniformsArrayType = [ +- {name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}, +- {name: 'strides', type: 'i32', length: 2}, {name: 'dilations', type: 'i32', length: 2}, +- {name: 'filter_dims', type: 'i32', length: filterDims.length}, +- {name: 'pads', type: 'i32', length: pads.length} +- ]; +- appendActivationUniforms(attributes, uniforms); +- return ` ++ return { ++ name: 'Conv2DTransposeMatMul', ++ shaderCache: {hint: attributes.cacheKey}, ++ getRunData: () => ({ ++ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], ++ dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, ++ programUniforms ++ }), ++ getShaderSource: (shaderHelper: ShaderHelper) => ` + ${utilFunctions('uniforms.result_strides')} +- ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ++ ${ ++ shaderHelper.registerUniform('dimAOuter', 'i32') ++ .registerUniform('dimBOuter', 'i32') ++ .registerUniform('dimInner', 'i32') ++ .declareVariables(...inputVariables, output)}; ++ const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); ++ const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ ++ attributes.kernelShape[isChannelsLast ? 2 : 3]}); ++ const effectiveFilterDims : vec2 = filterDims + vec2( ++ ${ ++ attributes.dilations[0] <= 1 ? ++ 0 : ++ (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, ++ ${ ++ attributes.dilations[1] <= 1 ? ++ 0 : ++ (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); ++ const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ ++ attributes.pads[0] + attributes.pads[2]})/2, ++ i32(effectiveFilterDims[1]) - 1 - (${ ++ attributes.pads[1] + attributes.pads[3]})/2); ++ const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); ++ const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); ++ const dimAOuter : i32 = ${dimAOuter}; ++ const dimBOuter : i32 = ${dimBOuter}; ++ const dimInner : i32 = ${dimInner}; + ${declareFunctions} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} + ${ +@@ -246,18 +252,6 @@ export const createConv2DTransposeMatMulProgramInfo = + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, +- undefined, sequentialAccessByThreads)}`; +- }; +- +- return { +- name: 'Conv2DTransposeMatMul', +- shaderCache: +- {hint: `${attributes.cacheKey};${elementsPerThread};${workGroupSize};${isVec4}`, inputDependencies}, +- getRunData: () => ({ +- outputs: [{dims: outputShape, dataType: inputs[0].dataType}], +- dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, +- programUniforms +- }), +- getShaderSource ++ undefined, sequentialAccessByThreads)}` + }; + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +index 846ad49c5..50b0841a0 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +@@ -17,22 +17,27 @@ + + // sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts + +-import {DataType} from '../../../../wasm-common'; + import {LOG_DEBUG} from '../../../log'; + import {TensorView} from '../../../tensor-view'; + import {ShapeUtil} from '../../../util'; +-import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +-import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; ++import {ProgramInfo} from '../../types'; ++import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; + import {ConvTransposeAttributes} from '../conv-transpose'; + + const createConvTranspose2DOpProgramShaderSource = +- (shaderHelper: ShaderHelper, inputs: readonly TensorView[], outputShape: readonly number[], hasBias: boolean, +- is1DimensionDispatch: boolean, isVec4 = false, dataType: string, uniforms: UniformsArrayType, +- isChannelsLast = false): string => { ++ (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, ++ outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, ++ dataType: string): string => { ++ const isChannelsLast = attributes.format === 'NHWC'; + const rowDim = isChannelsLast ? 1 : 2; + const colDim = isChannelsLast ? 2 : 3; + const channelDim = isChannelsLast ? 3 : 1; ++ const outputSize = ShapeUtil.size(outputShape); + const workPerThread = isVec4 ? 2 : 1; ++ const group = attributes.group; ++ const wShape = inputs[1].dims; ++ const inputChannelsPerGroup = wShape[0] / group; ++ const outputChannelsPerGroup = wShape[1]; + + let declareFunctions = ` + fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { +@@ -45,21 +50,20 @@ const createConvTranspose2DOpProgramShaderSource = + }`; + } + const components = isVec4 ? 4 : 1; +- const w = inputVariable('W', inputs[1].dataType, inputs[1].dims.length, components); +- const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims.length, components); ++ const w = inputVariable('W', inputs[1].dataType, inputs[1].dims, components); ++ const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims, components); + const inputVariables = [dy, w]; + if (hasBias) { +- inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]].length, components)); ++ inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]], components)); + } +- const output = outputVariable('result', inputs[0].dataType, outputShape.length, components); +- ++ const output = outputVariable('result', inputs[0].dataType, outputShape, components); + const codeSnippet4 = `{ +- let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / uniforms.result_shape[1]; +- let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % uniforms.result_shape[1]; ++ let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / outShape[1]; ++ let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % outShape[1]; + let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; + let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4; + +- let dyCorner = vec2(i32(r), i32(c)) - vec2(uniforms.pads); ++ let dyCorner = vec2(i32(r), i32(c)) - vec2(pads); + + // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). + // ? = to be determined. : = across all values in that axis. +@@ -67,29 +71,29 @@ const createConvTranspose2DOpProgramShaderSource = + for (var i = 0; i < ${workPerThread}; i++) { + dotProd[i] = vec4<${dataType}>(0.0); + } +- for (var wR: u32 = 0; wR < uniforms.filter_dims[0]; wR = wR + 1) { +- var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(uniforms.strides.x); +- let wRPerm = uniforms.filter_dims[0] - 1 - wR; +- if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[1]) || ++ for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { ++ var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); ++ let wRPerm = filterDims[0] - 1 - wR; ++ if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || + fract(dyR) > 0.0 || wRPerm < 0) { + continue; + } + let idyR: u32 = u32(dyR); + +- for (var wC: u32 = 0; wC < uniforms.filter_dims[1]; wC = wC + 1) { +- let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); +- let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); +- let wCPerm = uniforms.filter_dims[1] - 1 - wC; ++ for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { ++ let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); ++ let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); ++ let wCPerm = filterDims[1] - 1 - wC; + if (wCPerm < 0) { + continue; + } + var bDyCVal = true; + var bDyCVal2 = true; +- if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[2]) || ++ if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || + fract(dyC) > 0.0) { + bDyCVal = false; + } +- if (dyC2 < 0.0 || dyC2 >= ${dataType}(uniforms.Dy_shape[2]) || ++ if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || + fract(dyC2) > 0.0) { + bDyCVal2 = false; + } +@@ -97,7 +101,7 @@ const createConvTranspose2DOpProgramShaderSource = + let idyC: u32 = u32(dyC); + let idyC2: u32 = u32(dyC2); + if (bDyCVal && bDyCVal2) { +- let d2Length = uniforms.Dy_shape[3]; ++ let d2Length = outBackprop[3]; + for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) { + let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; + let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; +@@ -119,7 +123,7 @@ const createConvTranspose2DOpProgramShaderSource = + dot(xValue, wValue3)); + } + } else if (bDyCVal) { +- let d2Length = uniforms.Dy_shape[${channelDim}]; ++ let d2Length = outBackprop[${channelDim}]; + for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { + let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; + let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; +@@ -134,7 +138,7 @@ const createConvTranspose2DOpProgramShaderSource = + dotProd[0] = dotProd[0] + tmpval; + } + } else if (bDyCVal2) { +- let d2Length = uniforms.Dy_shape[3]; ++ let d2Length = outBackprop[3]; + for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { + let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; + let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; +@@ -163,39 +167,39 @@ const createConvTranspose2DOpProgramShaderSource = + let d1 = ${output.indicesGet('outputIndices', channelDim)}; + let r = ${output.indicesGet('outputIndices', rowDim)}; + let c = ${output.indicesGet('outputIndices', colDim)}; +- let dyCorner = vec2(i32(r), i32(c)) - uniforms.pads; ++ let dyCorner = vec2(i32(r), i32(c)) - pads; + let dyRCorner = dyCorner.x; + let dyCCorner = dyCorner.y; +- let groupId = d1 / uniforms.output_channels_per_group; +- let wOutChannel = d1 - groupId * uniforms.output_channels_per_group; ++ let groupId = d1 / ${outputChannelsPerGroup}; ++ let wOutChannel = d1 - groupId * ${outputChannelsPerGroup}; + // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). + // ? = to be determined. : = across all values in that axis. + var dotProd = ${dataType}(0.0); +- for (var wR: u32 = 0; wR < uniforms.effective_filter_dims.x; wR = wR + 1) { +- if (wR % uniforms.dilations.x != 0) { ++ for (var wR: u32 = 0; wR < effectiveFilterDims.x; wR = wR + 1) { ++ if (wR % dilations.x != 0) { + continue; + } +- let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(uniforms.strides[0]); +- let wRPerm = uniforms.filter_dims.x - 1 - wR / uniforms.dilations.x; +- if (dyR < 0.0 || dyR >= ${dataType}(uniforms.Dy_shape[${rowDim}]) || fract(dyR) > 0.0 || ++ let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); ++ let wRPerm = filterDims.x - 1 - wR / dilations.x; ++ if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + wRPerm < 0) { + continue; + } + let idyR: u32 = u32(dyR); + +- for (var wC: u32 = 0; wC < uniforms.effective_filter_dims.y; wC = wC + 1) { +- if (wC % uniforms.dilations.y != 0) { ++ for (var wC: u32 = 0; wC < effectiveFilterDims.y; wC = wC + 1) { ++ if (wC % dilations.y != 0) { + continue; + } +- let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(uniforms.strides.y); +- let wCPerm = uniforms.filter_dims.y - 1 - wC / uniforms.dilations.y; +- if (dyC < 0.0 || dyC >= ${dataType}(uniforms.Dy_shape[${colDim}]) || ++ let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); ++ let wCPerm = filterDims.y - 1 - wC / dilations.y; ++ if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || + fract(dyC) > 0.0 || wCPerm < 0) { + continue; + } + let idyC: u32 = u32(dyC); +- var inputChannel = groupId * uniforms.input_channels_per_group; +- for (var d2: u32 = 0; d2 < uniforms.input_channels_per_group; d2 = d2 + 1) { ++ var inputChannel = groupId * ${inputChannelsPerGroup}; ++ for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { + let xValue = ${ + isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : + dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; +@@ -210,11 +214,27 @@ const createConvTranspose2DOpProgramShaderSource = + `; + + return ` +- ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} ++ ${shaderHelper.declareVariables(...inputVariables, output)} + ${declareFunctions} +- ++ const outShape : vec4 = vec4(${outputShape.join(',')}); ++ const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); ++ const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); ++ const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ ++ attributes.kernelShape[isChannelsLast ? 2 : 3]}); ++ const dilations : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); ++ const effectiveFilterDims : vec2 = filterDims + vec2( ++ ${ ++ attributes.dilations[0] <= 1 ? ++ 0 : ++ (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, ++ ${ ++ attributes.dilations[1] <= 1 ? ++ 0 : ++ (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); ++ const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${attributes.pads[0] + attributes.pads[2]})/2, ++ i32(effectiveFilterDims[1]) - 1 - (${attributes.pads[1] + attributes.pads[3]})/2); + ${shaderHelper.mainStart()} +- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}; ++ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}; + ${isVec4 ? codeSnippet4 : codeSnippet}}`; + }; + +@@ -237,73 +257,19 @@ export const createConvTranspose2DProgramInfo = + ]; + LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); + +- const isChannelsLast = attributes.format === 'NHWC'; +- const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; +- const strides = [attributes.strides[0], attributes.strides[1]]; +- const filterDims = +- [attributes.kernelShape[isChannelsLast ? 1 : 2], attributes.kernelShape[isChannelsLast ? 2 : 3]]; +- const dilations = [attributes.dilations[0], attributes.dilations[1]]; +- const effectiveFilterDims = [ +- filterDims[0] + +- (attributes.dilations[0] <= 1 ? +- 0 : +- (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)), +- filterDims[1] + +- (attributes.dilations[1] <= 1 ? +- 0 : +- (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)) +- ]; +- const pads = [ +- effectiveFilterDims[0] - 1 - Math.floor((attributes.pads[0] + attributes.pads[2]) / 2), +- effectiveFilterDims[1] - 1 - Math.floor(attributes.pads[1] + attributes.pads[3]) / 2 +- ]; +- +- const isVec4 = false; +- const group = attributes.group; +- const wShape = inputs[1].dims; +- const inputChannelsPerGroup = wShape[0] / group; +- const outputChannelsPerGroup = wShape[1]; +- +- const programUniforms: ProgramUniform[] = [ +- {type: DataType.int32, data: outputSize}, {type: DataType.uint32, data: strides}, +- {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations}, +- {type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads}, +- {type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup}, +- ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims) +- ]; +- if (hasBias) { +- programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); +- inputDependencies.push('rank'); +- } +- programUniforms.push(...createTensorShapeVariables(outputShape)); +- +- const is1DimensionDispatch = dispatch[1] === 1 && dispatch[2] === 1; +- const getShaderSource = (shaderHelper: ShaderHelper) => { +- const uniforms: UniformsArrayType = [ +- {name: 'output_size', type: 'u32'}, {name: 'strides', type: 'u32', length: strides.length}, +- {name: 'filter_dims', type: 'u32', length: filterDims.length}, +- {name: 'dilations', type: 'u32', length: filterDims.length}, +- {name: 'effective_filter_dims', type: 'u32', length: effectiveFilterDims.length}, +- {name: 'pads', type: 'i32', length: pads.length}, {name: 'input_channels_per_group', type: 'u32'}, +- {name: 'output_channels_per_group', type: 'u32'} +- ]; +- const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); +- return `${ +- createConvTranspose2DOpProgramShaderSource( +- shaderHelper, inputs, outputShape, hasBias, is1DimensionDispatch, isVec4, dataType, uniforms, +- isChannelsLast)}`; +- }; ++ const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return { + name: 'ConvTranspose2D', +- shaderCache: {hint: `${attributes.cacheKey};`, inputDependencies}, ++ shaderCache: {hint: attributes.cacheKey}, + getRunData: () => ({ + dispatchGroup: {x: dispatch[0], y: dispatch[1], z: dispatch[2]}, + outputs: [{ + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType +- }], +- programUniforms ++ }] + }), +- getShaderSource ++ getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( ++ shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, ++ dataType), + }; + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +index 29c7941e6..47ec16a29 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +@@ -19,12 +19,11 @@ + // + // modified to fit the needs of the project + +-import {DataType} from '../../../../wasm-common'; + import {TensorView} from '../../../tensor-view'; + import {ShapeUtil} from '../../../util'; + import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; +-import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; +-import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; ++import {createTensorShapeVariables, enableShapesUniforms, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; ++import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; + + import {typeSnippet} from './activation_util'; + +@@ -113,14 +112,14 @@ fn main(@builtin(local_invocation_id) localId : vec3, + ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} + let globalRowStart = i32(workgroupId.y) * ${tileAOuter}; + +- let num_tiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; ++ let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; + + var acc: array, rowPerThread>; + + // Loop over shared dimension. + let tileRowB = localRow * ${rowPerThreadB}; +- for (var t = 0; t < num_tiles; t = t + 1) { ++ for (var t = 0; t < numTiles; t = t + 1) { + // Load one tile of A into local memory. + for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { + let inputRow = tileRow + innerRow; +@@ -205,7 +204,7 @@ export const makeMatMulPackedSource = + let globalColStart = i32(workgroupId.x) * ${tileBOuter}; + + // Loop over shared dimension. +- for (var t = 0; t < num_tiles; t = t + 1) { ++ for (var t = 0; t < numTiles; t = t + 1) { + // Load one tile of A into local memory. + for (var inputRow = localRow; inputRow < ${tileAHight}; inputRow = inputRow + ${workgroupSize[1]}) { + for (var inputCol = localCol; inputCol < ${tileAWidth}; inputCol = inputCol + ${workgroupSize[0]}) { +@@ -261,7 +260,7 @@ let tileRowA = i32(localId.y) * ${rowPerThreadA}; + let tileColA = i32(localId.x) * ${colPerThreadA}; + let tileRowB = i32(localId.y) * ${rowPerThreadB}; + // Loop over shared dimension. +-for (var t = 0; t < num_tiles; t = t + 1) { ++for (var t = 0; t < numTiles; t = t + 1) { + // Load one tile of A into local memory. + for (var innerRow = 0; innerRow < ${rowPerThreadA}; innerRow = innerRow + 1) { + for (var innerCol = 0; innerCol < ${colPerThreadA}; innerCol = innerCol + 1) { +@@ -323,8 +322,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, + @builtin(workgroup_id) workgroupId : vec3) { + let batch = ${splitK ? '0' : 'i32(globalId.z)'}; + ${batchDims ? `let batchIndices = ${batchDims.offsetToIndices('u32(batch)')};` : ''} +- let num_tiles = ${ +- splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dim_inner - 1) / tileInner + 1'}; ++ let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(uniforms.dimInner - 1) / tileInner + 1'}; + var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; + + var acc : array, rowPerThread>; +@@ -381,7 +379,7 @@ const matMulReadWriteFnSource = + typeSnippet(component, dataType)} { + var value = ${typeSnippet(component, dataType)}(0.0); + let col = colIn * ${component}; +- if(row < uniforms.dim_a_outer && col < uniforms.dim_inner) ++ if(row < uniforms.dimAOuter && col < uniforms.dimInner) + { + ${getAIndices()} + value = ${aVariable.getByIndices('aIndices')}; +@@ -393,7 +391,7 @@ const matMulReadWriteFnSource = + typeSnippet(component, dataType)} { + var value = ${typeSnippet(component, dataType)}(0.0); + let col = colIn * ${component}; +- if(row < uniforms.dim_inner && col < uniforms.dim_b_outer) ++ if(row < uniforms.dimInner && col < uniforms.dimBOuter) + { + ${getBIndices()} + value = ${bVariable.getByIndices('bIndices')}; +@@ -403,7 +401,7 @@ const matMulReadWriteFnSource = + + fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { + let col = colIn * ${component}; +- if (row < uniforms.dim_a_outer && col < uniforms.dim_b_outer) { ++ if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) { + var value = valueIn; + let coords = vec3(batch, row, colIn); + ${ +@@ -424,10 +422,16 @@ export const createMatmulProgramInfo = + isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => { + const aShape = inputs[0].dims; + const bShape = inputs[1].dims; ++ + const outerDimsA = aShape.slice(0, -2); + const outerDimsB = bShape.slice(0, -2); ++ + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); ++ const enableBatchUniforms = enableShapesUniforms(outerDims.length); ++ const batchShapeOrRank = enableBatchUniforms ? outerDims.length : outerDims; ++ const batchDims = internalVariable('batchDims', inputs[0].dataType, batchShapeOrRank, 1); + const batchSize = ShapeUtil.size(outerDims); ++ + const dimAOuter = aShape[aShape.length - 2]; + const dimInner = aShape[aShape.length - 1]; + const dimBOuter = bShape[bShape.length - 1]; +@@ -442,62 +446,72 @@ export const createMatmulProgramInfo = + Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) + ]; + ++ const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + const components = isVec4 ? 4 : 1; ++ + const aShapeTemp = [...outerDimsA, dimAOuter, dimInner / components]; +- const aRank = aShapeTemp.length; ++ const enableAShapesUniforms = enableShapesUniforms(aShapeTemp.length); ++ const aShapeOrRank = enableAShapesUniforms ? aShapeTemp.length : aShapeTemp; ++ + const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; +- const bRank = bShapeTemp.length; ++ const enableBShapesUniforms = enableShapesUniforms(bShapeTemp.length); ++ const bShapeOrRank = enableBShapesUniforms ? bShapeTemp.length : bShapeTemp; ++ + const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; +- const programUniforms: ProgramUniform[] = [ +- {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, +- {type: DataType.int32, data: dimInner} +- ]; +- appendActivationUniformsData(activationAttributes, programUniforms); +- programUniforms.push(...createTensorShapeVariables(outerDims, aShapeTemp, bShapeTemp)); +- const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; ++ ++ const A = inputVariable('a', inputs[0].dataType, aShapeOrRank, components); ++ const B = inputVariable('b', inputs[1].dataType, bShapeOrRank, components); ++ const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); ++ const inputVariables = [A, B]; ++ const programUniforms: ProgramUniform[] = ++ [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; ++ if (enableBatchUniforms) { ++ programUniforms.push(...createTensorShapeVariables(outerDims)); ++ } ++ if (enableAShapesUniforms) { ++ programUniforms.push(...createTensorShapeVariables(aShapeTemp)); ++ } ++ if (enableBShapesUniforms) { ++ programUniforms.push(...createTensorShapeVariables(bShapeTemp)); ++ } ++ const inputDependencies: ProgramInputTensorInfoDependency[] = []; ++ inputDependencies.push(enableAShapesUniforms ? 'rank' : 'dims'); ++ inputDependencies.push(enableBShapesUniforms ? 'rank' : 'dims'); + + const hasBias = inputs.length > 2; ++ const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); ++ const declareFunctions = matMulReadWriteFnSource( ++ components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], ++ isChannelsLast); + if (hasBias) { ++ const biasComponents = isChannelsLast ? components : 1; ++ inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); ++ + inputDependencies.push('rank'); + } + programUniforms.push(...createTensorShapeVariables(outputShapeTemp)); + +- const getShaderSource = (shaderHelper: ShaderHelper) => { +- const batchRank = outerDims.length; +- const batchDims = internalVariable('batchDims', inputs[0].dataType, batchRank, 1); +- const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); +- +- const A = inputVariable('a', inputs[0].dataType, aRank, components); +- const B = inputVariable('b', inputs[1].dataType, bRank, components); +- const output = outputVariable('result', inputs[0].dataType, outputShapeTemp.length, components); +- const inputVariables = [A, B]; +- if (hasBias) { +- const biasComponents = isChannelsLast ? components : 1; +- inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents)); +- } +- const uniforms: UniformsArrayType = +- [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; +- appendActivationUniforms(activationAttributes, uniforms); +- const baseType = tensorTypeToWsglStorageType(output.type.tensor); +- const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); +- const declareFunctions = matMulReadWriteFnSource( +- components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], +- isChannelsLast); +- return ` ++ const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${ +- shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( +- ...inputVariables, output)} ++ shaderHelper.registerUniform('dimAOuter', 'i32') ++ .registerUniform('dimBOuter', 'i32') ++ .registerUniform('dimInner', 'i32') ++ .registerInternalVariables(batchDims) ++ .declareVariables(...inputVariables, output)} ++ ${activationFunction} + ${declareFunctions} + ${ +- isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : +- makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} ++ isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : ++ makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} + `; +- }; ++ // TODO: turn clipMax and clipMin to uniforms. + return { + name: 'MatMul', + shaderCache: { +- hint: `${elementsPerThread};${activationAttributes.activation};${isVec4};${isChannelsLast}`, ++ hint: activationAttributes.activationCacheKey + `${elementsPerThread}` + ++ `${isVec4}` + ++ `${isChannelsLast}`, + inputDependencies + }, + getRunData: () => ({ +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +index 2cfe6356d..ef8038dff 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +@@ -1,7 +1,7 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../../wasm-common'; ++import {tensorDataTypeEnumToString} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ComputeContext, GpuDataType, ProgramUniform} from '../types'; + +@@ -241,10 +241,9 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView + WG = Math.ceil(dComp / 8); + } + const elementsPerWG = Math.ceil(d / components / WG); +- const programUniforms: ProgramUniform[] = [ +- {type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp}, +- {type: DataType.uint32, data: elementsPerWG} +- ]; ++ const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type']; ++ const programUniforms: ProgramUniform[] = ++ [{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}]; + const dataType = tensorTypeToWsglStorageType(input.dataType, components); + + const getShaderSource = (shaderHelper: ShaderHelper) => { +@@ -298,7 +297,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView + + if (sum == 0) { + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { +- x[offset + i] = ${fillVector(elemValueType, components, 'uniforms.d_inv')}; ++ x[offset + i] = ${fillVector('f32', components, 'uniforms.d_inv')}; + } + } else { + for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) { +@@ -337,10 +336,11 @@ const computeAttentionProbs = + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; ++ const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type']; + const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize}, +- {type: DataType.uint32, data: parameters.totalSequenceLength}, +- {type: DataType.uint32, data: parameters.kvSequenceLength}, {type: q.dataType, data: alpha} ++ {type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize}, ++ {type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength}, ++ {type: tensorDataType, data: alpha} + ]; + + const inputs = [q, key]; +@@ -430,9 +430,9 @@ const computeVxAttentionScore = + z: params.batchSize * params.numHeads + }; + const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength}, +- {type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads}, +- {type: DataType.uint32, data: params.vHiddenSize} ++ {type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength}, ++ {type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads}, ++ {type: 'uint32', data: params.vHiddenSize} + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { +@@ -526,10 +526,10 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { + }; + const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]]; + const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: M}, {type: DataType.uint32, data: K}, {type: DataType.uint32, data: N}, +- {type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.headSize}, +- {type: DataType.uint32, data: parameters.hiddenSize}, +- {type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} ++ {type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N}, ++ {type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize}, ++ {type: 'uint32', data: parameters.hiddenSize}, ++ {type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts +index 39b932375..00a6ca75b 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts +@@ -3,13 +3,12 @@ + + import {env} from 'onnxruntime-common'; + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; + import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; + import {ComputeContext, ProgramInfo} from '../types'; + +-import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; ++import {createTensorShapeVariables, enableShapesUniforms, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; + + export interface BatchNormAttributes extends AttributeWithCacheKey { + readonly epsilon: number; +@@ -62,7 +61,7 @@ const createBatchNormInferenceProgramInfo = + const cComponents = format === 'NHWC' && yShape.length > 1 ? components : 1; + const outputSize = ShapeUtil.size(yShape) / components; + // Only support uniforms for opset version >= 9 (spatial = true). +- const useShapesUniforms = spatial; ++ const useShapesUniforms = enableShapesUniforms(yShape.length) && spatial; + const shapeOrRank = useShapesUniforms ? yShape.length : yShape; + const x = inputVariable('x', inputs[0].dataType, inputs[0].dims, components); + const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims, cComponents); +@@ -124,11 +123,11 @@ const createBatchNormInferenceProgramInfo = + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: useShapesUniforms ? + [ +- {type: DataType.uint32, data: outputSize}, ++ {type: 'uint32', data: outputSize}, + ...createTensorShapeVariables(yShape), + ] : + [ +- {type: DataType.uint32, data: outputSize}, ++ {type: 'uint32', data: outputSize}, + ], + }), + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +index a094fffe2..c033c0ba0 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +@@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; + import {BroadcastUtil, ShapeUtil} from '../../util'; + import {ComputeContext, ProgramInfo} from '../types'; + +-import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; ++import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; + + type BuiltinFunctionName = string; + type BinaryCustomExpression = (expressionA: string, expressionB: string) => string; +@@ -18,7 +18,8 @@ type BinaryFunctionCall = BuiltinFunctionName|BinaryCustomExpression|{ + const createBinaryOpProgramShader = + (shaderHelper: ShaderHelper, dimsA: readonly number[], dimsB: readonly number[], dimsOutput: readonly number[], + vectorize: boolean, doBroadcast: boolean, sharedDimensionDivisibleBy4: boolean, funcCall: BinaryFunctionCall, +- typeA: number, typeB: number, typeOutput: number, additionalImplementation?: string) => { ++ typeA: number, typeB: number, typeOutput: number, useShapesUniforms: boolean, ++ additionalImplementation?: string) => { + let expressionScalar: BinaryCustomExpression; + let expressionVector: BinaryCustomExpression; + if (typeof funcCall === 'string') { +@@ -30,9 +31,12 @@ const createBinaryOpProgramShader = + expressionVector = funcCall.vector; + } + +- const output = outputVariable('outputData', typeOutput, dimsOutput.length, 4); +- const a = inputVariable('aData', typeA, dimsA.length, 4); +- const b = inputVariable('bData', typeB, dimsB.length, 4); ++ const inputAShapeOrRank = useShapesUniforms ? dimsA.length : dimsA; ++ const inputBShapeOrRank = useShapesUniforms ? dimsB.length : dimsB; ++ const outputShapeOrRank = useShapesUniforms ? dimsOutput.length : dimsOutput; ++ const output = outputVariable('outputData', typeOutput, outputShapeOrRank, 4); ++ const a = inputVariable('aData', typeA, inputAShapeOrRank, 4); ++ const b = inputVariable('bData', typeB, inputBShapeOrRank, 4); + + let assignment: string; + if (vectorize) { +@@ -165,23 +169,30 @@ const createBinaryOpProgramInfo = + vectorize = true; + } + cacheKeyAux.push(vectorize); +- ++ const useShapesUniforms = enableShapesUniforms(a.dims.length) && enableShapesUniforms(b.dims.length) && ++ enableShapesUniforms(outputShape.length); + return { + name, + shaderCache: { + hint: cacheKey + cacheKeyAux.map((x) => x.toString()).join('_'), +- inputDependencies: ['rank', 'rank'], ++ inputDependencies: useShapesUniforms ? ['rank', 'rank'] : ['dims', 'dims'], + }, + getShaderSource: (shaderHelper) => createBinaryOpProgramShader( + shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall, +- a.dataType, b.dataType, outputDataType, additionalImplementation), ++ a.dataType, b.dataType, outputDataType, useShapesUniforms, additionalImplementation), + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: outputDataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, +- programUniforms: [ +- {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, +- ...createTensorShapeVariables(a.dims, b.dims, outputShape) +- ], ++ programUniforms: useShapesUniforms ? ++ [ ++ {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, ++ ...createTensorShapeVariables(a.dims), ++ ...createTensorShapeVariables(b.dims), ++ ...createTensorShapeVariables(outputShape), ++ ] : ++ [ ++ {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, ++ ], + }), + }; + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts +index 516094d0e..bc3265be9 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts +@@ -259,16 +259,8 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = + return typeof mappedType === 'string' ? mappedType : mappedType[1]; + }; + +-export const createTensorShapeVariables = (...dims: ReadonlyArray): ProgramUniform[] => { +- const programUniforms: ProgramUniform[] = []; +- dims.forEach(dim => { +- if (dim.length !== 0) { +- programUniforms.push( +- {type: DataType.uint32, data: dim}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dim)}); +- } +- }); +- return programUniforms; +-}; ++export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => ++ dims.length === 0 ? [] : [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}]; + + /** + * A helper function to get maximum vector size for specified data length +@@ -338,28 +330,18 @@ export const sumVector = (name: string, components: number) => { + * @param name - the name of variable. + * @param index - the index of variable element. + * @param length - the length of variable. +- * @param type - the type of variable, optional. + */ +-export const getElementAt = +- (name: string, index: number|string, length: number, type?: UniformDataElementType): string => { +- if (name.startsWith('uniforms.') && length > 4) { +- if (typeof (index) === 'string') { +- if (type === 'f16') { +- return `${name}[(${index}) / 8][(${index}) % 8 / 4][(${index}) % 8 % 4]`; +- } else { +- return `${name}[(${index}) / 4][(${index}) % 4]`; +- } +- } else { +- if (type === 'f16') { +- return `${name}[${Math.floor(index / 8)}][${Math.floor(index % 8 / 4)}][${index % 8 % 4}]`; +- } else { +- return `${name}[${Math.floor(index / 4)}][${index % 4}]`; +- } +- } +- } else { +- return length > 1 ? `${name}[${index}]` : name; +- } +- }; ++export const getElementAt = (name: string, index: number|string, length: number): string => { ++ if (name.startsWith('uniforms.') && length > 4) { ++ if (typeof (index) === 'string') { ++ return `${name}[(${index}) / 4][(${index}) % 4]`; ++ } else { ++ return `${name}[${Math.floor(index / 4)}][${index % 4}]`; ++ } ++ } else { ++ return length > 1 ? `${name}[${index}]` : name; ++ } ++}; + + /** + * A helper function to get a IndicesHelper for a given input or output. +@@ -706,7 +688,7 @@ export const internalVariable = + (name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper => + createIndicesHelper(name, type, shapeOrRank, 'internal', components); + +-export type UniformDataElementType = 'u32'|'f16'|'f32'|'i32'; ++export type UniformDataElementType = 'u32'|'f32'|'i32'; + export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>; + + /** +@@ -879,11 +861,7 @@ class ShaderHelperImpl implements ShaderHelper { + const uniformSnippets: string[] = []; + for (const {name, type, length} of this.uniforms) { + if (length && length > 4) { +- if (type === 'f16') { +- uniformSnippets.push(`@align(16) ${name}:array, ${Math.ceil(length / 8)}>`); +- } else { +- uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); +- } ++ uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`); + } else { + const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`; + uniformSnippets.push(`${name}:${typeTemp}`); +@@ -930,3 +908,6 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly + } + return dims; + }; ++ ++// TODO: remove this when all related uses have been removed. ++export const enableShapesUniforms = (_rank: number): boolean => true; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +index b06c9fb49..43cc4a4c0 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +@@ -1,13 +1,12 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; + import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; + import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; + +-import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; ++import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; + + export interface ConcatAttributes extends AttributeWithCacheKey { + readonly axis: number; +@@ -95,22 +94,32 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P + + let previousSum = 0; + const inputDependencies: ProgramInputTensorInfoDependency[] = []; +- const inputRanks = []; +- const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; ++ const inputShapeOrRanks = []; ++ const enableInputShapesUniforms = []; ++ const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; + for (let i = 0; i < inputs.length; ++i) { + previousSum += inputs[i].dims[adjustedAxis]; + sizeInConcatAxis[i] = previousSum; +- inputRanks.push(inputs[i].dims.length); +- inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); +- inputDependencies.push('rank'); +- programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); ++ enableInputShapesUniforms.push(enableShapesUniforms(inputs[i].dims.length)); ++ inputShapeOrRanks.push(enableInputShapesUniforms[i] ? inputs[i].dims.length : inputs[i].dims); ++ inputVars[i] = inputVariable(`input${i}`, dataType, inputShapeOrRanks[i]); ++ inputDependencies.push(enableInputShapesUniforms[i] ? 'rank' : 'dims'); ++ programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); + } + for (let i = 0; i < inputs.length; ++i) { +- programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); ++ if (enableInputShapesUniforms[i]) { ++ programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); ++ } ++ } ++ ++ const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); ++ if (enableOutputShapesUniforms) { ++ programUniforms.push(...createTensorShapeVariables(outputShape)); + } +- programUniforms.push(...createTensorShapeVariables(outputShape)); + +- const output = outputVariable('output', dataType, outputShape.length); ++ const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; ++ const output = outputVariable('output', dataType, outputShapeOrRank); ++ + const indicesAxis = output.indicesGet('indices', adjustedAxis); + const sizeInConcatAxisStr = + Array.from(Array(sizeInConcatAxis.length).keys()).map(i => `uniforms.sizeInConcatAxis${i}`).join(','); +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +index 7d424305c..21b4953d3 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +@@ -1,14 +1,13 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; +-import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; ++import {ProgramInfo, ProgramUniform} from '../types'; + +-import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; ++import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper} from './common'; + import {calculateOutputShape, ConvAttributes} from './conv'; +-import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from './fuse-utils'; ++import {getActivationSnippet} from './fuse-utils'; + + /** + * naive grouped conv implementation, supports 1d/2d conv +@@ -28,70 +27,52 @@ export const createGroupedConvProgramInfo = + xShape, wShape, attributes.dilations, attributes.pads, attributes.strides, isChannelLast); + const outputSize = ShapeUtil.size(outputShape); + +- const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.dilations}, +- {type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]]}, +- {type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]]}, +- {type: DataType.uint32, data: outputChannelsPerGroup} +- ]; +- appendActivationUniformsData(attributes, programUniforms); +- programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShape)); +- const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; ++ const output = outputVariable('output', inputs[0].dataType, outputShape); ++ const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); ++ const x = inputVariable('x', inputs[0].dataType, xShape); ++ const w = inputVariable('w', inputs[1].dataType, wShape); ++ const inputVars = [x, w]; + if (hasBias) { +- programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); +- inputDependencies.push('rank'); ++ inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); + } +- programUniforms.push(...createTensorShapeVariables(outputShape)); + +- const getShaderSource = (shaderHelper: ShaderHelper) => { +- const output = outputVariable('output', inputs[0].dataType, outputShape.length); +- const baseType = tensorTypeToWsglStorageType(output.type.tensor); +- const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); +- const x = inputVariable('x', inputs[0].dataType, xShape.length); +- const w = inputVariable('w', inputs[1].dataType, wShape.length); +- const inputVars = [x, w]; +- if (hasBias) { +- inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims)); +- } ++ const getShaderSource = (shaderHelper: ShaderHelper) => ` ++ const strides: vec2 = vec2(${attributes.strides[0]}u, ${attributes.strides[1]}u); ++ const pads: vec2 = vec2(${attributes.pads[0]}u, ${attributes.pads[1]}u); + +- const uniforms: UniformsArrayType = [ +- {name: 'output_size', type: 'u32'}, {name: 'dilations', type: 'u32', length: attributes.dilations.length}, +- {name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2}, +- {name: 'output_channels_per_group', type: 'u32'} +- ]; +- appendActivationUniforms(attributes, uniforms); +- return ` +- ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ++ ${shaderHelper.declareVariables(...inputVars, output)} ++ ++ ${activationFunction} + + ${shaderHelper.mainStart()} +- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} ++ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + + let outputIndices = ${output.offsetToIndices('global_idx')}; + let batch: u32 = outputIndices[0]; + let output_channel: u32 = outputIndices[${isChannelLast ? 3 : 1}]; + let xRCCorner: vec2 = vec2(outputIndices[${isChannelLast ? 1 : 2}], outputIndices[${ +- isChannelLast ? 2 : 3}]) * uniforms.strides - uniforms.pads; +- let group_id: u32 = output_channel / uniforms.output_channels_per_group; ++ isChannelLast ? 2 : 3}]) * strides - pads; ++ let group_id: u32 = output_channel / ${outputChannelsPerGroup}u; + + var value: ${output.type.value} = ${output.type.value}(0); +- for (var wInChannel: u32 = 0u; wInChannel < uniforms.w_shape[1]; wInChannel++) { +- let input_channel = group_id * uniforms.w_shape[1] + wInChannel; +- for (var wHeight: u32 = 0u; wHeight < uniforms.w_shape[2]; wHeight++) { +- let xHeight = xRCCorner.x + wHeight * uniforms.dilations[0]; ++ for (var wInChannel: u32 = 0u; wInChannel < ${wShape[1]}u; wInChannel++) { ++ let input_channel = group_id * ${wShape[1]}u + wInChannel; ++ for (var wHeight: u32 = 0u; wHeight < ${wShape[2]}u; wHeight++) { ++ let xHeight = xRCCorner.x + wHeight * ${attributes.dilations[0]}u; + +- if (xHeight < 0u || xHeight >= uniforms.x_shape[${isChannelLast ? 1 : 2}]) { ++ if (xHeight < 0u || xHeight >= ${xShape[isChannelLast ? 1 : 2]}u) { + continue; + } + +- for (var wWidth: u32 = 0u; wWidth < uniforms.w_shape[3]; wWidth++) { +- let xWidth = xRCCorner.y + wWidth * uniforms.dilations[1]; +- if (xWidth < 0u || xWidth >= uniforms.x_shape[${isChannelLast ? 2 : 3}]) { ++ for (var wWidth: u32 = 0u; wWidth < ${wShape[3]}u; wWidth++) { ++ let xWidth = xRCCorner.y + wWidth * ${attributes.dilations[1]}u; ++ if (xWidth < 0u || xWidth >= ${xShape[isChannelLast ? 2 : 3]}u) { + continue; + } + + let xVal = ${ +- isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : +- x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; ++ isChannelLast ? x.get('batch', 'xHeight', 'xWidth', 'input_channel') : ++ x.get('batch', 'input_channel', 'xHeight', 'xWidth')}; + let wVal = ${w.get('output_channel', 'wInChannel', 'wHeight', 'wWidth')}; + value += xVal*wVal; + } +@@ -101,17 +82,15 @@ export const createGroupedConvProgramInfo = + ${applyActivation} + ${output.setByOffset('global_idx', 'value')} + }`; +- }; + return { + name: 'GroupedConv', +- shaderCache: {hint: attributes.cacheKey, inputDependencies}, ++ shaderCache: {hint: attributes.cacheKey}, + getRunData: () => ({ + outputs: [{ + dims: squeezeOutputShapeFunction ? squeezeOutputShapeFunction(outputShape) : outputShape, + dataType: inputs[0].dataType + }], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, +- programUniforms + }), + getShaderSource, + }; +@@ -128,17 +107,14 @@ export const createGroupedConvVectorizeProgramInfo = + const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; + + const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: outputSize}, +- {type: DataType.int32, data: [attributes.strides[0], attributes.strides[1]]}, +- {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]} ++ {type: 'uint32', data: outputSize}, {type: 'int32', data: attributes.strides}, ++ {type: 'int32', data: attributes.pads}, ...createTensorShapeVariables(xShape), ++ ...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShapeInShader) + ]; +- appendActivationUniformsData(attributes, programUniforms); +- programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShapeInShader)); + const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); +- const baseType = tensorTypeToWsglStorageType(output.type.tensor); +- const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); ++ const {activationFunction, applyActivation} = getActivationSnippet(attributes, output.type.value); + const x = inputVariable('x', inputs[0].dataType, xShape.length, components); + const w = inputVariable('w', inputs[1].dataType, wShape.length, components); + const inputVars = [x, w]; +@@ -146,14 +122,14 @@ export const createGroupedConvVectorizeProgramInfo = + inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components)); + } + const processBias = hasBias ? 'value += b[output_channel];' : ''; +- const uniforms: UniformsArrayType = [ +- {name: 'output_size', type: 'u32'}, +- {name: 'strides', type: 'i32', length: 2}, +- {name: 'pads', type: 'i32', length: 2}, +- ]; +- appendActivationUniforms(attributes, uniforms); ++ + return ` +- ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ++ ${ ++ shaderHelper.registerUniform('output_size', 'u32') ++ .registerUniform('strides', 'i32', 2) ++ .registerUniform('pads', 'i32', 2) ++ .declareVariables(...inputVars, output)} ++ ${activationFunction} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + let width0 = uniforms.output_shape[3]; +@@ -173,7 +149,7 @@ export const createGroupedConvVectorizeProgramInfo = + // Use constant instead of uniform can give better performance for w's height/width. + for (var w_height: u32 = 0u; w_height < ${wShape[0]}; w_height++) { + let x_height = x_corner.x + i32(w_height); +- if (x_height >= 0 && u32(x_height) < uniforms.x_shape[1]) { ++ if (x_height >= 0 || u32(x_height) < uniforms.x_shape[1]) { + for (var i = 0; i < ${xNumber}; i++) { + let x_width = x_corner.y + i; + if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) { +@@ -185,7 +161,7 @@ export const createGroupedConvVectorizeProgramInfo = + for (var w_width: u32 = 0u; w_width < ${wShape[1]}; w_width++) { + let w_val = ${w.get('w_height', 'w_width', '0', 'output_channel')}; + for (var i = 0u; i < ${outputNumber}u; i++) { +- values[i] = fma(x_vals[i * u32(uniforms.strides[1]) + w_width], w_val, values[i]); ++ values[i] = fma(x_vals[i * ${attributes.strides[1]}u + w_width], w_val, values[i]); + } + } + } +@@ -203,7 +179,7 @@ export const createGroupedConvVectorizeProgramInfo = + return { + name: 'GroupedConv-Vectorize', + shaderCache: { +- hint: `${attributes.cacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, ++ hint: `${attributes.activationCacheKey};${components};${outputNumber};${xNumber};${wShape[0]};${wShape[1]}`, + inputDependencies: hasBias ? ['rank', 'rank', 'type'] : ['rank', 'rank'] + }, + getRunData: () => ({ +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +index 33d16754c..32b1d52ed 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +@@ -2,6 +2,7 @@ + // Licensed under the MIT License. + + import {TensorView} from '../../tensor-view'; ++import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; + import {ComputeContext} from '../types'; + + import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; +@@ -58,6 +59,7 @@ export interface ConvTransposeAttributes extends ConvAttributes { + readonly outputShape: readonly number[]; + } + ++ + const getAdjustedConvTransposeAttributes = + (attributes: T, inputs: readonly TensorView[]): T => { + const kernelShape = attributes.kernelShape.slice(); +@@ -94,7 +96,11 @@ const getAdjustedConvTransposeAttributes = + + // always return a new object so does not modify the original attributes + const newAttributes: T = Object.assign({}, attributes); +- Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides}); ++ const cacheKey = attributes.cacheKey + [ ++ kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), ++ dilations.join(',') ++ ].join('_'); ++ Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); + return newAttributes; + }; + +@@ -113,7 +119,7 @@ export const parseConvTransposeAttributes = (attributes: Record + const wIsConst = (attributes.wIsConst as () => boolean)(); + const outputPadding = attributes.outputPadding as [number, number, number, number]; + const outputShape = attributes.outputShape as [number, number]; +- return { ++ return createAttributeWithCacheKey({ + autoPad, + format, + dilations, +@@ -124,9 +130,8 @@ export const parseConvTransposeAttributes = (attributes: Record + pads, + strides, + wIsConst, +- ...activationAttributes, +- cacheKey: `${attributes.format};${activationAttributes.activation};` +- }; ++ ...activationAttributes ++ }); + }; + + const validateInputs = (inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +index 5afec0389..7af2c5db4 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +@@ -3,7 +3,7 @@ + + import {TensorView} from '../../tensor-view'; + import {PoolConvUtil} from '../../util'; +-import {AttributeWithCacheKey} from '../attribute-with-cache-key'; ++import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; + import {ComputeContext} from '../types'; + + import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu'; +@@ -110,7 +110,7 @@ const getAdjustedConvAttributes = (attributes: T, inpu + + // always return a new object so does not modify the original attributes + const newAttributes: T = Object.assign({}, attributes); +- Object.assign(newAttributes, {kernelShape, pads}); ++ Object.assign(newAttributes, {kernelShape, pads, cacheKey: attributes.cacheKey}); + return newAttributes; + }; + +@@ -126,18 +126,8 @@ export const parseConvAttributes = (attributes: Record): ConvAt + const strides = attributes.strides as [number, number]; + const wIsConst = (attributes.w_is_const as () => boolean)(); + +- return { +- autoPad, +- format, +- dilations, +- group, +- kernelShape, +- pads, +- strides, +- wIsConst, +- ...activationAttributes, +- cacheKey: `${attributes.format};${activationAttributes.activation};` +- }; ++ return createAttributeWithCacheKey( ++ {autoPad, format, dilations, group, kernelShape, pads, strides, wIsConst, ...activationAttributes}); + }; + + const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvAttributes): void => { +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts +index 6080301d9..2ff909c30 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts +@@ -54,8 +54,8 @@ const createCumsumProgramInfo = + outputs: [{dims: inputShape, dataType: inputType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: [ +- {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis}, +- ...createTensorShapeVariables(inputShape, inputShape) ++ {type: 'uint32', data: outputSize}, {type: 'int32', data: axis}, ++ ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape) + ] + + }), +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +index 19a009c2e..4db7c04ad 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +@@ -1,13 +1,13 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; + import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; + import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +-import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; ++import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; ++ + + export interface EinsumAttributes extends AttributeWithCacheKey { + readonly equation: string; +@@ -181,12 +181,14 @@ class EinsumEquation { + const appendMax = (name: string): string => name + '_max'; + + const createEinsumProgramInfo = +- (inputShapes: Array, dataType: number, einsumEquation: EinsumEquation, +- outputShape: readonly number[]): ProgramInfo => { +- const ranks = inputShapes.map((dims) => dims.length); +- const inputVars = ranks.map((rank, index) => inputVariable(`input${index}`, dataType, rank)); ++ (enableInputShapesUniforms: readonly boolean[], inputShapes: Array, dataType: number, ++ einsumEquation: EinsumEquation, outputShape: readonly number[]): ProgramInfo => { ++ const shapeOrRanks = inputShapes.map((dims, index) => enableInputShapesUniforms[index] ? dims.length : dims); ++ const inputVars = shapeOrRanks.map((shapeOrRank, index) => inputVariable(`input${index}`, dataType, shapeOrRank)); + const outputSize = ShapeUtil.size(outputShape); +- const output = outputVariable('output', dataType, outputShape.length); ++ const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); ++ const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; ++ const output = outputVariable('output', dataType, outputShapeOrRank); + const uniformsSymbols = + [...einsumEquation.symbolToInfo.keys()].filter((symbol) => !einsumEquation.rhs.symbolToIndices.has(symbol)); + const getShaderSource = (shaderHelper: ShaderHelper) => { +@@ -267,20 +269,24 @@ const createEinsumProgramInfo = + }; + return { + name: 'Einsum', +- shaderCache: {hint: einsumEquation.equation, inputDependencies: inputShapes.map(() => 'rank')}, ++ shaderCache: { ++ hint: einsumEquation.equation, ++ inputDependencies: enableInputShapesUniforms.map((enableShapeUniform) => enableShapeUniform ? 'rank' : 'dims') ++ }, + getRunData: () => { + // The symbols from uniformSymbols array are guaranteed to exist in einsumEquations.symbolToInfo map. The + // filter is added to make sure that dimValue is never 0. + const programUniformsInit: ProgramUniform[] = + uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol)) +- .map( +- (symbol) => +- ({type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); +- programUniformsInit.push({type: DataType.uint32, data: outputSize}); ++ .map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); ++ programUniformsInit.push({type: 'uint32', data: outputSize}); + const programUniforms: ProgramUniform[] = +- inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)]) ++ inputShapes.filter((_, index) => enableInputShapesUniforms[index]) ++ .map((dims, _) => [...createTensorShapeVariables(dims)]) + .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); +- programUniforms.push(...createTensorShapeVariables(outputShape)); ++ if (enableOutputShapesUniforms) { ++ programUniforms.push(...createTensorShapeVariables(outputShape)); ++ } + return ({ + outputs: [{dims: outputShape, dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, +@@ -293,9 +299,11 @@ const createEinsumProgramInfo = + + export const einsum = (context: ComputeContext, attributes: EinsumAttributes): void => { + const einsumEquation = new EinsumEquation(context.inputs, attributes.equation); ++ const enableInputShapesUniforms = context.inputs.map((input, _) => enableShapesUniforms(input.dims.length)); + const outputShape = einsumEquation.outputDims; + const inputShapes = context.inputs.map((input, _) => input.dims); +- context.compute(createEinsumProgramInfo(inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); ++ context.compute(createEinsumProgramInfo( ++ enableInputShapesUniforms, inputShapes, context.inputs[0].dataType, einsumEquation, outputShape)); + }; + + export const parseEinsumAttributes = (attributes: Record): EinsumAttributes => { +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +index 80ee90642..035d89755 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +@@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; + import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +-import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; ++import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; + + const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length !== 2) { +@@ -49,9 +49,15 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => + const components = dataType === DataType.bool ? 4 : 1; + const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); + ++ const enableInputShapeUniform = enableShapesUniforms(inputShape.length); ++ const enableOutputShapeUniform = enableShapesUniforms(outputShape.length); ++ ++ + const getShaderSource = (shaderHelper: ShaderHelper) => { +- const input = inputVariable('input', dataType, inputShape.length, components); +- const output = outputVariable('output', dataType, outputShape.length, components); ++ const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape; ++ const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape; ++ const input = inputVariable('input', dataType, inputShapeOrRank, components); ++ const output = outputVariable('output', dataType, outputShapeOrRank, components); + let assignment: string; + if (dataType === DataType.bool) { + const singleAssignment = (resStr: string, x: number, typeCast = '') => ` +@@ -84,11 +90,16 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => + ${assignment}`; + }; + +- const programUniforms: ProgramUniform[] = +- [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)]; ++ const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; ++ if (enableInputShapeUniform) { ++ programUniforms.push(...createTensorShapeVariables(inputShape)); ++ } ++ if (enableOutputShapeUniform) { ++ programUniforms.push(...createTensorShapeVariables(outputShape)); ++ } + return { + name: 'Expand', +- shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']}, ++ shaderCache: {hint: `${outputShape.length}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']}, + getShaderSource, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +index 6e66abacf..0b5c0db2b 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +@@ -1,78 +1,44 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../../wasm-common'; + import {MAX_CLIP, MIN_CLIP} from '../../util'; +-import {ProgramUniform} from '../types'; +- +-import {UniformsArrayType} from './common'; + + export interface InternalActivationAttributes { + readonly activation: string; + readonly clipMin?: number; + readonly clipMax?: number; +- readonly alpha?: number; +- readonly beta?: number; ++ readonly activationCacheKey: string; + } + +-export const getActivationSnippet = +- (attributes: InternalActivationAttributes, valueType: string, baseType = 'f32'): string => { ++export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): ++ {activationFunction: string; applyActivation: string} => { + switch (attributes.activation) { + case 'Relu': +- return `value = max(value, ${valueType}(0.0));`; ++ return {activationFunction: '', applyActivation: `value = max(value, ${valueType}(0.0));`}; + case 'Sigmoid': +- return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; ++ return { ++ activationFunction: '', ++ applyActivation: `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));` ++ }; + case 'Clip': +- return `value = clamp(value, ${valueType}(${baseType}(uniforms.clip_min)), ${valueType}(${ +- baseType}(uniforms.clip_max)));`; +- case 'HardSigmoid': +- return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${baseType}(uniforms.alpha) * value + ${ +- baseType}(uniforms.beta)));`; +- case 'LeakyRelu': +- return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`; +- case '': +- return ''; +- // TODO: adding other activations that can be fused. ++ return { ++ activationFunction: `const clip_min_=${valueType}(${attributes.clipMin!});const clip_max_=${valueType}(${ ++ attributes.clipMax!});`, ++ applyActivation: 'value = clamp(value, clip_min_, clip_max_);' ++ }; ++ // TODO: adding other activations that can be fused. + default: +- throw new Error(`Unsupported activation ${attributes.activation}`); +- } +- }; +- +-export const appendActivationUniformsData = +- (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { +- if (attributes.activation === 'Clip') { +- programUniform.push( +- {type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!}); +- } else if (attributes.activation === 'HardSigmoid') { +- programUniform.push( +- {type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!}); +- } else if (attributes.activation === 'LeakyRelu') { +- programUniform.push({type: DataType.float, data: attributes.alpha!}); ++ return {activationFunction: '', applyActivation: ''}; + } + }; + +-export const appendActivationUniforms = (attributes: InternalActivationAttributes, uniforms: UniformsArrayType) => { +- if (attributes.activation === 'Clip') { +- uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); +- } else if (attributes.activation === 'HardSigmoid') { +- uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'}); +- } else if (attributes.activation === 'LeakyRelu') { +- uniforms.push({name: 'alpha', type: 'f32'}); +- } +-}; +- + export const parseInternalActivationAttributes = + (attributes: Record|undefined): InternalActivationAttributes => { + const activation = attributes?.activation as string || ''; +- if (activation === 'HardSigmoid') { +- const [alpha, beta] = attributes?.activation_params as [number, number] || [0.2, 0.5]; +- return {activation, alpha, beta}; +- } else if (activation === 'Clip') { ++ ++ if (activation === 'Clip') { + const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; +- return {activation, clipMax, clipMin}; +- } else if (activation === 'LeakyRelu') { +- const [alpha] = attributes?.activation_params as [number] || [0.01]; +- return {activation, alpha}; ++ return {activation, clipMax, clipMin, activationCacheKey: `${activation}:${clipMin},${clipMax}`}; + } +- return {activation}; ++ return {activation, activationCacheKey: activation}; + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +index 4ab6c175a..a945954ad 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +@@ -1,7 +1,6 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; + import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +@@ -47,11 +46,11 @@ const createGatherElementsProgramInfo = + const output = outputVariable('output', inputOutputDataType, outputShape.length); + + +- const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, +- {type: DataType.uint32, data: axis} +- ]; +- programUniforms.push(...createTensorShapeVariables(inputShape, indicesShape, outputShape)); ++ const programUniforms: ProgramUniform[] = ++ [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; ++ programUniforms.push(...createTensorShapeVariables(inputShape)); ++ programUniforms.push(...createTensorShapeVariables(indicesShape)); ++ programUniforms.push(...createTensorShapeVariables(outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; + + // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +index 5c31e6dd8..469249f92 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +@@ -5,9 +5,9 @@ import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; + import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +-import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; ++import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; + +-import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common'; ++import {createTensorShapeVariables, enableShapesUniforms, inputVariable, outputVariable, ShaderHelper} from './common'; + + export interface GatherAttributes extends AttributeWithCacheKey { + axis: number; +@@ -33,15 +33,33 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath + const components = inputs[0].dataType === DataType.bool ? 4 : 1; + const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); + +- const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, +- {type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, outputShape) +- ]; ++ const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length); ++ const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims; ++ const enableIndicesShapesUniforms = enableShapesUniforms(inputs[1].dims.length); ++ const indicesShapeOrRank = enableIndicesShapesUniforms ? inputs[1].dims.length : inputs[1].dims; ++ const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length); ++ const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape; ++ ++ const programUniforms: ProgramUniform[] = ++ [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; ++ if (enableInputShapesUniforms) { ++ programUniforms.push(...createTensorShapeVariables(inputs[0].dims)); ++ } ++ if (enableIndicesShapesUniforms) { ++ programUniforms.push(...createTensorShapeVariables(inputs[1].dims)); ++ } ++ if (enableOutputShapesUniforms) { ++ programUniforms.push(...createTensorShapeVariables(outputShape)); ++ } ++ ++ const inputDependencies: ProgramInputTensorInfoDependency[] = []; ++ inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims'); ++ inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims'); + + const getShaderSource = (shaderHelper: ShaderHelper) => { +- const data = inputVariable('data', inputs[0].dataType, inputs[0].dims.length, components); +- const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims.length); +- const output = outputVariable('output', inputs[0].dataType, outputShape.length, components); ++ const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank, components); ++ const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank); ++ const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank, components); + + const calcDataIndices = (x: number|string): string => { + const indicesRank = indicesShape.length; +@@ -109,7 +127,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath + }; + return { + name: 'Gather', +- shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank', 'rank']}, ++ shaderCache: {hint: attributes.cacheKey, inputDependencies}, + getRunData: () => ({ + outputs: [ + {dims: outputShape, dataType: inputs[0].dataType}, +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +index 76302e1af..a0d402151 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +@@ -1,7 +1,6 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {GemmUtil, ShapeUtil} from '../../util'; + import {AttributeWithCacheKey} from '../attribute-with-cache-key'; +@@ -46,9 +45,8 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt + } + const outputSize = ShapeUtil.size(outputShape); + const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, +- {type: DataType.uint32, data: K}, {type: DataType.float, data: attributes.alpha}, +- {type: DataType.float, data: attributes.beta} ++ {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, {type: 'uint32', data: K}, ++ {type: 'float32', data: attributes.alpha}, {type: 'float32', data: attributes.beta} + ]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + if (inputs.length === 3) { +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +index 2f652dbd3..a835c90bd 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +@@ -25,8 +25,8 @@ const createInstanceNormProgramInfo = + const inputShape = [xShape[0], xShape[1], normPackedSize]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; + const programUniforms: ProgramUniform[] = +- [{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}]; +- programUniforms.push(...createTensorShapeVariables(inputShape, inputShape)); ++ [{type: 'uint32', data: normSize}, {type: 'uint32', data: normPackedSize}]; ++ programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)); + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); +@@ -132,9 +132,8 @@ const computeMean = + + const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; + const meanProgramUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: wgSize}, {type: DataType.uint32, data: h}, +- {type: DataType.uint32, data: Math.floor(c / components)}, +- {type: DataType.uint32, data: Math.floor(h * c / components)} ++ {type: 'uint32', data: wgSize}, {type: 'uint32', data: h}, {type: 'uint32', data: Math.floor(c / components)}, ++ {type: 'uint32', data: Math.floor(h * c / components)} + ]; + + const getMeanShaderSource = (shaderHelper: ShaderHelper) => { +@@ -183,9 +182,8 @@ const computeMean = + {inputs: [input], outputs: [-1]})[0]; + + const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: unitsOfWork}, {type: DataType.uint32, data: h}, +- {type: DataType.uint32, data: Math.floor(c / components)}, +- {type: DataType.uint32, data: Math.floor(WG * c / components)} ++ {type: 'uint32', data: unitsOfWork}, {type: 'uint32', data: h}, ++ {type: 'uint32', data: Math.floor(c / components)}, {type: 'uint32', data: Math.floor(WG * c / components)} + ]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; + const getShaderSource = (shaderHelper: ShaderHelper) => { +@@ -248,7 +246,7 @@ const createInstanceNormNHWCProgramInfo = + const components = getMaxComponents(C); + const outputSize = ShapeUtil.size(outputShape) / components; + const programUniforms: ProgramUniform[] = +- [{type: DataType.uint32, data: H}, {type: DataType.uint32, data: Math.floor(C / components)}]; ++ [{type: 'uint32', data: H}, {type: 'uint32', data: Math.floor(C / components)}]; + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + // first compute mean + const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +index 3f73d9cb7..3c9f6ce71 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +@@ -49,9 +49,8 @@ const createLayerNormProgramInfo = + const components = getMaxComponents(normSize); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; + const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: normCount}, {type: DataType.float, data: normSize}, +- {type: DataType.uint32, data: Math.floor(normSize / components)}, +- {type: DataType.float, data: attributes.epsilon} ++ {type: 'uint32', data: normCount}, {type: 'float32', data: normSize}, ++ {type: 'uint32', data: Math.floor(normSize / components)}, {type: 'float32', data: attributes.epsilon} + ]; + if (bias) { + inputDependencies.push('type'); +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +index 1a92d8610..de9309d1e 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +@@ -1,14 +1,13 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {BroadcastUtil, ShapeUtil} from '../../util'; + import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + + import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; +-import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; +-import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; ++import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common'; ++import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; + + export const createNaiveMatmulProgramInfo = + (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], +@@ -28,13 +27,11 @@ export const createNaiveMatmulProgramInfo = + const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2); + const batchSize = ShapeUtil.size(outerDims); + const outputShapeInShader = [batchSize, M, N]; +- + const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, +- {type: DataType.uint32, data: K} ++ {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, ++ {type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), ++ ...createTensorShapeVariables(bShape) + ]; +- appendActivationUniformsData(activationAttributes, programUniforms); +- programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape)); + if (hasBias) { + programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); + } +@@ -45,8 +42,7 @@ export const createNaiveMatmulProgramInfo = + const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); + const b = inputVariable('b', inputs[1].dataType, bShape.length, components); + const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); +- const baseType = tensorTypeToWsglStorageType(output.type.tensor); +- const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); ++ const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value); + const inputVariables = [a, b]; + let processBias = ''; + if (hasBias) { +@@ -61,12 +57,6 @@ export const createNaiveMatmulProgramInfo = + const outerDimsB = bShape.slice(0, -2); + const broadCastADims = getBroadcastDims(outerDimsA, outerDims); + const broadCastBDims = getBroadcastDims(outerDimsB, outerDims); +- const uniforms: UniformsArrayType = [ +- {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, +- {name: 'K', type: 'u32'} +- ]; +- appendActivationUniforms(activationAttributes, uniforms); +- + const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { + const rank = variable.rank; + const name = variable.name; +@@ -106,10 +96,15 @@ export const createNaiveMatmulProgramInfo = + + return ` + ${ +- shaderHelper.registerUniforms(uniforms).registerInternalVariables(batchDims).declareVariables( +- ...inputVariables, output)} ++ shaderHelper.registerUniform('outputSize', 'u32') ++ .registerUniform('M', 'u32') ++ .registerUniform('N', 'u32') ++ .registerUniform('K', 'u32') ++ .registerInternalVariables(batchDims) ++ .declareVariables(...inputVariables, output)} ++ ${activationFunction} + ${shaderHelper.mainStart()} +- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} ++ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')} + let col = (global_idx % (uniforms.N / ${components})) * ${components}; + var index1 = global_idx / (uniforms.N / ${components}); + let stride1 = uniforms.M / ${outputNumber}; +@@ -139,7 +134,8 @@ export const createNaiveMatmulProgramInfo = + return { + name: 'MatMulNaive', + shaderCache: { +- hint: `${activationAttributes.activation};${components};${aComponents};${outputNumber};${isChannelsLast}`, ++ hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${ ++ isChannelsLast}`, + inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank'] + }, + getRunData: () => ({ +@@ -170,8 +166,9 @@ export const matMul = (context: ComputeContext): void => { + const N = outputShape[outputShape.length - 1]; + const K = context.inputs[0].dims[context.inputs[0].dims.length - 1]; + if (N < 8 && K < 8) { +- context.compute(createNaiveMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); ++ context.compute( ++ createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + } else { +- context.compute(createMatmulProgramInfo(context.inputs, {activation: ''}, outputShape)); ++ context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape)); + } + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +index 5c5c849d9..6d22e3780 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +@@ -1,7 +1,6 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; + import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; +@@ -239,10 +238,8 @@ const addBiasTranspose = + hiddenSize: number, biasOffset: number) => { + const outputShape = [batchSize, sequenceLength, hiddenSize]; + const outputSize = ShapeUtil.size(outputShape); +- const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: biasOffset}, +- {type: DataType.uint32, data: hiddenSize} +- ]; ++ const programUniforms: ProgramUniform[] = ++ [{type: 'uint32', data: outputSize}, {type: 'uint32', data: biasOffset}, {type: 'uint32', data: hiddenSize}]; + + const getShaderSource = (shaderHelper: ShaderHelper) => { + const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape); +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +index 236fc29fd..eca3fa7d9 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +@@ -1,7 +1,7 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../../wasm-common'; ++import {DataType, tensorDataTypeEnumToString} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; + import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; +@@ -19,8 +19,8 @@ const validateInputs = (inputs: readonly TensorView[]): void => { + if (!inputs || inputs.length < 1) { + throw new Error('Too few inputs'); + } +- if (inputs[0].dataType !== DataType.float && inputs[0].dataType !== DataType.float16) { +- throw new Error('Input type must be float or float16.'); ++ if (inputs[0].dataType !== DataType.float) { ++ throw new Error('Input type must be float.'); + } + + if (inputs.length >= 2) { +@@ -153,12 +153,13 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr + const inputDims = inputs[0].dims; + const outputSize = ShapeUtil.size(outputShape); + const programUniforms: ProgramUniform[] = +- [{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.pads}]; ++ [{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.pads}]; + if (attributes.mode === 0) { +- programUniforms.push({type: inputs[0].dataType, data: attributes.value}); ++ const tensorDataType = tensorDataTypeEnumToString(inputs[0].dataType) as ProgramUniform['type']; ++ programUniforms.push({type: tensorDataType, data: attributes.value}); + } + +- programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape)); ++ programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape)); + const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; + + const getShaderSource = (shaderHelper: ShaderHelper) => { +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +index 4e933573b..9e9b361c1 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +@@ -3,7 +3,6 @@ + + import {env} from 'onnxruntime-common'; + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {PoolConvUtil, ShapeUtil} from '../../util'; + import {AttributeWithCacheKey} from '../attribute-with-cache-key'; +@@ -57,8 +56,7 @@ const getUniformAndPadInfo = ({ + outputs: [{dims: outputShape, dataType: outputDataType}], + dispatchGroup: {x: outputSize}, +- programUniforms: [{type: DataType.uint32, data: reduceSize}] ++ programUniforms: [{type: 'uint32', data: reduceSize}] + }), + }; + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +index e8205ba6f..e8851ac54 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +@@ -100,8 +100,10 @@ export const createReduceProgramInfo = + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: outputDataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, +- programUniforms: +- [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)] ++ programUniforms: [ ++ {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), ++ ...createTensorShapeVariables(outputShape) ++ ] + }), + }; + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +index 2c6b537de..f68526acc 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +@@ -2,7 +2,6 @@ + // Licensed under the MIT License. + + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; + import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +@@ -642,8 +641,11 @@ const createResizeProgramInfo = + outputs: [{dims: outputShape, dataType: inputTensor.dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: [ +- {type: DataType.uint32, data: outputSize}, {type: DataType.float, data: scales}, +- {type: DataType.float, data: roi}, ...createTensorShapeVariables(inputShape, outputShape) ++ {type: 'uint32', data: outputSize}, ++ {type: 'float32', data: scales}, ++ {type: 'float32', data: roi}, ++ ...createTensorShapeVariables(inputShape), ++ ...createTensorShapeVariables(outputShape), + ] + }) + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +index 7be9ceec6..a2fda9f07 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +@@ -4,10 +4,10 @@ + import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; +-import {AttributeWithCacheKey} from '../attribute-with-cache-key'; +-import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; ++import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; ++import {ComputeContext, ProgramInfo} from '../types'; + +-import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; ++import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; + + export interface SkipLayerNormAttributes extends AttributeWithCacheKey { + epsilon: number; +@@ -86,74 +86,60 @@ const createSkipLayerNormProgramInfo = + const hasInputSkipBiasSumOutput = outputCount > 3; + + const components = getMaxComponents(hiddenSize); +- +- const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: outputSize}, +- {type: DataType.uint32, data: components}, +- {type: DataType.uint32, data: hiddenSize}, +- {type: DataType.float, data: attributes.epsilon}, ++ const variables = [ ++ inputVariable('x', inputs[0].dataType, inputs[0].dims, components), ++ inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), ++ inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), + ]; +- const getShaderSource = (shaderHelper: ShaderHelper) => { +- const uniformsArray: UniformsArrayType = [ +- {name: 'output_size', type: 'u32'}, +- {name: 'components', type: 'u32'}, +- {name: 'hidden_size', type: 'u32'}, +- {name: 'epsilon', type: 'f32'}, +- ]; +- const variables = [ +- inputVariable('x', inputs[0].dataType, inputs[0].dims, components), +- inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), +- inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), +- ]; +- if (hasBetaInput) { +- variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); +- } +- if (hasBiasInput) { +- variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); +- } +- variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); +- if (hasMeanOutput) { +- variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim)); +- } +- if (hasInvStdDevOutput) { +- variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); +- } +- if (hasInputSkipBiasSumOutput) { +- variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components)); +- } +- const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); +- return ` +- +- ${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)} ++ if (hasBetaInput) { ++ variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); ++ } ++ if (hasBiasInput) { ++ variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); ++ } ++ variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); ++ if (hasMeanOutput) { ++ variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim)); ++ } ++ if (hasInvStdDevOutput) { ++ variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); ++ } ++ if (hasInputSkipBiasSumOutput) { ++ variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); ++ } ++ const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); ++ const getShaderSource = (shaderHelper: ShaderHelper) => ` ++ const hiddenSize: f32 = ${hiddenSize}; ++ const hiddenSizeVectorized: u32 = ${hiddenSize / components}; ++ const epsilon: f32 = ${attributes.epsilon}; ++ ++ ${shaderHelper.declareVariables(...variables)} + + ${shaderHelper.mainStart()} +- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size / uniforms.hidden_size')} +- let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components; +- let offset = global_idx * hidden_size_vectorized; ++ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)} ++ let offset = global_idx * hiddenSizeVectorized; + var sum = ${fillVector('f32', components)}; + var squareSum = ${fillVector('f32', components)}; +- for (var i: u32 = 0; i < hidden_size_vectorized; i++) { +- let skip_value = skip[offset + i]; +- let bias_value = ${hasBiasInput ? 'bias[i]' : '0.0'}; +- let input_value = x[offset + i]; +- let value = input_value + skip_value + bias_value; +- ${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''} ++ for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { ++ let skipValue = skip[offset + i]; ++ let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'}; ++ let inputValue = x[offset + i]; ++ let value = inputValue + skipValue + biasValue; ++ ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''} + output[offset + i] = value; +- let f32_value = ${castToF32(dataType, components, 'value')}; +- sum += f32_value; +- squareSum += f32_value * f32_value; ++ let f32Value = ${castToF32(dataType, components, 'value')}; ++ sum += f32Value; ++ squareSum += f32Value * f32Value; + } +- let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size); +- let inv_std_dev = inverseSqrt(${ +- sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon); +- ${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''} +- ${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''} +- for (var i: u32 = 0; i < hidden_size_vectorized; i++) { +- output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(inv_std_dev) * gamma[i] + ${ +- hasBetaInput ? 'beta[i]' : '0.0'}; ++ let mean = ${sumVector('sum', components)} / hiddenSize; ++ let invStdDev = inverseSqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon); ++ ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''} ++ ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = invStdDev;' : ''} ++ for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { ++ output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(invStdDev) * gamma[i] ++ + ${hasBetaInput ? 'beta[i]' : '0.0'}; + } + }`; +- }; + const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; + if (outputCount > 1) { + outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); +@@ -164,14 +150,12 @@ const createSkipLayerNormProgramInfo = + if (outputCount > 3) { + outputs.push({dims: inputShape, dataType: inputs[0].dataType}); + } ++ + return { + name: 'SkipLayerNormalization', +- shaderCache: { +- hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`, +- inputDependencies: inputs.map((_input, _index) => 'type') +- }, ++ shaderCache: {hint: attributes.cacheKey}, + getShaderSource, +- getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}), ++ getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}), + }; + }; + +@@ -194,3 +178,8 @@ export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNorm + context.compute( + createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {outputs}); + }; ++ ++export const parseSkipLayerNormAttributes = (attributes: Record): SkipLayerNormAttributes => { ++ const epsilon = attributes.epsilon as number; ++ return createAttributeWithCacheKey({epsilon}); ++}; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +index a5e71f30e..5212c6475 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +@@ -155,9 +155,9 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice + ]; + + const programUniforms: ProgramUniform[] = [ +- {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: starts}, +- {type: DataType.int32, data: signs}, {type: DataType.uint32, data: steps}, +- ...createTensorShapeVariables(inputs[0].dims, outputShape) ++ {type: 'uint32', data: outputSize}, {type: 'uint32', data: starts}, {type: 'int32', data: signs}, ++ {type: 'uint32', data: steps}, ...createTensorShapeVariables(inputs[0].dims), ++ ...createTensorShapeVariables(outputShape) + ]; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +index 6f8bfa08d..324dc3af1 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +@@ -5,7 +5,6 @@ + // performance limitations when the reduced axis is long. Need to add + // a optimized codepath for this. + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; + import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +@@ -137,7 +136,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut + getRunData: () => ({ + outputs: [{dims: shape, dataType: input.dataType}], + dispatchGroup: {x: rows}, +- programUniforms: [{type: DataType.uint32, data: packedCols}] ++ programUniforms: [{type: 'uint32', data: packedCols}] + }), + getShaderSource, + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts +index 14d6f3792..b8582614f 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts +@@ -1,7 +1,6 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; + import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +@@ -73,7 +72,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split + const outputsTensorInfo: TensorInfo[] = []; + const outputShapes: number[][] = []; + let previousSum = 0; +- const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: inputSize}]; ++ const programUniforms: ProgramUniform[] = [{type: 'uint32', data: inputSize}]; + for (let i = 0; i < attributes.numOutputs; i++) { + previousSum += attributes.splitSizes[i]; + sizeInSplitAxis[i] = previousSum; +@@ -83,8 +82,9 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split + outputs[i] = outputVariable(`output${i}`, dataType, outputShape); + outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); + } +- programUniforms.push( +- {type: DataType.uint32, data: sizeInSplitAxis}, ...createTensorShapeVariables(inputShape, ...outputShapes)); ++ programUniforms.push({type: 'uint32', data: sizeInSplitAxis}); ++ programUniforms.push(...createTensorShapeVariables(inputShape)); ++ outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape))); + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${ + shaderHelper.registerUniform('input_size', 'u32') +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +index f9728575f..90a36a7be 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +@@ -79,8 +79,10 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, +- programUniforms: +- [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)], ++ programUniforms: [ ++ {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), ++ ...createTensorShapeVariables(outputShape) ++ ], + }), + getShaderSource, + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +index 7ae801222..c4d43e9f4 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +@@ -1,13 +1,12 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../../wasm-common'; + import {TensorView} from '../../tensor-view'; + import {ShapeUtil} from '../../util'; + import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; + import {ComputeContext, ProgramInfo} from '../types'; + +-import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; ++import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; + + export interface TransposeAttributes extends AttributeWithCacheKey { + readonly perm: number[]; +@@ -40,9 +39,12 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu + const inputDataType = inputTensor.dataType; + const inputRank = inputTensor.dims.length; + const perm = getAdjustedPerm(inputRank, permAttr); ++ const useShapesUniforms = enableShapesUniforms(inputRank); + const outputShape = getOutputShape(inputTensor.dims, perm); +- const output = outputVariable('output', inputDataType, outputShape.length); +- const input = inputVariable('a', inputDataType, inputRank); ++ const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape; ++ const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims; ++ const output = outputVariable('output', inputDataType, outShapeOrRank); ++ const input = inputVariable('a', inputDataType, inShapeOrRank); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} +@@ -59,14 +61,21 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu + }`; + return { + name: 'Transpose', +- shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']}, ++ shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']}, + getRunData: (inputs) => { + const outputSize = ShapeUtil.size(outputShape); + return { + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, +- programUniforms: +- [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)], ++ programUniforms: useShapesUniforms ? ++ [ ++ {type: 'uint32', data: outputSize}, ++ ...createTensorShapeVariables(inputs[0].dims), ++ ...createTensorShapeVariables(outputShape), ++ ] : ++ [ ++ {type: 'uint32', data: outputSize}, ++ ], + }; + }, + getShaderSource, +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +index 1accfac18..a25e7fe42 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +@@ -53,7 +53,7 @@ const createElementwiseProgramInfo = + dispatchGroup: + {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}, + programUniforms: [ +- {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, ++ {type: 'uint32', data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, + ], + }) + }); +@@ -242,26 +242,6 @@ export const sigmoid = (context: ComputeContext): void => { + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); + }; + +-export interface HardSigmoidAttributes extends AttributeWithCacheKey { +- readonly alpha: number; +- readonly beta: number; +-} +- +-export const parseHardSigmoidAttributes = (attributes: Record): HardSigmoidAttributes => +- createAttributeWithCacheKey(attributes as { +- alpha: number; +- beta: number; +- }); +- +-export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => { +- const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); +- context.compute(createElementwiseProgramInfo( +- context.inputs[0], 'HardSigmoid', +- a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${ +- attributes.beta})))`, +- undefined, attributes.cacheKey)); +-}; +- + export const sin = (context: ComputeContext): void => { + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin')); + }; +@@ -279,9 +259,7 @@ export const tan = (context: ComputeContext): void => { + }; + + export const tanh = (context: ComputeContext): void => { +- // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved +- context.compute(createElementwiseProgramInfo( +- context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`)); ++ context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', 'tanh')); + }; + + export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { +diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts +index cfee07a92..2ef9637bc 100644 +--- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts ++++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts +@@ -97,8 +97,10 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: outputDataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}, +- programUniforms: +- [{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape)], ++ programUniforms: [ ++ {type: 'uint32', data: vecSize}, ...createTensorShapeVariables(dimsC), ...createTensorShapeVariables(dimsA), ++ ...createTensorShapeVariables(dimsB), ...createTensorShapeVariables(outputShape) ++ ], + }), + }; + }; +diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts +index 9d05f607f..72eb9713e 100644 +--- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts ++++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts +@@ -38,6 +38,7 @@ export class ProgramManager { + const device = this.backend.device; + const computePassEncoder = this.backend.getComputePassEncoder(); + this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); ++ computePassEncoder.setPipeline(buildArtifact.computePipeline); + const entries = []; + for (const input of inputs) { + entries.push({binding: entries.length, resource: {buffer: input.buffer}}); +@@ -50,20 +51,8 @@ export class ProgramManager { + } + const bindGroup = device.createBindGroup( + {layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name}); +- +- if (this.backend.sessionStatus === 'capturing') { +- const commandInfo = { +- kernelId: this.backend.currentKernelId!, +- computePipeline: buildArtifact.computePipeline, +- bindGroup, +- dispatchGroup +- }; +- const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); +- sessionCommandList!.push(commandInfo); +- } +- +- computePassEncoder.setPipeline(buildArtifact.computePipeline); + computePassEncoder.setBindGroup(0, bindGroup); ++ + computePassEncoder.dispatchWorkgroups(...dispatchGroup); + this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); + this.backend.pendingDispatchNumber++; +diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts +index ba5b84fcf..e55bfb6ba 100644 +--- a/js/web/lib/wasm/jsep/webgpu/types.ts ++++ b/js/web/lib/wasm/jsep/webgpu/types.ts +@@ -1,13 +1,10 @@ + // Copyright (c) Microsoft Corporation. All rights reserved. + // Licensed under the MIT License. + +-import {DataType} from '../../wasm-common'; + import {TensorView} from '../tensor-view'; + + import {ShaderHelper} from './ops/common'; + +-export type SessionState = 'default'|'capturing'|'replaying'; +- + export enum GpuDataType { + default = 0, + upload = 1, +@@ -27,7 +24,7 @@ export interface TensorInfo { + } + + export interface ProgramUniform { +- type: DataType; ++ type: 'int32'|'float32'|'uint32'; + data: number|readonly number[]; + } + +diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts +index 48eac5749..41ab2d52c 100644 +--- a/js/web/lib/wasm/session-options.ts ++++ b/js/web/lib/wasm/session-options.ts +@@ -168,18 +168,6 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n + setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); + } + +- if (sessionOptions.enableGraphCapture !== undefined) { +- if (typeof sessionOptions.enableGraphCapture !== 'boolean') { +- throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`); +- } +- const keyDataOffset = allocWasmString('enableGraphCapture', allocs); +- const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs); +- if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { +- checkLastError( +- `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`); +- } +- } +- + if (sessionOptions.freeDimensionOverrides) { + for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) { + if (typeof name !== 'string') { +diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts +index 37b9ed6a1..5821fac3c 100644 +--- a/js/web/lib/wasm/wasm-core-impl.ts ++++ b/js/web/lib/wasm/wasm-core-impl.ts +@@ -84,7 +84,7 @@ export const initRuntime = async(env: Env): Promise => { + * @param epName + */ + export const initEp = async(env: Env, epName: string): Promise => { +- if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) { ++ if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') { + // perform WebGPU availability check + if (typeof navigator === 'undefined' || !navigator.gpu) { + throw new Error('WebGPU is not supported in current environment'); +@@ -139,7 +139,7 @@ type IOBindingState = { + */ + type SessionMetadata = [ + inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], +- bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBound: boolean ++ bindingState: IOBindingState|null + ]; + + const activeSessions = new Map(); +@@ -228,15 +228,13 @@ export const createSession = async( + await Promise.all(loadingPromises); + } + +- sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); ++ sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); + if (sessionHandle === 0) { + checkLastError('Can\'t create a session.'); + } + + const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); + +- const enableGraphCapture = !!options?.enableGraphCapture; +- + const inputNames = []; + const outputNames = []; + const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; +@@ -258,20 +256,12 @@ export const createSession = async( + outputNames.push(nameString); + + if (!BUILD_DEFS.DISABLE_WEBGPU) { +- if (enableGraphCapture && options?.preferredOutputLocation === undefined) { +- outputPreferredLocations.push('gpu-buffer'); +- continue; +- } + const location = typeof options?.preferredOutputLocation === 'string' ? + options.preferredOutputLocation : + options?.preferredOutputLocation?.[nameString] ?? 'cpu'; + if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${location}.`); + } +- if (enableGraphCapture && location !== 'gpu-buffer') { +- throw new Error(`Not supported preferred output location: ${ +- location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`); +- } + outputPreferredLocations.push(location); + } + } +@@ -291,9 +281,7 @@ export const createSession = async( + }; + } + +- activeSessions.set( +- sessionHandle, +- [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, enableGraphCapture, false]); ++ activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); + return [sessionHandle, inputNames, outputNames]; + } catch (e) { + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); +@@ -325,16 +313,13 @@ export const releaseSession = (sessionId: number): void => { + if (!session) { + throw new Error(`cannot release session. invalid session id: ${sessionId}`); + } +- const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session; ++ const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + + if (ioBindingState) { +- if (enableGraphCapture) { +- wasm._OrtClearBoundOutputs(ioBindingState.handle); +- } + wasm._OrtReleaseBinding(ioBindingState.handle); + } + +- wasm.jsepOnReleaseSession?.(sessionId); ++ wasm.jsepUnregisterBuffers?.(sessionId); + + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); +@@ -343,75 +328,70 @@ export const releaseSession = (sessionId: number): void => { + }; + + export const prepareInputOutputTensor = +- (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number, +- enableGraphCapture = false): void => { +- if (!tensor) { +- tensorHandles.push(0); +- return; +- } +- +- const wasm = getInstance(); ++ (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): ++ void => { ++ if (!tensor) { ++ tensorHandles.push(0); ++ return; ++ } + +- const dataType = tensor[0]; +- const dims = tensor[1]; +- const location = tensor[3]; ++ const wasm = getInstance(); + +- let rawData: number; +- let dataByteLength: number; ++ const dataType = tensor[0]; ++ const dims = tensor[1]; ++ const location = tensor[3]; + +- if (dataType === 'string' && location === 'gpu-buffer') { +- throw new Error('String tensor is not supported on GPU.'); +- } ++ let rawData: number; ++ let dataByteLength: number; + +- if (enableGraphCapture && location !== 'gpu-buffer') { +- throw new Error( +- `External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`); +- } ++ if (dataType === 'string' && location === 'gpu-buffer') { ++ throw new Error('String tensor is not supported on GPU.'); ++ } + +- if (location === 'gpu-buffer') { +- const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; +- const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; +- dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; +- rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); +- } else { +- const data = tensor[2]; +- +- if (Array.isArray(data)) { +- // string tensor +- dataByteLength = 4 * data.length; +- rawData = wasm._malloc(dataByteLength); +- allocs.push(rawData); +- let dataIndex = rawData / 4; +- for (let i = 0; i < data.length; i++) { +- if (typeof data[i] !== 'string') { +- throw new TypeError(`tensor data at index ${i} is not a string`); ++ if (location === 'gpu-buffer') { ++ const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; ++ const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; ++ dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; ++ rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); ++ } else { ++ const data = tensor[2]; ++ ++ if (Array.isArray(data)) { ++ // string tensor ++ dataByteLength = 4 * data.length; ++ rawData = wasm._malloc(dataByteLength); ++ allocs.push(rawData); ++ let dataIndex = rawData / 4; ++ for (let i = 0; i < data.length; i++) { ++ if (typeof data[i] !== 'string') { ++ throw new TypeError(`tensor data at index ${i} is not a string`); ++ } ++ wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); ++ } ++ } else { ++ dataByteLength = data.byteLength; ++ rawData = wasm._malloc(dataByteLength); ++ allocs.push(rawData); ++ wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } +- wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); + } +- } else { +- dataByteLength = data.byteLength; +- rawData = wasm._malloc(dataByteLength); +- allocs.push(rawData); +- wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); +- } +- } + +- const stack = wasm.stackSave(); +- const dimsOffset = wasm.stackAlloc(4 * dims.length); +- try { +- let dimIndex = dimsOffset / 4; +- dims.forEach(d => wasm.HEAP32[dimIndex++] = d); +- const tensor = wasm._OrtCreateTensor( +- tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, +- dataLocationStringToEnum(location)); +- if (tensor === 0) { +- checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); +- } +- tensorHandles.push(tensor); +- } finally { +- wasm.stackRestore(stack); +- } +- }; ++ const stack = wasm.stackSave(); ++ const dimsOffset = wasm.stackAlloc(4 * dims.length); ++ try { ++ let dimIndex = dimsOffset / 4; ++ dims.forEach(d => wasm.HEAP32[dimIndex++] = d); ++ const tensor = wasm._OrtCreateTensor( ++ tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, ++ dataLocationStringToEnum(location)); ++ if (tensor === 0) { ++ checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); ++ } ++ tensorHandles.push(tensor); ++ } finally { ++ wasm.stackRestore(stack); ++ } ++ }; + + /** + * perform inference run +@@ -424,12 +404,7 @@ export const run = async( + if (!session) { + throw new Error(`cannot run inference. invalid session id: ${sessionId}`); + } +- const sessionHandle = session[0]; +- const inputNamesUTF8Encoded = session[1]; +- const outputNamesUTF8Encoded = session[2]; +- const ioBindingState = session[3]; +- const enableGraphCapture = session[4]; +- const inputOutputBound = session[5]; ++ const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + + const inputCount = inputIndices.length; + const outputCount = outputIndices.length; +@@ -452,15 +427,13 @@ export const run = async( + + // create input tensors + for (let i = 0; i < inputCount; i++) { +- prepareInputOutputTensor( +- inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], enableGraphCapture); ++ prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + } + + // create output tensors + for (let i = 0; i < outputCount; i++) { + prepareInputOutputTensor( +- outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i], +- enableGraphCapture); ++ outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + } + + let inputValuesIndex = inputValuesOffset / 4; +@@ -476,7 +449,7 @@ export const run = async( + wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; + } + +- if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState && !inputOutputBound) { ++ if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; + + if (inputNamesUTF8Encoded.length !== inputCount) { +@@ -513,13 +486,10 @@ export const run = async( + } + } + } +- activeSessions.set( +- sessionId, +- [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]); + } + +- wasm.jsepOnRunStart?.(sessionHandle); + let errorCode: number; ++ + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + errorCode = await wasm._OrtRunWithBinding( + sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); +@@ -625,12 +595,10 @@ export const run = async( + } + } + +- if (ioBindingState && !enableGraphCapture) { ++ if (ioBindingState) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); +- activeSessions.set( +- sessionId, +- [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, false]); + } ++ + return output; + } finally { + wasm.stackRestore(beforeRunStack); +diff --git a/js/web/package-lock.json b/js/web/package-lock.json +index 41c44aaa2..74cd0d81a 100644 +--- a/js/web/package-lock.json ++++ b/js/web/package-lock.json +@@ -1,12 +1,12 @@ + { + "name": "onnxruntime-web", +- "version": "1.18.0", ++ "version": "1.17.0", + "lockfileVersion": 2, + "requires": true, + "packages": { + "": { + "name": "onnxruntime-web", +- "version": "1.18.0", ++ "version": "1.17.0", + "license": "MIT", + "dependencies": { + "flatbuffers": "^1.12.0", +@@ -49,7 +49,7 @@ + }, + "../common": { + "name": "onnxruntime-common", +- "version": "1.18.0", ++ "version": "1.17.0", + "license": "MIT", + "devDependencies": { + "typedoc": "^0.23.22" +diff --git a/js/web/package.json b/js/web/package.json +index a502c2b6b..047de3829 100644 +--- a/js/web/package.json ++++ b/js/web/package.json +@@ -8,7 +8,7 @@ + "type": "git" + }, + "author": "fs-eire", +- "version": "1.18.0", ++ "version": "1.17.0", + "jsdelivr": "dist/ort.min.js", + "dependencies": { + "flatbuffers": "^1.12.0", +diff --git a/js/web/script/build.ts b/js/web/script/build.ts +index d3652f382..ea0c122cb 100644 +--- a/js/web/script/build.ts ++++ b/js/web/script/build.ts +@@ -44,6 +44,7 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // /js/ + const DEFAULT_DEFINE = { + 'BUILD_DEFS.DISABLE_WEBGL': 'false', + 'BUILD_DEFS.DISABLE_WEBGPU': 'false', ++ 'BUILD_DEFS.DISABLE_WEBNN': 'false', + 'BUILD_DEFS.DISABLE_WASM': 'false', + 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', + 'BUILD_DEFS.DISABLE_WASM_THREAD': 'false', +@@ -363,6 +364,7 @@ async function main() { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_WEBGPU': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', ++ 'BUILD_DEFS.DISABLE_WEBNN': 'true', + 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', + 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', + }, +@@ -395,7 +397,7 @@ async function main() { + // ort.webgpu[.min].js + await addAllWebBuildTasks({ + outputBundleName: 'ort.webgpu', +- define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'}, ++ define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'true'}, + }); + // ort.wasm[.min].js + await addAllWebBuildTasks({ +@@ -409,6 +411,7 @@ async function main() { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_WEBGPU': 'true', + 'BUILD_DEFS.DISABLE_WASM': 'true', ++ 'BUILD_DEFS.DISABLE_WEBNN': 'true', + }, + }); + // ort.wasm-core[.min].js +@@ -418,6 +421,7 @@ async function main() { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_WEBGPU': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', ++ 'BUILD_DEFS.DISABLE_WEBNN': 'true', + 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', + 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', + }, +@@ -430,6 +434,7 @@ async function main() { + 'BUILD_DEFS.DISABLE_TRAINING': 'false', + 'BUILD_DEFS.DISABLE_WEBGPU': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', ++ 'BUILD_DEFS.DISABLE_WEBNN': 'true', + }, + }); + } +diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts +index ed4dd76a6..8f6c5f6f0 100644 +--- a/js/web/script/test-runner-cli-args.ts ++++ b/js/web/script/test-runner-cli-args.ts +@@ -396,6 +396,10 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs + + const globalEnvFlags = parseGlobalEnvFlags(args); + ++ if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) { ++ throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.'); ++ } ++ + // Options: + // --log-verbose=<...> + // --log-info=<...> +diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts +index 9105c0241..d56792c6e 100644 +--- a/js/web/script/test-runner-cli.ts ++++ b/js/web/script/test-runner-cli.ts +@@ -12,7 +12,6 @@ import * as os from 'os'; + import * as path from 'path'; + import {inspect} from 'util'; + +-import {onnx} from '../lib/onnxjs/ort-schema/protobuf/onnx'; + import {bufferToBase64} from '../test/test-shared'; + import {Test} from '../test/test-types'; + +@@ -265,12 +264,10 @@ async function main() { + + let modelUrl: string|null = null; + let cases: Test.ModelTestCase[] = []; +- let externalData: Array<{data: string; path: string}>|undefined; + + npmlog.verbose('TestRunnerCli.Init.Model', `Start to prepare test data from folder: ${testDataRootFolder}`); + + try { +- const maybeExternalDataFiles: Array<[fileNameWithoutExtension: string, size: number]> = []; + for (const thisPath of fs.readdirSync(testDataRootFolder)) { + const thisFullPath = path.join(testDataRootFolder, thisPath); + const stat = fs.lstatSync(thisFullPath); +@@ -285,8 +282,6 @@ async function main() { + } else { + throw new Error('there are multiple model files under the folder specified'); + } +- } else { +- maybeExternalDataFiles.push([path.parse(thisPath).name, stat.size]); + } + } else if (stat.isDirectory()) { + const dataFiles: string[] = []; +@@ -312,34 +307,6 @@ async function main() { + if (modelUrl === null) { + throw new Error('there are no model file under the folder specified'); + } +- // for performance consideration, we do not parse every model. when we think it's likely to have external +- // data, we will parse it. We think it's "likely" when one of the following conditions is met: +- // 1. any file in the same folder has the similar file name as the model file +- // (e.g., model file is "model_abc.onnx", and there is a file "model_abc.pb" or "model_abc.onnx.data") +- // 2. the file size is larger than 1GB +- const likelyToHaveExternalData = maybeExternalDataFiles.some( +- ([fileNameWithoutExtension, size]) => +- path.basename(modelUrl!).startsWith(fileNameWithoutExtension) || size >= 1 * 1024 * 1024 * 1024); +- if (likelyToHaveExternalData) { +- const model = onnx.ModelProto.decode(fs.readFileSync(path.join(testDataRootFolder, path.basename(modelUrl!)))); +- const externalDataPathSet = new Set(); +- for (const initializer of model.graph!.initializer!) { +- if (initializer.externalData) { +- for (const data of initializer.externalData) { +- if (data.key === 'location') { +- externalDataPathSet.add(data.value!); +- } +- } +- } +- } +- externalData = []; +- const externalDataPaths = [...externalDataPathSet]; +- for (const dataPath of externalDataPaths) { +- const fullPath = path.resolve(testDataRootFolder, dataPath); +- const url = path.join(TEST_DATA_BASE, path.relative(TEST_ROOT, fullPath)); +- externalData.push({data: url, path: dataPath}); +- } +- } + } catch (e) { + npmlog.error('TestRunnerCli.Init.Model', `Failed to prepare test data. Error: ${inspect(e)}`); + throw e; +@@ -373,23 +340,9 @@ async function main() { + npmlog.verbose('TestRunnerCli.Init.Model', ` Model file: ${modelUrl}`); + npmlog.verbose('TestRunnerCli.Init.Model', ` Backend: ${backend}`); + npmlog.verbose('TestRunnerCli.Init.Model', ` Test set(s): ${cases.length} (${caseCount})`); +- if (externalData) { +- npmlog.verbose('TestRunnerCli.Init.Model', ` External data: ${externalData.length}`); +- for (const data of externalData) { +- npmlog.verbose('TestRunnerCli.Init.Model', ` - ${data.path}`); +- } +- } + npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); + +- return { +- name: path.basename(testDataRootFolder), +- platformCondition, +- modelUrl, +- backend, +- cases, +- ioBinding, +- externalData +- }; ++ return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases, ioBinding}; + } + + function tryLocateModelTestFolder(searchPattern: string): string { +diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc +index 6a10e3b96..812e9d7c2 100644 +--- a/js/web/test/data/ops/fused-conv.jsonc ++++ b/js/web/test/data/ops/fused-conv.jsonc +@@ -108,327 +108,5 @@ + ] + } + ] +- }, +- { +- "name": "fused conv with clip", +- "operator": "FusedConv", +- "attributes": [ +- { "name": "activation", "data": "Clip", "type": "string" }, +- { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, +- { "name": "activation_params", "data": [400.0, 600.0], "type": "floats" } +- ], +- "opset": { "domain": "com.microsoft", "version": 1 }, +- "cases": [ +- { +- "name": "T[0]", +- "inputs": [ +- { +- "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], +- "dims": [1, 1, 3, 3], +- "type": "float32" +- }, +- { +- "data": [1, 2, 3, 4], +- "dims": [1, 1, 2, 2], +- "type": "float32" +- } +- ], +- "outputs": [ +- { +- "data": [400, 470, 600, 600], +- "dims": [1, 1, 2, 2], +- "type": "float32" +- } +- ] +- } +- ] +- }, +- { +- "name": "fused conv with HardSigmoid", +- "operator": "FusedConv", +- "attributes": [ +- { "name": "activation", "data": "HardSigmoid", "type": "string" }, +- { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, +- { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } +- ], +- "opset": { "domain": "com.microsoft", "version": 1 }, +- "cases": [ +- { +- "name": "T[0]", +- "inputs": [ +- { +- "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], +- "dims": [1, 1, 3, 3], +- "type": "float32" +- }, +- { +- "data": [1, 2, 3, 4], +- "dims": [1, 1, 2, 2], +- "type": "float32" +- } +- ], +- "outputs": [ +- { +- "data": [0, 0, 1, 1], +- "dims": [1, 1, 2, 2], +- "type": "float32" +- } +- ] +- } +- ] +- }, +- { +- "name": "NHWC conv with HardSigmoid", +- "operator": "Conv", +- "attributes": [ +- { "name": "activation", "data": "HardSigmoid", "type": "string" }, +- { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, +- { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } +- ], +- "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, +- "cases": [ +- { +- "name": "T[0]", +- "inputs": [ +- { +- "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], +- "dims": [1, 3, 3, 1], +- "type": "float32" +- }, +- { +- "data": [1, 2, 3, 4], +- "dims": [1, 1, 2, 2], +- "type": "float32" +- } +- ], +- "outputs": [ +- { +- "data": [0, 0, 1, 1], +- "dims": [1, 2, 2, 1], +- "type": "float32" +- } +- ] +- } +- ] +- }, +- { +- "name": "fused group-conv with HardSigmoid", +- "operator": "FusedConv", +- "attributes": [ +- { "name": "activation", "data": "HardSigmoid", "type": "string" }, +- { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, +- { "name": "group", "data": 3, "type": "int" }, +- { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } +- ], +- "opset": { "domain": "com.microsoft", "version": 1 }, +- "cases": [ +- { +- "name": "T[0]", +- "inputs": [ +- { +- "data": [ +- 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, +- 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 +- ], +- "dims": [1, 3, 3, 3], +- "type": "float32" +- }, +- { +- "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], +- "dims": [3, 1, 2, 2], +- "type": "float32" +- } +- ], +- "outputs": [ +- { +- "data": [1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1], +- "dims": [1, 3, 2, 2], +- "type": "float32" +- } +- ] +- } +- ] +- }, +- { +- "name": "NHWC group-conv with HardSigmoid", +- "operator": "Conv", +- "attributes": [ +- { "name": "activation", "data": "HardSigmoid", "type": "string" }, +- { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, +- { "name": "group", "data": 3, "type": "int" }, +- { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } +- ], +- "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, +- "cases": [ +- { +- "name": "T[0]", +- "inputs": [ +- { +- "data": [ +- 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, +- 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 +- ], +- "dims": [1, 3, 3, 3], +- "type": "float32" +- }, +- { +- "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], +- "dims": [3, 1, 2, 2], +- "type": "float32" +- } +- ], +- "outputs": [ +- { +- "data": [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], +- "dims": [1, 2, 2, 3], +- "type": "float32" +- } +- ] +- } +- ] +- }, +- { +- "name": "fused group-conv with LeakyRelu", +- "operator": "FusedConv", +- "attributes": [ +- { "name": "activation", "data": "LeakyRelu", "type": "string" }, +- { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, +- { "name": "group", "data": 3, "type": "int" }, +- { "name": "activation_params", "data": [2.0], "type": "floats" } +- ], +- "opset": { "domain": "com.microsoft", "version": 1 }, +- "cases": [ +- { +- "name": "T[0]", +- "inputs": [ +- { +- "data": [ +- 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, +- 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 +- ], +- "dims": [1, 3, 3, 3], +- "type": "float32" +- }, +- { +- "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], +- "dims": [3, 1, 2, 2], +- "type": "float32" +- } +- ], +- "outputs": [ +- { +- "data": [9, -6, 51, 47, -170, -10, 251, 229, 847, 889, 973, 1015], +- "dims": [1, 3, 2, 2], +- "type": "float32" +- } +- ] +- } +- ] +- }, +- { +- "name": "NHWC group-conv with LeakyRelu", +- "operator": "Conv", +- "attributes": [ +- { "name": "activation", "data": "LeakyRelu", "type": "string" }, +- { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, +- { "name": "group", "data": 3, "type": "int" }, +- { "name": "activation_params", "data": [2.0], "type": "floats" } +- ], +- "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, +- "cases": [ +- { +- "name": "T[0]", +- "inputs": [ +- { +- "data": [ +- 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, +- 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 +- ], +- "dims": [1, 3, 3, 3], +- "type": "float32" +- }, +- { +- "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], +- "dims": [3, 1, 2, 2], +- "type": "float32" +- } +- ], +- "outputs": [ +- { +- "data": [-162, 63, -158, 33, 281, 85, 105, 337, 455, 177, 515, 609], +- "dims": [1, 2, 2, 3], +- "type": "float32" +- } +- ] +- } +- ] +- }, +- { +- "name": "fused conv with LeakyRelu", +- "operator": "FusedConv", +- "attributes": [ +- { "name": "activation", "data": "LeakyRelu", "type": "string" }, +- { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, +- { "name": "activation_params", "data": [2.0], "type": "floats" } +- ], +- "opset": { "domain": "com.microsoft", "version": 1 }, +- "cases": [ +- { +- "name": "T[0]", +- "inputs": [ +- { +- "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], +- "dims": [1, 1, 3, 3], +- "type": "float32" +- }, +- { +- "data": [1, 2, 3, 4], +- "dims": [1, 1, 2, 2], +- "type": "float32" +- } +- ], +- "outputs": [ +- { +- "data": [-540, -860, 390, 430], +- "dims": [1, 1, 2, 2], +- "type": "float32" +- } +- ] +- } +- ] +- }, +- { +- "name": "NHWC conv with LeakyRelu", +- "operator": "Conv", +- "attributes": [ +- { "name": "activation", "data": "LeakyRelu", "type": "string" }, +- { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, +- { "name": "activation_params", "data": [2.0], "type": "floats" } +- ], +- "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, +- "cases": [ +- { +- "name": "T[0]", +- "inputs": [ +- { +- "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], +- "dims": [1, 3, 3, 1], +- "type": "float32" +- }, +- { +- "data": [1, 2, 3, 4], +- "dims": [1, 1, 2, 2], +- "type": "float32" +- } +- ], +- "outputs": [ +- { +- "data": [-540, -860, 390, 430], +- "dims": [1, 2, 2, 1], +- "type": "float32" +- } +- ] +- } +- ] + } + ] +diff --git a/js/web/test/data/ops/tanh.jsonc b/js/web/test/data/ops/tanh.jsonc +deleted file mode 100644 +index f7691535b..000000000 +--- a/js/web/test/data/ops/tanh.jsonc ++++ /dev/null +@@ -1,26 +0,0 @@ +-[ +- { +- "name": "tanh with no attributes", +- "operator": "Tanh", +- "attributes": [], +- "cases": [ +- { +- "name": "T[2,4]", +- "inputs": [ +- { +- "data": [-1000, -1, 0, 0.1, 0.2, 0.3, 0.4, 1000], +- "dims": [2, 4], +- "type": "float32" +- } +- ], +- "outputs": [ +- { +- "data": [-1, -0.761594, 0, 0.099668, 0.197375, 0.291313, 0.379949, 1], +- "dims": [2, 4], +- "type": "float32" +- } +- ] +- } +- ] +- } +-] +diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc +index 56db28b0a..033b3b3f4 100644 +--- a/js/web/test/suite-test-list.jsonc ++++ b/js/web/test/suite-test-list.jsonc +@@ -597,9 +597,9 @@ + // // "test_hardmax_example", + // // "test_hardmax_negative_axis", + // // "test_hardmax_one_hot", +- "test_hardsigmoid_default", +- "test_hardsigmoid_example", +- "test_hardsigmoid", ++ // // "test_hardsigmoid_default", ++ // // "test_hardsigmoid_example", ++ // // "test_hardsigmoid", + // // "test_hardswish_expanded", + // // "test_hardswish", + "test_if", +@@ -1389,7 +1389,6 @@ + "sub.jsonc", + "sub_int32.jsonc", + "tan.jsonc", +- "tanh.jsonc", + "tile.jsonc", + "transpose.jsonc", + "transpose_int32_uint32.jsonc", +diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts +index b01d47478..442cb1bcf 100644 +--- a/js/web/test/test-runner.ts ++++ b/js/web/test/test-runner.ts +@@ -138,8 +138,8 @@ async function loadTensors( + + async function initializeSession( + modelFilePath: string, backendHint: ort.InferenceSession.ExecutionProviderConfig, ioBindingMode: Test.IOBindingMode, +- profile: boolean, externalData: ort.InferenceSession.SessionOptions['externalData'], +- sessionOptions: ort.InferenceSession.SessionOptions, fileCache?: FileCacheBuffer): Promise { ++ profile: boolean, sessionOptions: ort.InferenceSession.SessionOptions, ++ fileCache?: FileCacheBuffer): Promise { + const preloadModelData: Uint8Array|undefined = + fileCache && fileCache[modelFilePath] ? fileCache[modelFilePath] : undefined; + Logger.verbose( +@@ -153,8 +153,7 @@ async function initializeSession( + executionProviders: [backendHint], + profiler: profilerConfig, + enableProfiling: profile, +- preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, +- externalData ++ preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined + }; + + let session: ort.InferenceSession; +@@ -247,8 +246,8 @@ export class ModelTestContext { + const executionProviderConfig = + modelTest.backend === 'webnn' ? (testOptions?.webnnOptions || 'webnn') : modelTest.backend!; + const session = await initializeSession( +- modelTest.modelUrl, executionProviderConfig, modelTest.ioBinding, profile, modelTest.externalData, +- testOptions?.sessionOptions || {}, this.cache); ++ modelTest.modelUrl, executionProviderConfig, modelTest.ioBinding, profile, testOptions?.sessionOptions || {}, ++ this.cache); + + const initEnd = now(); + +diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts +index 14b9fd7c0..cd008e82e 100644 +--- a/js/web/test/test-types.ts ++++ b/js/web/test/test-types.ts +@@ -65,7 +65,6 @@ export declare namespace Test { + export interface ModelTest { + name: string; + modelUrl: string; +- externalData?: InferenceSession.SessionOptions['externalData']; + backend?: string; // value should be populated at build time + ioBinding: IOBindingMode; + platformCondition?: PlatformCondition; +diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py +index c3699f0fb..57219c50f 100644 +--- a/onnxruntime/__init__.py ++++ b/onnxruntime/__init__.py +@@ -7,7 +7,7 @@ ONNX Runtime is a performance-focused scoring engine for Open Neural Network Exc + For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_ + or the `Github project `_. + """ +-__version__ = "1.18.0" ++__version__ = "1.17.0" + __author__ = "Microsoft" + + # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). +diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +index 166f5c8f5..72948c74d 100644 +--- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc ++++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +@@ -9,7 +9,6 @@ + #include "core/mlas/inc/mlas_q4.h" + #include "core/providers/cpu/math/matmul_helper.h" + #include "core/providers/common.h" +- + #ifdef ORT_NEURAL_SPEED + #include "contrib_ops/cpu/quantization/neural_speed_gemm.h" + #endif +@@ -17,39 +16,6 @@ + namespace onnxruntime { + namespace contrib { + +-namespace { +-int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { +- const auto accuracy_level = std::clamp(accuracy_level_attr, +- static_cast(CompMostAccurate), +- static_cast(CompLeastAccurate)); +- +-#if defined(ORT_NEURAL_SPEED) +- +- ORT_UNUSED_PARAMETER(nbits); +- ORT_UNUSED_PARAMETER(block_size); +- +- // Neural Speed APIs already expect a minimum accuracy level so just use the given value. +- return accuracy_level; +- +-#else // defined(ORT_NEURAL_SPEED) +- +- // Find a supported accuracy level that is not less accurate than the one given. +- // CompMostAccurate is always supported with the fallback implementation. +- // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. +- int64_t effective_accuracy_level = accuracy_level; +- for (; effective_accuracy_level > CompMostAccurate; --effective_accuracy_level) { +- const auto compute_type = static_cast(effective_accuracy_level); +- if (MlasIsSQNBitGemmAvailable(nbits, block_size, compute_type)) { +- break; +- } +- } +- +- return effective_accuracy_level; +- +-#endif // defined(ORT_NEURAL_SPEED) +-} +-} // namespace +- + class MatMulNBits final : public OpKernel { + public: + MatMulNBits(const OpKernelInfo& info) +@@ -58,7 +24,7 @@ class MatMulNBits final : public OpKernel { + N_{narrow(info.GetAttr("N"))}, + block_size_{narrow(info.GetAttr("block_size"))}, + nbits_{narrow(info.GetAttr("bits"))}, +- accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { ++ accuracy_level_{info.GetAttr("accuracy_level")} { + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + #ifdef ORT_NEURAL_SPEED +@@ -92,22 +58,17 @@ class MatMulNBits final : public OpKernel { + const bool column_wise_quant_{true}; + IAllocatorUniquePtr packed_b_; + size_t packed_b_size_{0}; +- +-#if defined(ORT_NEURAL_SPEED) +- ++#ifdef ORT_NEURAL_SPEED + bool is_asym_{false}; + bool all_constant_{false}; +- +-#endif // defined(ORT_NEURAL_SPEED) ++#endif + }; + + Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; +- +-#if defined(ORT_NEURAL_SPEED) +- ++#ifdef ORT_NEURAL_SPEED + if (!all_constant_) { + return Status::OK(); + } +@@ -155,17 +116,11 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat + #else // defined(ORT_NEURAL_SPEED) + + if (input_idx == 1) { +- const auto compute_type = static_cast(accuracy_level_); +- if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { +- return Status::OK(); +- } +- packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type); +- if (packed_b_size_ == 0) { +- return Status::OK(); +- } ++ packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_); ++ if (packed_b_size_ == 0) return Status::OK(); + auto qptr = tensor.DataRaw(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); +- MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get()); ++ MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, qptr, packed_b_.get()); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); +@@ -181,9 +136,7 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat + Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; +- +-#if defined(ORT_NEURAL_SPEED) +- ++#ifdef ORT_NEURAL_SPEED + // Pack three tensors into one buffer + if (input_idx == 1) { + used_shared_buffers = true; +@@ -206,7 +159,6 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep + } + + #endif // defined(ORT_NEURAL_SPEED) +- + return Status::OK(); + } + +@@ -215,10 +167,8 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { + + const Tensor* a = ctx->Input(0); + const auto* a_data = a->Data(); +- +-#if defined(ORT_NEURAL_SPEED) +- +- if (packed_b_) { ++#ifdef ORT_NEURAL_SPEED ++ if (packed_b_.get()) { + TensorShape b_shape({static_cast(N_), static_cast(K_)}); + + MatMulComputeHelper helper; +@@ -284,43 +234,37 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { + const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), + [](size_t offset) { return offset == 0; }); + +- if (has_single_b_matrix) { +- const auto compute_type = static_cast(accuracy_level_); +- +- if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { +- IAllocatorUniquePtr workspace{}; +- if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, +- nbits_, block_size_, compute_type); +- workspace_size > 0) { +- AllocatorPtr allocator; +- ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); +- workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); +- } +- +- const void* b_data = [&]() -> const void* { +- if (packed_b_) { +- return packed_b_.get(); ++ if (has_single_b_matrix && packed_b_) { ++ for (int64_t accuracy_level = accuracy_level_; ++ accuracy_level >= static_cast(CompMostAccurate); ++ --accuracy_level) { ++ const auto compute_type = static_cast(accuracy_level); ++ if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) { ++ IAllocatorUniquePtr workspace{}; ++ if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, ++ nbits_, block_size_, compute_type); ++ workspace_size > 0) { ++ AllocatorPtr allocator; ++ ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); ++ workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + } + +- const Tensor* b = ctx->Input(1); +- return b->DataRaw(); +- }(); +- +- InlinedVector data(batch_count); +- for (size_t i = 0; i < batch_count; ++i) { +- data[i].A = a_data + helper.LeftOffsets()[i]; +- data[i].lda = lda; +- data[i].QuantBData = b_data; +- data[i].QuantBScale = scales_data; +- data[i].QuantBZeroPoint = zero_points_data; +- data[i].C = y_data + helper.OutputOffsets()[i]; +- data[i].ldc = N; +- } ++ InlinedVector data(batch_count); ++ for (size_t i = 0; i < batch_count; ++i) { ++ data[i].A = a_data + helper.LeftOffsets()[i]; ++ data[i].lda = lda; ++ data[i].QuantBData = packed_b_.get(); ++ data[i].QuantBScale = scales_data; ++ data[i].QuantBZeroPoint = zero_points_data; ++ data[i].C = y_data + helper.OutputOffsets()[i]; ++ data[i].ldc = N; ++ } + +- MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), +- thread_pool); ++ MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), ++ thread_pool); + +- return Status::OK(); ++ return Status::OK(); ++ } + } + } + +diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +index b18e12298..56d950ca2 100644 +--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h ++++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +@@ -258,7 +258,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch + cpu_state.sequences.InitDevice(beam_state.sequences_device); + ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), + cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), +- this->ort_stream_, ++ nullptr, + DeviceCopyDirection::hostToDevice)); + } + +@@ -397,8 +397,12 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch + output_sequences_scores); + + // Output per token scores +- gsl::span per_token_scores = beam_state.scores; +- this->beam_scorer_->OutputScores(per_token_scores, output_scores); ++ if (output_scores) { ++ gsl::span target = output_scores->MutableDataAsSpan(); ++ gsl::span source = beam_state.scores; ++ assert(target.size() == source.size()); ++ ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); ++ } + + return status; + } +diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +index 8f5cdc97f..94547887d 100644 +--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h ++++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +@@ -214,7 +214,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches + cpu_state.sequences.InitDevice(beam_state.sequences_device); + ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), + cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), +- this->ort_stream_, ++ nullptr, + DeviceCopyDirection::hostToDevice)); + } + +@@ -404,8 +404,12 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches + output_sequences_scores); + + // Output per token scores +- gsl::span per_token_scores = beam_state.scores; +- this->beam_scorer_->OutputScores(per_token_scores, output_scores); ++ if (output_scores) { ++ gsl::span target = output_scores->MutableDataAsSpan(); ++ gsl::span source = beam_state.scores; ++ assert(target.size() == source.size()); ++ ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); ++ } + + return status; + } +diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +index 72e6d3930..91b93a125 100644 +--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h ++++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +@@ -226,7 +226,7 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe + cpu_state.sequences.InitDevice(beam_state.sequences_device); + ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), + cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), +- this->ort_stream_, ++ nullptr, + DeviceCopyDirection::hostToDevice)); + } + +@@ -500,8 +500,12 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe + output_sequences_scores); + + // Output per token scores +- gsl::span per_token_scores = beam_state.scores; +- this->beam_scorer_->OutputScores(per_token_scores, output_scores); ++ if (output_scores) { ++ gsl::span target = output_scores->MutableDataAsSpan(); ++ gsl::span source = beam_state.scores; ++ assert(target.size() == source.size()); ++ ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); ++ } + + return status; + } +diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +index bb6885c32..3962486d5 100644 +--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc ++++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +@@ -123,20 +123,8 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { + logits_processor = logits_processor_tensor ? static_cast(*logits_processor_tensor->Data()) : 0; + ORT_ENFORCE(logits_processor >= 0, + "logits_processor shall be a non-negative integer, got ", logits_processor); +- +- if (this->model_type == IGenerationParameters::kModelTypeWhisper) { +- auto* temperature_tensor = context->Input(14); +- if (temperature_tensor) { +- if (temperature_tensor->IsDataType()) { +- temperature = *temperature_tensor->Data(); +- } else { +- temperature = static_cast(*temperature_tensor->Data()); +- } +- } else { +- temperature = 1.0f; +- } +- } + } ++ + void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) { + // Override vocab_size using the inferred shape from the decoder subgraph ONLY IF + // the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch) +diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +index 0eccbe266..7e2e5b212 100644 +--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc ++++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +@@ -50,12 +50,11 @@ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) con + return beams_.back().score < current_score; + } + +-template + void BeamHypotheses::Output( + int top_k, + int max_length, +- gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) +- gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty ++ gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) ++ gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty + { + // Copy the top_k beams into the sequences + ORT_ENFORCE(top_k <= beams_used_); +@@ -68,7 +67,7 @@ void BeamHypotheses::Output( + gsl::copy(item.hypothesis, target); + + if (!sequences_scores.empty()) +- sequences_scores[index] = (T)item.score; ++ sequences_scores[index] = item.score; + } + } + +@@ -182,21 +181,21 @@ void BeamSearchScorer::Process(ISequences& sequences, + } + } + +-template +-void OutputSequenceScores(BeamSearchScorer* scorer, +- ISequences& sequences, +- gsl::span& final_beam_scores, +- Tensor* output_sequences, +- Tensor* output_sequence_scores) { ++void BeamSearchScorer::Finalize(ISequences& sequences, ++ gsl::span& final_beam_scores, ++ Tensor* output_sequences, ++ Tensor* output_sequence_scores) { ++ ORT_ENFORCE(output_sequences != nullptr); ++ + // Finalize all open beam hypotheses and add to generated hypotheses. +- for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) { +- BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index]; ++ for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) { ++ BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; + if (beam_hyp.done_) { + continue; + } + +- for (size_t beam_index = 0; beam_index < scorer->num_beams_; beam_index++) { +- size_t batch_beam_index = batch_index * scorer->num_beams_ + beam_index; ++ for (size_t beam_index = 0; beam_index < num_beams_; beam_index++) { ++ size_t batch_beam_index = batch_index * num_beams_ + beam_index; + float final_score = final_beam_scores[batch_beam_index]; + auto final_tokens = sequences.GetSequence(narrow(batch_beam_index)); + beam_hyp.Add(final_tokens, final_score); +@@ -207,59 +206,26 @@ void OutputSequenceScores(BeamSearchScorer* scorer, + gsl::span output = output_sequences->MutableDataAsSpan(); + + // Fill output sequences with pad token ID so that we do not need append it later. +- std::fill_n(output.data(), output.size(), scorer->pad_token_id_); ++ std::fill_n(output.data(), output.size(), pad_token_id_); + + // Score of each sequence, with shape (batch_size * num_return_sequences). +- gsl::span sequence_scores; ++ gsl::span sequence_scores; + if (output_sequence_scores) { +- sequence_scores = output_sequence_scores->MutableDataAsSpan(); ++ sequence_scores = output_sequence_scores->MutableDataAsSpan(); + } + + // Select the best hypotheses according to number of sequences to return. +- for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) { +- BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index]; ++ for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) { ++ BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; + +- auto batch_output = output.subspan(batch_index * scorer->num_return_sequences_ * scorer->max_length_, +- scorer->num_return_sequences_ * scorer->max_length_); +- gsl::span sequence_scores_buffer; ++ auto batch_output = output.subspan(batch_index * num_return_sequences_ * max_length_, ++ num_return_sequences_ * max_length_); ++ gsl::span sequence_scores_buffer; + if (!sequence_scores.empty()) +- sequence_scores_buffer = sequence_scores.subspan(batch_index * scorer->num_return_sequences_, scorer->num_return_sequences_); +- +- beam_hyp.template Output(narrow(scorer->num_return_sequences_), narrow(scorer->max_length_), batch_output, +- sequence_scores_buffer); +- } +-} +- +-void BeamSearchScorer::Finalize(ISequences& sequences, +- gsl::span& final_beam_scores, +- Tensor* output_sequences, +- Tensor* output_sequence_scores) { +- ORT_ENFORCE(output_sequences != nullptr); ++ sequence_scores_buffer = sequence_scores.subspan(batch_index * num_return_sequences_, num_return_sequences_); + +- if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType()) { +- OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); +- } else { +- ORT_ENFORCE(output_sequence_scores->IsDataType()); +- OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); +- } +-} +- +-void BeamSearchScorer::OutputScores(gsl::span& final_scores, Tensor* output_scores) { +- if (output_scores) { +- if (output_scores->IsDataType()) { +- gsl::span target = output_scores->MutableDataAsSpan(); +- ORT_ENFORCE(target.size() == final_scores.size()); +- std::copy_n(final_scores.data(), final_scores.size(), target.data()); +- } else { +- ORT_ENFORCE(output_scores->IsDataType()); +- gsl::span target = output_scores->MutableDataAsSpan(); +- ORT_ENFORCE(target.size() == final_scores.size()); +- const float* src = final_scores.data(); +- MLFloat16* dst = target.data(); +- for (size_t i = 0; i < target.size(); i++) { +- dst[i] = MLFloat16(src[i]); +- } +- } ++ beam_hyp.Output(narrow(num_return_sequences_), narrow(max_length_), batch_output, ++ sequence_scores_buffer); + } + } + +diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +index dc92e8038..94b6d340d 100644 +--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h ++++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +@@ -35,11 +35,10 @@ struct BeamHypotheses { + bool CanImprove(float best_sum_logprobs, int current_length) const; + + // Output results +- template +- void Output(int top_k, // number of sequences to return +- int max_length, // max sequence length +- gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length) +- gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) ++ void Output(int top_k, // number of sequences to return ++ int max_length, // max sequence length ++ gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length) ++ gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) + + gsl::span beams_; // Beam width sized array of hypotheses, sorted by highest scoring + int beams_used_; // Number of elements used in beams_ +@@ -61,14 +60,13 @@ struct BeamSearchScorer : IBeamScorer { + Tensor* output_sequences, + Tensor* output_sequence_scores) override; + +- void OutputScores(gsl::span& final_scores, Tensor* output_scores) override; +- + bool IsDone() const override { return not_done_count_ == 0; } + + gsl::span GetNextScores() override { return next_beam_scores_; } + gsl::span GetNextTokens() override { return next_beam_tokens_; } + gsl::span GetNextIndicesCPU() override { return next_beam_indices_; } + ++ private: + size_t batch_size_; + size_t num_beams_; + size_t max_length_; +diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +index cb62e2f7b..f6faf2e32 100644 +--- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h ++++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +@@ -120,9 +120,6 @@ struct IBeamScorer { + Tensor* output_sequences, + Tensor* output_sequence_scores) = 0; + +- virtual void OutputScores(gsl::span& final_scores, +- Tensor* output_scores) = 0; +- + virtual bool IsDone() const = 0; // GPU version will return false here, as it asynchronously queues up the event + virtual bool IsDoneLater() const { return false; } // GPU version waits for the asynchous result to complete here + +diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +index c74e9160c..f39f090c7 100644 +--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc ++++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +@@ -17,6 +17,14 @@ namespace onnxruntime { + namespace contrib { + namespace transformers { + ++#ifdef DEBUG_GENERATION ++template ++void DumpScores(const char* name, const NextTokenScores& next_token_scores) { ++ std::cout << name << std::endl; ++ ORT_UNUSED_PARAMETER(next_token_scores); ++} ++#endif ++ + // Interface for all scorers for beam search or beam sample. + template + MinLengthLogitsProcessor::MinLengthLogitsProcessor(int min_length, int eos_token_id) +@@ -28,6 +36,10 @@ void MinLengthLogitsProcessor::Process(const ISequences* sequences, + if (sequences->GetSequenceLength() < min_length_) { + next_token_scores.SetScore(eos_token_id_, std::numeric_limits::lowest()); + } ++ ++#ifdef DEBUG_GENERATION ++ DumpScores("MinLengthLogitsProcessor", next_token_scores); ++#endif + } + + template +@@ -56,6 +68,10 @@ void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences, + beam_token_scores[word_id] = (score < 0 ? score * penalty_ : score / penalty_); + } + } ++ ++#ifdef DEBUG_GENERATION ++ DumpScores("RepetitionPenaltyLogitsProcessor", next_token_scores); ++#endif + } + + template +@@ -93,6 +109,10 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, + beam_token_scores[word_id] = std::numeric_limits::lowest(); + } + } ++ ++#ifdef DEBUG_GENERATION ++ DumpScores("NoRepeatNGramLogitsProcessor", next_token_scores); ++#endif + } + + template +@@ -116,6 +136,10 @@ void VocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, + } + } + } ++ ++#ifdef DEBUG_GENERATION ++ DumpScores("VocabMaskLogitsProcessor", next_token_scores); ++#endif + } + + template +@@ -147,6 +171,10 @@ void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, + } + } + } ++ ++#ifdef DEBUG_GENERATION ++ DumpScores("PrefixVocabMaskLogitsProcessor", next_token_scores); ++#endif + } + + template +@@ -165,6 +193,10 @@ void TemperatureLogitsProcessor::Process(const ISequences* /*sequences*/, + *p /= temperature_; + ++p; + } ++ ++#ifdef DEBUG_GENERATION ++ DumpScores("TemperatureLogitsProcessor", next_token_scores); ++#endif + } + + template +@@ -186,6 +218,10 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences*, + for (size_t i = 0; i < next_token_scores.scores.size(); i++) { + *p -= presence_mask_[i] * presence_penalty_; + } ++ ++#ifdef DEBUG_GENERATION ++ DumpScores("PresencePenaltyLogitsProcessor", next_token_scores); ++#endif + } + + void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { +diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +index 03d4e89ac..4688ff272 100644 +--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h ++++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +@@ -265,6 +265,10 @@ class TimestampLogitsProcessor : public ILogitsProcessor { + } + } + } ++ ++#ifdef DEBUG_GENERATION ++ DumpScores("TimestampLogitsProcessor", next_token_scores); ++#endif + } + + private: +diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +index dea5391c7..87e88ac31 100644 +--- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc ++++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +@@ -24,8 +24,7 @@ namespace { + + template + struct DispatchGroupNorm { +- Status operator()(CudaTuningContext* tuning_ctx, +- Stream* ort_stream, ++ Status operator()(cudaStream_t stream, + Tensor* output, + Tensor* add_out, + const Tensor* input, +@@ -45,8 +44,7 @@ struct DispatchGroupNorm { + int channels_per_block) { + typedef typename ToCudaType::MappedType CudaT; + return LaunchGroupNormKernel( +- tuning_ctx, +- ort_stream, ++ stream, + reinterpret_cast(output->MutableData()), + add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), + reinterpret_cast(input->Data()), +@@ -211,8 +209,7 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { + context->GetComputeStream()); + + utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); +- return dispatcher.InvokeRet(GetTuningContext(), +- context->GetComputeStream(), output, add_out, input, skip, bias, ++ return dispatcher.InvokeRet(Stream(context), output, add_out, input, skip, bias, + gamma, beta, workspace.get(), + epsilon_, + batch_size, +diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc +deleted file mode 100644 +index 5dec69052..000000000 +--- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.cc ++++ /dev/null +@@ -1,101 +0,0 @@ +-/* +- * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +- * SPDX-License-Identifier: Apache-2.0 +- * +- * Licensed under the Apache License, Version 2.0 (the "License"); +- * you may not use this file except in compliance with the License. +- * You may obtain a copy of the License at +- * +- * http://www.apache.org/licenses/LICENSE-2.0 +- * +- * Unless required by applicable law or agreed to in writing, software +- * distributed under the License is distributed on an "AS IS" BASIS, +- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +- * See the License for the specific language governing permissions and +- * limitations under the License. +- */ +- +-// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +-// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +-// Copyright (c) Microsoft Corporation. All rights reserved. +-// Licensed under the MIT License. +- +-#include "contrib_ops/cuda/diffusion/group_norm_common_base.h" +- +-using namespace onnxruntime::cuda; +- +-namespace onnxruntime { +-namespace contrib { +-namespace cuda { +- +-int NextSize(int x) { +- for (size_t i = 0; i < kNumOfSizes; ++i) { +- if (x <= kSizes[i]) { +- return kSizes[i]; +- } +- } +- +- return x; +-} +- +-int32_t GetThreadsPerBlock(int32_t channels_per_block, int32_t channels_per_thread) { +- return NextSize(channels_per_block) / channels_per_thread; +-} +- +-int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { +- int32_t max_divisor = -1; +- for (int32_t i = 1; i <= std::sqrt(n); i++) { +- if (n % i == 0) { +- int32_t divisor1 = n / i; +- int32_t divisor2 = i; +- +- if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { +- max_divisor = divisor1; +- } +- if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { +- max_divisor = divisor2; +- } +- } +- } +- return max_divisor; +-} +- +-// Find proper channels per block based on a cost function: The cost is number of channels corresponding to +-// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has +-// work to do so it is ideal case. +-int FindChannelsPerBlock(int num_channels, int channels_per_group) { +- int min_cost = -1; +- int best_candidate = -1; +- for (size_t i = kNumOfSizes; i > 0; --i) { +- if (kSizes[i - 1] < channels_per_group) { +- break; +- } +- +- int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; +- int blocks = (num_channels + channels_per_block - 1) / channels_per_block; +- int cost = blocks * kSizes[i - 1] - num_channels; +- if (cost == 0) { +- return channels_per_block; +- } +- +- if (min_cost == -1 || cost < min_cost) { +- min_cost = cost; +- best_candidate = channels_per_block; +- } +- } +- +- return best_candidate; +-} +- +-int GetChannelsPerBlock(int num_channels, int num_groups) { +- int32_t channels_per_group = num_channels / num_groups; +- int32_t channels_per_block = channels_per_group; +- if (channels_per_group < kMaxSize / 2) { +- channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); +- } +- return channels_per_block; +-} +- +-} // namespace cuda +-} // namespace contrib +-} // namespace onnxruntime +diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h +deleted file mode 100644 +index ea87d0c29..000000000 +--- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h ++++ /dev/null +@@ -1,186 +0,0 @@ +-/* +- * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +- * SPDX-License-Identifier: Apache-2.0 +- * +- * Licensed under the Apache License, Version 2.0 (the "License"); +- * you may not use this file except in compliance with the License. +- * You may obtain a copy of the License at +- * +- * http://www.apache.org/licenses/LICENSE-2.0 +- * +- * Unless required by applicable law or agreed to in writing, software +- * distributed under the License is distributed on an "AS IS" BASIS, +- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +- * See the License for the specific language governing permissions and +- * limitations under the License. +- */ +- +-// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +-// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +-// Copyright (c) Microsoft Corporation. All rights reserved. +-// Licensed under the MIT License. +-#pragma once +-#include "core/providers/cuda/cuda_common.h" +-using namespace onnxruntime::cuda; +- +-namespace onnxruntime { +-namespace contrib { +-namespace cuda { +- +-// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. +-constexpr static int32_t CHANNELS_PER_THREAD = 2; +- +-constexpr static int kSizes[] = {128, 256, 320, 384, 512}; +-constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +-constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; +- +-int32_t GetThreadsPerBlock(int32_t channels_per_block, int32_t channels_per_thread); +- +-static inline int32_t DivUp(int32_t m, int32_t n) { +- return (m + n - 1) / n; +-} +- +-int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor); +- +-int GetChannelsPerBlock(int num_channels, int num_groups); +- +-template +-struct GroupNormNHWCParams { +- // The output buffer. Shape is (n, h, w, c). +- T* dst; +- +- // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). +- T* add_out; +- +- // The input buffer. Shape is (n, h, w, c). +- T const* src; +- +- // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). +- T const* skip; +- +- // Optional input buffer for bias tensor. Shape is (c). +- T const* bias; +- +- // The gamma scaling factor. +- float const* gamma; +- +- // The beta term to add in GN. +- float const* beta; +- +- // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. +- float* group_sum_buffer; +- +- // The number of instances in the batch. +- int32_t n; +- +- // The height and width of each activation map. +- int32_t h; +- int32_t w; +- +- // Number of channels. +- int32_t c; +- +- // Number of groups. +- int32_t groups; +- +- // Do we apply the SiLU activation function? +- bool use_silu; +- +- // Precomputed values and parameters to control the execution of the kernels. +- +- // Number of activations per instance (h * w) +- int32_t hw; +- +- // Number of activations per block +- int32_t hw_per_block; +- +- // Number of channels per block in the C dimension. +- int32_t channels_per_block; +- +- // Number of channels per group in the C dimension. +- int32_t channels_per_group; +- +- // The precomputed stride between instances. +- int32_t hwc; +- // The inverse of hw*channels_per_group to compute mean of a group. +- float inv_hw_channels_per_group; +- // The precomputed number of groups per block. +- int32_t groups_per_block; +- +- // Number of threads per block +- int32_t threads_per_block; +- +- // Epsilon to get stable variance in normalization. +- float epsilon; +- +- // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. +- bool broadcast_skip; +- +- // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. +- T* skip_workspace; +- +- GroupNormNHWCParams(T* output, +- T* add_out, +- const T* input, +- const T* skip, +- const T* bias, +- const float* gamma, +- const float* beta, +- float* workspace, +- float epsilon, +- int batch_size, +- int num_channels, +- int height, +- int width, +- int num_groups, +- bool use_silu, +- bool broadcast_skip, +- int channels_per_block) { +- int32_t channels_per_group = num_channels / num_groups; +- // channels_per_block is computed in PrePack. +- // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. +- if (channels_per_block < channels_per_group) { +- channels_per_block = GetChannelsPerBlock(num_channels, num_groups); +- } +- +- this->use_silu = use_silu; +- this->dst = output; +- this->add_out = add_out; +- this->src = input; +- this->skip = skip; +- this->bias = bias; +- this->gamma = gamma; +- this->beta = beta; +- this->group_sum_buffer = workspace; +- this->n = batch_size; +- this->h = height; +- this->w = width; +- this->c = num_channels; +- this->groups = num_groups; +- this->hw = this->h * this->w; +- +- // This will allocate as many blocks as possible to partition HW. +- // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. +- // TODO: tune this logic to find proper blocks when hw is small. +- constexpr int32_t max_blocks_per_hw = 1024; +- const int32_t blocks_per_hw = FindMaxDivisor(this->hw, max_blocks_per_hw); +- this->hw_per_block = DivUp(this->hw, blocks_per_hw); +- +- this->channels_per_block = channels_per_block; +- this->channels_per_group = channels_per_group; +- this->hwc = this->hw * this->c; +- this->inv_hw_channels_per_group = 1.F / (float)(this->hw * this->channels_per_group); +- this->groups_per_block = channels_per_block / this->channels_per_group; +- this->epsilon = epsilon; +- this->broadcast_skip = broadcast_skip; +- +- // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. +- this->skip_workspace = (this->add_out != nullptr) ? this->add_out : this->dst; +- +- this->threads_per_block = GetThreadsPerBlock(channels_per_block, CHANNELS_PER_THREAD); +- } +-}; +- +-} // namespace cuda +-} // namespace contrib +-} // namespace onnxruntime +diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +index 4909dc5e3..48b161552 100644 +--- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu ++++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +@@ -27,8 +27,6 @@ + #include "core/providers/cuda/cu_inc/common.cuh" + #include "contrib_ops/cuda/diffusion/group_norm_impl.h" + #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +-#include "contrib_ops/cuda/diffusion/group_norm_common_base.h" +-#include "contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh" + + using namespace onnxruntime::cuda; + +@@ -36,6 +34,329 @@ namespace onnxruntime { + namespace contrib { + namespace cuda { + ++namespace { ++ ++// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. ++constexpr static int32_t CHANNELS_PER_THREAD = 2; ++ ++constexpr static int kSizes[] = {128, 256, 320, 384, 512}; ++constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); ++constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; ++ ++int NextSize(int x) { ++ for (size_t i = 0; i < kNumOfSizes; ++i) { ++ if (x <= kSizes[i]) { ++ return kSizes[i]; ++ } ++ } ++ ++ return x; ++} ++} // namespace ++ ++static inline int32_t DivUp(int32_t m, int32_t n) { ++ return (m + n - 1) / n; ++} ++ ++static inline __device__ __host__ float sigmoid(float x) { ++ return 1.F / (1.F + expf(-x)); ++} ++ ++struct GroupSums { ++ // Is it the 1st element of the group? ++ int32_t flag; ++ // The sum. ++ float sum; ++ // The sum of squares. ++ float sum_sq; ++}; ++ ++struct GroupSumsOp { ++ inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { ++ GroupSums dst; ++ dst.sum = b.flag ? b.sum : (a.sum + b.sum); ++ dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); ++ dst.flag = a.flag + b.flag; ++ return dst; ++ } ++}; ++ ++template ++struct GroupNormNHWCParams { ++ // The output buffer. Shape is (n, h, w, c). ++ T* dst; ++ ++ // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). ++ T* add_out; ++ ++ // The input buffer. Shape is (n, h, w, c). ++ T const* src; ++ ++ // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). ++ T const* skip; ++ ++ // Optional input buffer for bias tensor. Shape is (c). ++ T const* bias; ++ ++ // The gamma scaling factor. ++ float const* gamma; ++ ++ // The beta term to add in GN. ++ float const* beta; ++ ++ // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. ++ float* group_sum_buffer; ++ ++ // The number of instances in the batch. ++ int32_t n; ++ ++ // The height and width of each activation map. ++ int32_t h; ++ int32_t w; ++ ++ // Number of channels. ++ int32_t c; ++ ++ // Number of groups. ++ int32_t groups; ++ ++ // Do we apply the SiLU activation function? ++ bool use_silu; ++ ++ // Precomputed values and parameters to control the execution of the kernels. ++ ++ // Number of activations per instance (h * w) ++ int32_t hw; ++ ++ // Number of activations per block ++ int32_t hw_per_block; ++ ++ // Number of channels per block in the C dimension. ++ int32_t channels_per_block; ++ ++ // Number of channels per group in the C dimension. ++ int32_t channels_per_group; ++ ++ // The precomputed stride between instances. ++ int32_t hwc; ++ // The inverse of hw*channels_per_group to compute mean of a group. ++ float inv_hw_channels_per_group; ++ // The precomputed number of groups per block. ++ int32_t groups_per_block; ++ ++ // Number of threads per block ++ int32_t threads_per_block; ++ ++ // Epsilon to get stable variance in normalization. ++ float epsilon; ++ ++ // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. ++ bool broadcast_skip; ++ ++ // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. ++ T* skip_workspace; ++}; ++ ++template ++inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); ++ ++template <> ++inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { ++ // Fetch two channels per thread. ++ __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); ++ ++ float2 f2 = __half22float2(h2); ++ ++ // Update the sum. ++ sum += f2.x + f2.y; ++ ++ // Update the sum of squares. ++ sum_sq += f2.x * f2.x + f2.y * f2.y; ++} ++ ++template <> ++inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { ++ // Fetch two channels per thread. ++ float2 f2 = *reinterpret_cast(&src[offset]); ++ ++ // Update the sum. ++ sum += f2.x + f2.y; ++ ++ // Update the sum of squares. ++ sum_sq += f2.x * f2.x + f2.y * f2.y; ++} ++ ++// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] ++template ++inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, ++ int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); ++ ++template <> ++inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, ++ int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { ++ // Fetch two channels per thread. ++ __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); ++ __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); ++ __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); ++ h2 = h2 + b; ++ h2 = h2 + s; ++ ++ *reinterpret_cast<__half2*>(&add_out[offset]) = h2; ++ ++ float2 f2 = __half22float2(h2); ++ sum += f2.x + f2.y; ++ sum_sq += f2.x * f2.x + f2.y * f2.y; ++} ++ ++template <> ++inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, ++ int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { ++ float2 f2 = *reinterpret_cast(&src[offset]); ++ float2 s = *reinterpret_cast(&skip[skip_offset]); ++ float2 b = *reinterpret_cast(&bias[bias_offset]); ++ f2.x += s.x + b.x; ++ f2.y += s.y + b.y; ++ ++ *reinterpret_cast(&add_out[offset]) = f2; ++ ++ sum += f2.x + f2.y; ++ sum_sq += f2.x * f2.x + f2.y * f2.y; ++} ++ ++// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] ++template ++inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, ++ int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); ++ ++template <> ++inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, ++ int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { ++ __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); ++ __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); ++ h2 = h2 + s; ++ ++ *reinterpret_cast<__half2*>(&add_out[offset]) = h2; ++ ++ float2 f2 = __half22float2(h2); ++ sum += f2.x + f2.y; ++ sum_sq += f2.x * f2.x + f2.y * f2.y; ++} ++ ++template <> ++inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, ++ int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { ++ float2 f2 = *reinterpret_cast(&src[offset]); ++ float2 s = *reinterpret_cast(&skip[skip_offset]); ++ f2.x += s.x; ++ f2.y += s.y; ++ *reinterpret_cast(&add_out[offset]) = f2; ++ sum += f2.x + f2.y; ++ sum_sq += f2.x * f2.x + f2.y * f2.y; ++} ++ ++template ++__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { ++ // The object in charge of doing the sums for the different blocks. ++ typedef cub::BlockScan BlockScan; ++ ++ // Allocate shared memory for BlockScan. ++ __shared__ typename BlockScan::TempStorage temp_storage; ++ ++ // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. ++ __shared__ float2 smem[THREADS_PER_BLOCK]; ++ ++ // The instance in the batch. ++ int32_t ni = blockIdx.z; ++ ++ // The channel loaded by that thread. ++ int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; ++ ++ if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { ++ return; ++ } ++ ++ // The first activation loaded by that block. ++ int32_t hw_begin = blockIdx.y * params.hw_per_block; ++ // The last activation loaded by that block. ++ int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); ++ ++ // The sums. ++ float sum = 0.F; ++ float sum_sq = 0.F; ++ ++ // Iterate over the activations to compute the sums. ++ int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; ++ if (params.skip != nullptr) { ++ // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) ++ const int64_t bias_offset = static_cast(ci); ++ T* add_out = params.skip_workspace; ++ if (params.broadcast_skip) { ++ const int64_t skip_offset = static_cast(ni) * params.c + ci; ++ ++ if (params.bias != nullptr) { ++ for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { ++ AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); ++ } ++ } else { ++ for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { ++ AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); ++ } ++ } ++ } else { ++ if (params.bias != nullptr) { ++ for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { ++ AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); ++ } ++ } else { ++ for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { ++ AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); ++ } ++ } ++ } ++ } else { // GroupNorm ++ for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { ++ UpdateSum(params.src, offset, sum, sum_sq); ++ } ++ } ++ ++ // The group index relative to the first group within the same block. ++ int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; ++ // The channel in the group. ++ int32_t cj = ci % params.channels_per_group; ++ ++ // The data for the summations. ++ GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; ++ ++ // Do the segmented scan. InclusiveScan is not deterministic. ++ GroupSums out; ++ BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); ++ ++ // Store the results for the groups in shared memory (to produce coalesced stores later). ++ // For each group, only the last thread of that group is picked to save sum to shared memory. ++ if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { ++ smem[gi] = make_float2(out.sum, out.sum_sq); ++ } ++ ++ // Make sure the data is in shared memory. ++ __syncthreads(); ++ ++ // Threads that have nothing left to do, exit. ++ if (threadIdx.x >= params.groups_per_block) { ++ return; ++ } ++ ++ // The global group index. ++ // Use neighboring threads for coalesced write. ++ int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; ++ ++ if (gj < params.groups) { ++ float2 sums = smem[threadIdx.x]; ++ const int index = (2 * ni) * params.groups + gj; ++ atomicAdd(¶ms.group_sum_buffer[index], sums.x); ++ atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); ++ } ++} ++ + template + void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { + dim3 grid; +@@ -49,26 +370,119 @@ void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) + // The number of instances. + grid.z = params.n; + +-#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ +- GroupNormNHWCSumKernel \ +- <<>>( \ +- params.skip_workspace, params.group_sum_buffer, params.src, params.skip, params.bias, \ +- params.channels_per_block, params.hw_per_block, params.hw, params.hwc, params.c, \ +- params.channels_per_group, params.groups, params.groups_per_block, params.broadcast_skip); \ +- break; +- + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: +- LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) ++ GroupNormNHWCSumKernel<<>>(params); ++ break; + case 192: +- LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) ++ GroupNormNHWCSumKernel<<>>(params); ++ break; + case 160: +- LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) ++ GroupNormNHWCSumKernel<<>>(params); ++ break; + case 128: +- LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) ++ GroupNormNHWCSumKernel<<>>(params); ++ break; + case 64: +- LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) ++ GroupNormNHWCSumKernel<<>>(params); ++ break; ++ } ++} ++ ++template ++__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, ++ float2& gamma_f2, float2& beta_f2, bool silu); ++ ++template <> ++__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, ++ float2& gamma_f2, float2& beta_f2, bool silu) { ++ // Fetch two channels per thread. ++ __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); ++ ++ // Extract the two half values. ++ float2 f2 = __half22float2(h2); ++ ++ // Normalize the channels. ++ f2.x = (f2.x - mean) * inv_std_dev; ++ f2.y = (f2.y - mean) * inv_std_dev; ++ ++ // Scale by gamma and add beta. ++ f2.x = gamma_f2.x * f2.x + beta_f2.x; ++ f2.y = gamma_f2.y * f2.y + beta_f2.y; ++ ++ // Apply SiLU activation if needed. ++ if (silu) { ++ f2.x = f2.x * sigmoid(f2.x); ++ f2.y = f2.y * sigmoid(f2.y); ++ } ++ ++ *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); ++} ++ ++template <> ++__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, ++ float2& gamma_f2, float2& beta_f2, bool silu) { ++ // Fetch two channels per thread. ++ float2 f2 = *reinterpret_cast(&src[offset]); ++ ++ // Normalize the channels. ++ f2.x = (f2.x - mean) * inv_std_dev; ++ f2.y = (f2.y - mean) * inv_std_dev; ++ ++ // Scale by gamma and add beta. ++ f2.x = gamma_f2.x * f2.x + beta_f2.x; ++ f2.y = gamma_f2.y * f2.y + beta_f2.y; ++ ++ // Apply SiLU activation if needed. ++ if (silu) { ++ f2.x = f2.x * sigmoid(f2.x); ++ f2.y = f2.y * sigmoid(f2.y); ++ } ++ ++ *reinterpret_cast(&dst[offset]) = f2; ++} ++ ++template ++__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { ++ // The channel loaded by that thread. ++ int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; ++ if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { ++ return; ++ } ++ ++ // The instance in the batch. ++ int32_t ni = blockIdx.z; ++ ++ // The group that thread works on. ++ int32_t gi = ci / params.channels_per_group; ++ ++ // Load the sum and sum of squares for the group. ++ float sum = 0.F, sum_sq = 0.F; ++ if (gi < params.groups) { ++ const int index = (2 * ni) * params.groups + gi; ++ sum = params.group_sum_buffer[index]; ++ sum_sq = params.group_sum_buffer[index + params.groups]; ++ } ++ ++ // Load gamma/beta. Fetch two per thread. ++ float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); ++ float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); ++ ++ // Compute the mean. ++ float mean = sum * params.inv_hw_channels_per_group; ++ // Compute the variance. ++ float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); ++ // Compute the inverse of the stddev. ++ float inv_std_dev = rsqrtf(var + params.epsilon); ++ ++ int32_t hw_begin = blockIdx.y * params.hw_per_block; ++ int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); ++ ++ const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; ++ int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; ++ for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { ++ ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); + } + } + +@@ -83,34 +497,83 @@ void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t strea + // The number of instances. + grid.z = params.n; + +-#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ +- GroupNormNHWCScaleKernel \ +- <<>>( \ +- params.dst, params.src, params.skip, params.gamma, params.beta, params.skip_workspace, \ +- params.group_sum_buffer, params.epsilon, params.c, params.channels_per_block, params.channels_per_group, \ +- params.groups, params.hwc, params.inv_hw_channels_per_group, params.hw, params.hw_per_block, \ +- params.use_silu); \ +- break; +- + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: +- LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) ++ GroupNormNHWCScaleKernel<<>>(params); ++ break; + case 192: +- LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) ++ GroupNormNHWCScaleKernel<<>>(params); ++ break; + case 160: +- LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) ++ GroupNormNHWCScaleKernel<<>>(params); ++ break; + case 128: +- LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) ++ GroupNormNHWCScaleKernel<<>>(params); ++ break; + case 64: +- LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) ++ GroupNormNHWCScaleKernel<<>>(params); ++ break; + } + } + ++int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { ++ int32_t max_divisor = -1; ++ for (int32_t i = 1; i <= std::sqrt(n); i++) { ++ if (n % i == 0) { ++ int32_t divisor1 = n / i; ++ int32_t divisor2 = i; ++ ++ if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { ++ max_divisor = divisor1; ++ } ++ if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { ++ max_divisor = divisor2; ++ } ++ } ++ } ++ return max_divisor; ++} ++ ++// Find proper channels per block based on a cost function: The cost is number of channels corresponding to ++// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has ++// work to do so it is ideal case. ++int FindChannelsPerBlock(int num_channels, int channels_per_group) { ++ int min_cost = -1; ++ int best_candidate = -1; ++ for (size_t i = kNumOfSizes; i > 0; --i) { ++ if (kSizes[i - 1] < channels_per_group) { ++ break; ++ } ++ ++ int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; ++ int blocks = (num_channels + channels_per_block - 1) / channels_per_block; ++ int cost = blocks * kSizes[i - 1] - num_channels; ++ if (cost == 0) { ++ return channels_per_block; ++ } ++ ++ if (min_cost == -1 || cost < min_cost) { ++ min_cost = cost; ++ best_candidate = channels_per_block; ++ } ++ } ++ ++ return best_candidate; ++} ++ ++int GetChannelsPerBlock(int num_channels, int num_groups) { ++ int32_t channels_per_group = num_channels / num_groups; ++ int32_t channels_per_block = channels_per_group; ++ if (channels_per_group < kMaxSize / 2) { ++ channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); ++ } ++ return channels_per_block; ++} ++ + template + Status LaunchGroupNormKernel( +- CudaTuningContext* tuning_ctx, +- Stream* ort_stream, ++ cudaStream_t stream, + T* output, + T* add_out, + const T* input, +@@ -128,17 +591,19 @@ Status LaunchGroupNormKernel( + bool use_silu, + bool broadcast_skip, + int channels_per_block) { ++ GroupNormNHWCParams params; + +- // tuning_ctx only used for ROCm EP. +- ORT_UNUSED_PARAMETER(tuning_ctx); +- +- GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, reinterpret_cast(workspace), epsilon, +- batch_size, num_channels, height, width, num_groups, use_silu, +- broadcast_skip, channels_per_block); ++ int32_t channels_per_group = num_channels / num_groups; ++ // channels_per_block is computed in PrePack. ++ // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. ++ if (channels_per_block < channels_per_group) { ++ channels_per_block = GetChannelsPerBlock(num_channels, num_groups); ++ } + +- if (params.channels_per_block % params.channels_per_group != 0 || +- params.channels_per_block > kMaxSize || +- (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { ++ // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases ++ if (channels_per_block % channels_per_group != 0 || ++ channels_per_block > kMaxSize || ++ (channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in CUDA does not support the input: n=", batch_size, + " h=", height, +@@ -147,7 +612,42 @@ Status LaunchGroupNormKernel( + " groups=", num_groups); + } + +- auto stream = static_cast(ort_stream->GetHandle()); ++ params.use_silu = use_silu; ++ params.dst = output; ++ params.add_out = add_out; ++ params.src = input; ++ params.skip = skip; ++ params.bias = bias; ++ params.gamma = gamma; ++ params.beta = beta; ++ params.group_sum_buffer = reinterpret_cast(workspace); ++ params.n = batch_size; ++ params.h = height; ++ params.w = width; ++ params.c = num_channels; ++ params.groups = num_groups; ++ params.hw = params.h * params.w; ++ ++ // This will allocate as many blocks as possible to partition HW. ++ // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. ++ // TODO: tune this logic to find proper blocks when hw is small. ++ constexpr int32_t max_blocks_per_hw = 1024; ++ const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw); ++ params.hw_per_block = DivUp(params.hw, blocks_per_hw); ++ ++ params.channels_per_block = channels_per_block; ++ params.channels_per_group = channels_per_group; ++ params.hwc = params.hw * params.c; ++ params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group); ++ params.groups_per_block = channels_per_block / params.channels_per_group; ++ params.epsilon = epsilon; ++ params.broadcast_skip = broadcast_skip; ++ ++ // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. ++ params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst; ++ ++ params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD; ++ + CUDA_RETURN_IF_ERROR(cudaMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); + +@@ -163,14 +663,14 @@ Status LaunchGroupNormKernel( + return Status::OK(); + } + +-template Status LaunchGroupNormKernel(CudaTuningContext* tuning_ctx, Stream* stream, half* output, half* add_out, ++template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, half* add_out, + const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, + float epsilon, int batch_size, int num_channels, + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); + +-template Status LaunchGroupNormKernel(CudaTuningContext* tuning_ctx, Stream* stream, float* output, float* add_out, ++template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, float* add_out, + const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, + float epsilon, int batch_size, int num_channels, +diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +index 98f38a147..9532aeecb 100644 +--- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h ++++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +@@ -8,8 +8,6 @@ + #include + #include + +-#include "core/providers/cuda/tunable/cuda_tunable.h" +- + namespace onnxruntime { + namespace contrib { + namespace cuda { +@@ -23,8 +21,7 @@ int GetChannelsPerBlock(int num_channels, int num_groups); + + template + Status LaunchGroupNormKernel( +- CudaTuningContext* tuning_ctx, +- Stream* ort_stream, ++ cudaStream_t stream, + T* output, // normalized output tensor. Shape is (n, h, w, c) + T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c) + const T* input, // input tensor. Shape is (n, h, w, c) +diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh +deleted file mode 100644 +index ecd06315e..000000000 +--- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh ++++ /dev/null +@@ -1,451 +0,0 @@ +-/* +- * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +- * SPDX-License-Identifier: Apache-2.0 +- * +- * Licensed under the Apache License, Version 2.0 (the "License"); +- * you may not use this file except in compliance with the License. +- * You may obtain a copy of the License at +- * +- * http://www.apache.org/licenses/LICENSE-2.0 +- * +- * Unless required by applicable law or agreed to in writing, software +- * distributed under the License is distributed on an "AS IS" BASIS, +- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +- * See the License for the specific language governing permissions and +- * limitations under the License. +- */ +- +-// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +-// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +-// Copyright (c) Microsoft Corporation. All rights reserved. +-// Licensed under the MIT License. +-#pragma once +-#include +-#include +-#include "core/providers/cuda/cuda_common.h" +-#include "core/providers/cuda/cu_inc/common.cuh" +- +-using namespace onnxruntime::cuda; +- +-namespace onnxruntime { +-namespace contrib { +-namespace cuda { +- +-static inline __device__ __host__ float sigmoid(float x) { +- return 1.F / (1.F + expf(-x)); +-} +- +-struct GroupSums { +- // Is it the 1st element of the group? +- int32_t flag; +- // The sum. +- float sum; +- // The sum of squares. +- float sum_sq; +-}; +- +-struct GroupSumsOp { +- inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { +- GroupSums dst; +- dst.sum = b.flag ? b.sum : (a.sum + b.sum); +- dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); +- dst.flag = a.flag + b.flag; +- return dst; +- } +-}; +- +-template +-inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq) { +- using VecT = onnxruntime::cuda::aligned_vector; +- const VecT input_v = *reinterpret_cast(src + offset); +- +-#pragma unroll +- for (int i = 0; i < ILP; i++) { +- const float val = static_cast(input_v.val[i]); +- sum += val; +- sum_sq += val * val; +- } +-} +- +-template <> +-inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { +- // Fetch two channels per thread. +- __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); +- +- float2 f2 = __half22float2(h2); +- +- // Update the sum. +- sum += f2.x + f2.y; +- +- // Update the sum of squares. +- sum_sq += f2.x * f2.x + f2.y * f2.y; +-} +- +-template <> +-inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { +- // Fetch two channels per thread. +- float2 f2 = *reinterpret_cast(&src[offset]); +- +- // Update the sum. +- sum += f2.x + f2.y; +- +- // Update the sum of squares. +- sum_sq += f2.x * f2.x + f2.y * f2.y; +-} +- +-// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] +-template +-inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, +- int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { +- using VecT = onnxruntime::cuda::aligned_vector; +- const VecT input_v = *reinterpret_cast(src + offset); +- const VecT skip_v = *reinterpret_cast(skip + skip_offset); +- const VecT bias_v = *reinterpret_cast(bias + bias_offset); +- VecT output_v = *reinterpret_cast(add_out + offset); +- +-#pragma unroll +- for (int i = 0; i < ILP; i++) { +- output_v.val[i] = input_v.val[i] + skip_v.val[i] + bias_v.val[i]; +- const float val = static_cast(output_v.val[i]); +- sum += val; +- sum_sq += val * val; +- } +- *(reinterpret_cast(add_out + offset)) = output_v; +-} +- +-template <> +-inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, +- int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { +- // Fetch two channels per thread. +- __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); +- __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); +- __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); +- h2 = h2 + b; +- h2 = h2 + s; +- +- *reinterpret_cast<__half2*>(&add_out[offset]) = h2; +- +- float2 f2 = __half22float2(h2); +- sum += f2.x + f2.y; +- sum_sq += f2.x * f2.x + f2.y * f2.y; +-} +- +-template <> +-inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, +- int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { +- float2 f2 = *reinterpret_cast(&src[offset]); +- float2 s = *reinterpret_cast(&skip[skip_offset]); +- float2 b = *reinterpret_cast(&bias[bias_offset]); +- f2.x += s.x + b.x; +- f2.y += s.y + b.y; +- +- *reinterpret_cast(&add_out[offset]) = f2; +- +- sum += f2.x + f2.y; +- sum_sq += f2.x * f2.x + f2.y * f2.y; +-} +- +-// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] +-template +-inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, +- int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { +- using VecT = onnxruntime::cuda::aligned_vector; +- const VecT input_v = *reinterpret_cast(src + offset); +- const VecT skip_v = *reinterpret_cast(skip + skip_offset); +- VecT output_v = *reinterpret_cast(add_out + offset); +- +-#pragma unroll +- for (int i = 0; i < ILP; i++) { +- output_v.val[i] = input_v.val[i] + skip_v.val[i]; +- const float val = static_cast(output_v.val[i]); +- sum += val; +- sum_sq += val * val; +- } +- *(reinterpret_cast(add_out + offset)) = output_v; +-} +- +-template <> +-inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, +- int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { +- __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); +- __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); +- h2 = h2 + s; +- +- *reinterpret_cast<__half2*>(&add_out[offset]) = h2; +- +- float2 f2 = __half22float2(h2); +- sum += f2.x + f2.y; +- sum_sq += f2.x * f2.x + f2.y * f2.y; +-} +- +-template <> +-inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, +- int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { +- float2 f2 = *reinterpret_cast(&src[offset]); +- float2 s = *reinterpret_cast(&skip[skip_offset]); +- f2.x += s.x; +- f2.y += s.y; +- *reinterpret_cast(&add_out[offset]) = f2; +- sum += f2.x + f2.y; +- sum_sq += f2.x * f2.x + f2.y * f2.y; +-} +- +-template +-__global__ void GroupNormNHWCSumKernel(T* skip_workspace, float* group_sum_buffer, const T* src, const T* skip, const T* bias, +- int32_t channels_per_block, int32_t hw_per_block, int32_t hw, int32_t hwc, int32_t c, +- int32_t channels_per_group, int32_t groups, int32_t groups_per_block, bool broadcast_skip) { +- // The object in charge of doing the sums for the different blocks. +- typedef cub::BlockScan BlockScan; +- +- // Allocate shared memory for BlockScan. +- __shared__ typename BlockScan::TempStorage temp_storage; +- +- // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. +- __shared__ float2 smem[THREADS_PER_BLOCK]; +- +- // The instance in the batch. +- int32_t ni = blockIdx.z; +- +- // The channel loaded by that thread. +- int32_t ci = blockIdx.x * channels_per_block + threadIdx.x * ILP; +- +- if (ci >= c || threadIdx.x * ILP >= channels_per_block) { +- return; +- } +- +- // The first activation loaded by that block. +- int32_t hw_begin = blockIdx.y * hw_per_block; +- // The last activation loaded by that block. +- int32_t hw_end = min(hw_begin + hw_per_block, hw); +- +- // The sums. +- float sum = 0.F; +- float sum_sq = 0.F; +- +- // Iterate over the activations to compute the sums. +- int64_t offset = static_cast(ni) * hwc + static_cast(hw_begin) * c + ci; +- if (skip != nullptr) { +- // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) +- const int64_t bias_offset = static_cast(ci); +- T* add_out = skip_workspace; +- if (broadcast_skip) { +- const int64_t skip_offset = static_cast(ni) * c + ci; +- +- if (bias != nullptr) { +- for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { +- AddSkipBias(add_out, src, skip, bias, offset, skip_offset, bias_offset, sum, sum_sq); +- } +- } else { +- for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { +- AddSkip(add_out, src, skip, offset, skip_offset, sum, sum_sq); +- } +- } +- } else { +- if (bias != nullptr) { +- for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { +- AddSkipBias(add_out, src, skip, bias, offset, offset, bias_offset, sum, sum_sq); +- } +- } else { +- for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { +- AddSkip(add_out, src, skip, offset, offset, sum, sum_sq); +- } +- } +- } +- } else { // GroupNorm +- for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { +- UpdateSum(src, offset, sum, sum_sq); +- } +- } +- +- // The group index relative to the first group within the same block. +- int32_t gi = threadIdx.x * ILP / channels_per_group; +- // The channel in the group. +- int32_t cj = ci % channels_per_group; +- +- // The data for the summations. +- GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; +- +- // Do the segmented scan. InclusiveScan is not deterministic. +- GroupSums out; +- BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); +- +- // Store the results for the groups in shared memory (to produce coalesced stores later). +- // For each group, only the last thread of that group is picked to save sum to shared memory. +- if (cj == channels_per_group - ILP) { +- smem[gi] = make_float2(out.sum, out.sum_sq); +- } +- +- // Make sure the data is in shared memory. +- __syncthreads(); +- +- // Threads that have nothing left to do, exit. +- if (threadIdx.x >= groups_per_block) { +- return; +- } +- +- // The global group index. +- // Use neighboring threads for coalesced write. +- int32_t gj = blockIdx.x * groups_per_block + threadIdx.x; +- +- if (gj < groups) { +- float2 sums = smem[threadIdx.x]; +- const int index = (2 * ni) * groups + gj; +- atomicAdd(&group_sum_buffer[index], sums.x); +- atomicAdd(&group_sum_buffer[index + groups], sums.y); +- } +-} +- +-template +-__device__ void computeGroupNormVec(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, +- const float* gamma_v, const float* beta_v, bool silu) { +- using VecT = onnxruntime::cuda::aligned_vector; +- const VecT input_v = *reinterpret_cast(src + offset); +- VecT output_v; +- +-#pragma unroll +- for (int i = 0; i < ILP; i++) { +- float val = static_cast(input_v.val[i]); +- val = (val - mean) * inv_std_dev; +- val = gamma_v[i] * val + beta_v[i]; +- +- if (silu) { +- val = val * sigmoid(val); +- } +- output_v.val[i] = static_cast(val); +- } +- *(reinterpret_cast(dst + offset)) = output_v; +-} +- +-template +-__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, +- float2& gamma_f2, float2& beta_f2, bool silu); +- +-template <> +-__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, +- float2& gamma_f2, float2& beta_f2, bool silu) { +- // Fetch two channels per thread. +- __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); +- +- // Extract the two half values. +- float2 f2 = __half22float2(h2); +- +- // Normalize the channels. +- f2.x = (f2.x - mean) * inv_std_dev; +- f2.y = (f2.y - mean) * inv_std_dev; +- +- // Scale by gamma and add beta. +- f2.x = gamma_f2.x * f2.x + beta_f2.x; +- f2.y = gamma_f2.y * f2.y + beta_f2.y; +- +- // Apply SiLU activation if needed. +- if (silu) { +- f2.x = f2.x * sigmoid(f2.x); +- f2.y = f2.y * sigmoid(f2.y); +- } +- +- *reinterpret_cast<__half2*>(&dst[offset]) = __float22half2_rn(f2); +-} +- +-template <> +-__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, +- float2& gamma_f2, float2& beta_f2, bool silu) { +- // Fetch two channels per thread. +- float2 f2 = *reinterpret_cast(&src[offset]); +- +- // Normalize the channels. +- f2.x = (f2.x - mean) * inv_std_dev; +- f2.y = (f2.y - mean) * inv_std_dev; +- +- // Scale by gamma and add beta. +- f2.x = gamma_f2.x * f2.x + beta_f2.x; +- f2.y = gamma_f2.y * f2.y + beta_f2.y; +- +- // Apply SiLU activation if needed. +- if (silu) { +- f2.x = f2.x * sigmoid(f2.x); +- f2.y = f2.y * sigmoid(f2.y); +- } +- +- *reinterpret_cast(&dst[offset]) = f2; +-} +- +-template +-__device__ void ComputeGroupNormKernel(const T* input, T* dst, int64_t offset, float mean, float inv_std_dev, +- const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { +- using VecF = onnxruntime::cuda::aligned_vector; +- +- const VecF gamma_v = *reinterpret_cast(gamma + ci); +- const VecF beta_v = *reinterpret_cast(beta + ci); +- // Iterate over the activations to compute the sums. +- for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { +- // Fetch ILP channels per thread. +- computeGroupNormVec(input, dst, offset, mean, inv_std_dev, gamma_v.val, beta_v.val, use_silu); +- } +-} +- +-template <> +-__device__ void ComputeGroupNormKernel(const float* input, float* dst, int64_t offset, float mean, float inv_std_dev, +- const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { +- // Load gamma/beta. Fetch two per thread. +- float2 gamma_f2 = *reinterpret_cast(&gamma[ci]); +- float2 beta_f2 = *reinterpret_cast(&beta[ci]); +- for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { +- ComputeGroupNorm(input, dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, use_silu); +- } +-} +- +-template <> +-__device__ void ComputeGroupNormKernel(const half* input, half* dst, int64_t offset, float mean, float inv_std_dev, +- const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { +- // Load gamma/beta. Fetch two per thread. +- float2 gamma_f2 = *reinterpret_cast(&gamma[ci]); +- float2 beta_f2 = *reinterpret_cast(&beta[ci]); +- for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { +- ComputeGroupNorm(input, dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, use_silu); +- } +-} +- +-template +-__global__ void GroupNormNHWCScaleKernel(T* dst, const T* src, const T* skip, const float* gamma, const float* beta, +- const T* skip_workspace, const float* group_sum_buffer, float epsilon, +- int32_t c, int32_t channels_per_block, int32_t channels_per_group, +- int32_t groups, int32_t hwc, float inv_hw_channels_per_group, +- int32_t hw, int32_t hw_per_block, bool use_silu) { +- // The channel loaded by that thread. +- int32_t ci = blockIdx.x * channels_per_block + threadIdx.x * ILP; +- if (ci >= c || threadIdx.x * ILP >= channels_per_block) { +- return; +- } +- +- // The instance in the batch. +- int32_t ni = blockIdx.z; +- +- // The group that thread works on. +- int32_t gi = ci / channels_per_group; +- +- // Load the sum and sum of squares for the group. +- float sum = 0.F, sum_sq = 0.F; +- if (gi < groups) { +- const int index = (2 * ni) * groups + gi; +- sum = group_sum_buffer[index]; +- sum_sq = group_sum_buffer[index + groups]; +- } +- +- // Compute the mean. +- float mean = sum * inv_hw_channels_per_group; +- // Compute the variance. +- float var = sum_sq * inv_hw_channels_per_group - (mean * mean); +- // Compute the inverse of the stddev. +- float inv_std_dev = rsqrtf(var + epsilon); +- +- int32_t hw_begin = blockIdx.y * hw_per_block; +- int32_t hw_end = min(hw_begin + hw_per_block, hw); +- +- const T* input = (skip != nullptr) ? skip_workspace : src; +- int64_t offset = static_cast(ni) * hwc + static_cast(hw_begin) * c + ci; +- ComputeGroupNormKernel(input, dst, offset, mean, inv_std_dev, gamma, beta, use_silu, c, ci, hw_begin, hw_end); +-} +- +-} // namespace cuda +-} // namespace contrib +-} // namespace onnxruntime +diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +index 08cbb145a..2a90e4911 100644 +--- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc ++++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +@@ -49,7 +49,6 @@ ONNX_OPERATOR_KERNEL_EX( + .InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU +- .InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU + .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU + .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), +diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc +index 4cfa89a4d..b31f5d243 100644 +--- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc ++++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc +@@ -203,34 +203,33 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { + DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); + } + +-void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { ++void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { + if (is_enabled_) +- DumpGpuTensor(name, tensor, dim0, dim1, true); ++ DumpGpuTensor(name, tensor, dim0, dim1, true); + } + +-void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { ++void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { + if (is_enabled_) +- DumpGpuTensor(name, tensor, dim0, dim1, true); ++ DumpGpuTensor(name, tensor, dim0, dim1, true); + } + +-void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { ++void CudaTensorConsoleDumper::Print(const char* name, const size_t* tensor, int dim0, int dim1) const { + if (is_enabled_) +- DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); ++ DumpGpuTensor(name, tensor, dim0, dim1, true); + } + +-void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { +- if (is_enabled_) +- DumpGpuTensor(name, tensor, dim0, dim1, true); ++void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1) const { ++ Print(name, reinterpret_cast(tensor), dim0, dim1); + } + +-void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { ++void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { + if (is_enabled_) +- DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); ++ DumpGpuTensor(name, tensor, dim0, dim1, true); + } + +-void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { ++void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { + if (is_enabled_) +- DumpGpuTensor(name, tensor, dim0, dim1, true); ++ DumpGpuTensor(name, tensor, dim0, dim1, true); + } + + void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const { +@@ -243,11 +242,6 @@ void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int d + DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); + } + +-void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { +- if (is_enabled_) +- DumpGpuTensor(name, tensor, dim0, dim1, true); +-} +- + void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const { + if (is_enabled_) + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); +@@ -258,31 +252,22 @@ void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, i + DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); + } + +-void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const { +- if (is_enabled_) +- DumpGpuTensor(name, tensor, dim0, dim1, true); ++void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const { ++ Print(name, reinterpret_cast(tensor), dim0, dim1, dim2); + } + +-void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const { +- if (is_enabled_) +- DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); ++void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const { ++ Print(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3); + } + +-void CudaTensorConsoleDumper::Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const { ++void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { + if (is_enabled_) +- DumpGpuTensor(name, tensor, dim0, dim1, dim2, dim3, true); +-} +- +-void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1) const { +- Print(name, reinterpret_cast(tensor), dim0, dim1); +-} +- +-void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const { +- Print(name, reinterpret_cast(tensor), dim0, dim1, dim2); ++ DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); + } + +-void CudaTensorConsoleDumper::Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const { +- Print(name, reinterpret_cast(tensor), dim0, dim1, dim2, dim3); ++void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { ++ if (is_enabled_) ++ DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); + } + + void CudaTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const { +@@ -316,52 +301,43 @@ void CudaTensorConsoleDumper::Print(const char* name, const std::string& value, + } + + #else +-void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { ++void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { + } + +-void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { ++void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { + } + +-void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { ++void CudaTensorConsoleDumper::Print(const char*, const size_t*, int, int) const { + } + +-void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { ++void CudaTensorConsoleDumper::Print(const char*, const half*, int, int) const { + } + +-void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { ++void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int) const { + } + +-void CudaTensorConsoleDumper::Print(const char*, const float*, int, int) const { ++void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int) const { + } + + void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int) const { + } + +-void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { +-} +- +-void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int) const { +-} +- + void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int) const { + } + +-void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { +-} +- +-void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int) const { ++void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int) const { + } + +-void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int) const { ++void CudaTensorConsoleDumper::Print(const char*, const int64_t*, int, int, int) const { + } + +-void CudaTensorConsoleDumper::Print(const char*, const BFloat16*, int, int, int, int) const { ++void CudaTensorConsoleDumper::Print(const char*, const int32_t*, int, int, int) const { + } + +-void CudaTensorConsoleDumper::Print(const char*, const half*, int, int) const { ++void CudaTensorConsoleDumper::Print(const char*, const float*, int, int, int, int) const { + } + +-void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int) const { ++void CudaTensorConsoleDumper::Print(const char*, const MLFloat16*, int, int, int, int) const { + } + + void CudaTensorConsoleDumper::Print(const char*, const half*, int, int, int, int) const { +diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h +index 773401f79..264ecd7cf 100644 +--- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h ++++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h +@@ -16,31 +16,20 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::transformers::ICons + public: + CudaTensorConsoleDumper() = default; + virtual ~CudaTensorConsoleDumper() {} +- ++ void Print(const char* name, const float* tensor, int dim0, int dim1) const override; ++ void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; + void Print(const char* name, const size_t* tensor, int dim0, int dim1) const override; +- +- void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; +- void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; +- ++ void Print(const char* name, const half* tensor, int dim0, int dim1) const; + void Print(const char* name, const int64_t* tensor, int dim0, int dim1) const override; +- void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; +- +- void Print(const char* name, const float* tensor, int dim0, int dim1) const override; ++ void Print(const char* name, const int32_t* tensor, int dim0, int dim1) const override; + void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const override; + void Print(const char* name, const float* tensor, int dim0, int dim1, int dim2, int dim3) const; +- +- void Print(const char* name, const half* tensor, int dim0, int dim1) const; +- void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const; +- void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const; +- +- void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const override; + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const override; + void Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; +- +- void Print(const char* name, const BFloat16* tensor, int dim0, int dim1) const; +- void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2) const; +- void Print(const char* name, const BFloat16* tensor, int dim0, int dim1, int dim2, int dim3) const; +- ++ void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2) const; ++ void Print(const char* name, const half* tensor, int dim0, int dim1, int dim2, int dim3) const; ++ void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const override; ++ void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const override; + void Print(const char* name, const Tensor& value) const override; + void Print(const char* name, const OrtValue& value) const override; + void Print(const char* name, int index, bool end_line) const override; +diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +index a39abefed..dbd7fb010 100644 +--- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu ++++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +@@ -307,13 +307,12 @@ __device__ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_ + return beams_[beams_count_ - 1].score < current_score; + } + +-template + __device__ void BeamHypotheses::Output( + int top_k, + int max_length, + int pad_token_id, + int32_t* sequences, // buffer of shape (num_return_sequences, max_length) +- T* sequences_scores) // buffer of shape (num_return_sequences) or empty ++ float* sequences_scores) // buffer of shape (num_return_sequences) or empty + { + // Copy the top_k beams into the sequences + for (int index = 0; index < top_k; index++) { +@@ -328,7 +327,7 @@ __device__ void BeamHypotheses::Output( + target[i] = pad_token_id; + + if (sequences_scores) +- sequences_scores[index] = (T)item.score; ++ sequences_scores[index] = item.score; + } + } + +@@ -502,14 +501,13 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp + next_beam_tokens.data()); + } + +-template + __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, + const int32_t* sequences_buffer, + int sequence_length, + BeamHypotheses* beam_hyps_, + const float* final_beam_scores, + int32_t* output, +- T* sequence_scores) { ++ float* sequence_scores) { + int batch_index = blockIdx.x * blockDim.x + threadIdx.x; + if (batch_index >= state.batch_size_) + return; +@@ -536,7 +534,6 @@ __global__ void BeamSearchScorer_Finalize(BeamScorerState& state, + sequence_scores ? sequence_scores + batch_index * state.num_return_sequences_ : nullptr); + } + +-template + void LaunchBeamSearchScorer_Finalize(int batch_size, + BeamScorerState& state, + gsl::span sequences, +@@ -544,7 +541,7 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, + gsl::span beam_hyps, + gsl::span final_beam_scores, + gsl::span output, +- gsl::span sequence_scores, ++ gsl::span sequence_scores, + cudaStream_t stream) { + BeamSearchScorer_Finalize<<<1, batch_size, 0, stream>>>(state, + sequences.data(), +@@ -555,58 +552,6 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, + sequence_scores.data()); + } + +-template void LaunchBeamSearchScorer_Finalize( +- int batch_size, +- BeamScorerState& state, +- gsl::span sequences, +- int sequence_length, +- gsl::span beam_hyps, +- gsl::span final_beam_scores, +- gsl::span output, +- gsl::span sequence_scores, +- cudaStream_t stream); +- +-template void LaunchBeamSearchScorer_Finalize<__half>( +- int batch_size, +- BeamScorerState& state, +- gsl::span sequences, +- int sequence_length, +- gsl::span beam_hyps, +- gsl::span final_beam_scores, +- gsl::span output, +- gsl::span<__half> sequence_scores, +- cudaStream_t stream); +- +-template +-__global__ void FloatConvertAndCopyKernel(const float* src, T* dst, size_t total_elements) { +- int64_t index = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; +- if (index < total_elements) { +- dst[index] = (T)src[index]; +- } +-} +- +-template +-void LaunchBeamSearchScoreCopy(gsl::span final_scores, +- gsl::span output_scores, +- cudaStream_t stream) { +- ORT_ENFORCE(final_scores.size() == output_scores.size()); +- constexpr unsigned ThreadPerBlock = 256; +- unsigned num_blocks = (unsigned)((final_scores.size() + (ThreadPerBlock - 1))/ ThreadPerBlock); +- +- typedef typename ToCudaType::MappedType CudaT; +- +- FloatConvertAndCopyKernel<<>>( +- final_scores.data(), (CudaT*)output_scores.data(), final_scores.size()); +-} +- +-template void LaunchBeamSearchScoreCopy(gsl::span final_scores, +- gsl::span output_scores, +- cudaStream_t stream); +- +-template void LaunchBeamSearchScoreCopy(gsl::span final_scores, +- gsl::span output_scores, +- cudaStream_t stream); +- + __global__ void AddProbsKernel(float* log_probs, + float* cum_log_probs, + const int vocab_size, +diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +index 281cb6c72..5ed594919 100644 +--- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h ++++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +@@ -65,12 +65,11 @@ struct BeamHypotheses { + __device__ bool CanImprove(float best_sum_logprobs, int current_length) const; + + // Output results +- template +- __device__ void Output(int top_k, // number of sequences to return +- int max_length, // max sequence length +- int pad_token_id, // pad token +- int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length) +- T* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) ++ __device__ void Output(int top_k, // number of sequences to return ++ int max_length, // max sequence length ++ int pad_token_id, // pad token ++ int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length) ++ float* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) + }; + + struct BeamScorerState { +@@ -111,7 +110,6 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp + gsl::span next_beam_indices, + cudaStream_t stream); + +-template + void LaunchBeamSearchScorer_Finalize(int batch_size, + BeamScorerState& state, + gsl::span sequences, +@@ -119,14 +117,9 @@ void LaunchBeamSearchScorer_Finalize(int batch_size, + gsl::span beam_hyps_, + gsl::span final_beam_scores, + gsl::span output, +- gsl::span sequence_scores, ++ gsl::span sequence_scores, + cudaStream_t stream); + +-template +-void LaunchBeamSearchScoreCopy(gsl::span final_scores, +- gsl::span output_scores, +- cudaStream_t stream); +- + void LaunchNextTokenKernel(const int64_t* next_token_indices, + int32_t* next_indices, + int32_t* next_tokens, +diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +index bba30805a..380d561bb 100644 +--- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc ++++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +@@ -620,8 +620,6 @@ struct CudaBeamSearchScorer : transformers::IBeamScorer { + Tensor* output_sequences, + Tensor* output_sequence_scores) override; + +- void OutputScores(gsl::span& final_scores, Tensor* output_scores) override; +- + bool IsDone() const override { return false; } // For CUDA we speculatively run the next step while we wait for the GPU to report status. We use 'IsDoneLater()' for this + bool IsDoneLater() const override; + +@@ -634,6 +632,7 @@ struct CudaBeamSearchScorer : transformers::IBeamScorer { + } + gsl::span GetNextIndicesGPU() override { return next_beam_indices_; } + ++ private: + mutable cuda::AutoDestoryCudaEvent event_process_complete_; + IAllocatorUniquePtr state_cpu_; + IAllocatorUniquePtr state_gpu_; +@@ -744,58 +743,22 @@ bool CudaBeamSearchScorer::IsDoneLater() const { + return state_cpu_->not_done_count_ == 0; + } + +-template +-void CudaOutputSequenceScores(CudaBeamSearchScorer* scorer, +- transformers::ISequences& sequences, +- gsl::span& final_beam_scores, +- Tensor* output_sequences, +- Tensor* output_sequence_scores) { +- // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length). +- gsl::span output{output_sequences->MutableData(), static_cast(output_sequences->Shape().Size())}; +- +- // Score of each sequence, with shape (batch_size * num_return_sequences). +- using CudaT = typename ToCudaType::MappedType; +- gsl::span sequence_scores; +- if (output_sequence_scores) { +- sequence_scores = gsl::span{(CudaT*)output_sequence_scores->MutableData(), static_cast(output_sequence_scores->Shape().Size())}; +- } +- +- cuda::LaunchBeamSearchScorer_Finalize(scorer->state_cpu_->batch_size_, +- *scorer->state_gpu_, +- sequences.GetCurrentDeviceSequences(), +- sequences.GetSequenceLength(), +- scorer->beam_hyps_, +- final_beam_scores, +- output, +- sequence_scores, +- scorer->stream_); +-} +- + void CudaBeamSearchScorer::Finalize(transformers::ISequences& sequences, + gsl::span& final_beam_scores, + Tensor* output_sequences, + Tensor* output_sequence_scores) { + ORT_ENFORCE(output_sequences != nullptr); + +- if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType()) { +- CudaOutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); +- } else { +- ORT_ENFORCE(output_sequence_scores->IsDataType()); +- CudaOutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores); +- } +-} ++ // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length). ++ gsl::span output{output_sequences->MutableData(), static_cast(output_sequences->Shape().Size())}; + +-void CudaBeamSearchScorer::OutputScores(gsl::span& final_scores, Tensor* output_scores) { +- if (output_scores) { +- if (output_scores->IsDataType()) { +- gsl::span target(output_scores->MutableData(), output_scores->Shape().Size()); +- cuda::LaunchBeamSearchScoreCopy(final_scores, target, stream_); +- } else { +- ORT_ENFORCE(output_scores->IsDataType()); +- gsl::span target(output_scores->MutableData(), output_scores->Shape().Size()); +- cuda::LaunchBeamSearchScoreCopy(final_scores, target, stream_); +- } ++ // Score of each sequence, with shape (batch_size * num_return_sequences). ++ gsl::span sequence_scores; ++ if (output_sequence_scores) { ++ sequence_scores = gsl::span{output_sequence_scores->MutableData(), static_cast(output_sequence_scores->Shape().Size())}; + } ++ ++ cuda::LaunchBeamSearchScorer_Finalize(state_cpu_->batch_size_, *state_gpu_, sequences.GetCurrentDeviceSequences(), sequences.GetSequenceLength(), beam_hyps_, final_beam_scores, output, sequence_scores, stream_); + } + + std::unique_ptr CreateBeamScorer(const transformers::IGenerationParameters& parameters, +diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh +index be8508670..0599318a4 100644 +--- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh ++++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh +@@ -31,7 +31,7 @@ using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecializatio + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +-using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface ++using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface + using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation + + static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +@@ -141,35 +141,6 @@ std::vector, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); + +-template <> +-std::vector, ck::Tuple<>, +- PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, +- MaskingSpecialization::MaskOutUpperTriangle>>> +-GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< +- F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); +- +-// fp16, biased, non-masked +-template <> +-std::vector, ck::Tuple<>, +- PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, +- MaskingSpecialization::MaskOutUpperTriangle>>> +-GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< +- F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); +- +-// fp16, biased, fp16 masked, basically, two bias +-template <> +-std::vector, ck::Tuple<>, +- PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, +- MaskingSpecialization::MaskOutUpperTriangle>>> +-GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< +- F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); +- + } // namespace internal + } // namespace rocm + } // namespace contrib +diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu +index 2e32a6594..181e47f01 100644 +--- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu ++++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu +@@ -32,27 +32,6 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + return instances; + } + +-using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< +- 2, 1, 1, 1, 1, +- F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, +- PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, +- MaskingSpecialization::MaskOutUpperTriangle>; +- +-template <> +-std::vector> +-GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< +- F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { +- std::vector> instances; +- ck::tensor_operation::device::instance::add_device_operation_instances( +- instances, +- device_batched_gemm_softmax_gemm_permute_instances< +- 2, 1, 1, 1, 1, +- F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, +- MaskingSpecialization::MaskOutUpperTriangle>{}); +- +- return instances; +-} +- + } // namespace internal + } // namespace rocm + } // namespace contrib +diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu +index 91da8d9e1..1577bdf39 100644 +--- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu ++++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu +@@ -32,27 +32,6 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + return instances; + } + +-using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< +- 2, 1, 1, 1, 1, +- F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, +- PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, +- MaskingSpecialization::MaskOutUpperTriangle>; +- +-template <> +-std::vector> +-GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< +- F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { +- std::vector> instances; +- ck::tensor_operation::device::instance::add_device_operation_instances( +- instances, +- device_batched_gemm_softmax_gemm_permute_instances< +- 2, 1, 1, 1, 1, +- F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, +- MaskingSpecialization::MaskOutUpperTriangle>{}); +- +- return instances; +-} +- + } // namespace internal + } // namespace rocm + } // namespace contrib +diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu +index b08123be1..14de59234 100644 +--- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu ++++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu +@@ -32,27 +32,6 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + return instances; + } + +-using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< +- 2, 1, 1, 1, 1, +- F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, +- PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, +- MaskingSpecialization::MaskOutUpperTriangle>; +- +-template <> +-std::vector> +-GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< +- F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { +- std::vector> instances; +- ck::tensor_operation::device::instance::add_device_operation_instances( +- instances, +- device_batched_gemm_softmax_gemm_permute_instances< +- 2, 1, 1, 1, 1, +- F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, +- MaskingSpecialization::MaskOutUpperTriangle>{}); +- +- return instances; +-} +- + } // namespace internal + } // namespace rocm + } // namespace contrib +diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +index 54dda4bfa..78983ac95 100644 +--- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh ++++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +@@ -732,154 +732,122 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp +-auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { ++template ++auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { + constexpr const int kNumBiasBuffer = static_cast(USE_BIAS) + static_cast(USE_MASK); + + using Nop = ck::tensor_operation::element_wise::PassThrough; + using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; + +- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( +- !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), +- "attention mode is not supported, got ", params->attention->mode); +- if constexpr (USE_BIAS) { +- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( +- params->bias_buffer == nullptr, "biased version only support input with bias"); +- } else { +- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( +- params->bias_buffer != nullptr, "non-biased version only support input without bias"); +- } +- if constexpr (USE_MASK) { +- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( +- !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), +- "mask type is not supported, got ", params->attention->mask_type); +- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( +- params->mask_index_buffer == nullptr, "masked version only support input with mask"); +- } else { +- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( +- params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); +- } +- +- auto attn = params->attention; +- const int& G0 = attn->batch_size; +- const int& G1 = attn->num_heads; +- const int& M = attn->sequence_length; +- const int& N = attn->total_sequence_length; +- const int& K = attn->head_size; +- const int& O = attn->v_head_size; +- { +- auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); +- ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); +- } +- +- auto [qs, ks, vs] = GetQkvStrides(attn); +- std::vector q_buffer_lengths = {G0, G1, M, K}; +- std::vector q_buffer_strides = qs.template ForBNSHCoord>(); +- std::vector k_buffer_lengths = {G0, G1, N, K}; +- std::vector k_buffer_strides = ks.template ForBNSHCoord>(); +- std::vector v_buffer_lengths = {G0, G1, O, N}; +- std::vector v_buffer_strides = vs.template ForBNHSCoord>(); +- std::vector out_buffer_lengths = {G0, G1, M, O}; +- std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 +- +- std::array bias_buffers{}; +- std::array, kNumBiasBuffer> bias_lengths{}; +- std::array, kNumBiasBuffer> bias_strides{}; +- if constexpr (USE_BIAS) { +- bias_buffers[0] = const_cast(params->bias_buffer); +- bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) +- bias_strides[0] = {G1 * M * N, M * N, N, 1}; +- } +- if constexpr (USE_MASK) { +- bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; +- bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) +- if (params->mask_index_dims.size() == 2) { // [B,T] +- bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; +- } else if (params->mask_index_dims.size() == 3) { // [B,S,T] +- bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; +- } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] +- bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; +- } else { +- ORT_ENFORCE(false, "Unreachable"); +- } +- } +- +- auto arg = impl->MakeArgumentPointer( +- params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, +- bias_buffers, // Gemm1 bias, as attention mask +- {}, // Gemm2 bias +- q_buffer_lengths, q_buffer_strides, +- k_buffer_lengths, k_buffer_strides, +- v_buffer_lengths, v_buffer_strides, +- out_buffer_lengths, out_buffer_strides, +- bias_lengths, bias_strides, +- {}, +- {}, +- Nop{}, +- Nop{}, +- Acc0ElementOp{params->scale}, +- Nop{}, +- Nop{}); +- +- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), +- impl->GetTypeString(), " does not support the params"); +- +- if constexpr (USE_MASK) { +- ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); +- } +- +- invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); +- return Status::OK(); +-} +- +-template +-auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { + using CKDataType = typename CKDataTypeAdaptor::type; + using D0DataType = typename ck::detail::tuple_concat< + std::conditional_t, ck::Tuple<>>, + std::conditional_t, ck::Tuple<>>>::type; + +- constexpr static auto MaskingSpecMaskDisabled = ++ constexpr static auto MaskingSpec = + ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; +- constexpr static auto MaskingSpecMaskOutUpperTriangle = +- ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; +- +- std::vector>>> +- ret; + ++ std::vector>>> ret; + for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< +- CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) { ++ CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpec>()) { + auto type_string = impl->GetTypeString(); + + auto invoker = impl->MakeInvokerPointer(); + auto op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GemmSoftmaxGemmPermuteParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( +- params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); +- +- return GetArgAndRunInvoker(impl, invoker, params); +- }; +- ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); +- } ++ !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), ++ "attention mode is not supported, got ", params->attention->mode); ++ if constexpr (USE_BIAS) { ++ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( ++ params->bias_buffer == nullptr, "biased version only support input with bias"); ++ } else { ++ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( ++ params->bias_buffer != nullptr, "non-biased version only support input without bias"); ++ } ++ if constexpr (USE_MASK) { ++ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( ++ !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), ++ "mask type is not supported, got ", params->attention->mask_type); ++ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( ++ params->mask_index_buffer == nullptr, "masked version only support input with mask"); ++ } else { ++ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( ++ params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); ++ } + +- for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< +- CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) { +- auto type_string = impl->GetTypeString(); ++ auto attn = params->attention; ++ const int& G0 = attn->batch_size; ++ const int& G1 = attn->num_heads; ++ const int& M = attn->sequence_length; ++ const int& N = attn->total_sequence_length; ++ const int& K = attn->head_size; ++ const int& O = attn->v_head_size; ++ { ++ auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); ++ ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); ++ } + +- auto invoker = impl->MakeInvokerPointer(); +- auto op = [impl = std::move(impl), invoker = std::move(invoker)]( +- const GemmSoftmaxGemmPermuteParams* params) -> Status { +- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( +- !params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle"); +- TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( +- params->attention->sequence_length != params->attention->total_sequence_length, +- "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); ++ auto [qs, ks, vs] = GetQkvStrides(attn); ++ std::vector q_buffer_lengths = {G0, G1, M, K}; ++ std::vector q_buffer_strides = qs.template ForBNSHCoord>(); ++ std::vector k_buffer_lengths = {G0, G1, N, K}; ++ std::vector k_buffer_strides = ks.template ForBNSHCoord>(); ++ std::vector v_buffer_lengths = {G0, G1, O, N}; ++ std::vector v_buffer_strides = vs.template ForBNHSCoord>(); ++ std::vector out_buffer_lengths = {G0, G1, M, O}; ++ std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 ++ ++ std::array bias_buffers{}; ++ std::array, kNumBiasBuffer> bias_lengths{}; ++ std::array, kNumBiasBuffer> bias_strides{}; ++ if constexpr (USE_BIAS) { ++ bias_buffers[0] = const_cast(params->bias_buffer); ++ bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) ++ bias_strides[0] = {G1 * M * N, M * N, N, 1}; ++ } ++ if constexpr (USE_MASK) { ++ bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; ++ bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) ++ if (params->mask_index_dims.size() == 2) { // [B,T] ++ bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; ++ } else if (params->mask_index_dims.size() == 3) { // [B,S,T] ++ bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; ++ } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] ++ bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; ++ } else { ++ ORT_ENFORCE(false, "Unreachable"); ++ } ++ } + +- return GetArgAndRunInvoker(impl, invoker, params); ++ auto arg = impl->MakeArgumentPointer( ++ params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, ++ bias_buffers, // Gemm1 bias, as attention mask ++ {}, // Gemm2 bias ++ q_buffer_lengths, q_buffer_strides, ++ k_buffer_lengths, k_buffer_strides, ++ v_buffer_lengths, v_buffer_strides, ++ out_buffer_lengths, out_buffer_strides, ++ bias_lengths, bias_strides, ++ {}, ++ {}, ++ Nop{}, ++ Nop{}, ++ Acc0ElementOp{params->scale}, ++ Nop{}, ++ Nop{}); ++ ++ TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), ++ impl->GetTypeString(), " does not support the params"); ++ ++ if constexpr (USE_MASK) { ++ ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); ++ } ++ invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); ++ return Status::OK(); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); + } +- + return ret; + } + #endif // USE_COMPOSABLE_KERNEL +diff --git a/onnxruntime/core/flatbuffers/checkpoint_version.h b/onnxruntime/core/flatbuffers/checkpoint_version.h +index e6ee20bf5..6cad27c35 100644 +--- a/onnxruntime/core/flatbuffers/checkpoint_version.h ++++ b/onnxruntime/core/flatbuffers/checkpoint_version.h +@@ -13,9 +13,7 @@ namespace onnxruntime { + // The format includes support for the ModuleState (stores the module parameters), OptimizerGroups + // (stores the optimizer states), and PropertyBag + // (stores custom user properties with support for int64, float and strings). +-// Version 2: Introduces the On-Device Training nominal checkpoint state. +-// Changes include the addition of the is_nominal_state field in the checkpoint's ModuleState. +-constexpr const int kCheckpointVersion = 2; ++constexpr const int kCheckpointVersion = 1; + + /** + * @brief Check if the given checkpoint version is supported in this build +diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py +index 19c6b1b6f..2be826fee 100644 +--- a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py ++++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py +@@ -74,17 +74,9 @@ class ModuleState(object): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) + return o == 0 + +- # ModuleState +- def IsNominalState(self): +- o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) +- if o != 0: +- return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) +- return False +- +-def ModuleStateStart(builder): builder.StartObject(3) ++def ModuleStateStart(builder): builder.StartObject(2) + def ModuleStateAddRequiresGradParams(builder, requiresGradParams): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(requiresGradParams), 0) + def ModuleStateStartRequiresGradParamsVector(builder, numElems): return builder.StartVector(4, numElems, 4) + def ModuleStateAddFrozenParams(builder, frozenParams): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(frozenParams), 0) + def ModuleStateStartFrozenParamsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +-def ModuleStateAddIsNominalState(builder, isNominalState): builder.PrependBoolSlot(2, isNominalState, 0) + def ModuleStateEnd(builder): return builder.EndObject() +diff --git a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs +index 94757fa6d..c8244b0a4 100644 +--- a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs ++++ b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs +@@ -8,10 +8,6 @@ namespace onnxruntime.fbs; + table ModuleState { + requires_grad_params:[Tensor]; + frozen_params:[Tensor]; +- // Nominal state just means that the Tensors in the ModuleState +- // are empty. i.e. The tensors are treated as named entities +- // without any meaningful data. +- is_nominal_state:bool; + } + + table ParameterOptimizerState { +diff --git a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h +index d205c5eb8..48feebb19 100644 +--- a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h ++++ b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h +@@ -39,8 +39,7 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ModuleStateBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_REQUIRES_GRAD_PARAMS = 4, +- VT_FROZEN_PARAMS = 6, +- VT_IS_NOMINAL_STATE = 8 ++ VT_FROZEN_PARAMS = 6 + }; + const flatbuffers::Vector> *requires_grad_params() const { + return GetPointer> *>(VT_REQUIRES_GRAD_PARAMS); +@@ -48,9 +47,6 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + const flatbuffers::Vector> *frozen_params() const { + return GetPointer> *>(VT_FROZEN_PARAMS); + } +- bool is_nominal_state() const { +- return GetField(VT_IS_NOMINAL_STATE, 0) != 0; +- } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_REQUIRES_GRAD_PARAMS) && +@@ -59,7 +55,6 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + VerifyOffset(verifier, VT_FROZEN_PARAMS) && + verifier.VerifyVector(frozen_params()) && + verifier.VerifyVectorOfTables(frozen_params()) && +- VerifyField(verifier, VT_IS_NOMINAL_STATE) && + verifier.EndTable(); + } + }; +@@ -74,9 +69,6 @@ struct ModuleStateBuilder { + void add_frozen_params(flatbuffers::Offset>> frozen_params) { + fbb_.AddOffset(ModuleState::VT_FROZEN_PARAMS, frozen_params); + } +- void add_is_nominal_state(bool is_nominal_state) { +- fbb_.AddElement(ModuleState::VT_IS_NOMINAL_STATE, static_cast(is_nominal_state), 0); +- } + explicit ModuleStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); +@@ -92,27 +84,23 @@ struct ModuleStateBuilder { + inline flatbuffers::Offset CreateModuleState( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> requires_grad_params = 0, +- flatbuffers::Offset>> frozen_params = 0, +- bool is_nominal_state = false) { ++ flatbuffers::Offset>> frozen_params = 0) { + ModuleStateBuilder builder_(_fbb); + builder_.add_frozen_params(frozen_params); + builder_.add_requires_grad_params(requires_grad_params); +- builder_.add_is_nominal_state(is_nominal_state); + return builder_.Finish(); + } + + inline flatbuffers::Offset CreateModuleStateDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector> *requires_grad_params = nullptr, +- const std::vector> *frozen_params = nullptr, +- bool is_nominal_state = false) { ++ const std::vector> *frozen_params = nullptr) { + auto requires_grad_params__ = requires_grad_params ? _fbb.CreateVector>(*requires_grad_params) : 0; + auto frozen_params__ = frozen_params ? _fbb.CreateVector>(*frozen_params) : 0; + return onnxruntime::fbs::CreateModuleState( + _fbb, + requires_grad_params__, +- frozen_params__, +- is_nominal_state); ++ frozen_params__); + } + + struct ParameterOptimizerState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { +diff --git a/onnxruntime/core/framework/execution_provider.cc b/onnxruntime/core/framework/execution_provider.cc +index b39924d4c..7f8009216 100644 +--- a/onnxruntime/core/framework/execution_provider.cc ++++ b/onnxruntime/core/framework/execution_provider.cc +@@ -35,4 +35,77 @@ common::Status IExecutionProvider::Compile(const std::vector& + } + + #endif ++ ++int IExecutionProvider::ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_viewer, ++ HashValue& model_hash) { ++ model_hash = 0; ++ ++ // find the top level graph ++ const Graph* cur_graph = &graph_viewer.GetGraph(); ++ while (cur_graph->IsSubgraph()) { ++ cur_graph = cur_graph->ParentGraph(); ++ } ++ ++ uint32_t instance_hash[4] = {0, 0, 0, 0}; ++ ++ const Graph& main_graph = *cur_graph; ++ ++ // hash the bytes in the Graph instance. we can't just use the address as a new Graph instance may use ++ // the same memory (unit tests prove this can occur). the raw bytes of the Graph instance should be a unique ++ // fingerprint for the instance that can use used as the key to the hash of the model path/contents. ++ MurmurHash3::x86_128(&main_graph, gsl::narrow_cast(sizeof(Graph)), instance_hash[0], &instance_hash); ++ HashValue graph_instance_hash = instance_hash[0] | (uint64_t(instance_hash[1]) << 32); ++ ++ // if we've already hashed this main graph instance use the cached value ++ auto entry = main_graph_hash_.find(graph_instance_hash); ++ if (entry != main_graph_hash_.cend()) { ++ model_hash = entry->second; ++ } else { ++ uint32_t hash[4] = {0, 0, 0, 0}; ++ ++ // prefer path the model was loaded from ++ // this may not be available if the model was loaded from a stream or in-memory bytes ++ const auto& model_path_str = main_graph.ModelPath().ToPathString(); ++ if (!model_path_str.empty()) { ++ MurmurHash3::x86_128(model_path_str.data(), gsl::narrow_cast(model_path_str.size()), hash[0], &hash); ++ } else { ++ auto hash_str = [&hash](const std::string& str) { ++ MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); ++ }; ++ ++ // fingerprint the main graph by hashing graph inputs and the ordered outputs from each node ++ for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) { ++ hash_str(node_arg->Name()); ++ } ++ ++ // note: process nodes in order defined in model to be deterministic ++ for (const auto& node : main_graph.Nodes()) { ++ for (const auto* node_arg : node.OutputDefs()) { ++ if (node_arg->Exists()) { ++ hash_str(node_arg->Name()); ++ } ++ } ++ } ++ } ++ ++ model_hash = hash[0] | (uint64_t(hash[1]) << 32); ++ ++ main_graph_hash_[graph_instance_hash] = model_hash; ++ } ++ ++ // return the current unique id, and increment to update ++ return model_metadef_id_[model_hash]++; ++} ++ ++int IExecutionProvider::GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const { ++ ORT_ENFORCE(metadef_id_generator_, ++ "IExecutionProvider constructor must be called with true for use_metadef_id_creator"); ++ ++ // if the EP is shared across multiple sessions there's a very small potential for concurrency issues. ++ // use a lock when generating an id to be paranoid ++ static OrtMutex mutex; ++ std::lock_guard lock(mutex); ++ return metadef_id_generator_->GenerateId(graph_viewer, model_hash); ++} ++ + } // namespace onnxruntime +diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc +index 90ee8a46f..07b465c80 100644 +--- a/onnxruntime/core/framework/graph_partitioner.cc ++++ b/onnxruntime/core/framework/graph_partitioner.cc +@@ -645,10 +645,6 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers + all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end()); + } + +- if (all_ep_context_nodes.size() < 1) { +- return Status::OK(); +- } +- + auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { + for (auto& node : all_ep_context_nodes) { + if (node_name == node->Name()) { +@@ -660,69 +656,75 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers + + onnxruntime::PathString context_cache_path; + PathString model_pathstring = graph.ModelPath().ToPathString(); ++ if (all_ep_context_nodes.size() > 0) { ++ if (!ep_context_path.empty()) { ++ context_cache_path = ToPathString(ep_context_path); ++ } else if (!model_pathstring.empty()) { ++ context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); ++ } + +- if (!ep_context_path.empty()) { +- context_cache_path = ToPathString(ep_context_path); +- } else if (!model_pathstring.empty()) { +- context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); +- } +- +- { ++ { + #ifdef _WIN32 +- std::wifstream fs(context_cache_path); ++ std::wifstream fs(context_cache_path); + #else +- std::ifstream fs(context_cache_path); ++ std::ifstream fs(context_cache_path); + #endif +- ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); +- } +- +- Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), +- graph.DomainToVersionMap(), {}, logger); +- auto& ep_graph = ep_context_model.MainGraph(); +- ep_graph.SetDescription(graph.Description()); +- +- // Set inputs outputs explicitly to make sure the order is same as the user model. +- auto inputs = graph.GetInputs(); +- auto outputs = graph.GetOutputs(); ++ ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); ++ } + +- InlinedVector ep_graph_inputs; +- ep_graph_inputs.reserve(inputs.size()); +- for (auto& input : inputs) { +- auto input_arg = graph.GetNodeArg(input->Name()); +- auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); +- ep_graph_inputs.push_back(&ep_graph_input_arg); +- } ++ Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), ++ graph.DomainToVersionMap(), {}, logger); ++ auto& ep_graph = ep_context_model.MainGraph(); ++ ep_graph.SetDescription(graph.Description()); ++ ++ // Set inputs outputs explicitly to make sure the order is same as the user model. ++ auto inputs = graph.GetInputs(); ++ auto outputs = graph.GetOutputs(); ++ ++ InlinedVector ep_graph_inputs; ++ ep_graph_inputs.reserve(inputs.size()); ++ for (auto& input : inputs) { ++ auto input_arg = graph.GetNodeArg(input->Name()); ++ auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); ++ ep_graph_inputs.push_back(&ep_graph_input_arg); ++ } + +- InlinedVector ep_graph_outputs; +- ep_graph_outputs.reserve(outputs.size()); +- for (auto& output : outputs) { +- auto output_arg = graph.GetNodeArg(output->Name()); +- auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); +- ep_graph_outputs.push_back(&ep_graph_output_arg); +- } ++ InlinedVector ep_graph_outputs; ++ ep_graph_outputs.reserve(outputs.size()); ++ for (auto& output : outputs) { ++ auto output_arg = graph.GetNodeArg(output->Name()); ++ auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); ++ ep_graph_outputs.push_back(&ep_graph_output_arg); ++ } + +- ep_graph.SetInputs(ep_graph_inputs); +- ep_graph.SetOutputs(ep_graph_outputs); ++ ep_graph.SetInputs(ep_graph_inputs); ++ ep_graph.SetOutputs(ep_graph_outputs); + +- for (const auto& node : graph.Nodes()) { +- // the fused node and EPContext node has same node name +- auto ep_context_node = get_ep_context_node(node.Name()); +- // Use EpContext node created by the EPs if name matched, otherwise use node from original model +- if (ep_context_node.first) { +- ep_graph.AddNode(*ep_context_node.second); +- } else { +- ep_graph.AddNode(node); ++ for (const auto& node : graph.Nodes()) { ++ // the fused node and EPContext node has same node name ++ auto ep_context_node = get_ep_context_node(node.Name()); ++ // Use EpContext node created by the EPs if name matched, otherwise use node from original model ++ if (ep_context_node.first) { ++ ep_graph.AddNode(*ep_context_node.second); ++ } else { ++ ep_graph.AddNode(node); ++ } + } +- } + +- // handle initializers +- for (const auto& initialized_tensor : graph.GetAllInitializedTensors()) { +- if (ep_graph.GetNodeArg(initialized_tensor.first) != nullptr) { +- ep_graph.AddInitializedTensor(*initialized_tensor.second); ++ // handle initializers ++ for (const auto& input : graph.GetInputsIncludingInitializers()) { ++ const ONNX_NAMESPACE::TensorProto* initializer = nullptr; ++ if (graph.GetInitializedTensor(input->Name(), initializer)) { ++ // There initializer could have duplicates so make sure we only add once ++ const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; ++ if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) { ++ ep_graph.AddInitializedTensor(*initializer); ++ } ++ } + } +- } + +- ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); ++ ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); ++ } + + return Status::OK(); + } +diff --git a/onnxruntime/core/framework/model_metadef_id_generator.cc b/onnxruntime/core/framework/model_metadef_id_generator.cc +deleted file mode 100644 +index e51c6ebc2..000000000 +--- a/onnxruntime/core/framework/model_metadef_id_generator.cc ++++ /dev/null +@@ -1,75 +0,0 @@ +-// Copyright (c) Microsoft Corporation. All rights reserved. +-// Licensed under the MIT License. +-#include +-#include "model_metadef_id_generator.h" +-#include "core/platform/ort_mutex.h" +-#include "core/graph/graph_viewer.h" +-#include "core/framework/murmurhash3.h" +- +-namespace onnxruntime { +-int ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_viewer, +- HashValue& model_hash) const { +- // if the EP is shared across multiple sessions there's a very small potential for concurrency issues. +- // use a lock when generating an id to be paranoid +- static OrtMutex mutex; +- std::lock_guard lock(mutex); +- model_hash = 0; +- +- // find the top level graph +- const Graph* cur_graph = &graph_viewer.GetGraph(); +- while (cur_graph->IsSubgraph()) { +- cur_graph = cur_graph->ParentGraph(); +- } +- +- uint32_t instance_hash[4] = {0, 0, 0, 0}; +- +- const Graph& main_graph = *cur_graph; +- +- // hash the bytes in the Graph instance. we can't just use the address as a new Graph instance may use +- // the same memory (unit tests prove this can occur). the raw bytes of the Graph instance should be a unique +- // fingerprint for the instance that can use used as the key to the hash of the model path/contents. +- MurmurHash3::x86_128(&main_graph, gsl::narrow_cast(sizeof(Graph)), instance_hash[0], &instance_hash); +- HashValue graph_instance_hash = instance_hash[0] | (uint64_t(instance_hash[1]) << 32); +- +- // if we've already hashed this main graph instance use the cached value +- auto entry = main_graph_hash_.find(graph_instance_hash); +- if (entry != main_graph_hash_.cend()) { +- model_hash = entry->second; +- } else { +- uint32_t hash[4] = {0, 0, 0, 0}; +- +- // prefer path the model was loaded from +- // this may not be available if the model was loaded from a stream or in-memory bytes +- const auto& model_path_str = main_graph.ModelPath().ToPathString(); +- if (!model_path_str.empty()) { +- MurmurHash3::x86_128(model_path_str.data(), gsl::narrow_cast(model_path_str.size()), hash[0], &hash); +- } else { +- auto hash_str = [&hash](const std::string& str) { +- MurmurHash3::x86_128(str.data(), gsl::narrow_cast(str.size()), hash[0], &hash); +- }; +- +- // fingerprint the main graph by hashing graph inputs and the ordered outputs from each node +- for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) { +- hash_str(node_arg->Name()); +- } +- +- // note: process nodes in order defined in model to be deterministic +- for (const auto& node : main_graph.Nodes()) { +- for (const auto* node_arg : node.OutputDefs()) { +- if (node_arg->Exists()) { +- hash_str(node_arg->Name()); +- } +- } +- } +- } +- +- model_hash = hash[0] | (uint64_t(hash[1]) << 32); +- +- main_graph_hash_[graph_instance_hash] = model_hash; +- } +- +- // return the current unique id, and increment to update +- return model_metadef_id_[model_hash]++; +-} +- +-} // namespace onnxruntime +diff --git a/onnxruntime/core/framework/model_metadef_id_generator.h b/onnxruntime/core/framework/model_metadef_id_generator.h +deleted file mode 100644 +index 82f68c42b..000000000 +--- a/onnxruntime/core/framework/model_metadef_id_generator.h ++++ /dev/null +@@ -1,31 +0,0 @@ +-// Copyright (c) Microsoft Corporation. All rights reserved. +-// Licensed under the MIT License. +- +-#pragma once +-#include +-#include "core/common/basic_types.h" +-namespace onnxruntime { +-class GraphViewer; +- +-/// +-/// helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across +-/// multiple sessions. +-/// +-class ModelMetadefIdGenerator { +- public: +- /** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance. +- The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models. +- @param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph. +- @param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model. +- This is created using the model path if available, +- or the model input names and the output names from all nodes in the main graph. +- */ +- int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const; +- +- private: +- // mutable as these are caches so we can minimize the hashing required on each usage of GenerateId +- mutable std::unordered_map main_graph_hash_; // map graph instance hash to model contents hash +- mutable std::unordered_map model_metadef_id_; // current unique id for model +-}; +- +-} // namespace onnxruntime +diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +index 27c968a59..982e8fd83 100644 +--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc ++++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +@@ -1231,7 +1231,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, + "In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) " + "are treated as stop of the extra_decoding_ids for corresponding batch.", + "I", OpSchema::Optional) +- .Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional) + .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") + .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) + .Output(2, "scores", +diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc +index 902839bee..f71b7eceb 100644 +--- a/onnxruntime/core/graph/graph.cc ++++ b/onnxruntime/core/graph/graph.cc +@@ -2798,13 +2798,12 @@ Status Graph::Resolve(const ResolveOptions& options) { + graph.GraphProtoSyncNeeded(false); + } + +- // set num_resolves_ here so the graph and any subgraphs all have the same value +- ++graph.num_resolves_; +- + return Status::OK(); }; + + ORT_RETURN_IF_ERROR(ForThisAndAllSubgraphs(all_subgraphs, finalize_func)); + ++ ++num_resolves_; ++ + return Status::OK(); + } + +diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc +index 6d7ed94b2..8e9624035 100644 +--- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc ++++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc +@@ -392,14 +392,6 @@ Status LoadOrtTensorOrtFormat(const fbs::Tensor& fbs_tensor, const AllocatorPtr + ort_tensor = onnxruntime::Tensor( + tensor_dtype, TensorShape(tensor_dims->data(), tensor_dims->size()), allocator); + +- if (fbs_tensor.raw_data()->size() == 0U) { +- // Empty tensor. Nothing to unpack. +- // This check is necessary because an empty ort tensor will return a size of 1. +- // As a result, the following call to UnpackTensor will fail since the src and +- // dst sizes do not match (0 and 1 elements). +- return Status::OK(); +- } +- + // The tensor proto is used as a dummy here. The actual data is stored in the raw_data field of the flatbuffer. + // The data is copied from the raw_data field to the ort_tensor. + ONNX_NAMESPACE::TensorProto unused_tensor_proto; +diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h +index 32e9cc981..047011e70 100644 +--- a/onnxruntime/core/mlas/inc/mlas_qnbit.h ++++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h +@@ -37,7 +37,9 @@ typedef enum { + + CompMostAccurate = CompUndef, + CompLeastAccurate = CompInt8, +-} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; ++} MLAS_SQNBIT_COMPUTE_TYPE; ++ ++using MLAS_SQNBIT_GEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these + + /** + * @brief Data parameters for float/n-bit quantized int GEMM routine. +@@ -100,12 +102,18 @@ MlasSQNBitGemmBatch( + /** + * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * ++ * @param[in] M row size of matrix A and C ++ * @param[in] N column size of matrix B and C ++ * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ + bool MLASCALL + MlasIsSQNBitGemmAvailable( ++ size_t M, ++ size_t N, ++ size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +@@ -145,15 +153,13 @@ MlasSQNBitGemmBatchWorkspaceSize( + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block +- * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ + size_t MLASCALL + MlasSQNBitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, +- size_t BlkLen, +- MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ++ size_t BlkLen + ); + + /** +@@ -163,7 +169,6 @@ MlasSQNBitGemmPackQuantBDataSize( + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block +- * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + * @param[in] QuantBData quantized B data + * @param[out] PackedQuantBData packed quantized B data + * @param[in] ThreadPool optional thread pool to use +@@ -174,7 +179,6 @@ MlasSQNBitGemmPackQuantBData( + size_t K, + size_t BlkBitWidth, + size_t BlkLen, +- MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const void* QuantBData, + void* PackedQuantBData, + MLAS_THREADPOOL* ThreadPool = nullptr +diff --git a/onnxruntime/core/mlas/lib/amx_common.h b/onnxruntime/core/mlas/lib/amx_common.h +index caf94af02..3eb070093 100644 +--- a/onnxruntime/core/mlas/lib/amx_common.h ++++ b/onnxruntime/core/mlas/lib/amx_common.h +@@ -18,7 +18,7 @@ Abstract: + + #include "mlasi.h" + +-#ifdef _WIN32 ++#ifdef WIN32 + #define tile_dpbssd(dst, src1, src2) _tile_dpbssd(dst, src1, src2) + + #define tile_dpbsud(dst, src1, src2) _tile_dpbsud(dst, src1, src2) +diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +index 38c31c884..0d8a56923 100644 +--- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp ++++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +@@ -39,17 +39,23 @@ enum SQNBitGemmVariant { + + SQNBitGemmVariant + GetSQNBitGemmVariant( ++ size_t M, ++ size_t N, ++ size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + ) + { ++ MLAS_UNREFERENCED_PARAMETER(N); ++ MLAS_UNREFERENCED_PARAMETER(K); ++ + if (BlkBitWidth == 4 && + (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { + if (ComputeType == CompFp32 || + ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 + return SQNBitGemmVariant_BitWidth4_CompFp32; +- } else if (ComputeType == CompInt8) { ++ } else if (ComputeType == CompInt8 && M == 1) { + return SQNBitGemmVariant_BitWidth4_CompInt8; + } + } +@@ -61,6 +67,9 @@ GetSQNBitGemmVariant( + + bool MLASCALL + MlasIsSQNBitGemmAvailable( ++ size_t M, ++ size_t N, ++ size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +@@ -71,7 +80,7 @@ MlasIsSQNBitGemmAvailable( + return false; + } + +- const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); ++ const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + + switch (Variant) { + case SQNBitGemmVariant_BitWidth4_CompFp32: { +@@ -155,7 +164,7 @@ MlasSQNBitGemmBatchWorkspaceSize( + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + ) + { +- const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); ++ const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + + const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); + if (PerGemmWorkspaceStride == 0) { +@@ -169,24 +178,91 @@ MlasSQNBitGemmBatchWorkspaceSize( + return WorkspaceSize + Alignment - 1; + } + ++namespace ++{ ++ ++void ++SQ4BitGemmPackQuantBData( ++ size_t N, ++ size_t K, ++ size_t BlkLen, ++ const std::byte* QuantBDataBegin, ++ std::byte* PackedQuantBDataBegin, ++ MLAS_THREADPOOL* ThreadPool ++) ++{ ++ constexpr size_t BlkBitWidth = 4; ++ ++ assert(BlkLen % 16 == 0); ++ ++ const size_t BlockCountK = MlasDivRoundup(K, BlkLen); ++ const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); ++ const size_t Iterations = N * BlockCountK; // one iteration per block ++ ++ MlasTrySimpleParallel( ++ ThreadPool, Iterations, ++ [&](ptrdiff_t tid) { ++ const size_t n = tid / BlockCountK; ++ const size_t k_blk = tid % BlockCountK; ++ ++ const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; ++ const std::byte* QuantBData = QuantBDataBegin + data_offset; ++ std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; ++ ++ // ++ // Pack 16 4-bit values (8 bytes) at a time like this: ++ // ++ // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | ++ // => ++ // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | ++ // ++ for (size_t kk = 0; kk < BlkLen; kk += 16) { ++ for (size_t byte_pair_idx = 0; byte_pair_idx < 4; ++byte_pair_idx) { ++ const std::byte src0 = QuantBData[byte_pair_idx]; ++ const std::byte src1 = QuantBData[byte_pair_idx + 4]; ++ ++ std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; ++ std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; ++ ++ dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); ++ dst1 = (src0 >> 4) | ((src1 >> 4) << 4); ++ } ++ ++ QuantBData += 8; ++ PackedQuantBData += 8; ++ } ++ } ++ ); ++} ++ ++} // namespace ++ + size_t MLASCALL + MlasSQNBitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, +- size_t BlkLen, +- MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ++ size_t BlkLen + ) + { +- const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; +- if (Dispatch == nullptr) { +- return 0; ++ // Ensure that a general implementation is available on this platform. ++ // For now, all implementations share the same packed format. ++ { ++ // Currently, there are implementations specific to M = 1, so pick a more general M > 1. ++ constexpr size_t M = 2; ++ // A CompUndef implementation should be available if any is available. ++ constexpr MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType = CompUndef; ++ const bool HasGeneralImplementation = ++ MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType); ++ if (!HasGeneralImplementation) { ++ return 0; ++ } + } + +- if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { +- return Dispatch->SQ4BitGemmPackQuantBDataSize( +- N, K, BlkLen, ComputeType +- ); ++ if (BlkBitWidth == 4) { ++ const size_t BlockCountK = MlasDivRoundup(K, BlkLen); ++ const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); ++ return PackedQuantBDataSize; + } + + return 0; +@@ -198,28 +274,20 @@ MlasSQNBitGemmPackQuantBData( + size_t K, + size_t BlkBitWidth, + size_t BlkLen, +- MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const void* QuantBData, + void* PackedQuantBData, + MLAS_THREADPOOL* ThreadPool + ) + { +- const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; +- if (Dispatch == nullptr) { +- return; +- } +- +- if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBData != nullptr) { +- Dispatch->SQ4BitGemmPackQuantBData( ++ if (BlkBitWidth == 4) { ++ SQ4BitGemmPackQuantBData( + N, + K, + BlkLen, +- ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBData), + ThreadPool + ); +- return; + } + } + +@@ -444,37 +512,7 @@ SQ4BitGemm_CompInt8( + return; + } + +- // This is a naive M > 1 implementation that repeatedly calls the M=1 kernel. +- // TODO Replace it with an optimized implementation. +- size_t CountN; +- for (size_t n = 0; n < RangeCountN; n += CountN) { +- CountN = std::min(RangeCountN - n, size_t{128}); +- +- const std::byte* a_row = QuantA; +- const std::byte* b_col = QuantBData + n * ldb; +- const float* b_col_scale = QuantBScale + n * k_blks; +- const std::byte* b_col_zp = +- (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; +- float* c_blk = C + n; +- const float* bias = (Bias == nullptr) ? nullptr : Bias + n; +- +- for (size_t m = 0; m < RangeCountM; ++m) { +- GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( +- BlkLen, +- a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias +- ); +- +- if (DataParams->PostProcessor != nullptr) { +- DataParams->PostProcessor->Process( +- DataParams->C, RangeStartM, RangeStartN + n, +- RangeCountM, CountN, ldc +- ); +- } +- +- c_blk += ldc; +- a_row += lda; +- } +- } ++ assert(false && "not implemented for M > 1"); + } + + typedef void(InitializeWorkspaceFn)( +@@ -556,7 +594,7 @@ MlasSQNBitGemmBatch( + MLAS_THREADPOOL* ThreadPool + ) + { +- const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); ++ const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + assert(Variant != SQNBitGemmVariantInvalid); + + // +diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h +index 3992bc3e4..a66db79dc 100644 +--- a/onnxruntime/core/mlas/lib/sqnbitgemm.h ++++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h +@@ -99,33 +99,6 @@ Q8BlkAlignment() + // + + struct MLAS_SQNBIT_GEMM_DISPATCH { +- // +- // Quantized B data packing function prototypes. +- // +- +- /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ +- typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( +- size_t N, +- size_t K, +- size_t BlkLen, +- MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +- ); +- +- SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; +- +- /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ +- typedef void(SQ4BitGemmPackQuantBData_Fn)( +- size_t N, +- size_t K, +- size_t BlkLen, +- MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, +- const std::byte* QuantBDataBegin, +- std::byte* PackedQuantBDataBegin, +- MLAS_THREADPOOL* ThreadPool +- ); +- +- SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; +- + // + // CompFp32 kernel function prototypes. + // +diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +index c4c54a9be..69fd427fa 100644 +--- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp ++++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +@@ -15,115 +15,14 @@ Abstract: + + --*/ + ++#include "sqnbitgemm.h" ++ + #include + + #include + #include + #include + +-#include "sqnbitgemm.h" +- +-// +-// Quantized B data packing function implementation. +-// +- +-namespace +-{ +- +-size_t +-SQ4BitGemmPackQuantBDataSize( +- size_t N, +- size_t K, +- size_t BlkLen, +- MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +-) +-{ +- MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType +- +- constexpr size_t BlkBitWidth = 4; +- +- const size_t BlockCountK = MlasDivRoundup(K, BlkLen); +- const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); +- return PackedQuantBDataSize; +-} +- +-void +-SQ4BitGemmPackQuantBData( +- size_t N, +- size_t K, +- size_t BlkLen, +- MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, +- const std::byte* QuantBDataBegin, +- std::byte* PackedQuantBDataBegin, +- MLAS_THREADPOOL* ThreadPool +-) +-{ +- constexpr size_t BlkBitWidth = 4; +- +- assert(BlkLen >= 16 && BlkLen % 16 == 0); +- +- const size_t BlockCountK = MlasDivRoundup(K, BlkLen); +- const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); +- const size_t Iterations = N * BlockCountK; // one iteration per block +- +- const size_t SubBlkLen = (ComputeType == CompInt8) +- ? ((BlkLen == 16) ? 16 : 32) +- : 16; +- +- const size_t SubBlkDataSize = SubBlkLen / 2; +- const size_t SubBlkBytePairCount = SubBlkLen / 4; +- +- // +- // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: +- // +- // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | +- // => +- // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | +- // +- +- // +- // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: +- // +- // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | +- // => +- // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | +- // +- +- MlasTrySimpleParallel( +- ThreadPool, Iterations, +- [&](ptrdiff_t tid) { +- const size_t n = tid / BlockCountK; +- const size_t k_blk = tid % BlockCountK; +- +- const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; +- const std::byte* QuantBData = QuantBDataBegin + data_offset; +- std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; +- +- for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { +- for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { +- const std::byte src0 = QuantBData[byte_pair_idx]; +- const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; +- +- std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; +- std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; +- +- dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); +- dst1 = (src0 >> 4) | ((src1 >> 4) << 4); +- } +- +- QuantBData += SubBlkDataSize; +- PackedQuantBData += SubBlkDataSize; +- } +- } +- ); +-} +- +-} // namespace +- +-// +-// General helpers. +-// +- + namespace + { + +@@ -196,16 +95,7 @@ LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) + } + } + +-} // namespace +- +-// +-// CompFp32 kernel implementation. +-// +- +-namespace +-{ +- +-template ++template + MLAS_FORCEINLINE void + ComputeDotProducts_BlkBitWidth4_CompFp32( + size_t BlkLen, +@@ -222,11 +112,11 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( + ) + { + constexpr size_t BlkBitWidth = 4; +- constexpr size_t SubBlkLen = 16; + + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + +- assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); ++ constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration ++ assert(BlkLen % SubBlkLen == 0); + + const uint8x8_t LowMask = vdup_n_u8(0x0F); + +@@ -247,8 +137,7 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( + + const std::byte* QuantBData = QuantBDataColPtr; + const float* QuantBScale = QuantBScaleColPtr; +- [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer +- // only used if HasZeroPoint == true ++ size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + for (size_t k = 0; k < CountK; k += BlkLen) { + const size_t k_blk_len = std::min(CountK - k, BlkLen); +@@ -258,9 +147,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( + [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } + ); + +- [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset of 16. +- // only used if HasZeroPoint == true +- if constexpr (HasZeroPoint) { ++ float offset[NCols]; // Includes zero point and float conversion offset of 16. ++ if (QuantBZeroPointColPtr != nullptr) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; +@@ -269,6 +157,11 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( + : (zp_packed & std::byte{0x0F}); + offset[i] = 16.0f + std::to_integer(zp); + }); ++ } else { ++ UnrolledLoop([&](size_t i) { ++ constexpr float zp = 8.0f; ++ offset[i] = 16.0f + zp; ++ }); + } + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { +@@ -294,6 +187,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( + bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); + }); + ++ // dequantize B ++ + // shift left 3 and widen to 16 bits + uint16x8_t bv_u16[NCols][2]; + UnrolledLoop([&](size_t i) { +@@ -322,17 +217,10 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( + }); + + // subtract float conversion offset (16) and zero point +- if constexpr (HasZeroPoint) { +- UnrolledLoop([&](size_t i) { +- const float32x4_t offset_v = vdupq_n_f32(offset[i]); +- UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); +- }); +- } else { +- const float32x4_t offset_v = vdupq_n_f32(16.0f + 8.0f); +- UnrolledLoop([&](size_t i) { +- UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); +- }); +- } ++ UnrolledLoop([&](size_t i) { ++ const float32x4_t offset_v = vdupq_n_f32(offset[i]); ++ UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); ++ }); + + // multiply by scale + UnrolledLoop([&](size_t i) { +@@ -349,9 +237,7 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( + // increment pointers to next block + QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScale += 1; +- if constexpr (HasZeroPoint) { +- QuantBZeroPointIdx += 1; +- } ++ QuantBZeroPointIdx += 1; + } + + if constexpr (NCols == 4) { +@@ -372,9 +258,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( + } + } + +-template +-void +-SQ4BitGemmM1Kernel_CompFp32_Impl( ++MLAS_FORCEINLINE void ++SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, +@@ -410,7 +295,7 @@ SQ4BitGemmM1Kernel_CompFp32_Impl( + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { +- ComputeDotProducts_BlkBitWidth4_CompFp32( ++ ComputeDotProducts_BlkBitWidth4_CompFp32( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, +@@ -421,7 +306,7 @@ SQ4BitGemmM1Kernel_CompFp32_Impl( + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; +- if constexpr (HasZeroPoint) { ++ if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + +@@ -434,7 +319,7 @@ SQ4BitGemmM1Kernel_CompFp32_Impl( + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { +- ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( ++ ComputeDotProducts_BlkBitWidth4_CompFp32<1>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, +@@ -445,7 +330,7 @@ SQ4BitGemmM1Kernel_CompFp32_Impl( + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; +- if constexpr (HasZeroPoint) { ++ if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + +@@ -454,49 +339,6 @@ SQ4BitGemmM1Kernel_CompFp32_Impl( + } + } + +-MLAS_FORCEINLINE void +-SQ4BitGemmM1Kernel_CompFp32( +- size_t BlkLen, +- const float* A, +- const std::byte* QuantBData, +- const float* QuantBScale, +- const std::byte* QuantBZeroPoint, +- float* C, +- size_t CountN, +- size_t CountK, +- size_t BlockStrideQuantB, +- const float* Bias +-) +-{ +- if (QuantBZeroPoint != nullptr) { +- SQ4BitGemmM1Kernel_CompFp32_Impl( +- BlkLen, +- A, +- QuantBData, +- QuantBScale, +- QuantBZeroPoint, +- C, +- CountN, +- CountK, +- BlockStrideQuantB, +- Bias +- ); +- } else { +- SQ4BitGemmM1Kernel_CompFp32_Impl( +- BlkLen, +- A, +- QuantBData, +- QuantBScale, +- QuantBZeroPoint, +- C, +- CountN, +- CountK, +- BlockStrideQuantB, +- Bias +- ); +- } +-} +- + MLAS_FORCEINLINE void + Q4BitBlkDequantBForSgemm_CompFp32( + size_t BlkLen, +@@ -511,7 +353,6 @@ Q4BitBlkDequantBForSgemm_CompFp32( + { + auto impl0_reference = [&]() { + constexpr size_t BlkBitWidth = 4; +- constexpr size_t SubBlkLen = 16; + + float* Dst = FpData; + +@@ -537,11 +378,11 @@ Q4BitBlkDequantBForSgemm_CompFp32( + : 8; + + for (size_t kk = 0; kk < kklen; ++kk) { +- const size_t packed_idx = kk % SubBlkLen; ++ const size_t packed_idx = kk % 16; + +- const bool is_low_half = packed_idx < (SubBlkLen / 2); +- const size_t packed_byte_idx = packed_idx % (SubBlkLen / 2); +- const size_t packed_range_offset = (kk / SubBlkLen) * (SubBlkLen / 2); ++ const bool is_low_half = packed_idx < 8; ++ const size_t packed_byte_idx = packed_idx % 8; ++ const size_t packed_range_offset = (kk / 16) * 8; + + const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx]; + const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4); +@@ -574,7 +415,7 @@ Q4BitBlkDequantBForSgemm_CompFp32( + } + + // +-// CompInt8 kernel implementation. ++// CompInt8 kernel implementation and related helpers + // + + template +@@ -590,6 +431,8 @@ QuantizeBlock( + + assert(BlkLen % SubBlkLen == 0); + ++ constexpr size_t VectorCount = SubBlkLen / 4; ++ + // + // Scan block values first to determine scale. + // +@@ -600,16 +443,16 @@ QuantizeBlock( + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + +- float32x4_t a[SubBlkLen / 4]{}; ++ float32x4_t a[VectorCount]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + +- float32x4_t abs_a[SubBlkLen / 4]; +- UnrolledLoop([&](size_t i) { ++ float32x4_t abs_a[VectorCount]; ++ UnrolledLoop([&](size_t i) { + abs_a[i] = vabsq_f32(a[i]); + }); + + // find amax of SubBlkLen elements +- for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { ++ for (size_t interval = VectorCount / 2; interval > 0; interval /= 2) { + for (size_t i = 0; i < interval; ++i) { + abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); + } +@@ -634,19 +477,19 @@ QuantizeBlock( + for (k = 0; k < ElementCount; k += SubBlkLen) { + const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); + +- float32x4_t a[SubBlkLen / 4]{}; ++ float32x4_t a[VectorCount]{}; + LoadFloatData(A + k, SubBlkElementCount, a); + +- UnrolledLoop([&](size_t i) { ++ UnrolledLoop([&](size_t i) { + a[i] = vmulq_n_f32(a[i], scale_reciprocal); + }); + +- int32x4_t a_s32[SubBlkLen / 4]; +- UnrolledLoop([&](size_t i) { ++ int32x4_t a_s32[VectorCount]; ++ UnrolledLoop([&](size_t i) { + a_s32[i] = vcvtaq_s32_f32(a[i]); + }); + +- UnrolledLoop([&](size_t i) { ++ UnrolledLoop([&](size_t i) { + QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); + QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); + QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); +@@ -687,7 +530,7 @@ QuantizeARow_CompInt8( + } + } + +-template ++template + MLAS_FORCEINLINE void + ComputeDotProducts_BlkBitWidth4_CompInt8( + size_t BlkLen, +@@ -703,22 +546,20 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( + const float* BiasPtr + ) + { +- constexpr size_t BlkBitWidth = 4; +- + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); +- static_assert(SubBlkLen == 16 || SubBlkLen == 32, "SubBlkLen must be 16 or 32"); + +- assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); ++ constexpr size_t BlkBitWidth = 4; + +- [[maybe_unused]] const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); // only used if SubBlkLen == 16 +- [[maybe_unused]] const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); // only used if SubBlkLen == 32 ++ constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration ++ assert(BlkLen % SubBlkLen == 0); ++ ++ const uint8x8_t LowMask = vdup_n_u8(0x0F); + + const std::byte* QuantA = QuantARowPtr; + + const std::byte* QuantBData = QuantBDataColPtr; + const float* QuantBScale = QuantBScaleColPtr; +- [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer +- // only used if HasZeroPoint == true ++ size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + + float32x4_t acc[NCols]{}; + +@@ -731,8 +572,8 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( + float b_scale[NCols]; + UnrolledLoop([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; }); + +- [[maybe_unused]] int8_t b_zp[NCols]; // only used if HasZeroPoint == true +- if constexpr (HasZeroPoint) { ++ int8_t b_zp[NCols]; ++ if (QuantBZeroPointColPtr != nullptr) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; +@@ -740,73 +581,42 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( + ? std::to_integer(zp_packed >> 4) + : std::to_integer(zp_packed & std::byte{0x0F}); + }); ++ } else { ++ UnrolledLoop([&](size_t i) { ++ b_zp[i] = 8; ++ }); + } + + for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { + // load A row vector +- int8x16_t av[SubBlkLen / 16]; +- UnrolledLoop([&](size_t i) { +- av[i] = vld1q_s8(a_data + k_idx_in_blk + i * 16); +- }); ++ int8x16_t av = vld1q_s8(a_data + k_idx_in_blk); + + // load B column vectors +- int8x16_t bv[NCols][SubBlkLen / 16]; +- ++ uint8x8_t bv_packed[NCols]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; ++ UnrolledLoop([&](size_t i) { ++ bv_packed[i] = vld1_u8( ++ reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset ++ ); ++ }); + +- if constexpr (SubBlkLen == 16) { +- uint8x8_t bv_packed[NCols]; +- UnrolledLoop([&](size_t i) { +- bv_packed[i] = vld1_u8( +- reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset +- ); +- }); +- +- UnrolledLoop([&](size_t i) { +- const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMaskU8x8)); +- const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4)); +- bv[i][0] = vcombine_s8(lo, hi); +- }); +- } else { +- static_assert(SubBlkLen == 32); +- +- uint8x16_t bv_packed[NCols]; +- UnrolledLoop([&](size_t i) { +- bv_packed[i] = vld1q_u8( +- reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset +- ); +- }); +- +- UnrolledLoop([&](size_t i) { +- bv[i][0] = vreinterpretq_s8_u8(vandq_u8(bv_packed[i], LowMaskU8x16)); +- bv[i][1] = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed[i], 4)); +- }); +- } ++ int8x16_t bv[NCols]; ++ UnrolledLoop([&](size_t i) { ++ const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMask)); ++ const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4)); ++ bv[i] = vcombine_s8(lo, hi); ++ }); + + // subtract B zero point +- if constexpr (HasZeroPoint) { +- UnrolledLoop([&](size_t i) { +- const int8x16_t zp_v = vdupq_n_s8(b_zp[i]); +- UnrolledLoop([&](size_t j) { +- bv[i][j] = vsubq_s8(bv[i][j], zp_v); +- }); +- }); +- } else { +- const int8x16_t zp_v = vdupq_n_s8(8); +- +- UnrolledLoop([&](size_t i) { +- UnrolledLoop([&](size_t j) { +- bv[i][j] = vsubq_s8(bv[i][j], zp_v); +- }); +- }); +- } ++ UnrolledLoop([&](size_t i) { ++ const int8x16_t zp_v = vdupq_n_s8(b_zp[i]); ++ bv[i] = vsubq_s8(bv[i], zp_v); ++ }); + + // compute quantized dot product + int32x4_t dot[NCols]{}; + UnrolledLoop([&](size_t i) { +- UnrolledLoop([&](size_t j) { +- dot[i] = vdotq_s32(dot[i], av[j], bv[i][j]); +- }); ++ dot[i] = vdotq_s32(dot[i], av, bv[i]); + }); + + // convert dot product result to float +@@ -826,9 +636,7 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( + QuantA += Q8BlkSize(BlkLen); + QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + QuantBScale += 1; +- if constexpr (HasZeroPoint) { +- QuantBZeroPointIdx += 1; +- } ++ QuantBZeroPointIdx += 1; + } + + if constexpr (NCols == 4) { +@@ -849,9 +657,9 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( + } + } + +-template ++MLAS_FORCEINLINE + void +-SQ4BitGemmM1Kernel_CompInt8_Impl( ++SQ4BitGemmM1Kernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, +@@ -865,6 +673,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl( + ) + { + constexpr size_t BlkBitWidth = 4; ++ constexpr size_t NCols = 4; + + const std::byte* QuantARowPtr = QuantA; + float* CRowPtr = C; +@@ -886,7 +695,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl( + int64_t nblk = static_cast(CountN) - NCols; + + while (nblk >= 0) { +- ComputeDotProducts_BlkBitWidth4_CompInt8( ++ ComputeDotProducts_BlkBitWidth4_CompInt8( + BlkLen, + QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, +@@ -897,7 +706,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl( + + QuantBDataColPtr += NCols * StrideQuantBData; + QuantBScaleColPtr += NCols * StrideQuantBScale; +- if constexpr (HasZeroPoint) { ++ if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; + } + +@@ -910,7 +719,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl( + // left over columns less than `NCols`? + nblk += NCols; + for (int64_t n = 0; n < nblk; ++n) { +- ComputeDotProducts_BlkBitWidth4_CompInt8<1, SubBlkLen, HasZeroPoint>( ++ ComputeDotProducts_BlkBitWidth4_CompInt8<1>( + BlkLen, + QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, +@@ -921,7 +730,7 @@ SQ4BitGemmM1Kernel_CompInt8_Impl( + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; +- if constexpr (HasZeroPoint) { ++ if (QuantBZeroPointColPtr != nullptr) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + +@@ -930,94 +739,6 @@ SQ4BitGemmM1Kernel_CompInt8_Impl( + } + } + +-template +-MLAS_FORCEINLINE void +-SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( +- size_t BlkLen, +- const std::byte* QuantA, +- const std::byte* QuantBData, +- const float* QuantBScale, +- const std::byte* QuantBZeroPoint, +- float* C, +- size_t CountN, +- size_t CountK, +- size_t BlockStrideQuantB, +- const float* Bias +-) +-{ +- if (BlkLen == 16) { +- SQ4BitGemmM1Kernel_CompInt8_Impl<4, 16, HasZeroPoint>( +- BlkLen, +- QuantA, +- QuantBData, +- QuantBScale, +- QuantBZeroPoint, +- C, +- CountN, +- CountK, +- BlockStrideQuantB, +- Bias +- ); +- } else { +- SQ4BitGemmM1Kernel_CompInt8_Impl<4, 32, HasZeroPoint>( +- BlkLen, +- QuantA, +- QuantBData, +- QuantBScale, +- QuantBZeroPoint, +- C, +- CountN, +- CountK, +- BlockStrideQuantB, +- Bias +- ); +- } +-} +- +-MLAS_FORCEINLINE +-void +-SQ4BitGemmM1Kernel_CompInt8( +- size_t BlkLen, +- const std::byte* QuantA, +- const std::byte* QuantBData, +- const float* QuantBScale, +- const std::byte* QuantBZeroPoint, +- float* C, +- size_t CountN, +- size_t CountK, +- size_t BlockStrideQuantB, +- const float* Bias +-) +-{ +- if (QuantBZeroPoint != nullptr) { +- SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( +- BlkLen, +- QuantA, +- QuantBData, +- QuantBScale, +- QuantBZeroPoint, +- C, +- CountN, +- CountK, +- BlockStrideQuantB, +- Bias +- ); +- } else { +- SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( +- BlkLen, +- QuantA, +- QuantBData, +- QuantBScale, +- QuantBZeroPoint, +- C, +- CountN, +- CountK, +- BlockStrideQuantB, +- Bias +- ); +- } +-} +- + } // namespace + + // +@@ -1027,12 +748,8 @@ SQ4BitGemmM1Kernel_CompInt8( + const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { + MLAS_SQNBIT_GEMM_DISPATCH d; + +- d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; +- d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; +- + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; + d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; +- + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8; + d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; + +diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc +index b7cb3ba48..d27603e4a 100644 +--- a/onnxruntime/core/optimizer/conv_activation_fusion.cc ++++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc +@@ -111,7 +111,7 @@ class ConvActivationSelector : public NodeSelector { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) { + return std::nullopt; + } +- } else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider) { ++ } else if (node_ep.empty() || node_ep == kCpuExecutionProvider) { + if (!is_supported_non_cuda_rocm_ep_activation(*next_node) && + !graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "HardSigmoid", {6})) { + return std::nullopt; +diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc +index 752b74280..127c37bd8 100644 +--- a/onnxruntime/core/providers/cann/cann_execution_provider.cc ++++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc +@@ -9,6 +9,7 @@ + #include + #include + ++#include "core/providers/shared_library/provider_api.h" + #define ORT_API_MANUAL_INIT + #include "core/session/onnxruntime_cxx_api.h" + #include "core/providers/cann/cann_execution_provider.h" +@@ -1028,14 +1029,13 @@ Status RegisterCANNKernels(KernelRegistry& kernel_registry) { + } // namespace cann + + CANNExecutionProvider::CANNExecutionProvider(const CANNExecutionProviderInfo& info) +- : IExecutionProvider{onnxruntime::kCannExecutionProvider, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_{info} { ++ : IExecutionProvider{onnxruntime::kCannExecutionProvider, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, info.device_id), true}, info_{info} { + InitProviderOrtApi(); + + CANN_CALL_THROW(aclrtSetDevice(info_.device_id)); + + soc_name_ = aclrtGetSocName(); + ORT_ENFORCE(soc_name_ != nullptr, "aclrtGetSocName return nullptr"); +- metadef_id_generator_ = ModelMetadefIdGenerator::Create(); + } + + CANNExecutionProvider::~CANNExecutionProvider() { +@@ -1197,7 +1197,7 @@ std::unique_ptr CANNExecutionProvider::GetSubGraph( + + // Generate unique kernel name for CANN subgraph + HashValue model_hash = 0; +- int id = metadef_id_generator_->GenerateId(graph_viewer, model_hash); ++ int id = GenerateMetaDefId(graph_viewer, model_hash); + auto meta_def = IndexedSubGraph_MetaDef::Create(); + meta_def->name() = graph_viewer.Name() + "_" + std::to_string(model_hash) + "_" + std::to_string(id); + +diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h +index 63ae98086..76d3d9c33 100644 +--- a/onnxruntime/core/providers/cann/cann_execution_provider.h ++++ b/onnxruntime/core/providers/cann/cann_execution_provider.h +@@ -81,7 +81,6 @@ class CANNExecutionProvider : public IExecutionProvider { + std::unordered_map modelIDs_; + std::unordered_map models_; + std::unordered_map> names_; +- std::unique_ptr metadef_id_generator_; + }; + + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/coreml/builders/coreml_spec.h b/onnxruntime/core/providers/coreml/builders/coreml_spec.h +index e9cd4af94..631bb7e25 100644 +--- a/onnxruntime/core/providers/coreml/builders/coreml_spec.h ++++ b/onnxruntime/core/providers/coreml/builders/coreml_spec.h +@@ -9,6 +9,6 @@ + #error "This file should only be included when building on Apple platforms." + #endif + +-#include "coreml_proto/Model.pb.h" ++#include "coreml/Model.pb.h" + + namespace COREML_SPEC = CoreML::Specification; +diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +index ef66e6b87..3b7bd5c18 100644 +--- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc ++++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +@@ -11,7 +11,7 @@ + #include "core/providers/shared/utils/utils.h" + #include "core/optimizer/initializer.h" + +-#include "coreml_proto/NeuralNetwork.pb.h" ++#include "coreml/NeuralNetwork.pb.h" + + namespace onnxruntime { + namespace coreml { +diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +index c133f7b82..c9973671f 100644 +--- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc ++++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +@@ -24,7 +24,7 @@ namespace onnxruntime { + constexpr const char* COREML = "CoreML"; + + CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags) +- : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider}, ++ : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider, true}, + coreml_flags_(coreml_flags) { + } + +@@ -54,7 +54,7 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie + + const auto gen_metadef_name = [&]() { + HashValue model_hash; +- int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); ++ int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + return MakeString(COREML, "_", model_hash, "_", metadef_id); + }; + +diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h +index 020173954..67050e807 100644 +--- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h ++++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h +@@ -4,7 +4,6 @@ + #pragma once + + #include "core/framework/execution_provider.h" +-#include "core/framework/model_metadef_id_generator.h" + #include "core/providers/coreml/coreml_provider_factory.h" + + namespace onnxruntime { +@@ -35,6 +34,5 @@ class CoreMLExecutionProvider : public IExecutionProvider { + #ifdef __APPLE__ + std::unordered_map> coreml_models_; + #endif +- ModelMetadefIdGenerator metadef_id_generator_; + }; + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/ArrayFeatureExtractor.proto b/onnxruntime/core/providers/coreml/mlmodel_format/ArrayFeatureExtractor.proto +new file mode 100644 +index 000000000..2b83ccbe3 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/ArrayFeatureExtractor.proto +@@ -0,0 +1,19 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++package CoreML.Specification; ++ ++/** ++ * An array feature extractor. ++ * ++ * Given an index, extracts the value at that index from its array input. ++ * Indexes are zero-based. ++ */ ++message ArrayFeatureExtractor { ++ repeated uint64 extractIndex = 1; ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/BayesianProbitRegressor.proto b/onnxruntime/core/providers/coreml/mlmodel_format/BayesianProbitRegressor.proto +new file mode 100644 +index 000000000..9688d87ce +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/BayesianProbitRegressor.proto +@@ -0,0 +1,139 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++package CoreML.Specification; ++ ++/** ++* A Bayesian probit regressor. ++* ++* The probit regression model is superficially similar to the more commonly known ++* logistic regression, with sampling distribution of the model given by ++* ++* P(y=+1|x,w) = Φ(/β) ++* ++* where w are the set of weights, ++* x are the set of features for the given event, ++* β is a model hyper-parameter, and ++* Φ is the link function, defined to be the CDF of the normal distribution. ++* The weights w[i,j] are Gaussian distributed, with mean μ[i,j] and precision 1/(σ[i,j])^2 ++* (where i indexes over features and j indexes over the values for the feature). ++* The parameter β scales the steepness of the inverse link function. ++* ++* (see https://en.wikipedia.org/wiki/Probit_model and https://en.wikipedia.org/wiki/Logistic_regression ++* for more details on probit model and logistic regression, respectively) ++* ++* Input: X ++* x represents a set of features, each taking on a discrete value (note that continuous values ++* would first need to be discretized). x can be represented as a vector where the index i is ++* the feature id and x[i] is the feature value. Alternatively, x can be represented as a matrix ++* with 2 columns where the first column indicates the feature id and the second column contains ++* the feature values, i.e. x[i,0] is the feature id and x[i,1] is the feature value. ++* ++* additional input features: ++* - "optimism": apply a mean shift to the probability, i.e. shift regression mean by o*stdev, ++* where o is the "optimism" parameter (see additional output features) ++* - "samplingScale": for sampling from posterior, multiply standard deviation by this factor ++* - "samplingTruncation": for sampling from posterior, truncate sampling distribution at given multiple of std from mean ++* ++* Output: Y ++* probability P(y|x,w) ++* ++* additional output features: ++* - mean (regression output before applying link function) ++* - variance (regression output variance before applying link function) ++* - pessimistic probability: P(y|x,w) with a mean shift parameterized by "optimism" feature ++* - sampled probability: p ~ P(y|x,w) with standard deviation scaling parametrized by "samplingScale" feature ++* and distribution truncated at multiple of standard deviation, ++* where multiple parameterized by "samplingTruncation" feature. ++* ++*/ ++ ++message BayesianProbitRegressor { ++ ++ /* ++ * Parameterization of a Gaussian distribution ++ */ ++ message Gaussian { ++ double mean = 1; ++ double precision = 2; // inverse of the variance ++ } ++ ++ /* ++ * Weight for a specific feature value ++ * The weight is represented as a Gaussian distribution ++ * with a mean and precision (1/variance) to capture ++ * uncertainty in the weight ++ */ ++ message FeatureValueWeight { ++ uint32 featureValue = 1; ++ Gaussian featureWeight = 2; ++ } ++ ++ /* ++ * Feature with associated weights (for different values) ++ * Each feature has a set of weights for the (discrete) values ++ * it can take ++ */ ++ message FeatureWeight { ++ uint32 featureId = 1; ++ repeated FeatureValueWeight weights = 2; ++ } ++ ++ uint32 numberOfFeatures = 1; ++ ++ Gaussian bias = 2; // bias term ++ ++ /* ++ * Set of features with associated weights ++ */ ++ repeated FeatureWeight features = 3; // feature weights ++ ++ /* ++ * Set this name to be the same as input feature of type multi-array (1D) ++ * in the model description you want to use as the regression input ++ */ ++ string regressionInputFeatureName = 10; ++ ++ /* ++ * Set this name to be the same as optional input feature of type double ++ * in the model description you want to use as the optimism input ++ */ ++ string optimismInputFeatureName = 11; ++ ++ /* ++ * Set this name to be the same as optional input feature of type double ++ * in the model description you want to use as the samplingScale input ++ */ ++ string samplingScaleInputFeatureName = 12; ++ ++ /* ++ * Set this name to be the same as optional input feature of type double ++ * in the model description you want to use as the samplingBounds input ++ */ ++ string samplingTruncationInputFeatureName = 13; ++ ++ /* ++ * name of 'mean' output feature ++ */ ++ string meanOutputFeatureName = 20; ++ ++ /* ++ * name of 'variance' output feature ++ */ ++ string varianceOutputFeatureName = 21; ++ ++ /* ++ * name of 'pessimistic' output feature ++ */ ++ string pessimisticProbabilityOutputFeatureName = 22; ++ ++ /* ++ * name of 'sampled' output feature: samples from the scaled posterior probability distribuiton ++ */ ++ string sampledProbabilityOutputFeatureName = 23; ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/CategoricalMapping.proto b/onnxruntime/core/providers/coreml/mlmodel_format/CategoricalMapping.proto +new file mode 100644 +index 000000000..23112d074 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/CategoricalMapping.proto +@@ -0,0 +1,38 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification; ++ ++/** ++ * A categorical mapping. ++ * ++ * This allows conversion from integers to strings, or from strings to integers. ++ */ ++message CategoricalMapping { ++ oneof MappingType { ++ // Conversion from strings to integers ++ StringToInt64Map stringToInt64Map = 1; ++ ++ // Conversion from integer to string ++ Int64ToStringMap int64ToStringMap = 2; ++ } ++ ++ /** ++ * The value returned if an input is not contained in the map above. ++ * If one of these is not set, then an error is raised on an unknown input. ++ */ ++ oneof ValueOnUnknown { ++ // Default output when converting from an integer to a string. ++ string strValue = 101; ++ ++ // Default output when converting from a string to an integer. ++ int64 int64Value = 102; ++ } ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/CustomModel.proto b/onnxruntime/core/providers/coreml/mlmodel_format/CustomModel.proto +new file mode 100644 +index 000000000..9a6d36e00 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/CustomModel.proto +@@ -0,0 +1,30 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++package CoreML.Specification; ++ ++/** ++* A parameterized model whose function is defined in code ++*/ ++message CustomModel { ++ ++ message CustomModelParamValue { ++ oneof value { ++ double doubleValue = 10; ++ string stringValue = 20; ++ int32 intValue = 30; ++ int64 longValue = 40; ++ bool boolValue = 50; ++ bytes bytesValue = 60; ++ } ++ } ++ ++ string className = 10; // The name of the class (conforming to MLCustomModel) corresponding to this model ++ map parameters = 30; ++ string description = 40; // An (optional) description provided by the model creator. This information is displayed when viewing the model, but does not affect the model's execution on device. ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/DataStructures.proto b/onnxruntime/core/providers/coreml/mlmodel_format/DataStructures.proto +new file mode 100644 +index 000000000..8b120c2d7 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/DataStructures.proto +@@ -0,0 +1,95 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "FeatureTypes.proto"; ++ ++package CoreML.Specification; ++ ++/** ++ * A mapping from a string ++ * to a 64-bit integer. ++ */ ++message StringToInt64Map { ++ map map = 1; ++} ++ ++/** ++ * A mapping from a 64-bit integer ++ * to a string. ++ */ ++message Int64ToStringMap { ++ map map = 1; ++} ++ ++/** ++ * A mapping from a string ++ * to a double-precision floating point number. ++ */ ++message StringToDoubleMap { ++ map map = 1; ++} ++ ++/** ++ * A mapping from a 64-bit integer ++ * to a double-precision floating point number. ++ */ ++message Int64ToDoubleMap { ++ map map = 1; ++} ++ ++/** ++ * A vector of strings. ++ */ ++message StringVector { ++ repeated string vector = 1; ++} ++ ++/** ++ * A vector of 64-bit integers. ++ */ ++message Int64Vector { ++ repeated int64 vector = 1; ++} ++ ++/** ++ * A vector of floating point numbers. ++ */ ++message FloatVector { ++ repeated float vector = 1; ++} ++ ++/** ++ * A vector of double-precision floating point numbers. ++ */ ++message DoubleVector { ++ repeated double vector = 1; ++} ++ ++/** ++ * A range of int64 values ++ */ ++message Int64Range { ++ int64 minValue = 1; ++ int64 maxValue = 2; ++} ++ ++/** ++ * A set of int64 values ++ */ ++message Int64Set { ++ repeated int64 values = 1; ++} ++ ++/** ++ * A range of double values ++ */ ++message DoubleRange { ++ double minValue = 1; ++ double maxValue = 2; ++} ++ +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/DictVectorizer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/DictVectorizer.proto +new file mode 100644 +index 000000000..3f94eeec1 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/DictVectorizer.proto +@@ -0,0 +1,36 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification; ++ ++/** ++ * Uses an index mapping to convert a dictionary to an array. ++ * ++ * The output array will be equal in length to the index mapping vector parameter. ++ * All keys in the input dictionary must be present in the index mapping vector. ++ * ++ * For each item in the input dictionary, insert its value in the output array. ++ * The position of the insertion is determined by the position of the item's key ++ * in the index mapping. Any keys not present in the input dictionary, will be ++ * zero in the output array. ++ * ++ * For example: if the ``stringToIndex`` parameter is set to ``["a", "c", "b", "z"]``, ++ * then an input of ``{"a": 4, "c": 8}`` will produce an output of ``[4, 8, 0, 0]``. ++ * ++ */ ++message DictVectorizer { ++ oneof Map { ++ /// String keys to indexes ++ StringVector stringToIndex = 1; ++ ++ /// Int keys to indexes ++ Int64Vector int64ToIndex = 2; ++ } ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/FeatureTypes.proto b/onnxruntime/core/providers/coreml/mlmodel_format/FeatureTypes.proto +new file mode 100644 +index 000000000..8711ac7de +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/FeatureTypes.proto +@@ -0,0 +1,224 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++package CoreML.Specification; ++ ++/** ++ * The 64-bit integer feature type. ++ */ ++message Int64FeatureType {} ++ ++/** ++ * The double-precision floating point number feature type. ++ */ ++message DoubleFeatureType {} ++ ++/** ++ * The string feature type. ++ */ ++message StringFeatureType {} ++ ++ ++message SizeRange { ++ uint64 lowerBound = 1; ++ int64 upperBound = 2; // negative value means unbound otherwise upperbound is included in range ++} ++ ++/** ++ * The image feature type. ++ */ ++message ImageFeatureType { ++ // Assumes raw (decompressed) format ++ enum ColorSpace { ++ INVALID_COLOR_SPACE = 0; ++ GRAYSCALE = 10; // 8 bits per pixel ++ RGB = 20; // 32 bits per pixel: RGBA with A channel ignored ++ BGR = 30; // 32 bits per pixel: BGRA with A channel ignored ++ } ++ ++ message ImageSize { ++ uint64 width = 1; ++ uint64 height = 2; ++ } ++ ++ message EnumeratedImageSizes { ++ repeated ImageSize sizes = 1; ++ } ++ ++ message ImageSizeRange { ++ SizeRange widthRange = 1; ++ SizeRange heightRange = 2; ++ } ++ ++ // The required or default image size is width x height ++ // ++ // If specificationVersion <= 2 or SizeFlexibility is empty, ++ // width x height is the required fixed image size ++ // ++ // If SizeFlexibility is present, width x height indicate a "default" ++ // image size which must be consistent with the flexibilty specified ++ ++ int64 width = 1; ++ int64 height = 2; ++ ++ // For specification version >= 3 you can specify image size flexibility. ++ ++ oneof SizeFlexibility { ++ ++ // Use enumeratedSizes for a set of distinct fixed sizes ++ // e.g. portrait or landscape: [80 x 100, 100 x 8] ++ // ++ // If the width x height fields above are specified then they must be ++ // one of the sizes listed. ++ // ++ // If width and height are not specified above then the default width ++ // and height will be enumeratedSizes[0] ++ // ++ // Must be non-empty ++ ++ EnumeratedImageSizes enumeratedSizes = 21; ++ ++ // Use imageSizeRange to allow for ranges of values ++ // e.g. any image greater than 10 x 20: [10..= 3 you can specify image size flexibility. ++ ++ oneof ShapeFlexibility { ++ ++ // Use enumeratedShapes for a set of distinct fixed shapes ++ // ++ // If the shape field is specified then it must be ++ // one of the enumerated shapes. ++ /// ++ // If shape is not specifed, the "default" shape will be considered ++ // enumeratedShapes[0] ++ // ++ // Must be non-empty ++ ++ EnumeratedShapes enumeratedShapes = 21; ++ ++ // Use shapeRange to allow the size of each dimension vary within ++ // indpendently specified ranges ++ // ++ // If you specify shape above it must fall in the range ++ // specified in shapeRanges. It will be treated as the default shape. ++ // ++ // If you don't specify shape above then the default shape will ++ // have shape[d] = shapeRange.sizeRanges[d].lowerBound ++ ++ ShapeRange shapeRange = 31; ++ ++ } ++ ++ oneof defaultOptionalValue { ++ int32 intDefaultValue = 41; ++ float floatDefaultValue = 51; ++ double doubleDefaultValue = 61; ++ } ++ ++} ++ ++/** ++ * The dictionary feature type. ++ */ ++message DictionaryFeatureType { ++ /** ++ * Key/value type tags, with the following restrictions: ++ * - ``keyType`` must be a hashable type ++ * - ``valueType`` is assumed to be a ``double`` ++ */ ++ oneof KeyType { ++ Int64FeatureType int64KeyType = 1; ++ StringFeatureType stringKeyType = 2; ++ } ++} ++ ++/** ++ * The Sequence feature type. ++ */ ++message SequenceFeatureType { ++ ++ /** ++ * Currently only categorical int64 and String sequences are supported ++ */ ++ oneof Type { ++ Int64FeatureType int64Type = 1; ++ StringFeatureType stringType = 3; ++ } ++ ++ // Range of allowed size/length/count of sequence ++ SizeRange sizeRange = 101; ++} ++ ++/** ++ * A feature, which may be optional. ++ */ ++message FeatureType { ++ oneof Type { ++ Int64FeatureType int64Type = 1; ++ DoubleFeatureType doubleType = 2; ++ StringFeatureType stringType = 3; ++ ImageFeatureType imageType = 4; ++ ArrayFeatureType multiArrayType = 5; ++ DictionaryFeatureType dictionaryType = 6; ++ SequenceFeatureType sequenceType = 7; ++ } ++ ++ bool isOptional = 1000; ++} ++ +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/FeatureVectorizer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/FeatureVectorizer.proto +new file mode 100644 +index 000000000..75eaf14b5 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/FeatureVectorizer.proto +@@ -0,0 +1,26 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++package CoreML.Specification; ++ ++/** ++ * A FeatureVectorizer puts one or more features into a single array. ++ * ++ * The ordering of features in the output array is determined by ++ * ``inputList``. ++ * ++ * ``inputDimensions`` is a zero based index. ++ */ ++message FeatureVectorizer { ++ message InputColumn { ++ string inputColumn = 1; ++ uint64 inputDimensions = 2; ++ } ++ ++ repeated InputColumn inputList = 1; ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/GLMClassifier.proto b/onnxruntime/core/providers/coreml/mlmodel_format/GLMClassifier.proto +new file mode 100644 +index 000000000..47f6f4a3c +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/GLMClassifier.proto +@@ -0,0 +1,43 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification; ++ ++/** ++ * A generalized linear model classifier. ++ */ ++message GLMClassifier { ++ message DoubleArray { ++ repeated double value = 1; ++ } ++ ++ enum PostEvaluationTransform { ++ Logit = 0; ++ Probit = 1; /// Only binary classification is supported for probit ++ } ++ ++ enum ClassEncoding { ++ ReferenceClass = 0; /// First class is the reference class ++ OneVsRest = 1; /// Also called One vs All ++ } ++ ++ repeated DoubleArray weights = 1; ++ repeated double offset = 2; ++ PostEvaluationTransform postEvaluationTransform = 3; ++ ClassEncoding classEncoding = 4; ++ ++ /** ++ * Required class label mapping. ++ */ ++ oneof ClassLabels { ++ StringVector stringClassLabels = 100; ++ Int64Vector int64ClassLabels = 101; ++ } ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/GLMRegressor.proto b/onnxruntime/core/providers/coreml/mlmodel_format/GLMRegressor.proto +new file mode 100644 +index 000000000..64093c4f1 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/GLMRegressor.proto +@@ -0,0 +1,28 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++package CoreML.Specification; ++ ++/** ++ * A generalized linear model regressor. ++ */ ++message GLMRegressor { ++ message DoubleArray { ++ repeated double value = 1; ++ } ++ ++ enum PostEvaluationTransform { ++ NoTransform = 0; ++ Logit = 1; ++ Probit = 2; ++ } ++ ++ repeated DoubleArray weights = 1; ++ repeated double offset = 2; ++ PostEvaluationTransform postEvaluationTransform = 3; ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Gazetteer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Gazetteer.proto +new file mode 100644 +index 000000000..6abbffaf6 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/Gazetteer.proto +@@ -0,0 +1,43 @@ ++// Copyright (c) 2019, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification.CoreMLModels; ++ ++/** ++* A model which uses an efficient probabilistic representation ++* for assigning labels to a set of strings. ++*/ ++message Gazetteer { ++ ++ /* ++ * Stores the revision number for the model, revision 2 is available on ++ * iOS, tvOS 13.0+, macOS 10.15+ ++ */ ++ uint32 revision = 1; ++ ++ /* ++ * Stores the language of the model, as specified in BCP-47 format, ++ * e.g. "en-US". See https://tools.ietf.org/html/bcp47 ++ */ ++ string language = 10; ++ ++ /* ++ * Natural Lanaguge framework's efficient representation of a gazetter. ++ */ ++ bytes modelParameterData = 100; ++ ++ /* ++ * Stores the set of output class labels ++ */ ++ oneof ClassLabels { ++ StringVector stringClassLabels = 200; ++ } ++ ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Identity.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Identity.proto +new file mode 100644 +index 000000000..123a15e59 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/Identity.proto +@@ -0,0 +1,18 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++package CoreML.Specification; ++ ++/** ++ * An identity model. ++ * ++ * This model returns given inputs as outputs, unchanged. ++ * Intended to be used for testing purposes. ++ */ ++message Identity { ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Imputer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Imputer.proto +new file mode 100644 +index 000000000..3de280b2f +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/Imputer.proto +@@ -0,0 +1,43 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification; ++ ++/** ++ * A transformer that replaces missing values with a default value, ++ * such as a statistically-derived value. ++ * ++ * If ``ReplaceValue`` is set, then missing values of that type are ++ * replaced with the corresponding value. ++ * ++ * For example: if ``replaceDoubleValue`` is set to ``NaN`` ++ * and a single ``NaN`` double value is provided as input, ++ * then it is replaced by ``imputedDoubleValue``. However ++ * if the input is an array of doubles, then any instances ++ * of ``NaN`` in the array is replaced with the corresponding ++ * value in ``imputedDoubleArray``. ++ */ ++message Imputer { ++ oneof ImputedValue { ++ double imputedDoubleValue = 1; ++ int64 imputedInt64Value = 2; ++ string imputedStringValue = 3; ++ DoubleVector imputedDoubleArray = 4; ++ Int64Vector imputedInt64Array = 5; ++ StringToDoubleMap imputedStringDictionary = 6; ++ Int64ToDoubleMap imputedInt64Dictionary = 7; ++ } ++ ++ oneof ReplaceValue { ++ double replaceDoubleValue = 11; ++ int64 replaceInt64Value = 12; ++ string replaceStringValue = 13; ++ } ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/ItemSimilarityRecommender.proto b/onnxruntime/core/providers/coreml/mlmodel_format/ItemSimilarityRecommender.proto +new file mode 100644 +index 000000000..a5a8c1109 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/ItemSimilarityRecommender.proto +@@ -0,0 +1,93 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++/** ++ * Each tree is a collection of nodes, ++ * each of which is identified by a unique identifier. ++ * ++ * Each node is either a branch or a leaf node. ++ * A branch node evaluates a value according to a behavior; ++ * if true, the node identified by ``true_child_node_id`` is evaluated next, ++ * if false, the node identified by ``false_child_node_id`` is evaluated next. ++ * A leaf node adds the evaluation value to the base prediction value ++ * to get the final prediction. ++ * ++ * A tree must have exactly one root node, ++ * which has no parent node. ++ * A tree must not terminate on a branch node. ++ * All leaf nodes must be accessible ++ * by evaluating one or more branch nodes in sequence, ++ * starting from the root node. ++ */ ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification; ++ ++ ++/** ++ * Item Similarity Recommender ++ * ++ * The Item Similarity recommender takes as input a list of items and scores, ++ * then uses that information and a table of item similarities to predict similarity ++ * scores for all items. By default, the items predicted are most similar to the given ++ * items but not part of that item set. ++ * ++ * The predicted score for a given item k is ++ * sum_(i in observed items) sim_(k,i) * (score_i - shift_k) ++ * ++ * Because only the most similar scores for each item i are stored, ++ * sim_(k,i) is often zero. ++ * ++ * For many models, the score adjustment parameter shift_j is zero -- it's occasionally used ++ * to counteract global biases for popular items. ++ * ++ * ++ * References: ++ */ ++message ItemSimilarityRecommender { ++ ++ /** The items similar to a given base item. ++ */ ++ message ConnectedItem { ++ uint64 itemId = 1; ++ double similarityScore = 2; ++ } ++ ++ /** The formula for the score of a given model as given above, with shift_k ++ * parameter given by itemScoreAdjustment, and the similar item list filling in ++ * all the known sim(k,i) scores for i given by itemID and k given by the itemID parameter in ++ * the similarItemList. ++ */ ++ message SimilarItems { ++ uint64 itemId = 1; ++ repeated ConnectedItem similarItemList = 2; ++ double itemScoreAdjustment = 3; ++ } ++ ++ repeated SimilarItems itemItemSimilarities = 1; ++ ++ /** One or none of these are given. If none are given, then the items must number 0, 1, ..., num_items - 1. ++ * If either is given, the length must be exactly num_items. ++ */ ++ StringVector itemStringIds = 2; ++ Int64Vector itemInt64Ids = 3; ++ ++ /** Input parameter names specifying different possible inputs to the recommender. ++ */ ++ string itemInputFeatureName = 10; /* Required */ ++ string numRecommendationsInputFeatureName = 11; /* Optional; defaults to all items if not given.*/ ++ string itemRestrictionInputFeatureName = 12; /* Optional. */ ++ string itemExclusionInputFeatureName = 13; /* Optional; defaults to input item list if not given. */ ++ ++ /** The predicted outputs. At least one of these must be specified. ++ */ ++ string recommendedItemListOutputFeatureName = 20; ++ string recommendedItemScoreOutputFeatureName = 21; ++ ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/LinkedModel.proto b/onnxruntime/core/providers/coreml/mlmodel_format/LinkedModel.proto +new file mode 100644 +index 000000000..b113000e8 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/LinkedModel.proto +@@ -0,0 +1,42 @@ ++// Copyright (c) 2019, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++import public "Parameters.proto"; ++ ++package CoreML.Specification; ++ ++/** ++ * A model which wraps another (compiled) model external to this one ++ */ ++message LinkedModel { ++ ++ oneof LinkType { ++ // A model located via a file system path ++ LinkedModelFile linkedModelFile = 1; ++ } ++} ++ ++// Model is referenced by a model file name and search path ++message LinkedModelFile { ++ ++ // Model file name: e.g. "MyFetureExtractor.mlmodelc" ++ StringParameter linkedModelFileName = 1; ++ ++ // Search path to find the linked model file ++ // Multiple paths can be searched using the unix-style path separator ":" ++ // Each path can be relative (to this model) or absolute ++ // ++ // An empty string is the same as teh relative search path "." ++ // which searches in the same location as this model file ++ // ++ // There are some special paths which start with $ ++ // - $BUNDLE_MAIN - Indicates to look in the main bundle ++ // - $BUNDLE_IDENTIFIER(identifier) - Looks in Bunde with given identifer ++ StringParameter linkedModelSearchPath = 2; ++} ++ ++ +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Model.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Model.proto +new file mode 100644 +index 000000000..737233f2e +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/Model.proto +@@ -0,0 +1,322 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++/** ++ * A Core ML model consists of a specification version ++ * and a model description, ++ * and can be any one of the following types: ++ * ++ * Neural Networks ++ * - `NeuralNetwork` ++ * ++ * Regressors ++ * - ``GLMRegressor`` ++ * - ``SupportVectorRegressor`` ++ * - ``TreeEnsembleRegressor`` ++ * - ``NeuralNetworkRegressor`` ++ * - ``BayesianProbitRegressor`` ++ * ++ * Classifiers ++ * - `NeuralNetworkClassifier` ++ * - `TreeEnsembleClassifier` ++ * - `GLMClassifier` ++ * - `SupportVectorClassifier` ++ * - `KNearestNeighborsClassifier` ++ * ++ * Other models ++ * - `CustomModel` ++ * - `TextClassifier` ++ * - `WordTagger` ++ * - `Gazetteer` ++ * - `WordEmbedding` ++ * - `VisionFeaturePrint` ++ * - `LinkedModel` ++ * - `SoundAnalysisPreprocessing` ++ * - `ItemSimilarityRecommender` ++ * ++ * Feature Engineering ++ * - `Imputer` ++ * - `Scaler` ++ * - `Normalizer` ++ * - `OneHotEncoder` ++ * - `CategoricalMapping` ++ * - `FeatureVectorizer` ++ * - `DictVectorizer` ++ * - `ArrayFeatureExtractor` ++ * - `NonMaximumSuppression` ++ * ++ * Pipelines ++ * - `PipelineClassifier` ++ * - `PipelineRegressor` ++ * - `Pipeline` ++ * ++ * Simple Mathematical Functions ++ * - `Identity` ++ */ ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "VisionFeaturePrint.proto"; ++import public "TextClassifier.proto"; ++import public "WordTagger.proto"; ++import public "Gazetteer.proto"; ++import public "WordEmbedding.proto"; ++import public "ArrayFeatureExtractor.proto"; ++import public "BayesianProbitRegressor.proto"; ++import public "CategoricalMapping.proto"; ++import public "CustomModel.proto"; ++import public "DictVectorizer.proto"; ++import public "FeatureTypes.proto"; ++import public "FeatureVectorizer.proto"; ++import public "GLMRegressor.proto"; ++import public "GLMClassifier.proto"; ++import public "NearestNeighbors.proto"; ++import public "Identity.proto"; ++import public "Imputer.proto"; ++import public "NeuralNetwork.proto"; ++import public "Normalizer.proto"; ++import public "OneHotEncoder.proto"; ++import public "Scaler.proto"; ++import public "NonMaximumSuppression.proto"; ++import public "SVM.proto"; ++import public "TreeEnsemble.proto"; ++import public "Parameters.proto"; ++import public "ItemSimilarityRecommender.proto"; ++import public "SoundAnalysisPreprocessing.proto"; ++import public "LinkedModel.proto"; ++ ++package CoreML.Specification; ++ ++/** ++ * A pipeline consisting of one or more models. ++ */ ++message Pipeline { ++ repeated Model models = 1; ++ ++ // Optional names given for each model ++ // If not supplied it defaults to ["model0",..., "model"(models.size()-1)] ++ // These names can be used to disambiguate the scope / domain of a parameter ++ repeated string names = 2; ++} ++ ++/** ++ * A classifier pipeline. ++ */ ++message PipelineClassifier { ++ Pipeline pipeline = 1; ++} ++ ++/** ++ * A regressor pipeline. ++ */ ++message PipelineRegressor { ++ Pipeline pipeline = 1; ++} ++ ++/** ++ * A feature description, ++ * consisting of a name, short description, and type. ++ */ ++message FeatureDescription { ++ string name = 1; ++ string shortDescription = 2; ++ FeatureType type = 3; ++} ++ ++/** ++ * Model metadata, ++ * consisting of a short description, a version string, ++ * an author, a license, and any other user defined ++ * key/value meta data. ++ */ ++message Metadata { ++ string shortDescription = 1; ++ string versionString = 2; ++ string author = 3; ++ string license = 4; ++ map userDefined = 100; ++} ++ ++/** ++ * A description of a model, ++ * consisting of descriptions of its input and output features. ++ * Both regressor and classifier models require the name of the ++ * primary predicted output feature (``predictedFeatureName``). ++ * Classifier models can specify the output feature containing ++ * probabilities for the predicted classes ++ * (``predictedProbabilitiesName``). ++ */ ++message ModelDescription { ++ repeated FeatureDescription input = 1; ++ repeated FeatureDescription output = 10; ++ ++ // [Required for regressor and classifier models]: the name ++ // to give to an output feature containing the prediction. ++ string predictedFeatureName = 11; ++ ++ // [Optional for classifier models]: the name to give to an ++ // output feature containing a dictionary mapping class ++ // labels to their predicted probabilities. If not specified, ++ // the dictionary will not be returned by the model. ++ string predictedProbabilitiesName = 12; ++ ++ repeated FeatureDescription trainingInput = 50; ++ ++ Metadata metadata = 100; ++} ++ ++message SerializedModel { ++ // Identifier whose content describes the model type of the serialized protocol buffer message. ++ string identifier = 1; ++ ++ // Must be a valid serialized protocol buffer of the above specified type. ++ bytes model = 2; ++} ++ ++/** ++ * A Core ML model, ++ * consisting of a specification version, ++ * a model description, and a model type. ++ * ++ * Core ML model compatibility is indicated by ++ * a monotonically increasing specification version number, ++ * which is incremented anytime a backward-incompatible change is made ++ * (this is functionally equivalent to the MAJOR version number ++ * described by `Semantic Versioning 2.0.0 `_). ++ * ++ * Specification Versions : OS Availability (Core ML Version) ++ * ++ * 1 : iOS 11, macOS 10.13, tvOS 11, watchOS 4 (Core ML 1) ++ * - Feedforward & Recurrent Neural Networks ++ * - General Linear Models ++ * - Tree Ensembles ++ * - Support Vector Machines ++ * - Pipelines ++ * - Feature Engineering ++ * ++ * 2 : iOS 11.2, macOS 10.13.2, tvOS 11.2, watchOS 4.2 (Core ML 1.2) ++ * - Custom Layers for Neural Networks ++ * - Float 16 support for Neural Network layers ++ * ++ * 3 : iOS 12, macOS 10.14, tvOS 12, watchOS 5 (Core ML 2) ++ * - Flexible shapes and image sizes ++ * - Categorical sequences ++ * - Core ML Vision Feature Print, Text Classifier, Word Tagger ++ * - Non Max Suppression ++ * - Crop and Resize Bilinear NN layers ++ * - Custom Models ++ * ++ * 4 : iOS 13, macOS 10.15, tvOS 13, watchOS 6 (Core ML 3) ++ * - Updatable models ++ * - Exact shape / general rank mapping for neural networks ++ * - Large expansion of supported neural network layers ++ * - Generalized operations ++ * - Control flow ++ * - Dynamic layers ++ * - See NeuralNetwork.proto ++ * - Nearest Neighbor Classifier ++ * - Sound Analysis Prepreocessing ++ * - Recommender ++ * - Linked Model ++ * - NLP Gazeteer ++ * - NLP WordEmbedding ++ * ++ * 5 : iOS 14, macOS 11, tvOS 14, watchOS 7 (Core ML 4) ++ * - Model Deployment ++ * - Model Encryption ++ * - Unified converter API with PyTorch and Tensorflow 2 Support in coremltools 4 ++ * - MIL builder for neural networks and composite ops in coremltools 4 ++ * - New layers in neural network: ++ * - CumSum ++ * - OneHot ++ * - ClampedReLu ++ * - ArgSort ++ * - SliceBySize ++ * - Convolution3D ++ * - Pool3D ++ * - Bilinear Upsample with align corners and fractional factors ++ * - PixelShuffle ++ * - MatMul with int8 weights and int8 activations ++ * - Concat interleave ++ * - See NeuralNetwork.proto ++ * - Enhanced Xcode model view with interactive previews ++ * - Enhanced Xcode Playground support for Core ML models ++ * ++ */ ++message Model { ++ int32 specificationVersion = 1; ++ ModelDescription description = 2; ++ ++ /* ++ * Following model types support on-device update: ++ * ++ * - NeuralNetworkClassifier ++ * - NeuralNetworkRegressor ++ * - NeuralNetwork ++ * - KNearestNeighborsClassifier ++ */ ++ bool isUpdatable = 10; ++ ++ // start at 200 here ++ // model specific parameters: ++ oneof Type { ++ // pipeline starts at 200 ++ PipelineClassifier pipelineClassifier = 200; ++ PipelineRegressor pipelineRegressor = 201; ++ Pipeline pipeline = 202; ++ ++ // regressors start at 300 ++ GLMRegressor glmRegressor = 300; ++ SupportVectorRegressor supportVectorRegressor = 301; ++ TreeEnsembleRegressor treeEnsembleRegressor = 302; ++ NeuralNetworkRegressor neuralNetworkRegressor = 303; ++ BayesianProbitRegressor bayesianProbitRegressor = 304; ++ ++ // classifiers start at 400 ++ GLMClassifier glmClassifier = 400; ++ SupportVectorClassifier supportVectorClassifier = 401; ++ TreeEnsembleClassifier treeEnsembleClassifier = 402; ++ NeuralNetworkClassifier neuralNetworkClassifier = 403; ++ KNearestNeighborsClassifier kNearestNeighborsClassifier = 404; ++ ++ // generic models start at 500 ++ NeuralNetwork neuralNetwork = 500; ++ ItemSimilarityRecommender itemSimilarityRecommender = 501; ++ ++ // Custom and linked models ++ CustomModel customModel = 555; ++ LinkedModel linkedModel = 556; ++ ++ // feature engineering starts at 600 ++ OneHotEncoder oneHotEncoder = 600; ++ Imputer imputer = 601; ++ FeatureVectorizer featureVectorizer = 602; ++ DictVectorizer dictVectorizer = 603; ++ Scaler scaler = 604; ++ CategoricalMapping categoricalMapping = 606; ++ Normalizer normalizer = 607; ++ ArrayFeatureExtractor arrayFeatureExtractor = 609; ++ NonMaximumSuppression nonMaximumSuppression = 610; ++ ++ ++ // simple mathematical functions used for testing start at 900 ++ Identity identity = 900; ++ ++ // reserved until 1000 ++ ++ // CoreML provided models ++ CoreMLModels.TextClassifier textClassifier = 2000; ++ CoreMLModels.WordTagger wordTagger = 2001; ++ CoreMLModels.VisionFeaturePrint visionFeaturePrint = 2002; ++ CoreMLModels.SoundAnalysisPreprocessing soundAnalysisPreprocessing = 2003; ++ CoreMLModels.Gazetteer gazetteer = 2004; ++ CoreMLModels.WordEmbedding wordEmbedding = 2005; ++ ++ // Reserved private messages start at 3000 ++ // These messages are subject to change with no notice or support. ++ SerializedModel serializedModel = 3000; ++ } ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/NearestNeighbors.proto b/onnxruntime/core/providers/coreml/mlmodel_format/NearestNeighbors.proto +new file mode 100644 +index 000000000..82acd8490 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/NearestNeighbors.proto +@@ -0,0 +1,132 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++package CoreML.Specification; ++ ++import public "DataStructures.proto"; ++import public "Parameters.proto"; ++ ++/** ++ * A k-Nearest-Neighbor classifier ++ */ ++message KNearestNeighborsClassifier { ++ ++ /** ++ * The "core" nearest neighbor model attributes. ++ */ ++ NearestNeighborsIndex nearestNeighborsIndex = 1; ++ ++ /** ++ * Number of neighbors to use for classification. ++ */ ++ Int64Parameter numberOfNeighbors = 3; ++ ++ /** ++ * Type of labels supported by the model. Currently supports String or Int64 ++ * labels. ++ */ ++ oneof ClassLabels { ++ StringVector stringClassLabels = 100; ++ Int64Vector int64ClassLabels = 101; ++ } ++ ++ /** ++ * Default value of class label (useful when prediction is called on an empty kNN classifier) ++ */ ++ oneof DefaultClassLabel { ++ string defaultStringLabel = 110; ++ int64 defaultInt64Label = 111; ++ } ++ ++ /** ++ * Weighting scheme to be used when computing the majority label of a ++ * new data point. ++ */ ++ oneof WeightingScheme { ++ UniformWeighting uniformWeighting = 200; ++ InverseDistanceWeighting inverseDistanceWeighting = 210; ++ } ++} ++ ++/** ++ * The "core" attributes of a Nearest Neighbors model. ++ */ ++message NearestNeighborsIndex { ++ ++ /** ++ * Number of dimensions of the input data. ++ */ ++ int32 numberOfDimensions = 1; ++ ++ /** ++ * Vector of floating point data that makes up the model. Each data point must have 'numberOfDimensions' ++ * dimensions. ++ */ ++ repeated FloatVector floatSamples = 2; ++ ++ /** ++ * Backing data structure for the Nearest Neighbors Index. Currently supports ++ * a linear index or a kd-tree index. ++ */ ++ oneof IndexType { ++ LinearIndex linearIndex = 100; ++ SingleKdTreeIndex singleKdTreeIndex = 110; ++ } ++ ++ /** ++ * Distance function to be used to find neighbors. Currently only Squared Euclidean ++ * Distance is supported. ++ */ ++ oneof DistanceFunction { ++ SquaredEuclideanDistance squaredEuclideanDistance = 200; ++ } ++ ++} ++ ++/** ++ * Specifies a uniform weighting scheme (i.e. each neighbor receives equal ++ * voting power). ++ */ ++message UniformWeighting { ++} ++ ++ ++/** ++ * Specifies a inverse-distance weighting scheme (i.e. closest neighbors receives higher ++ * voting power). A nearest neighbor with highest sum of (1 / distance) is picked. ++ */ ++message InverseDistanceWeighting { ++} ++ ++ ++/** ++ * Specifies a flat index of data points to be searched by brute force. ++ */ ++message LinearIndex { ++} ++ ++ ++/** ++ * Specifies a kd-tree backend for the nearest neighbors model. ++ */ ++message SingleKdTreeIndex { ++ ++ /** ++ * Number of data points contained within a leaf node of the kd-tree. ++ */ ++ int32 leafSize = 1; ++ ++} ++ ++ ++/** ++ * Specifies the Squared Euclidean Distance function. ++ */ ++message SquaredEuclideanDistance { ++} ++ +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/NeuralNetwork.proto b/onnxruntime/core/providers/coreml/mlmodel_format/NeuralNetwork.proto +new file mode 100644 +index 000000000..44a77c6e7 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/NeuralNetwork.proto +@@ -0,0 +1,6531 @@ ++// Copyright (c) 2017-2019, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++/** ++ * A neural network is defined through a collection of layers ++ * and represents a directed acyclic graph (DAG). ++ * Each layer has a name, a layer type, ++ * a list of input names, a list of output names, ++ * and a collection of parameters specific to the layer type. ++ * ++ * The graph structure and connectivity of the neural network ++ * is inferred from the input and output names. ++ * A neural network starts with the layer ++ * whose input name is equal to the value specified in ++ * ``Model.description.input.name``, ++ * and ends with the layer ++ * whose output name is equal to the value specified in ++ * ``Model.description.output.name``. ++ * Layers must have unique input and output names, ++ * and a layer may not have input or output names that ++ * refer to layers that are not yet defined. ++ * ++ * For Core ML specification version <=3, ++ * all inputs are mapped to static rank 5 tensors, with axis notations ++ * [Sequence, Batch, Channel, Height, Width]. ++ * ++ * From specification version 4 onwards (iOS >= 13, macOS >= 10.15), more options are available ++ * (see enums ``NeuralNetworkMultiArrayShapeMapping``, ``NeuralNetworkImageShapeMapping``) ++ * to map inputs to generic N-Dimensional (or N rank) tensors, where N >= 1. ++ * ++ * Each layer type may have specific constraints on the ranks of its inputs and outputs. ++ * ++ * Some of the layers (such as softmax, reduce, etc) have parameters that have been described in ++ * terms of notational axis "Channel", "Height", "Width" or "Sequence". They can be re-interpreted easily in ++ * the general ND setting by using the following rule: ++ * "width" is same as axis = -1 (i.e. the last axis from the end) ++ * "height" is same as axis = -2 (i.e. the second last axis from the end) ++ * "channel" is same as axis = -3 (i.e. the third last axis from the end) ++ * "sequence" is same as axis = -5 (i.e. the fifth last axis from the end) ++ * ++ * Several layers are available in 3 different variations, with the names ending ++ * in identifiers: ``like``, ``static`` and ``dynamic``. For instance, ``FillLike``, ++ * ``FillStatic`` and ``FillDynamic``. The ``static`` variation generally will have ++ * a property corresponding to the shape of the output. For instance, if the ++ * output of the ``FillStatic`` layer is desired to be of shape (10, 4), the ++ * property ``targetShape`` will have to be set to [10, 4]. In the ``dynamic`` case, ++ * the shape is an input, hence it can be changed at runtime. For instance, for ++ * a ``FillDynamic`` layer, the input would have to be an array containing the ++ * values 10 and 4, if the desired output is of shape (10, 4). Whereas in the ++ * ``like`` case, the additional input's shape is used as the output shape, ignoring ++ * its values. For instance, for a ``FillLike`` layer, for an input with shape ++ * (10, 4), the output generated will also be of shape (10, 4), values of the ++ * input will be ignored. ++ */ ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++import public "Parameters.proto"; ++ ++package CoreML.Specification; ++ ++ ++enum NeuralNetworkMultiArrayShapeMapping { ++ ++ /* ++ * Describes how the MultiArray shape for the inputs, ++ * provided in Features Types proto via model description, ++ * is mapped to construct tensors that are fed into the Neural Network layers. ++ */ ++ ++ /* ++ * Default legacy value. Only supported for Core ML Specification version <= 3. ++ * ++ * The default legacy shape mapping resolves all input shapes to a rank 5 equivalent ++ * with axis notation of [Seq, Batch, Channel, Height, Width]. ++ * ++ * When this enum value is selected, ++ * the repeated shape field in the message "ArrayFeatureType" in feature types proto, ++ * must be either length 1 or length 3. ++ * ++ * The following rule is used to map the values in the shape field to the actual tensor shape: ++ * rank 1 shape is mapped to shape [1,1,C,1,1] ++ * rank 3 shape is mapped to shape [1,1,C,H,W] ++ * At runtime, the first two dimensions (Seq or Batch) can be presented as well, with non-1 values. ++ * ++ * It is invalid to use this enum value if any of the layers added ++ * Specification version 4 (iOS >= 13, macOS >= 10.15) onwards are used in the network. ++ * Validator will raise an error in that case. ++ */ ++ RANK5_ARRAY_MAPPING = 0; ++ ++ /* ++ * The exact shape and rank (i.e. number of dimensions in the shape) of the input, ++ * as specified in the message "ArrayFeatureType", is passed through to the layers. ++ * Supported only for Specification version >= 4 (iOS >= 13, macOS >= 10.15). ++ */ ++ EXACT_ARRAY_MAPPING = 1; ++ ++} ++ ++enum NeuralNetworkImageShapeMapping { ++ ++ /* ++ * Describes how the shape of the input tensors is constructed from image inputs. ++ */ ++ ++ /* ++ * In this case, image input is mapped to a rank 5 tensor. ++ * For Color images, input tensor is shaped as [1,1,3,H,W]. ++ * For Gray images, input tensor is shaped as [1,1,1,H,W]. ++ */ ++ RANK5_IMAGE_MAPPING = 0; ++ ++ /* ++ * For Color images, input tensor is shaped as [1,3,H,W]. ++ * For Gray images, input tensor is shaped as [1,1,H,W]. ++ * Supported only for Specification version >= 4 (iOS >= 13, macOS >= 10.15). ++ */ ++ RANK4_IMAGE_MAPPING = 1; ++ ++} ++ ++/** ++ A neural network. ++ */ ++message NeuralNetwork { ++ ++ repeated NeuralNetworkLayer layers = 1; ++ repeated NeuralNetworkPreprocessing preprocessing = 2; ++ ++ // use this enum value to determine the input tensor shapes to the neural network, for multiarray inputs ++ NeuralNetworkMultiArrayShapeMapping arrayInputShapeMapping = 5; ++ ++ // use this enum value to determine the input tensor shapes to the neural network, for image inputs ++ NeuralNetworkImageShapeMapping imageInputShapeMapping = 6; ++ ++ ++ NetworkUpdateParameters updateParams = 10; ++ ++} ++ ++/// Preprocessing ++/// ------------- ++ ++/** ++ * A neural network preprocessor that ++ * performs a scalar multiplication of an image ++ * followed by addition of scalar biases to the channels. ++ * ++ * Input: X ++ * An image in BGR or RGB format with shape ``[3, H, W]`` ++ * or in grayscale format with shape ``[1, H, W]``. ++ * Output: Y ++ * An image with format and shape corresponding to the input. ++ * ++ * If the input image is in BGR format: ++ * ++ * .. code:: ++ * ++ * Y[0, :, :] = channelScale * X[0, :, :] + blueBias ++ * Y[1, :, :] = channelScale * X[1, :, :] + greenBias ++ * Y[2, :, :] = channelScale * X[2, :, :] + redBias ++ * ++ * If the input image is in RGB format: ++ * ++ * .. code:: ++ * ++ * Y[0, :, :] = channelScale * X[0, :, :] + redBias ++ * Y[1, :, :] = channelScale * X[1, :, :] + greenBias ++ * Y[2, :, :] = channelScale * X[2, :, :] + blueBias ++ * ++ * If the input image is in grayscale format: ++ * ++ * .. code:: ++ * ++ * Y[0, :, :] = channelScale * X[0, :, :] + grayBias ++ */ ++message NeuralNetworkImageScaler { ++ ++ float channelScale = 10; ///Scalar to be multiplied. ++ float blueBias = 20; ///Scalar blue bias to be added. ++ float greenBias = 21; ///Scalar green bias to be added. ++ float redBias = 22; ///Scalar red bias to be added. ++ float grayBias = 30; ///Scalar bias to be added for grayscale images. ++ ++} ++ ++/** ++ * A neural network preprocessor that ++ * subtracts the provided mean image from the input image. ++ * The mean image is subtracted from the input named ++ * ``NeuralNetworkPreprocessing.featureName``. ++ */ ++message NeuralNetworkMeanImage { ++ ++ /** ++ * Mean image stored as a flattened array of floats, ++ * representing shape [Channel,Height,Width]. ++ */ ++ repeated float meanImage = 1; ++ ++} ++ ++/// Preprocessing parameters for image inputs. ++message NeuralNetworkPreprocessing { ++ ++ string featureName = 1; /// must be equal to the input name to which the preprocessing is applied ++ oneof preprocessor { ++ NeuralNetworkImageScaler scaler = 10; ++ NeuralNetworkMeanImage meanImage = 11; ++ } ++ ++} ++ ++/// Activation Functions ++/// -------------------- ++ ++/** ++ * A rectified linear unit (ReLU) activation function. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = \text{max}(0, x) ++ */ ++message ActivationReLU { ++ ++} ++ ++/** ++ * A leaky rectified linear unit (ReLU) activation function. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = \begin{cases} ++ * x & \text{if } x \geq 0 \\ ++ * \alpha x & \text{if } x < 0 ++ * \end{cases} ++ */ ++message ActivationLeakyReLU { ++ ++ float alpha = 1; //negative slope value for leakyReLU ++ ++} ++ ++/** ++ * A hyperbolic tangent activation function. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = \dfrac{1 - e^{-2x}}{1 + e^{-2x}} ++ */ ++message ActivationTanh { ++ ++} ++ ++/** ++ * A scaled hyperbolic tangent activation function. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = \alpha \tanh(\beta x) ++ */ ++message ActivationScaledTanh { ++ ++ float alpha = 1; ++ float beta = 2; ++ ++} ++ ++/** ++ * A sigmoid activation function. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = \dfrac{1}{1 + e^{-x}} ++ */ ++message ActivationSigmoid { ++ ++} ++ ++/** ++ * A linear activation function. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = \alpha x + \beta ++ */ ++message ActivationLinear { ++ ++ float alpha = 1; ++ float beta = 2; ++ ++} ++ ++/** ++ * A hard sigmoid activation function. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = \text{min}(\text{max}(\alpha x + \beta, 0), 1) ++ */ ++message ActivationSigmoidHard { ++ ++ float alpha = 1; ++ float beta = 2; ++ ++} ++ ++/** ++ * A parameterized rectified linear unit (PReLU) activation function. ++ * Input must be at least rank 3. Axis = -3 is denoted by "C", or channels. ++ * "alpha" parameter can be a vector of length C. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x_i) = \begin{cases} ++ * x_i & \text{if } x_i \geq 0 \\ ++ * \alpha_i x_i & \text{if } x_i < 0 ++ * \end{cases} \;,\;i=1,...,C ++ */ ++message ActivationPReLU { ++ ++ // parameter of length C or 1. ++ // If length is 1, same value is used for all channels ++ WeightParams alpha = 1; ++ ++} ++ ++/** ++ * An exponential linear unit (ELU) activation function. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = \begin{cases} ++ * x & \text{if } x \geq 0 \\ ++ * \alpha (e^x - 1) & \text{if } x < 0 ++ * \end{cases} ++ */ ++message ActivationELU { ++ ++ float alpha = 1; ++ ++} ++ ++/** ++ * A thresholded rectified linear unit (ReLU) activation function. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = \begin{cases} ++ * x & \text{if } x \geq \alpha \\ ++ * 0 & \text{if } x < \alpha ++ * \end{cases} ++ */ ++message ActivationThresholdedReLU { ++ ++ float alpha = 1; ++ ++} ++ ++/** ++ * A softsign activation function. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = \dfrac{x}{1 + |x|} ++ */ ++message ActivationSoftsign { ++ ++} ++ ++/** ++ * A softplus activation function. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = \text{log}(1 + e^x) ++ */ ++message ActivationSoftplus { ++ ++} ++ ++/** ++ * A parametric softplus activation function. ++ * Input must be at least rank 3. axis = -3 is denoted by "C", or channels. ++ * "alpha"/"beta" parameter can be a vector of length C. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x_i) = \alpha_i \text{log}(1 + e^{\beta_i x_i}) \;,\;i=1,...,C ++ */ ++message ActivationParametricSoftplus { ++ ++ // If length is 1, same value is used for all channels ++ WeightParams alpha = 1; //parameter of length C or 1 ++ WeightParams beta = 2; //parameter of length C or 1 ++ ++} ++ ++message ActivationParams { ++ ++ oneof NonlinearityType { ++ ActivationLinear linear = 5; ++ ++ ActivationReLU ReLU = 10; ++ ActivationLeakyReLU leakyReLU = 15; ++ ActivationThresholdedReLU thresholdedReLU = 20; ++ ActivationPReLU PReLU = 25; ++ ++ ActivationTanh tanh = 30; ++ ActivationScaledTanh scaledTanh = 31; ++ ++ ActivationSigmoid sigmoid = 40; ++ ActivationSigmoidHard sigmoidHard = 41; ++ ++ ActivationELU ELU = 50; ++ ++ ActivationSoftsign softsign = 60; ++ ActivationSoftplus softplus = 70; ++ ActivationParametricSoftplus parametricSoftplus = 71; ++ } ++ ++} ++ ++/** ++ * Representation of the intermediate tensors ++ */ ++message Tensor { ++ ++ // Number of dimensions in the tensor shape ++ uint32 rank = 1; ++ // actual value of the tensor shape. ++ // must be of length "rank". Can contain -1s for unknown dimensions. ++ repeated int64 dimValue = 2; ++ ++} ++ ++/** ++ * A single neural network layer. ++ */ ++message NeuralNetworkLayer { ++ ++ string name = 1; //descriptive name of the layer ++ repeated string input = 2; ++ repeated string output = 3; ++ ++ repeated Tensor inputTensor = 4; // must be the same length as the "input" field ++ repeated Tensor outputTensor = 5; // must be the same length as the "output" field ++ ++ // Must be set to true to mark the layer as updatable. ++ // If true, the weightParams in the layer's properties must also be set to updatable ++ // If false, the value of the isUpdatable parameter within the layer's weights are ignored ++ bool isUpdatable = 10; ++ ++ oneof layer { ++ ++ // Start at 100 here ++ ConvolutionLayerParams convolution = 100; ++ ++ PoolingLayerParams pooling = 120; ++ ++ ActivationParams activation = 130; ++ ++ InnerProductLayerParams innerProduct = 140; ++ EmbeddingLayerParams embedding = 150; ++ ++ // Normalization-related Layers ++ BatchnormLayerParams batchnorm = 160; ++ MeanVarianceNormalizeLayerParams mvn = 165; ++ L2NormalizeLayerParams l2normalize = 170; ++ SoftmaxLayerParams softmax = 175; ++ LRNLayerParams lrn = 180; ++ ++ CropLayerParams crop = 190; ++ PaddingLayerParams padding = 200; ++ UpsampleLayerParams upsample = 210; ++ ++ ResizeBilinearLayerParams resizeBilinear = 211; ++ CropResizeLayerParams cropResize = 212; ++ ++ UnaryFunctionLayerParams unary = 220; ++ ++ // Element-wise Operations ++ AddLayerParams add = 230; ++ MultiplyLayerParams multiply = 231; ++ ++ AverageLayerParams average = 240; ++ ScaleLayerParams scale = 245; ++ ++ BiasLayerParams bias = 250; ++ MaxLayerParams max = 260; ++ MinLayerParams min = 261; ++ ++ DotProductLayerParams dot = 270; ++ ReduceLayerParams reduce = 280; ++ LoadConstantLayerParams loadConstant = 290; ++ ++ // Data Reorganization ++ ReshapeLayerParams reshape = 300; ++ FlattenLayerParams flatten = 301; ++ PermuteLayerParams permute = 310; ++ ConcatLayerParams concat = 320; ++ SplitLayerParams split = 330; ++ SequenceRepeatLayerParams sequenceRepeat = 340; ++ ++ ReorganizeDataLayerParams reorganizeData = 345; ++ SliceLayerParams slice = 350; ++ ++ // Recurrent Layers ++ SimpleRecurrentLayerParams simpleRecurrent = 400; ++ GRULayerParams gru = 410; ++ UniDirectionalLSTMLayerParams uniDirectionalLSTM = 420; ++ BiDirectionalLSTMLayerParams biDirectionalLSTM = 430; ++ ++ // Custom (user-implemented) Layer ++ CustomLayerParams custom = 500; ++ ++ // Following layers are available only after Core ML Specification ++ // version >= 4 (iOS >= 13, macOS >= 10.15) ++ ++ // Control Flow related Layers ++ CopyLayerParams copy = 600; ++ BranchLayerParams branch = 605; ++ ++ LoopLayerParams loop = 615; ++ LoopBreakLayerParams loopBreak = 620; ++ LoopContinueLayerParams loopContinue = 625; ++ ++ RangeStaticLayerParams rangeStatic = 635; ++ RangeDynamicLayerParams rangeDynamic = 640; ++ ++ // Element-wise Unary Layers ++ ClipLayerParams clip = 660; ++ CeilLayerParams ceil = 665; ++ FloorLayerParams floor = 670; ++ ++ SignLayerParams sign = 680; ++ RoundLayerParams round = 685; ++ ++ Exp2LayerParams exp2 = 700; ++ ++ SinLayerParams sin = 710; ++ CosLayerParams cos = 715; ++ TanLayerParams tan = 720; ++ ++ AsinLayerParams asin = 730; ++ AcosLayerParams acos = 735; ++ AtanLayerParams atan = 740; ++ ++ SinhLayerParams sinh = 750; ++ CoshLayerParams cosh = 755; ++ TanhLayerParams tanh = 760; ++ ++ AsinhLayerParams asinh = 770; ++ AcoshLayerParams acosh = 775; ++ AtanhLayerParams atanh = 780; ++ ++ ErfLayerParams erf = 790; ++ GeluLayerParams gelu = 795; ++ ++ // Element-wise Binary with Broadcasting Support ++ EqualLayerParams equal = 815; ++ NotEqualLayerParams notEqual = 820; ++ LessThanLayerParams lessThan = 825; ++ LessEqualLayerParams lessEqual = 827; ++ GreaterThanLayerParams greaterThan = 830; ++ GreaterEqualLayerParams greaterEqual = 832; ++ ++ LogicalOrLayerParams logicalOr = 840; ++ LogicalXorLayerParams logicalXor = 845; ++ LogicalNotLayerParams logicalNot = 850; ++ LogicalAndLayerParams logicalAnd = 855; ++ ++ ModBroadcastableLayerParams modBroadcastable = 865; ++ MinBroadcastableLayerParams minBroadcastable = 870; ++ MaxBroadcastableLayerParams maxBroadcastable = 875; ++ AddBroadcastableLayerParams addBroadcastable = 880; ++ PowBroadcastableLayerParams powBroadcastable = 885; ++ DivideBroadcastableLayerParams divideBroadcastable = 890; ++ FloorDivBroadcastableLayerParams floorDivBroadcastable = 895; ++ MultiplyBroadcastableLayerParams multiplyBroadcastable = 900; ++ SubtractBroadcastableLayerParams subtractBroadcastable = 905; ++ ++ // Tensor Manipulations ++ TileLayerParams tile = 920; ++ StackLayerParams stack = 925; ++ GatherLayerParams gather = 930; ++ ScatterLayerParams scatter = 935; ++ GatherNDLayerParams gatherND = 940; ++ ScatterNDLayerParams scatterND = 945; ++ SoftmaxNDLayerParams softmaxND = 950; ++ GatherAlongAxisLayerParams gatherAlongAxis = 952; ++ ScatterAlongAxisLayerParams scatterAlongAxis = 954; ++ ++ ReverseLayerParams reverse = 960; ++ ReverseSeqLayerParams reverseSeq = 965; ++ ++ SplitNDLayerParams splitND = 975; ++ ConcatNDLayerParams concatND = 980; ++ TransposeLayerParams transpose = 985; ++ ++ SliceStaticLayerParams sliceStatic = 995; ++ SliceDynamicLayerParams sliceDynamic = 1000; ++ SlidingWindowsLayerParams slidingWindows = 1005; ++ ++ TopKLayerParams topK = 1015; ++ ArgMinLayerParams argMin = 1020; ++ ArgMaxLayerParams argMax = 1025; ++ ++ EmbeddingNDLayerParams embeddingND = 1040; ++ BatchedMatMulLayerParams batchedMatmul = 1045; ++ ++ // Tensor Allocation / Reshape-related Operations ++ GetShapeLayerParams getShape = 1065; ++ LoadConstantNDLayerParams loadConstantND = 1070; ++ ++ FillLikeLayerParams fillLike = 1080; ++ FillStaticLayerParams fillStatic = 1085; ++ FillDynamicLayerParams fillDynamic = 1090; ++ ++ BroadcastToLikeLayerParams broadcastToLike = 1100; ++ BroadcastToStaticLayerParams broadcastToStatic = 1105; ++ BroadcastToDynamicLayerParams broadcastToDynamic = 1110; ++ ++ SqueezeLayerParams squeeze = 1120; ++ ExpandDimsLayerParams expandDims = 1125; ++ FlattenTo2DLayerParams flattenTo2D = 1130; ++ ReshapeLikeLayerParams reshapeLike = 1135; ++ ReshapeStaticLayerParams reshapeStatic = 1140; ++ ReshapeDynamicLayerParams reshapeDynamic = 1145; ++ RankPreservingReshapeLayerParams rankPreservingReshape = 1150; ++ ++ ConstantPaddingLayerParams constantPad = 1155; ++ ++ // Random Distributions ++ RandomNormalLikeLayerParams randomNormalLike = 1170; ++ RandomNormalStaticLayerParams randomNormalStatic = 1175; ++ RandomNormalDynamicLayerParams randomNormalDynamic = 1180; ++ ++ RandomUniformLikeLayerParams randomUniformLike = 1190; ++ RandomUniformStaticLayerParams randomUniformStatic = 1195; ++ RandomUniformDynamicLayerParams randomUniformDynamic = 1200; ++ ++ RandomBernoulliLikeLayerParams randomBernoulliLike = 1210; ++ RandomBernoulliStaticLayerParams randomBernoulliStatic = 1215; ++ RandomBernoulliDynamicLayerParams randomBernoulliDynamic = 1220; ++ ++ CategoricalDistributionLayerParams categoricalDistribution = 1230; ++ ++ // Reduction-related Layers: ++ ReduceL1LayerParams reduceL1 = 1250; ++ ReduceL2LayerParams reduceL2 = 1255; ++ ReduceMaxLayerParams reduceMax = 1260; ++ ReduceMinLayerParams reduceMin = 1265; ++ ReduceSumLayerParams reduceSum = 1270; ++ ReduceProdLayerParams reduceProd = 1275; ++ ReduceMeanLayerParams reduceMean = 1280; ++ ReduceLogSumLayerParams reduceLogSum = 1285; ++ ReduceSumSquareLayerParams reduceSumSquare = 1290; ++ ReduceLogSumExpLayerParams reduceLogSumExp = 1295; ++ ++ // Masking / Selection Layers ++ WhereNonZeroLayerParams whereNonZero = 1313; ++ MatrixBandPartLayerParams matrixBandPart = 1315; ++ LowerTriangularLayerParams lowerTriangular = 1320; ++ UpperTriangularLayerParams upperTriangular = 1325; ++ WhereBroadcastableLayerParams whereBroadcastable = 1330; ++ ++ // Normalization Layers ++ LayerNormalizationLayerParams layerNormalization = 1350; ++ ++ NonMaximumSuppressionLayerParams NonMaximumSuppression = 1400; ++ ++ // Following layers are available only after Core ML Specification ++ // version >= 5 (iOS >= 14, macOS >= 11.0) ++ OneHotLayerParams oneHot = 1450; ++ CumSumLayerParams cumSum = 1455; ++ ClampedReLULayerParams clampedReLU = 1460; ++ ArgSortLayerParams argSort = 1461; ++ Pooling3DLayerParams pooling3d = 1465; ++ GlobalPooling3DLayerParams globalPooling3d = 1466; ++ SliceBySizeLayerParams sliceBySize = 1470; ++ Convolution3DLayerParams convolution3d = 1471; ++ ++ } ++ ++} ++ ++/** ++ * Branching Layer ++ * ++ * A layer that provides the functionality of branching or an If-Else block. ++ * ++ * Must have 1 input. There are no outputs as the execution is transferred to either the ++ * if or the else branch based on the value of the input. ++ * ++ * Input is the condition predicate. Must be a scalar (length 1 tensor). ++ * ++ */ ++message BranchLayerParams { ++ ++ /** ++ * execute this graph if the absolute value of the input Tensor is greater than 1e-6 ++ * This must be present. ++ */ ++ NeuralNetwork ifBranch = 1; ++ /** ++ * execute this graph if the absolute value of the input Tensor is less than 1e-6 ++ * This is optional. ++ */ ++ NeuralNetwork elseBranch = 2; ++ ++} ++ ++/** ++ * Loop Layer ++ * ++ * A layer that provides the functionality of a "for" loop or a "while" loop. ++ * ++ * There are either no inputs or 1 input. When an input is present, it corresponds to the maximum loop count, ++ * in that case the value of the "maxLoopIterations" field is ignored. Input must be a scalar. ++ * (For description below, maxLoopIterations is assumed to be the value of the input, when its present) ++ * ++ * No outputs are produced. Blobs produced by the condition or the body network are visible in the scope of the overall network. ++ * ++ * "conditionNetwork" must produce a tensor with the name specified in the "conditionVar" field. ++ * ++ * There are 3 possible cases for determining the termination condition: ++ * ++ * Case 1: ++ * ++ * If there is no "conditionNetwork", in this case the layer corresponds to a pure for loop, which is run "maxLoopIterations" number of times. ++ * Equivalent pseudo-code: ++ * ++ * for loopIterator = 0 : maxLoopIterations ++ * bodyNetwork() ++ * ++ * ++ * Case 2: ++ * ++ * "conditionNetwork" is present, and "maxLoopIterations" is 0 and there is no input, ++ * in this case the layer corresponds to a while loop. Equivalent pseudo-code: ++ * ++ * conditionVar = conditionNetwork() ++ * while conditionVar: ++ * bodyNetwork() ++ * conditionVar = conditionNetwork() ++ * ++ * ++ * Case 3: ++ * ++ * "conditionNetwork" is provided, and "maxLoopIterations" is positive or there is an input, ++ * in this case the layer corresponds to a while loop with a joint condition. Equivalent pseudo-code: ++ * ++ * loopIterator = 0 ++ * conditionVar = conditionNetwork() ++ * while (conditionVar and loopIterator < maxLoopIterations): ++ * bodyNetwork() ++ * loopIterator = loopIterator + 1 ++ * conditionVar = conditionNetwork() ++ * ++ */ ++message LoopLayerParams { ++ ++ /** ++ * maximum number of iterations. Ignored if input is present. ++ */ ++ uint64 maxLoopIterations = 1; ++ /** ++ * This field provides the name of the tensor which is produced by the conditionNetwork ++ * and whose value is checked to start/continue/terminate the loop. Value close to 0.0f is treated as False. ++ * This field is optional. ++ * Must be a non empty string if and only if "conditionNetwork" is present. ++ */ ++ string conditionVar = 2; ++ /** ++ * Must generate a tensor with the name provided in the "conditionVar" field. ++ * This field is optional. ++ * Must be present if and only if "conditionVar" field is a non empty string. ++ */ ++ NeuralNetwork conditionNetwork = 3; ++ /** ++ * Body of the loop. ++ * This field must be present. ++ */ ++ NeuralNetwork bodyNetwork = 4; ++ ++} ++ ++/** ++ * Loop break Layer ++ * ++ * Terminate the loop that has this layer. ++ * If present, it should always reside in the "bodyNetwork" of the loop layer ++ * ++ * No inputs/outputs ++ * ++ */ ++message LoopBreakLayerParams { ++ ++} ++ ++/** ++ * Loop Continue Layer ++ * ++ * Stop the current loop iteration and continue on the next iteration. ++ * If present, it should always reside in the "bodyNetwork" of the loop layer ++ * ++ * No inputs/outputs ++ * ++ */ ++message LoopContinueLayerParams { ++ ++} ++ ++/** ++ * Copy Layer ++ * ++ * A layer that copies its input tensor to the output tensor. ++ * Must have 1 input and 1 output, with distinct names. ++ * This is the only layer that is allowed to re-generate an output that is already present in the neural network prior to this layer, ++ * in which case it will overwrite the output tensor. ++ * ++ */ ++message CopyLayerParams { ++ ++} ++ ++/** ++ * GreaterThan Layer ++ * ++ * Either 1 or 2 inputs. ++ * Produces 1 output. ++ * Perform elementwise greater than operation. ++ * ++ * Output is 1.0f if the condition is true otherwise 0.0f. ++ * ++ * .. code:: ++ * ++ * y = x1 > x2 ++ * or ++ * y = x1 > alpha, if only one input is provided ++ * ++ * Broadcasting is supported. ++ * ++ */ ++message GreaterThanLayerParams { ++ ++ /** ++ * Compare to the scalar value provided here if there is 1 input ++ */ ++ float alpha = 2; ++ ++} ++ ++/** ++ * GreaterEqual Layer ++ * ++ * Either 1 or 2 inputs. ++ * Produces 1 output. ++ * Perform elementwise greater equal operation. ++ * ++ * Output is 1.0f if the condition is true otherwise 0.0f. ++ * ++ * .. code:: ++ * ++ * y = x1 >= x2 ++ * or ++ * y = x1 >= alpha, if only one input is provided ++ * ++ * Broadcasting is supported. ++ * ++ */ ++message GreaterEqualLayerParams { ++ ++ /** ++ * Compare to the scalar value provided here if there is 1 input ++ */ ++ float alpha = 2; ++ ++} ++ ++/** ++ * LessThan Layer ++ * ++ * Either 1 or 2 inputs. ++ * Produces 1 output. ++ * Perform elementwise less than operation. ++ * ++ * Output is 1.0f if the condition is true otherwise 0.0f. ++ * ++ * .. code:: ++ * ++ * y = x1 < x2 ++ * or ++ * y = x1 < alpha, if only one input is provided ++ * ++ * Broadcasting is supported. ++ * ++ */ ++message LessThanLayerParams { ++ ++ /** ++ * Compare to the scalar value provided here if there is 1 input ++ */ ++ float alpha = 2; ++ ++} ++ ++/** ++ * LessEqual Layer ++ * ++ * Either 1 or 2 inputs. ++ * Produces 1 output. ++ * Perform elementwise less equal operation. ++ * ++ * Output is 1.0f if the condition is true otherwise 0.0f. ++ * ++ * .. code:: ++ * ++ * y = x1 <= x2 ++ * or ++ * y = x1 <= alpha, if only one input is provided ++ * ++ * Broadcasting is supported. ++ * ++ */ ++message LessEqualLayerParams { ++ ++ /** ++ * Compare to the scalar value provided here if there is 1 input ++ */ ++ float alpha = 2; ++ ++} ++ ++/** ++ * Equal Layer ++ * ++ * Either 1 or 2 inputs. ++ * Produces 1 output. ++ * Perform elementwise equal operation. ++ * ++ * Output is 1.0f if the condition is true otherwise 0.0f. ++ * ++ * .. code:: ++ * ++ * y = x1 == x2 ++ * or ++ * y = x1 == alpha, if only one input is provided ++ * ++ * Broadcasting is supported. ++ * ++ */ ++message EqualLayerParams { ++ ++ /** ++ * Compare to the scalar value provided here if there is 1 input ++ */ ++ float alpha = 1; ++ ++} ++ ++/** ++ * NotEqual Layer ++ * ++ * Either 1 or 2 inputs. ++ * Produces 1 output. ++ * Perform elementwise not equal operation. ++ * ++ * Output is 1.0f if the condition is true otherwise 0.0f. ++ * ++ * .. code:: ++ * ++ * y = x1 != x2 ++ * or ++ * y = x1 != alpha, if only one input is provided ++ * ++ * Broadcasting is supported. ++ * ++ */ ++message NotEqualLayerParams { ++ ++ /** ++ * Compare to the scalar value provided here if there is 1 input ++ */ ++ float alpha = 1; ++ ++} ++ ++/** ++ * LogicalAnd Layer ++ * ++ * Must have 2 inputs, produces 1 output. ++ * Perform elementwise logical AND operation. ++ * ++ * Input is considered False if equal to 0.0f otherwise True. ++ * Output is 1.0f if the condition is true otherwise 0.0f. ++ * ++ * .. code:: ++ * ++ * y = AND(x1, x2) ++ * ++ * Broadcasting is supported. ++ * ++ */ ++message LogicalAndLayerParams { ++ ++} ++ ++/** ++ * LogicalOr Layer ++ * ++ * Must have 2 inputs, produces 1 output. ++ * Perform elementwise logical OR operation. ++ * ++ * Input is considered False if equal to 0.0f otherwise True. ++ * Output is 1.0f if the condition is true otherwise 0.0f. ++ * ++ * .. code:: ++ * ++ * y = OR(x1, x2) ++ * ++ * Broadcasting is supported. ++ * ++ */ ++message LogicalOrLayerParams { ++ ++} ++ ++/** ++ * LogicalXor Layer ++ * ++ * Must have 2 inputs, produces 1 output. ++ * Perform elementwise logical XOR operation. ++ * ++ * Input is considered False if equal to 0.0f otherwise True. ++ * Output is 1.0f if the condition is true otherwise 0.0f. ++ * ++ * .. code:: ++ * ++ * y = XOR(x1, x2) ++ * ++ * Broadcasting is supported. ++ * ++ */ ++message LogicalXorLayerParams { ++ ++} ++ ++/** ++ * LogicalNot Layer ++ * ++ * Must have 1 input, produces 1 output. ++ * Perform elementwise logical NOT operation. ++ * ++ * Input is considered False if equal to 0.0f otherwise True. ++ * Output is 1.0f if the condition is true otherwise 0.0f. ++ * ++ * .. code:: ++ * ++ * y = NOT(x) ++ * ++ * ++ */ ++message LogicalNotLayerParams { ++ ++} ++ ++/// Border Amounts ++/// -------------- ++ ++/** ++ * Specifies the amount of spatial border to be either padded or cropped. ++ * ++ * For padding: ++ * ++ * .. code:: ++ * ++ * H_out = borderAmounts[0].startEdgeSize + H_in + borderAmounts[0].endEdgeSize ++ * W_out = borderAmounts[1].startEdgeSize + W_in + borderAmounts[1].endEdgeSize ++ * ++ * topPaddingAmount == Height startEdgeSize ++ * bottomPaddingAmount == Height endEdgeSize ++ * leftPaddingAmount == Width startEdgeSize ++ * rightPaddingAmount == Width endEdgeSize ++ * ++ * For cropping: ++ * ++ * .. code:: ++ * ++ * H_out = (-borderAmounts[0].startEdgeSize) + H_in + (-borderAmounts[0].endEdgeSize) ++ * W_out = (-borderAmounts[1].startEdgeSize) + W_in + (-borderAmounts[1].endEdgeSize) ++ * ++ * topCropAmount == Height startEdgeSize ++ * bottomCropAmount == Height endEdgeSize ++ * leftCropAmount == Width startEdgeSize ++ * rightCropAmount == Width endEdgeSize ++ */ ++message BorderAmounts { ++ ++ message EdgeSizes { ++ /** ++ * The amount to be padded or cropped from the beginning. ++ */ ++ uint64 startEdgeSize = 1; ++ ++ /** ++ * The amount to be padded or cropped from the end. ++ */ ++ uint64 endEdgeSize = 2; ++ } ++ ++ /** ++ * The border amounts. ++ * This must be length 2 in the order ``[H, W]``. ++ */ ++ repeated EdgeSizes borderAmounts = 10; ++ ++} ++ ++/** ++ * Specifies the type of padding to be used with Convolution/Deconvolution and Pooling layers. ++ * After padding, input spatial shape: ``[H_in, W_in]``, gets modified to the ++ * output spatial shape ``[H_out, W_out]``. ++ * ++ * .. code:: ++ * ++ * topPaddingAmount == Height startEdgeSize == borderAmounts[0].startEdgeSize ++ * bottomPaddingAmount == Height endEdgeSize == borderAmounts[0].endEdgeSize ++ * leftPaddingAmount == Width startEdgeSize == borderAmounts[1].startEdgeSize ++ * rightPaddingAmount == Width endEdgeSize == borderAmounts[1].endEdgeSize ++ * ++ * With Convolution or Pooling: ++ * ++ * .. code:: ++ * ++ * H_out = int_division_round_down((H_in + topPaddingAmount + bottomPaddingAmount - KernelSize[0]),stride[0]) + 1 ++ * ++ * which is same as: ++ * ++ * .. code:: ++ * ++ * H_out = int_division_round_up((H_in + topPaddingAmount + bottomPaddingAmount - KernelSize[0] + 1),stride[0]) ++ * ++ * With Deconvolution: ++ * ++ * .. code:: ++ * ++ * H_out = (H_in-1) * stride[0] + kernelSize[0] - (topPaddingAmount + bottomPaddingAmount) ++ * ++ * ++ * The equivalent expressions hold true for ``W_out`` as well. ++ * ++ * ++ * By default, the values of ``paddingAmounts`` are set to ``0``, ++ * which results in a "true" valid padding. ++ * If non-zero values are provided for ``paddingAmounts``, ++ * "valid" convolution/pooling is performed within the spatially expanded input. ++ * ++ */ ++message ValidPadding { ++ ++ BorderAmounts paddingAmounts = 1; ++ ++} ++ ++/** ++ * Specifies the type of padding to be used with Convolution/Deconvolution and pooling layers. ++ * After padding, input spatial shape: ``[H_in, W_in]``, gets modified to the ++ * output spatial shape ``[H_out, W_out]``. ++ * With Convolution or pooling: ++ * ++ * .. code:: ++ * ++ * H_out = int_division_round_up(H_in,stride[0]) ++ * W_out = int_division_round_up(W_in,stride[1]) ++ * ++ * This is achieved by using the following padding amounts: ++ * ++ * .. code:: ++ * ++ * totalPaddingHeight = max(0,(H_out-1) * stride[0] + KernelSize[0] - Hin) ++ * totalPaddingWidth = max(0,(W_out-1) * stride[1] + KernelSize[1] - Win) ++ * ++ * There are two modes of asymmetry: ++ * ``BOTTOM_RIGHT_HEAVY``, and ``TOP_LEFT_HEAVY``. ++ * ++ * If the mode is ``BOTTOM_RIGHT_HEAVY``: ++ * ++ * .. code:: ++ * ++ * topPaddingAmount = floor(totalPaddingHeight / 2) ++ * bottomPaddingAmount = totalPaddingHeight - topPaddingAmount ++ * leftPaddingAmount = floor(totalPaddingWidth / 2) ++ * rightPaddingAmount = totalPaddingWidth - leftPaddingAmount ++ * ++ * If the mode is ``TOP_LEFT_HEAVY``: ++ * ++ * .. code:: ++ * ++ * bottomPaddingAmount = floor(totalPaddingHeight / 2) ++ * topPaddingAmount = totalPaddingHeight - bottomPaddingAmount ++ * rightPaddingAmount = floor(totalPaddingWidth / 2) ++ * leftPaddingAmount = totalPaddingWidth - rightPaddingAmount ++ * ++ * ++ * With Deconvolution: ++ * ++ * .. code:: ++ * ++ * H_out = H_in * stride[0] ++ * W_out = W_in * stride[1] ++ */ ++message SamePadding { ++ ++ enum SamePaddingMode { ++ ++ BOTTOM_RIGHT_HEAVY = 0; ++ TOP_LEFT_HEAVY = 1; ++ ++ } ++ SamePaddingMode asymmetryMode = 1; ++ ++} ++ ++/** ++ * Specifies how grid points are sampled from an interval. ++ * Without the loss of generality, assume the interval to be [0, X-1] from which N points are to be sampled. ++ * Here X may correspond to an input image's height or width. ++ * All the methods can be expressed in terms of numpy's linspace function, along with the constraint that grid points have to lie in the interval [0, X-1]. ++ * Note: numpy.linspace(start = start, end = end, num = N, endpoint = True) corresponds to sampling ++ * N points uniformly from the interval [start, end], endpoints included. ++ * The methods vary in how the ``start`` and ``end`` values are computed. ++ */ ++message SamplingMode { ++ ++ enum Method { ++ ++ /** ++ * start = 0, end = X-1 ++ * grid points = numpy.linspace(start, end) ++ */ ++ STRICT_ALIGN_ENDPOINTS_MODE = 0; ++ ++ /** ++ * if N == 1: start = end = (X-1)/2 ++ * otherwise, start = 0, end = X-1 ++ * grid points = numpy.linspace(start, end) ++ */ ++ ALIGN_ENDPOINTS_MODE = 1; ++ ++ /** ++ * start = 0, end = X - X/N ++ * grid points = min(X-1, numpy.linspace(start, end)) ++ * This is same as the mode used in the upsample layer in this specification, when used with bilinear interpolation. In that case N/X = upsample ratio. ++ */ ++ UPSAMPLE_MODE = 2; ++ ++ /** ++ * spacing = max(1, X-1)/N ++ * start = 0.5 * spacing ++ * end = start + (N-1) * spacing ++ * grid points = min(X-1, numpy.linspace(start, end)) ++ */ ++ ROI_ALIGN_MODE = 3; ++ ++ } ++ ++ Method samplingMethod = 1; ++ ++} ++ ++/** ++ * Specifies the convention used to specify four bounding box coordinates for an image of size (Height, Width). ++ * The (0,0) coordinate corresponds to the top-left corner of the image. ++ */ ++message BoxCoordinatesMode { ++ ++ enum Coordinates { ++ ++ /** ++ * [h_start, w_start, h_end, w_end] ++ */ ++ CORNERS_HEIGHT_FIRST = 0; ++ ++ /** ++ * [w_start, h_start, w_end, h_end] ++ */ ++ CORNERS_WIDTH_FIRST = 1; ++ ++ /** ++ * [h_center, w_center, box_height, box_width] ++ */ ++ CENTER_SIZE_HEIGHT_FIRST = 2; ++ ++ /** ++ * [w_center, h_center, box_width, box_height] ++ */ ++ CENTER_SIZE_WIDTH_FIRST = 3; ++ ++ } ++ ++ Coordinates boxMode = 1; ++ ++} ++ ++/** ++ * Weights for layer parameters. ++ * Weights are stored as repeated floating point numbers ++ * using row-major ordering ++ * and can represent 1-, 2-, 3-, or 4-dimensional data. ++ */ ++message WeightParams { ++ ++ /** ++ * Values specified in single / float / FP32 precision. ++ */ ++ repeated float floatValue = 1; ++ ++ /** ++ * Values in 16-bit half precision floating point. ++ */ ++ bytes float16Value = 2; ++ ++ /** ++ * Raw value specification for quantized lower precisions. ++ * ++ * This field is interpreted as uintN, where N is the number of bits in quantization. ++ * E.g. if n=8, the field is interpreted as an array of UINT8. ++ * Use this field for quantized parameters unless specifically noted to use ++ * int8RawValue. ++ */ ++ bytes rawValue = 30; ++ ++ /** ++ * Field to be used if int8DynamicQuantize is set in the parent layer. ++ * Cannot be set if rawValue is also set. ++ * The values in this field are interpreted as INT8. ++ * ++ * If this field is set, following conditions must hold true: ++ * * QuantizationType == LinearQuantizationParams, such that ++ * * size of the "scale" field is 1 and "bias" field is empty in "LinearQuantizationParams" ++ */ ++ bytes int8RawValue = 31; ++ ++ /** ++ * Quantization related parameters. ++ */ ++ QuantizationParams quantization = 40; ++ ++ bool isUpdatable = 50; ++ ++} ++ ++/** ++ * Quantization parameters. ++ */ ++message QuantizationParams { ++ ++ uint64 numberOfBits = 1; ++ oneof QuantizationType { ++ LinearQuantizationParams linearQuantization = 101; ++ LookUpTableQuantizationParams lookupTableQuantization = 102; ++ } ++ ++} ++ ++message LinearQuantizationParams { ++ ++ /** ++ * Stores scale and bias values corresponding to the quantized weights. ++ * Must be an array of 1 element, or an array of C elements, where C ++ * is number of output channels. For recurrent layers it is equal to ++ * the output vector size. ++ * ++ * Relationship between quantized weights, unquantized weights, scale and bias: ++ * ++ * W_unquantized = W_quantized * scale + bias ++ * ++ */ ++ repeated float scale = 1; ++ repeated float bias = 2; ++ ++} ++ ++message LookUpTableQuantizationParams { ++ ++ /* Stores look-up table quantization values. Must be an array of ++ (2^numberOfBits) Elements. ++ */ ++ repeated float floatValue = 1; ++ ++} ++ ++/// Layers ++/// ------ ++ ++/** ++ * A layer that performs spatial convolution or deconvolution. ++ * ++ * .. code:: ++ * ++ * y = ConvolutionLayer(x) ++ * ++ * Requires 1 or 2 inputs and produces 1 output. ++ * ++ * Input ++ * First Input: ++ * A blob with rank greater than or equal to 4. ++ * Rank 4 blob represents [Batch, channels, height, width]. ++ * For ranks greater than 4, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * ++ * From Core ML specification version 4 onwards (iOS >= 13, macOS >= 10.15). ++ * convolution layer can have 2 inputs, in which case the second input is ++ * the blob representing the weights. This is allowed when "isDeconvolution" = False. ++ * The weight blob should have shape ++ * ``[outputChannels, kernelChannels, kernelHeight, kernelWidth]``, ++ * where kernelChannels == inputChannels / nGroups. ++ * ++ * Output ++ * Rank is same as the input. e.g.: for rank 4 input, output shape is [B, C_out, H_out, W_out] ++ * ++ * ++ * If ``dilationFactor`` is not 1, effective kernel size is ++ * modified as follows: ++ * ++ * .. code:: ++ * ++ * KernelSize[0] <-- (kernelSize[0]-1) * dilationFactor[0] + 1 ++ * KernelSize[1] <-- (kernelSize[1]-1) * dilationFactor[1] + 1 ++ * ++ * Type of padding can be ``valid`` or ``same``. Output spatial dimensions depend on the ++ * the type of padding. For details, refer to the descriptions of the messages "ValidPadding" ++ * and "SamePadding". Padded values are all zeros. ++ * ++ * For Deconvolution, ``ConvolutionPaddingType`` (``valid`` or ``same``) is ignored when ``outputShape`` is set. ++ * ++ * ++ */ ++message ConvolutionLayerParams { ++ ++ /** ++ * The number of kernels. ++ * Same as ``C_out`` used in the layer description. ++ */ ++ uint64 outputChannels = 1; ++ ++ /** ++ * Channel dimension of the kernels. ++ * Must be equal to ``inputChannels / nGroups``, if isDeconvolution == False ++ * Must be equal to ``inputChannels``, if isDeconvolution == True ++ */ ++ uint64 kernelChannels = 2; ++ ++ /** ++ * Group convolution, i.e. weight reuse along channel axis. ++ * Input and kernels are divided into g groups ++ * and convolution / deconvolution is applied within the groups independently. ++ * If not set or 0, it is set to the default value 1. ++ */ ++ uint64 nGroups = 10; ++ ++ /** ++ * Must be length 2 in the order ``[H, W]``. ++ * If not set, default value ``[3, 3]`` is used. ++ */ ++ repeated uint64 kernelSize = 20; ++ ++ /** ++ * Must be length 2 in the order ``[H, W]``. ++ * If not set, default value ``[1, 1]`` is used. ++ */ ++ repeated uint64 stride = 30; ++ ++ /** ++ * Must be length 2 in order ``[H, W]``. ++ * If not set, default value ``[1, 1]`` is used. ++ * It is ignored if ``isDeconvolution == true``. ++ */ ++ repeated uint64 dilationFactor = 40; ++ ++ /** ++ * The type of padding. ++ */ ++ oneof ConvolutionPaddingType { ++ ValidPadding valid = 50; ++ SamePadding same = 51; ++ } ++ ++ /** ++ * Flag to specify whether it is a deconvolution layer. ++ */ ++ bool isDeconvolution = 60; ++ ++ /** ++ * Flag to specify whether a bias is to be added or not. ++ */ ++ bool hasBias = 70; ++ ++ /** ++ * Weights associated with this layer. ++ * If convolution (``isDeconvolution == false``), weights have the shape ++ * ``[outputChannels, kernelChannels, kernelHeight, kernelWidth]``, where kernelChannels == inputChannels / nGroups ++ * If deconvolution (``isDeconvolution == true``) weights have the shape ++ * ``[kernelChannels, outputChannels / nGroups, kernelHeight, kernelWidth]``, where kernelChannels == inputChannels ++ */ ++ WeightParams weights = 90; ++ WeightParams bias = 91; /// Must be of size [outputChannels]. ++ ++ /** ++ * The output shape, which has length 2 ``[H_out, W_out]``. ++ * This is used only for deconvolution (``isDeconvolution == true``). ++ * If not set, the deconvolution output shape is calculated ++ * based on ``ConvolutionPaddingType``. ++ */ ++ repeated uint64 outputShape = 100; ++ ++} ++ ++/** ++ * A layer that performs a 3-dimensional convolution. ++ * ++ * .. code:: ++ * ++ * y = Convolution3DLayer(x) ++ * ++ * Input ++ * A blob of rank 5. ++ * The input blob's shape should be ``[batch, channels, depth, height, width]``. ++ * ++ * Fields ++ * The bias field, if set, should have shape of ``[channelsOut]``. ++ * ++ * Output ++ * A blob of rank 5. ++ * The output blob's shape is ``[batch, channelsOut, depthOut, heightOut, widthOut]``. ++ * ++ * Type of padding can be ``custom``, ``valid``, or ``same``. Padded values are all zeros. ++ * Output spatial dimensions depend on the the type of padding. For details, refer to the ++ * descriptions of the ``PaddingType`` field of this ``Convolution3DLayerParams`` message. ++ * ++ * Example ++ * For example, given an input of size ``[1, 3, 3, 8, 8]``, a stride of 2 in each dimension, ++ * a kernel of 3 in each dimension, 2 output channels, and ``same`` padding, this layer will ++ * compute the total padding applied in the depth, height, and width dimensions to be 2, 1, and 1, ++ * respectively. The depth padding is even and will be applied equally to both sides of the depth ++ * dimension. Since the height and width padding values are odd, they'll be applied to the ++ * bottom/right of the height/width dimensions. Thus, the padding applied to the input will be ++ * ``[1, 1, 0, 1, 0, 1]`` (front, back, top, bottom, left, right). Finally, the output produced ++ * will have size ``[1, 2, 2, 4, 4]``. ++ * ++ */ ++message Convolution3DLayerParams { ++ ++ /** ++ * The number of channels in the output (channelsOut). Must be a positive integer. ++ */ ++ int32 outputChannels = 1; ++ ++ /** ++ * The number of channels in the input (channels). Must be a positive integer. ++ */ ++ int32 inputChannels = 2; ++ ++ /** ++ * Group convolution, i.e., weight reuse along the channel axis. ++ * It must evenly divide both the number of input and output channels and be at most the number ++ * of input channels (a depthwise convolution). ++ * Input and kernels are divided into g groups and convolution is applied within the groups ++ * independently. ++ */ ++ int32 nGroups = 10; ++ ++ /* Depth of the convolution kernel. Must be a positive integer. ++ */ ++ int32 kernelDepth = 20; ++ ++ /* Height of the convolution kernel. Must be a positive integer. ++ */ ++ int32 kernelHeight = 21; ++ ++ /* Width of the convolution kernel. Must be a positive integer. ++ */ ++ int32 kernelWidth = 22; ++ ++ /* Stride along the depth direction. Must be a positive integer. ++ */ ++ int32 strideDepth = 31; ++ ++ /* Stride along the height direction. Must be a positive integer. ++ */ ++ int32 strideHeight = 32; ++ ++ /* Stride along the width direction. Must be a positive integer. ++ */ ++ int32 strideWidth = 33; ++ ++ /* Dilation along the depth direction. Must be a positive integer. ++ */ ++ int32 dilationDepth = 40; ++ ++ /* Dilation along the height direction. Must be a positive integer. ++ */ ++ int32 dilationHeight = 41; ++ ++ /* Dilation along the width direction. Must be a positive integer. ++ */ ++ int32 dilationWidth = 42; ++ ++ /** ++ * Flag to specify whether a bias is to be added or not. ++ * If false, then no bias is added. ++ */ ++ bool hasBias = 50; ++ ++ /** ++ * Weights associated with this layer. ++ * Weights have the shape ++ * if deconvolution == False ++ * ``[outputChannels, kernelChannels, kernelDepth, kernelHeight, kernelWidth]``, where ++ * kernelChannels == inputChannels / nGroups ++ * else if deconvolution == True ++ * ``[outputChannels / nGroups, kernelChannels, kernelDepth, kernelHeight, kernelWidth]``, where ++ */ ++ WeightParams weights = 60; ++ ++ /** ++ * Must be of size ``[outputChannels]``. ++ */ ++ WeightParams bias = 61; ++ ++ ++ /** ++ * The type of padding. ++ * All padding types pad the input shape with zeros. ++ * CUSTOM padding will add the custom padding values specified below to their respective ++ * dimensions, e.g., `customPaddingFront` number of zeros will be added to one side of the ++ * input's depth dimension and `customPaddingBack` number of zeros will be added to the other ++ * side of the input's depth dimension. ++ * VALID padding adds no padding to any dimension. In this case, the last convolution along ++ * each dimension will be dropped if the input dimension and the kernel size, stride, and ++ * dilation do not match. ++ * SAME padding adds enough padding to each dimension such that the output of the convolution ++ * has size ``Ceiling(inputShape / stride)``. Padding is added evenly to both sides of each ++ * dimension unless the total padding to add is odd, in which case it is added to the ++ * back/bottom/right side of the respective dimension. For example, if the total padding needed ++ * in the depth dimension is 3, 1 zero will be added to the front side of the depth dimension ++ * and 2 zeros will be added to the back side. ++ */ ++ enum PaddingType { ++ CUSTOM = 0; ++ VALID = 1; ++ SAME = 2; ++ } ++ PaddingType paddingType = 70; ++ ++ /* Padding before the input in the depth direction. Must be zero or a positive integer. ++ * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. ++ */ ++ int32 customPaddingFront = 80; ++ ++ /* Padding after the input in the depth direction. Must be zero or a positive integer. ++ * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. ++ */ ++ int32 customPaddingBack = 81; ++ ++ /* Padding before the input in the height direction. Must be zero or a positive integer. ++ * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. ++ */ ++ int32 customPaddingTop = 82; ++ ++ /* Padding after the input in the height direction. Must be zero or a positive integer. ++ * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. ++ */ ++ int32 customPaddingBottom = 83; ++ ++ /* Padding before the input in the width direction. Must be zero or a positive integer. ++ * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. ++ */ ++ int32 customPaddingLeft = 84; ++ ++ /* Padding after the input in the width direction. Must be zero or a positive integer. ++ * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. ++ */ ++ int32 customPaddingRight = 85; ++ ++ /* Flag to specify if this is Convolution Transpose or not. ++ */ ++ bool isDeconvolution = 86; ++ ++ /* ++ * The output shape, which has length 3 ``[D_out, H_out, W_out]``. ++ * This is used only for deconvolution (``isDeconvolution == true``). ++ * If not set, the deconvolution output shape is calculated ++ * based on ``PaddingType``. ++ */ ++ repeated uint64 outputShape = 87; ++ ++} ++ ++/** ++ * A layer that performs a matrix-vector or matrix-matrix product. ++ * This is equivalent to a fully-connected, or dense layer. ++ * The weight parameters correspond to a matrix of dimensions (inputChannels, outputChannels) i.e. (C_in, C_out) ++ * ++ * .. code:: ++ * ++ * y = InnerProductLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * Input can have rank 1 to rank 5. This is how it is reshaped in to the matrix (for rank > 1): ++ * rank 1 (x1) : in this case, the layer corresponds to a matrix-vector product. x1 must be equal to C_in ++ * rank 2 (x1, x2): x2 must be equal to C_in ++ * rank 3 (x1, x2, x3) --> (x1 * x2, x3). x3 must be equal to C_in ++ * rank 4 (x1, x2, x3, x4) ---> (x1, x2 * x3 * x4). x2 * x3 * x4 must be equal to C_in ++ * rank 5 (x1, x2, x3, x4, x5) ---> (x1 * x2, x3 * x4 * x5). x3 * x4 * x5 must be equal to C_in ++ * ++ * Output ++ * Output rank is same as the input rank ++ * rank 1: (C_out) ++ * rank 2: (x1, C_out) ++ * rank 3: (x1, x2, C_out) ++ * rank 4: (x1, C_out, 1, 1) ++ * rank 5: (x1, x2, C_out, 1, 1) ++ * ++ */ ++message InnerProductLayerParams { ++ ++ uint64 inputChannels = 1; /// Input size: C_in. ++ uint64 outputChannels = 2; /// Output size: C_out. ++ ++ bool hasBias = 10; /// Whether a bias is added or not. ++ ++ WeightParams weights = 20; /// Weight matrix [C_out, C_in]. ++ WeightParams bias = 21; /// Bias vector [C_out]. ++ ++ /** ++ * If set, this layer, at runtime, quantizes the floating point input blob to int8 before applying an ++ * inner product using INT8 weight matrix parameters, as provided in weights->int8RawValue. The ++ * result is then dequantized. ++ * Requires: ++ * * hasBias == false ++ * * QuantizationType == LinearQuantizationParams, such that ++ * * size of the "scale" field is 1 and "bias" field is empty in "LinearQuantizationParams" ++ * * numberOfBits == 8 ++ * * weights->rawValue_size to be empty ++ */ ++ bool int8DynamicQuantize = 22; ++ ++} ++ ++/** ++ * A layer that performs a matrix lookup and optionally adds a bias. ++ * The weights matrix is stored with dimensions [outputChannels, inputDim]. ++ * ++ * .. code:: ++ * ++ * y = EmbeddingLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * Input values must be in the range ``[0, inputDim - 1]``. ++ * ++ * Input must have rank equal to 4 or 5, such that the last 3 dimensions are all 1. ++ * rank 4: shape (x1, 1, 1, 1). x1 is effectively the batch/sequence length. ++ * rank 5: shape (x1, x2 , 1, 1, 1). x1 * x2 is effectively the combined batch/sequence length. ++ * ++ * Output ++ * Output rank is same as the input rank. Please see input description above. ++ * rank 4: shape (x1, outputChannels, 1, 1) ++ * rank 5: shape (x1, x2, outputChannels, 1, 1) ++ * ++ */ ++message EmbeddingLayerParams { ++ ++ uint64 inputDim = 1; /// Size of the input dictionary. ++ uint64 outputChannels = 2; /// Size of the output vectors. ++ ++ bool hasBias = 10; /// Whether a bias is added or not. ++ ++ WeightParams weights = 20; /// 2-D weights of dimensions [outputChannels, inputDim]. ++ WeightParams bias = 21; /// Bias of size [outputChannels]. ++ ++} ++ ++/** ++ * A layer that performs a matrix lookup and optionally adds a bias. ++ * The weights matrix is stored with dimensions [embeddingSize, vocabSize]. ++ * ++ * .. code:: ++ * ++ * y = EmbeddingNDLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * Input values must be in the range ``[0, vocabSize - 1]``. ++ * Input must have rank at least 2. The last dimension must always be 1. ++ * rank 2: shape (x1, 1). x1 is the batch/sequence length. ++ * rank 3: shape (x1, x2, 1). x1 * x2 is effectively the combined batch/sequence length. ++ * rank 4: shape (x1, x2, x3, 1). x1 * x2 * x2 is effectively the combined batch/sequence length. ++ * rank 5: shape (x1, x2 , x3, x4, 1). x1 * x2 * x3 * x4 is effectively the combined batch/sequence length. ++ * ++ * Output ++ * Output rank is same as the input rank. Please see input description above. ++ * rank 2: shape (x1, embeddingSize) ++ * rank 3: shape (x1, x2, embeddingSize) ++ * rank 4: shape (x1, x2, x3, embeddingSize) ++ * rank 5: shape (x1, x2, x3, x4, embeddingSize) ++ * ++ */ ++message EmbeddingNDLayerParams { ++ ++ uint64 vocabSize = 1; /// Size of the input dictionary. ++ uint64 embeddingSize = 2; /// Size of the output vectors. ++ bool hasBias = 3; /// Whether a bias is added or not. ++ WeightParams weights = 20; /// 2-D weights of dimensions [embeddingSize, vocabSize]. ++ WeightParams bias = 21; /// Bias of size [embeddingSize]. ++ ++} ++ ++/** ++ * A layer that performs batch normalization, ++ * which is performed along axis = -3, ++ * and repeated along the other axes, if present. ++ * ++ * .. code:: ++ * ++ * y = BatchnormLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * This operation is described by the following formula: ++ * ++ * .. math:: ++ * y_i = \gamma_i \dfrac{ (x_i - \mu_i)}{\sqrt{\sigma_i^2 + \epsilon}} + \beta_i \;,\;i=1,....,C ++ * ++ * Input ++ * A blob with rank greater than equal to 3. ++ * Example: Rank 4 blob represents [Batch, channels, height, width] ++ * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * ++ * Output ++ * A blob with the same shape as the input. ++ */ ++message BatchnormLayerParams { ++ ++ uint64 channels = 1; /// Size of the channel dimension in the input. ++ ++ /** ++ * If ``computeMeanVar == true``, ++ * the mean and variance are calculated from either ++ * the single input instance, if ``instanceNormalization == true``, ++ * or the whole batch, if ``instanceNormalization = false``. ++ * and the values provided in parameters "mean" and "variance" are ignored. ++ */ ++ bool computeMeanVar = 5; ++ bool instanceNormalization = 6; ++ ++ /** ++ * A small constant to avoid division by 0 while normalizing by variance. ++ * Defaults to ``1e-5`` if not set or set to ``0``. ++ */ ++ float epsilon = 10; ++ ++ WeightParams gamma = 15; /// Parameter of length [channels] ++ WeightParams beta = 16; /// Parameter of length [channels] ++ WeightParams mean = 17; /// Parameter of length [channels] ++ WeightParams variance = 18; /// Parameter of length [channels] ++ ++} ++ ++/** ++ * A spatial pooling layer. ++ * ++ * .. code:: ++ * ++ * y = PoolingLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with rank greater than equal to 4. ++ * Rank 4 blob represents [Batch, channels, height, width] ++ * For ranks greater than 4, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * ++ * Output ++ * Rank is same as the input. e.g.: for rank 4 input, output shape is [B, C, H_out, W_out] ++ * ++ * Padding options are similar to ``ConvolutionLayerParams`` ++ * with the additional option of ``ValidCompletePadding`` (``includeLastPixel``), ++ * which ensures that the last application of the kernel ++ * always includes the last pixel of the input image, if there is padding. ++ * ++ * .. code:: ++ * ++ * H_out = ceil(float(H_in + 2 * paddingAmounts[0] - kernelSize[0])/float(Stride[0])) + 1 ++ * if (paddingAmounts[0] > 0 or paddingAmounts[1] > 0) ++ * if ((H_out - 1) * Stride >= H_in + paddingAmounts[0]) { ++ * H_out = H_out - 1 ++ * } ++ * } ++ * ++ * The equivalent expressions hold true for ``W_out`` as well. ++ * Only symmetric padding is supported with this option. ++ */ ++message PoolingLayerParams { ++ ++ enum PoolingType { ++ ++ MAX = 0; ++ AVERAGE = 1; ++ L2 = 2; ++ ++ } ++ PoolingType type = 1; /// Type of pooling operation. ++ ++ /** ++ * Must be length 2 in the order ``[H, W]``. ++ * If not set, default value ``[3, 3]`` is used. ++ */ ++ repeated uint64 kernelSize = 10; ++ ++ /** ++ * Must be length 2 in the order ``[H, W]``. ++ * If not set, default value ``[1, 1]`` is used. ++ */ ++ repeated uint64 stride = 20; ++ ++ message ValidCompletePadding { ++ ++ /** ++ * Must be length 2 in order ``[H, W]``. ++ * If not set, value ``[0, 0]`` is used. ++ */ ++ repeated uint64 paddingAmounts = 10; ++ ++ } ++ ++ oneof PoolingPaddingType { ++ ValidPadding valid = 30; ++ SamePadding same = 31; ++ ValidCompletePadding includeLastPixel = 32; ++ } ++ ++ /** ++ * If true, padded values are excluded from the count (denominator) ++ * when computing average pooling. ++ */ ++ bool avgPoolExcludePadding = 50; ++ ++ /** ++ * If true, global pooling is performed. ++ * Kernel size is inferred from the input data spatial dimensions. ++ */ ++ bool globalPooling = 60; ++ ++} ++ ++/* ++ * A layer to pool three spatial dimensions ++ * ++ * Input ++ * A blob with rank equal to 5, representing [Batch, channels, depth, height, width]. ++ * ++ * Output ++ * Rank is same as the input: A blob with rank equal to 5, representing [Batch, channels, depth, height, width]. ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * For example, given an input of shape (1,1,2,3,3): ++ * +----+----+----+ ++ * / | 10 | 11 | 12 | ++ * / +----+----+----+ ++ * / | 13 | 14 | 15 | ++ * / +----+----+----+ ++ * / | 16 | 17 | 18 | ++ * / +----+----+----+ ++ * +----+----+----+ / ++ * | 1 | 2 | 3 | / ++ * +----+----+----+ / ++ * | 4 | 5 | 6 | / ++ * +----+----+----+ / ++ * | 7 | 8 | 9 | / ++ * +----+----+----+ ++ * ++ * And applying MAX pooling using: ++ * Kernel: 2x2x2 ++ * Stride: 1x1x1 ++ * Valid Padding ++ * We expect to get an output with shape: (1,1,1,2,2) and value: ++ * +----+----+ ++ * | 14 | 15 | ++ * +----+----+ ++ * | 17 | 18 | ++ * +----+----+ ++ */ ++message Pooling3DLayerParams { ++ ++ enum PoolingType3D { ++ MAX = 0; ++ AVERAGE = 1; ++ } ++ ++ // Whether to use Max or Average ++ PoolingType3D type = 1; ++ ++ // Depth of the pooling region. ++ int32 kernelDepth = 2; ++ ++ // Height of the pooling region. ++ int32 kernelHeight = 3; ++ ++ // Width of the pooling region. ++ int32 kernelWidth = 4; ++ ++ // Stride along the depth direction ++ int32 strideDepth = 5; ++ ++ // Stride along the height direction ++ int32 strideHeight = 6; ++ ++ // Stride along the width direction ++ int32 strideWidth = 7; ++ ++ /** ++ * The type of padding. ++ * All padding types pad the input shape with zeros. ++ * CUSTOM padding will add the custom padding values specified below to their respective ++ * dimensions, e.g., `customPaddingFront` number of zeros will be added to one side of the ++ * input's depth dimension and `customPaddingBack` number of zeros will be added to the other ++ * side of the input's depth dimension. ++ * VALID padding adds no padding to any dimension. In this case, the last pool along ++ * each dimension will be dropped if the input dimension and the kernel size, and stride do not match. ++ * SAME padding adds enough padding to each dimension such that the output ++ * has the same spatial dimensions as the input. Padding is added evenly to both ++ * sides of each dimension unless the total padding to add is odd, in which case the extra padding ++ * is added to the back/bottom/right side of the respective dimension. For example, if the the ++ * total horizontal padding is 3, then there will be 1 padding on the left, and 2 padding on the right. ++ */ ++ enum Pooling3DPaddingType { ++ CUSTOM = 0; ++ VALID = 1; ++ SAME = 2; ++ } ++ Pooling3DPaddingType paddingType = 15; ++ ++ // Padding before the input in the depth direction. ++ int32 customPaddingFront = 8; ++ ++ // Padding after the input in the depth direction. ++ int32 customPaddingBack = 9; ++ ++ // Padding before the input in the height direction. ++ int32 customPaddingTop = 10; ++ ++ // Padding after the input in the height direction. ++ int32 customPaddingBottom = 11; ++ ++ // Padding before the input in the width direction. ++ int32 customPaddingLeft = 12; ++ ++ // Padding after the input in the width direction. ++ int32 customPaddingRight = 13; ++ ++ // If true, exclude zeros from padding in Average pooling. Meaningless in Max Pooling. ++ bool countExcludePadding = 14; ++} ++ ++/* ++ * A layer to pool three spatial dimensions down to one value. ++ * This behaves like a special case of Pooling3DLayerParams in which ++ * the Kernel is the size of the input and there is no padding. ++ * ++ * Input ++ * A blob with rank equal to 5, representing [Batch, channels, depth, height, width]. ++ * ++ * Output ++ * Rank is same as the input: A blob with rank equal to 5, representing [Batch, channels, depth, height, width]. ++ * Depth, height, and width of the output will always be 1. ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * For example, given an input of shape (1,1,2,3,3): ++ * +----+----+----+ ++ * / | 10 | 11 | 12 | ++ * / +----+----+----+ ++ * / | 13 | 14 | 15 | ++ * / +----+----+----+ ++ * / | 16 | 17 | 18 | ++ * / +----+----+----+ ++ * +----+----+----+ / ++ * | 1 | 2 | 3 | / ++ * +----+----+----+ / ++ * | 4 | 5 | 6 | / ++ * +----+----+----+ / ++ * | 7 | 8 | 9 | / ++ * +----+----+----+ ++ * ++ * And applying MAX global 3d pooling, we expect to get an output with shape: (1,1,1,1,1) and value: ++ * +----+ ++ * | 18 | ++ * +----+ ++ */ ++message GlobalPooling3DLayerParams { ++ ++ enum GlobalPoolingType3D { ++ MAX = 0; ++ AVERAGE = 1; ++ } ++ ++ // Whether to use Max or Average ++ GlobalPoolingType3D type = 1; ++} ++ ++/** ++ * A layer that performs padding along spatial dimensions. ++ * ++ * .. code:: ++ * ++ * y = PaddingLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with rank at least 2. ++ * e.g.: blob with shape ``[H_in, W_in]``. ++ * For ranks greater than 2, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch ++ * i.e. Padding is applied on last two dimensions. ++ * ++ * Output ++ * Same rank as the input. ++ * e.g.: blob with shape ``[H_out, W_out]``. ++ * ++ * Output dimensions are calculated as follows: ++ * ++ * .. code:: ++ * ++ * H_out = H_in + topPaddingAmount + bottomPaddingAmount ++ * W_out = W_in + leftPaddingAmount + rightPaddingAmount ++ * ++ * topPaddingAmount == Height startEdgeSize == borderAmounts[0].startEdgeSize ++ * bottomPaddingAmount == Height endEdgeSize == borderAmounts[0].endEdgeSize ++ * leftPaddingAmount == Width startEdgeSize == borderAmounts[1].startEdgeSize ++ * rightPaddingAmount == Width endEdgeSize == borderAmounts[1].endEdgeSize ++ * ++ * There are three types of padding: ++ * ++ * - ``PaddingConstant``, which fills a constant value at the border. ++ * - ``PaddingReflection``, which reflects the values at the border. ++ * - ``PaddingReplication``, which replicates the values at the border. ++ * ++ * Given the following input: ++ * ++ * .. code:: ++ * ++ * [1, 3, 4] : 1 2 3 4 ++ * 5 6 7 8 ++ * 9 10 11 12 ++ * ++ * Here is the output of applying the padding ++ * ``(top=2, left=2, bottom=0, right=0)`` ++ * with each of the supported types: ++ * ++ * - ``PaddingConstant`` (``value = 0``): ++ * .. code:: ++ * ++ * [1, 5, 6] : 0 0 0 0 0 0 ++ * 0 0 0 0 0 0 ++ * 0 0 1 2 3 4 ++ * 0 0 5 6 7 8 ++ * 0 0 9 10 11 12 ++ * ++ * - ``PaddingReflection``: ++ * .. code:: ++ * ++ * [1, 5, 6] : 11 10 9 10 11 12 ++ * 7 6 5 6 7 8 ++ * 3 2 1 2 3 4 ++ * 7 6 5 6 7 8 ++ * 11 10 9 10 11 12 ++ * ++ * - ``PaddingReplication``: ++ * .. code:: ++ * ++ * [1, 5, 6] : 1 1 1 2 3 4 ++ * 1 1 1 2 3 4 ++ * 1 1 1 2 3 4 ++ * 5 5 5 6 7 8 ++ * 9 9 9 10 11 12 ++ */ ++message PaddingLayerParams { ++ ++ /** ++ * Fill a constant value in the padded region. ++ */ ++ message PaddingConstant { ++ float value = 1; ++ } ++ ++ /** ++ * Reflect the values at the border for padding. ++ */ ++ message PaddingReflection { ++ } ++ ++ /** ++ * Replicate the values at the border for padding. ++ */ ++ message PaddingReplication { ++ } ++ ++ oneof PaddingType { ++ PaddingConstant constant = 1; ++ PaddingReflection reflection = 2; ++ PaddingReplication replication = 3; ++ } ++ ++ BorderAmounts paddingAmounts = 10; /// Amounts to be padded to the input. ++ ++} ++ ++/** ++ * A layer that concatenates along the axis = -3 or -5. ++ * For general concatenation along any axis, see ConcatNDLayer. ++ * ++ * .. code:: ++ * ++ * y = ConcatLayer(x1,x2,....) ++ * ++ * Requires more than 1 input and produces 1 output. ++ * ++ * Input ++ * All input blobs must have same rank. ++ * If "sequenceConcat" = False, rank must be greater than equal to 3. In this case concatenation is along axis = -3 ++ * If "sequenceConcat" = True, rank must be greater than equal to 5. In this case concatenation is along axis = -5 ++ * ++ * Output ++ * Same rank as the input. ++ * ++ */ ++message ConcatLayerParams { ++ ++ /** ++ * If true, concatenate along the axis = -5 instead of axis = -3. ++ */ ++ bool sequenceConcat = 100; ++ ++} ++ ++/** ++ * A layer that performs local response normalization (LRN). ++ * ++ * .. code:: ++ * ++ * y = LRNLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with rank greater than equal to 3. ++ * Example: Rank 4 blob represents [Batch, channels, height, width] ++ * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * Output ++ * A blob with the same shape as the input. ++ * ++ * This layer is described by the following formula: ++ * ++ * .. math:: ++ * x_i \leftarrow \dfrac{x_i}{\left ( k + \dfrac{\alpha}{C} \sum_j x_j^2 \right )^\beta} ++ * ++ * where the summation is done over a ``(localSize, 1, 1)`` neighborhood --- ++ * that is, over a window "across" channels in 1x1 spatial neighborhoods. ++ */ ++message LRNLayerParams { ++ ++ float alpha = 1; ++ float beta = 2; ++ uint64 localSize = 3; /// Number of channels in the normalization window. ++ float k = 4; /// Defaults to 1 if not set or 0. Must be strictly positive. ++ ++} ++ ++/** ++ * Softmax Normalization Layer ++ * ++ * A layer that performs softmax normalization. ++ * Normalization is applied along axis = -3 or N-3 (where N is the rank of the input) ++ * For softmax layer that can operate on any axis, see SoftmaxNDLayer. ++ * ++ * ++ * .. code:: ++ * ++ * y = SoftmaxLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * Must be a blob with rank >= 3. ++ * Output ++ * A blob with the same shape as the input. ++ * ++ * This layer is described by the following formula: ++ * ++ * .. math:: ++ * x_i \leftarrow \dfrac{e^{x_i}}{\sum_i{e^{x_i}}} ++ */ ++message SoftmaxLayerParams { ++ ++} ++ ++/** ++ * A layer that uniformly splits across axis = -3 to produce a specified number of outputs. ++ * For general split operation along any axis, see SplitNDLayer. ++ * ++ * .. code:: ++ * ++ * (y1,y2,...yN) = SplitLayer(x), where N = nOutputs ++ * ++ * Requires 1 input and produces multiple outputs. ++ * ++ * Input ++ * A blob with rank at least 3. ++ * e.g.: blob with shape ``[C, H, W]`` ++ * Output ++ * ``nOutputs`` blobs each with same rank as the input. ++ * e.g.: For input that is of shape ``[C, H, W]``, output shapes will be ``[C/nOutputs, H, W]`` ++ */ ++message SplitLayerParams { ++ ++ uint64 nOutputs = 1; /// The number of outputs. ++ ++} ++ ++/** ++ * A layer that performs elementwise addition. ++ * This layer has limited broadcasting support. For general broadcasting see AddBroadcastableLayer. ++ * ++ * .. code:: ++ * ++ * y = AddLayer(x1,x2,...) ++ * ++ * Requires 1 or more than 1 input and produces 1 output. ++ * ++ * Input ++ * In general, there are no rank constraints. ++ * However, only certain set of shapes are broadcastable. For example: ++ * [B, 1, 1, 1], [B, C, 1, 1], [B, 1, H, W], [B, C, H, W] ++ * Output ++ * A blob with shape equal to the input blob. ++ * ++ * If only one input is provided, scalar addition is performed: ++ * ++ * .. math:: ++ * y = x + \alpha ++ * ++ */ ++message AddLayerParams { ++ ++ /** ++ * Scalar to be added to the input. ++ * Only used if there is a single input. ++ */ ++ float alpha = 1; ++ ++} ++ ++/** ++ * A layer that performs elementwise multiplication. ++ * This layer has limited broadcasting support. For general broadcasting see MultiplyBroadcastableLayer. ++ * ++ * .. code:: ++ * ++ * y = MultiplyLayer(x1,x2,...) ++ * ++ * Requires 1 or more than 1 input and produces 1 output. ++ * ++ * Input ++ * In general, there are no rank constraints. ++ * However, only certain set of shapes are broadcastable. For example: ++ * [B, 1, 1, 1], [B, C, 1, 1], [B, 1, H, W], [B, C, H, W] ++ * Output ++ * A blob with shape equal to the first input blob. ++ * ++ * If only one input is provided, scalar multiplication is performed: ++ * ++ * .. math:: ++ * y = \alpha x ++ * ++ */ ++message MultiplyLayerParams { ++ ++ /** ++ * Scalar to be multiplied with the input. ++ * Only used if there is a single input. ++ */ ++ float alpha = 1; ++ ++} ++ ++/** ++ * A layer that applies a unary function. ++ * ++ * .. code:: ++ * ++ * y = UnaryFunctionLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with no rank constraints. ++ * Output ++ * A blob with the same shape as the input. ++ * ++ * The input is first modified by shifting and scaling: ++ * ++ * .. math:: ++ * x \leftarrow \text{scale} \cdot x + \text{shift} ++ */ ++message UnaryFunctionLayerParams { ++ ++ /** ++ * A unary operator. ++ * ++ * The following functions are supported: ++ * ++ * ``SQRT`` ++ * .. math:: f(x) = \sqrt{x} ++ * ++ * ``RSQRT`` ++ * .. math:: f(x) = \dfrac{1}{\sqrt{x + \epsilon}} ++ * ++ * ``INVERSE`` ++ * .. math:: f(x) = \dfrac{1}{x + \epsilon} ++ * ++ * ``POWER`` ++ * .. math:: f(x) = x^\alpha ++ * ++ * ``EXP`` ++ * .. math:: f(x) = e^x ++ * ++ * ``LOG`` ++ * .. math:: f(x) = \log x ++ * ++ * ``ABS`` ++ * .. math:: f(x) = |x| ++ * ++ * ``THRESHOLD`` ++ * .. math:: f(x) = \text{max}(\alpha, x) ++ */ ++ enum Operation { ++ SQRT = 0; ++ RSQRT = 1; ++ INVERSE = 2; ++ POWER = 3; ++ EXP = 4; ++ LOG = 5; ++ ABS = 6; ++ THRESHOLD = 7; ++ } ++ Operation type = 1; /// The type of unary function. ++ ++ /** ++ * A constant used in ``POWER`` and ``THRESHOLD`` functions. ++ */ ++ float alpha = 2; ++ ++ /** ++ * A small constant to avoid division by 0 while normalizing variance. ++ * Defaults to ``1e-6`` if not set or set to ``0``. ++ */ ++ float epsilon = 3; ++ ++ /** ++ * Input is shifted by this amount ++ * before the unary function is applied. ++ * Defaults to ``0.0`` if not set. ++ */ ++ float shift = 4; ++ ++ /** ++ * Input is scaled by this amount ++ * before the unary function is applied. ++ * Defaults to ``1.0`` if not set or set to ``0``. ++ */ ++ float scale = 5; ++ ++} ++ ++/** ++ * A layer that scales up spatial dimensions. ++ * It supports two modes: nearest neighbour (default) and bilinear. ++ * ++ * .. code:: ++ * ++ * y = UpsampleLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with rank at least 3. ++ * e.g.: blob with shape ``[C, H, W]``. ++ * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * ++ * Output ++ * Same rank as the input. ++ * e.g.: blob with shape ``[C, scalingFactor[0] * H, scalingFactor[1] * W]`` ++ */ ++message UpsampleLayerParams { ++ ++ /** ++ * Scaling Factor. Mutually exclusive with fractionalScalingFactor. ++ * Must be length 2 in order ``[H, W]``. ++ * If not set, default value ``[1, 1]`` is used. ++ */ ++ repeated uint64 scalingFactor = 1; ++ ++ /** ++ * Fractional scaling factor. Mutually exclusive with scalingFactor. ++ * Must be length 2 in order ``[H, W]``. ++ * If not set, default value ``[1.0, 1.0]`` is used. ++ */ ++ repeated float fractionalScalingFactor = 7; ++ ++ /* ++ * Overall mode for interpolating new elements when upsampling. ++ * NN - Nearest Neighbors - simply pick the nearest true value for interpolated values. ++ * BILINEAR - Use bilinear interpolation. See LinearUpsamplingMode for behavior. ++ */ ++ enum InterpolationMode { ++ ++ NN = 0; /// Nearest Neighbour ++ BILINEAR = 1; /// Bilinear ++ ++ } ++ ++ InterpolationMode mode = 5; ++ ++ /** ++ * LinearUpsampleMode specifies the behavior for linear upsampling. Only valid when Interpolation Mode is BILINEAR. ++ * If input grid is [0, Xin-1] (corresponding to an input size of Xin), and if the output size is Xout, ++ * then the grid points are sampled in the following manner: ++ * DEFAULT: ++ * spacing = (Xin-Xin/Xout) / (Xout-1) ++ * grid_point[i] = min(Xin-1, max(0, i * spacing)), for i = 0,1,2,….,Xout-1 ++ * ALIGN_CORNERS_TRUE: ++ * spacing = (Xin-1) / (Xout-1) ++ * grid_point[i] = min(Xin-1, max(0, i * spacing)), for i = 0,1,2,….,Xout-1 ++ * ALIGN_CORNERS_FALSE: ++ * spacing = Xin / Xout ++ * grid_point[i] = min(Xin-1, max(0, i * spacing + 0.5 * spacing - 0.5)), for i = 0,1,2,….,Xout-1 ++ */ ++ enum LinearUpsampleMode { ++ ++ DEFAULT = 0; ++ ALIGN_CORNERS_TRUE = 1; ++ ALIGN_CORNERS_FALSE = 2; ++ ++ } ++ ++ LinearUpsampleMode linearUpsampleMode = 6; ++ ++} ++ ++/** ++* A layer that resizes the input to a pre-specified spatial size using bilinear interpolation. ++* ++* .. code:: ++* ++* y = ResizeBilinearLayer(x) ++* ++* Requires 1 input and produces 1 output. ++* ++* Input ++* A blob with rank at least 3. ++* e.g.: blob with shape ``[C, H_in, W_in]``. ++* For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++* ++* Output ++* Same rank as the input. ++* e.g.: blob with shape ``[C, H_out, W_out]``. ++* ++*/ ++message ResizeBilinearLayerParams { ++ ++ /** ++ * Target Spatial Size. ++ * Must be length 2 in order ``[Height, Width]``, i.e. ``[H_out, W_out]``. ++ * If not set, default value ``[1, 1]`` is used. ++ */ ++ repeated uint64 targetSize = 1; ++ ++ /** ++ * Mode used to compute the grid on which the spatial output values are evaluated. ++ * Same mode is applied to both the height and width axes. ++ */ ++ SamplingMode mode = 2; ++ ++} ++ ++/** ++* A layer that extracts cropped spatial patches or RoIs (regions of interest) from the input and resizes them to a pre-specified size using ++* bilinear interpolation. ++* Note that RoI Align layer can be implemented with this layer followed by a pooling layer. ++* ++* .. code:: ++* ++* y = CropResizeLayer(x) ++* ++* Requires 2 inputs and produces 1 output. ++* ++* Input ++* There are two inputs. ++* First input represents an image feature map. ++* Second input represents the bounding box coordinates for N patches or RoIs (region of interest). ++* ++* First input is rank 5: [1, Batch, C, H_in, W_in]. ++* Second input is rank 5. Its shape can be either [N, 1, 4, 1, 1] or [N, 1, 5, 1, 1]. ++* ++* N: number of patches/RoIs to be extracted ++* ++* If RoI shape = ``[N, 1, 4, 1, 1]`` ++* The axis=-3 corresponds to the four coordinates specifying the bounding box. ++* All the N RoIs are extracted from all the batches of the input. ++* ++* If RoI shape = ``[N, 1, 5, 1, 1]`` ++* The first element of the axis=-3 specifies the input batch id from which to extract the RoI and ++* must be in the interval ``[0, Batch - 1]``. That is, n-th RoI is extracted from the RoI[n,0,0,0,0]-th ++* input batch id. The last four elements of the axis=-3 specify the bounding box coordinates. ++* ++* Output ++* A blob with rank 5. ++* - Shape is [N, Batch, C, H_out, W_out] if input RoI shape is [N, 1, 4, 1, 1] ++* - Shape is [N, 1, C, H_out, W_out] if input RoI shape is [N, 1, 5, 1, 1] ++* ++*/ ++message CropResizeLayerParams { ++ ++ /** ++ * Target Spatial Size. ++ * Must be length 2 in order ``[Height, Width]``, i.e. ``[H_out, W_out]``. ++ * If not set, default value ``[1, 1]`` is used. ++ */ ++ repeated uint64 targetSize = 1; ++ ++ /** ++ * If true the bounding box coordinates must be in the interval [0, 1]. ++ * They are scaled by (H_in - 1), (W_in - 1), i.e. based on the input spatial dimensions. ++ * If false the bounding box coordinates must be in the interval ++ * [0, H_in -1] and [0, W_in - 1], respectively for height and width dimensions. ++ */ ++ bool normalizedCoordinates = 2; ++ ++ /** ++ * Mode used to compute the grid on which the spatial output values are evaluated. ++ * Same mode is applied to both the height and width axes. ++ */ ++ SamplingMode mode = 3; ++ ++ /** ++ * Representation used to express the bounding box coordinates. ++ * It determines how the values of the second input are interpreted. ++ */ ++ BoxCoordinatesMode boxIndicesMode = 4; ++ ++ /** ++ * Additional spatial scale that multiplies the bounding box coordinates. ++ * Generally used while implementing the RoI Align layer, ++ * which uses unnormalized RoI coordinates along with a spatial scale less than or equal to 1. ++ */ ++ float spatialScale = 5; ++ ++} ++ ++/** ++ * A layer that performs elementwise addition of a bias, ++ * which is broadcasted to match the input shape. ++ * ++ * .. code:: ++ * ++ * y = BiasLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with rank at least 3. ++ * e.g.: blob with shape ``[C, H, W]``. ++ * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * Output ++ * A blob with the same shape as the input. ++ */ ++message BiasLayerParams { ++ ++ /** ++ * The shape of the bias. ++ * Must be one of the following: ++ * ``[1]``, ``[C]``, ``[1, H, W]`` or ``[C, H, W]``. ++ */ ++ repeated uint64 shape = 1; ++ ++ /** ++ * The bias values. ++ * The size must be equal to the product of the ``shape`` dimensions. ++ */ ++ WeightParams bias = 2; ++ ++} ++ ++/** ++ * A layer that performs elmentwise multiplication by a scale factor ++ * and optionally adds a bias; ++ * both the scale and bias are broadcasted to match the input shape. ++ * ++ * .. code:: ++ * ++ * y = ScaleLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with rank at least 3. ++ * e.g.: blob with shape ``[C, H, W]``. ++ * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * Output ++ * A blob with the same shape as the input. ++ */ ++message ScaleLayerParams { ++ ++ /** ++ * The shape of the scale. ++ * Must be one of the following: ++ * ``[1]``, ``[C]``, ``[1, H, W]`` or ``[C, H, W]``. ++ */ ++ repeated uint64 shapeScale = 1; ++ ++ /** ++ * The scale values. ++ * The size must be equal to the product of the ``shape`` dimensions. ++ */ ++ WeightParams scale = 2; /// Scale values. Size must be equal to the product of dimensions specified in shapeScale. ++ ++ bool hasBias = 3; /// If true, a bias is added after scaling. ++ ++ /** ++ * The shape of the bias. ++ * Must be one of the following: ++ * ``[1]``, ``[C]``, ``[1, H, W]`` or ``[C, H, W]``. ++ */ ++ repeated uint64 shapeBias = 4; ++ ++ /** ++ * The bias values. ++ * The size must be equal to the product of the ``shape`` dimensions. ++ */ ++ WeightParams bias = 5; ++ ++} ++ ++/** ++ * A layer that loads data as a parameter and provides it as an output. ++ * The output is rank 5. For general rank, see LoadConstantNDLayer. ++ * ++ * .. code:: ++ * ++ * y = LoadConstantLayer() ++ * ++ * Requires no input and produces 1 output. ++ * ++ * Output: ++ * A blob with rank 5 and shape ``[1, 1, C, H, W]`` ++ */ ++message LoadConstantLayerParams { ++ ++ /** ++ * The shape of the constant to be loaded, ++ * which must be``[C, H, W]``, that is length 3. ++ */ ++ repeated uint64 shape = 1; ++ ++ /** ++ * The data values, ++ * of size ``C * H * W``. ++ */ ++ WeightParams data = 2; ++ ++} ++ ++/** ++ * A layer that performs L2 normalization, i.e. divides by the ++ * the square root of the sum of squares of all elements of input. ++ * ++ * .. code:: ++ * ++ * y = L2NormalizeLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with rank greater than equal to 3. ++ * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * Output ++ * A blob with the same shape as the input. ++ * ++ * This layer is described by the following formula: ++ * ++ * .. math:: ++ * x_i \leftarrow \dfrac{x_i}{\sqrt{\sum{x_i^2} + \epsilon}} ++ */ ++message L2NormalizeLayerParams { ++ ++ /** ++ * A small constant to avoid division by 0 while normalizing variance. ++ * Defaults to ``1e-6`` if not set or set to ``0``. ++ */ ++ float epsilon = 1; ++ ++} ++ ++/// Data Reorganization Layers ++/// -------------------------- ++ ++/** ++ * A layer that flattens the input. ++ * ++ * .. code:: ++ * ++ * y = FlattenLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with rank greater than equal to 3. ++ * e.g.: Rank 4 blob represents [Batch, C, H, W] ++ * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * Output ++ * Same rank as the input, such that last two dimensions are both 1. ++ * e.g.: For rank 4 input, output shape is ``[Batch, C * H * W, 1, 1]`` ++ * ++ * There are two X orders: ``CHANNEL_FIRST`` and ``CHANNEL_LAST``. ++ * ``CHANNEL_FIRST`` does not require data to be rearranged, ++ * because row major ordering is used by internal storage. ++ * ``CHANNEL_LAST`` requires data to be rearranged. ++ */ ++message FlattenLayerParams { ++ ++ enum FlattenOrder { ++ ++ CHANNEL_FIRST = 0; ++ CHANNEL_LAST = 1; ++ ++ } ++ FlattenOrder mode = 1; ++ ++} ++ ++/** ++ * A layer that recasts the input into a new shape. ++ * ++ * .. code:: ++ * ++ * y = ReshapeLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with rank 5. ++ * e.g.: ``[1, 1, C, H, W]`` or ``[Seq, 1, C, H, W]``. ++ * Output ++ * A blob with rank 5. ++ * e.g.: ``[1, 1, C_out, H_out, W_out]`` or ``[Seq_out, 1, C_out, H_out, W_out]``. ++ * ++ * There are two reshape orders: ``CHANNEL_FIRST`` and ``CHANNEL_LAST``. ++ * ``CHANNEL_FIRST`` is equivalent to ++ * flattening the input to ``[Seq, 1, C * H * W, 1, 1]`` in channel first order ++ * and then reshaping it to the target shape; ++ * no data rearrangement is required. ++ * ``CHANNEL_LAST`` is equivalent to ++ * flattening the input to ``[Seq, 1, H * W * C, 1, 1]`` in channel last order, ++ * reshaping it to ``[Seq_out, 1, H_out, W_out, C_out]`` (it is now in "H_out-major"" order), ++ * and then permuting it to ``[C_out, H_out, W_out]``; ++ * both the flattening and permuting requires the data to be rearranged. ++ */ ++message ReshapeLayerParams { ++ ++ /** ++ * The shape of the output. ++ * Must be of length 3 or 4. ++ * If set to 3, ``targetShape`` is interpreted as ++ * ``[1, 1, C_out, H_out, W_out]``, and sequence length of the input is preserved. ++ * If set to 4, ``targetShape`` is interpreted as ++ * ``[Seq_out, 1, C_out, H_out, W_out]``, ++ * where ``Seq_out`` is the new sequence length. ++ */ ++ repeated int64 targetShape = 1; ++ ++ enum ReshapeOrder { ++ ++ CHANNEL_FIRST = 0; ++ CHANNEL_LAST = 1; ++ ++ } ++ ReshapeOrder mode = 2; ++ ++} ++ ++/** ++ * A layer that rearranges the dimensions and data of an input. ++ * For generic transpose/permute operation see TransposeLayer. ++ * ++ * .. code:: ++ * ++ * y = PermuteLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * Must be a rank 5 blob. ++ * e.g.: shape ``[Seq, B, C, H, W]``. ++ * Output ++ * Rank 5 blob. Transposed version of the input, such that dimensions at axis=1 or axis=-4 is unchanged. ++ * ++ * ++ * Examples: ++ * ++ * Assume input shape is [Seq, B, C, H, W] ++ * ++ * - If ``axis`` is set to ``[0, 3, 1, 2]``, ++ * then the output has shape ``[Seq, B, W, C, H]`` ++ * ++ * - If ``axis`` is set to ``[3, 1, 2, 0]``, ++ * then the output has shape ``[W, B, C, H, Seq]`` ++ * ++ * - If ``axis`` is set to ``[0, 3, 2, 1]``, ++ * then the output has shape ``[Seq, B, W, H, C]`` ++ * ++ * - If ``axis`` is not set, or is set to ``[0, 1, 2, 3]``, ++ * the output is the same as the input. ++ */ ++message PermuteLayerParams { ++ ++ /** ++ * The order in which to permute the dimensions. ++ * Must have length 4 and a permutation of ``[0, 1, 2, 3]``. ++ */ ++ repeated uint64 axis = 1; ++ ++} ++ ++/** ++ * A layer that reorganizes data in the input in specific ways. ++ * ++ * .. code:: ++ * ++ * y = ReorganizeDataLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with rank at least 3. ++ * e.g.: blob with shape ``[C, H, W]``. ++ * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * Output ++ * Same rank as the input. ++ * e.g.: blob with shape ``[C_out, H_out, W_out]``. ++ * ++ * mode == SPACE_TO_DEPTH ++ * ``[C_out, H_out, W_out]`` : ``[C * blockSize * blockSize, H/blockSize, W/blockSize]``. ++ * blockSize must divide H and W. ++ * Data is moved from the spatial dimensions to the channel dimension. Input is spatially divided into ++ * non-overlapping blocks of size blockSize X blockSize and data from each block is moved into the ++ * channel dimension. ++ * ++ * mode == DEPTH_TO_SPACE ++ * ``[C_out, H_out, W_out]`` : ``[C/(blockSize * blockSize), H * blockSize, W * blockSize]``. ++ * Square of blockSize must divide C. ++ * Reverse of SPACE_TO_DEPTH. Data is moved from the channel dimension to the spatial dimensions. ++ * ++ * mode == PIXEL_SHUFFLE ++ * ``[C_out, H_out, W_out]`` : ``[C/(blockSize * blockSize), H * blockSize, W * blockSize]``. ++ * Square of blockSize must divide C. ++ * Similar to DEPTH_TO_SPACE, but using the pixel-shuffle semantics for channel order in the output space. ++ * In both modes, elements along the channel dimension are collapsed into ++ * blocks in the spatial dimensions. The difference is in the arrangement of ++ * the input-channels' data in the output space. See below example for more ++ * detail. ++ * (Only available in Core ML Specification >= 5 (iOS >= 14, macOS >= 11.0) ++ * ++ * ++ * Examples: ++ * ++ * Assume input is the following [C = 8, H = 1, W = 2] tensor: ++ * ++ * .. code:: ++ * ++ * [[[1 2]] [[3 4]] [[5 6]] [[7 8]] [[9 10]] [[11 12]] [[13 14]] [[15 16]]] ++ * ++ * If block_size == 2 and mode == DEPTH_TO_SPACE, output will be the following ++ * [C = 2, H = 2, W = 4] tensor: ++ * ++ * .. code:: ++ * ++ * [[[ 1 5 2 6] ++ * [ 9 13 10 14]] ++ * ++ * [[ 3 7 4 8] ++ * [11 15 12 16]]] ++ * ++ * For mode == SPACE_TO_DEPTH, the behavior is the same as mode == ++ * DEPTH_TO_SPACE, but with the input and output swapped. ++ * ++ * If block_size == 2 and mode == PIXEL_SHUFFLE, output will be the following ++ * [C = 2, H = 2, W = 4] tensor: ++ * ++ * .. code:: ++ * ++ * [[[ 1 3 2 4] ++ * [ 5 7 6 8]] ++ * ++ * [[ 9 11 10 12] ++ * [13 15 14 16]]] ++ * ++ */ ++message ReorganizeDataLayerParams { ++ ++ enum ReorganizationType { ++ ++ SPACE_TO_DEPTH = 0; ++ DEPTH_TO_SPACE = 1; ++ PIXEL_SHUFFLE = 2; ++ ++ } ++ ReorganizationType mode = 1; ++ uint64 blockSize = 2; /// must be greater than 1 ++ ++} ++ ++/** ++ * A layer that slices the input data along axis = -1 or -2 or -3. ++ * For general slice along any axis, please see SliceStaticLayer/SliceDynamicLayer. ++ * ++ * .. code:: ++ * ++ * y = SliceLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob that can, in general, have any rank. However, depending on the value of "axis" , ++ * there may be additional rank constraints. ++ * Output ++ * A blob with the same rank as the input. ++ * ++ * Sliced section is taken from the interval ``[startIndex, endIndex)``, i.e. ++ * startIndex is inclusive while endIndex is exclusive. ++ * stride must be positive and represents the step size for slicing. ++ * Negative indexing is supported for startIndex and endIndex. ++ * -1 denotes N-1, -2 denotes N-2 and so on, where N is the length of the dimension to be sliced. ++ * ++ */ ++message SliceLayerParams { ++ ++ int64 startIndex = 1; /// start of the sliced section. Inclusive. ++ int64 endIndex = 2; /// end of sliced section. Exclusive. ++ uint64 stride = 3; /// The step size. Must be positive. ++ ++ enum SliceAxis { ++ ++ CHANNEL_AXIS = 0; ++ HEIGHT_AXIS = 1; ++ WIDTH_AXIS = 2; ++ ++ } ++ // The following mapping is used for interpreting this parameter: ++ // CHANNEL_AXIS => axis = -3, input must have rank at least 3. ++ // HEIGHT_AXIS => axis = -2, input must have rank at least 2. ++ // WIDTH_AXIS => axis = -1 ++ SliceAxis axis = 4; ++ ++} ++ ++/** ++ * A layer that reduces the input using a specified operation. ++ * ++ * .. code:: ++ * ++ * y = ReduceLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob that can, in general, have any rank. However, depending on the value of "axis" , ++ * there may be additional rank constraints. ++ * Output ++ * A blob with the same rank as the input, which has 1s on the dimensions specified in the parameter "axis" ++ * ++ * Values supported for axis are [-1], [-2], [-3], [-2,-1], [-3,-2,-1] ++ * and the equivalent positive values (depending on the rank of the input) ++ * For mode == 'ArgMax', axis must be [-1] or [-2] or [-3]. ++ */ ++message ReduceLayerParams { ++ ++ /* ++ * The following reduction operations are supported ++ * and are applied on the specified axis of the input array: ++ * ++ * ``SUM`` ++ * Sum of all elements ++ * ++ * .. math:: \sum{x_i} ++ * ++ * ``AVG`` ++ * Sum of all elements divided by the number of elements ++ * ++ * .. math:: \dfrac{\sum^n{x_i}}{n} ++ * ++ * ``PROD`` ++ * Product of all elements ++ * ++ * .. math:: \prod{x_i} ++ * ++ * ``LOGSUM`` ++ * Sum of the natural logarithm of all elements ++ * ++ * .. math:: \sum{\ln{(x_i + \epsilon)}} ++ * ++ * ``SUMSQUARE`` ++ * Sum of squares of all elements ++ * ++ * .. math:: \sum{x^2} ++ * ++ * ``L1`` ++ * L1 normalization of all elements ++ * ++ * .. math:: ||x||_1 = \sum{|x_i|} ++ * ++ * ``L2`` ++ * L2 normalization of all elements ++ * ++ * .. math:: ||x||_2 = \sqrt{\sum{x_i^2}} ++ * ++ * ``MAX`` ++ * Maximum of all elements ++ * ++ * .. math:: \text{max}(x_i) ++ * ++ * ``MIN`` ++ * Minumum of all elements ++ * ++ * .. math:: \text{min}(x_i) ++ * ++ * ``ARGMAX`` ++ * Argument of the maximum of all elements ++ * ++ * .. math:: \text{argmax}(x_i) ++ * ++ */ ++ enum ReduceOperation { ++ ++ SUM = 0; ++ AVG = 1; ++ PROD = 2; ++ LOGSUM = 3; ++ SUMSQUARE = 4; ++ L1 = 5; ++ L2 = 6; ++ MAX = 7; ++ MIN = 8; ++ ARGMAX = 9; /// only supported with axis = C, H or W. ++ ++ } ++ ReduceOperation mode = 1; /// Specifies function used to reduce. ++ ++ /** ++ * Used if mode is ``LOGSUM``. ++ * Defaults to ``1e-6`` if not set or is set to ``0``. ++ */ ++ float epsilon = 2; ++ ++ enum ReduceAxis { ++ ++ CHW = 0; ++ HW = 1; ++ C = 2; ++ H = 3; ++ W = 4; ++ ++ } ++ ++ // The following mapping is used for interpreting this parameter: ++ // CHW = axis [-3, -2, -1], input must have rank at least 3. ++ // HW = axis [-2, -1], input must have rank at least 2. ++ // C = axis [-3] ++ // H = axis [-2] ++ // W = axis [-1] ++ ReduceAxis axis = 3; ++ ++} ++ ++/** ++ * A layer that crops the spatial dimensions of an input. ++ * If two inputs are provided, the shape of the second input is used as the reference shape. ++ * ++ * .. code:: ++ * ++ * y = CropLayer(x1) or y = CropLayer(x1,x2) ++ * ++ * Requires 1 or 2 inputs and produces 1 output. ++ * ++ * Input ++ * 1 or 2 tensors, each with rank at least 3, both inputs must have equal rank. ++ * Example: ++ * - 1 input case: A blob with shape ``[C, H_in, W_in]``. ++ * - 2 input case: 1st blob with shape ``[C, H_in, W_in]``, 2nd blob with shape ``[C, H_out, W_out]``. ++ * ++ * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * ++ * Output ++ * Same rank as the inputs. ++ * e.g.: A blob with shape ``[C, H_out, W_out]``. ++ * ++ * If one input is used, output is computed as follows: ++ * ++ * .. code:: ++ * ++ * y = x1[:, topCropAmount:H_in - bottomCropAmount, leftCropAmount:W_in - rightCropAmount] ++ * ++ * topCropAmount == Height startEdgeSize == borderAmounts[0].startEdgeSize ++ * bottomCropAmount == Height endEdgeSize == borderAmounts[0].endEdgeSize ++ * leftCropAmount == Width startEdgeSize == borderAmounts[1].startEdgeSize ++ * rightCropAmount == Width endEdgeSize == borderAmounts[1].endEdgeSize ++ * ++ * H_out = H_in - topCropAmount - bottomCropAmount ++ * W_out = W_in - leftCropAmount - rightCropAmount ++ * ++ * If two inputs are used, output is computed as follows: ++ * ++ * .. code:: ++ * ++ * y = x1[:, offset[0]:offset[0] + H_out, offset[1]:offset[1] + W_out] ++ */ ++message CropLayerParams { ++ ++ /** ++ * The amounts to be cropped from the input. ++ * Used only if a single input is provided. ++ */ ++ BorderAmounts cropAmounts = 1; ++ ++ /** ++ * The offset amounts. ++ * Used only if two inputs are provided. ++ * Must be of length 2, in order ``[H, W]``. ++ */ ++ repeated uint64 offset = 5; ++ ++} ++ ++/** ++ * A layer that computes the elementwise average of the inputs. ++ * This layer has limited broadcasting support. For general broadcasting see AddBroadcastableLayer. ++ * ++ * .. code:: ++ * ++ * y = AverageLayer(x1,x2,...) ++ * ++ * Requires multiple inputs and produces 1 output. ++ * ++ * Input ++ * In general, there are no rank constraints. ++ * However, only certain set of shapes are broadcastable. For example: ++ * [B, 1, 1, 1], [B, C, 1, 1], [B, 1, H, W], [B, C, H, W] ++ * Output ++ * A blob with the same shape as each input. ++ */ ++message AverageLayerParams { ++ ++} ++ ++/** ++ * A layer that computes the elementwise maximum over the inputs. ++ * ++ * .. code:: ++ * ++ * y = MaxLayer(x1,x2,...) ++ * ++ * Requires multiple inputs and produces 1 output. ++ * ++ * Input ++ * In general, there are no rank constraints. ++ * However, only certain set of shapes are broadcastable. For example: ++ * [B, C, 1, 1], [B, C, H, W] ++ * Output ++ * A blob with the same shape as each input. ++ */ ++message MaxLayerParams { ++ ++} ++ ++/** ++ * A layer that computes the elementwise minimum over the inputs. ++ * ++ * .. code:: ++ * ++ * y = MinLayer(x1,x2,...) ++ * ++ * Requires multiple inputs and produces 1 output. ++ * ++ * Input ++ * In general, there are no rank constraints. ++ * However, only certain set of shapes are broadcastable. For example: ++ * [B, C, 1, 1], [B, C, H, W] ++ * Output ++ * A blob with the same shape as each input. ++ */ ++message MinLayerParams { ++ ++} ++ ++/** ++ * A layer that computes the dot product of two vectors. ++ * ++ * .. code:: ++ * ++ * y = DotProductLayer(x1,x2) ++ * ++ * Requires 2 inputs and produces 1 output. ++ * ++ * Input ++ * Two blobs with rank at least 3, such that the last two dimensions must be 1. ++ * e.g.: blobs with shape ``[B, C, 1, 1]``. ++ * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * ++ * Output ++ * Same rank as the input. ++ * e.g. for rank 4 inputs, output shape: [B, 1, 1, 1] ++ */ ++message DotProductLayerParams { ++ ++ /** ++ * If true, inputs are normalized first, ++ * thereby computing the cosine similarity. ++ */ ++ bool cosineSimilarity = 1; ++ ++} ++ ++/** ++ * A layer that performs mean variance normalization, along axis = -3. ++ * ++ * .. code:: ++ * ++ * y = MeanVarianceNormalizeLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with rank greater than equal to 3. ++ * Example: Rank 4 blob represents [Batch, channels, height, width] ++ * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. ++ * ++ * Output ++ * A blob with the same shape as the input. ++ * ++ * If ``acrossChannels == true`` ++ * normalization is performed on flattened input, i.e. the input is reshaped to (Batch,C), where "Batch" contains ++ * all dimensions from 0 to -4 (inclusive), and C contains dimensions -1, -2, -3. ++ * ++ * If ``acrossChannels == false`` ++ * normalization is performed within a channel, ++ * across spatial dimensions (i.e. last two dimensions). ++ */ ++message MeanVarianceNormalizeLayerParams { ++ ++ /** ++ * If true, mean and variance are computed across channels. ++ */ ++ bool acrossChannels = 1; ++ ++ /** ++ * If false, only mean is subtracted. ++ */ ++ bool normalizeVariance = 2; ++ ++ /** ++ * A small constant to avoid division by 0 while normalizing variance. ++ * Defaults to ``1e-6`` if not set or set to ``0``. ++ */ ++ float epsilon = 3; ++ ++} ++ ++/** ++ * A layer that repeats a sequence or the dimension sitting at axis = -5 ++ * ++ * .. code:: ++ * ++ * y = SequenceRepeatLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A blob with rank at least 5. ++ * e.g: shape ``[Seq, B, C, H, W]`` ++ * Output ++ * A blob with the same rank as the input. ++ * e.g.: for input shape ``[Seq, B, C, H, W]``, output shape is ``[nRepetitions * Seq, B, C, H, W]``. ++ */ ++message SequenceRepeatLayerParams { ++ ++ /** ++ * Number of repetitions. ++ * Defaults to ``1`` if not set or set to ``0``. ++ */ ++ uint64 nRepetitions = 1; ++ ++} ++ ++/// Recurrent Layers ++/// ---------------- ++ ++/* ++ * The following activations are supported with recurrent layers: ++ * - Linear ++ * - Sigmoid ++ * - Tanh ++ * - ReLU ++ * - Scaled Hyperbolic Tangent: alpha * tanh(beta * x), currently only supported for alpha = 1.7159, beta = 2/3 ++ * - Hard Sigmoid: min(max(alpha * x + beta, 0), 1), currently only supported for alpha = 0.2, beta = 0.5 ++ */ ++ ++/** ++ * A simple recurrent layer. ++ * ++ * .. code:: ++ * ++ * y_t = SimpleRecurrentLayer(x_t, y_{t-1}) ++ * ++ * Input ++ * A blob of rank 5, with shape `[Seq, Batch, inputVectorSize, 1, 1]``. ++ * This represents a sequence of vectors of size ``inputVectorSize``. ++ * Output ++ * Same rank as the input. ++ * Represents a vector of size ``outputVectorSize``. It is either the final output or a sequence of outputs at all time steps. ++ * ++ * - Output Shape: ``[1, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == false`` ++ * - Output Shape: ``[Seq, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == true`` ++ * ++ * This layer is described by the following equation: ++ * ++ * .. math:: ++ * \boldsymbol{y_t} = f(\mathrm{clip}(W \boldsymbol{x_t} + \ ++ * R \boldsymbol{y_{t-1}} + b)) ++ * ++ * - ``W`` is a 2-dimensional weight matrix ++ * (``[outputVectorSize, inputVectorSize]``, row-major) ++ * - ``R`` is a 2-dimensional recursion matrix ++ * (``[outputVectorSize, outputVectorSize]``, row-major) ++ * - ``b`` is a 1-dimensional bias vector (``[outputVectorSize]``) ++ * - ``f()`` is an activation ++ * - ``clip()`` is a function that constrains values between ``[-50.0, 50.0]`` ++ */ ++message SimpleRecurrentLayerParams { ++ ++ uint64 inputVectorSize = 1; /// The size of the input vectors. ++ uint64 outputVectorSize = 2; /// The size of the output vectors. ++ ++ /** ++ * Activations supported are Linear, Sigmoid, Tanh, ReLU, Scaled Tanh (alpha = 1.71, beta = 2/3), Hard sigmoid (alpha = 0.2, beta = 0.5) ++ */ ++ ActivationParams activation = 10; /// The activation function. ++ ++ /** ++ If false output is just the result after final state update. ++ If true, output is a sequence, containing outputs at all time steps. ++ */ ++ bool sequenceOutput = 15; ++ ++ bool hasBiasVector = 20; /// If false, no bias is added. ++ ++ WeightParams weightMatrix = 30; /// Weight matrix W. ++ WeightParams recursionMatrix = 31; /// Recursion Weight matrix R. ++ WeightParams biasVector = 32; /// Bias vector b. ++ ++ bool reverseInput = 100; ++ // If true, then the node processes the input sequence from right to left ++ ++} ++ ++/** ++ * Gated-Recurrent Unit (GRU) Layer ++ * ++ * .. code:: ++ * ++ * y_t = GRULayer(x_t, y_{t-1}) ++ * ++ * Input ++ * A blob of rank 5, with shape `[Seq, Batch, inputVectorSize, 1, 1]``. ++ * This represents a sequence of vectors of size ``inputVectorSize``. ++ * Output ++ * Same rank as the input. ++ * Represents a vector of size ``outputVectorSize``. It is either the final output or a sequence of outputs at all time steps. ++ * ++ * - Output Shape: ``[1, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == false`` ++ * - Output Shape: ``[Seq, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == true`` ++ * ++ * This layer is described by the following equations: ++ * ++ * Update Gate ++ * .. math:: ++ * \boldsymbol{z_t} = \ ++ * f(\mathrm{clip}(W_z \boldsymbol{x_t} + \ ++ * R_z \boldsymbol{y_{t-1}} + b_z) ++ * ++ * Reset Gate ++ * .. math:: ++ * \boldsymbol{r_t} = \ ++ * f(\mathrm{clip}(W_r \boldsymbol{x_t} + \ ++ * R_r \boldsymbol{y_{t-1}} + b_r)) ++ * ++ * Cell Memory State ++ * .. math:: ++ * \boldsymbol{c_t} = \ ++ * \boldsymbol{y_{t-1}} \odot \boldsymbol{r_t} ++ * ++ * Output Gate ++ * .. math:: ++ * \boldsymbol{o_t} = \ ++ * g(\mathrm{clip}(W_o \boldsymbol{x_t} + \ ++ * R_o \boldsymbol{c_t} + b_o)) ++ * ++ * Output ++ * .. math:: ++ * \boldsymbol{y_t} = \ ++ * (1 - \boldsymbol{z_t}) \odot \boldsymbol{o_t} + \ ++ * \boldsymbol{z_t} \odot \boldsymbol{y_{t-1}} ++ * ++ * - ``W_z``, ``W_r``, ``W_o`` are 2-dimensional input weight matrices ++ * (``[outputVectorSize, inputVectorSize]``, row-major) ++ * - ``R_z``, ``R_r``, ``R_o`` are 2-dimensional recursion matrices ++ * (``[outputVectorSize, outputVectorSize]``, row-major) ++ * - ``b_z``, ``b_r``, ``b_o`` are 1-dimensional bias vectors ++ * (``[outputVectorSize]``) ++ * - ``f()``, ``g()`` are activations ++ * - ``clip()`` is a function that constrains values between ``[-50.0, 50.0]`` ++ * - ``⊙`` denotes the elementwise product of matrices ++ */ ++message GRULayerParams { ++ ++ uint64 inputVectorSize = 1; /// Size of the input vectors. ++ uint64 outputVectorSize = 2; /// Size of the output vectors. ++ ++ /** ++ * 2 element array representing activations [f(), g()] in that order. ++ * Typical values used = [sigmoid, tanh]. ++ * Activations supported are Linear, Sigmoid, Tanh, ReLU, Scaled Tanh (alpha = 1.71, beta = 2/3), Hard sigmoid (alpha = 0.2, beta = 0.5) ++ */ ++ repeated ActivationParams activations = 10; ++ ++ /** ++ * If false output is just the result after final state update. ++ * If true, output is a sequence, containing outputs at all time steps. ++ */ ++ bool sequenceOutput = 15; ++ ++ /** ++ * If false, no biases (``b_z``, ``b_r``, ``b_o``) are added. ++ */ ++ bool hasBiasVectors = 20; ++ ++ WeightParams updateGateWeightMatrix = 30; /// Weight Matrix W_z. ++ WeightParams resetGateWeightMatrix = 31; /// Weight Matrix W_r. ++ WeightParams outputGateWeightMatrix = 32; /// Weight Matrix W_o. ++ ++ WeightParams updateGateRecursionMatrix = 50; /// Recursion Weight Matrix R_z. ++ WeightParams resetGateRecursionMatrix = 51; /// Recursion Weight Matrix R_r. ++ WeightParams outputGateRecursionMatrix = 52; /// Recursion Weight Matrix R_o. ++ ++ WeightParams updateGateBiasVector = 70; /// Bias vector b_z. ++ WeightParams resetGateBiasVector = 71; /// Bias vector b_r. ++ WeightParams outputGateBiasVector = 72; /// Bias vector b_o. ++ ++ /// If true, then the node processes the input sequence from right to left ++ bool reverseInput = 100; ++ ++} ++ ++/** ++ * Long short-term memory (LSTM) parameters. ++ * ++ * This is described by the following equations: ++ * ++ * Input Gate ++ * .. math:: ++ * \boldsymbol{i_t} = \ ++ * f(\mathrm{clip}(W_i \boldsymbol{x_t} + \ ++ * R_i \boldsymbol{y_{t-1}} + \ ++ * p_i \odot c_{t-1} + b_i)) ++ * ++ * Forget Gate ++ * .. math:: ++ * \boldsymbol{f_t} = \ ++ * f(\mathrm{clip}(W_f \boldsymbol{x_t} + \ ++ * R_f \boldsymbol{y_{t-1}} + \ ++ * p_f \odot c_{t-1} + b_f)) ++ * ++ * Block Input ++ * .. math:: ++ * \boldsymbol{z_t} = \ ++ * g(\mathrm{clip}(W_z \boldsymbol{x_t} + \ ++ * R_z \boldsymbol{y_{t-1}} + b_z)) ++ * ++ * Cell Memory State ++ * .. math:: ++ * \boldsymbol{c_t} = \ ++ * \boldsymbol{c_{t-1}} \odot \boldsymbol{f_t} + \ ++ * \boldsymbol{i_t} \odot \boldsymbol{z_t} ++ * ++ * Output Gate ++ * .. math:: ++ * \boldsymbol{o_t} = \ ++ * f(\mathrm{clip}(W_o \boldsymbol{x_t} + \ ++ * R_o \boldsymbol{y_{t-1}} + \ ++ * p_o \odot c_t + b_o)) ++ * ++ * Output ++ * .. math:: ++ * \boldsymbol{y_t} = \ ++ * h(\boldsymbol{c_t}) \odot \boldsymbol{o_t} ++ * ++ * - ``W_i``, ``W_f``, ``W_z``, ``W_o`` are 2-dimensional input weight matrices ++ * (``[outputVectorSize, inputVectorSize]``, row-major) ++ * - ``R_i``, ``R_f``, ``R_z``, ``R_o`` are 2-dimensional recursion matrices ++ * (``[outputVectorSize, outputVectorSize]``, row-major) ++ * - ``b_i``, ``b_f``, ``b_z``, ``b_o`` are 1-dimensional bias vectors ++ * (``[outputVectorSize]``) ++ * - ``p_``, ``p_f``, ``p_o`` are 1-dimensional peephole vectors ++ * (``[outputVectorSize]``) ++ * - ``f()``, ``g()``, ``h()`` are activations ++ * - ``clip()`` is a function that constrains values between ``[-50.0, 50.0]`` ++ * - ``⊙`` denotes the elementwise product of matrices ++ */ ++message LSTMParams { ++ ++ /** ++ * If true, output is a sequence, containing outputs at all time steps. ++ * If false, output is just the result after final state update. ++ */ ++ bool sequenceOutput = 10; ++ ++ /** ++ * If false, no biases (``b_i``, ``b_f``, ``b_z``, ``b_o``) are added. ++ */ ++ bool hasBiasVectors = 20; ++ ++ /** ++ * If true, a vector of ``1`` values is added to ``b_f``. ++ */ ++ bool forgetBias = 30; ++ ++ /** ++ * If true, peephole vectors are included. ++ */ ++ bool hasPeepholeVectors = 40; ++ ++ /** ++ * If the coupled Input and Forget flag is on, the behaviour of ++ * ``c_t`` is changed to the following (i.e. forget gate is not used): ++ * ++ * .. math:: ++ * \boldsymbol{c_t} = \ ++ * \boldsymbol{c_{t-1}} \odot (1 - \boldsymbol{i_t}) + \ ++ * \boldsymbol{i_t} \odot \boldsymbol{z_t} ++ * ++ */ ++ bool coupledInputAndForgetGate = 50; ++ ++ /** ++ * Places a limit on the maximum and minimum values of ``c_t``. ++ * c_t = min(c_t, cellClipThreshold) ++ * c_t = max(c_t, -cellClipThreshold) ++ * If 0, it is set to its default value = 50.0. ++ */ ++ float cellClipThreshold = 60; ++ ++} ++ ++/** ++ * Weights for long short-term memory (LSTM) layers ++ */ ++message LSTMWeightParams { ++ ++ WeightParams inputGateWeightMatrix = 1; /// Weight Matrix W_i. ++ WeightParams forgetGateWeightMatrix = 2; /// Weight Matrix W_f. ++ WeightParams blockInputWeightMatrix = 3; /// Weight Matrix W_z. ++ WeightParams outputGateWeightMatrix = 4; /// Weight Matrix W_o. ++ ++ WeightParams inputGateRecursionMatrix = 20; /// Recursion Weight Matrix R_i. ++ WeightParams forgetGateRecursionMatrix = 21; /// Recursion Weight Matrix R_f. ++ WeightParams blockInputRecursionMatrix = 22; /// Recursion Weight Matrix R_z. ++ WeightParams outputGateRecursionMatrix = 23; /// Recursion Weight Matrix R_o. ++ ++ //biases: ++ WeightParams inputGateBiasVector = 40; /// Bias vector b_i. ++ WeightParams forgetGateBiasVector = 41; /// Bias vector b_f. ++ WeightParams blockInputBiasVector = 42; /// Bias vector b_z. ++ WeightParams outputGateBiasVector = 43; /// Bias vector b_o. ++ ++ //peepholes: ++ WeightParams inputGatePeepholeVector = 60; /// Peephole vector p_i. ++ WeightParams forgetGatePeepholeVector = 61; /// Peephole vector p_f. ++ WeightParams outputGatePeepholeVector = 62; /// Peephole vector p_o. ++ ++} ++ ++/** ++ * A unidirectional long short-term memory (LSTM) layer. ++ * ++ * .. code:: ++ * ++ * (y_t, c_t) = UniDirectionalLSTMLayer(x_t, y_{t-1}, c_{t-1}) ++ * ++ * Input ++ * A blob of rank 5, with shape `[Seq, Batch, inputVectorSize, 1, 1]``. ++ * This represents a sequence of vectors of size ``inputVectorSize``. ++ * Output ++ * Same rank as the input. ++ * Represents a vector of size ``outputVectorSize``. It is either the final output or a sequence of outputs at all time steps. ++ * ++ * - Output Shape: ``[1, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == false`` ++ * - Output Shape: ``[Seq, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == true`` ++ * ++ */ ++message UniDirectionalLSTMLayerParams { ++ ++ uint64 inputVectorSize = 1; /// Size of the input vectors. ++ uint64 outputVectorSize = 2; /// Size of the output vectors. ++ ++ /** ++ * 3 element array representing activations [f(),g(),h()] in that order. ++ * Typical values used = [sigmoid, tanh, tanh]. ++ * Activations supported are Linear, Sigmoid, Tanh, ReLU, Scaled Tanh (alpha = 1.71, beta = 2/3), Hard sigmoid (alpha = 0.2, beta = 0.5) ++ */ ++ repeated ActivationParams activations = 10; ++ ++ LSTMParams params = 15; ++ ++ LSTMWeightParams weightParams = 20; /// Weights, biases and peepholes. ++ ++ /// If true, then the node processes the input sequence from right to left ++ bool reverseInput = 100; ++ ++} ++ ++/** ++ * Bidirectional long short-term memory (LSTM) layer ++ * ++ * .. code:: ++ * ++ * (y_t, c_t, y_t_reverse, c_t_reverse) = BiDirectionalLSTMLayer(x_t, y_{t-1}, c_{t-1}, y_{t-1}_reverse, c_{t-1}_reverse) ++ * ++ * Input ++ * A blob of rank 5, with shape `[Seq, Batch, inputVectorSize, 1, 1]``. ++ * This represents a sequence of vectors of size ``inputVectorSize``. ++ * Output ++ * Same rank as the input. ++ * Represents a vector of size ``2 * outputVectorSize``. It is either the final output or a sequence of outputs at all time steps. ++ * ++ * - Output Shape: ``[1, Batch, 2 * outputVectorSize, 1, 1]`` , if ``sequenceOutput == false`` ++ * - Output Shape: ``[Seq, Batch, 2 * outputVectorSize, 1, 1]`` , if ``sequenceOutput == true`` ++ * ++ * ++ * The first LSTM operates on the input sequence in the forward direction. ++ * The second LSTM operates on the input sequence in the reverse direction. ++ * ++ * Example: given the input sequence ``[x_1, x_2, x_3]``, ++ * where ``x_i`` are vectors at time index ``i``: ++ * ++ * The forward LSTM output is ``[yf_1, yf_2, yf_3]``, ++ * ++ * where ``yf_i`` are vectors of size ``outputVectorSize``: ++ * ++ * - ``yf_1`` is the output at the end of sequence {``x_1``} ++ * - ``yf_2`` is the output at the end of sequence {``x_1``, ``x_2``} ++ * - ``yf_3`` is the output at the end of sequence {``x_1``, ``x_2``, ``x_3``} ++ * ++ * The backward LSTM output: ``[yb_1, yb_2, yb_3]``, ++ * ++ * where ``yb_i`` are vectors of size ``outputVectorSize``: ++ * ++ * - ``yb_1`` is the output at the end of sequence {``x_3``} ++ * - ``yb_2`` is the output at the end of sequence {``x_3``, ``x_2``} ++ * - ``yb_3`` is the output at the end of sequence {``x_3``, ``x_2``, ``x_1``} ++ * ++ * Output of the bi-dir layer: ++ * ++ * - if ``sequenceOutput = True`` : { ``[yf_1, yb_3]``, ``[yf_2, yb_2]``, ``[yf_3, yb_1]`` } ++ * - if ``sequenceOutput = False`` : { ``[yf_3, yb_3]`` } ++ */ ++message BiDirectionalLSTMLayerParams { ++ ++ /** ++ * Size of the input vectors. ++ */ ++ uint64 inputVectorSize = 1; ++ /** ++ * Size of the outputs vectors. ++ * It is same for both forward and backward LSTMs. ++ */ ++ uint64 outputVectorSize = 2; ++ ++ /** ++ * 3 element array representing activations [f(),g(),h()] in that order. ++ * Typical values used = [sigmoid, tanh, tanh]. ++ * Activations supported are Linear, Sigmoid, Tanh, ReLU, Scaled Tanh (alpha = 1.71, beta = 2/3), Hard sigmoid (alpha = 0.2, beta = 0.5) ++ */ ++ repeated ActivationParams activationsForwardLSTM = 10; ++ /** ++ * Currently, backward LSTM activations ++ * must be same as the ones for the forward LSTM. ++ */ ++ repeated ActivationParams activationsBackwardLSTM = 11; ++ ++ /** ++ * Common parameters shared by the forward and backward LSTMs. ++ */ ++ LSTMParams params = 15; ++ ++ /** ++ * Weights and biases. ++ * Must be a length 2 message, ++ * for the forward and backward LSTM respectively. ++ */ ++ repeated LSTMWeightParams weightParams = 20; ++ ++} ++ ++message CustomLayerParams { ++ ++ message CustomLayerParamValue { ++ oneof value { ++ double doubleValue = 10; ++ string stringValue = 20; ++ int32 intValue = 30; ++ int64 longValue = 40; ++ bool boolValue = 50; ++ } ++ } ++ ++ string className = 10; // The name of the class (conforming to MLCustomLayer) corresponding to this layer ++ repeated WeightParams weights = 20; // Any weights -- these are serialized in binary format and memmapped at runtime ++ map parameters = 30; // these may be handled as strings, so this should not be large ++ string description = 40; // An (optional) description of the layer provided by the model creator. This information is displayed when viewing the model, but does not affect the model's execution on device. ++ ++} ++ ++/** ++ * A layer that rearranges the dimensions and data of an input. ++ * ++ * .. code:: ++ * ++ * y = TransposeLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * A N-Dimensional tensor. ++ * Output ++ * A N-Dimensional tensor of the same rank but with dimensions and data permuted according to axes. ++ * Shape: ``[InputShape[axis[0]], InputShape[axis[1]], ... , InputShape[axis[N-1]]]`` ++ * ++ * Examples: ++ * ++ * - If ``axes`` is set to ``[3, 1, 2, 0]`` and the input shape is ``[6,7,8,9]``, ++ * then the output has shape ``[9,7,8,6]`` ++ */ ++ ++message TransposeLayerParams { ++ ++ /** ++ * Length of "axes" should match the rank of input & output tensor ++ * "axes" should be a permutation of "[0,1,2,...,N-1]" where N is the rank. ++ */ ++ repeated uint64 axes = 1; // ++ ++} ++ ++/** ++ * A layer that computes the matrix multiplication of two tensors with numpy-like broadcasting ++ * where the matrices reside in the last two indices of the tensor. ++ * ++ * .. code:: ++ * ++ * y = BatchedMatMul(a,b) ++ * ++ * Requires 1 or 2 inputs and produces 1 output. ++ * ++ * The first tensor, "a", must be provided as an input. The second tensor can either be an input or provided as a weight matrix parameter. ++ * ++ * Input ++ * - a: First N-Dimensional tensor ++ * - b: Second N-Dimensional tensor (either a rank-N input or a matrix, i.e. N=2, provided as a layer parameter) ++ * ++ * Output ++ * A tensor containing the matrix product of two tensors. ++ * When there are two inputs: rank is max(2, rank(a), rank(b)) ++ * When there is one input: rank is same as that of the input. ++ * ++ * This operation behaves as following: ++ * ++ * When there are two inputs: ++ * - If N >= 2 for both tensors, it is treated as a batch of matrices residing in the last two indices. ++ * All the indices, except for the last two, are broadcasted using conventional rules. ++ * - If the first tensor is 1-D, it is converted to a 2-D tensor by prepending a 1 to its shape. Eg. (D) -> (1,D) ++ * - If the second tensor is 1-D, it is converted to a 2-D tensor by appending a 1 to its shape. Eg. (D) -> (D,1) ++ * ++ * When there is one input: ++ * - The weight matrix corresponds to a matrix, of shape (X1, X2). Values of X1, X2 must be provided as layer parameters. ++ * - The input, "a", is reshaped into a matrix by combining all the leading dimensions, except the last, into a batch dimension. eg: ++ * - if "a" is rank 1 (X1,) --> (1, X1). Output shape will be (X2,) ++ * - if "a" is rank 2 (B1, X1) --> no need to reshape. Output shape will be (B1, X2) ++ * - if "a" is rank 3 (B1, B2, X1) --> (B1 * B2, X1). Output shape will be (B1, B2, X2) ++ * - etc ++ */ ++message BatchedMatMulLayerParams { ++ ++ /** ++ * If transposeA is true, it transposes the left matrix on the fly before matrix multiplication. ++ * (is ignored when there is one input) ++ */ ++ bool transposeA = 1; ++ /** ++ * If transposeB is true, it transposes the right matrix on the fly before matrix multiplication. ++ * (is ignored when there is one input) ++ */ ++ bool transposeB = 2; ++ ++ /* ++ * Following parameters are ignored when there are two inputs. ++ */ ++ ++ uint64 weightMatrixFirstDimension = 5; /// X1: same as the last dimension of the input tensor ++ uint64 weightMatrixSecondDimension = 6; /// X2: same as the last dimension of the output tensor ++ ++ bool hasBias = 7; /// Whether a bias is added or not. Supported only when there is one input. ++ ++ /* ++ * Weight matrix representing shape [X1, X2]. ++ * Values are however stored in column major order, ++ * in the "repeated float" or "bytes" fields of the message "WeightParams" ++ */ ++ WeightParams weights = 8; ++ WeightParams bias = 9; /// Bias vector [X2]. Supported only when there is one input. ++ ++ /** ++ * If set, this layer, at runtime, quantizes the floating point input blob to int8 before applying the ++ * matrix multiplication using the INT8 weight parameters provided in weights->int8RawValue. The ++ * result is then dequantized. ++ * Requires: ++ * * number of inputs to be 1 ++ * * hasBias == false ++ * * QuantizationType == LinearQuantizationParams, such that ++ * * size of the "scale" field is 1 and "bias" field is empty in "LinearQuantizationParams" ++ * * numberOfBits == 8 ++ * * weights->rawValue_size to be empty ++ */ ++ bool int8DynamicQuantize = 10; ++ ++} ++ ++/** ++ * A layer that concatenates a list of tensors along a specified axis. ++ * ++ * .. code:: ++ * ++ * y = ConcatNDLayer(x1,x2,....) ++ * ++ * Requires at least 2 input and produces 1 output. ++ * ++ * Input ++ * The rank of the input tensors must match and all dimensions also must match, except for the dimension 'axis'. ++ * ++ * ++ * Output ++ * Same rank as the input. The dimension along "axis", is the sum of the dimensions of the inputs. ++ * ++ * example: ++ * ++ * in1 : shape (3, 2), value = [[1, 2], [3, 4], [5, 6]] ++ * in2 : shape (3, 2), value = [[7, 8], [9, 10], [11, 12]] ++ * axis = 0 ++ * ++ * if interleave = False (default) ++ * output : shape (6, 2) ++ * output[0:3, :] = in1 ++ * output[3:6, :] = in2 ++ * value = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]] ++ * ++ * if interleave = True ++ * output : shape (6, 2) ++ * output[0::2, :] = in1 ++ * output[1::2, :] = in2 ++ * value = [[1, 2], [7, 8], [3, 4], [9, 10], [5, 6], [11, 12]] ++ * ++ */ ++message ConcatNDLayerParams { ++ ++ /** ++ * Dimension along which to concatenate. Supports negative values of the parameter 'axis'. ++ */ ++ int64 axis = 1; ++ ++ /** ++ * (Only available in Core ML Specification >= 5 (iOS >= 14, macOS >= 11.0) ++ * Interleave option. If True, concatenation is done via interleaving the inputs. ++ * This requires all inputs to have the exact same shape. ++ */ ++ bool interleave = 2; ++ ++ ++} ++ ++/** ++ * A layer that performs softmax normalization along a specified axis. ++ * ++ * .. code:: ++ * ++ * y = SoftmaxNDLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Output shape is same as the input. ++ */ ++message SoftmaxNDLayerParams { ++ ++ /** ++ * Dimension on which the softmax would be performed. Supports negative values of the parameter 'axis'. ++ */ ++ int64 axis = 1; ++ ++} ++ ++/** ++ * A layer that reverses specific dimensions of the input tensor. ++ * It is similar in functionality to the numpy.flip method. ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ */ ++message ReverseLayerParams { ++ ++ /** ++ * Reverses each dimension of the input tensor for which corresponding reverseDim is set to True. ++ * Requires len(reverseDim) == rank(inputTensor) ++ */ ++ repeated bool reverseDim = 1; ++ ++} ++ ++/** ++ * A layer that reverses variable length slices. ++ * ++ * Requires 2 inputs and produces 1 output. ++ * ++ * 2 inputs, in order are denoted by "data", "seq_lengths". ++ * "seq_lenghts" must be a rank 1 tensor, i.e. seq_lengths.shape = (B,) ++ * which contains the lengths of the amount of sequence to be reversed, for each element of the batch. ++ * Dimension "batchAxis" in "data" must be equal to B, i.e, ++ * data.shape[batchAxis] = B. ++ * ++ * According to the batch axis, input "data" is first divided into a batch of B inputs, ++ * each of which is flipped along the dimension "sequenceAxis", by the amount specified in ++ * "seq_lengths", the second input. ++ * ++ * e.g.: ++ * ++ * data [shape = (2,4)]: ++ * [0 1 2 3] ++ * [4 5 6 7] ++ * seq_lengths [shape = (2,)]: ++ * [3, 0] ++ * batchAxis = 0 ++ * sequenceAxis = 1 ++ * ++ * output [shape = (2,4)]: ++ * [2 1 0 3] ++ * [4 5 6 7] ++ * ++ * ++ * data [shape = (2,3,2)]: ++ * [0 1] ++ * [2 3] ++ * [4 5] (slice = 0) ++ * [6 7] ++ * [8 9] ++ * [10 11] (slice = 1) ++ * seq_lengths [shape = (2,)]: ++ * [2, 3] ++ * batchAxis = 0 ++ * sequenceAxis = 1 ++ * ++ * output [shape = (2,3,2)]: ++ * [2 3] ++ * [0 1] ++ * [4 5] (slice = 0) ++ * [10 11] ++ * [8 9] ++ * [6 7] (slice = 1) ++ * ++ * Output shape is same as the input. ++ */ ++message ReverseSeqLayerParams { ++ ++ int64 batchAxis = 1; // batch axis has to be strictly less than seq_axis ++ int64 sequenceAxis = 2; ++ ++} ++ ++/** ++ * A layer that loads data as a parameter and provides it as an output. ++ * ++ * .. code:: ++ * ++ * y = LoadConstantNDLayer() ++ * ++ * Requires no input and produces 1 output. ++ * ++ * Output: A tensor with shape as provided in the parameter "shape" ++ */ ++message LoadConstantNDLayerParams { ++ ++ /** ++ * The shape of the constant to be loaded. ++ */ ++ repeated uint64 shape = 1; ++ WeightParams data = 2; ++ ++} ++ ++/** ++ * A layer that generates an output tensor with a constant value. ++ * Input is only used to determine the shape of the output. ++ * This layer is used to allocate a tensor with a dynamic shape (that of the input) and constant value. ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * .. code:: ++ * ++ * y = FillLikeLayer(x) ++ * ++ * Input ++ * A N-Dimensional tensor, whose values are ignored. Only the shape is used to ++ * infer the shape of the output. ++ * ++ * Output ++ * A N-Dimensional tensor with the same shape as the input tensor. ++ * ++ */ ++message FillLikeLayerParams { ++ ++ float value = 1; ++ ++} ++ ++/** ++ * A layer that generates an output tensor with a constant value. ++ * This layer is used to allocate a tensor with a static shape and constant value. ++ * ++ * Requires no input and produces 1 output. ++ * ++ * .. code:: ++ * ++ * y = FillStaticLayer(x) ++ * ++ * Output ++ * A N-Dimensional tensor of shape "targetShape". ++ * ++ */ ++message FillStaticLayerParams { ++ ++ float value = 1; ++ repeated uint64 targetShape = 2; ++ ++} ++ ++/** ++ * A layer that generates an output tensor with a constant value. ++ * This layer is used to allocate a tensor with a dynamic shape (as specified by the input) and constant value. ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * .. code:: ++ * ++ * y = FillDynamicLayer(x) ++ * ++ * Input ++ * A rank 1 tensor specifying the shape of the output ++ * ++ * Output ++ * An N-Dimensional tensor with the shape specified by the values in the input tensor. ++ * ++ */ ++message FillDynamicLayerParams { ++ ++ float value = 1; ++ ++} ++ ++/** ++ * A layer that returns the elements either from tensor x or tensor y, ++ * depending on the value in the condition tensor. ++ * It is similar in functionality to the numpy.where method with 3 inputs. ++ * ++ * Requires 3 inputs and produces 1 output. ++ * Inputs, in order, are the condition tensor, x and y. ++ * ++ * for each vector index (i,...,j): ++ * output[i,...,j] = x[i,...,j] if condition[i,...,j] = True ++ * y[i,...,j] if condition[i,...,j] = False ++ * ++ * All the 3 inputs are first broadcasted to a common shape. ++ * (the shapes must be broadcastable) ++ * ++ * output.rank = max(input[0].rank, input[1].rank, input[2].rank) ++ * ++ */ ++message WhereBroadcastableLayerParams { ++ ++} ++ ++/** ++ * A layer that computes elementwise trigonometric sine function. ++ * ++ * ++ * .. code:: ++ * ++ * y = SinLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message SinLayerParams { ++ ++} ++ ++/** ++ * A layer that computes elementwise trigonometric cosine function. ++ * ++ * ++ * .. code:: ++ * ++ * y = CosLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message CosLayerParams { ++ ++} ++ ++/** ++ * A layer that computes elementwise trigonometric tangent function. ++ * ++ * ++ * .. code:: ++ * ++ * y = TanLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message TanLayerParams { ++ ++} ++ ++/** ++ * A layer that computes elementwise trigonometric arcsine function. ++ * ++ * ++ * .. code:: ++ * ++ * y = AsinLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message AsinLayerParams { ++ ++} ++ ++/** ++ * A layer that computes elementwise trigonometric arccosine function. ++ * ++ * ++ * .. code:: ++ * ++ * y = AcosLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message AcosLayerParams { ++ ++} ++ ++/** ++ * A layer that computes elementwise trigonometric arctangent function. ++ * ++ * ++ * .. code:: ++ * ++ * y = AtanLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message AtanLayerParams { ++ ++} ++ ++/** ++ * A layer that computes elementwise trigonometric hyperbolic sine function. ++ * ++ * ++ * .. code:: ++ * ++ * y = SinhLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message SinhLayerParams { ++ ++} ++ ++/** ++ * A layer that computes elementwise trigonometric hyperbolic cosine function. ++ * ++ * ++ * .. code:: ++ * ++ * y = CoshLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message CoshLayerParams { ++ ++} ++ ++/** ++ * A layer that computes elementwise trigonometric hyperbolic tangent function. ++ * ++ * ++ * .. code:: ++ * ++ * y = TanhLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message TanhLayerParams { ++ ++} ++ ++/** ++ * A layer that computes elementwise trigonometric hyperbolic arcsine function. ++ * ++ * ++ * .. code:: ++ * ++ * y = AsinhLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message AsinhLayerParams { ++ ++} ++ ++/** ++ * A layer that computes elementwise trigonometric hyperbolic arccosine function. ++ * ++ * ++ * .. code:: ++ * ++ * y = AcoshLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message AcoshLayerParams { ++ ++} ++ ++/** ++ * A layer that computes elementwise trigonometric hyperbolic arctangent function. ++ * ++ * ++ * .. code:: ++ * ++ * y = AtanhLayer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message AtanhLayerParams { ++ ++} ++/** ++ * A layer that raises each element in first tensor to the power of ++ * corresponding element in the second tensor. ++ * Supports conventional numpy-like broadcasting. ++ * ++ * .. code:: ++ * ++ * y = PowBroadcastableLayer(x) ++ * ++ * Requires 2 inputs and produces 1 output. ++ * ++ * Input ++ * - First N-Dimensional tensor ++ * - Second N-Dimensional tensor ++ * ++ * Output ++ * An N-Dimensional tensor with the broadcast shape. ++ * ++ */ ++message PowBroadcastableLayerParams { ++ ++} ++ ++/** ++ * A layer that computes the exponential of all elements in the input tensor, with the base 2. ++ * ++ * ++ * .. code:: ++ * ++ * y = Exp2Layer(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message Exp2LayerParams { ++ ++} ++ ++/** ++ * A layer that returns a tensor containing the indices of all non-zero ++ * elements of input tensor. ++ * It is similar in functionality to the numpy.where method with 1 input. ++ * ++ * Requires 1 input and produces 1 output. ++ * Output is of rank 2, of shape (N,R), ++ * where N is the number of non-zero elements in the input and R is the rank of the input. ++ * ++ * Output contains indices represented in the multi-index form ++ * ++ * e.g.: ++ * input {shape = (4,)}: ++ * [0 1 0 2] ++ * output {shape = (2,1)}: ++ * [1] ++ * [3] ++ * ++ * ++ * input {shape = (3, 3)}: ++ * [1 2 1] ++ * [0 2 2] ++ * [2 1 0] ++ * output {shape = (7,1)}: ++ * [0. 0.] ++ * [0. 1.] ++ * [0. 2.] ++ * [1. 1.] ++ * [1. 2.] ++ * [2. 0.] ++ * [2. 1.] ++ * ++ */ ++message WhereNonZeroLayerParams { ++ ++} ++ ++/** ++ * A layer that copies a tensor setting everything outside a central band in ++ * each inner-most matrix to zero. ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Parameters for matrix_band_part layer ++ * band(m, n) = (num_lower < 0 || (m-n) <= num_lower) && (num_upper < 0 || (n-m) <= num_upper). ++ * output[i, j, k, ..., m, n] = band(m, n) * input[i, j, k, ..., m, n] ++ * ++ * ++ * Output shape is same as the input shape. ++ * Rank of the input must be at least 2. ++ * For rank higher than 2, the last 2 dimensions are treated as the matrix, while the rest are treated as batch. ++ */ ++message MatrixBandPartLayerParams { ++ ++ int64 numLower = 1; ++ int64 numUpper = 2; ++ ++} ++ ++/** ++ * A layer that copies a tensor setting everything outside upper triangular to zero. ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Output shape is same as the input shape. ++ * Rank of the input must be at least 2. ++ * For rank higher than 2, the last 2 dimensions are treated as the matrix, while the rest are treated as batch. ++ */ ++message UpperTriangularLayerParams { ++ ++ int64 k = 1; // Diagonal below which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above ++ ++} ++ ++/** ++ * A layer that copies a tensor setting everything outside lower triangular to zero. ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Output shape is same as the input shape. ++ * Rank of the input must be at least 2. ++ * For rank higher than 2, the last 2 dimensions are treated as the matrix, while the rest are treated as batch. ++ */ ++message LowerTriangularLayerParams { ++ ++ int64 k = 1; // Diagonal above which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above ++ ++} ++ ++/** ++ * ++ * A layer that broadcasts a tensor to a new shape. ++ * ++ * Requires 2 inputs and produces 1 output. ++ * ++ * First input is broadcast to produce the output, while the second input is only ++ * used to determine the shape of the output. Values of second input are not used. ++ * ++ * Output is a tensor with the same shape as the second input. ++ * ++ */ ++message BroadcastToLikeLayerParams { ++ ++} ++ ++/** ++ * ++ * A layer that broadcasts a tensor to a new shape. ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Output tensor is the broadcasted version of the input and has shape as specified in the ++ * parameter "targetShape". ++ */ ++message BroadcastToStaticLayerParams { ++ ++ repeated uint64 targetShape = 1; ++ ++} ++ ++/** ++ * ++ * A layer that broadcasts a tensor to a new shape. ++ * ++ * Requires 2 inputs and produces 1 output. ++ * ++ * First input is the one that is broadcasted to produce the output. ++ * Second input is a rank 1 tensor specifying the shape of the output. ++ * Output tensor has shape as specified by the values in the 2nd input tensor. ++ */ ++message BroadcastToDynamicLayerParams { ++ ++} ++ ++/** ++ * A layer that performs element-wise addition operation with broadcast support. ++ * ++ * Requires 2 inputs and produces 1 output. ++ */ ++message AddBroadcastableLayerParams { ++ ++} ++ ++/** ++ * A layer that performs element-wise maximum operation with broadcast support. ++ * ++ * Requires 2 inputs and produces 1 output. ++ */ ++message MaxBroadcastableLayerParams { ++ ++} ++ ++/** ++ * A layer that performs element-wise minimum operation with broadcast support. ++ * ++ * Requires 2 inputs and produces 1 output. ++ */ ++message MinBroadcastableLayerParams { ++ ++} ++ ++/** ++ * A layer that performs element-wise modular operation with broadcast support. ++ * ++ * Requires 2 inputs and produces 1 output. ++ */ ++message ModBroadcastableLayerParams { ++ ++} ++ ++/** ++ * A layer that performs element-wise floor division operation with broadcast support. ++ * ++ * Requires 2 inputs and produces 1 output. ++ */ ++message FloorDivBroadcastableLayerParams { ++ ++} ++ ++/** ++ * A layer that performs element-wise subtract operation with broadcast support. ++ * ++ * Requires 2 inputs and produces 1 output. ++ */ ++message SubtractBroadcastableLayerParams { ++ ++} ++ ++/** ++ * A layer that performs element-wise multiply operation with broadcast support. ++ * ++ * Requires 2 inputs and produces 1 output. ++ */ ++message MultiplyBroadcastableLayerParams { ++ ++} ++ ++/** ++ * A layer that performs element-wise division operation with broadcast support. ++ * ++ * Requires 2 inputs and produces 1 output. ++ */ ++message DivideBroadcastableLayerParams { ++ ++} ++ ++/** ++ * Gather layer that gathers elements from the first input, along a specified axis, ++ * at indices specified in the second input. ++ * It is similar in functionality to the numpy.take method. ++ * ++ * Requires 2 inputs and produces 1 output. ++ * ++ * Given two inputs, 'data' and 'indices', gather the slices of 'data' ++ * and store into output. ++ * e.g. ++ * for i in [0, length(indices) - 1] ++ * output[i] = data[indices[i]] (1-D case, axis=0) ++ * ++ * if axis = 0: ++ * for each vector index (i,...,j) ++ * output[i,...,j,:,..,:] = data[indices[i,...,j],:,..,:] ++ * ++ * output.rank = (data.rank - 1) + indices.rank ++ * ++ * Negative indices and negative axis are supported. ++ * ++ * e.g: ++ * ++ * data shape = (2, 3) ++ * indices shape = (6, 8) ++ * axis = 0 ++ * output shape = (6, 8) + (3,) = (6, 8, 3) ++ * ++ * data shape = (2, 3, 5) ++ * indices shape = (6, 8) ++ * axis = 1 ++ * output shape = (2,) + (6, 8) + (5,) = (2, 6, 8, 5) ++ * ++ */ ++message GatherLayerParams { ++ ++ int64 axis = 1; ++ ++} ++ ++/* ++ * Scatter accumulation mode. ++ */ ++enum ScatterMode { ++ ++ SCATTER_UPDATE = 0; ++ SCATTER_ADD = 1; /// add ++ SCATTER_SUB = 2; /// subtract ++ SCATTER_MUL = 3; /// multiply ++ SCATTER_DIV = 4; /// divide ++ SCATTER_MAX = 5; /// maximum ++ SCATTER_MIN = 6; /// minimum ++ ++} ++ ++/* ++ * A layer that scatters data into a new tensor according to indices from the input. ++ * This is the inverse operation of Gather. ++ * ++ * Requires 3 inputs and produces 1 output. ++ * ++ * Output is initialized with the first input. ++ * Then updated with the values in the third input, at indices specified by the second input. ++ * ++ * An example when axis=0: ++ * Given three inputs, in order, "container", "indices", "updates", where ++ * ++ * - "container" is a rank R+1 tensor of shape [D_0, D_1, ..., D_R], which ++ * contains D_0 number of tensors, each with shape [D_1, ..., D_R]. ++ * ++ * - "indices" is a rank 1 tensor with shape [N], where N is the number of updates. ++ * The values in this tensor must be in the range [0, D_0 - 1]. (negative indexing is supported) ++ * ++ * - "updates" is a rank R+1 tensor with shape [N, D_1, ..., D_R], which represents ++ * a total number of N tensors, each of shape [D_1, ..., D_R]. ++ * ++ * The effect of this operation is as follows: ++ * ++ * output = container; ++ * For each i in 0, ..., N - 1 ++ * output[indices[i], :, ..., :] = updates[i, :, ..., :] // if mode == "SCATTER_UPDATE" ++ * ++ * or ++ * For each i in 0, ..., N - 1 ++ * output[indices[i], :, ..., :] += updates[i, :, ..., :] // if mode == "SCATTER_ADD" ++ * ++ * etc ++ * ++ * When "indices" is a tensor of rank greater than 1, the equation becomes (for axis=0): ++ * For each vector index (i,...,j) ++ * output[indices[i,...,j],...] -= updates[i,...,j,...] // if mode == "SCATTER_SUB" ++ * ++ * ++ * The output has the same shape as the first input. ++ * "indices" input must have rank less than or equal to the "updates" input and its shape ++ * must be a subset of the the shape of the "updates" input. ++ * ++ * e.g: ++ * ++ * container shape = (4, 3) ++ * indices shape = (5, 2, 3) ++ * updates shape = (4, 5, 2, 3) ++ * axis = 1 ++ * output shape = (4, 3) ++ * ++ * container shape = (4, 4, 3) ++ * indices shape = (6,) ++ * updates shape = (4, 6, 3) ++ * axis = -2 ++ * output shape = (4, 4, 3) ++ * ++ * container shape = (5,) ++ * indices shape = (5, 7, 5, 6) ++ * updates shape = (5, 7, 5, 6) ++ * axis = -1 ++ * output shape = (5,) ++ */ ++ ++message ScatterLayerParams { ++ ++ int64 axis = 1; ++ ScatterMode mode = 2; /// mode of accumulation. ++ ++} ++ ++/** ++ * A layer that gathers elements from the first input, 'params', at the multi-indices specified ++ * by the second input, 'indices'. ++ * ++ * Requires 2 inputs and produces 1 output. ++ * ++ * 'params' = input[0], 'indices' = input[1] ++ * ++ * 'indices' is a rank K+1 tensor of shape [I_0, I_1, .., I_(K-1), I_K] which is viewed as a collection of ++ * indices of (I_0 * I_1 * ... * I_(K-1)) points in the I_K dimensional space. For instance, the multi-index of the first point ++ * is indices[0,0,...,0,:]. ++ * ++ * Here is how the output is constructed: ++ * ++ * for i = 0,1,...,(I_0-1) ++ * ... ++ * for j = 0,1,....,(I_(K-1)-1) ++ * output[i,....,j,:,:,..,:] = params[indices[i,...,j,:], :,:,..,:] ++ * ++ * Hence, output shape is [I_0, I_1,...,I(K-1)] + params.shape[I_K:] ++ * ++ * output.rank = indices.rank - 1 + params.rank - indices.shape[-1] ++ * ++ * e.g: ++ * ++ * input[0] shape = (4, 2, 3, 4) ++ * input[1] shape = (6, 2) ++ * output shape = (6,) + (3, 4) = (6, 3, 4) ++ * ++ * input[0] shape = (3, 3, 3, 4, 7) ++ * input[1] shape = (3, 5) ++ * output shape = (3,) + () = (3,) ++ * ++ * input[0] shape = (5, 3, 2, 5) ++ * input[1] shape = (2, 7, 3, 2) ++ * output shape = (2, 7, 3) + (2, 5) = (2, 7, 3, 2, 5) ++ * ++ */ ++message GatherNDLayerParams { ++ ++} ++ ++/* ++ * A layer that scatters data into a new tensor according to multi-indices from the input. ++ * This is the inverse operation of GatherND. ++ * ++ * Requires 3 inputs and produces 1 output. ++ * 3 inputs, in order are denoted as "container", "indices", "updates". ++ * ++ * 'indices' is a rank K+1 tensor of shape [I_0, I_1, .., I_(K-1), I_K] which is viewed as a collection of ++ * indices of (I_0 * I_1 * ... * I_(K-1)) points in the I_K dimensional space. For instance, the multi-index of the first point ++ * is indices[0,0,...,0,:]. ++ * ++ * container.rank >= I_K ++ * updates.rank = K + (container.rank - I_K) ++ * shape of 'updates' = [I_0, I_1,...,I(K-1)] + container.shape[I_K:] ++ * ++ * output = container ++ * For each vector index (i,...,j) s.t. 0<=i shape: (3,) ++ * reps = N/A [Ignored] ++ * output shape = (2, 8, 12) ++ * ++ */ ++message TileLayerParams { ++ ++ repeated uint64 reps = 1; ++ ++} ++ ++/** ++ * A layer that returns the shape of an input tensor. ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input: a tensor. ++ * Output: a vector of length R, where R is the rank of the input tensor ++ * Output is always a rank 1 tensor. ++ */ ++message GetShapeLayerParams { ++ ++} ++ ++/** ++ * A layer that computes the Gauss error function, ++ * which is defined as: ++ * ++ * .. math:: ++ * f(x) = \dfrac{1}{\sqrt{\pi}}\int_{-x}^{x}{e^{-t^2}dt} ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ */ ++message ErfLayerParams { ++ ++} ++ ++/** ++ * A layer that evaluates the Gaussian Error Linear Unit (GELU) activation. ++ * Following equations are used to compute the activation based on the value of the "mode" parameter: ++ * ++ * mode == 'EXACT': ++ * .. math:: ++ * f(x) = 0.5x\left ( 1+\rm{erf}\left ( \frac{x}{\sqrt{2}} \right ) \right ) ++ * ++ * mode == 'TANH_APPROXIMATION': ++ * .. math:: ++ * f(x) = 0.5x\left ( 1+\rm{tanh}\left ( \sqrt{2/\pi}\left ( x + 0.044715x^3 \right ) \right ) \right ) ++ * ++ * mode == 'SIGMOID_APPROXIMATION': ++ * .. math:: ++ * f(x) = x*\rm{sigmoid}(1.702x) ++ * ++ * Requires 1 input and produces 1 output. ++ * Output shape is same as the input. ++ * ++ */ ++message GeluLayerParams { ++ ++ enum GeluMode { ++ ++ EXACT = 0; ++ TANH_APPROXIMATION = 1; ++ SIGMOID_APPROXIMATION = 2; ++ ++ } ++ ++ GeluMode mode = 1; /// mode of GELU operation. ++ ++} ++ ++/** ++ * RangeStatic layer that returns a tensor that contains evenly spaced values. ++ * It is similar in functionality to the numpy.arange method. ++ * ++ * Requires no input and produces 1 output. ++ * Output is a rank 1 tensor. ++ */ ++message RangeStaticLayerParams { ++ ++ float endValue = 1; ++ float startValue = 2; ++ float stepSizeValue = 3; ++ ++} ++ ++/** ++ * A layer that returns a tensor that contains evenly spaced values. ++ * Its functionality is similar to the numpy.arange method. ++ * ++ * Requires at least 1 input, up to a maximum of 3 inputs. ++ * Produces 1 output, which is a rank 1 tensor. ++ * ++ * Each input must be a scalar, or rank 1 and shape (1,). ++ * ++ * The first input represents the "endValue". ++ * The second input, if present, corresponds to "startValue". In this case the value of the "startValue" parameter is ignored. ++ * The third input, if present, corresponds to "stepSizeValue". In this case the value of the "stepSizeValue" parameter is ignored. ++ * ++ */ ++message RangeDynamicLayerParams { ++ ++ float startValue = 2; ++ float stepSizeValue = 3; ++ ++} ++ ++/** ++ * A layer that returns a tensor containing all windows of size ``windowSize`` ++ * separated by ``step`` along the dimension ``axis``. ++ * ++ * .. code:: ++ * ++ * y = SlidingWindows(x) ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * Input ++ * An N-Dimensional tensor. ++ * ++ * Output ++ * An (N+1)-Dimensional tensor. ++ * ++ * This operation behaves as following: ++ * - if axis = 0 & input is rank 1 (L,). Output shape will be (M, W). ++ * - if axis = 1 & input is rank 3 (B1, L, C1). Output shape will be (B1, M, W, C1) ++ * - if axis = 2 & input is rank 5 (B1, B2, L, C1, C2) --> (B1 * B2, L, C1 * C2) --> (B1 * B2, M, W, C1 * C2). Output shape will be (B1, B2, M, W, C1, C2) ++ * - etc. ++ * where ++ * - L, C, B refer to input length, feature dimension length & batch size respectively ++ * - W is the window size. ++ * - M is the number of windows/slices calculated as M = (L - W) / step + 1 ++ */ ++message SlidingWindowsLayerParams { ++ ++ int64 axis = 1; ++ uint64 windowSize = 2; ++ uint64 step = 3; ++ ++} ++ ++/** ++ * A layer that applies layer normalization over the input tensor. ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * output = gamma * (input - computed_mean) / (sqrt(computed_variance + eps)) + beta ++ * ++ * Parameters ++ * normalizedShape: subset of the input shape, along with layer norm is performed, rest of the input shape is treated as the batch dimension. The mean and variance are computed for the input, over the last few dimensions as specified by the normalizedShape parameter. ++ * gamma: must have shape = "normalizedShape" ++ * beta: must have shape = "normalizedShape" ++ * eps: small constant to avoid division by 0 ++ * ++ * Output shape is same as the input. ++ * ++ * e.g.: ++ * input shape = (10,5) ++ * normalized shape = (5,) or (10,5) ++ * ++ * input shape = (10,5,6,7) ++ * normalized shape = (7,) or (6,7) or (5,6,7) or (10,5,6,7) ++ */ ++message LayerNormalizationLayerParams { ++ ++ repeated int64 normalizedShape = 1; ++ float eps = 2; ++ WeightParams gamma = 3; ++ WeightParams beta = 4; ++ ++} ++ ++/** ++ * Non maximum suppression (NMS) layer. ++ * Applies the non maximum suppression algorithm to input bounding box coordinates. ++ * The effect of this layer is similar to the functionality of the "NonMaximumSuppression" ++ * model type (for details please see NonMaximumSuppression.proto) with a couple of differences. ++ * One, this is a layer in a neural network model, whereas that is a different model type. Second, ++ * this layer supports a batch of bounding boxes. ++ * ++ * The NMS layer requires at least 2 inputs, and up to a maximum of 5 inputs. It produces 4 outputs. ++ * Following is the description of inputs and outputs: ++ * ++ * input 1, shape (B,N,4): coordinates of N boxes, for a batch size B. ++ * input 2, shape (B,N,C): class scores for each box. C can be 1 when there is only 1 score per box, i.e., no class specific score. ++ * ++ * input 3, optional, shape (1,): IoU threshold. When present, it overwrites the value provided in layer parameter "iouThreshold". ++ * input 4, optional, shape (1,): Score threshold. When present, it overwrites the value provided in layer parameter "scoreThreshold". ++ * input 5, optional, shape (1,): Maximum number of boxes. When present, it overwrites the value provided in layer parameter "maxBoxes". ++ * ++ * output 1, shape (B,maxBoxes,4): box coordinates, corresponding to the surviving boxes. ++ * output 2, shape (B,maxBoxes,C): box scores, corresponding to the surviving boxes. ++ * output 3, shape (B,maxBoxes): indices of the surviving boxes. Hence it will have values in the range [0,N-1], except for padding. ++ * output 4, shape (B,): number of boxes selected after the NMS algorithm, for each batch. ++ * ++ * When surviving boxes are less than "maxBoxes", the first 3 outputs are padded. ++ * For the first two outputs, the padding is done using values 0, whereas for the third output the ++ * padding value used is -1, since the output values represent indices. ++ * ++ * If no box survives, that is, all the scores are below the "scoreThreshold", ++ * then for that batch, number of boxes (value of the fourth output) will be 1. The first 3 outputs will ++ * correspond to the box with the highest score. This is to avoid generating an "empty" output. ++ * ++ * The four values that describe the box dimensions are (in order): ++ * ++ * - x (center location of the box along the horizontal axis) ++ * - y (center location of the box along the vertical axis) ++ * - width (size of box along the horizontal axis) ++ * - height (size of box on along the vertical axis) ++ * ++ * In each batch, ++ * the N scores for N boxes, used for suppression, are generated by taking the max of the matrix (N,C) ++ * along the columns. ++ * If "perClassSuppression" flag is false, suppression happens across all classes. ++ * If "perClassSuppression" flag is true, each box is assigned to the class with the highest ++ * score and then the suppression happens separately for boxes within the same class. ++ * ++ * Note that the 4th output can be used to dynamically slice the first 3 outputs, in case ++ * the padded outputs are not required. ++ * ++ */ ++message NonMaximumSuppressionLayerParams { ++ /** ++ * The intersection over union (IoU) threshold over which boxes are suppressed. ++ */ ++ float iouThreshold = 1; ++ ++ /** ++ * Before IoU suppression is performed, boxes with class scores below this threshold are rejected. ++ */ ++ float scoreThreshold = 2; ++ ++ /** ++ * The maximum number of boxes to be given out as output. ++ * If the number of surviving boxes are less, output is padded up to this number. ++ */ ++ uint64 maxBoxes = 3; ++ ++ /** ++ * If true, suppression is performed independently within boxes of each class. ++ */ ++ bool perClassSuppression = 4; ++} ++ ++/** ++ * A layer that performs element-wise clamped ReLU operation. ++ * ++ * Requires 1 input and produces 1 output. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = \begin{cases} ++ * \text{min}(\text{beta},x) \;\; \text{if} \;\; x \geq 0\\ ++ * \text{min}(\text{beta} ,\text{alpha}\cdot x) \;\; \text{if} \;\; x<0 ++ * \end{cases} ++ * ++ * Output shape is same as the input. ++ * ++ * Available (iOS >= 14, macOS >= 11.0, watchOS >= 7) ++ */ ++message ClampedReLULayerParams { ++ ++ float alpha = 1; ++ float beta = 2; ++ ++} ++ ++/** ++* A layer that returns the indices that would sort the input tensor, along a specified axis. ++* ++* Requires 1 input and produces 1 output. ++* ++* Output has the same rank and shape as the input. ++* ++* Value of "axis" must be positive and less than the rank of the input. ++* ++* e.g.: ++* ++* input shape = (5,) ++* axis = 0 ++* input values = [3.1, 5.4, 32.9, 3.2, 77.0] ++* output shape = (5,) ++* output values = [0, 3, 1, 2, 4], descending = False ++* output values = [4, 2, 1, 3, 0], descending = True ++* ++* input shape = (2,3) ++* axis = 1 ++* input values = [[3, 5, 32], [3, 77, 6]] ++* output shape = (2,3) ++* output values = [[0, 1, 2], [0, 2, 1]], descending = False ++* output values = [[2, 1, 0], [1, 2, 0]], descending = True ++* ++*/ ++message ArgSortLayerParams { ++ ++ int64 axis = 1; /// must be between [0, input_rank - 1] ++ bool descending = 2; ++ ++} ++ ++/** ++ * A layer that does slice operation by providing size to be extracted ++ * from the given input tensor. ++ * ++ * Requires 2 inputs and produces 1 output. ++ * Rank of the output is same as the rank of the first input. ++ * ++ * The 1st input represents the tensor to be sliced. ++ * The 2nd input represents the beginning index to be sliced from. ++ * ++ * Example: ++ * Input 1: x (x.shape = (2, 3, 4)) ++ * Input 2: begin ++ * size: 2 ++ * axis: 1 ++ * ++ * Output: x[:, begin:begin+2, :] ++ * ++ */ ++message SliceBySizeLayerParams { ++ ++ int64 size = 2; ++ int64 axis = 3; ++ ++} ++ ++ ++/// Neural Network Specializations ++/// ------------------------------ ++ ++/** ++ * A neural network specialized as a classifier. ++ */ ++message NeuralNetworkClassifier { ++ ++ repeated NeuralNetworkLayer layers = 1; ++ repeated NeuralNetworkPreprocessing preprocessing = 2; ++ ++ // use this enum value to determine the input tensor shapes to the neural network, for multiarray inputs ++ NeuralNetworkMultiArrayShapeMapping arrayInputShapeMapping = 5; ++ ++ // use this enum value to determine the input tensor shapes to the neural network, for image inputs ++ NeuralNetworkImageShapeMapping imageInputShapeMapping = 6; ++ ++ NetworkUpdateParameters updateParams = 10; ++ ++ // The set of labels for every possible class. ++ oneof ClassLabels { ++ StringVector stringClassLabels = 100; ++ Int64Vector int64ClassLabels = 101; ++ } ++ ++ // The name of the output blob containing the probability of each class. ++ // In other words, the score vector. Must be a 1-D tensor with the same ++ // number and order of elements as ClassLabels. ++ string labelProbabilityLayerName = 200; ++} ++ ++ ++/** ++ * A layer that computes the one hot representation of the input. ++ * ++ * Requires 1 or 2 inputs and produces 1 output. ++ * Rank of the output is one more than the first input. ++ * If the second input is present, it is used to determine the value of "oneHotVectorSize" and the parameter "oneHotVectorSize" is ignored. ++ * ++ * Input values correspond to indices and should typically be in the range [0,"oneHotVectorSize" -1]. If it is outside this range, a vector of all "offValue" will be chosen. ++ * ++ * Typically one hot vectors contain 0s everywhere, except 1 at the index that the input corresponds to. ++ * However, instead of 0, any float value could be generated by using the "offValue" parameter. ++ * Similarly, instead of 1, any other value can be used by employing the "onValue" parameter. ++ * ++ * e.g.: ++ * input shape: (10,), "oneHotVectorSize" : 32, axis=-1, then output shape will be (10,32) ++ * input shape: (10,23), "oneHotVectorSize" : 32, axis=1, then output shape will be (10,32,23) ++ * input shape: (10,), "oneHotVectorSize" : 32, axis=0, then output shape will be (32,10) ++ * ++ * input shape: (2,), "oneHotVectorSize" : 4, axis=-1, then output shape will be (2,4) ++ * say input values = [2, 0], and "onValue" = 5, and "offValue" = -1, then output will be: ++ * [-1, -1, 5, -1 ++ * 5, -1, -1, -1] ++ * ++ * say input values = [2, -1], and "onValue" = 5, and "offValue" = -1, then output will be: ++ * [-1, -1, 5, -1 ++ * -1, -1, -1, -1] ++ * ++ * Available (iOS >= 14, macOS >= 11.0, watchOS >= 7) ++ */ ++ ++message OneHotLayerParams { ++ ++ uint64 oneHotVectorSize = 1; /// size of the one hot vector ++ int64 axis = 2; /// negative indexing is supported. It refers to the axis in the output tensor. ++ float onValue = 3; ++ float offValue = 4; ++} ++ ++ ++/** ++ * A layer that computes the cumsum values of the input along a given axis. ++ * ++ * Requires 1 or 2 inputs and produces 1 output. ++ * ++ * Output shape and rank is same as the first input. ++ * If the second input is present, it is used to determine the value of "axis" and the parameter "axis" is ignored. ++ * ++ * e.g.: ++ * Input shape = (3,), values it has: [4, 6, 7] ++ * ++ * Then output values will be: ++ * ++ * if "excludeFinalSum" = False and "reverse" = False: ++ * output values : [4, 10, 17] ++ * ++ * if "excludeFinalSum" = True and "reverse" = False: ++ * output values : [0, 4, 10] ++ * ++ * if "excludeFinalSum" = False and "reverse" = True: ++ * output values : [17, 13, 7] ++ * ++ * if "excludeFinalSum" = True and "reverse" = True: ++ * output values : [13, 7, 0] ++ * ++ * ++ * Available (iOS >= 14, macOS >= 11.0, watchOS >= 7) ++ */ ++ ++ ++message CumSumLayerParams { ++ ++ int64 axis = 1; /// negative indexing is supported ++ ++ /// if true, the first element of the output is 0, and the last element contains the sum of the input up to the penultimate value ++ /// if false, the first element of the output is same as the input and the last element is the sum of all the input values ++ /// (this behavior is reversed when "reverse" flag is True) ++ bool excludeFinalSum = 2; ++ ++ bool reverse = 3; /// if true, cumsum is performed in the opposite direction ++} ++ ++ ++/** ++ * A neural network specialized as a regressor. ++ */ ++message NeuralNetworkRegressor { ++ ++ repeated NeuralNetworkLayer layers = 1; ++ repeated NeuralNetworkPreprocessing preprocessing = 2; ++ ++ // use this enum value to determine the input tensor shapes to the neural network, for multiarray inputs ++ NeuralNetworkMultiArrayShapeMapping arrayInputShapeMapping = 5; ++ ++ // use this enum value to determine the input tensor shapes to the neural network, for image inputs ++ NeuralNetworkImageShapeMapping imageInputShapeMapping = 6; ++ ++ NetworkUpdateParameters updateParams = 10; ++ ++} ++ ++/// --------------------------------------------------------- ++/// On-device Training related messages ++/// --------------------------------------------------------- ++ ++/** ++ * Details on how the network will be updated ++ */ ++message NetworkUpdateParameters { ++ ++ repeated LossLayer lossLayers = 1; ++ Optimizer optimizer = 2; ++ Int64Parameter epochs = 3; ++ ++ /** ++ * Describes whether to shuffle the batch of data between epochs. ++ */ ++ BoolParameter shuffle = 10; ++ ++ /** ++ * The seed to be used in an associated random number generator. ++ */ ++ Int64Parameter seed = 20; ++} ++ ++/** ++ * Loss layer - categorical cross entropy and mean squared error are the only supported loss functions currently ++ */ ++message LossLayer { ++ ++ string name = 1; ++ oneof LossLayerType { ++ ++ CategoricalCrossEntropyLossLayer categoricalCrossEntropyLossLayer = 10; ++ MeanSquaredErrorLossLayer meanSquaredErrorLossLayer = 11; ++ ++ } ++ ++} ++ ++/** ++ * Categorical cross entropy loss layer ++ * Categorical cross entropy is used for single label categorization (only one category is applicable for each data point). ++ * ++ * The input is a vector of length N representing the distribution over N categories. It must be the output of a softmax. ++ * ++ * The target is a single value representing the true category or class label. If the target is the predictedFeatureName of a neural network classifier it will be inverse mapped to the corresponding categorical index for you. ++ * ++ * math: ++ * Loss_{CCE}(input, target) = -\sum_{i=1}^{N} (target == i) log( input[i] ) = - log (input[target]) ++ */ ++message CategoricalCrossEntropyLossLayer { ++ ++ string input = 1; ++ string target = 2; ++ ++} ++ ++/** ++ * Mean squared error loss layer, ++ * specifying input and target ++ */ ++message MeanSquaredErrorLossLayer { ++ ++ string input = 1; ++ string target = 2; ++ ++} ++ ++/** ++ * Optimizer - stochastic gradient descent and adam are the only supported optimizers currently ++ */ ++message Optimizer { ++ ++ oneof OptimizerType { ++ ++ SGDOptimizer sgdOptimizer = 10; ++ AdamOptimizer adamOptimizer = 11; ++ ++ } ++ ++} ++ ++/** ++ * Stochastic gradient descent optimizer, ++ * specifying configurable learning rate, mini batch size, and momentum ++ */ ++message SGDOptimizer { ++ ++ DoubleParameter learningRate = 1; ++ Int64Parameter miniBatchSize = 2; ++ DoubleParameter momentum = 3; ++ ++} ++ ++/** ++ * Adam optimizer, ++ * specifying configurable learning rate, mini batch size, betas, and eps ++ */ ++message AdamOptimizer { ++ ++ DoubleParameter learningRate = 1; ++ Int64Parameter miniBatchSize = 2; ++ DoubleParameter beta1 = 3; ++ DoubleParameter beta2 = 4; ++ DoubleParameter eps = 5; ++ ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/NonMaximumSuppression.proto b/onnxruntime/core/providers/coreml/mlmodel_format/NonMaximumSuppression.proto +new file mode 100644 +index 000000000..c98949a0c +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/NonMaximumSuppression.proto +@@ -0,0 +1,187 @@ ++// Copyright (c) 2018, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification; ++ ++/* ++* Non-maximum suppression of axis-aligned bounding boxes. ++* ++* This is used primarily for object detectors that tend to produce multiple ++* boxes around a single object. This is a byproduct of the detector's ++* robustness to spatial translation. If there are two or more bounding boxes ++* that are very similar to one another, the algorithm should return only a ++* single representative. ++* ++* Similarity between two bounding boxes is measured by intersection-over-union ++* (IOU), the fraction between the area of intersection and area of the union. ++* Here is an example where the areas can be calculated by hand by counting glyphs:: ++* ++* +-------+ +-------+ ++* | | | | ++* | +------+ +--+ | +---+ ++* | | | | | | | | ++* +-------+ | +--+ +----+ | ++* | | | | ++* +------+ +------+ ++* Intersection Union ++* IOU: 0.16 = 12 / 73 ++* ++* All IOU scores are fractions betwen 0.0 (fully disjoint) and 1.0 (perfect ++* overlap). The standard algorithm (PickTop) is defined as follows: ++* ++* 1. Sort boxes by descending order of confidence ++* 2. Take the top one and mark it as keep ++* 3. Suppress (mark it as discard) all boxes within a fixed IOU radius of the ++* keep box ++* 4. Go to 2 and repeat on the subset of boxes not already kept or discarded ++* 5. When all boxes are processed, output only the ones marked as keep ++* ++* Before the algorithm, boxes that fall below the confidence threshold are ++* discarded. ++*/ ++message NonMaximumSuppression { ++ // Suppression methods: ++ /* ++ * Pick the bounding box of the top confidence, suppress all within a radius. ++ */ ++ message PickTop { ++ /* ++ * Suppression is only done among predictions with the same label ++ * (argmax of the confidence). ++ */ ++ bool perClass = 1; ++ } ++ ++ /* ++ * Choose which underlying suppression method to use ++ */ ++ oneof SuppressionMethod { ++ PickTop pickTop = 1; ++ } ++ ++ /* ++ * Optional class label mapping. ++ */ ++ oneof ClassLabels { ++ StringVector stringClassLabels = 100; ++ Int64Vector int64ClassLabels = 101; ++ } ++ ++ /* ++ * This defines the radius of suppression. A box is considered to be within ++ * the radius of another box if their IOU score is less than this value. ++ */ ++ double iouThreshold = 110; ++ ++ /* ++ * Remove bounding boxes below this threshold. The algorithm run-time is ++ * proportional to the square of the number of incoming bounding boxes ++ * (O(N^2)). This threshold is a way to reduce N to make the algorithm ++ * faster. The confidence threshold can be any non-negative value. Negative ++ * confidences are not allowed, since if the output shape is specified to be ++ * larger than boxes after suppression, the unused boxes are filled with ++ * zero confidence. If the prediction is handled by Core Vision, it is also ++ * important that confidences are defined with the following semantics: ++ * ++ * 1. Confidences should be between 0 and 1 ++ * 2. The sum of the confidences for a prediction should not exceed 1, but is ++ * allowed to be less than 1 ++ * 3. The sum of the confidences will be interpreted as the confidence of ++ * any object (e.g. if the confidences for two classes are 0.2 and 0.4, ++ it means there is a 60% (0.2 + 0.4) confidence that an object is ++ present) ++ */ ++ double confidenceThreshold = 111; ++ ++ /* ++ * Set the name of the confidence input. ++ * ++ * The input should be a multi-array of type double and shape N x C. N is ++ * the number of boxes and C the number of classes. Each row describes the ++ * confidences of each object category being present at that particular ++ * location. Confidences should be nonnegative, where 0.0 means the highest ++ * certainty the object is not present. ++ * ++ * Specifying shape is optional. ++ */ ++ string confidenceInputFeatureName = 200; ++ ++ /* ++ * Set the name of the coordinates input. ++ * ++ * The input should be a multi-array of type double and shape N x 4. The ++ * rows correspond to the rows of the confidence matrix. The four values ++ * describe (in order): ++ * ++ * - x (center location of the box along the horizontal axis) ++ * - y (center location of the box along the vertical axis) ++ * - width (size of box along the horizontal axis) ++ * - height (size of box on along the vertical axis) ++ * ++ * Specifying shape is optional. ++ */ ++ string coordinatesInputFeatureName = 201; ++ ++ /* ++ * The iouThreshold can be optionally overridden by specifying this string ++ * and providing a corresponding input of type double. This allows changing ++ * the value of the parameter during run-time. ++ * ++ * The input should be a scalar double between 0.0 and 1.0. Setting it to 1.0 ++ * means there will be no suppression based on IOU. ++ */ ++ string iouThresholdInputFeatureName = 202; ++ ++ /* ++ * The confidenceThreshold can be optionally overridden by specifying this ++ * string and providing a corresponding input. This allows changing the ++ * value of the parameter during run-time, which can aid setting it just ++ * right for a particular use case. ++ * ++ * The input should be a scalar double with nonnegative value. ++ */ ++ string confidenceThresholdInputFeatureName = 203; ++ ++ /* ++ * Set the name of the confidence output. The output will be the same type ++ * and shape as the corresponding input. The only difference is that the ++ * number of rows may have been reduced. ++ * ++ * Specifying shape is optional. One reason to specify shape is to limit ++ * the number of output boxes. This can be done is several ways: ++ * ++ * Fixed shape: ++ * The output can be pinned to a fixed set of boxes. If this number is larger ++ * than the number of boxes that would have been returned, the output is padded ++ * with zeros for both confidence and coordinates. Specifying a fixed shape ++ * can be done by setting either shape (deprecated) or allowedShapes set to ++ * fixedsize. ++ * ++ * Min/max: ++ * It is also possible to set both a minimum and a maximum. The same zero-padding ++ * as for fixed shape is applied when necessary. Setting min/max is done by defining ++ * two allowedShapes, where the first dimension uses a rangeofsizes defining lowerbound ++ * and upperbound. ++ */ ++ string confidenceOutputFeatureName = 210; ++ ++ /* ++ * Set the name of the coordinates output. The output will be the same type ++ * and shape as the corresponding input. The only difference is that the ++ * number of rows may have been reduced. ++ * ++ * Specifying shape is optional. See confidence output for a more detailed ++ * description. Note that to achieve either fixed shape output or a ++ * constraint range of boxes, only one of confidence or coordinates need to ++ * set a shape. Both shapes are allowed to be defined, but in such case they ++ * have to be consistent along dimension 0. ++ */ ++ string coordinatesOutputFeatureName = 211; ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Normalizer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Normalizer.proto +new file mode 100644 +index 000000000..627f7e2e3 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/Normalizer.proto +@@ -0,0 +1,38 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++package CoreML.Specification; ++ ++/** ++ * A normalization preprocessor. ++ */ ++message Normalizer { ++ /** ++ * There are three normalization modes, ++ * which have the corresponding formulas: ++ * ++ * Max ++ * .. math:: ++ * max(x_i) ++ * ++ * L1 ++ * .. math:: ++ * z = ||x||_1 = \sum_{i=1}^{n} |x_i| ++ * ++ * L2 ++ * .. math:: ++ * z = ||x||_2 = \sqrt{\sum_{i=1}^{n} x_i^2} ++ */ ++ enum NormType { ++ LMax = 0; ++ L1 = 1; ++ L2 = 2; ++ } ++ ++ NormType normType = 1; ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/OneHotEncoder.proto b/onnxruntime/core/providers/coreml/mlmodel_format/OneHotEncoder.proto +new file mode 100644 +index 000000000..f47cf2816 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/OneHotEncoder.proto +@@ -0,0 +1,41 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification; ++ ++/** ++ * Transforms a categorical feature into an array. The array will be all ++ * zeros expect a single entry of one. ++ * ++ * Each categorical value will map to an index, this mapping is given by ++ * either the ``stringCategories`` parameter or the ``int64Categories`` ++ * parameter. ++ */ ++message OneHotEncoder { ++ enum HandleUnknown { ++ ErrorOnUnknown = 0; ++ IgnoreUnknown = 1; // Output will be all zeros for unknown values. ++ } ++ ++ /** ++ * Mapping to be used for the encoding. The position of the category in ++ * the below vector determines where the single one entry will be in the ++ * output. ++ */ ++ oneof CategoryType { ++ StringVector stringCategories = 1; ++ Int64Vector int64Categories = 2; ++ } ++ ++ // Output can be a dictionary with only one entry, instead of an array. ++ bool outputSparse = 10; ++ ++ HandleUnknown handleUnknown = 11; ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Parameters.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Parameters.proto +new file mode 100644 +index 000000000..ed1ebe525 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/Parameters.proto +@@ -0,0 +1,52 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification; ++ ++/** ++ * Int64 parameter, ++ * consisting of a default int64 value, and allowed range or set of values ++ * value is unbounded if AllowedValues is not set. ++ */ ++message Int64Parameter { ++ int64 defaultValue = 1; ++ oneof AllowedValues { ++ Int64Range range = 10; ++ Int64Set set = 11; ++ } ++} ++ ++/** ++ * Double parameter, ++ * consisting of a default double value, and allowed range of values ++ * value is unbounded if AllowedValues is not set. ++ */ ++message DoubleParameter { ++ double defaultValue = 1; ++ oneof AllowedValues { ++ DoubleRange range = 10; ++ } ++} ++ ++/** ++ * String parameter, ++ * A default string value must be provided ++ */ ++message StringParameter { ++ string defaultValue = 1; ++} ++ ++/** ++ * String parameter, ++ * A default bool value must be provided ++ */ ++message BoolParameter { ++ bool defaultValue = 1; ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/README.md b/onnxruntime/core/providers/coreml/mlmodel_format/README.md +new file mode 100644 +index 000000000..e5eba65f9 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/README.md +@@ -0,0 +1,16 @@ ++# Core ML Model Format Specification ++This directory contains the protobuf message definitions that comprise the Core ML model document (``.mlmodel``) format. ++ ++The top-level message is ``Model``, which is defined in ``Model.proto``. ++Other message types describe data structures, feature types, feature engineering model types, and predictive model types. ++ ++# Update the Core ML Model Format Specification ++Please do not modify protobuf message definitions, they are copied directly from [Core ML Tools](https://github.com/apple/coremltools) repository. ++ ++To update the Core ML Model Format Schema schema files to a more recent version: ++1. Delete all the protobuf message definitions (`.proto`) from this directory. ++2. Copy the new version of protobuf message definitions (`.proto`) from the `mlmodel/format/` directory of preferred coremltools release branch. ++ ++# Core ML Model Format Schema version history ++## [coremltools 4.0](https://github.com/apple/coremltools/releases/tag/4.0) ++[Core ML Model Format Specification](https://github.com/apple/coremltools/tree/4.0/mlmodel/format) +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/SVM.proto b/onnxruntime/core/providers/coreml/mlmodel_format/SVM.proto +new file mode 100644 +index 000000000..932a4ec21 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/SVM.proto +@@ -0,0 +1,195 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification; ++ ++/// Kernel Definitions ++/// ------------------ ++ ++/** ++ * A linear kernel. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * K(\boldsymbol{x}, \boldsymbol{x'}) = \boldsymbol{x}^T \boldsymbol{x'} ++ */ ++message LinearKernel { ++} ++ ++/** ++ * A Gaussian radial basis function (RBF) kernel. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * K(\boldsymbol{x}, \boldsymbol{x'}) = \ ++ * \exp(-\gamma || \boldsymbol{x} - \boldsymbol{x'} ||^2 ) ++ * ++ */ ++message RBFKernel { ++ double gamma = 1; ++} ++ ++/** ++ * A polynomial kernel. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * K(\boldsymbol{x}, \boldsymbol{x'}) = \ ++ * (\gamma \boldsymbol{x}^T \boldsymbol{x'} + c)^{degree} ++ */ ++message PolyKernel { ++ int32 degree = 1; ++ double c = 2; ++ double gamma = 3; ++} ++ ++/** ++ * A sigmoid kernel. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * K(\boldsymbol{x}, \boldsymbol{x'}) = \ ++ * \tanh(\gamma \boldsymbol{x}^T \boldsymbol{x'} + c) ++ */ ++message SigmoidKernel { ++ double gamma = 1; ++ double c = 2; ++} ++ ++/** ++ * A kernel. ++ */ ++message Kernel { ++ oneof kernel { ++ LinearKernel linearKernel = 1; ++ RBFKernel rbfKernel = 2; ++ PolyKernel polyKernel = 3; ++ SigmoidKernel sigmoidKernel = 4; ++ } ++} ++ ++ ++/// Support Vector Definitions ++/// -------------------------- ++ ++/** ++ * A sparse node. ++ */ ++message SparseNode { ++ int32 index = 1; // 1-based indexes, like libsvm ++ double value = 2; ++} ++ ++/** ++ * A sparse vector. ++ */ ++message SparseVector { ++ repeated SparseNode nodes = 1; ++} ++ ++/** ++ * One or more sparse support vectors. ++ */ ++message SparseSupportVectors { ++ repeated SparseVector vectors = 1; ++} ++ ++/** ++ * A dense vector. ++ */ ++message DenseVector { ++ repeated double values = 1; ++} ++ ++/** ++ * One or more dense support vectors. ++ */ ++message DenseSupportVectors { ++ repeated DenseVector vectors = 1; ++} ++ ++/** ++ * One or more coefficients. ++ */ ++message Coefficients { ++ repeated double alpha = 1; ++} ++ ++/** ++ * A support vector regressor. ++ */ ++message SupportVectorRegressor { ++ Kernel kernel = 1; ++ ++ // Support vectors, either sparse or dense format ++ oneof supportVectors { ++ SparseSupportVectors sparseSupportVectors = 2; ++ DenseSupportVectors denseSupportVectors = 3; ++ } ++ ++ // Coefficients, one for each support vector ++ Coefficients coefficients = 4; ++ ++ double rho = 5; ++} ++ ++/** ++ * A support vector classifier ++ */ ++message SupportVectorClassifier { ++ Kernel kernel = 1; ++ ++ /** ++ * The number of support vectors for each class. ++ */ ++ repeated int32 numberOfSupportVectorsPerClass = 2; ++ ++ /** ++ * The support vectors, in either sparse or dense format. ++ */ ++ oneof supportVectors { ++ SparseSupportVectors sparseSupportVectors = 3; ++ DenseSupportVectors denseSupportVectors = 4; ++ } ++ ++ /** ++ * The coefficients, essentially a two dimensional array of ++ * size: (numberOfClasses-1) by (total number of support vectors) ++ */ ++ repeated Coefficients coefficients = 5; ++ ++ /** ++ * Constants for decision function, ++ * with K*(K-1) / 2 elements, ++ * where K is the number of classes. ++ */ ++ repeated double rho = 6; ++ ++ /** ++ * Pairwise probability information for A vs B classifier. ++ * Total of K*(K-1)/2 elements where K is the number of classes. ++ * These fields are optional, ++ * and only required if you want probabilities or multi class predictions. ++ */ ++ repeated double probA = 7; ++ repeated double probB = 8; ++ ++ /** ++ * Class label mapping. ++ */ ++ oneof ClassLabels { ++ StringVector stringClassLabels = 100; ++ Int64Vector int64ClassLabels = 101; ++ } ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Scaler.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Scaler.proto +new file mode 100644 +index 000000000..f0e13d54b +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/Scaler.proto +@@ -0,0 +1,34 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++package CoreML.Specification; ++ ++/** ++ * A scaling operation. ++ * ++ * This function has the following formula: ++ * ++ * .. math:: ++ * f(x) = scaleValue \cdot (x + shiftValue) ++ * ++ * If the ``scaleValue`` is not given, the default value 1 is used. ++ * If the ``shiftValue`` is not given, the default value 0 is used. ++ * ++ * If ``scaleValue`` and ``shiftValue`` are each a single value ++ * and the input is an array, then the scale and shift are applied ++ * to each element of the array. ++ * ++ * If the input is an integer, then it is converted to a double to ++ * perform the scaling operation. If the output type is an integer, ++ * then it is cast to an integer. If that cast is lossy, then an ++ * error is generated. ++ */ ++message Scaler { ++ repeated double shiftValue = 1; ++ repeated double scaleValue = 2; ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/SoundAnalysisPreprocessing.proto b/onnxruntime/core/providers/coreml/mlmodel_format/SoundAnalysisPreprocessing.proto +new file mode 100644 +index 000000000..05bb744a9 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/SoundAnalysisPreprocessing.proto +@@ -0,0 +1,60 @@ ++// Copyright (c) 2019, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++package CoreML.Specification.CoreMLModels; ++ ++/** ++* A model which takes audio signal samples as input and outputs an array of ++* preprocessed samples according to the specified preprocessing types ++*/ ++message SoundAnalysisPreprocessing { ++ ++ // Specific preprocessing types for sound analysis ++ ++ /* Vggish preprocesses input audio samples and makes them ready to ++ be fed to Vggish feature extractor. ++ c.f. https://arxiv.org/pdf/1609.09430.pdf ++ ++ The preprocessing takes input a single channel (monophonic) audio samples ++ 975 miliseconds long, sampled at 16KHz, i.e., 15600 samples 1D multiarray ++ and produces preprocessed samples in multiarray of shape [1, 96, 64] ++ ++ (1) Splits the input audio samples into overlapping frames, where each ++ frame is 25 milliseconds long and hops forward by 10 milliseconds. ++ Any partial frames at the end are dropped. ++ ++ (2) Hann window: apply a periodic Hann with a window_length of ++ 25 milliseconds, which translates to 400 samples in 16KHz sampling rate ++ ++ w(n) = 0.5 - 0.5 * cos(2*pi*n/window_length_sample), ++ where 0 <= n <= window_lenth_samples - 1 and window_lenth_samples = 400 ++ ++ Then, the Hann window is applied to each frame as below ++ ++ windowed_frame(n) = frame(n) * w(n) ++ where 0 <= n <= window_lenth_samples - 1 and window_lenth_samples = 400 ++ ++ (3) Power spectrum: calculate short-time Fourier transfor magnitude, with ++ an FFT length of 512 ++ ++ (4) Log Mel filter bank: calculates a log magnitude mel-frequency ++ spectrogram minimum frequency of 125Hz and maximum frequency of 7500Hz, ++ number of mel bins is 64, log_offset is 0.01, number of spectrum bins ++ is 64. ++ */ ++ ++ message Vggish { ++ // no specific parameter ++ } ++ ++ // Vision feature print type ++ oneof SoundAnalysisPreprocessingType { ++ Vggish vggish = 20; ++ } ++ ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/TextClassifier.proto b/onnxruntime/core/providers/coreml/mlmodel_format/TextClassifier.proto +new file mode 100644 +index 000000000..bf6d3c7f7 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/TextClassifier.proto +@@ -0,0 +1,43 @@ ++// Copyright (c) 2018, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification.CoreMLModels; ++ ++/** ++* A model which takes a single input string and outputs a ++* label for the input. ++*/ ++message TextClassifier { ++ ++ /* ++ * Stores the resivion number for the model, revision 1 is available on ++ * iOS, tvOS 12.0+, macoOS 10.14+ ++ */ ++ uint32 revision = 1; ++ ++ /* ++ * Stores the language of the model, as specified in BCP-47 format, ++ * e.g. "en-US". See https://tools.ietf.org/html/bcp47 ++ */ ++ string language = 10; ++ ++ /* ++ * Stores the byte representation of learned model parameters ++ */ ++ bytes modelParameterData = 100; ++ ++ /* ++ * Stores the set of output class labels ++ */ ++ oneof ClassLabels { ++ StringVector stringClassLabels = 200; ++ } ++ ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/TreeEnsemble.proto b/onnxruntime/core/providers/coreml/mlmodel_format/TreeEnsemble.proto +new file mode 100644 +index 000000000..defebee98 +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/TreeEnsemble.proto +@@ -0,0 +1,161 @@ ++// Copyright (c) 2017, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++/** ++ * Each tree is a collection of nodes, ++ * each of which is identified by a unique identifier. ++ * ++ * Each node is either a branch or a leaf node. ++ * A branch node evaluates a value according to a behavior; ++ * if true, the node identified by ``true_child_node_id`` is evaluated next, ++ * if false, the node identified by ``false_child_node_id`` is evaluated next. ++ * A leaf node adds the evaluation value to the base prediction value ++ * to get the final prediction. ++ * ++ * A tree must have exactly one root node, ++ * which has no parent node. ++ * A tree must not terminate on a branch node. ++ * All leaf nodes must be accessible ++ * by evaluating one or more branch nodes in sequence, ++ * starting from the root node. ++ */ ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification; ++ ++/** ++ * A tree ensemble post-evaluation transform. ++ */ ++enum TreeEnsemblePostEvaluationTransform { ++ NoTransform = 0; ++ Classification_SoftMax = 1; ++ Regression_Logistic = 2; ++ Classification_SoftMaxWithZeroClassReference = 3; ++} ++ ++/** ++ * Tree ensemble parameters. ++ */ ++message TreeEnsembleParameters { ++ message TreeNode { ++ uint64 treeId = 1; ++ uint64 nodeId = 2; ++ ++ enum TreeNodeBehavior { ++ BranchOnValueLessThanEqual = 0; ++ BranchOnValueLessThan = 1; ++ BranchOnValueGreaterThanEqual = 2; ++ BranchOnValueGreaterThan = 3; ++ BranchOnValueEqual = 4; ++ BranchOnValueNotEqual = 5; ++ LeafNode = 6; ++ } ++ ++ /** ++ * The branch mode parameters. ++ * ++ * If branch is false, ++ * then the parameters in this section must be filled in ++ * to determine how the branching functions. ++ */ ++ TreeNodeBehavior nodeBehavior = 3; ++ ++ /** ++ * If the node behavior mode is a branch mode, ++ * then these values must be filled in. ++ */ ++ uint64 branchFeatureIndex = 10; ++ double branchFeatureValue = 11; ++ uint64 trueChildNodeId = 12; ++ uint64 falseChildNodeId = 13; ++ bool missingValueTracksTrueChild = 14; ++ ++ /** ++ * The leaf mode. ++ * ++ * If ``nodeBahavior`` == ``LeafNode``, ++ * then the evaluationValue is added to the base prediction value ++ * in order to get the final prediction. ++ * To support multiclass classification ++ * as well as regression and binary classification, ++ * the evaluation value is encoded here as a sparse vector, ++ * with evaluationIndex being the index of the base vector ++ * that evaluation value is added to. ++ * In the single class case, ++ * it is expected that evaluationIndex is exactly 0. ++ */ ++ message EvaluationInfo { ++ uint64 evaluationIndex = 1; ++ double evaluationValue = 2; ++ } ++ ++ repeated EvaluationInfo evaluationInfo = 20; ++ ++ /** ++ * The relative hit rate of a node for optimization purposes. ++ * ++ * This value has no effect on the accuracy of the result; ++ * it allows the tree to optimize for frequent branches. ++ * The value is relative, ++ * compared to the hit rates of other branch nodes. ++ * ++ * You typically use a proportion of training samples ++ * that reached this node ++ * or some similar metric to derive this value. ++ */ ++ double relativeHitRate = 30; ++ } ++ ++ repeated TreeNode nodes = 1; ++ ++ /** ++ * The number of prediction dimensions or classes in the model. ++ * ++ * All instances of ``evaluationIndex`` in a leaf node ++ * must be less than this value, ++ * and the number of values in the ``basePredictionValue`` field ++ * must be equal to this value. ++ * ++ * For regression, ++ * this is the dimension of the prediction. ++ * For classification, ++ * this is the number of classes. ++ */ ++ uint64 numPredictionDimensions = 2; ++ ++ /** ++ * The base prediction value. ++ * ++ * The number of values in this must match ++ * the default values of the tree model. ++ */ ++ repeated double basePredictionValue = 3; ++} ++ ++/** ++ * A tree ensemble classifier. ++ */ ++message TreeEnsembleClassifier { ++ TreeEnsembleParameters treeEnsemble = 1; ++ TreeEnsemblePostEvaluationTransform postEvaluationTransform = 2; ++ ++ // Required class label mapping ++ oneof ClassLabels { ++ StringVector stringClassLabels = 100; ++ Int64Vector int64ClassLabels = 101; ++ } ++} ++ ++/** ++ * A tree ensemble regressor. ++ */ ++message TreeEnsembleRegressor { ++ TreeEnsembleParameters treeEnsemble = 1; ++ TreeEnsemblePostEvaluationTransform postEvaluationTransform = 2; ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/VisionFeaturePrint.proto b/onnxruntime/core/providers/coreml/mlmodel_format/VisionFeaturePrint.proto +new file mode 100644 +index 000000000..cd13d290e +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/VisionFeaturePrint.proto +@@ -0,0 +1,63 @@ ++// Copyright (c) 2018, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++package CoreML.Specification.CoreMLModels; ++ ++/** ++* A model which takes an input image and outputs array(s) of features ++* according to the specified feature types ++*/ ++message VisionFeaturePrint { ++ ++ // Specific vision feature print types ++ ++ // Scene extracts features useful for identifying contents of natural images ++ // in both indoor and outdoor environments ++ message Scene { ++ enum SceneVersion { ++ SCENE_VERSION_INVALID = 0; ++ // VERSION_1 is available on iOS,tvOS 12.0+, macOS 10.14+ ++ // It uses a 299x299 input image and yields a 2048 float feature vector ++ SCENE_VERSION_1 = 1; ++ } ++ ++ SceneVersion version = 1; ++ } ++ ++ // Objects extracts features useful for identifying and localizing ++ // objects in natural images ++ message Objects { ++ enum ObjectsVersion { ++ OBJECTS_VERSION_INVALID = 0; ++ // VERSION_1 is available on iOS,tvOS 14.0+, macOS 11.0+ ++ // It uses a 299x299 input image and yields two multiarray ++ // features: one at high resolution of shape (288, 35, 35) ++ // the other at low resolution of shape (768, 17, 17) ++ OBJECTS_VERSION_1 = 1; ++ } ++ ++ ObjectsVersion version = 1; ++ ++ /* ++ * Stores the names of the output features according to the ++ * order of them being computed from the neural network, i.e., ++ * the first element in the output is the earliest being ++ * computed, while the last is the latest being computed. In ++ * general, the order reflects the resolution of the feature. ++ * The earlier it is computed, the higher the feature resolution. ++ */ ++ repeated string output = 100; ++ } ++ ++ // Vision feature print type ++ oneof VisionFeaturePrintType { ++ Scene scene = 20; ++ Objects objects = 21; ++ } ++ ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/WordEmbedding.proto b/onnxruntime/core/providers/coreml/mlmodel_format/WordEmbedding.proto +new file mode 100644 +index 000000000..ec11a67ca +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/WordEmbedding.proto +@@ -0,0 +1,35 @@ ++// Copyright (c) 2019, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification.CoreMLModels; ++ ++/** ++* A model which maps a set of strings into a finite-dimensional real vector space. ++*/ ++message WordEmbedding { ++ ++ /* ++ * Stores the revision number for the model, revision 2 is available on ++ * iOS, tvOS 13.0+, macOS 10.15+ ++ */ ++ uint32 revision = 1; ++ ++ /* ++ * Stores the language of the model, as specified in BCP-47 format, ++ * e.g. "en-US". See https://tools.ietf.org/html/bcp47 ++ */ ++ string language = 10; ++ ++ /* ++ * Stores efficient representation of emebedding as encoded by the Natural Language Framework ++ */ ++ bytes modelParameterData = 100; ++ ++} +diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/WordTagger.proto b/onnxruntime/core/providers/coreml/mlmodel_format/WordTagger.proto +new file mode 100644 +index 000000000..8523e05df +--- /dev/null ++++ b/onnxruntime/core/providers/coreml/mlmodel_format/WordTagger.proto +@@ -0,0 +1,75 @@ ++// Copyright (c) 2018, Apple Inc. All rights reserved. ++// ++// Use of this source code is governed by a BSD-3-clause license that can be ++// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause ++ ++syntax = "proto3"; ++option optimize_for = LITE_RUNTIME; ++ ++import public "DataStructures.proto"; ++ ++package CoreML.Specification.CoreMLModels; ++ ++/** ++* A model which takes a single input string and outputs a ++* sequence of tokens, tags for tokens, along with their ++* locations and lengths, in the original string. ++*/ ++message WordTagger { ++ ++ /* ++ * Stores the resivion number for the model, revision 1 is available on ++ * iOS, tvOS 12.0+, macoOS 10.14+ ++ */ ++ uint32 revision = 1; ++ ++ /* ++ * Stores the language of the model, as specified in BCP-47 format, ++ * e.g. "en-US". See https://tools.ietf.org/html/bcp47 ++ */ ++ string language = 10; ++ ++ /* ++ * Stores the name of tokens output. The output will be ++ * a sequence of strings that contains the tokens in the ++ * input string ++ */ ++ string tokensOutputFeatureName = 20; ++ ++ /* ++ * Stores the name of token tags output. The output will be ++ * a sequence of strings that contains the tags for each ++ * token in the input string ++ */ ++ string tokenTagsOutputFeatureName = 21; ++ ++ /* ++ * Stores the name of token locations output. The output will be ++ * a sequence of integers that contains the locations (indices) ++ * for each token in the input string, location starts from 0 ++ */ ++ string tokenLocationsOutputFeatureName = 22; ++ ++ /* ++ * Stores the name of token lengths output. The output will be ++ * a sequence of integers that contains the lengths for each ++ * token in the input string ++ */ ++ string tokenLengthsOutputFeatureName = 23; ++ ++ /* ++ * Stores the byte representation of learned model parameters ++ */ ++ bytes modelParameterData = 100; ++ ++ /* ++ * Stores the set of output tags ++ */ ++ oneof Tags { ++ StringVector stringTags = 200; ++ } ++ ++ ++ ++} ++ +diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +index bf73c59fb..9c55d37f5 100644 +--- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc ++++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +@@ -87,13 +87,7 @@ struct ProviderHostCPUImpl : ProviderHostCPU { + const TensorShape& indice_shape, + const TensorShape& update_shape) override { return ScatterND::ValidateShapes(input_shape, indice_shape, update_shape); } + // From cpu/tensor/padbase.h (direct) +- Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); } +- +- void PadBase__ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, +- PadsVector& pads) override { +- PadBase::ComputePads(ctx, data_rank, pads_data, pads); +- } +- ++ Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) override { return PadBase::HandleDimValueZero(mode, input_shape, output_shape); } + // From cpu/tensor/split.h (direct) + Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, + int& after_dims_including_split_axis, int& after_dims_excluding_split, +diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h +index f33eec4b9..8dee1cd62 100644 +--- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h ++++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h +@@ -25,8 +25,6 @@ class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Pr + class contrib__AdamWOptimizerBase__Prepare; + class contrib__SGDOptimizerV2Base__Prepare; + +-using PadsVector = InlinedVector; +- + struct ProviderHostCPU { + // From cpu/tensor/gatherbase.h + virtual Status GatherBase__PrepareForCompute(const GatherBase* p, OpKernelContext* context, GatherBase__Prepare& prepare) = 0; +@@ -46,11 +44,7 @@ struct ProviderHostCPU { + const TensorShape& indice_shape, + const TensorShape& update_shape) = 0; + // From cpu/tensor/padbase.h +- virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) = 0; +- +- virtual void PadBase__ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, +- PadsVector& pads) = 0; +- ++ virtual Status PadBase__HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) = 0; + // From cpu/tensor/split.h + virtual Status SplitBase__PrepareForCompute(const SplitBase* p, const TensorShape& input_shape, int num_outputs, int64_t& axis, int& before_dims, + int& after_dims_including_split_axis, int& after_dims_excluding_split, +diff --git a/onnxruntime/core/providers/cpu/tensor/pad.cc b/onnxruntime/core/providers/cpu/tensor/pad.cc +index 912280687..fe5267f20 100644 +--- a/onnxruntime/core/providers/cpu/tensor/pad.cc ++++ b/onnxruntime/core/providers/cpu/tensor/pad.cc +@@ -9,8 +9,6 @@ + #include "core/providers/op_kernel_type_control.h" + #include "core/util/math.h" + +-#include +- + // there's no way to use a raw pointer as the copy destination with std::copy_n + // (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset + // without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning. +@@ -169,7 +167,47 @@ ONNX_CPU_OPERATOR_KERNEL( + + using PadsVector = PadBase::PadsVector; + +-Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) { ++// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values) ++template ++static void PadAxis(T* output, T* input, ptrdiff_t input_delta, ptrdiff_t input_pitch, ++ size_t block_size, size_t block_count) { ++ for (size_t block_index = 0; block_index < block_count; block_index++) { ++ for (size_t i = 0; i < block_size; i++) { ++ *output++ = *input; ++ input += input_delta; ++ } ++ input += input_pitch; ++ } ++} ++ ++// These are optimizations of PadAxis. The inner loop is removed since the innermost axis has a blockSize of 1, ++// and inputPitch and inputDelta are just a single value added each iteration. ++template ++static void PadInnermostAxis(T* output, T* input, ptrdiff_t input_delta, size_t block_count) { ++ for (size_t block_index = 0; block_index < block_count; block_index++) { ++ *output++ = *input; ++ input += input_delta; ++ } ++} ++ ++// For constant padding, there is no input, just a size to write the constant to ++template ++static void PadAxisConstant(T* output, T constant, size_t size) { ++ if (size == 1) { ++ *output = constant; ++ } else if (size == 2) { ++ *output = constant; ++ *(output + 1) = constant; ++ } else { ++ // This would be faster with SSE instructions. ++ // That would mean to have an implementation for each type (uint8, uint32, uint64). ++ T* end = output + size; ++ for (; output != end;) ++ *output++ = constant; ++ } ++} ++ ++Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) { + switch (mode) { + case Mode::Constant: { + // default behavior is fine +@@ -204,66 +242,34 @@ Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_sh + return Status::OK(); + } + +-static void ComputePadWithAxes( +- gsl::span pads_tensor_raw_data, +- std::function get_axis, +- size_t axes_size, +- size_t data_rank, +- PadsVector& pads) { +- for (size_t i = 0; i < axes_size; ++i) { +- const size_t axis = onnxruntime::narrow(HandleNegativeAxis(get_axis(i), data_rank)); +- pads[axis] = pads_tensor_raw_data[i]; // xi_begin +- pads[data_rank + axis] = pads_tensor_raw_data[axes_size + i]; // xi_end +- } +-} ++// special handling for edge case where the input has one or more dims with value of 0 ++template ++static Status PadInputWithDimValueOfZero(OpKernelContext* ctx, ++ const Mode& mode, ++ const TensorShape& input_shape, ++ TensorShapeVector& output_dims, ++ T value) { ++ TensorShape output_shape(output_dims); ++ ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode, input_shape, output_shape)); + +-void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, +- PadsVector& pads) { +- pads.reserve(2 * data_rank); +- const Tensor* axes_tensor = ctx.Input(3); +- if (axes_tensor) { +- const size_t num_axes_dims = axes_tensor->Shape().NumDimensions(); +- ORT_ENFORCE(num_axes_dims == 1, "Axes tensor should be a 1D tensor "); +- +- const int64_t num_axes = axes_tensor->Shape().Size(); +- ORT_ENFORCE(pads_data.size() == narrow(2 * num_axes), +- "Pads tensor size should be equal to twice the number of explicitly provided axes."); +- +- pads.resize(2 * data_rank, 0); +- if (axes_tensor->IsDataType()) { +- auto axes_data = axes_tensor->DataAsSpan(); +- ComputePadWithAxes( +- pads_data, +- [axes_data](size_t idx) -> int64_t { +- return axes_data[idx]; +- }, +- axes_data.size(), +- data_rank, +- pads); +- } else if (axes_tensor->IsDataType()) { +- auto axes_data = axes_tensor->DataAsSpan(); +- ComputePadWithAxes( +- pads_data, +- [axes_data](size_t idx) { +- return axes_data[idx]; +- }, +- axes_data.size(), +- data_rank, +- pads); +- } +- } else { +- ORT_ENFORCE(pads_data.size() == 2 * data_rank, +- "Pads tensor size should be equal to twice the input dimension count "); +- pads.assign(pads_data.begin(), pads_data.end()); ++ auto& output_tensor = *ctx->Output(0, output_shape); ++ ++ // we need to add pads if mode is constant, otherwise the output has one or more dim values of 0 so is empty ++ if (mode == Mode::Constant) { ++ // we add pads with the default value to all dims including those with a value of 0 ++ auto* output = reinterpret_cast(output_tensor.MutableDataRaw()); ++ std::fill_n(output, output_shape.Size(), value); + } ++ ++ return Status::OK(); + } + + // Flatten no padding inner most Axis, so one memcpy cover multiple Axis. + // For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as + // [1,224,224*3] with padding [0,3,3*3,0,3,3*3]. +-void PadBase::FlattenInnerShape(gsl::span input_dims, gsl::span pads, +- gsl::span slices, TensorShapeVector& reshaped_dims) { +- const size_t dims_count = input_dims.size(); ++static void FlattenInnerShape(const TensorShapeVector& input_dims, const PadsVector& pads, ++ const PadsVector& slices, TensorShapeVector& reshaped_dims) { ++ size_t dims_count = input_dims.size(); + size_t inner_axis = dims_count - 1; + size_t inner_size = 1; + +@@ -282,14 +288,14 @@ void PadBase::FlattenInnerShape(gsl::span input_dims, gsl::span 0); + + reshaped_dims.reserve(inner_axis + 1); +- std::copy(input_dims.begin(), input_dims.begin() + inner_axis + 1, std::back_inserter(reshaped_dims)); ++ std::copy(input_dims.cbegin(), input_dims.cbegin() + inner_axis + 1, std::back_inserter(reshaped_dims)); + + // Flatten inner axis. + reshaped_dims[inner_axis] = inner_size; + } + +-void PadBase::ReshapePads(gsl::span src_pad, size_t src_dim_count, size_t new_dim_count, +- size_t inner_no_pad_size, PadsVector& reshaped_pad) { ++static void ReshapePads(const PadsVector& src_pad, size_t src_dim_count, size_t new_dim_count, ++ size_t inner_no_pad_size, PadsVector& reshaped_pad) { + size_t inner_axis = new_dim_count - 1; + std::copy(src_pad.begin(), src_pad.begin() + inner_axis, reshaped_pad.begin()); + std::copy(src_pad.begin() + src_dim_count, src_pad.begin() + src_dim_count + inner_axis, +@@ -300,68 +306,6 @@ void PadBase::ReshapePads(gsl::span src_pad, size_t src_dim_count + reshaped_pad[inner_axis + new_dim_count] = src_pad[inner_axis + src_dim_count] * inner_no_pad_size; + } + +-// special handling for edge case where the input has one or more dims with value of 0 +-template +-static Status PadInputWithDimValueOfZero(OpKernelContext* ctx, +- const Mode& mode, +- const TensorShape& input_shape, +- TensorShapeVector& output_dims, +- T value) { +- TensorShape output_shape(output_dims); +- ORT_RETURN_IF_ERROR(PadBase::HandleDimValueZero(mode, input_shape, output_shape)); +- +- auto& output_tensor = *ctx->Output(0, output_shape); +- +- // we need to add pads if mode is constant, otherwise the output has one or more dim values of 0 so is empty +- if (mode == Mode::Constant) { +- // we add pads with the default value to all dims including those with a value of 0 +- auto* output = reinterpret_cast(output_tensor.MutableDataRaw()); +- std::fill_n(output, output_shape.Size(), value); +- } +- +- return Status::OK(); +-} +- +-// This is the general padding method to n-dimensionally do edge or reflection padding (based on the inputDelta values) +-template +-static void PadAxis(T* output, T* input, ptrdiff_t input_delta, ptrdiff_t input_pitch, +- size_t block_size, size_t block_count) { +- for (size_t block_index = 0; block_index < block_count; block_index++) { +- for (size_t i = 0; i < block_size; i++) { +- *output++ = *input; +- input += input_delta; +- } +- input += input_pitch; +- } +-} +- +-// These are optimizations of PadAxis. The inner loop is removed since the innermost axis has a blockSize of 1, +-// and inputPitch and inputDelta are just a single value added each iteration. +-template +-static void PadInnermostAxis(T* output, T* input, ptrdiff_t input_delta, size_t block_count) { +- for (size_t block_index = 0; block_index < block_count; block_index++) { +- *output++ = *input; +- input += input_delta; +- } +-} +- +-// For constant padding, there is no input, just a size to write the constant to +-template +-static void PadAxisConstant(T* output, T constant, size_t size) { +- if (size == 1) { +- *output = constant; +- } else if (size == 2) { +- *output = constant; +- *(output + 1) = constant; +- } else { +- // This would be faster with SSE instructions. +- // That would mean to have an implementation for each type (uint8, uint32, uint64). +- T* end = output + size; +- for (; output != end;) +- *output++ = constant; +- } +-} +- + template + static Status PadImpl(OpKernelContext* ctx, + const PadsVector& pads, +@@ -383,7 +327,7 @@ static Status PadImpl(OpKernelContext* ctx, + + // Reshape input dims + TensorShapeVector reshaped_input_dims; +- PadBase::FlattenInnerShape(output_dims, pads, slices, reshaped_input_dims); ++ FlattenInnerShape(output_dims, pads, slices, reshaped_input_dims); + + // Reshape padding + size_t new_dims_count = reshaped_input_dims.size(); +@@ -392,8 +336,8 @@ static Status PadImpl(OpKernelContext* ctx, + ? reshaped_input_dims[inner_axis] / output_dims[inner_axis] + : 0); + PadsVector reshaped_pad(2 * new_dims_count), reshaped_slice(2 * new_dims_count); +- PadBase::ReshapePads(pads, data_rank, new_dims_count, inner_no_pad_size, reshaped_pad); +- PadBase::ReshapePads(slices, data_rank, new_dims_count, inner_no_pad_size, reshaped_slice); ++ ReshapePads(pads, data_rank, new_dims_count, inner_no_pad_size, reshaped_pad); ++ ReshapePads(slices, data_rank, new_dims_count, inner_no_pad_size, reshaped_slice); + + TensorShapeVector reshaped_output_dims = reshaped_input_dims; + TensorShapeVector input_starts; +@@ -631,6 +575,20 @@ static PadValue PadValueFromFloat(float value, MLDataType data_type) { + return result; + } + ++template ++void ComputePadWithAxes( ++ gsl::span pads_tensor_raw_data, ++ gsl::span axes_tensor_raw_data, ++ size_t data_rank, ++ PadsVector& pads) { ++ size_t axes_size = axes_tensor_raw_data.size(); ++ for (size_t i = 0; i < axes_size; ++i) { ++ int64_t axis = HandleNegativeAxis(onnxruntime::narrow(axes_tensor_raw_data[i]), data_rank); ++ pads[onnxruntime::narrow(axis)] = pads_tensor_raw_data[i]; // xi_begin ++ pads[data_rank + onnxruntime::narrow(axis)] = pads_tensor_raw_data[axes_size + i]; // xi_end ++ } ++} ++ + Status Pad::Compute(OpKernelContext* ctx) const { + const Tensor& input_tensor = *ctx->Input(0); + MLDataType data_type = input_tensor.DataType(); +@@ -650,14 +608,48 @@ Status Pad::Compute(OpKernelContext* ctx) const { + ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), + "Pads tensor should be a 1D tensor of shape [2 * num_axes] " + "or a 2D tensor of shape [1, 2 * num_axes]"); +- +- const auto pads_data = pads_tensor.DataAsSpan(); +- +- // Compute Pads by applying axes if specified otherwise copy the supplied pads. +- PadBase::ComputePads(*ctx, data_rank, pads_data, pads); ++ const int64_t* pads_tensor_raw_data = pads_tensor.Data(); ++ size_t pads_size = static_cast(pads_tensor.Shape().Size()); ++ pads.reserve(2 * data_rank); ++ ++ const Tensor* axes_tensor = ctx->Input(3); ++ if (axes_tensor) { ++ const auto& axes_tensor_dims = axes_tensor->Shape().GetDims(); ++ ORT_ENFORCE(axes_tensor_dims.size() == 1, "Axes tensor should be a 1D tensor "); ++ int64_t axes_size = axes_tensor_dims[0]; ++ ++ pads.resize(2 * data_rank, 0); ++ if (axes_tensor->IsDataType()) { ++ const int32_t* axes_tensor_raw_data = axes_tensor->Data(); ++ ComputePadWithAxes( ++ {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)}, ++ {axes_tensor_raw_data, onnxruntime::narrow(axes_size)}, ++ data_rank, ++ pads); ++ } else if (axes_tensor->IsDataType()) { ++ const int64_t* axes_tensor_raw_data = axes_tensor->Data(); ++ ComputePadWithAxes( ++ {pads_tensor_raw_data, onnxruntime::narrow(2 * axes_size)}, ++ {axes_tensor_raw_data, onnxruntime::narrow(axes_size)}, ++ data_rank, ++ pads); ++ } ++ } else { ++ ORT_ENFORCE(pads_size == 2 * data_rank, ++ "Pads tensor size should be equal to twice the input dimension count "); ++ for (size_t i = 0; i < pads_size; ++i) { ++ pads.push_back(pads_tensor_raw_data[i]); ++ } ++ } + + // Separate out any negative pads into the slices array +- PadBase::SeparateNegativeToSlices(pads, slices); ++ slices.assign(pads.size(), 0); ++ for (size_t index = 0; index < pads.size(); index++) { ++ if (pads[index] < 0) { ++ slices[index] = pads[index]; ++ pads[index] = 0; ++ } ++ } + + value.u64 = 0U; + const Tensor* value_tensor = ctx->Input(2); +diff --git a/onnxruntime/core/providers/cpu/tensor/padbase.h b/onnxruntime/core/providers/cpu/tensor/padbase.h +index 43f9cbfc9..d869ed1a6 100644 +--- a/onnxruntime/core/providers/cpu/tensor/padbase.h ++++ b/onnxruntime/core/providers/cpu/tensor/padbase.h +@@ -19,80 +19,9 @@ class PadBase { + // Pads and slices are usually about twice the shapes involved + using PadsVector = InlinedVector; + +- // The following several functions are shared among the providers +- +- /// +- /// Handle the case when the input shape has zero dim values. +- /// Depending on the mode, the input dim with zero value must match the output dim value. +- /// +- /// +- /// Padding mode enum value +- /// actual input shape +- /// output_shape +- /// Error if current mode padding can not be achieved with zero dim values +- static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape); +- +- /// +- /// Compute Pads by applying axes if specified otherwise copy the supplied pads. +- /// +- /// The function queries optional axes input (since version 18) and if present, +- /// applies it as a mask to the pads. If axes is not present, the pads are copied as is. +- /// If axes are present, they are used as a mask over pads, so only those axes are being padded. +- /// +- /// kernel context to query axes input +- /// input rank +- /// pads data from pads input +- /// resulting pads +- static void ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, +- PadsVector& pads); +- +- /// +- /// Separates negative pad values to slices and zeros them out in original pads. +- /// Leaving the rest of slices values as zero. +- /// +- /// This function is used inline in the Pad CUDA implementation and is not exposed via a provider +- /// interfaces. +- /// +- /// pad values +- /// slices output +- static void SeparateNegativeToSlices(gsl::span pads, PadsVector& slices) { +- slices.assign(pads.size(), 0); +- for (size_t index = 0, lim = pads.size(); index < lim; index++) { +- if (pads[index] < 0) { +- slices[index] = pads[index]; +- pads[index] = 0; +- } +- } +- } +- +- // End provider shared +- +- /// +- /// Flatten no padding inner most Axis, so one memcpy cover multiple Axis. +- /// For example, for a shape of [1,224,224,3] with padding [0,3,3,0,0,3,3,0], can be flatten as +- /// [1,224,224*3] with padding [0,3,3*3,0,3,3*3]. +- /// +- /// This is a helper function pads are expected to be twice the rank +- /// +- /// original input dims +- /// pad values +- /// slices +- /// result dims +- static void FlattenInnerShape(gsl::span input_dims, gsl::span pads, +- gsl::span slices, TensorShapeVector& reshaped_dims); +- +- /// +- /// Used after the inner shape is flattened, so we can apply this function to pads and slices +- /// to reshape them as well. +- /// +- /// pads +- /// original dim count +- /// expected flattended dim count +- /// is the left most dimension that was flattened. +- /// In the example above, that would be 224, reverse computed from 224*3 +- /// resulting reshaped pads or slices +- static void ReshapePads(gsl::span src_pad, size_t src_dim_count, size_t new_dim_count, +- size_t inner_no_pad_size, PadsVector& reshaped_pad); ++ // Update the output_shape to make it consistent with numpy handling where there are one or more dimensions ++ // in the input_shape with a value of zero. ++ static Status HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape); + + protected: + PadBase(const OpKernelInfo& info) : value_(info.GetAttrOrDefault("value", 0.f)) { +diff --git a/onnxruntime/core/providers/cpu/tensor/scatter.cc b/onnxruntime/core/providers/cpu/tensor/scatter.cc +index c7a200592..8844b7e7a 100644 +--- a/onnxruntime/core/providers/cpu/tensor/scatter.cc ++++ b/onnxruntime/core/providers/cpu/tensor/scatter.cc +@@ -198,6 +198,13 @@ struct Func_Min { + } + }; + ++template <> ++struct Func_Min { ++ void operator()(MLFloat16*, const MLFloat16*) const { ++ ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'min'."); ++ } ++}; ++ + template <> + struct Func_Min { + void operator()(BFloat16*, const BFloat16*) const { +@@ -226,6 +233,13 @@ struct Func_Max { + } + }; + ++template <> ++struct Func_Max { ++ void operator()(MLFloat16*, const MLFloat16*) const { ++ ORT_NOT_IMPLEMENTED("CPU execution provider: MLFloat16 data type is not supported with ScatterElements opset 18 when reduction is 'max'."); ++ } ++}; ++ + template <> + struct Func_Max { + void operator()(BFloat16*, const BFloat16*) const { +diff --git a/onnxruntime/core/providers/cuda/atomic/common.cuh b/onnxruntime/core/providers/cuda/atomic/common.cuh +index 170aa3a2d..14fa2d070 100644 +--- a/onnxruntime/core/providers/cuda/atomic/common.cuh ++++ b/onnxruntime/core/providers/cuda/atomic/common.cuh +@@ -122,316 +122,5 @@ __device__ __forceinline__ void AtomicAdd(half* start_addr, size_t index, + #endif + } + +-// Disable default template instantiation. +-// For every type T, we need to define a specialization +-// to select the right type for calling atomicCAS. +-template +-class AtomicCasType; +- +-template<> +-class AtomicCasType { +- public: +- using type = unsigned short int; +- static const unsigned int mask = 0xffu; +-}; +- +-template<> +-class AtomicCasType { +- public: +- using type = unsigned short int; +- static const unsigned int mask = 0xffffu; +-}; +- +-template<> +-class AtomicCasType { +- public: +- using type = unsigned int; +- static const unsigned int mask = 0xffffffffu; +-}; +- +-template<> +-class AtomicCasType { +- public: +- using type = unsigned long long int; +- static const unsigned int mask = 0xffffffffu; +-}; +- +-template<> +-class AtomicCasType { +- public: +- using type = int; +- static const unsigned int mask = 0xffffffffu; +-}; +- +-template<> +-class AtomicCasType { +- public: +- using type = unsigned long long int; +- static const unsigned int mask = 0xffffffffu; +-}; +- +-// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. +-// +-// This function compute 8-bit atomic binary operation using 32-bit atomicCAS. +-// It accumulate `val` into the `address` using the `func`. +-// The accumulation is atomic (i.e., thread-safe). +-// +-// E.g., Assume ValueType is +-// int8_t +-// and BinaryFunc is +-// struct AddFunc { +-// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const { +-// return a + b; +-// } +-// This function becomes atomic_add for int8_t. +-template +-__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { +- // Assert to ensure the following bit-wise manipulation is correct. +- static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, +- "ValueType must be 1-byte, 2-byte or 4-byte large."); +- // Number of bytes to the lower 4-byte aligned address. +- // If the current address is b1010"10", then offset = b10 = 2, +- // which means the current address is 2 bytes away from +- // the lower 4-byte aligned address b1010"00". +- size_t offset = (size_t)address & 3; +- // Find an new 4-byte aligned address `address_as_ui` lower than +- // or equal to `address`. Lower than `address` so that the actual +- // int8_t byte is in the 4-byte word that we load. +- // +- // This address has the following properties: +- // 1. It is 4-byte aligned. +- // 2. It is lower than or equal to `address`. +- // 3. De-referencing this address may return +- // a uint32_t value that contains the same int8_t +- // value indicated by `address`. +- // +- // E.g., +- // address = b101010 +- // offset = b101010 & b000011 = b10 = 2 +- // (char*)address - offset => (char*)b101010 - b000010 => b1010"00", +- // which is (32-bit aligned). +- uint32_t * address_as_ui = (uint32_t*)((char*)address - offset); +- uint32_t old = *address_as_ui; +- // E.g., offset = 2. +- // address_as_ui is an address 2 bytes lower than `address`. +- // +- // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... +- // ^ ^ ^ +- // | | | +- // | address <--- offset * 8 (bit)-----> address_as_ui +- // | ^ +- // | | +- // ------------------------- *address_as_ui ----------------------- +- // +- // This visualization shows +- // 1. the 32-bit word at address_as_ui. +- // 2. the gap between address_as_ui and address. +- // 3. *address_as_ui contains the int8_t value at `address`. +- uint32_t shift = offset * 8; +- uint32_t old_byte; +- uint32_t newval; +- uint32_t assumed; +- do { +- assumed = old; +- // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so +- // we want to select the 3rd byte (byte 2 below) from the word. +- // +- // Journey of a 32-bit value: +- // +- // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... +- // +- // | +- // | old >> offset * 8, where offset = 2. +- // | Effectively, push lower two bytes +- // | out of the word. +- // V +- // +- // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 ..... +- // +- // | apply bit-wise AND, +- // | & 0xff (i.e., & b11111111), +- // | so that we only keep +- // | the byte of interest. +- // | Otherwise, overflow may +- // | happen when casting this +- // | 32-bit value to int8_t. +- // V +- // +- // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... +- old_byte = (old >> shift) & AtomicCasType::mask; +- // Compute new int8_t value and store it to newrawvalue. +- // Journey of a 32-bit value (cont'd): +- // +- // newrawvalue +- // ... new byte 2 ... +- auto newrawvalue = func(val, reinterpret_cast(old_byte)); +- // Put the new int8_t value back to 32-bit word. +- // Also ensure that bits not occupied by the int8_t value are 0s. +- // +- // Journey of a 32-bit value (cont'd): +- // +- // reinterpret_cast(newrawvalue) +- // random values | random values | random values | ... new byte 2 ... +- // +- // reinterpret_cast(newrawvalue) & AtomicCasType::mask +- // 00000000 | 00000000 | 00000000 | ... new byte 2 ... +- newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask; +- // Journey of a 32-bit value (cont'd): +- // +- // old +- // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... +- // +- // 0x000000ff +- // 00000000 | 00000000 | 00000000 | 11111111 +- // +- // 0x000000ff << shift +- // 00000000 | 11111111 | 00000000 | 00000000 +- // +- // ~(0x000000ff << shift) +- // 11111111 | 00000000 | 11111111 | 11111111 +- // +- // old & ~(0x000000ff << shift) +- // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 ..... +- // +- // newval << shift +- // 00000000 | ... new byte 2 ... | 00000000 | 00000000 +- // +- // (old & ~(0x000000ff << shift)) | (newval << shift) +- // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... +- newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift); +- old = atomicCAS(address_as_ui, assumed, newval); +- } while (assumed != old); +-} +- +-// It accumulates `val` into the `address` using the `func`. +-// This function is thread-safe (i.e., atomic). +-template +-__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { +- ValueType observed = *address, assumed, new_value; +- using CasType = typename AtomicCasType::type; +- static_assert(sizeof(ValueType) == sizeof(CasType), +- "ValueType and CasType must have the same size for calling atomicCAS."); +- auto address_as_cas_type = reinterpret_cast(address); +- do { +- // Record the value used to compute new value. +- assumed = observed; +- +- // Compute expected new value. +- new_value = func(observed, val); +- +- // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. +- // 4 +- // 8 +- auto observed_as_cas_type = *reinterpret_cast(&observed); +- auto new_value_as_cas_type = *reinterpret_cast(&new_value); +- +- // Call atomicCAS as if the 2-byte type variables are all unsigned short int. +- // 4 unsigned int (or int) +- // 8 unsigned long long int +- auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); +- +- // Cast the freshly observed value in memory back to the TwoByteType. +- observed = *reinterpret_cast(&cas_observed_as_cas_type); +- +- // Two cases: +- // 1. compare-and-swap success +- // a. `address` holds `new_value` +- // b. `observed` becomes the new value after the assignment. +- // Thus, the following `observed != new_value` is false, +- // and the loop terminates. +- // 2. compare-and-swap fails +- // a. `address` holds a value different from `observed`, thus, +- // the `new_value` is stale. +- // b. `observed` becomes the fresh value observed in `address`. +- // Thus, the following (observed != new_value) is true, +- // and the loop continues. In the next iteration, the +- // `new_value` is computed again using the fresh `observed`. +- } while (observed != assumed); +-} +- +-struct AddFunc { +- template +- __device__ __forceinline__ T operator()(T a, T b) const { +- return a + b; +- } +-}; +- +-struct MulFunc { +- template +- __device__ __forceinline__ T operator()(T a, T b) const { +- return a * b; +- } +-}; +- +-struct MaxFunc { +- template +- __device__ __forceinline__ T operator()(T a, T b) const { +- return b > a ? b : a; +- } +-}; +- +-struct MinFunc { +- template +- __device__ __forceinline__ T operator()(T a, T b) const { +- return b < a ? b : a; +- } +-}; +- +-__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { +- atomic_byte_func_with_unit32_cas(address, value, AddFunc()); +-} +-__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { +- atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +-} +-__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { +- atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +-} +-__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { +- atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +-} +- +-__device__ __forceinline__ void atomic_mul(half* address, half value) { +-#if __CUDA_ARCH__ >= 700 +- atomic_binary_func(address, value, MulFunc()); +-#else +- atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +-#endif +-} +-__device__ __forceinline__ void atomic_max(half* address, half value) { +-#if __CUDA_ARCH__ >= 700 +- atomic_binary_func(address, value, MaxFunc()); +-#else +- atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +-#endif +-} +-__device__ __forceinline__ void atomic_min(half* address, half value) { +-#if __CUDA_ARCH__ >= 700 +- atomic_binary_func(address, value, MinFunc()); +-#else +- atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +-#endif +-} +- +-__device__ __forceinline__ void atomic_mul(float* address, float value) { +- atomic_binary_func(address, value, MulFunc()); +-} +-__device__ __forceinline__ void atomic_max(float* address, float value) { +- atomic_binary_func(address, value, MaxFunc()); +-} +-__device__ __forceinline__ void atomic_min(float* address, float value) { +- atomic_binary_func(address, value, MinFunc()); +-} +- +-__device__ __forceinline__ void atomic_mul(double* address, double value) { +- atomic_binary_func(address, value, MulFunc()); +-} +-__device__ __forceinline__ void atomic_max(double* address, double value) { +- atomic_binary_func(address, value, MaxFunc()); +-} +-__device__ __forceinline__ void atomic_min(double* address, double value) { +- atomic_binary_func(address, value, MinFunc()); +-} +- +- + } // namespace cuda + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/cuda/cuda_call.cc b/onnxruntime/core/providers/cuda/cuda_call.cc +index f60684795..4f223041e 100644 +--- a/onnxruntime/core/providers/cuda/cuda_call.cc ++++ b/onnxruntime/core/providers/cuda/cuda_call.cc +@@ -30,7 +30,6 @@ const char* CudaErrString(cudaError_t x) { + return cudaGetErrorString(x); + } + +-#ifndef USE_CUDA_MINIMAL + template <> + const char* CudaErrString(cublasStatus_t e) { + cudaDeviceSynchronize(); +@@ -77,7 +76,6 @@ const char* CudaErrString(cufftResult e) { + return "Unknown cufft error status"; + } + } +-#endif + + #ifdef ORT_USE_NCCL + template <> +@@ -134,7 +132,6 @@ std::conditional_t CudaCall( + + template Status CudaCall(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line); + template void CudaCall(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line); +-#ifndef USE_CUDA_MINIMAL + template Status CudaCall(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line); + template void CudaCall(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line); + template Status CudaCall(cudnnStatus_t retCode, const char* exprString, const char* libName, cudnnStatus_t successCode, const char* msg, const char* file, const int line); +@@ -143,7 +140,6 @@ template Status CudaCall(curandStatus_t retCode, const ch + template void CudaCall(curandStatus_t retCode, const char* exprString, const char* libName, curandStatus_t successCode, const char* msg, const char* file, const int line); + template Status CudaCall(cufftResult retCode, const char* exprString, const char* libName, cufftResult successCode, const char* msg, const char* file, const int line); + template void CudaCall(cufftResult retCode, const char* exprString, const char* libName, cufftResult successCode, const char* msg, const char* file, const int line); +-#endif + + #ifdef ORT_USE_NCCL + template Status CudaCall(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line); +diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc +index 65083f89f..33f293894 100644 +--- a/onnxruntime/core/providers/cuda/cuda_common.cc ++++ b/onnxruntime/core/providers/cuda/cuda_common.cc +@@ -14,27 +14,6 @@ namespace cuda { + // 0x04 - pedantic + constexpr const char* kCudaGemmOptions = "ORT_CUDA_GEMM_OPTIONS"; + +-const char* CudaDataTypeToString(cudaDataType_t dt) { +- switch (dt) { +- case CUDA_R_16F: +- return "CUDA_R_16F"; +- case CUDA_R_16BF: +- return "CUDA_R_16BF"; +- case CUDA_R_32F: +- return "CUDA_R_32F"; +-#if !defined(DISABLE_FLOAT8_TYPES) +- // Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8 +- case CUDA_R_8F_E4M3: +- return "CUDA_R_8F_E4M3"; +- case CUDA_R_8F_E5M2: +- return "CUDA_R_8F_E5M2"; +-#endif +- default: +- return ""; +- } +-} +- +-#ifndef USE_CUDA_MINIMAL + // Initialize the singleton instance + HalfGemmOptions HalfGemmOptions::instance; + +@@ -75,6 +54,26 @@ const char* cublasGetErrorEnum(cublasStatus_t error) { + } + } + ++const char* CudaDataTypeToString(cudaDataType_t dt) { ++ switch (dt) { ++ case CUDA_R_16F: ++ return "CUDA_R_16F"; ++ case CUDA_R_16BF: ++ return "CUDA_R_16BF"; ++ case CUDA_R_32F: ++ return "CUDA_R_32F"; ++#if !defined(DISABLE_FLOAT8_TYPES) ++ // Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8 ++ case CUDA_R_8F_E4M3: ++ return "CUDA_R_8F_E4M3"; ++ case CUDA_R_8F_E5M2: ++ return "CUDA_R_8F_E5M2"; ++#endif ++ default: ++ return ""; ++ } ++} ++ + const char* CublasComputeTypeToString(cublasComputeType_t ct) { + switch (ct) { + case CUBLAS_COMPUTE_16F: +@@ -93,7 +92,6 @@ const char* CublasComputeTypeToString(cublasComputeType_t ct) { + return ""; + } + } +-#endif + + // It must exist somewhere already. + cudaDataType_t ToCudaDataType(int32_t element_type) { +diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h +index e9941ce74..707099bac 100644 +--- a/onnxruntime/core/providers/cuda/cuda_common.h ++++ b/onnxruntime/core/providers/cuda/cuda_common.h +@@ -22,14 +22,13 @@ namespace onnxruntime { + namespace cuda { + + #define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr)) +-#ifndef USE_CUDA_MINIMAL + #define CUBLAS_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUBLAS_CALL(expr)) + #define CUSPARSE_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUSPARSE_CALL(expr)) + #define CURAND_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CURAND_CALL(expr)) + #define CUDNN_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDNN_CALL(expr)) + #define CUDNN2_RETURN_IF_ERROR(expr, m) ORT_RETURN_IF_ERROR(CUDNN_CALL2(expr, m)) + #define CUFFT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUFFT_CALL(expr)) +-#endif ++ + // Type mapping for MLFloat16 to half + template + class ToCudaType { +@@ -94,7 +93,7 @@ inline bool CalculateFdmStrides(gsl::span p, const std::vector + KernelCreateInfo BuildKernelCreateInfo() { +@@ -1338,7 +1326,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { + BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing + BuildKernelCreateInfo, + BuildKernelCreateInfo, +-#ifndef USE_CUDA_MINIMAL + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +@@ -1939,7 +1926,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +- BuildKernelCreateInfo, ++ BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +@@ -2014,10 +2001,10 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, ++ BuildKernelCreateInfo, ++ BuildKernelCreateInfo, ++ BuildKernelCreateInfo, ++ BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +@@ -2097,6 +2084,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, ++ BuildKernelCreateInfo, ++ BuildKernelCreateInfo, ++ BuildKernelCreateInfo, ++ BuildKernelCreateInfo, ++ BuildKernelCreateInfo, ++ BuildKernelCreateInfo, ++ BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +@@ -2140,7 +2134,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +- BuildKernelCreateInfo, + + // Opset 17 + BuildKernelCreateInfo, +@@ -2150,23 +2143,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { + + // Opset 18 + BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, + + // Opset 19 + BuildKernelCreateInfo, +@@ -2220,7 +2201,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +-#endif + }; + + for (auto& function_table_entry : function_table) { +@@ -2230,7 +2210,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { + } + } + +-#ifndef USE_CUDA_MINIMAL + #ifndef DISABLE_CONTRIB_OPS + ORT_RETURN_IF_ERROR(::onnxruntime::contrib::cuda::RegisterCudaContribKernels(kernel_registry)); + #endif +@@ -2241,7 +2220,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { + + #ifdef ENABLE_TRAINING_OPS + ORT_RETURN_IF_ERROR(::onnxruntime::cuda::RegisterCudaTrainingKernels(kernel_registry)); +-#endif + #endif + + return Status::OK(); +diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +index 7b507296d..daa3b5ff3 100644 +--- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc ++++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +@@ -16,7 +16,6 @@ namespace cuda { + namespace provider_option_names { + constexpr const char* kDeviceId = "device_id"; + constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; +-constexpr const char* kUserComputeStream = "user_compute_stream"; + constexpr const char* kMemLimit = "gpu_mem_limit"; + constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; + constexpr const char* kCudnnConvAlgoSearch = "cudnn_conv_algo_search"; +@@ -52,7 +51,6 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P + void* alloc = nullptr; + void* free = nullptr; + void* empty_cache = nullptr; +- void* user_compute_stream = nullptr; + ORT_THROW_IF_ERROR( + ProviderOptionsParser{} + .AddValueParser( +@@ -68,14 +66,6 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P + return Status::OK(); + }) + .AddAssignmentToReference(cuda::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) +- .AddValueParser( +- cuda::provider_option_names::kUserComputeStream, +- [&user_compute_stream](const std::string& value_str) -> Status { +- size_t address; +- ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); +- user_compute_stream = reinterpret_cast(address); +- return Status::OK(); +- }) + .AddValueParser( + cuda::provider_option_names::kGpuExternalAlloc, + [&alloc](const std::string& value_str) -> Status { +@@ -136,10 +126,6 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P + + CUDAExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; + info.external_allocator_info = alloc_info; +- +- info.user_compute_stream = user_compute_stream; +- info.has_user_compute_stream = (user_compute_stream != nullptr); +- + return info; + } + +@@ -147,7 +133,6 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution + const ProviderOptions options{ + {cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, +- {cuda::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, + {cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, + {cuda::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, + {cuda::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, +@@ -175,7 +160,6 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid + const ProviderOptions options{ + {cuda::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {cuda::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, +- {cuda::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, + {cuda::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, + {cuda::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, + {cuda::provider_option_names::kCudnnConvAlgoSearch, EnumToName(ort_cudnn_conv_algo_search_mapping, info.cudnn_conv_algo_search)}, +diff --git a/onnxruntime/core/providers/cuda/cuda_pch.h b/onnxruntime/core/providers/cuda/cuda_pch.h +index dfe50fe0a..f48554e8f 100644 +--- a/onnxruntime/core/providers/cuda/cuda_pch.h ++++ b/onnxruntime/core/providers/cuda/cuda_pch.h +@@ -10,19 +10,12 @@ + + #include + #include +-#include +-#ifndef USE_CUDA_MINIMAL + #include + #include + #include + #include + #include + #include +-#else +-typedef void* cudnnHandle_t; +-typedef void* cublasHandle_t; +-typedef void* cublasLtHandle_t; +-#endif + + #ifdef ORT_USE_NCCL + #include +diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +index 0a256394b..7c866395e 100644 +--- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc ++++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +@@ -69,7 +69,6 @@ CudaStream::CudaStream(cudaStream_t stream, + release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream), + deferred_cpu_allocator_(*this), + ep_info_(ep_info) { +-#ifndef USE_CUDA_MINIMAL + if (own_flag) { + CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_)); + CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream)); +@@ -81,12 +80,10 @@ CudaStream::CudaStream(cudaStream_t stream, + cudnn_handle_ = external_cudnn_handle; + CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream)); + } +-#endif + } + + CudaStream::~CudaStream() { + ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd()); +-#ifndef USE_CUDA_MINIMAL + if (own_stream_) { + cublasDestroy(cublas_handle_); + cudnnDestroy(cudnn_handle_); +@@ -94,7 +91,6 @@ CudaStream::~CudaStream() { + if (handle) + cudaStreamDestroy(static_cast(handle)); + } +-#endif + } + + std::unique_ptr CudaStream::CreateNotification(size_t /*num_consumers*/) { +diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc +index c850f7b58..4df59a98b 100644 +--- a/onnxruntime/core/providers/cuda/cudnn_common.cc ++++ b/onnxruntime/core/providers/cuda/cudnn_common.cc +@@ -9,7 +9,7 @@ + #include "core/common/gsl.h" + #include "shared_inc/cuda_call.h" + #include "core/providers/cpu/tensor/utils.h" +-#ifndef USE_CUDA_MINIMAL ++ + namespace onnxruntime { + namespace cuda { + +@@ -222,4 +222,3 @@ const Float8E5M2 Consts::One = Float8E5M2(1.0f, true); + + } // namespace cuda + } // namespace onnxruntime +-#endif +diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h +index fdd14deda..8a94a334e 100644 +--- a/onnxruntime/core/providers/cuda/cudnn_common.h ++++ b/onnxruntime/core/providers/cuda/cudnn_common.h +@@ -7,7 +7,7 @@ + #include + + #include "core/providers/cuda/cuda_common.h" +-#ifndef USE_CUDA_MINIMAL ++ + namespace onnxruntime { + namespace cuda { + +@@ -260,4 +260,3 @@ SetPoolingNdDescriptorHelper(cudnnPoolingDescriptor_t poolingDesc, + + } // namespace cuda + } // namespace onnxruntime +-#endif +diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +index b710e8a1b..10c8625b3 100644 +--- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu ++++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu +@@ -95,37 +95,7 @@ struct OffsetCalculatorFor2D { + + template + struct FuncAssignment { +- __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { +- start_addr[index] = value; +- } +-}; +- +-template +-struct FuncAdd { +- __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { +- atomic_add(start_addr + index, value); +- } +-}; +- +-template +-struct FuncMul { +- __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { +- atomic_mul(start_addr + index, value); +- } +-}; +- +-template +-struct FuncMax { +- __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { +- atomic_max(start_addr + index, value); +- } +-}; +- +-template +-struct FuncMin { +- __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { +- atomic_min(start_addr + index, value); +- } ++ __device__ __inline__ void operator()(T* start_addr, size_t index, T value) const { start_addr[index] = value; } + }; + + template +@@ -268,24 +238,8 @@ Status ScatterElementsImplInternal(cudaStream_t stream, const T* input_data, con + template + Status ScatterElementsImpl(cudaStream_t stream, const T* input_data, const TIndex* indices_data, const T* updates_data, + T* output_data, const GatherScatterElementsArgs& args) { +- if (args.operation == GatherScatterElementsArgs::Operation::NONE) { +- return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, +- FuncAssignment()); +- } else if (args.operation == GatherScatterElementsArgs::Operation::ADD) { +- return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, +- FuncAdd()); +- } else if (args.operation == GatherScatterElementsArgs::Operation::MUL) { +- return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, +- FuncMul()); +- } else if (args.operation == GatherScatterElementsArgs::Operation::MAX) { +- return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, +- FuncMax()); +- } else if (args.operation == GatherScatterElementsArgs::Operation::MIN) { +- return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, +- FuncMin()); +- } else { +- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported reduction operator."); +- } ++ return ScatterElementsImplInternal(stream, input_data, indices_data, updates_data, output_data, args, ++ FuncAssignment()); + } + + #define GATHER_SCATTER_ELEMENTS_SPECIALIZED_TINDEX_IMPL(T, TIndex) \ +diff --git a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h +index 7b1c88f1f..631d0bf04 100644 +--- a/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h ++++ b/onnxruntime/core/providers/cuda/tensor/gather_elements_impl.h +@@ -10,14 +10,6 @@ namespace onnxruntime { + namespace cuda { + + struct GatherScatterElementsArgs { +- enum class Operation { +- NONE, +- ADD, +- MUL, +- MAX, +- MIN +- }; +- + int64_t rank; + int64_t axis; + int64_t input_size; +@@ -27,9 +19,6 @@ struct GatherScatterElementsArgs { + TArray indices_fdms; + TArray indices_strides; + int64_t indices_size; +- // operation used to combine values associated the same +- // memory location in the output tensor. +- Operation operation; + }; + + template +diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc +index bdd6567d2..4584e5fd8 100644 +--- a/onnxruntime/core/providers/cuda/tensor/pad.cc ++++ b/onnxruntime/core/providers/cuda/tensor/pad.cc +@@ -29,27 +29,15 @@ namespace cuda { + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); \ +- ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ +- Pad, \ +- kOnnxDomain, \ +- 13, 17, \ +- T, \ +- kCudaExecutionProvider, \ +- (*KernelDefBuilder::Create()) \ +- .InputMemoryType(OrtMemTypeCPUInput, 1) \ +- .InputMemoryType(OrtMemTypeCPUInput, 2) \ +- .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +- Pad); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Pad, \ + kOnnxDomain, \ +- 18, \ ++ 13, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ +- .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Pad); + +@@ -106,15 +94,28 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { + if (is_dynamic_) { + const Tensor& pads_tensor = *ctx->Input(1); + const auto pads_tensor_dims = pads_tensor.Shape().GetDims(); ++ ORT_ENFORCE(utils::IsPrimitiveDataType(pads_tensor.DataType()), ++ "Pads tensor should be an INT64 tensor"); + ORT_ENFORCE(pads_tensor_dims.size() == 1 || (pads_tensor_dims.size() == 2 && pads_tensor_dims[0] == 1), +- "Pads tensor should be a 1D tensor of shape [2 * num_axes] or a 2D tensor of shape [1, 2 * num_axes]"); ++ "Pads tensor should be a 1D tensor of shape [2 * input_rank] or a 2D tensor of shape [1, 2 * input_rank]"); + +- const auto pads_data = pads_tensor.DataAsSpan(); +- +- PadBase::ComputePads(*ctx, input_shape.NumDimensions(), pads_data, pads); ++ const int64_t* pads_tensor_raw_data = pads_tensor.Data(); ++ size_t pads_size = static_cast(pads_tensor.Shape().Size()); ++ ORT_ENFORCE(pads_size == 2 * static_cast(dimension_count), ++ "Pads tensor size should be equal to twice the input dimension count "); + ++ pads.reserve(2LL * dimension_count); ++ for (size_t i = 0; i < pads_size; ++i) { ++ pads.push_back(pads_tensor_raw_data[i]); ++ } + // Separate out any negative pads into the slices array +- PadBase::SeparateNegativeToSlices(pads, slices); ++ slices.resize(pads.size(), 0); ++ for (size_t index = 0; index < pads.size(); index++) { ++ if (pads[index] < 0) { ++ slices[index] = pads[index]; ++ pads[index] = 0; ++ } ++ } + + T raw_value{}; + const Tensor* value_tensor = ctx->Input(2); +diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc +index 42a9f5000..e4d145154 100755 +--- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc ++++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc +@@ -27,23 +27,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 11, 12, kCudaExe + DataTypeImpl::GetTensorType()}), + ScatterElements); + +-ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 13, 15, kCudaExecutionProvider, +- (*KernelDefBuilder::Create()) +- .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) +- .TypeConstraint("Tind", +- std::vector{DataTypeImpl::GetTensorType(), +- DataTypeImpl::GetTensorType()}), +- ScatterElements); +- +-ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterElements, kOnnxDomain, 16, 17, kCudaExecutionProvider, +- (*KernelDefBuilder::Create()) +- .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) +- .TypeConstraint("Tind", +- std::vector{DataTypeImpl::GetTensorType(), +- DataTypeImpl::GetTensorType()}), +- ScatterElements); +- +-ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 18, kCudaExecutionProvider, ++ONNX_OPERATOR_KERNEL_EX(ScatterElements, kOnnxDomain, 13, kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()) + .TypeConstraint("Tind", std::vector{DataTypeImpl::GetTensorType(), +@@ -122,20 +106,6 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const { + TensorShapeVector indices_shape_vec = indices_shape.AsShapeVector(); + CoalesceDimensions(input_shape_vec, indices_shape_vec, nullptr, axis, args); + +- if (reduction_ == "none") { +- args.operation = GatherScatterElementsArgs::Operation::NONE; +- } else if (reduction_ == "add") { +- args.operation = GatherScatterElementsArgs::Operation::ADD; +- } else if (reduction_ == "mul") { +- args.operation = GatherScatterElementsArgs::Operation::MUL; +- } else if (reduction_ == "min") { +- args.operation = GatherScatterElementsArgs::Operation::MIN; +- } else if (reduction_ == "max") { +- args.operation = GatherScatterElementsArgs::Operation::MAX; +- } else { +- ORT_THROW("Unsupported reduction type"); +- } +- + // Use element size instead of concrete types so we can specialize less template functions to reduce binary size. + int dtype = GetElementType(input_tensor->DataType()->Size()); + if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) { +diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h +index 3884b716d..3e9e0ce04 100755 +--- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.h ++++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.h +@@ -14,12 +14,6 @@ class ScatterElements final : public CudaKernel { + ScatterElements(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(info.GetAttr("axis", &axis_).IsOK(), + "Missing/Invalid 'axis' attribute value"); +- reduction_ = info.GetAttrOrDefault("reduction", "none"); +- +- ORT_ENFORCE(reduction_ == "none" || reduction_ == "add" || +- reduction_ == "mul" || reduction_ == "max" || +- reduction_ == "min", +- "Invalid reduction attribute value of ", reduction_); + } + ~ScatterElements() = default; + Status ComputeInternal(OpKernelContext* context) const override; +@@ -29,10 +23,6 @@ class ScatterElements final : public CudaKernel { + struct ComputeImpl; + + int64_t axis_; +- // "reduction" attribute has been defined since opset 13 but +- // we never implemented it. Let's try to support them starting +- // with opset 18. +- std::string reduction_; + }; + + } // namespace cuda +diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp +index 45ff25c4f..76b9b308f 100644 +--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp ++++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp +@@ -29,7 +29,7 @@ public: + castDesc.OutputTensor = outputDescs.data(); + + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc }; +- ++ + SetDmlOperatorDesc(opDesc, kernelInfo); + } + +@@ -49,6 +49,5 @@ public: + + DML_OP_DEFINE_CREATION_FUNCTION(Cast, DmlOperatorCast); + DML_OP_DEFINE_CREATION_FUNCTION(CastLike15, DmlOperatorCast); +-DML_OP_DEFINE_CREATION_FUNCTION(CastLike19, DmlOperatorCast); + + } // namespace Dml +diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +index 16bb10f00..ab8ddbfe9 100644 +--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp ++++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +@@ -487,7 +487,7 @@ public: + Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); + + std::vector inputDescs = GetDmlInputDescs(); +- std::vector outputDescs = GetDmlOutputDescs(); ++ std::vector outputDescs = GetDmlOutputDescs(); + + DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {}; + opDesc.InputTensor = &inputDescs[0]; +@@ -497,11 +497,11 @@ public: + SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo); + } + else +- { ++ { + Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); + + std::vector inputDescs = GetDmlInputDescs(); +- std::vector outputDescs = GetDmlOutputDescs(); ++ std::vector outputDescs = GetDmlOutputDescs(); + + DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {}; + opDesc.InputTensor = &inputDescs[0]; +@@ -519,16 +519,13 @@ class DmlOperatorElementwiseQLinear : public DmlOperator + public: + DmlOperatorElementwiseQLinear(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo) + { +- +- ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2); ++ ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 3); + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + +- Initialize(kernelInfo, std::nullopt, std::nullopt); +- + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); + const uint32_t outputShapeDimCount = gsl::narrow_cast(outputShape.size()); +- const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType(); +- bool hasZeroPointTensor = kernelInfo.IsInputValid(2); ++ ++ Initialize(kernelInfo, std::nullopt, std::nullopt); + + uint32_t axis = 0; + +@@ -544,14 +541,9 @@ public: + axis = Dml::HandleNegativeAxis(signedAxis, outputShapeDimCount, /*validateAxis*/ false); + } + +- // Explicitly reshape each of the inputs after the first input (scale tensor and optional zero point tensor). ++ // Explicitly reshape each of the inputs after the first input (scale and zero point tensors). + for (uint32_t index = 1, inputCount = gsl::narrow_cast(m_inputTensorDescs.size()); index < inputCount; ++index) + { +- if (!kernelInfo.IsInputValid(index)) +- { +- continue; +- } +- + auto edgeDesc = kernelInfo.GetInputEdgeDescription(index); + assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor); + +@@ -595,8 +587,12 @@ public: + TOperatorDesc opDesc = {}; + opDesc.InputTensor = &inputDescs[0]; + opDesc.ScaleTensor = &inputDescs[1]; +- opDesc.ZeroPointTensor = hasZeroPointTensor ? &inputDescs[2] : nullptr; ++ opDesc.ZeroPointTensor = &inputDescs[2]; + opDesc.OutputTensor = &outputDescs[0]; ++ ++ TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1); ++ TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2); ++ + SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo); + } + }; +diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp +index b243f7e74..a014db5ad 100644 +--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp ++++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp +@@ -51,12 +51,6 @@ public: + { + mode = DML_PADDING_MODE_REFLECTION; + } +-#if DML_TARGET_VERSION >= 0x6300 +- else if (modeString == AttrValue::Wrap) +- { +- mode = DML_PADDING_MODE_WRAP; +- } +-#endif + else + { + ML_INVALID_ARGUMENT("Unknown Pad mode attribute."); +@@ -122,6 +116,5 @@ DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel); + DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel); + DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel); + DML_OP_DEFINE_CREATION_FUNCTION(Pad18, VersionedKernel); +-DML_OP_DEFINE_CREATION_FUNCTION(Pad19, VersionedKernel); + + } // namespace Dml +diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp +index b7cceb1d1..f332fac9d 100644 +--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp ++++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp +@@ -9,12 +9,11 @@ namespace Dml + constexpr NameAndIndex coordinateTransformationModes[] = + { + {"half_pixel", 0}, +- {"half_pixel_symmetric", 1}, +- {"pytorch_half_pixel", 2}, +- {"align_corners", 3}, +- {"asymmetric", 4}, +- {"tf_half_pixel_for_nn", 5}, +- {"tf_crop_and_resize", 6}, ++ {"pytorch_half_pixel", 1}, ++ {"align_corners", 2}, ++ {"asymmetric", 3}, ++ {"tf_half_pixel_for_nn", 4}, ++ {"tf_crop_and_resize", 5}, + }; + + constexpr NameAndIndex nearestNeighborRoundingModes[] = +@@ -51,7 +50,7 @@ void ComputePixelOffsetsAndScales( + uint32_t coordinateTransformationModeValue = *optionalCoordinateTransformationModeValue; + + ML_CHECK_VALID_ARGUMENT( +- !regionOfInterest.empty() || coordinateTransformationModeValue != 6 /*tf_crop_and_resize*/, ++ !regionOfInterest.empty() || coordinateTransformationModeValue != 5 /*tf_crop_and_resize*/, + "Resize expects 'roi' tensor for 'tf_crop_and_resize' mode." + ); + +@@ -89,18 +88,6 @@ void ComputePixelOffsetsAndScales( + break; + + case 1: +- // coordinate_transformation_mode is "half_pixel_symmetric", +- // adjustment = output_width_int / output_width +- // center = input_width / 2 +- // offset = center * (1 - adjustment) +- // x_original = (x + 0.5) / scale - (0.5 - offset) +- // x_original = (x + 0.5) / scale - (0.5 - [(input_width / 2) * (1 - (output_width_int / output_width))]) +- // output_width can be fractional when calculated with scale factor +- inputPixelOffset = 0.5f - float((inputDimensions[i] / 2.0f) * (1.0f - outputDimensions[i] / (scales[i] * inputDimensions[i]))); +- outputPixelOffset = -0.5; +- break; +- +- case 2: + // if coordinate_transformation_mode is "pytorch_half_pixel", + // x_original = length_resized > 1 ? (x_resized + 0.5) / scale - 0.5 : 0 + if (inputDimensions[i] <= 1) +@@ -117,7 +104,7 @@ void ComputePixelOffsetsAndScales( + } + break; + +- case 3: ++ case 2: + // if coordinate_transformation_mode is "align_corners", + // x_original = x_resized * (length_original - 1) / (length_resized - 1) + inputPixelOffset = 0.0; +@@ -134,7 +121,7 @@ void ComputePixelOffsetsAndScales( + } + break; + +- case 4: ++ case 3: + // if coordinate_transformation_mode is "asymmetric", + // x_original = x_resized / scale + inputPixelOffset = 0.0; +@@ -142,7 +129,7 @@ void ComputePixelOffsetsAndScales( + // Keep existing scales. + break; + +- case 5: ++ case 4: + // if coordinate_transformation_mode is "tf_half_pixel_for_nn", + // x_original = (x_resized + 0.5) / scale + inputPixelOffset = 0.0; +@@ -150,7 +137,7 @@ void ComputePixelOffsetsAndScales( + // Keep existing scales. + break; + +- case 6: ++ case 5: + // if coordinate_transformation_mode is "tf_crop_and_resize", + // x_original = length_resized > 1 ? start_x * (length_original - 1) + x_resized * (end_x - start_x) * (length_original - 1) / (length_resized - 1) + // : 0.5 * (start_x + end_x) * (length_original - 1) +@@ -190,7 +177,7 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper + public: + // Resample a multidimensional image to a new size. + DmlOperatorResize(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion) +- : DmlOperator(kernelCreationContext), ++ : DmlOperator(kernelCreationContext), + ResizeHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion) + { + ML_CHECK_VALID_ARGUMENT(!m_scales.empty(), "Resize/Upsample expect scales, either a 2nd input tensors or 'scales' attribute."); +@@ -263,11 +250,6 @@ public: + std::string mode = kernelCreationContext.GetOptionalAttribute(AttrName::Mode, "NEAREST"); + DML_INTERPOLATION_MODE interpolationMode = Dml::MapStringToInteropolationMode(mode); + +- +-#if DML_TARGET_VERSION >= 0x6300 +- const int antialiased = kernelCreationContext.GetOptionalAttribute(AttrName::Antialiased, 0); +-#endif +- + // Map ONNX to DML's mode using offsets and rounding direction. + // These offsets are in addition to the coordinate transform offsets. + DML_AXIS_DIRECTION roundingDirection = DML_AXIS_DIRECTION_DECREASING; +@@ -307,12 +289,7 @@ public: + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + +-#if DML_TARGET_VERSION >= 0x6300 +- DML_RESAMPLE3_OPERATOR_DESC operatorDesc = {}; +- operatorDesc.Antialiased = static_cast(antialiased); +-#else + DML_RESAMPLE2_OPERATOR_DESC operatorDesc = {}; +-#endif + operatorDesc.InputTensor = inputDescs.data(); + operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.InterpolationMode = interpolationMode; +@@ -321,11 +298,8 @@ public: + operatorDesc.DimensionCount = gsl::narrow_cast(paddedScales.size()); + operatorDesc.InputPixelOffsets = inputPixelOffsets.data(); + operatorDesc.OutputPixelOffsets = outputPixelOffsets.data(); +-#if DML_TARGET_VERSION >= 0x6300 +- DML_OPERATOR_DESC opDesc = { DML_OPERATOR_RESAMPLE3, &operatorDesc }; +-#else ++ + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_RESAMPLE2, &operatorDesc }; +-#endif + SetDmlOperatorDesc(opDesc, kernelCreationContext); + } + }; +@@ -368,10 +342,6 @@ void CALLBACK QueryResize(IMLOperatorSupportQueryContextPrivate* context, bool* + DML_OP_DEFINE_CREATION_FUNCTION(Resize10, VersionedKernel); + DML_OP_DEFINE_CREATION_FUNCTION(Resize11, VersionedKernel); + DML_OP_DEFINE_CREATION_FUNCTION(Resize13, VersionedKernel); +-#if DML_TARGET_VERSION >= 0x6300 +-DML_OP_DEFINE_CREATION_FUNCTION(Resize18, VersionedKernel); +-DML_OP_DEFINE_CREATION_FUNCTION(Resize19, VersionedKernel); +-#endif + DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, VersionedKernel); + DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, VersionedKernel); + DML_OP_DEFINE_CREATION_FUNCTION(Upsample10, VersionedKernel); +diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +index 9c136ed8c..15a805195 100644 +--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp ++++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +@@ -358,7 +358,6 @@ DML_OP_EXTERN_CREATION_FUNCTION(Pad7); + DML_OP_EXTERN_CREATION_FUNCTION(Pad11); + DML_OP_EXTERN_CREATION_FUNCTION(Pad13); + DML_OP_EXTERN_CREATION_FUNCTION(Pad18); +-DML_OP_EXTERN_CREATION_FUNCTION(Pad19); + DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth); + DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace); + DML_OP_EXTERN_CREATION_FUNCTION(Sqrt); +@@ -437,7 +436,6 @@ DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMul); + DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMulActivation); + DML_OP_EXTERN_CREATION_FUNCTION(Cast); + DML_OP_EXTERN_CREATION_FUNCTION(CastLike15); +-DML_OP_EXTERN_CREATION_FUNCTION(CastLike19); + DML_OP_EXTERN_CREATION_FUNCTION(MemcpyFromHost); + DML_OP_EXTERN_CREATION_FUNCTION(MemcpyToHost); + DML_OP_EXTERN_CREATION_FUNCTION(TopK7); +@@ -508,8 +506,6 @@ DML_OP_EXTERN_CREATION_FUNCTION(Trilu); + + #if DML_TARGET_VERSION >= 0x6300 + DML_OP_EXTERN_CREATION_FUNCTION(Col2Im); +-DML_OP_EXTERN_CREATION_FUNCTION(Resize18); +-DML_OP_EXTERN_CREATION_FUNCTION(Resize19); + #endif + + DML_OP_EXTERN_CREATION_FUNCTION(Shape); +@@ -602,7 +598,6 @@ constexpr static std::array supportedTypeListSigned + constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64|SupportedTensorDataTypes::Float32}; + constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */}; + constexpr static std::array supportedTypeListResize13 = supportedTypeListResize11; +-constexpr static std::array supportedTypeListResize18 = supportedTypeListResize11; + constexpr static std::array supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; + constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; + constexpr static std::array supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; +@@ -751,11 +746,6 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation + {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 + {REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 + {REG_INFO_VER( 18, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, +- +-#if DML_TARGET_VERSION >= 0x6300 +- {REG_INFO_VER( 19, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, +-#endif +- + {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO( 13, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, +@@ -795,7 +785,6 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation + {REG_INFO_COPY(13, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY(14, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY(16, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, +- {REG_INFO_COPY(19, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY(11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, +@@ -809,7 +798,6 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation + {REG_INFO_COPY( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO_COPY(13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO_COPY(14, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, +- {REG_INFO_COPY(19, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + + // Elementwise + {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, +@@ -869,10 +857,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation + {REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 13, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, +- {REG_INFO( 19, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)}, + {REG_INFO( 10, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 13, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, +- {REG_INFO( 19, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, +@@ -957,7 +943,6 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation + {REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported)}, + {REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, +- {REG_INFO( 19, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, + {REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, +@@ -976,10 +961,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation + {REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, + {REG_INFO_VER( 11, Resize, typeNameListTwo, supportedTypeListResize11, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, + {REG_INFO_VER( 13, Resize, typeNameListTwo, supportedTypeListResize13, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, +-#if DML_TARGET_VERSION >= 0x6300 +- {REG_INFO_VER( 18, Resize, typeNameListTwo, supportedTypeListResize18, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, +- {REG_INFO_VER( 19, Resize, typeNameListTwo, supportedTypeListResize18, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, +-#endif ++ + // Activation Functions + {REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, +@@ -1022,9 +1004,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation + {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, +- {REG_INFO( 19, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, +- {REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, + {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, + {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported)}, +@@ -1035,10 +1015,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation + {REG_INFO( 7, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, + {REG_INFO( 13, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, + {REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, +- {REG_INFO( 19, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, + {REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, + {REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, +- {REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, + {REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, + {REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, + +diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +index 287deaa51..e3df1d00b 100644 +--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h ++++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +@@ -12,7 +12,6 @@ namespace AttrName + static constexpr const char* AllowZero = "allowzero"; + static constexpr const char* Alpha = "alpha"; + static constexpr const char* AlignCorners = "align_corners"; +- static constexpr const char* Antialiased = "antialias"; + static constexpr const char* AutoPad = "auto_pad"; + static constexpr const char* Axes = "axes"; + static constexpr const char* Axis = "axis"; +@@ -150,6 +149,5 @@ namespace AttrValue + static constexpr const char* NearestNeighbor = "NN"; + static constexpr const char* NotSet = "NOTSET"; + static constexpr const char* Reflect = "reflect"; +- static constexpr const char* Wrap = "wrap"; + + } // namespace AttrValue +diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +index 317f5ebcb..83c6748fa 100644 +--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp ++++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +@@ -56,18 +56,6 @@ namespace OperatorHelper + } + } + +- template +- void ExpandToAxes(/*inout*/ std::vector& originalValues, gsl::span axes, std::vector expanded) +- { +- assert(originalValues.size() == axes.size()); +- // Fill in roi and scales/sizes +- for (size_t i = 0; i < axes.size(); i++) +- { +- expanded[axes[i]] = originalValues[i]; +- } +- originalValues = std::move(expanded); +- } +- + float CastFloat16ToFloat32(uint16_t input) + { + // Promote float16m10e5s1 to float32m23e8s1. +@@ -156,6 +144,50 @@ namespace OperatorHelper + } + #pragma warning(pop) + ++ void ReadCpuLocalTensorIntoInt32( ++ const MLOperatorTensor& tensor, ++ std::vector& result ++ ) ++ { ++ result.clear(); ++ ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor."); ++ ++ const std::vector& tensorDimensions = tensor.GetShape(); ++ const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions); ++ ++ switch (tensor.GetTensorDataType()) ++ { ++ case MLOperatorTensorDataType::Int32: ++ { ++ const int32_t* data = tensor.GetData(); ++ result.assign(data, data + elementCount); ++ } ++ break; ++ ++ case MLOperatorTensorDataType::Int64: ++ { ++ const int64_t* data = tensor.GetData(); ++ result.reserve(elementCount); ++ ++ // Use clamped cast rather than static_cast/narrow_cast, ++ // because it's not uncommon for a model to specify a ++ // 64-bit INTMAX constant as a sentinel value to mean ++ // the largest possible value (even though the actual ++ // dimension values come nowhere close to that, far ++ // less than 32-bit INTMAX). ++ for (auto d : gsl::make_span(data, data + elementCount)) ++ { ++ result.push_back(clamp_cast(d)); ++ } ++ } ++ break; ++ ++ default: ++ ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64."); ++ break; ++ } ++ } ++ + void ReadCpuLocalTensorIntoFloat32( + const MLOperatorTensor& tensor, + std::vector& result +@@ -2429,8 +2461,7 @@ namespace OperatorHelper + { + auto& attributes = kernelInformation.GetAttributes(); + m_inputDimensions = shapeInformation.GetInputTensorShape(0); +- std::vector outputSizes; +- std::vector axes; ++ std::vector outputSizes; + + if (opsetVersion >= 11) + { +@@ -2447,38 +2478,7 @@ namespace OperatorHelper + if (kernelInformation.IsInputValid(3)) + { + MLOperatorTensor outputSizesTensor = kernelInformation.GetConstantInputTensor(3); +- ReadCpuLocalTensorIntoInt32(outputSizesTensor, /*out*/ outputSizes); +- } +- +- axes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::Axes); +- // Handle possible axes input +- if (opsetVersion >= 18 && !axes.empty()) +- { +- uint32_t dimCount = gsl::narrow_cast(m_inputDimensions.size()); +- HandleEmptyAxes(/*inout*/ axes, m_inputDimensions, false); +- HandleNegativeAxes(/*inout*/ axes, dimCount); +- +- // Taken from https://github.com/onnx/onnx/blob/3d69db8fd16873d68e7033479467f9478562a12d/onnx/reference/ops/op_resize.py#L303 +- if (!m_scales.empty()) +- { +- std::vector defaultScales(dimCount, 1.0f); +- ExpandToAxes(/*inout*/ m_scales, axes, defaultScales); +- } +- if (!outputSizes.empty()) +- { +- ExpandToAxes(/*inout*/ outputSizes, axes, m_inputDimensions); +- } +- if (!m_regionOfInterest.empty()) +- { +- std::vector defaultRois(dimCount, 0.0f); +- defaultRois.resize(dimCount * 2, 1.0f); +- size_t numAxes = axes.size(); +- for (size_t i = 0; i < axes.size(); i++) +- { +- defaultRois[axes[i]] = m_regionOfInterest[i]; +- defaultRois[axes[i + dimCount]] = m_regionOfInterest[i + numAxes]; +- } +- } ++ ReadCpuLocalTensorIntoInt32(outputSizesTensor, /*out*/ outputSizes); + } + } + else if (opsetVersion >= 9) +diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +index 1b2521a86..0e0e6bb1e 100644 +--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h ++++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +@@ -120,53 +120,9 @@ double CastToFloat64(MLOperatorTensorDataType tensorDataType, const void* p); + void ReadScalarTensorData(const MLOperatorTensor& tensor, /*out*/ void* data, size_t dataByteSize); + int64_t ReadScalarTensorCastToInt64(const MLOperatorTensor& tensor); + double ReadScalarTensorCastToFloat64(const MLOperatorTensor& tensor); +-void ReadCpuLocalTensorIntoFloat32(const MLOperatorTensor& tensor, std::vector& result); +- +-template +-void ReadCpuLocalTensorIntoInt32( +- const MLOperatorTensor& tensor, +- std::vector& result +- ) +-{ +- result.clear(); +- ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor."); + +- const std::vector& tensorDimensions = tensor.GetShape(); +- const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions); +- +- switch (tensor.GetTensorDataType()) +- { +- case MLOperatorTensorDataType::Int32: +- { +- result.resize(elementCount); +- const int32_t* data = tensor.GetData(); +- std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast(v); }); +- } +- break; +- +- case MLOperatorTensorDataType::Int64: +- { +- const int64_t* data = tensor.GetData(); +- result.reserve(elementCount); +- +- // Use clamped cast rather than static_cast/narrow_cast, +- // because it's not uncommon for a model to specify a +- // 64-bit INTMAX constant as a sentinel value to mean +- // the largest possible value (even though the actual +- // dimension values come nowhere close to that, far +- // less than 32-bit INTMAX). +- for (auto d : gsl::make_span(data, data + elementCount)) +- { +- result.push_back(clamp_cast(d)); +- } +- } +- break; +- +- default: +- ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64."); +- break; +- } +-} ++void ReadCpuLocalTensorIntoInt32(const MLOperatorTensor& tensor, std::vector& result); ++void ReadCpuLocalTensorIntoFloat32(const MLOperatorTensor& tensor, std::vector& result); + + class EdgeShapes + { +@@ -1633,7 +1589,6 @@ using ShapeInferenceHelper_Pad7 = VersionedOpsetHelper; + using ShapeInferenceHelper_Pad11 = VersionedOpsetHelper; + using ShapeInferenceHelper_Pad13 = VersionedOpsetHelper; + using ShapeInferenceHelper_Pad18 = VersionedOpsetHelper; +-using ShapeInferenceHelper_Pad19 = VersionedOpsetHelper; + + using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper; + using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper; +@@ -1651,14 +1606,11 @@ using ShapeInferenceHelper_Expand = ExpandHelper; + using ShapeInferenceHelper_Reshape7 = ReshapeHelper; + using ShapeInferenceHelper_Reshape13 = ReshapeHelper; + using ShapeInferenceHelper_Reshape14 = ReshapeHelper; +-using ShapeInferenceHelper_Reshape19 = ReshapeHelper; + using ShapeInferenceHelper_ConstantOfShape = ConstantOfShapeHelper; + using ShapeInferenceHelper_Tile = TileHelper; + using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper; + using ShapeInferenceHelper_Resize11 = VersionedOpsetHelper; + using ShapeInferenceHelper_Resize13 = VersionedOpsetHelper; +-using ShapeInferenceHelper_Resize18 = VersionedOpsetHelper; +-using ShapeInferenceHelper_Resize19 = VersionedOpsetHelper; + using ShapeInferenceHelper_OneHot = OneHotHelper; + + using ShapeInferenceHelper_Sqrt = GetOutputShapeAsInputShapeHelper; +@@ -1773,7 +1725,6 @@ using ShapeInferenceHelper_Identity7 = GetOutputShapeAsInputShapeHelper; + using ShapeInferenceHelper_Identity13 = GetOutputShapeAsInputShapeHelper; + using ShapeInferenceHelper_Identity14 = GetOutputShapeAsInputShapeHelper; + using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper; +-using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper; + using ShapeInferenceHelper_MatMul = MatMulHelper; + using ShapeInferenceHelper_MatMulInteger = MatMulHelper; + using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper; +@@ -1799,7 +1750,6 @@ using ShapeInferenceHelper_CumSum14 = GetOutputShapeAsInputShapeHelper; + using ShapeInferenceHelper_Range = RangeHelper; + + using ShapeInferenceHelper_CastLike15 = GetOutputShapeAsInputShapeHelper; +-using ShapeInferenceHelper_CastLike19 = GetOutputShapeAsInputShapeHelper; + + using ShapeInferenceHelper_DmlFusedConv = ConvHelper; + using ShapeInferenceHelper_DmlFusedConvTranspose = ConvTransposeHelper; +diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +index e725ba085..8438bc620 100644 +--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h ++++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +@@ -408,24 +408,11 @@ namespace OperatorHelper + static const int sc_sinceVer_Split = 18; + static const int sc_sinceVer_LpPool = 18; + static const int sc_sinceVer_Col2Im = 18; +- static const int sc_sinceVer_Resize = 18; + } + + namespace OnnxOperatorSet19 + { + static const int sc_sinceVer_AveragePool = 19; +- static const int sc_sinceVer_Resize = 19; +- static const int sc_sinceVer_Pad = 19; +- static const int sc_sinceVer_Cast = 19; +- static const int sc_sinceVer_CastLike = 19; +- static const int sc_sinceVer_Constant = 19; +- static const int sc_sinceVer_Equal = 19; +- static const int sc_sinceVer_Identity = 19; +- static const int sc_sinceVer_QuantizeLinear = 19; +- static const int sc_sinceVer_DequantizeLinear = 19; +- static const int sc_sinceVer_Reshape = 19; +- static const int sc_sinceVer_Shape = 19; +- static const int sc_sinceVer_Size = 19; + } + + namespace MsftOperatorSet1 +diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +index 3271dab13..05eb0091a 100644 +--- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc ++++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +@@ -5,6 +5,8 @@ + #pragma warning(disable : 4996) + #endif + ++#include "core/providers/dnnl/dnnl_execution_provider.h" ++ + #include + #include + #include +@@ -14,7 +16,6 @@ + + #include "core/platform/ort_mutex.h" + #include "core/providers/shared_library/provider_api.h" +-#include "core/providers/dnnl/dnnl_execution_provider.h" + + #include "core/providers/dnnl/dnnl_fwd.h" + #include "core/providers/dnnl/dnnl_node_capability.h" +@@ -29,7 +30,7 @@ constexpr const char* DNNL = "Dnnl"; + constexpr const char* DNNL_CPU = "DnnlCpu"; + + DnnlExecutionProvider::DnnlExecutionProvider(const DnnlExecutionProviderInfo& info) +- : IExecutionProvider{onnxruntime::kDnnlExecutionProvider}, ++ : IExecutionProvider{onnxruntime::kDnnlExecutionProvider, true}, + info_(info) { + InitProviderOrtApi(); + +@@ -76,8 +77,8 @@ DnnlExecutionProvider::DnnlExecutionProvider(const DnnlExecutionProviderInfo& in + // Log the number of threads used + LOGS_DEFAULT(INFO) << "Allocated " << omp_get_max_threads() << " OpenMP threads for oneDNN ep\n"; + #endif // defined(DNNL_OPENMP) +- metadef_id_generator_ = ModelMetadefIdGenerator::Create(); +-} ++ ++} // namespace onnxruntime + + DnnlExecutionProvider::~DnnlExecutionProvider() { + } +@@ -228,7 +229,7 @@ std::vector> DnnlExecutionProvider::GetCapabi + + // Assign inputs and outputs to subgraph's meta_def + HashValue model_hash; +- int metadef_id = metadef_id_generator_->GenerateId(graph_viewer, model_hash); ++ int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + auto meta_def = ::onnxruntime::IndexedSubGraph_MetaDef::Create(); + meta_def->name() = "DNNL_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id); + meta_def->domain() = kMSDomain; +@@ -263,7 +264,7 @@ std::vector> DnnlExecutionProvider::GetCapabi + graph_viewer.ToProto(*model_proto->mutable_graph(), false, true); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + HashValue model_hash; +- int metadef_id = metadef_id_generator_->GenerateId(graph_viewer, model_hash); ++ int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + std::fstream dump("DNNL_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id) + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); + model_proto->SerializeToOstream(dump); + } +diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +index b7fcbb776..41062ccb4 100644 +--- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h ++++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +@@ -41,7 +41,6 @@ class DnnlExecutionProvider : public IExecutionProvider { + bool debug_log_ = false; + // enable fusion by default + bool enable_fusion_ = true; +- std::unique_ptr metadef_id_generator_; + }; + + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc +index 799d4172f..c2ff2ebc3 100644 +--- a/onnxruntime/core/providers/js/js_execution_provider.cc ++++ b/onnxruntime/core/providers/js/js_execution_provider.cc +@@ -3,7 +3,6 @@ + + #include "js_execution_provider.h" + +-#include + #include + #include + #include +@@ -99,7 +98,6 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Erf); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sigmoid); +-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, HardSigmoid); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Log); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Log); + +@@ -394,7 +392,6 @@ std::unique_ptr RegisterKernels() { + KERNEL_CREATE_INFO(13, Erf), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), + KERNEL_CREATE_INFO(13, Sigmoid), +- KERNEL_CREATE_INFO(6, HardSigmoid), + KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), + KERNEL_CREATE_INFO(13, Log), + +@@ -682,13 +679,9 @@ std::unique_ptr RegisterKernels() { + + using namespace js; + +-JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options) +- : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)}, ++JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info) ++ : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true}, + preferred_data_layout_{info.data_layout} { +- if (session_options) { +- enable_graph_capture_ = session_options->config_options.GetConfigOrDefault("enableGraphCapture", "false") == "true"; +- LOGS_DEFAULT(VERBOSE) << "Graph capture enable: " << enable_graph_capture_; +- } + } + + std::vector JsExecutionProvider::CreatePreferredAllocators() { +@@ -756,46 +749,4 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer + JsExecutionProvider::~JsExecutionProvider() { + } + +-Status JsExecutionProvider::OnRunStart() { +- if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { +- LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; +- EM_ASM({ Module.jsepCaptureBegin(); }); +- } +- return Status::OK(); +-} +- +-Status JsExecutionProvider::OnRunEnd(bool sync_stream) { +- if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { +- if (IsGraphCaptureAllowed()) { +- EM_ASM({ Module.jsepCaptureEnd(); }); +- is_graph_captured_ = true; +- } else { +- IncrementRegularRunCountBeforeGraphCapture(); +- } +- } +- +- return Status::OK(); +-} +- +-bool JsExecutionProvider::IsGraphCaptureEnabled() const { +- return enable_graph_capture_; +-} +- +-bool JsExecutionProvider::IsGraphCaptured() const { +- return is_graph_captured_; +-} +- +-Status JsExecutionProvider::ReplayGraph() { +- ORT_ENFORCE(IsGraphCaptured()); +- EM_ASM({ Module.jsepReplay(); }); +- return Status::OK(); +-} +- +-bool JsExecutionProvider::IsGraphCaptureAllowed() const { +- return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +-} +- +-void JsExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { +- ++regular_run_count_before_graph_capture_; +-} + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h +index 91a3256ec..39d43498c 100644 +--- a/onnxruntime/core/providers/js/js_execution_provider.h ++++ b/onnxruntime/core/providers/js/js_execution_provider.h +@@ -5,7 +5,6 @@ + #pragma once + + #include "core/framework/execution_provider.h" +-#include "core/framework/session_options.h" + #include "core/graph/constants.h" + #include "core/providers/providers.h" + +@@ -39,7 +38,7 @@ struct JsExecutionProviderInfo { + + class JsExecutionProvider : public IExecutionProvider { + public: +- JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options); ++ JsExecutionProvider(const JsExecutionProviderInfo& info); + ~JsExecutionProvider() override; + + std::vector> GetCapability( +@@ -58,22 +57,7 @@ class JsExecutionProvider : public IExecutionProvider { + bool ConcurrentRunSupported() const override { return false; } + + std::vector CreatePreferredAllocators() override; +- +- Status OnRunStart() override; +- Status OnRunEnd(bool sync_stream) override; +- +- bool IsGraphCaptureEnabled() const override; +- bool IsGraphCaptured() const override; +- Status ReplayGraph() override; +- +- private: +- bool IsGraphCaptureAllowed() const; +- void IncrementRegularRunCountBeforeGraphCapture(); + DataLayout preferred_data_layout_; +- bool enable_graph_capture_ = false; +- bool is_graph_captured_ = false; +- int regular_run_count_before_graph_capture_ = 0; +- const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + }; + + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/js/js_provider_factory.cc b/onnxruntime/core/providers/js/js_provider_factory.cc +index cbdf99f70..5b7329a87 100644 +--- a/onnxruntime/core/providers/js/js_provider_factory.cc ++++ b/onnxruntime/core/providers/js/js_provider_factory.cc +@@ -10,22 +10,21 @@ + namespace onnxruntime { + + struct JsProviderFactory : IExecutionProviderFactory { +- JsProviderFactory(const ProviderOptions& provider_options, const SessionOptions* session_options) +- : info_{provider_options}, session_options_(session_options) { ++ JsProviderFactory(const ProviderOptions& provider_options) ++ : info_{provider_options} { + } + + std::unique_ptr CreateProvider() override { +- return std::make_unique(info_, session_options_); ++ return std::make_unique(info_); + } + + private: + JsExecutionProviderInfo info_; +- const SessionOptions* session_options_; + }; + + std::shared_ptr JsProviderFactoryCreator::Create( +- const ProviderOptions& provider_options, const SessionOptions* session_options) { +- return std::make_shared(provider_options, session_options); ++ const ProviderOptions& provider_options) { ++ return std::make_shared(provider_options); + } + + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/js/js_provider_factory_creator.h b/onnxruntime/core/providers/js/js_provider_factory_creator.h +index 510b0fb42..dbabe255c 100644 +--- a/onnxruntime/core/providers/js/js_provider_factory_creator.h ++++ b/onnxruntime/core/providers/js/js_provider_factory_creator.h +@@ -9,11 +9,9 @@ + #include "core/providers/providers.h" + + namespace onnxruntime { +-struct SessionOptions; + + struct JsProviderFactoryCreator { +- static std::shared_ptr Create(const ProviderOptions& provider_options, +- const SessionOptions* session_options); ++ static std::shared_ptr Create(const ProviderOptions& provider_options); + }; + + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/js/operators/flatten.cc b/onnxruntime/core/providers/js/operators/flatten.cc +index 1aacae819..7e4b4c350 100644 +--- a/onnxruntime/core/providers/js/operators/flatten.cc ++++ b/onnxruntime/core/providers/js/operators/flatten.cc +@@ -13,7 +13,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) +- .TypeConstraint("T", JsepSupportedFloatTypes()), ++ .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Flatten); + + ONNX_OPERATOR_VERSIONED_KERNEL_EX( +@@ -23,7 +23,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) +- .TypeConstraint("T", JsepSupportedFloatTypes()), ++ .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Flatten); + + ONNX_OPERATOR_VERSIONED_KERNEL_EX( +@@ -33,7 +33,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) +- .TypeConstraint("T", JsepSupportedFloatTypes()), ++ .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Flatten); + + ONNX_OPERATOR_KERNEL_EX( +@@ -43,7 +43,7 @@ ONNX_OPERATOR_KERNEL_EX( + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .Alias(0, 0) +- .TypeConstraint("T", JsepSupportedFloatTypes()), ++ .TypeConstraint("T", DataTypeImpl::GetTensorType()), + Flatten); + + } // namespace js +diff --git a/onnxruntime/core/providers/js/operators/pad.cc b/onnxruntime/core/providers/js/operators/pad.cc +index 83fee3548..24ba85cbf 100644 +--- a/onnxruntime/core/providers/js/operators/pad.cc ++++ b/onnxruntime/core/providers/js/operators/pad.cc +@@ -14,7 +14,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( + 2, + 10, + kJsExecutionProvider, +- (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), ++ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), + Pad); + + ONNX_OPERATOR_VERSIONED_KERNEL_EX( +@@ -24,7 +24,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( + 12, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) +- .TypeConstraint("T", JsepSupportedFloatTypes()) ++ .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3), +@@ -37,7 +37,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( + 17, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) +- .TypeConstraint("T", JsepSupportedFloatTypes()) ++ .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3), +@@ -50,7 +50,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( + 18, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) +- .TypeConstraint("T", JsepSupportedFloatTypes()) ++ .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3), +@@ -62,7 +62,7 @@ ONNX_OPERATOR_KERNEL_EX( + 19, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) +- .TypeConstraint("T", JsepSupportedFloatTypes()) ++ .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3), +diff --git a/onnxruntime/core/providers/js/operators/slice.cc b/onnxruntime/core/providers/js/operators/slice.cc +index 869b54505..bbafe40ea 100644 +--- a/onnxruntime/core/providers/js/operators/slice.cc ++++ b/onnxruntime/core/providers/js/operators/slice.cc +@@ -12,7 +12,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( + 1, 9, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) +- .TypeConstraint("T", JsepSupportedDataTypes()), ++ .TypeConstraint("T", {DataTypeImpl::GetTensorType(), ++ DataTypeImpl::GetTensorType()}), + Slice_1); + + ONNX_OPERATOR_VERSIONED_KERNEL_EX( +@@ -25,7 +26,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3) + .InputMemoryType(OrtMemTypeCPU, 4) +- .TypeConstraint("T", JsepSupportedDataTypes()), ++ .TypeConstraint("T", {DataTypeImpl::GetTensorType(), ++ DataTypeImpl::GetTensorType()}), + Slice); + + ONNX_OPERATOR_VERSIONED_KERNEL_EX( +@@ -38,7 +40,8 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3) + .InputMemoryType(OrtMemTypeCPU, 4) +- .TypeConstraint("T", JsepSupportedDataTypes()), ++ .TypeConstraint("T", {DataTypeImpl::GetTensorType(), ++ DataTypeImpl::GetTensorType()}), + Slice); + + ONNX_OPERATOR_KERNEL_EX( +@@ -51,7 +54,8 @@ ONNX_OPERATOR_KERNEL_EX( + .InputMemoryType(OrtMemTypeCPU, 2) + .InputMemoryType(OrtMemTypeCPU, 3) + .InputMemoryType(OrtMemTypeCPU, 4) +- .TypeConstraint("T", JsepSupportedDataTypes()), ++ .TypeConstraint("T", {DataTypeImpl::GetTensorType(), ++ DataTypeImpl::GetTensorType()}), + Slice); + + } // namespace js +diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc +index 9082527e3..78563d30b 100644 +--- a/onnxruntime/core/providers/js/operators/unary.cc ++++ b/onnxruntime/core/providers/js/operators/unary.cc +@@ -77,9 +77,6 @@ JSEP_KERNEL_IMPL(Sigmoid, Sigmoid) + JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid) + JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid) + +-JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(HardSigmoid, HardSigmoid, alpha, 0.2, beta, 0.5) +-JSEP_ELEMENTWISE_KERNEL(HardSigmoid, 6, HardSigmoid) +- + JSEP_KERNEL_IMPL(Log, Log) + JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log) + JSEP_ELEMENTWISE_KERNEL(Log, 13, Log) +diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +index 40e76a0a6..8bfa66710 100644 +--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc ++++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +@@ -102,7 +102,7 @@ std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() c + } + + MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) +- : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, device_id_(info.device_id) { ++ : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id), true}, device_id_(info.device_id) { + InitProviderOrtApi(); + // Set GPU device to be used + HIP_CALL_THROW(hipSetDevice(device_id_)); +@@ -165,8 +165,6 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv + MIOPEN_CALL_THROW(miopenCreate(&external_miopen_handle_)); + MIOPEN_CALL_THROW(miopenSetStream(external_miopen_handle_, stream_)); + +- metadef_id_generator_ = ModelMetadefIdGenerator::Create(); +- + LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " + << "device_id: " << device_id_ + << ", migraphx_fp16_enable: " << fp16_enable_ +@@ -759,7 +757,7 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st + + // Generate unique kernel name for MIGraphX subgraph + uint64_t model_hash = 0; +- int id = metadef_id_generator_->GenerateId(graph, model_hash); ++ int id = GenerateMetaDefId(graph, model_hash); + std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(id); + auto meta_def = IndexedSubGraph_MetaDef::Create(); + const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph"; +diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +index d582338c7..c094be510 100644 +--- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h ++++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +@@ -98,7 +98,6 @@ class MIGraphXExecutionProvider : public IExecutionProvider { + AllocatorPtr allocator_; + miopenHandle_t external_miopen_handle_ = nullptr; + rocblas_handle external_rocblas_handle_ = nullptr; +- std::unique_ptr metadef_id_generator_; + }; + + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +index b04703d76..727917ad9 100644 +--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc ++++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.cc +@@ -50,7 +50,7 @@ std::unordered_set GetPartitioningStopOps(const optional& partitioning_stop_ops_list) +- : IExecutionProvider{onnxruntime::kNnapiExecutionProvider}, ++ : IExecutionProvider{onnxruntime::kNnapiExecutionProvider, true}, + nnapi_flags_(nnapi_flags), + partitioning_stop_ops_(GetPartitioningStopOps(partitioning_stop_ops_list)) { + nnapi_handle_ = NnApiImplementation(); +@@ -176,7 +176,7 @@ NnapiExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view + + const auto gen_metadef_name = [&]() { + HashValue model_hash; +- int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); ++ int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + return MakeString(NNAPI, "_", model_hash, "_", metadef_id); + }; + +diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h +index 460616c41..e4911511e 100644 +--- a/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h ++++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h +@@ -6,7 +6,6 @@ + #include "core/common/inlined_containers_fwd.h" + #include "core/common/optional.h" + #include "core/framework/execution_provider.h" +-#include "core/framework/model_metadef_id_generator.h" + #include "core/providers/nnapi/nnapi_builtin/nnapi_api_helper.h" + #include "core/providers/nnapi/nnapi_provider_factory.h" + +@@ -49,6 +48,5 @@ class NnapiExecutionProvider : public IExecutionProvider { + const NnApi* nnapi_handle_ = nullptr; + nnapi::DeviceWrapperVector nnapi_target_devices_; + nnapi::TargetDeviceOption target_device_option_; +- ModelMetadefIdGenerator metadef_id_generator_; + }; + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/partitioning_utils.h b/onnxruntime/core/providers/partitioning_utils.h +index 136725c2f..f9d5f7403 100644 +--- a/onnxruntime/core/providers/partitioning_utils.h ++++ b/onnxruntime/core/providers/partitioning_utils.h +@@ -40,7 +40,7 @@ using OnGroupClosedFn = std::function& group + + /** + Called to create a metadef name. +-Most likely should call ModelMetadefIdGenerator.GenerateId. ++Most likely should call IExecutionProvider::GenerateMetaDefId. + See onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc for example usage. + + @return The metadef name. +diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +index c2e71081b..5d3f406f5 100644 +--- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc ++++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +@@ -12,60 +12,34 @@ + namespace onnxruntime { + namespace qnn { + +-bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer) { +- // It's an Onnx model with Qnn context cache binary if it has a node with EPContext type and the source is QNN or QNNExecutionProvider. +- for (const auto& node : graph_viewer.Nodes()) { +- if (EPCONTEXT_OP == node.OpType()) { +- NodeAttrHelper node_helper(node); +- std::string cache_source = node_helper.Get(SOURCE, ""); +- +- std::transform(cache_source.begin(), +- cache_source.end(), +- cache_source.begin(), +- [](unsigned char c) { return static_cast(std::tolower(c)); }); +- +- if (cache_source == "qnnexecutionprovider" || cache_source == "qnn") { +- return true; ++Status IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, ++ bool& is_qnn_ctx_model) { ++ is_qnn_ctx_model = false; ++ for (const auto& fused_node_graph : fused_nodes_and_graphs) { ++ const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph); ++ // It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type ++ int count = 0; ++ for (const auto& node : graph_viewer.Nodes()) { ++ if (EPCONTEXT_OP == node.OpType()) { ++ is_qnn_ctx_model = true; + } ++ ++count; + } ++ ORT_RETURN_IF(is_qnn_ctx_model && count > 1, "Fused graph should only has 1 single EPContext node."); + } +- return false; ++ return Status::OK(); + } + +-bool IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs) { +- for (const auto& fused_node_graph : fused_nodes_and_graphs) { +- const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph); +- bool has_qnn_ep_context_node = GraphHasEpContextNode(graph_viewer); +- if (has_qnn_ep_context_node) { ++bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer) { ++ // It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type ++ for (const auto& node : graph_viewer.Nodes()) { ++ if (EPCONTEXT_OP == node.OpType()) { + return true; + } + } + return false; + } + +-Status GetMainContextNode(const std::vector& fused_nodes_and_graphs, +- QnnBackendManager* qnn_backend_manager, +- const logging::Logger& logger, +- int& main_context_pos, +- std::unordered_map>& qnn_models) { +- main_context_pos = -1; +- for (size_t i = 0; i < fused_nodes_and_graphs.size(); ++i) { +- const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph); +- const auto& ep_context_node = graph_viewer.Nodes().begin(); +- ORT_RETURN_IF_NOT(EPCONTEXT_OP == ep_context_node->OpType(), "Should only filter in the EPContext node."); +- qnn_models.emplace(ep_context_node->Name(), +- std::make_unique(logger, qnn_backend_manager)); +- NodeAttrHelper node_helper(*ep_context_node); +- int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast(0)); +- if (1 == is_main_context) { +- main_context_pos = static_cast(i); +- } +- } +- +- ORT_RETURN_IF(main_context_pos < 0, "Failed to find the EPContext node with main_context=1"); +- return Status::OK(); +-} +- + Status CreateNodeArgs(const std::vector& names, + const std::unordered_map& tensor_info_table, + std::vector& node_args, +@@ -86,18 +60,32 @@ Status CreateNodeArgs(const std::vector& names, + return Status::OK(); + } + +-Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, +- const onnxruntime::PathString& ctx_onnx_model_path, +- QnnBackendManager* qnn_backend_manager, +- std::unordered_map>& qnn_models) { +- ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node."); +- NodeAttrHelper node_helper(main_context_node); ++Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, ++ QnnBackendManager* qnn_backend_manager, ++ QnnModel& qnn_model, ++ const logging::Logger& logger) { ++ using namespace onnxruntime; ++ std::shared_ptr model; ++ ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger)); ++ const auto& graph = model->MainGraph(); ++ return GetEpContextFromGraph(GraphViewer(graph), ++ ctx_onnx_model_path, ++ qnn_backend_manager, ++ qnn_model); ++} ++ ++Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, ++ const onnxruntime::PathString& ctx_onnx_model_path, ++ QnnBackendManager* qnn_backend_manager, ++ QnnModel& qnn_model) { ++ const auto& node = graph_viewer.Nodes().begin(); ++ NodeAttrHelper node_helper(*node); + bool is_embed_mode = node_helper.Get(EMBED_MODE, true); + if (is_embed_mode) { + const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, ""); + return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), + static_cast(context_binary.length()), +- qnn_models); ++ qnn_model); + } + + std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path(); +@@ -145,16 +133,23 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, + cache_file.close(); + return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), + static_cast(buffer_size), +- qnn_models); ++ qnn_model); + } + +-Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, ++Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, + const onnxruntime::PathString& ctx_onnx_model_path, ++ bool is_qnn_ctx_model, ++ bool is_ctx_cache_file_exist, + QnnBackendManager* qnn_backend_manager, +- std::unordered_map>& qnn_models) { +- Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models); ++ QnnModel& qnn_model, ++ const logging::Logger& logger) { ++ Status status; ++ if (is_qnn_ctx_model) { ++ status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model); ++ } else if (is_ctx_cache_file_exist) { ++ status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger); ++ } + +- // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model + if (!status.IsOK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage()); + } +@@ -162,37 +157,88 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, + return Status::OK(); + } + +-// Figure out the real context cache file path +-// return true if context cache file exists +-bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, +- const std::string& customer_context_cache_path, +- const onnxruntime::PathString& model_pathstring, +- onnxruntime::PathString& context_cache_path) { +- // always try the path set by user first, it's the only way to set it if load model from memory ++Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path, ++ std::string& model_name, ++ std::string& model_description, ++ std::string& graph_partition_name, ++ std::string& cache_source, ++ const logging::Logger& logger) { ++ using namespace onnxruntime; ++ std::shared_ptr model; ++ ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger)); ++ const auto& graph = GraphViewer(model->MainGraph()); ++ const auto& node = graph.Nodes().begin(); ++ NodeAttrHelper node_helper(*node); ++ model_name = graph.Name(); ++ model_description = graph.Description(); ++ graph_partition_name = node_helper.Get(PARTITION_NAME, ""); ++ cache_source = node_helper.Get(SOURCE, ""); ++ ++ return Status::OK(); ++} ++ ++bool IsContextCacheFileExists(const std::string& customer_context_cache_path, ++ const onnxruntime::PathString& model_pathstring, ++ onnxruntime::PathString& context_cache_path) { ++ // Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default + if (!customer_context_cache_path.empty()) { + context_cache_path = ToPathString(customer_context_cache_path); +- } else if (!model_pathstring.empty()) { // model loaded from file +- if (is_qnn_ctx_model) { +- // it's a context cache model, just use the model path +- context_cache_path = model_pathstring; +- } else if (!model_pathstring.empty()) { +- // this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx +- context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); +- } ++ } else if (!model_pathstring.empty()) { ++ context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } + + return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); + } + +-Status CreateEPContextNodes(Model* model, +- unsigned char* buffer, +- uint64_t buffer_size, +- const std::string& sdk_build_version, +- const std::vector& fused_nodes_and_graphs, +- const std::unordered_map>& qnn_models, +- const onnxruntime::PathString& context_cache_path, +- bool qnn_context_embed_mode, +- const logging::Logger& logger) { ++Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path, ++ const std::string& model_name, ++ const std::string& model_description, ++ const std::string& graph_partition_name, ++ const logging::Logger& logger) { ++ std::string model_name_from_ctx_cache; ++ std::string model_description_from_ctx_cache; ++ std::string graph_partition_name_from_ctx_cache; ++ std::string cache_source; ++ auto status = GetMetadataFromEpContextModel(context_cache_path, ++ model_name_from_ctx_cache, ++ model_description_from_ctx_cache, ++ graph_partition_name_from_ctx_cache, ++ cache_source, ++ logger); ++ if (!status.IsOK()) { ++ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContextModel."); ++ } ++ ++ // The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT ++ if (cache_source != kQnnExecutionProvider) { ++ LOGS(logger, VERBOSE) << "Context binary cache is not generated by Ort."; ++ return Status::OK(); ++ } ++ ++ if (model_name != model_name_from_ctx_cache || ++ model_description != model_description_from_ctx_cache || ++ graph_partition_name != graph_partition_name_from_ctx_cache) { ++ std::string message = onnxruntime::MakeString("Metadata mismatch. onnx: ", ++ model_name, " ", model_description, " ", graph_partition_name, ++ " vs epcontext: ", ++ model_name_from_ctx_cache, " ", ++ model_description_from_ctx_cache, " ", ++ graph_partition_name_from_ctx_cache); ++ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, message); ++ } ++ ++ return Status::OK(); ++} ++ ++Status GenerateCtxCacheOnnxModel(Model* model, ++ unsigned char* buffer, ++ uint64_t buffer_size, ++ const std::string& sdk_build_version, ++ const std::vector& fused_nodes_and_graphs, ++ const std::unordered_map>& qnn_models, ++ const onnxruntime::PathString& context_cache_path, ++ bool qnn_context_embed_mode, ++ const logging::Logger& logger) { + auto& graph = model->MainGraph(); + + using namespace ONNX_NAMESPACE; +diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +index b1360b4e5..ba6fe23ec 100644 +--- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h ++++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +@@ -28,44 +28,59 @@ static const std::string EP_SDK_VER = "ep_sdk_version"; + static const std::string PARTITION_NAME = "partition_name"; + static const std::string SOURCE = "source"; + +-bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer); ++Status IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, ++ bool& is_qnn_ctx_model); + +-bool IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs); +- +-Status GetMainContextNode(const std::vector& fused_nodes_and_graphs, +- QnnBackendManager* qnn_backend_manager, +- const logging::Logger& logger, +- int& main_context_pos, +- std::unordered_map>& qnn_models); ++bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer); + + Status CreateNodeArgs(const std::vector& names, + const std::unordered_map& tensor_info_table, + std::vector& node_args, + onnxruntime::Graph& graph); + +-bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, +- const std::string& customer_context_cache_path, +- const onnxruntime::PathString& model_pathstring, +- onnxruntime::PathString& context_cache_path); ++bool IsContextCacheFileExists(const std::string& customer_context_cache_path, ++ const onnxruntime::PathString& model_pathstring, ++ onnxruntime::PathString& context_cache_path); ++ ++Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, ++ QnnBackendManager* qnn_backend_manager, ++ QnnModel& qnn_model, ++ const logging::Logger& logger); + +-Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, +- const onnxruntime::PathString& ctx_onnx_model_path, +- QnnBackendManager* qnn_backend_manager, +- std::unordered_map>& qnn_models); ++Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, ++ const onnxruntime::PathString& ctx_onnx_model_path, ++ QnnBackendManager* qnn_backend_manager, ++ QnnModel& qnn_model); + +-Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, ++Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, + const onnxruntime::PathString& ctx_onnx_model_path, ++ bool is_qnn_ctx_model, ++ bool is_ctx_cache_file_exist, + QnnBackendManager* qnn_backend_manager, +- std::unordered_map>& qnn_models); +- +-Status CreateEPContextNodes(Model* model, +- unsigned char* buffer, +- uint64_t buffer_size, +- const std::string& sdk_build_version, +- const std::vector& fused_nodes_and_graphs, +- const std::unordered_map>& qnn_models, +- const onnxruntime::PathString& context_cache_path, +- bool qnn_context_embed_mode, +- const logging::Logger& logger); ++ QnnModel& qnn_model, ++ const logging::Logger& logger); ++ ++Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path, ++ const std::string& model_name, ++ const std::string& model_description, ++ const std::string& graph_partition_name, ++ const logging::Logger& logger); ++ ++Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path, ++ std::string& model_name, ++ std::string& model_description, ++ std::string& graph_partition_name, ++ std::string& cache_source, ++ const logging::Logger& logger); ++ ++Status GenerateCtxCacheOnnxModel(Model* model, ++ unsigned char* buffer, ++ uint64_t buffer_size, ++ const std::string& sdk_build_version, ++ const std::vector& fused_nodes_and_graphs, ++ const std::unordered_map>& qnn_models, ++ const onnxruntime::PathString& context_cache_path, ++ bool qnn_context_embed_mode, ++ const logging::Logger& logger); + } // namespace qnn + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc +index 9849a05db..f4b0d1ff5 100644 +--- a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc ++++ b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc +@@ -55,19 +55,6 @@ Status SplitOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + return Status::OK(); + } + +-// Converts an ONNX list of split lengths to a QNN list of split indices. +-// Note that the first split index at 0 is implicit (QNN SDK >= 2.19 will raise a validation error if included). +-static void ConvertSplitLengthsToSplitIndices(gsl::span split_lengths, +- std::vector& split_indices) { +- uint32_t split_it = 0; +- for (size_t i = 0; i < split_lengths.size(); ++i) { +- if (i > 0) { // Do not include the 0th split index. +- split_indices.push_back(split_it); +- } +- split_it += SafeInt(split_lengths[i]); +- } +-} +- + Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, +@@ -92,15 +79,22 @@ Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr + const int64_t* tensor_data = reinterpret_cast(unpacked_tensor.data()); + size_t tensor_byte_size = unpacked_tensor.size(); + size_t size = tensor_byte_size / sizeof(int64_t); +- ConvertSplitLengthsToSplitIndices({tensor_data, size}, split_index); ++ split_index.push_back(0); // QNN need the start index of each range and starts from 0 ++ std::transform(tensor_data, tensor_data + size, std::back_inserter(split_index), ++ [](int64_t item) { return SafeInt(item); }); ++ split_index.pop_back(); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic split"); + } + } else { + NodeAttrHelper node_helper(node_unit); + if (node_helper.HasAttr("split")) { +- auto split_lengths = node_helper.Get("split", std::vector{0}); +- ConvertSplitLengthsToSplitIndices(split_lengths, split_index); ++ auto split = node_helper.Get("split", std::vector{0}); ++ uint32_t split_it = 0; ++ for (size_t i = 0; i < split.size(); ++i) { ++ split_index.push_back(split_it); ++ split_it += split[i]; ++ } + } + } + +@@ -111,19 +105,11 @@ Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr + "Cannot get shape"); + ORT_ENFORCE(static_cast(input_shape.size()) > axis_value, "axis not valid!"); + ORT_RETURN_IF_NOT(input_shape.at(axis_value) > 0, "Shape value not valid!"); +- +- // ONNX spec states that if not evenly divisible by `num_outputs`, the last chunk is smaller. +- // Therefore, we have to use ceil() when computing shape[axis] / num_outputs. +- // See: core/providers/cpu/tensor/split.cc::PrepareForCompute() +- const float num_outputs = static_cast(node_unit.Outputs().size()); +- const float split_dim_size = static_cast(input_shape[axis_value]); +- const uint32_t step = SafeInt(std::ceil(split_dim_size / num_outputs)); ++ auto num_outputs = node_unit.Outputs().size(); ++ auto step = SafeInt(input_shape.at(axis_value) / num_outputs); + uint32_t split_it = 0; +- + for (size_t i = 0; i < num_outputs; ++i) { +- if (i > 0) { // 0th split index is implicit (QNN >= 2.19 raises validation error if included) +- split_index.push_back(split_it); +- } ++ split_index.push_back(split_it); + split_it += step; + } + } +diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +index 5f0b87c7c..973b81d33 100644 +--- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc ++++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +@@ -517,8 +517,7 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 + return context_buffer; + } + +-Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, +- std::unordered_map>& qnn_models) { ++Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model) { + bool result = nullptr == qnn_sys_interface_.systemContextCreate || + nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || + nullptr == qnn_sys_interface_.systemContextFree; +@@ -551,9 +550,8 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t + graphs_info = binary_info->contextBinaryInfoV2.graphs; + } + +- ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); +- LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count << ", EPContext node count: " << qnn_models.size(); +- ORT_RETURN_IF(graph_count != qnn_models.size(), "Graph count from QNN context not equal to EPContext node count."); ++ ORT_RETURN_IF(graph_count > 1, "Load from Qnn cached context only support 1 sub-graph."); ++ ORT_RETURN_IF(graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); + + ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, + "Invalid function pointer for contextCreateFromBinary."); +@@ -573,12 +571,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t + + // More work to support multiple partition, how to map the graph name in compile to qnn graph name + // Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile +- for (uint32_t i = 0; i < graph_count; ++i) { +- std::string graph_name(graphs_info[i].graphInfoV1.graphName); +- auto qnn_model_pos = qnn_models.find(graph_name); +- ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names."); +- ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i])); +- } ++ ORT_RETURN_IF_ERROR(qnn_model.DeserializeGraphInfoFromBinaryInfo(graphs_info[0])); + + qnn_sys_interface_.systemContextFree(sys_ctx_handle); + sys_ctx_handle = nullptr; +diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +index 36375522b..f7b8947ab 100644 +--- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h ++++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +@@ -87,8 +87,7 @@ class QnnBackendManager { + + std::unique_ptr GetContextBinaryBuffer(uint64_t& written_buffer_size); + +- Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, +- std::unordered_map>& qnn_models); ++ Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model); + + Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); + +diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc +index 314cab4a3..869d9326d 100644 +--- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc ++++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc +@@ -97,8 +97,7 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, + std::unordered_map node_unit_map; + std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + +- // This name must be same with the EPContext node name +- const auto& graph_name = fused_node.Name(); ++ const auto& graph_name = graph_viewer.Name(); + ORT_RETURN_IF_ERROR(SetGraphInputOutputInfo(graph_viewer, fused_node)); + + QnnModelWrapper qnn_model_wrapper = QnnModelWrapper(graph_viewer, logger_, +diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +index b58f6e10d..0310cc2bc 100644 +--- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc ++++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +@@ -129,7 +129,7 @@ static void ParseHtpArchitecture(const std::string& htp_arch_string, QnnHtpDevic + + QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map, + const SessionOptions* session_options) +- : IExecutionProvider{onnxruntime::kQnnExecutionProvider} { ++ : IExecutionProvider{onnxruntime::kQnnExecutionProvider, true} { + if (session_options) { + disable_cpu_ep_fallback_ = session_options->config_options.GetConfigOrDefault( + kOrtSessionOptionsDisableCPUEPFallback, "0") == "1"; +@@ -150,7 +150,6 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio + LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode_; + + context_cache_path_cfg_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); +- LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_cfg_; + } + + static const std::string BACKEND_PATH = "backend_path"; +@@ -319,27 +318,14 @@ std::unordered_set + QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, + const std::unordered_map& node_unit_map, + const size_t node_unit_size, +- bool is_qnn_ctx_model, ++ bool load_from_cached_context, + const logging::Logger& logger) const { + std::unordered_set supported_nodes{}; +- // Filter in the EPContext node for QNN +- if (is_qnn_ctx_model) { ++ // Enable Qnn context cache requires the whole graph partitioned to Qnn EP ++ // Blindly filter in all nodes if context cache is enabled ++ if (load_from_cached_context) { + for (const auto& node : graph_viewer.Nodes()) { +- NodeAttrHelper node_helper(node); +- std::string cache_source = node_helper.Get(qnn::SOURCE, ""); +- +- std::transform(cache_source.begin(), +- cache_source.end(), +- cache_source.begin(), +- [](unsigned char c) { return static_cast(std::tolower(c)); }); +- +- if (qnn::EPCONTEXT_OP == node.OpType() && (cache_source == "qnnexecutionprovider" || cache_source == "qnn")) { +- LOGS(logger, VERBOSE) << "Node supported: [1] index: [" << node.Index() +- << "] name: [" << node.Name() +- << "] Operator type: [EPContext" +- << "] index: [" << node.Index() << "]"; +- supported_nodes.insert(&node); +- } ++ supported_nodes.insert(&node); + } + return supported_nodes; + } +@@ -424,11 +410,22 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer + } + + const auto& logger = *GetLogger(); +- bool is_qnn_ctx_model = qnn::GraphHasEpContextNode(graph_viewer); ++ bool load_from_cached_context = false; ++ bool is_qnn_ctx_model = qnn::IsQnnCtxModel(graph_viewer); ++ if (is_qnn_ctx_model) { ++ load_from_cached_context = true; ++ } + +- // It will load the QnnSystem lib if is_qnn_ctx_model=true, and +- // delay the Qnn context creation to Compile() using the cached context binary +- auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model); ++ // This is for case: QDQ model + Onnx Qnn context cache model ++ if (context_cache_enabled_ && !is_qnn_ctx_model) { ++ onnxruntime::PathString context_cache_path; ++ load_from_cached_context = qnn::IsContextCacheFileExists(context_cache_path_cfg_, ++ graph_viewer.ModelPath().ToPathString(), ++ context_cache_path); ++ } ++ ++ // Load from cached context will load the QnnSystem lib and skip the Qnn context creation ++ auto rt = qnn_backend_manager_->SetupBackend(logger, load_from_cached_context); + if (Status::OK() != rt) { + LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage(); + return result; +@@ -446,7 +443,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer + std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); + + const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(), +- is_qnn_ctx_model, logger); ++ load_from_cached_context, logger); + + // Helper function that returns a string that lists all unsupported nodes. + // Ex: { name: mul_123, type: Mul }, {}, ... +@@ -475,7 +472,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer + + const auto gen_metadef_name = [&]() { + uint64_t model_hash; +- int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); ++ int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + return MakeString(QNN, "_", model_hash, "_", metadef_id); + }; + +@@ -499,7 +496,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer + if (partition && partition->sub_graph) { + nodes_in_partition = partition->sub_graph->nodes.size(); + +- if (nodes_in_partition == 1 && !is_qnn_ctx_model) { ++ if (nodes_in_partition == 1) { + const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); + + if (!node) { +@@ -519,7 +516,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer + result.push_back(std::move(partition)); + num_of_supported_nodes += nodes_in_partition; + } +- } // for ++ } + } + + const size_t num_of_partitions = result.size(); +@@ -530,7 +527,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer + + // Print list of unsupported nodes to the ERROR logger if the CPU EP + // has been disabled for this inference session. +- if (!is_qnn_ctx_model && disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) { ++ if (disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) { + LOGS(logger, ERROR) << "Unsupported nodes in QNN EP: " << get_unsupported_node_names(); + } + +@@ -621,76 +618,64 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) { + const auto& logger = *GetLogger(); ++ Node& fused_node = fused_nodes_and_graphs[0].fused_node; ++ const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph); + +- bool is_qnn_ctx_model = qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs); ++ bool is_qnn_ctx_model = false; ++ ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model)); + + onnxruntime::PathString context_cache_path; +- bool is_ctx_file_exist = false; +- if (is_qnn_ctx_model || context_cache_enabled_) { +- const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); +- is_ctx_file_exist = qnn::ValidateContextCacheFilePath(is_qnn_ctx_model, +- context_cache_path_cfg_, +- graph_viewer_0.ModelPath().ToPathString(), +- context_cache_path); +- } +- +- ORT_RETURN_IF(is_ctx_file_exist && !is_qnn_ctx_model && context_cache_enabled_, +- "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ", +- "Please remove the EP context model manually if you want to re-generate it."); +- +- if (is_qnn_ctx_model) { +- // Table, the node name is the graph_meta_id (old) created from user model which used to generate the EP context model +- // for this session (created from an EP context model), the graph_meta_id is new +- std::unordered_map> qnn_models; +- +- int main_context_pos = -1; +- ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, qnn_backend_manager_.get(), +- logger, main_context_pos, qnn_models)); +- +- const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); +- // Create QNN context from the cached binary, deserialize the QNN graph from the binary +- ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, ++ bool is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, ++ graph_viewer.ModelPath().ToPathString(), ++ context_cache_path); ++ const std::string& model_name = graph_viewer.GetGraph().Name(); ++ const std::string& model_description = graph_viewer.GetGraph().Description(); ++ const std::string& graph_meta_id = fused_node.Name(); ++ if (fused_nodes_and_graphs.size() == 1 && !is_qnn_ctx_model && is_ctx_file_exist) { ++ ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(context_cache_path, ++ model_name, ++ model_description, ++ graph_meta_id, ++ logger)); ++ } ++ ++ if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { ++ ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); ++ std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); ++ // Load and execute from cached context if exist ++ ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxModel(graph_viewer, + context_cache_path, ++ is_qnn_ctx_model, ++ is_ctx_file_exist, + qnn_backend_manager_.get(), +- qnn_models)); +- +- for (auto fused_node_and_graph : fused_nodes_and_graphs) { +- const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); +- const auto& ep_context_node = graph_viewer.Nodes().begin(); +- const Node& fused_node = fused_node_and_graph.fused_node; +- const std::string& graph_meta_id = fused_node.Name(); +- std::string key = ep_context_node->Name(); +- ORT_RETURN_IF(qnn_models.find(key) == qnn_models.end(), key + " key name not exist in table qnn_models."); +- auto qnn_model = std::move(qnn_models[key]); +- ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); +- ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); +- +- // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] +- // the name here must be same with context->node_name in compute_info +- qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); +- +- ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); +- } ++ *(qnn_model.get()), ++ logger)); ++ ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); ++ ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); + ++ // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] ++ // the name here should be same with context->node_name in compute_info ++ qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); ++ ++ ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); + return Status::OK(); + } + + ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); +- // Generate QNN context model if it's QDQ model + context_cache_enabled=true + not exist already +- if (!is_qnn_ctx_model && context_cache_enabled_ && !is_ctx_file_exist) { +- // All partitioned graph share single QNN context, included in the same context binary ++ if (context_cache_enabled_ && !is_qnn_ctx_model) { ++ ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); + uint64_t buffer_size(0); + auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); + qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger); +- ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(), +- context_buffer.get(), +- buffer_size, +- qnn_backend_manager_->GetSdkVersion(), +- fused_nodes_and_graphs, +- qnn_models_, +- context_cache_path, +- qnn_context_embed_mode_, +- logger)); ++ ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(qnn_ep_context_model_.get(), ++ context_buffer.get(), ++ buffer_size, ++ qnn_backend_manager_->GetSdkVersion(), ++ fused_nodes_and_graphs, ++ qnn_models_, ++ context_cache_path, ++ qnn_context_embed_mode_, ++ logger)); + } + return Status::OK(); + } +diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h +index 09bcb24db..3f75be0ef 100644 +--- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h ++++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h +@@ -5,7 +5,6 @@ + + #include "core/framework/execution_provider.h" + #include "core/framework/session_options.h" +-#include "core/framework/model_metadef_id_generator.h" + #include "core/graph/model.h" + #include + #include "core/providers/qnn/builder/qnn_backend_manager.h" +@@ -72,7 +71,6 @@ class QNNExecutionProvider : public IExecutionProvider { + bool qnn_context_embed_mode_ = true; + int32_t vtcm_size_in_mb_ = 0; + std::unique_ptr qnn_ep_context_model_; +- ModelMetadefIdGenerator metadef_id_generator_; + }; + + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/rocm/atomic/common.cuh b/onnxruntime/core/providers/rocm/atomic/common.cuh +index b5d01b91c..4e2357020 100644 +--- a/onnxruntime/core/providers/rocm/atomic/common.cuh ++++ b/onnxruntime/core/providers/rocm/atomic/common.cuh +@@ -59,304 +59,5 @@ __device__ __forceinline__ void AtomicAdd(T *start_addr, size_t index, const siz + atomic_add(start_addr + index, value); + } + +-// Disable default template instantiation. +-// For every type T, we need to define a specialization +-// to select the right type for calling atomicCAS. +-template +-class AtomicCasType; +- +-template<> +-class AtomicCasType { +- public: +- using type = unsigned short int; +- static const unsigned int mask = 0xffu; +-}; +- +-template<> +-class AtomicCasType { +- public: +- using type = unsigned short int; +- static const unsigned int mask = 0xffffu; +-}; +- +-template<> +-class AtomicCasType { +- public: +- using type = unsigned int; +- static const unsigned int mask = 0xffffffffu; +-}; +- +-template<> +-class AtomicCasType { +- public: +- using type = unsigned long long int; +- static const unsigned int mask = 0xffffffffu; +-}; +- +-template<> +-class AtomicCasType { +- public: +- using type = int; +- static const unsigned int mask = 0xffffffffu; +-}; +- +-template<> +-class AtomicCasType { +- public: +- using type = unsigned long long int; +- static const unsigned int mask = 0xffffffffu; +-}; +- +-// Obtained from pytorch/aten/src/ATen/cuda/Atomic.cuh. +-// +-// This function compute 8-bit atomic binary operation using 32-bit atomicCAS. +-// It accumulate `val` into the `address` using the `func`. +-// The accumulation is atomic (i.e., thread-safe). +-// +-// E.g., Assume ValueType is +-// int8_t +-// and BinaryFunc is +-// struct AddFunc { +-// __device__ __forceinline__ int8_t operator()(int8_t a, int8_t b) const { +-// return a + b; +-// } +-// This function becomes atomic_add for int8_t. +-template +-__device__ __forceinline__ void atomic_byte_func_with_unit32_cas(ValueType* address, ValueType val, BinaryFunc func) { +- // Assert to ensure the following bit-wise manipulation is correct. +- static_assert(sizeof(ValueType) == 1 | sizeof(ValueType) == 2 | sizeof(ValueType) == 4, +- "ValueType must be 1-byte, 2-byte or 4-byte large."); +- // Number of bytes to the lower 4-byte aligned address. +- // If the current address is b1010"10", then offset = b10 = 2, +- // which means the current address is 2 bytes away from +- // the lower 4-byte aligned address b1010"00". +- size_t offset = (size_t)address & 3; +- // Find an new 4-byte aligned address `address_as_ui` lower than +- // or equal to `address`. Lower than `address` so that the actual +- // int8_t byte is in the 4-byte word that we load. +- // +- // This address has the following properties: +- // 1. It is 4-byte aligned. +- // 2. It is lower than or equal to `address`. +- // 3. De-referencing this address may return +- // a uint32_t value that contains the same int8_t +- // value indicated by `address`. +- // +- // E.g., +- // address = b101010 +- // offset = b101010 & b000011 = b10 = 2 +- // (char*)address - offset => (char*)b101010 - b000010 => b1010"00", +- // which is (32-bit aligned). +- uint32_t * address_as_ui = (uint32_t*)((char*)address - offset); +- uint32_t old = *address_as_ui; +- // E.g., offset = 2. +- // address_as_ui is an address 2 bytes lower than `address`. +- // +- // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... +- // ^ ^ ^ +- // | | | +- // | address <--- offset * 8 (bit)-----> address_as_ui +- // | ^ +- // | | +- // ------------------------- *address_as_ui ----------------------- +- // +- // This visualization shows +- // 1. the 32-bit word at address_as_ui. +- // 2. the gap between address_as_ui and address. +- // 3. *address_as_ui contains the int8_t value at `address`. +- uint32_t shift = offset * 8; +- uint32_t old_byte; +- uint32_t newval; +- uint32_t assumed; +- do { +- assumed = old; +- // Select 8-bit value from 32-bit word. Assume offset = 2 (byte), so +- // we want to select the 3rd byte (byte 2 below) from the word. +- // +- // Journey of a 32-bit value: +- // +- // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... +- // +- // | +- // | old >> offset * 8, where offset = 2. +- // | Effectively, push lower two bytes +- // | out of the word. +- // V +- // +- // 00000000 | 00000000 | ..... byte 3 ..... | ..... byte 2 ..... +- // +- // | apply bit-wise AND, +- // | & 0xff (i.e., & b11111111), +- // | so that we only keep +- // | the byte of interest. +- // | Otherwise, overflow may +- // | happen when casting this +- // | 32-bit value to int8_t. +- // V +- // +- // 00000000 | 00000000 | 00000000 | ..... byte 2 ..... +- old_byte = (old >> shift) & AtomicCasType::mask; +- // Compute new int8_t value and store it to newrawvalue. +- // Journey of a 32-bit value (cont'd): +- // +- // newrawvalue +- // ... new byte 2 ... +- auto newrawvalue = func(val, reinterpret_cast(old_byte)); +- // Put the new int8_t value back to 32-bit word. +- // Also ensure that bits not occupied by the int8_t value are 0s. +- // +- // Journey of a 32-bit value (cont'd): +- // +- // reinterpret_cast(newrawvalue) +- // random values | random values | random values | ... new byte 2 ... +- // +- // reinterpret_cast(newrawvalue) & AtomicCasType::mask +- // 00000000 | 00000000 | 00000000 | ... new byte 2 ... +- newval = reinterpret_cast(newrawvalue) & AtomicCasType::mask; +- // Journey of a 32-bit value (cont'd): +- // +- // old +- // ..... byte 3 ..... | ..... byte 2 ..... | ..... byte 1 ..... | ..... byte 0 ..... +- // +- // 0x000000ff +- // 00000000 | 00000000 | 00000000 | 11111111 +- // +- // 0x000000ff << shift +- // 00000000 | 11111111 | 00000000 | 00000000 +- // +- // ~(0x000000ff << shift) +- // 11111111 | 00000000 | 11111111 | 11111111 +- // +- // old & ~(0x000000ff << shift) +- // ..... byte 3 ..... | 00000000 | ..... byte 1 ..... | ..... byte 0 ..... +- // +- // newval << shift +- // 00000000 | ... new byte 2 ... | 00000000 | 00000000 +- // +- // (old & ~(0x000000ff << shift)) | (newval << shift) +- // ..... byte 3 ..... | ... new byte 2 ... | ..... byte 1 ..... | ..... byte 0 ..... +- newval = (old & ~(AtomicCasType::mask << shift)) | (newval << shift); +- old = atomicCAS(address_as_ui, assumed, newval); +- } while (assumed != old); +-} +- +-// It accumulates `val` into the `address` using the `func`. +-// This function is thread-safe (i.e., atomic). +-template +-__device__ __forceinline__ void atomic_binary_func(ValueType* address, ValueType val, BinaryFunc func) { +- ValueType observed = *address, assumed, new_value; +- using CasType = typename AtomicCasType::type; +- static_assert(sizeof(ValueType) == sizeof(CasType), +- "ValueType and CasType must have the same size for calling atomicCAS."); +- auto address_as_cas_type = reinterpret_cast(address); +- do { +- // Record the value used to compute new value. +- assumed = observed; +- +- // Compute expected new value. +- new_value = func(observed, val); +- +- // Cast to aribitrary 2-byte type to desired integer type supported by atomicCAS. +- // 4 +- // 8 +- auto observed_as_cas_type = *reinterpret_cast(&observed); +- auto new_value_as_cas_type = *reinterpret_cast(&new_value); +- +- // Call atomicCAS as if the 2-byte type variables are all unsigned short int. +- // 4 unsigned int (or int) +- // 8 unsigned long long int +- auto cas_observed_as_cas_type = atomicCAS(address_as_cas_type, observed_as_cas_type, new_value_as_cas_type); +- +- // Cast the freshly observed value in memory back to the TwoByteType. +- observed = *reinterpret_cast(&cas_observed_as_cas_type); +- +- // Two cases: +- // 1. compare-and-swap success +- // a. `address` holds `new_value` +- // b. `observed` becomes the new value after the assignment. +- // Thus, the following `observed != new_value` is false, +- // and the loop terminates. +- // 2. compare-and-swap fails +- // a. `address` holds a value different from `observed`, thus, +- // the `new_value` is stale. +- // b. `observed` becomes the fresh value observed in `address`. +- // Thus, the following (observed != new_value) is true, +- // and the loop continues. In the next iteration, the +- // `new_value` is computed again using the fresh `observed`. +- } while (observed != assumed); +-} +- +-struct AddFunc { +- template +- __device__ __forceinline__ T operator()(T a, T b) const { +- return a + b; +- } +-}; +- +-struct MulFunc { +- template +- __device__ __forceinline__ T operator()(T a, T b) const { +- return a * b; +- } +-}; +- +-struct MaxFunc { +- template +- __device__ __forceinline__ T operator()(T a, T b) const { +- return b > a ? b : a; +- } +-}; +- +-struct MinFunc { +- template +- __device__ __forceinline__ T operator()(T a, T b) const { +- return b < a ? b : a; +- } +-}; +- +-__device__ __forceinline__ void atomic_add(int8_t* address, int8_t value) { +- atomic_byte_func_with_unit32_cas(address, value, AddFunc()); +-} +-__device__ __forceinline__ void atomic_mul(int8_t* address, int8_t value) { +- atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +-} +-__device__ __forceinline__ void atomic_max(int8_t* address, int8_t value) { +- atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +-} +-__device__ __forceinline__ void atomic_min(int8_t* address, int8_t value) { +- atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +-} +- +-__device__ __forceinline__ void atomic_mul(half* address, half value) { +- atomic_byte_func_with_unit32_cas(address, value, MulFunc()); +-} +-__device__ __forceinline__ void atomic_max(half* address, half value) { +- atomic_byte_func_with_unit32_cas(address, value, MaxFunc()); +-} +-__device__ __forceinline__ void atomic_min(half* address, half value) { +- atomic_byte_func_with_unit32_cas(address, value, MinFunc()); +-} +- +-__device__ __forceinline__ void atomic_mul(float* address, float value) { +- atomic_binary_func(address, value, MulFunc()); +-} +-__device__ __forceinline__ void atomic_max(float* address, float value) { +- atomic_binary_func(address, value, MaxFunc()); +-} +-__device__ __forceinline__ void atomic_min(float* address, float value) { +- atomic_binary_func(address, value, MinFunc()); +-} +- +-__device__ __forceinline__ void atomic_mul(double* address, double value) { +- atomic_binary_func(address, value, MulFunc()); +-} +-__device__ __forceinline__ void atomic_max(double* address, double value) { +- atomic_binary_func(address, value, MaxFunc()); +-} +-__device__ __forceinline__ void atomic_min(double* address, double value) { +- atomic_binary_func(address, value, MinFunc()); +-} +- +- + } // namespace rocm + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +index ee3578326..d7c5098d9 100644 +--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc ++++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +@@ -170,8 +170,6 @@ ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de + + MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_)); + MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream)); +- +- hip_graph_.SetStream(stream); + } + + ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { +@@ -179,33 +177,6 @@ ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { + ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_))); + } + +-bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { +- return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_; +-} +- +-void ROCMExecutionProvider::PerThreadContext::CaptureBegin() { +- hip_graph_.Reset(); +- hip_graph_.CaptureBegin(); +-} +- +-void ROCMExecutionProvider::PerThreadContext::CaptureEnd() { +- hip_graph_.CaptureEnd(); +- is_graph_captured_ = true; +-} +- +-bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured() const { +- return is_graph_captured_; +-} +- +-Status ROCMExecutionProvider::PerThreadContext::ReplayGraph() { +- ORT_ENFORCE(IsGraphCaptured()); +- return hip_graph_.Replay(); +-} +- +-void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { +- ++regular_run_count_before_graph_capture_; +-} +- + void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { + if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable( + "ORT_ROCM_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead."); +@@ -248,11 +219,6 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in + if (info.external_allocator_info.UseExternalAllocator()) { + use_ep_level_unified_stream_ = true; + stream_ = nullptr; +- } else if (info.enable_hip_graph) { +- // current hip graph implementation only works with single stream +- // use EP level unified stream for all the reqeust +- HIP_CALL_THROW(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking)); +- use_ep_level_unified_stream_ = true; + } else { + stream_ = nullptr; + } +@@ -356,58 +322,25 @@ Status ROCMExecutionProvider::Sync() const { + Status ROCMExecutionProvider::OnRunStart() { + // always set ROCM device when session::Run() in case it runs in a worker thread + HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); +- if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { +- LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model"; +- GetPerThreadContext().CaptureBegin(); +- } + return Status::OK(); + } + + Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { +- if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { +- if (GetPerThreadContext().IsGraphCaptureAllowed()) { +- GetPerThreadContext().CaptureEnd(); +- // HIP work issued to a capturing stream doesn’t actually run on the GPU, +- // so run the captured graph here to actually execute the work. +- ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph()); +- } else { +- GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(); +- } +- } +- + if (sync_stream) { + HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream_))); + } + +- // The reason of !IsGraphCaptureEnabled(): +- // If hip graph is enabled, the per thread context will not be released +- // because the per thread hip graph needs to be maintained and replayed for +- // the next run. +- // The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end(): +- // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), +- // PerThreadContext won't be created and there is nothing to +- // release. This didn't happen before because we always call +- // GetPerThreadContext in OnRunStart. +- if (!IsGraphCaptureEnabled() && +- PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { ++ // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), ++ // PerThreadContext won't be created and there is nothing to ++ // release. This didn't happen before because we always call ++ // GetPerThreadContext in OnRunStart. ++ if (PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { + ReleasePerThreadContext(); + } + + return Status::OK(); + } + +-bool ROCMExecutionProvider::IsGraphCaptureEnabled() const { +- return info_.enable_hip_graph; +-} +- +-bool ROCMExecutionProvider::IsGraphCaptured() const { +- return GetPerThreadContext().IsGraphCaptured(); +-} +- +-Status ROCMExecutionProvider::ReplayGraph() { +- return GetPerThreadContext().ReplayGraph(); +-} +- + namespace rocm { + // opset 1 to 9 + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); +@@ -1069,7 +1002,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDom + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 14, Shape); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Size); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Transpose); +-class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); ++class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterElements); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Slice); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, Slice); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Softmax); +@@ -1158,10 +1091,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Identity); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND); +-class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Pad); +-class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Pad); +-class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad); +-class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, bool, Pad); ++class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Pad); ++class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Pad); ++class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Pad); ++class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, bool, Pad); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, SpaceToDepth); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, DepthToSpace); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int8_t, Sign); +@@ -1290,7 +1223,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, float, LessOrEqual); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); +-class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); + + // Opset 17 + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); +@@ -1299,11 +1231,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, MLFloat16, LayerNormalization); + + // Opset 18 +-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Pad); +-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Pad); +-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); +-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); +-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); + class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); + + // Opset 19 +@@ -2005,7 +1932,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +- BuildKernelCreateInfo, ++ BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +@@ -2094,10 +2021,10 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, ++ BuildKernelCreateInfo, ++ BuildKernelCreateInfo, ++ BuildKernelCreateInfo, ++ BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +@@ -2226,7 +2153,6 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +- BuildKernelCreateInfo, + + // Opset 17 + BuildKernelCreateInfo, +@@ -2235,11 +2161,6 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { + BuildKernelCreateInfo, + + // Opset 18 +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, +- BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // Opset 19 +diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h +index 37d5f7b42..c4945b9ac 100644 +--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h ++++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h +@@ -10,7 +10,6 @@ + #include "core/framework/execution_provider.h" + #include "core/platform/ort_mutex.h" + #include "core/providers/rocm/rocm_execution_provider_info.h" +-#include "core/providers/rocm/rocm_graph.h" + #include "core/providers/rocm/rocm_pch.h" + #include "core/providers/rocm/shared_inc/rocm_utils.h" + #include "core/providers/rocm/shared_inc/rocm_call.h" +@@ -74,9 +73,6 @@ class ROCMExecutionProvider : public IExecutionProvider { + + std::unique_ptr GetProfiler() override; + +- bool IsGraphCaptureEnabled() const override; +- bool IsGraphCaptured() const override; +- Status ReplayGraph() override; + void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; + OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; + std::vector CreatePreferredAllocators() override; +@@ -85,7 +81,6 @@ class ROCMExecutionProvider : public IExecutionProvider { + ROCMExecutionProviderInfo info_; + hipDeviceProp_t device_prop_; + bool external_stream_ = false; +- // only used when set user external stream or hip graph + hipStream_t stream_ = nullptr; + + bool use_ep_level_unified_stream_ = false; +@@ -138,13 +133,6 @@ class ROCMExecutionProvider : public IExecutionProvider { + } + } + +- bool IsGraphCaptureAllowed() const; +- void CaptureBegin(); +- void CaptureEnd(); +- bool IsGraphCaptured() const; +- Status ReplayGraph(); +- void IncrementRegularRunCountBeforeGraphCapture(); +- + private: + rocblas_handle rocblas_handle_ = nullptr; + miopenHandle_t miopen_handle_ = nullptr; +@@ -153,18 +141,6 @@ class ROCMExecutionProvider : public IExecutionProvider { + std::unique_ptr> constant_ones_double_; + std::unique_ptr> constant_ones_half_; + std::unique_ptr> constant_ones_bfloat16_; +- +- // Hip graph with multi threads will be supported in the future, so hip_graph_ +- // is put under PerThreadContext. +- ROCMGraph hip_graph_; +- bool is_graph_captured_ = false; +- int regular_run_count_before_graph_capture_ = 0; +- +- // There is chance that the second regular run allocates GPU memory for causes like: +- // (1) memory pattern is enabled. (2) arena allocation for stream. +- // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs +- // to allocate enough memory in Arena before graph capturing. +- const int min_num_runs_before_hip_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations. + }; + + using PerThreadContextMap = std::unordered_map>; +diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +index b557f9228..650635c15 100644 +--- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc ++++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +@@ -21,7 +21,6 @@ constexpr const char* kGpuExternalAlloc = "gpu_external_alloc"; + constexpr const char* kGpuExternalFree = "gpu_external_free"; + constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache"; + constexpr const char* kMiopenConvUseMaxWorkspace = "miopen_conv_use_max_workspace"; +-constexpr const char* kEnableHipGraph = "enable_hip_graph"; + constexpr const char* kTunableOpEnable = "tunable_op_enable"; + constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; + constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; +@@ -85,7 +84,6 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P + info.miopen_conv_exhaustive_search) + .AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) + .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace) +- .AddAssignmentToReference(rocm::provider_option_names::kEnableHipGraph, info.enable_hip_graph) + .AddValueParser( + rocm::provider_option_names::kTunableOpEnable, + [&info](const std::string& value_str) -> Status { +@@ -123,7 +121,6 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution + {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, + {rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, + {rocm::provider_option_names::kMiopenConvUseMaxWorkspace, MakeStringWithClassicLocale(info.miopen_conv_use_max_workspace)}, +- {rocm::provider_option_names::kEnableHipGraph, MakeStringWithClassicLocale(info.enable_hip_graph)}, + {rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)}, + {rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, + {rocm::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, +diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +index 2f549cc1a..e35c0cc0a 100644 +--- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h ++++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +@@ -63,8 +63,6 @@ struct ROCMExecutionProviderInfo { + // If set to false, use fix workspace size (32M) for Conv algo search, the final algo might not be the best. + bool miopen_conv_use_max_workspace{true}; + +- bool enable_hip_graph{false}; +- + rocm::TunableOpInfo tunable_op{}; + + static ROCMExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); +diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +index 88ef66667..4d88c2546 100644 +--- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc ++++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +@@ -185,7 +185,6 @@ struct ROCM_Provider : Provider { + info.has_user_compute_stream = params->has_user_compute_stream != 0; + info.user_compute_stream = params->user_compute_stream; + info.default_memory_arena_cfg = params->default_memory_arena_cfg; +- info.enable_hip_graph = params->enable_hip_graph; + info.tunable_op.enable = params->tunable_op_enable; + info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; + info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; +@@ -216,7 +215,6 @@ struct ROCM_Provider : Provider { + rocm_options.user_compute_stream = internal_options.user_compute_stream; + } + rocm_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg; +- rocm_options.enable_hip_graph = internal_options.enable_hip_graph; + rocm_options.tunable_op_enable = internal_options.tunable_op.enable; + rocm_options.tunable_op_tuning_enable = internal_options.tunable_op.tuning_enable; + rocm_options.tunable_op_max_tuning_duration_ms = internal_options.tunable_op.max_tuning_duration_ms; +diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h +index b78279040..53ba4874c 100644 +--- a/onnxruntime/core/providers/shared_library/provider_api.h ++++ b/onnxruntime/core/providers/shared_library/provider_api.h +@@ -95,15 +95,12 @@ enum OperatorStatus : int { + }; + + // onnx Protobuf types (All of these are direct mappings to the onnx types except for the Repeated*Field ones which map to a Repeated*Field type) +-struct int64s; // RepeatedField +-struct float32s; // RepeatedField ++struct int64s; // RepeatedField + struct AttributeProto; + struct GraphProto; + struct ModelProto; + struct NodeProto; + struct SparseTensorProto; +-struct StringStringEntryProto; +-struct StringStringEntryProtos; // RepeatedPtrField + struct TensorProto; + struct TensorProtos; // RepeatedPtrField + struct TensorShapeProto_Dimension; +@@ -116,9 +113,6 @@ struct TypeProto_Sequence; + struct TypeProto; + struct ValueInfoProto; + struct ValueInfoProtos; // RepeatedPtrField +-struct InferenceContext; +-class GraphInferencer; +-using InferenceFunction = std::function; + } // namespace ONNX_NAMESPACE + + namespace onnxruntime { +@@ -148,7 +142,7 @@ struct KernelDefBuilder; + struct KernelRegistry; + struct Function; + struct Graph; +-class GraphViewer; ++struct GraphViewer; + enum class DataLayout; + struct Model; + struct Path; +@@ -163,7 +157,6 @@ struct Tensor; + struct SparseTensor; + class TensorSeq; + class SessionState; +-class ModelMetadefIdGenerator; + + class If; + class Loop; +@@ -255,7 +248,6 @@ constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; + constexpr const char* kCannExecutionProvider = "CANNExecutionProvider"; + constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; + constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; +-constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider"; + constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider"; + constexpr const char* kTensorrtExecutionProvider = "TensorrtExecutionProvider"; + constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider"; +diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +index da1713587..a3155fe6b 100644 +--- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc ++++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +@@ -329,6 +329,10 @@ common::Status IExecutionProvider::Compile(const std::vector& + return g_host->IExecutionProvider__Compile(this, fused_nodes_and_graphs, node_compute_funcs); + } + ++int IExecutionProvider::GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const { ++ return g_host->IExecutionProvider__GenerateMetaDefId(this, graph_viewer, model_hash); ++} ++ + #ifdef USE_TENSORRT + std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name) { + return g_host->CreateCUDAAllocator(device_id, name); +@@ -492,10 +496,6 @@ template <> + Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } + template <> + Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } +-Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, +- /*out*/ std::vector& unpacked_tensor) { +- return g_host->UnpackInitializerData(tensor, model_path, unpacked_tensor); +-} + + } // namespace utils + +@@ -547,14 +547,7 @@ Status ScatterND::ValidateShapes(const TensorShape& input_shape, + const TensorShape& indice_shape, + const TensorShape& update_shape) { return g_host_cpu.ScatterNDBase__ValidateShapes(input_shape, indice_shape, update_shape); } + +-Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, const TensorShape& output_shape) { +- return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape); +-} +- +-void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, +- PadsVector& pads) { +- g_host_cpu.PadBase__ComputePads(ctx, data_rank, pads_data, pads); +-} ++Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_shape, TensorShape& output_shape) { return g_host_cpu.PadBase__HandleDimValueZero(mode, input_shape, output_shape); } + + Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, const ConcatBase::InlinedTensorsVector& input_tensors, + Prepare& p) const { +diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h +index f5a832744..21c14ce78 100644 +--- a/onnxruntime/core/providers/shared_library/provider_interfaces.h ++++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h +@@ -91,7 +91,6 @@ using HashValue = uint64_t; + using NodeIndex = size_t; + // We can't just reinterpret_cast this one, since it's an unordered_map of object BY VALUE (can't do anything by value on the real types) + // using NodeAttributes = std::unordered_map; +-using ModelMetaData = std::unordered_map; + + using InitializedTensorSet = std::unordered_map; + +@@ -202,8 +201,6 @@ struct ProviderHost { + virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint32_t* p_data, size_t expected_size) = 0; + virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) = 0; + virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) = 0; +- virtual Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, +- /*out*/ std::vector& unpacked_tensor) = 0; + + virtual uint16_t math__floatToHalf(float f) = 0; + virtual float math__halfToFloat(uint16_t h) = 0; +@@ -232,6 +229,8 @@ struct ProviderHost { + + virtual common::Status IExecutionProvider__Compile(IExecutionProvider* p, const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) = 0; + ++ virtual int IExecutionProvider__GenerateMetaDefId(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) = 0; ++ + // Status + virtual std::string Status__ToString(const Status* p) = 0; + +@@ -264,32 +263,12 @@ struct ProviderHost { + virtual void logging__Capture__operator_delete(logging::Capture* p) noexcept = 0; + virtual std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept = 0; + +- // Env +- virtual Env& Env__Default() = 0; +- + // Utils::DataTypeUtils + virtual const std::string* Utils__DataTypeUtils__ToType(const ONNX_NAMESPACE::TypeProto& type_proto) = 0; + + // int64s + virtual int int64s__size(const ONNX_NAMESPACE::int64s* p) = 0; + virtual const int64_t& int64s__Get(const ONNX_NAMESPACE::int64s* p, int index) = 0; +- virtual void int64s__Reserve(ONNX_NAMESPACE::int64s* p, int size) = 0; +- virtual const int64_t* int64s__data(const ONNX_NAMESPACE::int64s* p) = 0; +- +- // float32s +- virtual void float32s__Reserve(ONNX_NAMESPACE::float32s* p, int size) = 0; +- virtual const float* float32s__data(const ONNX_NAMESPACE::float32s* p) = 0; +- virtual int float32s__size(const ONNX_NAMESPACE::float32s* p) = 0; +- +- // StringStringEntryProto +- virtual std::string* StringStringEntryProto__mutable_key(ONNX_NAMESPACE::StringStringEntryProto* p) = 0; +- virtual std::string* StringStringEntryProto__mutable_value(ONNX_NAMESPACE::StringStringEntryProto* p) = 0; +- +- // StringStringEntryProtos +- virtual void StringStringEntryProtos__Clear(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; +- virtual ONNX_NAMESPACE::StringStringEntryProto* StringStringEntryProtos__Add(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; +- virtual int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; +- virtual ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) = 0; + + #if !defined(DISABLE_OPTIONAL_TYPE) + // TypeProto_Optional +@@ -306,7 +285,6 @@ struct ProviderHost { + virtual const ONNX_NAMESPACE::TensorShapeProto& TypeProto_Tensor__shape(const ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; + virtual ONNX_NAMESPACE::TensorShapeProto* TypeProto_Tensor__mutable_shape(ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; + virtual int32_t TypeProto_Tensor__elem_type(const ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; +- virtual void TypeProto_Tensor__set_elem_type(ONNX_NAMESPACE::TypeProto_Tensor* p, int32_t value) = 0; + + #if !defined(DISABLE_SPARSE_TENSORS) + // TypeProto_SparseTensor +@@ -351,17 +329,9 @@ struct ProviderHost { + virtual float AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p, int i) = 0; + virtual const ::std::string& AttributeProto__strings(const ONNX_NAMESPACE::AttributeProto* p, int i) = 0; + virtual const ONNX_NAMESPACE::int64s& AttributeProto__ints(const ONNX_NAMESPACE::AttributeProto* p) = 0; +- virtual const ONNX_NAMESPACE::float32s& AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p) = 0; +- virtual ONNX_NAMESPACE::int64s* AttributeProto__mutable_ints(ONNX_NAMESPACE::AttributeProto* p) = 0; +- virtual ONNX_NAMESPACE::float32s* AttributeProto__mutable_floats(ONNX_NAMESPACE::AttributeProto* p) = 0; +- virtual void AttributeProto__add_ints(ONNX_NAMESPACE::AttributeProto* p, int64_t size) = 0; +- virtual void AttributeProto__add_floats(ONNX_NAMESPACE::AttributeProto* p, float size) = 0; +- virtual void AttributeProto__add_strings(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& size) = 0; + virtual int64_t AttributeProto__i(const ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual float AttributeProto__f(const ONNX_NAMESPACE::AttributeProto* p) = 0; +- virtual const ONNX_NAMESPACE::TensorProto& AttributeProto__t(const ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual void AttributeProto__set_s(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0; +- virtual void AttributeProto__set_f(ONNX_NAMESPACE::AttributeProto* p, const float& value) = 0; + virtual void AttributeProto__set_i(ONNX_NAMESPACE::AttributeProto* p, int64_t value) = 0; + virtual const ::std::string& AttributeProto__s(const ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0; +@@ -384,7 +354,6 @@ struct ProviderHost { + virtual ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_value_info(ONNX_NAMESPACE::GraphProto* p) = 0; + virtual ONNX_NAMESPACE::TensorProtos* GraphProto__mutable_initializer(ONNX_NAMESPACE::GraphProto* p) = 0; + virtual ONNX_NAMESPACE::NodeProto* GraphProto__add_node(ONNX_NAMESPACE::GraphProto* p) = 0; +- virtual std::string* GraphProto__mutable_name(ONNX_NAMESPACE::GraphProto* p) = 0; + virtual ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) = 0; + + // ModelProto +@@ -400,7 +369,6 @@ struct ProviderHost { + virtual ONNX_NAMESPACE::GraphProto* ModelProto__mutable_graph(ONNX_NAMESPACE::ModelProto* p) = 0; + + virtual void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) = 0; +- virtual ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) = 0; + + // NodeProto + virtual std::unique_ptr NodeProto__construct() = 0; +@@ -415,33 +383,19 @@ struct ProviderHost { + virtual void TensorProto__operator_delete(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) = 0; + virtual bool TensorProto__has_name(const ONNX_NAMESPACE::TensorProto* p) = 0; +- virtual void TensorProto__set_name(ONNX_NAMESPACE::TensorProto* p, const ::std::string& name) = 0; +- virtual const ::std::string& TensorProto__name(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual int TensorProto__dims_size(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual const ONNX_NAMESPACE::int64s& TensorProto__dims(const ONNX_NAMESPACE::TensorProto* p) = 0; +- virtual void TensorProto__add_dims(ONNX_NAMESPACE::TensorProto* p, int64_t value) = 0; + virtual bool TensorProto__has_data_location(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual int TensorProto__data_location(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0; +- virtual std::string* TensorProto__mutable_raw_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) = 0; +- virtual void TensorProto__set_data_type(ONNX_NAMESPACE::TensorProto* p, int32_t type) = 0; + virtual void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) = 0; +- virtual ONNX_NAMESPACE::StringStringEntryProtos* TensorProto__mutable_external_data(ONNX_NAMESPACE::TensorProto* p) = 0; +- virtual void TensorProto__clear_float_data(ONNX_NAMESPACE::TensorProto* p) = 0; +- virtual void TensorProto__clear_int32_data(ONNX_NAMESPACE::TensorProto* p) = 0; +- virtual void TensorProto__clear_string_data(ONNX_NAMESPACE::TensorProto* p) = 0; +- virtual void TensorProto__clear_int64_data(ONNX_NAMESPACE::TensorProto* p) = 0; +- virtual void TensorProto__clear_double_data(ONNX_NAMESPACE::TensorProto* p) = 0; +- virtual void TensorProto__clear_uint64_data(ONNX_NAMESPACE::TensorProto* p) = 0; + + virtual bool TensorProto_DataType_IsValid(int value) = 0; + + // TensorProtos + virtual ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) = 0; +- virtual int TensorProtos__size(ONNX_NAMESPACE::TensorProtos* p) = 0; +- virtual ONNX_NAMESPACE::TensorProto& TensorProtos__at(ONNX_NAMESPACE::TensorProtos* p, int index) = 0; + + // TensorShapeProto_Dimension + virtual int TensorShapeProto_Dimension__value_case(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; +@@ -451,8 +405,6 @@ struct ProviderHost { + virtual bool TensorShapeProto_Dimension__has_dim_value(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; + virtual bool TensorShapeProto_Dimension__has_dim_param(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; + virtual void TensorShapeProto_Dimension__clear_dim_value(ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; +- virtual const std::string& TensorShapeProto_Dimension__denotation(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) const = 0; +- virtual void TensorShapeProto_Dimension__set_denotation(ONNX_NAMESPACE::TensorShapeProto_Dimension* p, const std::string& value) = 0; + + // TensorShapeProto_Dimensions + virtual std::unique_ptr TensorShapeProto_Dimensions__begin(const ONNX_NAMESPACE::TensorShapeProto_Dimensions* p) = 0; +@@ -476,8 +428,6 @@ struct ProviderHost { + + virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0; + +- virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0; +- + // ConfigOptions + virtual std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) = 0; + +@@ -703,7 +653,6 @@ struct ProviderHost { + virtual void Node__ToProto(const Node* p, ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) = 0; + + virtual const NodeAttributes& Node__GetAttributes(const Node* p) noexcept = 0; +- virtual void Node__AddAttribute(Node* p, const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) = 0; + virtual size_t Node__GetInputEdgesCount(const Node* p) noexcept = 0; + virtual size_t Node__GetOutputEdgesCount(const Node* p) noexcept = 0; + +@@ -713,13 +662,10 @@ struct ProviderHost { + virtual std::unique_ptr Node__OutputNodesBegin(const Node* p) noexcept = 0; + virtual std::unique_ptr Node__OutputNodesEnd(const Node* p) noexcept = 0; + +- virtual std::unique_ptr Node__InputEdgesBegin(const Node* p) noexcept = 0; +- virtual std::unique_ptr Node__InputEdgesEnd(const Node* p) noexcept = 0; + virtual std::unique_ptr Node__OutputEdgesBegin(const Node* p) noexcept = 0; + virtual std::unique_ptr Node__OutputEdgesEnd(const Node* p) noexcept = 0; + + virtual void Node__ForEachDef(const Node* p, std::function func, bool include_missing_optional_defs) = 0; +- virtual int Node__NodeType(const Node* p) const noexcept = 0; + virtual const std::unordered_map>& Node__GetAttributeNameToMutableSubgraphMap(Node* p) = 0; + virtual std::unordered_map> Node__GetAttributeNameToSubgraphMap(const Node* p) const = 0; + +@@ -730,7 +676,6 @@ struct ProviderHost { + virtual const ONNX_NAMESPACE::NodeArgInfo& NodeArg__ToProto(const NodeArg* p) noexcept = 0; + virtual bool NodeArg__Exists(const NodeArg* p) const noexcept = 0; + virtual const ONNX_NAMESPACE::TypeProto* NodeArg__TypeAsProto(const NodeArg* p) noexcept = 0; +- virtual Status NodeArg__OverrideTypesHelper(NodeArg* p, const ONNX_NAMESPACE::TypeProto& input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types) = 0; + + // NodeAttributes + virtual std::unique_ptr NodeAttributes__construct() = 0; +@@ -748,18 +693,12 @@ struct ProviderHost { + virtual std::unique_ptr NodeAttributes__find(const NodeAttributes* p, const std::string& key) = 0; + virtual void NodeAttributes__insert(NodeAttributes* p, const NodeAttributes& v) = 0; + virtual void NodeAttributes__emplace(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) = 0; +- virtual void NodeAttributes__insert_or_assign(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) = 0; + virtual void NodeAttributes__reserve(NodeAttributes* p, size_t size) = 0; + + // Model +- virtual std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, +- const PathString& model_path, const logging::Logger& logger) = 0; + virtual void Model__operator_delete(Model* p) = 0; + virtual Graph& Model__MainGraph(Model* p) = 0; + virtual std::unique_ptr Model__ToProto(Model* p) = 0; +- virtual std::unique_ptr Model__ToGraphProtoWithExternalInitializers(Model* p, const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) = 0; +- virtual const ModelMetaData& Model__MetaData(const Model* p) const noexcept = 0; +- virtual Status Model__Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) = 0; + + // Graph + virtual std::unique_ptr Graph__CreateGraphViewer(const Graph* p) = 0; +@@ -777,7 +716,6 @@ struct ProviderHost { + virtual void Graph__SetOutputs(Graph* p, gsl::span outputs) = 0; + + virtual const std::vector& Graph__GetInputs(const Graph* p) noexcept = 0; +- virtual std::vector Graph__Nodes(const Graph* p) = 0; + virtual bool Graph__GetInitializedTensor(const Graph* p, const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) = 0; + + virtual const Node* Graph__ParentNode(const Graph* p) const = 0; +@@ -787,26 +725,6 @@ struct ProviderHost { + virtual const Path& Graph__ModelPath(const Graph* p) const = 0; + virtual const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept = 0; + virtual bool Graph__IsSubgraph(const Graph* p) = 0; +- virtual const Node* Graph__GetProducerNode(const Graph* p, const std::string& node_arg_name) const = 0; +- virtual const Model& Graph__GetModel(const Graph* p) = 0; +- virtual void Graph__ReverseDFSFrom(const Graph* p, gsl::span from, +- const std::function& enter, +- const std::function& leave, +- const std::function& comp, +- const std::function& stop) const = 0; +- virtual Graph& Graph__SetGraphResolveNeeded(Graph* p) = 0; +- virtual void Graph__RemoveInitializedTensor(Graph* p, const std::string& tensor_name) = 0; +- +- virtual std::vector Graph__GetConsumerNodes(const Graph* p, const std::string& node_arg_name) const = 0; +- virtual void Graph__AddEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, +- int dst_arg_index) = 0; +- virtual void Graph__RemoveEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, +- int dst_arg_index) = 0; +- virtual void Graph__RemoveNode(Graph* p, NodeIndex index) = 0; +- virtual Node& Graph__FuseSubGraph(Graph* p, const IndexedSubGraph& sub_graph, const std::string& fused_node_name) = 0; +- virtual void Graph__UpdateProducerNode(Graph* p, const std::string& node_arg_name, NodeIndex node_index) = 0; +- virtual const ONNX_NAMESPACE::TensorProto* Graph__GetConstantInitializer(const Graph* p, const std::string& name, bool check_outer_scope) const = 0; +- virtual const InitializedTensorSet& Graph__GetAllInitializedTensors(const Graph* p) = 0; + virtual int Graph__MaxNodeIndex(const Graph* p) const noexcept = 0; + virtual Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept = 0; + virtual const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const = 0; +@@ -841,14 +759,11 @@ struct ProviderHost { + virtual const std::vector& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept = 0; + + virtual void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept = 0; +- virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; + + // Path + virtual PathString Path__ToPathString(const Path* p) noexcept = 0; + virtual const std::vector& Path__GetComponents(const Path* p) noexcept = 0; + virtual bool Path__IsEmpty(const Path* p) noexcept = 0; +- virtual std::unique_ptr Path__construct() = 0; +- virtual void Path__operator_delete(ONNX_NAMESPACE::Path* p) = 0; + + // OpKernel + virtual const Node& OpKernel__Node(const OpKernel* p) = 0; +@@ -1057,11 +972,6 @@ struct ProviderHost { + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + virtual Status LoadDynamicLibrary(onnxruntime::PathString library_name) = 0; + #endif +- +- // ModelMetadefIdGenerator +- virtual std::unique_ptr ModelMetadefIdGenerator__construct() = 0; +- virtual void ModelMetadefIdGenerator__operator_delete(ModelMetadefIdGenerator* p) = 0; +- virtual int ModelMetadefIdGenerator__GenerateId(const ModelMetadefIdGenerator* p, const GraphViewer& graph_viewer, HashValue& model_hash) = 0; + }; + + #if defined(_MSC_VER) && !defined(__clang__) +diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +index dde4005c8..eaf8ef459 100644 +--- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h ++++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +@@ -52,34 +52,11 @@ namespace ONNX_NAMESPACE { + struct int64s final { + int size() const { return g_host->int64s__size(this); } + const int64_t& Get(int index) const { return g_host->int64s__Get(this, index); } +- const int64_t* data() const { return g_host->int64s__data(this); } + const int64_t& operator[](int index) const { return Get(index); } +- void Reserve(int size) { g_host->int64s__Reserve(this, size); } +- PROVIDER_DISALLOW_ALL(int64s) +-}; +- +-struct float32s final { +- void Reserve(int size) { g_host->float32s__Reserve(this, size); } +- const float* data() const { return g_host->float32s__data(this); } +- int size() const { return g_host->float32s__size(this); } +- PROVIDER_DISALLOW_ALL(float32s) +-}; + +-struct StringStringEntryProto final { +- std::string* mutable_key() { return g_host->StringStringEntryProto__mutable_key(this); } +- std::string* mutable_value() { return g_host->StringStringEntryProto__mutable_value(this); } +- +- PROVIDER_DISALLOW_ALL(StringStringEntryProto) ++ PROVIDER_DISALLOW_ALL(int64s) + }; + +-struct StringStringEntryProtos final { +- void Clear() { g_host->StringStringEntryProtos__Clear(this); } +- StringStringEntryProto* Add() { return g_host->StringStringEntryProtos__Add(this); } +- int size() { return g_host->StringStringEntryProtos__size(this); } +- StringStringEntryProto& at(int index) { return g_host->StringStringEntryProtos__at(this, index); } +- +- PROVIDER_DISALLOW_ALL(StringStringEntryProtos) +-}; + struct AttributeProto final { + static std::unique_ptr Create() { return g_host->AttributeProto__construct(); } + void operator=(const AttributeProto& v) { g_host->AttributeProto__operator_assign(this, v); } +@@ -94,18 +71,9 @@ struct AttributeProto final { + float floats(int i) const { return g_host->AttributeProto__floats(this, i); } + const std::string& strings(int i) const { return g_host->AttributeProto__strings(this, i); } + const int64s& ints() const { return g_host->AttributeProto__ints(this); } +- const float32s& floats() const { return g_host->AttributeProto__floats(this); } +- int64s* mutable_ints() { return g_host->AttributeProto__mutable_ints(this); } +- float32s* mutable_floats() { return g_host->AttributeProto__mutable_floats(this); } +- void add_ints(int64_t value) { g_host->AttributeProto__add_ints(this, value); } +- void add_floats(float value) { g_host->AttributeProto__add_floats(this, value); } +- void add_strings(const ::std::string& value) { g_host->AttributeProto__add_strings(this, value); } +- + int64_t i() const { return g_host->AttributeProto__i(this); } + float f() const { return g_host->AttributeProto__f(this); } +- const ONNX_NAMESPACE::TensorProto& t() const { return g_host->AttributeProto__t(this); } + void set_s(const ::std::string& value) { return g_host->AttributeProto__set_s(this, value); } +- void set_f(const float& value) { return g_host->AttributeProto__set_f(this, value); } + void set_i(int64_t value) { return g_host->AttributeProto__set_i(this, value); } + const ::std::string& s() const { return g_host->AttributeProto__s(this); } + void set_name(const ::std::string& value) { return g_host->AttributeProto__set_name(this, value); } +@@ -153,8 +121,6 @@ struct GraphProto final { + NodeProto* add_node() { return g_host->GraphProto__add_node(this); } + NodeProto* mutable_node(int index) { return g_host->GraphProto__mutable_node(this, index); } + +- std::string* mutable_name() { return g_host->GraphProto__mutable_name(this); } +- + GraphProto() = delete; + GraphProto(const GraphProto&) = delete; + }; +@@ -167,7 +133,7 @@ struct ModelProto final { + bool SerializeToOstream(std::ostream& output) const { return g_host->ModelProto__SerializeToOstream(this, output); } + bool ParseFromString(const std::string& data) { return g_host->ModelProto__ParseFromString(this, data); } + std::string SerializeAsString() const { return g_host->ModelProto__SerializeAsString(this); } +- StringStringEntryProtos* mutable_metadata_props() { return g_host->ModelProto__mutable_metadata_props(this); }; ++ + const GraphProto& graph() const { return g_host->ModelProto__graph(this); } + GraphProto* mutable_graph() { return g_host->ModelProto__mutable_graph(this); } + +@@ -196,22 +162,17 @@ struct TensorProto final { + void operator=(const TensorProto& v) { g_host->TensorProto__operator_assign(this, v); } + + bool has_name() const { return g_host->TensorProto__has_name(this); } +- void set_name(const ::std::string& name) { return g_host->TensorProto__set_name(this, name); } +- const ::std::string& name() const { return g_host->TensorProto__name(this); } + + int dims_size() const { return g_host->TensorProto__dims_size(this); } + const int64s& dims() const { return g_host->TensorProto__dims(this); } +- void add_dims(int64_t value) { g_host->TensorProto__add_dims(this, value); } + + bool has_data_location() const { return g_host->TensorProto__has_data_location(this); } + TensorProto_DataLocation data_location() const { return TensorProto_DataLocation(g_host->TensorProto__data_location(this)); } + + bool has_raw_data() const { return g_host->TensorProto__has_raw_data(this); } + const std::string& raw_data() const { return g_host->TensorProto__raw_data(this); } +- std::string* mutable_raw_data() { return g_host->TensorProto__mutable_raw_data(this); } + + int32_t data_type() const { return g_host->TensorProto__data_type(this); } +- void set_data_type(int32_t type) { return g_host->TensorProto__set_data_type(this, type); } + + typedef TensorProto_DataType DataType; + static constexpr DataType UNDEFINED = TensorProto_DataType_UNDEFINED; +@@ -219,13 +180,6 @@ struct TensorProto final { + static bool DataType_IsValid(int value) { return g_host->TensorProto_DataType_IsValid(value); } + + void copy_from(const TensorProto* other) { return g_host->TensorProto__CopyFrom(this, other); } +- StringStringEntryProtos* mutable_external_data() { return g_host->TensorProto__mutable_external_data(this); }; +- void clear_float_data() { return g_host->TensorProto__clear_float_data(this); } +- void clear_int32_data() { return g_host->TensorProto__clear_int32_data(this); } +- void clear_string_data() { return g_host->TensorProto__clear_string_data(this); } +- void clear_int64_data() { return g_host->TensorProto__clear_int64_data(this); } +- void clear_double_data() { return g_host->TensorProto__clear_double_data(this); } +- void clear_uint64_data() { return g_host->TensorProto__clear_uint64_data(this); } + + TensorProto() = delete; + TensorProto(const TensorProto&) = delete; +@@ -233,8 +187,6 @@ struct TensorProto final { + + struct TensorProtos final { + TensorProto* Add() { return g_host->TensorProtos__Add(this); } +- int size() { return g_host->TensorProtos__size(this); } +- TensorProto& at(int index) { return g_host->TensorProtos__at(this, index); } + + PROVIDER_DISALLOW_ALL(TensorProtos) + }; +@@ -253,8 +205,6 @@ struct TensorShapeProto_Dimension final { + bool has_dim_value() const { return g_host->TensorShapeProto_Dimension__has_dim_value(this); } + bool has_dim_param() const { return g_host->TensorShapeProto_Dimension__has_dim_param(this); } + void clear_dim_value() { return g_host->TensorShapeProto_Dimension__clear_dim_value(this); } +- const std::string& denotation() const { return g_host->TensorShapeProto_Dimension__denotation(this); } +- void set_denotation(const std::string& value) { g_host->TensorShapeProto_Dimension__set_denotation(this, value); } + + PROVIDER_DISALLOW_ALL(TensorShapeProto_Dimension) + }; +@@ -282,7 +232,6 @@ struct TypeProto_Tensor final { + const TensorShapeProto& shape() const { return g_host->TypeProto_Tensor__shape(this); } + TensorShapeProto* mutable_shape() { return g_host->TypeProto_Tensor__mutable_shape(this); } + int32_t elem_type() const { return g_host->TypeProto_Tensor__elem_type(this); } +- void set_elem_type(int32_t value) { g_host->TypeProto_Tensor__set_elem_type(this, value); } + + PROVIDER_DISALLOW_ALL(TypeProto_Tensor) + }; +@@ -366,6 +315,7 @@ struct ValueInfoProtos final { + + PROVIDER_DISALLOW_ALL(ValueInfoProtos) + }; ++ + } // namespace ONNX_NAMESPACE + + namespace onnxruntime { +@@ -653,10 +603,6 @@ struct Function final { + }; + + struct Node final { +- enum class Type { +- Primitive = 0, +- Fused = 1, +- }; + const std::string& Name() const noexcept { return g_host->Node__Name(this); } + const std::string& Description() const noexcept { return g_host->Node__Description(this); } + const std::string& Domain() const noexcept { return g_host->Node__Domain(this); } +@@ -680,10 +626,6 @@ struct Node final { + void ToProto(ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) const { return g_host->Node__ToProto(this, proto, update_subgraphs); } + + const NodeAttributes& GetAttributes() const noexcept { return g_host->Node__GetAttributes(this); } +- void AddAttribute(const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) { +- g_host->Node__AddAttribute(this, attr_name, value); +- } +- + size_t GetInputEdgesCount() const noexcept { return g_host->Node__GetInputEdgesCount(this); } + size_t GetOutputEdgesCount() const noexcept { return g_host->Node__GetOutputEdgesCount(this); } + +@@ -719,15 +661,12 @@ struct Node final { + std::unique_ptr impl_; + }; + +- EdgeConstIterator InputEdgesBegin() const noexcept { return g_host->Node__InputEdgesBegin(this); } +- EdgeConstIterator InputEdgesEnd() const noexcept { return g_host->Node__InputEdgesEnd(this); } + EdgeConstIterator OutputEdgesBegin() const noexcept { return g_host->Node__OutputEdgesBegin(this); } + EdgeConstIterator OutputEdgesEnd() const noexcept { return g_host->Node__OutputEdgesEnd(this); } + + void ForEachDef(std::function func, bool include_missing_optional_defs = false) const { g_host->Node__ForEachDef(this, func, std::move(include_missing_optional_defs)); } + const std::unordered_map>& GetAttributeNameToMutableSubgraphMap() { return g_host->Node__GetAttributeNameToMutableSubgraphMap(this); } + std::unordered_map> GetAttributeNameToSubgraphMap() const { return g_host->Node__GetAttributeNameToSubgraphMap(this); } +- Type NodeType() const noexcept { return Type(g_host->Node__NodeType(this)); } + + PROVIDER_DISALLOW_ALL(Node) + }; +@@ -739,7 +678,6 @@ struct NodeArg final { + const NodeArgInfo& ToProto() const noexcept { return g_host->NodeArg__ToProto(this); } + bool Exists() const noexcept { return g_host->NodeArg__Exists(this); } + const ONNX_NAMESPACE::TypeProto* TypeAsProto() const noexcept { return g_host->NodeArg__TypeAsProto(this); } +- Status OverrideTypesHelper(const ONNX_NAMESPACE::TypeProto& input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types) { return g_host->NodeArg__OverrideTypesHelper(this, input_type, input_tensor_elem_type, current_tensor_elem_type, override_types); } + + PROVIDER_DISALLOW_ALL(NodeArg) + }; +@@ -760,8 +698,6 @@ struct NodeAttributes final { + IteratorHolder> find(const std::string& key) const { return g_host->NodeAttributes__find(this, key); } + void insert(const NodeAttributes& v) { return g_host->NodeAttributes__insert(this, v); } + void emplace(const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) { g_host->NodeAttributes__emplace(this, k, v); } +- void insert_or_assign(const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) { g_host->NodeAttributes__insert_or_assign(this, k, v); } +- + void reserve(size_t size) { g_host->NodeAttributes__reserve(this, size); } + + NodeAttributes() = delete; +@@ -769,18 +705,11 @@ struct NodeAttributes final { + }; + + struct Model final { +- static std::unique_ptr Create(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, +- const logging::Logger& logger) { +- return g_host->Model__construct(std::move(model_proto), model_path, logger); +- } + static void operator delete(void* p) { g_host->Model__operator_delete(reinterpret_cast(p)); } +- static Status Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { return g_host->Model__Load(file_path, model_proto); } + + Graph& MainGraph() { return g_host->Model__MainGraph(this); } + + std::unique_ptr ToProto() { return g_host->Model__ToProto(this); } +- std::unique_ptr ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) { return g_host->Model__ToGraphProtoWithExternalInitializers(this, external_file_name, file_path, initializer_size_threshold); } +- const ModelMetaData& MetaData() const noexcept { return g_host->Model__MetaData(this); } + + Model() = delete; + Model(const Model&) = delete; +@@ -803,7 +732,6 @@ struct Graph final { + void SetOutputs(gsl::span outputs) { return g_host->Graph__SetOutputs(this, outputs); } + + const std::vector& GetInputs() const noexcept { return g_host->Graph__GetInputs(this); } +- std::vector Nodes() const noexcept { return g_host->Graph__Nodes(this); } + + bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const { return g_host->Graph__GetInitializedTensor(this, tensor_name, value); } + +@@ -814,37 +742,6 @@ struct Graph final { + const Path& ModelPath() const { return g_host->Graph__ModelPath(this); } + const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->Graph__GetInputsIncludingInitializers(this); } + bool IsSubgraph() const { return g_host->Graph__IsSubgraph(this); } +- const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->Graph__GetProducerNode(this, node_arg_name); } +- const Model& GetModel() const { return g_host->Graph__GetModel(this); } +- void ReverseDFSFrom(gsl::span from, const std::function& enter, +- const std::function& leave, +- const std::function& comp, +- const std::function& stop) const { +- g_host->Graph__ReverseDFSFrom(this, from, enter, leave, comp, stop); +- } +- Graph& SetGraphResolveNeeded() { return g_host->Graph__SetGraphResolveNeeded(this); } +- void RemoveInitializedTensor(const std::string& tensor_name) { g_host->Graph__RemoveInitializedTensor(this, tensor_name); } +- +- std::vector GetConsumerNodes(const std::string& node_arg_name) const { +- return g_host->Graph__GetConsumerNodes(this, node_arg_name); +- } +- void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index) { +- g_host->Graph__AddEdge(this, src_node_index, dst_node_index, src_arg_index, dst_arg_index); +- } +- void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index) { +- g_host->Graph__RemoveEdge(this, src_node_index, dst_node_index, src_arg_index, dst_arg_index); +- } +- void RemoveNode(NodeIndex index) { g_host->Graph__RemoveNode(this, index); } +- Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) { +- return g_host->Graph__FuseSubGraph(this, sub_graph, fused_node_name); +- } +- void UpdateProducerNode(const std::string& node_arg_name, NodeIndex node_index) { +- g_host->Graph__UpdateProducerNode(this, node_arg_name, node_index); +- } +- const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const { +- return g_host->Graph__GetConstantInitializer(this, name, check_outer_scope); +- } +- const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return g_host->Graph__GetAllInitializedTensors(this); } + int MaxNodeIndex() const noexcept { return g_host->Graph__MaxNodeIndex(this); } + const Node* GetNode(NodeIndex node_index) const noexcept { return g_host->Graph__GetNode(this, node_index); } + Node* GetNode(NodeIndex node_index) noexcept { return g_host->Graph__GetNode(this, node_index); } +@@ -853,8 +750,7 @@ struct Graph final { + PROVIDER_DISALLOW_ALL(Graph) + }; + +-class GraphViewer final { +- public: ++struct GraphViewer final { + static void operator delete(void* p) { g_host->GraphViewer__operator_delete(reinterpret_cast(p)); } + + std::unique_ptr CreateModel(const logging::Logger& logger) const { return g_host->GraphViewer__CreateModel(this, logger); } +@@ -886,7 +782,6 @@ class GraphViewer final { + const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->GraphViewer__GetInputsIncludingInitializers(this); } + + void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) const { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args); } +- const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); } + + GraphViewer() = delete; + GraphViewer(const GraphViewer&) = delete; +@@ -894,16 +789,11 @@ class GraphViewer final { + }; + + struct Path final { +- static std::unique_ptr Create() { return g_host->Path__construct(); } +- static void operator delete(void* p) { g_host->Path__operator_delete(reinterpret_cast(p)); } +- + PathString ToPathString() const noexcept { return g_host->Path__ToPathString(this); } + const std::vector& GetComponents() const noexcept { return g_host->Path__GetComponents(this); } + bool IsEmpty() const noexcept { return g_host->Path__IsEmpty(this); } + +- Path() = delete; +- Path(const Path&) = delete; +- void operator=(const Path&) = delete; ++ PROVIDER_DISALLOW_ALL(Path) + }; + + struct OpKernelContext final { +@@ -1262,13 +1152,6 @@ class TensorSeq final { + void Reserve(size_t capacity) { g_host->TensorSeq__Reserve(this, capacity); } + }; + +-class ModelMetadefIdGenerator { +- public: +- static std::unique_ptr Create() { return g_host->ModelMetadefIdGenerator__construct(); } +- static void operator delete(void* p) { g_host->ModelMetadefIdGenerator__operator_delete(reinterpret_cast(p)); } +- int GenerateId(const GraphViewer& graph_viewer, HashValue& model_hash) const { return g_host->ModelMetadefIdGenerator__GenerateId(this, graph_viewer, model_hash); } +-}; +- + template <> + inline gsl::span Tensor::DataAsSpan() const { return g_host->Tensor__DataAsSpan_int64(this); } + +diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +index c0bf29e48..1c9340dfd 100644 +--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc ++++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +@@ -1310,7 +1310,7 @@ TensorrtExecutionProvider::PerThreadContext& TensorrtExecutionProvider::GetPerTh + } + + TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProviderInfo& info) +- : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info), device_id_(info.device_id) { ++ : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id), true}, info_(info), device_id_(info.device_id) { + InitProviderOrtApi(); + + CUDA_CALL_THROW(cudaSetDevice(device_id_)); +diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h +index 92cce0c20..a8e3ae3dd 100644 +--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h ++++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h +@@ -497,15 +497,7 @@ void RemoveCachesByType(const std::string& root, std::string file_extension) { + } + } + +-/** +- * +- * Helper class to generate engine id via model name/model content/env metadata +- * +- * +- * The TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches +- * compiled kernels, so the name must be unique and deterministic across models and sessions. +- * +- */ ++// Helper class to generate engine id via model name/model content/env metadata + HashValue TRTGenerateId(const GraphViewer& graph_viewer) { + HashValue model_hash = 0; + +diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.cc b/onnxruntime/core/providers/vitisai/imp/attr_proto.cc +index 1392ecef1..29bc886fb 100644 +--- a/onnxruntime/core/providers/vitisai/imp/attr_proto.cc ++++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.cc +@@ -2,106 +2,126 @@ + // Licensed under the MIT License. + #include "./attr_proto.h" + ++#include "./vai_assert.h" ++ + #include + #include + #include + #include + +-#include "core/providers/shared_library/provider_api.h" +- +-#include "./vai_assert.h" +- + namespace vaip { +-ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, int64_t value) { +- auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ++ ++ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, ++ int64_t value) { ++ auto ret = new onnx::AttributeProto(); + ret->set_name(name); +- ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); ++ ret->set_type(onnx::AttributeProto_AttributeType_INT); + ret->set_i(value); +- return ret.release(); ++ return ret; + } +-ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, float value) { +- auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ++ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, ++ float value) { ++ auto ret = new onnx::AttributeProto(); + ret->set_name(name); +- ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); ++ ret->set_type(onnx::AttributeProto_AttributeType_FLOAT); + ret->set_f(value); +- return ret.release(); ++ return ret; + } +-ONNX_NAMESPACE::AttributeProto* attr_proto_new_string(const std::string& name, const std::string& value) { +- auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ++ONNX_NAMESPACE::AttributeProto* attr_proto_new_string( ++ const std::string& name, const std::string& value) { ++ auto ret = new onnx::AttributeProto(); + ret->set_name(name); +- ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRING); ++ ret->set_type(onnx::AttributeProto_AttributeType_STRING); + ret->set_s(value); +- return ret.release(); ++ return ret; + } + ONNX_NAMESPACE::AttributeProto* attr_proto_new_tensor( + const std::string& name, const ONNX_NAMESPACE::TensorProto& value) { +- auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ++ auto ret = new onnx::AttributeProto(); + ret->set_name(name); +- ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR); +- *ret->add_tensors() = value; +- return ret.release(); ++ ret->set_type(onnx::AttributeProto_AttributeType_TENSOR); ++ *ret->mutable_t() = value; ++ return ret; + } +-ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints(const std::string& name, const std::vector& value) { +- auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ++ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints( ++ const std::string& name, const std::vector& value) { ++ auto ret = new onnx::AttributeProto(); + ret->set_name(name); +- ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INTS); ++ ret->set_type(onnx::AttributeProto_AttributeType_INTS); + ret->mutable_ints()->Reserve((int)value.size()); + for (auto v : value) { + ret->add_ints(v); + } +- return ret.release(); ++ return ret; + } ++ + ONNX_NAMESPACE::AttributeProto* attr_proto_new_floats( + const std::string& name, const std::vector& value) { +- auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ++ auto ret = new onnx::AttributeProto(); + ret->set_name(name); +- ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS); ++ ret->set_type(onnx::AttributeProto_AttributeType_FLOATS); + ret->mutable_floats()->Reserve((int)value.size()); + for (auto v : value) { + ret->add_floats(v); + } +- return ret.release(); ++ return ret; + } +-ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings(const std::string& name, const std::vector& value) { +- auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ++ ++ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings( ++ const std::string& name, const std::vector& value) { ++ auto ret = new onnx::AttributeProto(); + ret->set_name(name); +- ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS); ++ ret->set_type(onnx::AttributeProto_AttributeType_STRINGS); ++ ret->mutable_strings()->Reserve((int)value.size()); + for (auto& v : value) { + ret->add_strings(v); + } +- return ret.release(); ++ return ret; + } +-int64_t attr_proto_get_int(const ONNX_NAMESPACE::AttributeProto& attr) { +- vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT, attr.name()); ++ ++int64_t attr_proto_get_int(const onnx::AttributeProto& attr) { ++ vai_assert(attr.type() == onnx::AttributeProto_AttributeType_INT, attr.DebugString()); + return attr.i(); + } +-float attr_proto_get_float(const ONNX_NAMESPACE::AttributeProto& attr) { +- vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT, attr.name()); ++ ++float attr_proto_get_float(const onnx::AttributeProto& attr) { ++ vai_assert(attr.type() == onnx::AttributeProto_AttributeType_FLOAT, attr.DebugString()); + return attr.f(); + } +-const std::string& attr_proto_get_string(const ONNX_NAMESPACE::AttributeProto& attr) { +- vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_STRING, attr.name()); ++ ++const std::string& attr_proto_get_string(const onnx::AttributeProto& attr) { ++ vai_assert(attr.type() == onnx::AttributeProto_AttributeType_STRING, attr.DebugString()); + return attr.s(); + } +-const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor(const ONNX_NAMESPACE::AttributeProto& attr) { +- vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR, attr.name()); ++ ++const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor( ++ const onnx::AttributeProto& attr) { ++ vai_assert(attr.type() == onnx::AttributeProto_AttributeType_TENSOR, attr.DebugString()); + return attr.t(); + } +-gsl::span attr_proto_get_ints(const ONNX_NAMESPACE::AttributeProto& attr) { +- vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INTS, attr.name()); ++ ++gsl::span attr_proto_get_ints(const onnx::AttributeProto& attr) { ++ vai_assert(attr.type() == onnx::AttributeProto_AttributeType_INTS, attr.DebugString()); + return gsl::span(attr.ints()); + } +-gsl::span attr_proto_get_floats(const ONNX_NAMESPACE::AttributeProto& attr) { +- vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS, attr.name()); ++ ++gsl::span attr_proto_get_floats(const onnx::AttributeProto& attr) { ++ vai_assert(attr.type() == onnx::AttributeProto_AttributeType_FLOATS, attr.DebugString()); + return gsl::span(attr.floats()); + } +-std::vector attr_proto_get_strings(const ONNX_NAMESPACE::AttributeProto& attr) { +- vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS, attr.name()); +- std::vector ret; +- ret.reserve(attr.strings_size()); +- for (int i = 0; i < attr.strings_size(); i++) { +- ret.push_back(attr.strings(i)); +- } ++ ++std::vector attr_proto_get_strings( ++ const ONNX_NAMESPACE::AttributeProto& attr) { ++ vai_assert(attr.type() == onnx::AttributeProto_AttributeType_STRINGS, attr.DebugString()); ++ return std::vector(attr.strings().begin(), attr.strings().end()); ++} ++ ++ONNX_NAMESPACE::AttributeProto attr_proto_from_i64(const std::string& name, ++ int64_t value) { ++ ONNX_NAMESPACE::AttributeProto ret; ++ ret.set_name(name); ++ ret.set_i(value); + return ret; + } ++ + } // namespace vaip +diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.h b/onnxruntime/core/providers/vitisai/imp/attr_proto.h +index f4d56dd61..32ba8fa67 100644 +--- a/onnxruntime/core/providers/vitisai/imp/attr_proto.h ++++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.h +@@ -2,26 +2,46 @@ + // Licensed under the MIT License. + #pragma once + #include +-#include "vaip/my_ort.h" ++ + #include "core/common/gsl.h" ++#include "onnx/onnx_pb.h" + + namespace vaip { + +-ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, int64_t value); +-ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, float value); +-ONNX_NAMESPACE::AttributeProto* attr_proto_new_string(const std::string& name, const std::string& value); +-ONNX_NAMESPACE::AttributeProto* attr_proto_new_tensor(const std::string& name, const ONNX_NAMESPACE::TensorProto& value); +-ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints(const std::string& name, const std::vector& value); +-ONNX_NAMESPACE::AttributeProto* attr_proto_new_floats(const std::string& name, const std::vector& value); +-ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings(const std::string& name, const std::vector& value); ++ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, ++ int64_t value); ++ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, ++ float value); ++ONNX_NAMESPACE::AttributeProto* attr_proto_new_string(const std::string& name, ++ const std::string& value); ++ONNX_NAMESPACE::AttributeProto* attr_proto_new_tensor( ++ const std::string& name, const ONNX_NAMESPACE::TensorProto& value); ++ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints( ++ const std::string& name, const std::vector& value); ++ONNX_NAMESPACE::AttributeProto* attr_proto_new_floats( ++ const std::string& name, const std::vector& value); ++ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings( ++ const std::string& name, const std::vector& value); + + /// attr_proto getters + int64_t attr_proto_get_int(const ONNX_NAMESPACE::AttributeProto& attr); + float attr_proto_get_float(const ONNX_NAMESPACE::AttributeProto& attr); +-const std::string& attr_proto_get_string(const ONNX_NAMESPACE::AttributeProto& attr); +-const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor(const ONNX_NAMESPACE::AttributeProto& attr); +-gsl::span attr_proto_get_ints(const ONNX_NAMESPACE::AttributeProto& attr); +-gsl::span attr_proto_get_floats(const ONNX_NAMESPACE::AttributeProto& attr); +-std::vector attr_proto_get_strings(const ONNX_NAMESPACE::AttributeProto& attr); ++const std::string& attr_proto_get_string( ++ const ONNX_NAMESPACE::AttributeProto& attr); ++ ++const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor( ++ const onnx::AttributeProto& attr); ++gsl::span attr_proto_get_ints(const onnx::AttributeProto& attr); ++gsl::span attr_proto_get_floats(const onnx::AttributeProto& attr); ++std::vector attr_proto_get_strings( ++ const ONNX_NAMESPACE::AttributeProto& attr); ++ ++/// attr_proto makers ++ONNX_NAMESPACE::AttributeProto attr_proto_from_i64(const std::string& name, ++ int64_t); ++ ++/// ++using attr_proto_func_t = std::function; + + } // namespace vaip +diff --git a/onnxruntime/core/providers/vitisai/imp/capability.cc b/onnxruntime/core/providers/vitisai/imp/capability.cc +index 58522a45a..a55180bd2 100644 +--- a/onnxruntime/core/providers/vitisai/imp/capability.cc ++++ b/onnxruntime/core/providers/vitisai/imp/capability.cc +@@ -3,10 +3,15 @@ + #include "vaip/capability.h" + #include "./vai_assert.h" + ++#include "core/graph/basic_types.h" ++ ++#include "./attr_proto.h" ++ + namespace vaip { + using namespace ::onnxruntime; + +-static std::vector node_names_to_nodes(const GraphViewer& graph, const std::vector& node_names) { ++static std::vector node_names_to_nodes(const GraphViewer& graph, ++ const std::vector& node_names) { + auto ret = std::vector(); + ret.reserve(node_names.size()); + for (auto& onnx_node_name : node_names) { +@@ -19,45 +24,53 @@ static std::vector node_names_to_nodes(const GraphViewer& graph, cons + } + + std::unique_ptr XirSubgraphToComputeCapability1(const onnxruntime::GraphViewer& graph, vaip_core::ExecutionProvider* ep, size_t index) { +- auto meta_def = IndexedSubGraph_MetaDef::Create(); +- meta_def->constant_initializers() = *ep->get_meta_def_constant_initializer(); +- meta_def->inputs() = *ep->get_meta_def_inputs(); +- meta_def->outputs() = *ep->get_meta_def_outputs(); +- auto indexed_subgraph = IndexedSubGraph::Create(); +- indexed_subgraph->Nodes() = node_names_to_nodes(graph, *ep->get_meta_def_nodes()); ++ auto meta_def = std::make_unique(); ++ meta_def->constant_initializers = *ep->get_meta_def_constant_initializer(); ++ meta_def->inputs = *ep->get_meta_def_inputs(); ++ meta_def->outputs = *ep->get_meta_def_outputs(); ++ auto indexed_subgraph = std::make_unique(); ++ auto indexed_subgraph_ptr = indexed_subgraph.get(); ++ indexed_subgraph_ptr->nodes = node_names_to_nodes(graph, *ep->get_meta_def_nodes()); + static auto g_counter = 1; +- meta_def->name() = std::string("vitis_ai_ep_") + std::to_string(g_counter++); +- meta_def->domain() = "com.xilinx"; +- meta_def->since_version() = 1; +- meta_def->status() = ONNX_NAMESPACE::EXPERIMENTAL; +- auto index_proto = ONNX_NAMESPACE::AttributeProto::Create(); +- index_proto->set_name("index"); +- index_proto->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); +- index_proto->set_i(index); +- meta_def->attributes()["index"] = *index_proto; ++ meta_def->name = std::string("vitis_ai_ep_") + std::to_string(g_counter++); ++ meta_def->domain = "com.xilinx"; ++ meta_def->since_version = 1; ++ meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; ++ auto index_proto = std::unique_ptr(vaip::attr_proto_new_int("index", (int64_t)index)); ++ meta_def->attributes["index"] = *index_proto; + indexed_subgraph->SetMetaDef(std::move(meta_def)); +- return ComputeCapability::Create(std::move(indexed_subgraph)); ++ return std::make_unique(std::move(indexed_subgraph)); + } + + std::vector> + GetComputeCapabilityOps(const onnxruntime::GraphViewer& graph, + vaip_core::DllSafe>>* eps, +- const std::set& all_support_optypes_by_eps) { +- std::set all_nodes_included_eps; ++ const std::set& all_not_support_optypes) { ++ std::set all_compute_capability_nodes; + for (auto& ep : **eps) { +- auto nodes = node_names_to_nodes(graph, *ep->get_meta_def_nodes()); +- all_nodes_included_eps.insert(nodes.begin(), nodes.end()); ++ auto nodes = *ep->get_meta_def_nodes(); ++ for (auto n : nodes) ++ all_compute_capability_nodes.insert(n); + } +- +- std::vector node_indexs = graph.GetNodesInTopologicalOrder(); +- node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_nodes_included_eps.count(index) > 0; }), node_indexs.end()); +- node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_support_optypes_by_eps.count(graph.GetNode(index)->OpType()) == 0; }), node_indexs.end()); +- + std::vector> result; +- for (auto& n : node_indexs) { +- auto indexed_subgraph = IndexedSubGraph::Create(); +- indexed_subgraph->Nodes() = {n}; +- result.emplace_back(ComputeCapability::Create(std::move(indexed_subgraph))); ++ for (auto& n : graph.Nodes()) { ++ if ((!all_compute_capability_nodes.count(n.Name())) && all_not_support_optypes.count(n.OpType())) { ++ auto meta_def = std::make_unique(); ++ meta_def->name = n.OpType(); ++ meta_def->domain = n.Domain(); ++ meta_def->since_version = 1; ++ meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; ++ auto indexed_subgraph = std::make_unique(); ++ indexed_subgraph->nodes.push_back(n.Index()); ++ for (auto i : n.InputDefs()) { ++ meta_def->inputs.push_back(i->Name()); ++ } ++ for (auto i : n.OutputDefs()) { ++ meta_def->outputs.push_back(i->Name()); ++ } ++ indexed_subgraph->SetMetaDef(std::move(meta_def)); ++ result.emplace_back(std::make_unique(std::move(indexed_subgraph))); ++ } + } + return result; + } +diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc +index f609d40f4..b629c8eff 100644 +--- a/onnxruntime/core/providers/vitisai/imp/global_api.cc ++++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc +@@ -1,18 +1,20 @@ ++ + // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + // Licensed under the MIT License. +- + #include "vaip/global_api.h" + + #include +-#include +-#include + #include + + #include "./vai_assert.h" +- + #include "core/common/exceptions.h" ++#include "core/common/logging/logging.h" ++ + #include "core/framework/error_code_helper.h" +-#include "core/providers/shared/common.h" ++ ++#include "core/graph/model.h" ++#include "core/session/ort_env.h" ++#include "core/session/onnxruntime_cxx_api.h" + + #include + +@@ -53,14 +55,16 @@ struct OrtVitisAIEpAPI { + std::vector>* (*compile_onnx_model_with_options)( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); + void Ensure() { +- if (handle_) +- return; +- auto& env = Provider_GetHost()->Env__Default(); +- auto full_path = env.GetRuntimePath() + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); +- ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_)); +- ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "initialize_onnxruntime_vitisai_ep", (void**)&initialize_onnxruntime_vitisai_ep)); +- auto status1 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options); +- auto status2 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", (void**)&compile_onnx_model_3); ++ if (handle_) return; ++ auto full_path = Env::Default().GetRuntimePath() + ++ PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); ++ ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, true, &handle_)); ++ ORT_THROW_IF_ERROR(Env::Default().GetSymbolFromLibrary( ++ handle_, "initialize_onnxruntime_vitisai_ep", reinterpret_cast(&initialize_onnxruntime_vitisai_ep))); ++ auto status1 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", ++ reinterpret_cast(&compile_onnx_model_with_options)); ++ auto status2 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", ++ reinterpret_cast(&compile_onnx_model_3)); + if (!status1.IsOK() && !status2.IsOK()) { + ::onnxruntime::LogRuntimeError(0, status1, __FILE__, static_cast(__FUNCTION__), __LINE__); + ORT_THROW(status1); +@@ -72,12 +76,6 @@ struct OrtVitisAIEpAPI { + }; + + static OrtVitisAIEpAPI s_library_vitisaiep; +-static std::shared_ptr s_kernel_registry_vitisaiep; +-static std::vector s_domains_vitisaiep; +-static vaip_core::OrtApiForVaip the_global_api; +-std::shared_ptr get_kernel_registry_vitisaiep() { return s_kernel_registry_vitisaiep; } +-const std::vector& get_domains_vitisaiep() { return s_domains_vitisaiep; } +- + static std::string config_to_json_str(const onnxruntime::ProviderOptions& config) { + auto iter = config.find("config_file"); + if (iter == config.end()) { +@@ -107,142 +105,121 @@ static std::string config_to_json_str(const onnxruntime::ProviderOptions& config + return ""; + } + } +- +-vaip_core::DllSafe>> compile_onnx_model( +- const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { +-#ifndef _WIN32 +- auto model_path = graph_viewer.ModelPath().ToPathString(); +-#else +- using convert_t = std::codecvt_utf8; +- std::wstring_convert strconverter; +- auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); +-#endif ++vaip_core::DllSafe>> compile_onnx_model_with_options( ++ const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options) { + if (s_library_vitisaiep.compile_onnx_model_with_options) { +- return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options)); ++ return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph, options)); + } else { + auto json_str = config_to_json_str(options); +- return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_3(model_path, graph_viewer.GetGraph(), json_str.c_str())); ++ return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_3(model_path, graph, json_str.c_str())); + } + } + +-struct MyCustomOpKernel : OpKernel { +- MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { +- op_kernel_ = +- op_.CreateKernel(&op_, Ort::Global::api_, reinterpret_cast(&info)); ++std::vector initialize_vitisai_ep() { ++ s_library_vitisaiep.Ensure(); ++ Status status = Status::OK(); ++ try { ++ OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, ++ "onnxruntime-vitisai-ep"}; ++ std::ignore = OrtEnv::GetInstance(lm_info, status); ++ } catch (onnxruntime::OnnxRuntimeException& /*e*/) { + } +- +- ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } +- +- Status Compute(OpKernelContext* ctx) const override { +- op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); +- return Status::OK(); ++ auto domains = std::vector(); ++ domains.reserve(100); ++ s_library_vitisaiep.initialize_onnxruntime_vitisai_ep(create_org_api_hook(), domains); ++ auto& domainToVersionRangeInstance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); ++ if (domainToVersionRangeInstance.Map().find("com.xilinx") == domainToVersionRangeInstance.Map().end()) { ++ vaip::register_xir_ops(domains); + } + +- private: +- ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MyCustomOpKernel); +- +- const OrtCustomOp& op_; +- void* op_kernel_; +-}; +- +-void create_kernel_registry(std::vector domains) { +- s_kernel_registry_vitisaiep = KernelRegistry::Create(); +- for (const auto& domain : domains) { +- for (const auto* op : domain->custom_ops_) { +- auto def_builder = KernelDefBuilder::Create(); +- def_builder->SetName(op->GetName(op)); +- def_builder->SetDomain(domain->domain_.c_str()); +- def_builder->SinceVersion(1); +- if (op->version > 12) { +- auto input_count = op->GetInputTypeCount(op); +- for (auto i = 0u; i < input_count; i++) { +- def_builder->InputMemoryType(op->GetInputMemoryType(op, i), i); +- } +- } +- def_builder->Provider(onnxruntime::kVitisAIExecutionProvider); +- KernelCreateFn kernel_create_fn = +- [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { +- // out = std::make_unique(info, *op); +- return Status::OK(); +- }; +- std::ignore = s_kernel_registry_vitisaiep->Register(KernelCreateInfo(def_builder->Build(), kernel_create_fn)); +- } +- } +-} +-void initialize_vitisai_ep() { +- s_library_vitisaiep.Ensure(); +- s_domains_vitisaiep.reserve(100); +- s_library_vitisaiep.initialize_onnxruntime_vitisai_ep(create_org_api_hook(), s_domains_vitisaiep); +- vaip::register_xir_ops(s_domains_vitisaiep); +- create_kernel_registry(s_domains_vitisaiep); ++ return domains; + } + ++static vaip_core::OrtApiForVaip the_global_api; + vaip_core::OrtApiForVaip* create_org_api_hook() { +- InitProviderOrtApi(); +- the_global_api.host_ = Provider_GetHost(); + assert(Ort::Global::api_ != nullptr); + the_global_api.ort_api_ = Ort::Global::api_; + the_global_api.model_load = [](const std::string& filename) -> Model* { +- auto model_proto = ONNX_NAMESPACE::ModelProto::Create(); ++ ONNX_NAMESPACE::ModelProto model_proto; + auto& logger = logging::LoggingManager::DefaultLogger(); + auto file_path = ToPathString(filename); +- auto status = Model::Load(file_path, *model_proto); ++ auto status = Model::Load(file_path, model_proto); + vai_assert(status.IsOK(), "load model proto error"); +- auto model = Model::Create(std::move(*model_proto), file_path, logger); ++ auto model = std::make_unique(std::move(model_proto), file_path, nullptr, logger); + return model.release(); + }; + the_global_api.model_delete = [](Model* model) { delete model; }; +- +- the_global_api.model_clone = [](const Model& const_model) -> Model* { ++ the_global_api.model_clone = [](const Model& model) -> Model* { + auto& logger = logging::LoggingManager::DefaultLogger(); +- auto& model = const_cast(const_model); +- auto model_proto = model.ToProto(); +- auto file_path = model.MainGraph().ModelPath().ToPathString(); +- auto ret = Model::Create(std::move(*model_proto), file_path, logger); ++ auto model_proto = const_cast(model).ToProto(); ++ auto file_path = model.ModelPath().ToPathString(); ++ auto ret = std::make_unique(std::move(model_proto), file_path, nullptr, logger); + auto status = ret->MainGraph().Resolve(); + vai_assert(status.IsOK(), status.ErrorMessage()); + return ret.release(); + }; +- the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) { ++ the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) -> void { + const_cast(model.MetaData())[key] = value; + }; +- the_global_api.model_get_meta_data = +- [](const Model& model, const std::string& key) -> vaip_core::DllSafe { +- if (model.MetaData().count(key)) { +- return vaip_core::DllSafe(model.MetaData().at(key)); ++ the_global_api.model_get_meta_data = [](const Model& model, ++ const std::string& key) -> vaip_core::DllSafe { ++ auto& m = model.MetaData(); ++ auto it = m.find(key); ++ auto ret = std::string(); ++ if (it != m.end()) { ++ ret = it->second; + } +- return vaip_core::DllSafe(std::string()); ++ return vaip_core::DllSafe(ret); + }; ++ + the_global_api.model_has_meta_data = [](const Model& model, const std::string& key) -> int { +- return int(model.MetaData().count(key)); ++ auto& m = model.MetaData(); ++ return m.find(key) != m.end() ? 1 : 0; + }; ++ + the_global_api.model_main_graph = [](Model& model) -> Graph& { return model.MainGraph(); }; + the_global_api.graph_get_model = [](const Graph& graph) -> const Model& { return graph.GetModel(); }; +- the_global_api.graph_get_inputs_unsafe = [](const Graph& graph) -> auto { +- return vaip_core::DllSafe(graph.GetInputs()); ++ the_global_api.graph_get_inputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { ++ auto ret = std::vector(); ++ auto inputs = graph.GetInputs(); ++ for (auto input : inputs) { ++ vai_assert(input->Exists(), input->Name()); ++ ret.push_back(input); ++ } ++ return vaip_core::DllSafe(std::move(ret)); + }; +- the_global_api.graph_get_outputs_unsafe = [](const Graph& graph) -> auto { ++ the_global_api.graph_get_outputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { + return vaip_core::DllSafe(graph.GetOutputs()); + }; +- the_global_api.graph_set_outputs = [](Graph& graph, gsl::span outputs) { +- graph.SetOutputs(outputs); ++ ++ the_global_api.graph_set_outputs = [](Graph& graph, gsl::span outputs) -> void { ++ return graph.SetOutputs(outputs); + }; ++ + the_global_api.graph_get_node_arg = [](const Graph& graph, const std::string& name) -> const NodeArg* { + return graph.GetNodeArg(name); + }; + the_global_api.graph_producer_node = [](const Graph& graph, const std::string& name) -> const Node* { + return graph.GetProducerNode(name); + }; +- the_global_api.graph_get_node = [](const Graph& graph, size_t index) -> const Node* { +- return graph.GetNode(index); +- }; ++ ++ the_global_api.graph_get_node = [](const Graph& graph, size_t index) -> const Node* { return graph.GetNode(index); }; ++ + the_global_api.graph_save = vaip::graph_save; + the_global_api.graph_fuse = vaip::graph_fuse; + the_global_api.graph_remove_node = vaip::graph_remove_node; +- the_global_api.graph_add_node = vaip::graph_add_node; ++ the_global_api.graph_add_node = [](Graph& graph, const std::string& name, const std::string& op_type, ++ const std::string& description, const std::vector& input_args, ++ const std::vector& output_args, ++ vaip_core::NodeAttributes& attributes, const std::string& domain) -> Node& { ++ return vaip::graph_add_node(graph, name, op_type, description, input_args, output_args, ++ std::move(reinterpret_cast(attributes)), domain); ++ }; ++ + the_global_api.graph_get_all_initialized_tensors = [](const Graph& graph) -> const InitializedTensorSet& { + return graph.GetAllInitializedTensors(); + }; ++ + the_global_api.graph_resolve = [](Graph& graph, bool force) { + if (force) { + graph.SetGraphResolveNeeded(); +@@ -250,57 +227,129 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { + auto status = graph.Resolve(); + return status.Code(); + }; +- the_global_api.graph_get_consumer_nodes_unsafe = [](const Graph& graph, const std::string& node_arg_name) -> auto { ++ ++ the_global_api.graph_get_consumer_nodes_unsafe = ++ [](const Graph& graph, const std::string& node_arg_name) -> vaip_core::DllSafe> { + return vaip_core::DllSafe(graph.GetConsumerNodes(node_arg_name)); + }; +- the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> auto { return vaip_core::DllSafe(graph.Nodes()); }; ++ the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { ++ auto& node_refererence = graph.Nodes(); ++ std::vector nodes(static_cast(graph.NumberOfNodes()), nullptr); ++ std::transform(node_refererence.begin(), node_refererence.end(), nodes.begin(), [](const Node& n) { return &n; }); ++ return vaip_core::DllSafe(std::move(nodes)); ++ }; + the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { return graph.Name(); }; + the_global_api.graph_reverse_dfs_from = [](const Graph& graph, gsl::span from, +- const auto& enter, const auto& leave, const auto& stop) { ++ const std::function& enter, ++ const std::function& leave, ++ const std::function& stop) { + graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); + }; + // node + the_global_api.node_get_inputs_unsafe = vaip::node_get_inputs; + the_global_api.node_get_output_node_args_unsafe = vaip::node_get_output_node_args; ++ + the_global_api.node_op_type = [](const Node& node) -> const std::string& { return node.OpType(); }; + the_global_api.node_op_domain = [](const Node& node) -> const std::string& { return node.Domain(); }; +- the_global_api.node_get_index = [](const Node& node) -> size_t { return node.Index(); }; ++ the_global_api.node_get_index = [](const Node& node) -> size_t { return static_cast(node.Index()); }; + the_global_api.node_get_name = [](const Node& node) -> const std::string& { return node.Name(); }; + the_global_api.node_description = [](const Node& node) -> const std::string& { return node.Description(); }; +- the_global_api.node_get_attributes = [](Node& node) -> NodeAttributes& { +- return const_cast(node.GetAttributes()); ++ ++ the_global_api.node_get_attributes = [](Node& node) -> vaip_core::NodeAttributes& { ++ return reinterpret_cast(node.GetMutableAttributes()); ++ }; ++ ++ the_global_api.node_type_is_fused = [](const Node& node) { ++ return node.NodeType() == onnxruntime::Node::Type::Fused; + }; +- the_global_api.node_type_is_fused = [](const Node& node) { return node.NodeType() == Node::Type::Fused; }; +- the_global_api.node_get_function_body = [](const Node& node) -> const auto& { ++ the_global_api.node_get_function_body = [](const Node& node) -> const onnxruntime::Graph& { + assert(node.GetFunctionBody() != nullptr); + return node.GetFunctionBody()->Body(); + }; + + // node_arg +- the_global_api.node_arg_get_name_unsafe = +- [](const NodeArg& node_arg) -> const std::string& { return node_arg.Name(); }; ++ the_global_api.node_arg_get_name_unsafe = [](const NodeArg& node_arg) -> const std::string& { ++ return node_arg.Name(); ++ }; + the_global_api.node_arg_clone = vaip::node_arg_clone; + the_global_api.node_arg_new = vaip::node_arg_new; +- the_global_api.node_arg_is_exists = [](const NodeArg& node_arg) { return node_arg.Exists(); }; ++ the_global_api.node_arg_is_exists = vaip::node_arg_is_exists; + the_global_api.node_arg_is_constant = vaip::node_arg_is_constant; + the_global_api.node_arg_get_shape_i64_unsafe = vaip::node_arg_get_shape_i64; + the_global_api.node_arg_set_shape_i64 = vaip::node_arg_set_shape_i64; + the_global_api.node_arg_get_denotation_unsafe = vaip::node_arg_get_denotation; +- + the_global_api.node_arg_set_denotation = vaip::node_arg_set_denotation; + the_global_api.node_arg_get_const_data_as_tensor = vaip::node_arg_get_const_data_as_tensor; + + the_global_api.node_arg_get_element_type = vaip::node_arg_get_element_type; +- the_global_api.node_arg_set_element_type = vaip::node_arg_set_element_type; ++ the_global_api.node_arg_set_element_type = [](NodeArg& node_arg, int type) { ++ auto data_type = ONNX_NAMESPACE::TensorProto::UNDEFINED; ++ switch (type) { ++ case 1: ++ data_type = ONNX_NAMESPACE::TensorProto::FLOAT; ++ break; ++ case 2: ++ data_type = ONNX_NAMESPACE::TensorProto::UINT8; ++ break; ++ case 3: ++ data_type = ONNX_NAMESPACE::TensorProto::INT8; ++ break; ++ ++ case 4: ++ data_type = ONNX_NAMESPACE::TensorProto::UINT16; ++ break; ++ case 5: ++ data_type = ONNX_NAMESPACE::TensorProto::INT16; ++ break; ++ case 6: ++ data_type = ONNX_NAMESPACE::TensorProto::INT32; ++ break; ++ case 7: ++ data_type = ONNX_NAMESPACE::TensorProto::INT64; ++ break; ++ case 8: ++ data_type = ONNX_NAMESPACE::TensorProto::STRING; ++ break; ++ case 9: ++ data_type = ONNX_NAMESPACE::TensorProto::BOOL; ++ break; ++ case 10: ++ data_type = ONNX_NAMESPACE::TensorProto::FLOAT16; ++ break; ++ case 11: ++ data_type = ONNX_NAMESPACE::TensorProto::DOUBLE; ++ break; ++ case 12: ++ data_type = ONNX_NAMESPACE::TensorProto::UINT32; ++ break; ++ case 13: ++ data_type = ONNX_NAMESPACE::TensorProto::UINT64; ++ break; ++ case 14: ++ data_type = ONNX_NAMESPACE::TensorProto::COMPLEX64; ++ break; ++ case 15: ++ data_type = ONNX_NAMESPACE::TensorProto::COMPLEX128; ++ break; ++ case 16: ++ data_type = ONNX_NAMESPACE::TensorProto::BFLOAT16; ++ break; ++ default: ++ vai_assert(false, "TensorProto::DataType not supoort"); ++ } ++ return vaip::node_arg_set_element_type(node_arg, data_type); ++ }; + /// attr proto +- the_global_api.attr_proto_delete = [](ONNX_NAMESPACE::AttributeProto* v) { delete v; }; +- the_global_api.attr_proto_clone = [](const ONNX_NAMESPACE::AttributeProto& v) -> ONNX_NAMESPACE::AttributeProto* { +- auto ret = ONNX_NAMESPACE::AttributeProto::Create(); +- *ret = v; +- return ret.release(); ++ the_global_api.attr_proto_delete = [](onnx::AttributeProto* v) { delete v; }; ++ the_global_api.attr_proto_clone = [](const onnx::AttributeProto& v) -> onnx::AttributeProto* { ++ return new onnx::AttributeProto(v); ++ }; ++ the_global_api.attr_proto_get_name = [](const onnx::AttributeProto& attr_proto) -> const std::string& { ++ return attr_proto.name(); ++ }; ++ the_global_api.attr_proto_set_name = [](onnx::AttributeProto* attr_proto, const std::string& name) { ++ attr_proto->set_name(name); + }; +- the_global_api.attr_proto_get_name = [](const auto& attr_proto) -> const std::string& { return attr_proto.name(); }; +- the_global_api.attr_proto_set_name = [](auto* attr_proto, const auto& name) { attr_proto->set_name(name); }; + the_global_api.attr_proto_new_int = vaip::attr_proto_new_int; + the_global_api.attr_proto_new_float = vaip::attr_proto_new_float; + the_global_api.attr_proto_new_string = vaip::attr_proto_new_string; +@@ -315,24 +364,31 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { + the_global_api.attr_proto_get_ints = vaip::attr_proto_get_ints; + the_global_api.attr_proto_get_floats = vaip::attr_proto_get_floats; + the_global_api.attr_proto_get_strings = vaip::attr_proto_get_strings; +- the_global_api.attr_proto_get_type = [](const ONNX_NAMESPACE::AttributeProto& attr) -> int { return attr.type(); }; ++ the_global_api.attr_proto_get_type = [](const onnx::AttributeProto& attr) -> int { return attr.type(); }; + + /// node attributes +- the_global_api.node_attributes_new = []() { return NodeAttributes::Create().release(); }; +- the_global_api.node_attributes_add = [](NodeAttributes& p, ONNX_NAMESPACE::AttributeProto&& attr) { +- p.insert_or_assign(attr.name(), std::move(attr)); ++ the_global_api.node_attributes_new = []() { ++ return reinterpret_cast(new NodeAttributes()); + }; +- +- the_global_api.node_attributes_delete = [](NodeAttributes* p) { delete p; }; +- the_global_api.node_attributes_get = +- [](const NodeAttributes& attr, const std::string& name) -> const ONNX_NAMESPACE::AttributeProto* { +- if (attr.count(name)) { +- return &attr.at(name); ++ the_global_api.node_attributes_add = [](vaip_core::NodeAttributes& p, onnx::AttributeProto&& attr) { ++ reinterpret_cast(p).insert_or_assign(attr.name(), std::move(attr)); ++ }; ++ the_global_api.node_attributes_delete = [](vaip_core::NodeAttributes* p) { ++ delete reinterpret_cast(p); ++ }; ++ the_global_api.node_attributes_get = [](vaip_core::NodeAttributes& p, ++ const std::string& name) -> ONNX_NAMESPACE::AttributeProto* { ++ auto& attr = reinterpret_cast(p); ++ auto it = attr.find(name); ++ if (it == attr.end()) { ++ return nullptr; + } +- return nullptr; ++ return &it->second; + }; +- the_global_api.node_attributes_get_keys = [](NodeAttributes& attr) -> vaip_core::DllSafe> { ++ the_global_api.node_attributes_get_keys = ++ [](vaip_core::NodeAttributes& p) -> vaip_core::DllSafe> { + auto ret = std::vector(); ++ auto& attr = reinterpret_cast(p); + ret.reserve(attr.size()); + for (auto& it : attr) { + ret.push_back(it.first); +@@ -340,16 +396,35 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { + return vaip_core::DllSafe(std::move(ret)); + }; + /// tensor proto +- the_global_api.tensor_proto_get_shape_unsafe = vaip::tensor_proto_get_shape; +- the_global_api.tensor_proto_data_type = [](const ONNX_NAMESPACE::TensorProto& t) -> int { return t.data_type(); }; +- the_global_api.tensor_proto_delete = [](ONNX_NAMESPACE::TensorProto* tp) { delete tp; }; +- the_global_api.tensor_proto_new_floats = vaip::tensor_proto_new_floats; +- the_global_api.tensor_proto_new_i32 = vaip::tensor_proto_new_i32; +- the_global_api.tensor_proto_new_i64 = vaip::tensor_proto_new_i64; +- the_global_api.tensor_proto_new_i8 = vaip::tensor_proto_new_i8; +- the_global_api.tensor_proto_raw_data_size = [](const auto& tensor) { return tensor.raw_data().size(); }; ++ the_global_api.tensor_proto_get_shape_unsafe = ++ [](const onnx::TensorProto& t) -> vaip_core::DllSafe> { ++ return vaip_core::DllSafe>(vaip::tensor_proto_get_shape(t)); ++ }; ++ ++ the_global_api.tensor_proto_data_type = [](const onnx::TensorProto& t) -> int { return t.data_type(); }; ++ ++ the_global_api.tensor_proto_delete = [](onnx::TensorProto* tp) { delete tp; }; ++ ++ the_global_api.tensor_proto_new_floats = [](const std::string& name, const std::vector& shape, ++ const std::vector& data) -> onnx::TensorProto* { ++ return new onnx::TensorProto{vaip::tensor_proto_new_floats(name, shape, data)}; ++ }; ++ the_global_api.tensor_proto_new_i32 = [](const std::string& name, const std::vector& shape, ++ const std::vector& data) -> onnx::TensorProto* { ++ return new onnx::TensorProto{vaip::tensor_proto_new_i32(name, shape, data)}; ++ }; ++ the_global_api.tensor_proto_new_i64 = [](const std::string& name, const std::vector& shape, ++ const std::vector& data) -> onnx::TensorProto* { ++ return new onnx::TensorProto{vaip::tensor_proto_new_i64(name, shape, data)}; ++ }; ++ the_global_api.tensor_proto_new_i8 = [](const std::string& name, const std::vector& shape, ++ const std::vector& data) -> onnx::TensorProto* { ++ return new onnx::TensorProto{vaip::tensor_proto_new_i8(name, shape, data)}; ++ }; ++ the_global_api.tensor_proto_raw_data_size = vaip::tensor_proto_raw_data_size; ++ + the_global_api.tensor_proto_as_raw = vaip::tensor_proto_as_raw; +- the_global_api.tensor_proto_get_name = [](const auto& tensor) -> const std::string& { return tensor.name(); }; ++ the_global_api.tensor_proto_get_name = vaip::tensor_proto_get_name; + + the_global_api.get_lib_name = []() -> vaip_core::DllSafe { + return vaip_core::DllSafe(std::string("onnxruntime.") + std::string(ORT_VERSION)); +diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc +index 061bc414f..cca680baf 100644 +--- a/onnxruntime/core/providers/vitisai/imp/graph.cc ++++ b/onnxruntime/core/providers/vitisai/imp/graph.cc +@@ -2,15 +2,27 @@ + // Licensed under the MIT License. + #include "vaip/graph.h" + ++#include ++ ++#include "./vai_assert.h" + #include + #include + #include + #include + #include + #include +- +-#include "core/providers/shared_library/provider_api.h" +-#include "./vai_assert.h" ++#include "onnx/onnx-ml.pb.h" ++#ifdef _MSC_VER ++#pragma warning(push) ++// 'type' : forcing value to bool 'true' or 'false' (performance warning) ++#pragma warning(disable : 4800) ++#endif ++#include ++#ifdef _MSC_VER ++#pragma warning(pop) ++#endif ++using convert_t = std::codecvt_utf8; ++std::wstring_convert strconverter; + + #include "vaip/node.h" + #include "vaip/node_arg.h" +@@ -26,14 +38,23 @@ struct NodeEdgeT { + + static void graph_remove_node(Graph& graph, const Node& node) { + auto remove_edges = std::vector(); +- for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); ++it) { +- remove_edges.push_back(NodeEdgeT{it->GetNode().Index(), node.Index(), it->GetSrcArgIndex(), it->GetDstArgIndex()}); ++ auto begin = node.InputEdgesBegin(); ++ auto end = node.InputEdgesEnd(); ++ for (auto it = begin; it != end; ++it) { ++ remove_edges.push_back(NodeEdgeT{it->GetNode().Index(), node.Index(), ++ it->GetSrcArgIndex(), ++ it->GetDstArgIndex()}); + } +- for (auto it = node.OutputEdgesBegin(); it != node.OutputEdgesEnd(); ++it) { +- remove_edges.push_back(NodeEdgeT{node.Index(), it->GetNode().Index(), it->GetSrcArgIndex(), it->GetDstArgIndex()}); ++ begin = node.OutputEdgesBegin(); ++ end = node.OutputEdgesEnd(); ++ for (auto it = begin; it != end; ++it) { ++ remove_edges.push_back(NodeEdgeT{node.Index(), it->GetNode().Index(), ++ it->GetSrcArgIndex(), ++ it->GetDstArgIndex()}); + } + for (auto it : remove_edges) { +- graph.RemoveEdge(it.src_node_index, it.dst_node_index, it.src_arg_index, it.dst_arg_index); ++ graph.RemoveEdge(it.src_node_index, it.dst_node_index, it.src_arg_index, ++ it.dst_arg_index); + } + graph.RemoveNode(node.Index()); + } +@@ -47,9 +68,13 @@ static std::vector node_get_implicit_input_node_args(const Node& + } + return ret; + } +-Node& graph_add_node(Graph& graph, const std::string& name, const std::string& op_type, const std::string& description, +- const std::vector& input_args, const std::vector& output_args, +- const NodeAttributes& attributes, const std::string& domain) { ++ ++Node& graph_add_node(Graph& graph, const std::string& name, ++ const std::string& op_type, const std::string& description, ++ const std::vector& input_args, ++ const std::vector& output_args, ++ const NodeAttributes& attributes, ++ const std::string& domain) { + std::vector inputs; + inputs.reserve(input_args.size()); + for (auto i : input_args) { +@@ -60,7 +85,8 @@ Node& graph_add_node(Graph& graph, const std::string& name, const std::string& o + for (auto i : output_args) { + outputs.push_back(const_cast(i)); + } +- auto& ret = graph.AddNode(name, op_type, description, inputs, outputs, &attributes, domain); ++ auto& ret = graph.AddNode(name, op_type, description, inputs, outputs, ++ &attributes, domain); + auto src_arg_index = 0; + for (auto& o : outputs) { + auto consumers = graph.GetConsumerNodes(o->Name()); +@@ -70,7 +96,8 @@ Node& graph_add_node(Graph& graph, const std::string& name, const std::string& o + for (auto ni : *tmp_inputs) { + auto name1 = ni.node_arg->Name(); + if (name1 == o->Name()) { +- graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, dst_arg_index); ++ graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, ++ dst_arg_index); + } + dst_arg_index = dst_arg_index + 1; + } +@@ -78,7 +105,8 @@ Node& graph_add_node(Graph& graph, const std::string& name, const std::string& o + for (auto implicit_node_arg : node_get_implicit_input_node_args(*consumer)) { + auto name1 = implicit_node_arg->Name(); + if (name1 == o->Name()) { +- graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, dst_arg_index); ++ graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, ++ dst_arg_index); + } + dst_arg_index = dst_arg_index + 1; + } +@@ -104,39 +132,44 @@ void graph_remove_node(Graph& graph, const NodeInput& node_input) { + + void graph_save(const Graph& graph, const std::string& filename, const std::string& filename_dat, size_t initializer_size_threshold) { + auto& model = const_cast(graph.GetModel()); +- std::unique_ptr model_proto; ++ auto model_proto = ONNX_NAMESPACE::ModelProto(); + + if (initializer_size_threshold == std::numeric_limits::max()) { + model_proto = model.ToProto(); + } else { +- model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, graph.ModelPath().ToPathString(), initializer_size_threshold); ++ model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, ++ ToPathString(filename), ++ initializer_size_threshold); + } + auto& metadata = model.MetaData(); + if (!metadata.empty()) { +- auto metadata_props = model_proto->mutable_metadata_props(); +- metadata_props->Clear(); ++ model_proto.mutable_metadata_props()->Clear(); + for (auto& m : metadata) { +- auto prop = metadata_props->Add(); ++ auto prop = model_proto.mutable_metadata_props()->Add(); + *prop->mutable_key() = m.first; + *prop->mutable_value() = m.second; + } + } + // use relative path as data storage. +- auto graph_proto = model_proto->mutable_graph(); +- *graph_proto = *graph.ToGraphProto(); +- for (int i = 0; i < graph_proto->mutable_initializer()->size(); i++) { +- auto mutable_external_data = graph_proto->mutable_initializer()->at(i).mutable_external_data(); +- for (int j = 0; j < mutable_external_data->size(); j++) { +- auto& external_data = mutable_external_data->at(j); +- if (*external_data.mutable_key() == "location") +- *external_data.mutable_value() = std::filesystem::path(*external_data.mutable_value()).filename().u8string(); ++ auto graph_proto = model_proto.mutable_graph(); ++ *graph_proto = graph.ToGraphProto(); ++ for (auto i = 0; i < graph_proto->initializer_size(); ++i) { ++ auto initializer = graph_proto->mutable_initializer(i); ++ for (auto j = 0; j < initializer->external_data_size(); ++j) { ++ auto external_data = initializer->mutable_external_data(j); ++ if (external_data->key() == "location") { ++ *external_data->mutable_value() = std::filesystem::path(external_data->value()).filename().u8string(); ++ } + } + } +- +- std::fstream output(filename, std::ios::out | std::ios::trunc | std::ios::binary); +- bool result = model_proto->SerializeToOstream(output); +- output << std::flush; +- vai_assert(result, "model serialize to ostream error"); ++ int fd = -1; ++ Status status = Env::Default().FileOpenWr(filename, fd); ++ vai_assert(status.IsOK(), status.ErrorMessage()); ++ google::protobuf::io::FileOutputStream output(fd); ++ const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); ++ vai_assert(result, "model serialize to zero cipy stream error"); ++ status = Env::Default().FileClose(fd); ++ vai_assert(status.IsOK(), status.ErrorMessage()); + } + + Node& graph_fuse(Graph& graph, const std::string& name, +@@ -145,25 +178,25 @@ Node& graph_fuse(Graph& graph, const std::string& name, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& constant_initializers) { +- auto meta_def = IndexedSubGraph_MetaDef::Create(); +- meta_def->inputs() = inputs; +- meta_def->outputs() = outputs; +- meta_def->constant_initializers() = constant_initializers; +- meta_def->name() = "super_layer"; +- meta_def->domain() = "com.xilinx"; +- meta_def->since_version() = 1; +- meta_def->status() = ONNX_NAMESPACE::EXPERIMENTAL; +- +- auto indexed_subgraph = IndexedSubGraph::Create(); +- indexed_subgraph->Nodes() = nodes; ++ auto meta_def = std::make_unique(); ++ auto indexed_subgraph = std::make_unique(); ++ indexed_subgraph->nodes = nodes; ++ meta_def->inputs = inputs; ++ meta_def->outputs = outputs; ++ meta_def->constant_initializers = constant_initializers; ++ meta_def->name = "super_layer"; ++ meta_def->domain = "com.xilinx"; ++ meta_def->since_version = 1; ++ meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; + indexed_subgraph->SetMetaDef(std::move(meta_def)); +- + auto& fused_node = graph.FuseSubGraph(*indexed_subgraph, name); + auto function_body = fused_node.GetFunctionBody(); + if (function_body) { +- auto proto = function_body->Body().ToGraphProto(); +- *proto->mutable_name() = name; +- fused_node.AddAttribute("body", *proto); ++ auto& mygraph = function_body->Body(); ++ // auto proto = graph.ToGraphProtoWithExternal("exteranl.dat", 128); ++ auto proto = mygraph.ToGraphProto(); ++ *proto.mutable_name() = name; ++ fused_node.AddAttribute("body", proto); + } + for (auto&& o : fused_node.OutputDefs()) { + graph.UpdateProducerNode(o->Name(), fused_node.Index()); +diff --git a/onnxruntime/core/providers/vitisai/imp/node.cc b/onnxruntime/core/providers/vitisai/imp/node.cc +index 0565171fb..6d65ad4e8 100644 +--- a/onnxruntime/core/providers/vitisai/imp/node.cc ++++ b/onnxruntime/core/providers/vitisai/imp/node.cc +@@ -4,8 +4,9 @@ + #include "./vai_assert.h" + + #include "attr_proto.h" ++#include "core/graph/graph_utils.h" ++#include "core/graph/node_arg.h" + #include "vaip/node_arg.h" +-#include "core/providers/shared_library/provider_api.h" + + namespace vaip { + +@@ -28,6 +29,7 @@ vaip_core::DllSafe> node_get_inputs(const Node& node) { + } + return vaip_core::DllSafe(ret); + } ++ + vaip_core::DllSafe> node_get_output_node_args(const Node& node) { + auto outputs = node.OutputDefs(); + auto size = outputs.size(); +@@ -40,4 +42,11 @@ vaip_core::DllSafe> node_get_output_node_args(const + } + return vaip_core::DllSafe(ret); + } ++ ++vaip_core::DllSafe> node_get_output_shape(const Node& node, int index) { ++ auto outputs = node.OutputDefs(); ++ assert((size_t)index < outputs.size()); ++ return node_arg_get_shape_i64(*outputs[index]); ++} ++ + } // namespace vaip +diff --git a/onnxruntime/core/providers/vitisai/imp/node_arg.cc b/onnxruntime/core/providers/vitisai/imp/node_arg.cc +index a54cbef91..3bdeb0969 100644 +--- a/onnxruntime/core/providers/vitisai/imp/node_arg.cc ++++ b/onnxruntime/core/providers/vitisai/imp/node_arg.cc +@@ -2,16 +2,25 @@ + // Licensed under the MIT License. + #include "vaip/node_arg.h" + #include "./vai_assert.h" +-#include "core/providers/shared_library/provider_api.h" ++ ++#include + + #include "./tensor_proto.h" ++#include "core/graph/node_arg.h" + + namespace vaip { ++ ++bool node_arg_is_exists(const NodeArg& node_arg) { ++ return node_arg.Exists(); ++} + bool node_arg_is_constant(const Graph& graph, const NodeArg& node_arg) { + assert(node_arg.Exists()); + assert(!node_arg.Name().empty()); +- return graph.GetConstantInitializer(node_arg.Name(), true) != nullptr; ++ auto constant_tensor_proto = ++ graph.GetConstantInitializer(node_arg.Name(), true); ++ return constant_tensor_proto != nullptr; + } ++ + vaip_core::DllSafe> node_arg_get_shape_i64(const NodeArg& node_arg) { + auto shape = node_arg.Shape(); + if (nullptr == shape) return vaip_core::DllSafe>(); +@@ -23,42 +32,104 @@ vaip_core::DllSafe> node_arg_get_shape_i64(const NodeArg& n + } + return vaip_core::DllSafe(shape_vector); + } +-void node_arg_set_shape_i64(const NodeArg& node_arg, const std::vector& shape) { +- auto shape_proto = const_cast(node_arg.Shape()); +- assert(shape_proto != nullptr); +- assert(shape.size() == static_cast(shape_proto->dim_size())); +- auto rank = shape_proto->dim_size(); ++ ++static void LayoutTransformRule_set_shape(onnx::TensorShapeProto& shape_proto, ++ const std::vector& shape) { ++ assert(shape.size() == static_cast(shape_proto.dim_size())); ++ auto rank = shape_proto.dim_size(); + for (auto i = 0; i < rank; ++i) { +- shape_proto->mutable_dim(i)->set_dim_value(shape[i]); ++ shape_proto.mutable_dim(i)->set_dim_value(shape[i]); + } + } +-vaip_core::DllSafe> node_arg_get_denotation(const NodeArg& node_arg) { +- auto shape = node_arg.Shape(); +- if (shape == nullptr) { +- return vaip_core::DllSafe>(); +- } ++ ++static void LayoutTransformRule_set_shape(onnx::TypeProto& type_proto, ++ const std::vector& shape) { ++ assert(type_proto.value_case() == onnx::TypeProto::kTensorType); ++ //<< type_proto.DebugString(); ++ auto& tensor_type = *type_proto.mutable_tensor_type(); ++ auto& shape_prot = *tensor_type.mutable_shape(); ++ return LayoutTransformRule_set_shape(shape_prot, shape); ++} ++ ++static void LayoutTransformRule_set_shape(NodeArg* node_arg, ++ const std::vector& shape) { ++ assert(node_arg != nullptr); ++ auto* type_proto = node_arg->TypeAsProto(); ++ assert(type_proto != nullptr); ++ return LayoutTransformRule_set_shape( ++ *const_cast(type_proto), shape); ++} ++ ++void node_arg_set_shape_i64(const NodeArg& node_arg, ++ const std::vector& shape) { ++ LayoutTransformRule_set_shape(const_cast(&node_arg), shape); ++} ++ ++static std::vector LayoutTransformRule_get_denotation( ++ const onnx::TensorShapeProto& shape) { + auto ret = std::vector(); +- auto rank = shape->dim_size(); ++ auto rank = shape.dim_size(); ++ ret.reserve(rank); + for (auto i = 0; i < rank; ++i) { +- ret.push_back(shape->dim(i).denotation()); ++ auto& d = shape.dim(i).denotation(); ++ ret.push_back(d); + } +- return vaip_core::DllSafe>(ret); ++ return ret; + } +-void node_arg_set_denotation(const NodeArg& node_arg, const std::vector& denotation) { +- auto shape_proto = const_cast(node_arg.Shape()); +- assert(shape_proto != nullptr); +- assert(denotation.size() == static_cast(shape_proto->dim_size())); +- auto rank = shape_proto->dim_size(); +- for (auto i = 0; i < rank; ++i) { +- shape_proto->mutable_dim(i)->set_denotation(denotation[i]); ++ ++static vaip_core::DllSafe> LayoutTransformRule_get_denotation( ++ const onnx::TypeProto& type_proto) { ++ vai_assert(type_proto.value_case() == onnx::TypeProto::kTensorType, type_proto.DebugString()); ++ auto& tensor_type = type_proto.tensor_type(); ++ if (!tensor_type.has_shape()) { ++ return vaip_core::DllSafe>(); + } ++ auto& shape = tensor_type.shape(); ++ auto denotation = LayoutTransformRule_get_denotation(shape); ++ return vaip_core::DllSafe>(denotation); + } +-void node_arg_set_element_type(NodeArg& node_arg, int type) { +- if (type < 0 || type > 16) { +- vai_assert(false, "TensorProto::DataType not supoort"); ++ ++static vaip_core::DllSafe> LayoutTransformRule_get_denotation( ++ const NodeArg* node_arg) { ++ assert(node_arg != nullptr); ++ auto* type_proto = node_arg->TypeAsProto(); ++ assert(type_proto != nullptr); ++ return LayoutTransformRule_get_denotation(*type_proto); ++} ++ ++vaip_core::DllSafe> node_arg_get_denotation(const NodeArg& node_arg) { ++ return LayoutTransformRule_get_denotation(&node_arg); ++} ++ ++static onnx::TensorShapeProto* node_arg_get_tensor_mutable_shape( ++ NodeArg* node_arg) { ++ assert(node_arg != nullptr); ++ auto type_proto = const_cast(node_arg->TypeAsProto()); ++ assert(type_proto != nullptr); ++ vai_assert(type_proto->value_case() == onnx::TypeProto::kTensorType, ++ type_proto->DebugString()); ++ return type_proto->mutable_tensor_type()->mutable_shape(); ++} ++ ++static void LayoutTransformRule_set_denotation( ++ onnx::TensorShapeProto& shape, const std::vector& denotation) { ++ assert(denotation.size() == static_cast(shape.dim_size())); ++ auto rank = shape.dim_size(); ++ for (auto i = 0; i < rank; ++i) { ++ shape.mutable_dim(i)->set_denotation(denotation[i]); + } +- auto data_type = static_cast(type); +- auto type_proto = const_cast(node_arg.TypeAsProto()); ++} ++void node_arg_set_denotation(const NodeArg& node_arg, ++ const std::vector& denotation) { ++ auto mutable_shape = ++ node_arg_get_tensor_mutable_shape(const_cast(&node_arg)); ++ ++ return LayoutTransformRule_set_denotation(*mutable_shape, denotation); ++} ++ ++void node_arg_set_element_type(NodeArg& node_arg, ++ onnx::TensorProto::DataType data_type) { ++ auto type_proto = const_cast(node_arg.TypeAsProto()); + assert(type_proto != nullptr); + auto current_elem_type = type_proto->mutable_tensor_type()->elem_type(); + auto input_elem_type = data_type; +@@ -67,12 +138,24 @@ void node_arg_set_element_type(NodeArg& node_arg, int type) { + current_elem_type, true); + vai_assert(status.IsOK(), status.ErrorMessage()); + } ++void node_arg_set_shape(NodeArg& node_arg, std::vector shape) { ++ auto type_proto = const_cast(node_arg.TypeAsProto()); ++ assert(type_proto != nullptr); ++ for (auto i = 0u; i < shape.size(); i++) { ++ type_proto->mutable_tensor_type() ++ ->mutable_shape() ++ ->mutable_dim(i) ++ ->set_dim_value(shape[i]); ++ } ++} ++ + const ONNX_NAMESPACE::TensorProto& node_arg_get_const_data_as_tensor( + const Graph& graph, const NodeArg& node_arg) { + auto tensor_proto = graph.GetConstantInitializer(node_arg.Name(), true); + assert(tensor_proto != nullptr); + return *tensor_proto; + } ++ + int node_arg_get_element_type(const NodeArg& node_arg) { + auto type_proto = node_arg.TypeAsProto(); + assert(type_proto != nullptr); +@@ -81,7 +164,9 @@ int node_arg_get_element_type(const NodeArg& node_arg) { + } + return type_proto->tensor_type().elem_type(); + } +-NodeArg& node_arg_clone(Graph& graph, const NodeArg& node_arg, const std::string& name) { ++ ++NodeArg& node_arg_clone(Graph& graph, const NodeArg& node_arg, ++ const std::string& name) { + vai_assert(name != node_arg.Name(), "node arg must have a new unique name"); + vai_assert(graph.GetNodeArg(name) == nullptr, std::string("node arg " + name + " already exists. ")); + auto type_proto = node_arg.TypeAsProto(); +@@ -89,10 +174,12 @@ NodeArg& node_arg_clone(Graph& graph, const NodeArg& node_arg, const std::string + auto& ret = graph.GetOrCreateNodeArg(name, type_proto); + return ret; + } +-NodeArg& node_arg_new(Graph& graph, const std::string& name, const std::vector* shape, int element_type) { ++ ++NodeArg& node_arg_new(Graph& graph, ++ const std::string& name, const std::vector* shape, int element_type) { + vai_assert(graph.GetNodeArg(name) == nullptr, std::string("node arg " + name + " already exists. ")); +- auto type_proto = ONNX_NAMESPACE::TypeProto::Create(); +- auto tensor_type = type_proto->mutable_tensor_type(); ++ auto type_proto = onnx::TypeProto(); ++ auto tensor_type = type_proto.mutable_tensor_type(); + tensor_type->set_elem_type(element_type); + if (shape != nullptr) { + auto shape_proto = tensor_type->mutable_shape(); +@@ -102,6 +189,8 @@ NodeArg& node_arg_new(Graph& graph, const std::string& name, const std::vectorhas_shape() == false); + } +- return graph.GetOrCreateNodeArg(name, type_proto.release()); ++ auto& ret = graph.GetOrCreateNodeArg(name, &type_proto); ++ return ret; + } ++ + } // namespace vaip +diff --git a/onnxruntime/core/providers/vitisai/imp/node_attrs.cc b/onnxruntime/core/providers/vitisai/imp/node_attrs.cc +new file mode 100644 +index 000000000..e438266e2 +--- /dev/null ++++ b/onnxruntime/core/providers/vitisai/imp/node_attrs.cc +@@ -0,0 +1,114 @@ ++// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. ++// Licensed under the MIT License. ++#include "vaip/node_attrs.h" ++#include "./vai_assert.h" ++ ++namespace vaip { ++static onnx::AttributeProto make_attribute(const std::string& name, ++ int64_t value) { ++ auto ret = onnx::AttributeProto(); ++ ret.set_name(name); ++ ret.set_type(onnx::AttributeProto::INT); ++ ret.set_i(value); ++ return ret; ++} ++ ++static onnx::AttributeProto make_attribute(const std::string& name, ++ const std::vector value) { ++ auto ret = onnx::AttributeProto(); ++ ret.set_name(name); ++ ret.set_type(onnx::AttributeProto::INTS); ++ for (auto v : value) { ++ ret.add_ints(v); ++ } ++ return ret; ++} ++ ++static onnx::AttributeProto make_attribute(const std::string& name, ++ const std::string& value) { ++ auto ret = onnx::AttributeProto(); ++ ret.set_name(name); ++ ret.set_type(onnx::AttributeProto::STRING); ++ ret.set_s(value); ++ return ret; ++} ++static onnx::AttributeProto make_attribute( ++ const std::string& name, const std::vector& value) { ++ auto ret = onnx::AttributeProto(); ++ ret.set_name(name); ++ ret.set_type(onnx::AttributeProto::STRINGS); ++ for (auto v : value) { ++ ret.add_strings(v); ++ } ++ return ret; ++} ++ ++static onnx::AttributeProto make_attribute(const std::string& name, ++ const std::vector& value) { ++ auto ret = onnx::AttributeProto(); ++ ret.set_name(name); ++ ret.set_type(onnx::AttributeProto::FLOATS); ++ for (auto v : value) { ++ ret.add_floats(v); ++ } ++ return ret; ++} ++ ++static onnx::AttributeProto make_attribute(const std::string& name, ++ const onnx::TensorProto& value) { ++ auto ret = onnx::AttributeProto(); ++ ret.set_name(name); ++ ret.set_type(onnx::AttributeProto::TENSOR); ++ *(ret.mutable_t()) = std::move(value); ++ return ret; ++} // namespace vaip ++ ++NodeAttr::NodeAttr(const std::string& name, int64_t value) ++ : attribute_proto_{make_attribute(name, value)} {} ++ ++NodeAttr::NodeAttr(const std::string& name, const std::vector& value) ++ : attribute_proto_{make_attribute(name, value)} {} ++ ++NodeAttr::NodeAttr(const std::string& name, const std::string& value) ++ : attribute_proto_{make_attribute(name, value)} {} ++ ++NodeAttr::NodeAttr(const std::string& name, ++ const std::vector& value) ++ : attribute_proto_{make_attribute(name, value)} {} ++ ++NodeAttr::NodeAttr(const std::string& name, const std::vector& value) ++ : attribute_proto_{make_attribute(name, value)} {} ++ ++NodeAttr::NodeAttr(const std::string& name, const onnx::TensorProto& value) ++ : attribute_proto_{make_attribute(name, value)} {} ++ ++onnx::AttributeProto& NodeAttr::get() { return attribute_proto_; } ++ ++NodeAttributesBuiler::NodeAttributesBuiler(size_t capacity) : attrs_{} { ++ attrs_.reserve(capacity); ++} ++ ++NodeAttributes NodeAttributesBuiler::build() { ++ auto ret = NodeAttributes(); ++ ret.reserve(attrs_.size()); ++ for (auto& node_attr : attrs_) { ++ onnx::AttributeProto& attr_proto = node_attr.get(); ++ auto name = attr_proto.name(); ++ ret.insert(std::make_pair(name, std::move(attr_proto))); ++ } ++ attrs_.clear(); ++ return ret; ++} ++ ++void NodeAttributesBuiler::merge_into(Node& node) { ++ merge_into(node.GetMutableAttributes()); ++} ++ ++void NodeAttributesBuiler::merge_into(NodeAttributes& attrs) { ++ for (auto& attr : attrs_) { ++ vai_assert(attr.get().has_name(), std::string("attr must has name " + attr.get().DebugString())); ++ auto name = attr.get().name(); ++ attrs.insert_or_assign(std::move(name), std::move(attr.get())); ++ } ++} ++} // namespace vaip +diff --git a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +index 97ed2d3b4..ee8dfc6d0 100644 +--- a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc ++++ b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +@@ -1,25 +1,130 @@ ++ ++ + // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + // Licensed under the MIT License. +- + #include "./register_xir_ops.h" + #include "./vai_assert.h" +-#include "core/providers/shared_library/provider_api.h" ++ ++#include "core/common/logging/logging.h" ++#include "core/common/status.h" ++ ++#include "core/framework/customregistry.h" ++ + #include "core/session/onnxruntime_c_api.h" ++#include "core/session/custom_ops.h" ++#include "core/session/inference_session.h" ++#include "onnx/defs/schema.h" ++#include "onnx/defs/shape_inference.h" + + using namespace onnxruntime; +- + namespace vaip { ++ ++static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { ++ auto* shape = ctx.getAttribute("shape"); ++ auto* data_type = ctx.getAttribute("data_type"); ++ if (data_type->s() == "float32") { ++ updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT); ++ } else if (data_type->s() == "int8") { ++ updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT8); ++ } else if (data_type->s() == "uint8") { ++ updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::UINT8); ++ } else if (data_type->s() == "int32") { ++ updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32); ++ } else if (data_type->s() == "int64") { ++ updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT64); ++ } else if (data_type->s() == "int1") { ++ updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); ++ } else if (data_type->s() == "bfloat16") { ++ updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BFLOAT16); ++ } else if (data_type->s() == "float16") { ++ updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT16); ++ } else { ++ vai_assert(false, ", not supported data_type: " + data_type->s()); ++ } ++ if (shape != nullptr) { ++ for (auto i = 0; i < shape->ints_size(); ++i) { ++ ONNX_NAMESPACE::appendDim(ONNX_NAMESPACE::getOutputShape(ctx, 0), shape->ints(i)); ++ } ++ } else { ++ // set scalar type. ++ auto* output_shape = ONNX_NAMESPACE::getOutputShape(ctx, 0); ++ output_shape->clear_dim(); ++ } ++ return; ++} ++ ++static void xir_fixneuron_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { ++ ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); ++ ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); ++} ++ ++static void xir_subgraph_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { ++ auto num_inputs = ctx.getNumInputs(); ++ ++ // Run inferencing on the subgraph ++ ONNX_NAMESPACE::GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body"); ++ if (!graphInferencer) { ++ fail_type_inference("body is missing."); ++ } ++ ++ std::vector input_data; ++ std::vector subgraph_input_types; ++ for (size_t i = 0; i < num_inputs; ++i) { ++ input_data.push_back(ctx.getInputData(i)); ++ subgraph_input_types.push_back(ctx.getInputType(i)); ++ } ++ std::vector output_types; ++ output_types = ++ graphInferencer->doInferencing(subgraph_input_types, input_data); ++ ++ auto num_outputs = ctx.getNumOutputs(); ++ auto num_of_the_subgraph_outputs = output_types.size(); ++ if (num_outputs != num_of_the_subgraph_outputs) { ++ fail_type_inference("super layer has ", num_outputs, ++ " but subgraphs produce ", num_of_the_subgraph_outputs); ++ } ++ for (size_t i = 0, end = output_types.size(); i < end; ++i) { ++ auto subgraph_output = output_types[i]; ++ auto* super_layer_output = ctx.getOutputType(i); ++ *super_layer_output = *subgraph_output; ++ } ++} ++ + void register_xir_ops(const std::vector& domains) { ++ std::shared_ptr custom_registry; ++ auto status = CreateCustomRegistry(gsl::span(domains), custom_registry); ++ vai_assert(status.IsOK(), status.ErrorMessage()); + for (auto domain : domains) { + for (auto op : domain->custom_ops_) { + auto name = op->GetName(op); ++ auto schema1 = custom_registry->GetOpschemaRegistry()->GetSchema(name, ORT_API_VERSION, domain->domain_); ++ auto schema2 = ::ONNX_NAMESPACE::OpSchema(); ++ schema2.SetName(schema1->Name()); ++ schema2.SetDomain(schema1->domain()); ++ auto n = 0; ++ for (auto input : schema1->inputs()) { ++ schema2.Input(n, input.GetName(), input.GetDescription(), std::string("T") + std::to_string(n), input.GetOption(), false, input.GetMinArity(), input.GetDifferentiationCategory()); ++ schema2.TypeConstraint(std::string("T") + std::to_string(n), DataTypeImpl::ToString(DataTypeImpl::AllTensorTypes()), "all types"); ++ n = n + 1; ++ } ++ auto m = n; ++ n = 0; ++ for (auto output : schema1->outputs()) { ++ auto type_str = std::string("T") + std::to_string(n + m); ++ schema2.Output(n, output.GetName(), output.GetDescription(), type_str, output.GetOption(), false, output.GetMinArity(), output.GetDifferentiationCategory()); ++ schema2.TypeConstraint(type_str, DataTypeImpl::ToString(DataTypeImpl::AllTensorTypes()), "all types"); ++ n = n + 1; ++ } ++ schema2.SinceVersion(1); ++ schema2.AllowUncheckedAttributes(); + if ((std::string)name == "super_layer") { +- Provider_GetHost()->RegisterSchema(domain->domain_, op, 1); ++ schema2.TypeAndShapeInferenceFunction(xir_subgraph_shape_inference); + } else if ((std::string)name == "FixNeuron") { +- Provider_GetHost()->RegisterSchema(domain->domain_, op, 2); ++ schema2.TypeAndShapeInferenceFunction(xir_fixneuron_shape_inference); + } else { +- Provider_GetHost()->RegisterSchema(domain->domain_, op, 3); ++ schema2.TypeAndShapeInferenceFunction(xir_shape_infer); + } ++ ONNX_NAMESPACE::RegisterSchema(schema2, ORT_API_VERSION); + } + } + } +diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +index 48dcd220a..db03354bf 100644 +--- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc ++++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +@@ -1,19 +1,20 @@ + // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + // Licensed under the MIT License. + #include "./tensor_proto.h" ++#include "./vai_assert.h" ++#include "core/framework/tensorprotoutils.h" + + #include + #include + +-#include "./vai_assert.h" +-#include "core/providers/shared_library/provider_api.h" + namespace vaip { +-gsl::span tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& tensor) { ++ ++gsl::span tensor_proto_as_raw( ++ const ONNX_NAMESPACE::TensorProto& tensor) { + auto& mut_tensor = const_cast(tensor); + if (!tensor.has_raw_data()) { + std::vector unpacked_tensor; +- auto path = onnxruntime::Path::Create(); +- auto s = onnxruntime::utils::UnpackInitializerData(tensor, *path, unpacked_tensor); ++ auto s = onnxruntime::utils::UnpackInitializerData(tensor, onnxruntime::Path(), unpacked_tensor); + mut_tensor.mutable_raw_data()->resize(unpacked_tensor.size()); + mut_tensor.clear_float_data(); + mut_tensor.clear_int32_data(); +@@ -26,51 +27,78 @@ gsl::span tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& ten + return gsl::span(tensor.raw_data().data(), tensor.raw_data().size()); + } + +-vaip_core::DllSafe> tensor_proto_get_shape(const ONNX_NAMESPACE::TensorProto& tensor_proto) { ++size_t tensor_proto_raw_data_size(const ONNX_NAMESPACE::TensorProto& tensor) { ++ return tensor.raw_data().size(); ++} ++ ++std::vector tensor_proto_get_shape( ++ const onnx::TensorProto& tensor_proto) { + auto ret = std::vector(); + int rank = tensor_proto.dims_size(); + if (rank > 0) { +- auto& dims = tensor_proto.dims(); +- for (auto i = 0; i < dims.size(); ++i) { +- ret.push_back(dims[i]); ++ ret.reserve((size_t)rank); ++ for (auto i = 0; i < rank; ++i) { ++ ret.push_back(tensor_proto.dims(i)); + } + } +- return vaip_core::DllSafe(ret); ++ return ret; + } +-static ONNX_NAMESPACE::TensorProto* tensor_proto_new(const std::string& name, const std::vector& shape, +- int data_type, const char* data, size_t data_size) { +- auto tensor_proto = ONNX_NAMESPACE::TensorProto::Create(); +- tensor_proto->set_name(name); +- for (auto s : shape) { +- tensor_proto->add_dims(s); +- } +- tensor_proto->set_data_type(data_type); +- tensor_proto->mutable_raw_data()->assign(data, data_size); +- return tensor_proto.release(); ++ ++const std::string& tensor_proto_get_name( ++ const ONNX_NAMESPACE::TensorProto& tensor) { ++ return tensor.name(); + } + +-ONNX_NAMESPACE::TensorProto* tensor_proto_new_i32(const std::string& name, const std::vector& shape, +- const std::vector& data) { +- return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT32, +- reinterpret_cast(&data[0]), data.size() * sizeof(int32_t)); ++ONNX_NAMESPACE::TensorProto tensor_proto_new_i32( ++ const std::string& name, const std::vector& shape, ++ const std::vector& data) { ++ auto tensor_proto = ONNX_NAMESPACE::TensorProto(); ++ tensor_proto.set_name(name); ++ tensor_proto.mutable_dims()->Clear(); ++ tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); ++ tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::INT32); ++ tensor_proto.mutable_raw_data()->assign( ++ reinterpret_cast(&data[0]), data.size() * sizeof(int32_t)); ++ return tensor_proto; + } + +-ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const std::vector& shape, +- const std::vector& data) { +- return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT64, +- reinterpret_cast(&data[0]), data.size() * sizeof(int64_t)); ++ONNX_NAMESPACE::TensorProto tensor_proto_new_i64( ++ const std::string& name, const std::vector& shape, ++ const std::vector& data) { ++ auto tensor_proto = ONNX_NAMESPACE::TensorProto(); ++ tensor_proto.set_name(name); ++ tensor_proto.mutable_dims()->Clear(); ++ tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); ++ tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::INT64); ++ tensor_proto.mutable_raw_data()->assign( ++ reinterpret_cast(&data[0]), data.size() * sizeof(int64_t)); ++ return tensor_proto; + } + +-ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector& shape, +- const std::vector& data) { +- return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT8, +- reinterpret_cast(&data[0]), data.size() * sizeof(int8_t)); ++ONNX_NAMESPACE::TensorProto tensor_proto_new_i8( ++ const std::string& name, const std::vector& shape, ++ const std::vector& data) { ++ auto tensor_proto = ONNX_NAMESPACE::TensorProto(); ++ tensor_proto.set_name(name); ++ tensor_proto.mutable_dims()->Clear(); ++ tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); ++ tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::INT8); ++ tensor_proto.mutable_raw_data()->assign( ++ reinterpret_cast(&data[0]), data.size() * sizeof(int8_t)); ++ return tensor_proto; + } + +-ONNX_NAMESPACE::TensorProto* tensor_proto_new_floats(const std::string& name, const std::vector& shape, +- const std::vector& data) { +- return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_FLOAT, +- reinterpret_cast(&data[0]), data.size() * sizeof(float)); ++ONNX_NAMESPACE::TensorProto tensor_proto_new_floats( ++ const std::string& name, const std::vector& shape, ++ const std::vector& data) { ++ auto tensor_proto = ONNX_NAMESPACE::TensorProto(); ++ tensor_proto.set_name(name); ++ tensor_proto.mutable_dims()->Clear(); ++ tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); ++ tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::FLOAT); ++ tensor_proto.mutable_raw_data()->assign( ++ reinterpret_cast(&data[0]), data.size() * sizeof(float)); ++ return tensor_proto; + } + + } // namespace vaip +diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h +index 292905ca7..00aa388c8 100644 +--- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h ++++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h +@@ -1,20 +1,31 @@ + // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + // Licensed under the MIT License. + #pragma once +-#include "vaip/my_ort.h" +-#include "vaip/vaip_gsl.h" +-#include "vaip/dll_safe.h" +- ++// ++#include "core/common/gsl.h" ++#include "onnx/onnx_pb.h" + namespace vaip { +-gsl::span tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& tensor); +-vaip_core::DllSafe> tensor_proto_get_shape(const ONNX_NAMESPACE::TensorProto& tensor); +-const std::string& tensor_proto_get_name(const ONNX_NAMESPACE::TensorProto& tensor); +-ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector& shape, +- const std::vector& data); +-ONNX_NAMESPACE::TensorProto* tensor_proto_new_i32(const std::string& name, const std::vector& shape, +- const std::vector& data); +-ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const std::vector& shape, +- const std::vector& data); +-ONNX_NAMESPACE::TensorProto* tensor_proto_new_floats(const std::string& name, const std::vector& shape, +- const std::vector& data); ++ ++gsl::span tensor_proto_as_raw( ++ const ONNX_NAMESPACE::TensorProto& tensor); ++size_t tensor_proto_raw_data_size(const ONNX_NAMESPACE::TensorProto& tensor); ++ ++std::vector tensor_proto_get_shape( ++ const ONNX_NAMESPACE::TensorProto& tensor); ++const std::string& tensor_proto_get_name( ++ const ONNX_NAMESPACE::TensorProto& tensor); ++ONNX_NAMESPACE::TensorProto tensor_proto_new_i8( ++ const std::string& name, const std::vector& shape, ++ const std::vector& data); ++ONNX_NAMESPACE::TensorProto tensor_proto_new_i32( ++ const std::string& name, const std::vector& shape, ++ const std::vector& data); ++ONNX_NAMESPACE::TensorProto tensor_proto_new_i64( ++ const std::string& name, const std::vector& shape, ++ const std::vector& data); ++ ++ONNX_NAMESPACE::TensorProto tensor_proto_new_floats( ++ const std::string& name, const std::vector& shape, ++ const std::vector& data); ++ + } // namespace vaip +diff --git a/onnxruntime/core/providers/vitisai/include/vaip/capability.h b/onnxruntime/core/providers/vitisai/include/vaip/capability.h +index e7644dbe8..d6b5ae34d 100644 +--- a/onnxruntime/core/providers/vitisai/include/vaip/capability.h ++++ b/onnxruntime/core/providers/vitisai/include/vaip/capability.h +@@ -2,7 +2,8 @@ + // Licensed under the MIT License. + #pragma once + +-#include "core/providers/shared_library/provider_api.h" ++#include "core/framework/compute_capability.h" ++#include "core/graph/graph_viewer.h" + #include "vaip/custom_op.h" + namespace vaip { + using namespace ::onnxruntime; +diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +index 1f8b8802e..c446ab3ae 100644 +--- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h ++++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +@@ -2,15 +2,16 @@ + // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + // Licensed under the MIT License. + #pragma once +-#include "core/providers/shared_library/provider_api.h" +-#define ORT_API_MANUAL_INIT ++#include ++#include ++#include ++ + #include "core/session/onnxruntime_cxx_api.h" + #include "core/framework/provider_options.h" + #include "vaip/my_ort.h" + #include "vaip/dll_safe.h" + #include "vaip/custom_op.h" + +-void initialize_vitisai_ep(); +-vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); +-std::shared_ptr get_kernel_registry_vitisaiep(); +-const std::vector& get_domains_vitisaiep(); ++std::vector initialize_vitisai_ep(); ++vaip_core::DllSafe>> compile_onnx_model_with_options( ++ const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); +diff --git a/onnxruntime/core/providers/vitisai/include/vaip/graph.h b/onnxruntime/core/providers/vitisai/include/vaip/graph.h +index 292fb2bb3..9def86457 100644 +--- a/onnxruntime/core/providers/vitisai/include/vaip/graph.h ++++ b/onnxruntime/core/providers/vitisai/include/vaip/graph.h +@@ -1,19 +1,25 @@ + // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + // Licensed under the MIT License. + #pragma once ++#include + #include "./node.h" +-#include "vaip/my_ort.h" + namespace vaip { + using namespace onnxruntime; + + void graph_remove_node(Graph& graph, const NodeInput& node_input); +-Node& graph_add_node(Graph& graph, const std::string& name, const std::string& op_type, const std::string& description, +- const std::vector& input_args, const std::vector& output_args, +- const NodeAttributes& attributes, const std::string& domain); +-void graph_save(const Graph& graph, const std::string& filename, const std::string& dat_filename, +- size_t initializer_size_threshold); +-Node& graph_fuse(Graph& graph, const std::string& name, const std::string& op_type, const std::vector& nodes, +- const std::vector& inputs, const std::vector& outputs, ++Node& graph_add_node(Graph& graph, const std::string& name, ++ const std::string& op_type, const std::string& description, ++ const std::vector& input_args, ++ const std::vector& output_args, ++ const NodeAttributes& attributes, ++ const std::string& domain); ++ ++void graph_save(const Graph& graph, const std::string& filename, const std::string& dat_filename, size_t initializer_size_threshold); ++Node& graph_fuse(Graph& graph, const std::string& name, ++ const std::string& op_type, ++ const std::vector& nodes, ++ const std::vector& inputs, ++ const std::vector& outputs, + const std::vector& constant_initializers); + + } // namespace vaip +diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +index 46fc4ac9b..d43ef1253 100644 +--- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h ++++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +@@ -9,17 +9,15 @@ + #include + + namespace onnxruntime { +-struct Model; +-struct Graph; +-struct GraphViewer; +-struct Node; +-struct NodeArg; +-struct ProviderHost; +-struct NodeAttributes; ++class Model; ++class Graph; ++class GraphViewer; ++class Node; ++class NodeArg; + } // namespace onnxruntime + namespace ONNX_NAMESPACE { +-struct AttributeProto; +-struct TensorProto; ++class AttributeProto; ++class TensorProto; + #ifndef USE_VITISAI + enum TensorProto_DataType : int { + TensorProto_DataType_UNDEFINED = 0, +@@ -70,7 +68,6 @@ using onnxruntime::GraphViewer; + using onnxruntime::Model; + using onnxruntime::Node; + using onnxruntime::NodeArg; +-using onnxruntime::NodeAttributes; + struct ModelDeleter { + VAIP_DLL_SPEC void operator()(Model* tp) const; + }; +@@ -78,17 +75,22 @@ using ModelPtr = std::unique_ptr; + struct AttributeProtoDeleter { + VAIP_DLL_SPEC void operator()(AttributeProto* p) const; + }; +-using AttributeProtoPtr = std::unique_ptr; ++using AttributeProtoPtr = ++ std::unique_ptr; + + struct TensorProtoDeleter { + VAIP_DLL_SPEC void operator()(TensorProto* tp) const; + }; + using TensorProtoPtr = std::unique_ptr; + ++/// I cannot forward declare a using directive, because ++/// std::unorderd_map required AttributeProto must be defiend. ++class NodeAttributes; + struct NodeAttributesDeleter { + VAIP_DLL_SPEC void operator()(NodeAttributes* p) const; + }; +-using NodeAttributesPtr = std::unique_ptr; ++using NodeAttributesPtr = ++ std::unique_ptr; + /// get node's input + /// when Node* is nullptr, it is a tensor in the initializer. + /// node_arg is always non-null. +diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node.h b/onnxruntime/core/providers/vitisai/include/vaip/node.h +index 31d9d4bd7..bad7660f6 100644 +--- a/onnxruntime/core/providers/vitisai/include/vaip/node.h ++++ b/onnxruntime/core/providers/vitisai/include/vaip/node.h +@@ -2,6 +2,10 @@ + // Licensed under the MIT License. + + #pragma once ++ ++#include ++ ++#include "core/graph/node_arg.h" + #include "vaip/dll_safe.h" + #include "vaip/my_ort.h" + namespace vaip { +@@ -13,4 +17,8 @@ vaip_core::DllSafe> node_get_inputs(const Node& node); + + /// to support multiple outputs + vaip_core::DllSafe> node_get_output_node_args(const Node& node); ++/// get output shape ++/// index is usually zero, because most operators only have a single output. ++vaip_core::DllSafe> node_get_output_shape(const Node& node, int index = 0); ++ + } // namespace vaip +diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h b/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h +index fca641c5e..76432fc5b 100644 +--- a/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h ++++ b/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h +@@ -2,8 +2,9 @@ + // Licensed under the MIT License. + + #pragma once ++#include + #include "vaip/dll_safe.h" +-#include "vaip/my_ort.h" ++#include + namespace vaip { + using namespace onnxruntime; + +@@ -25,7 +26,9 @@ void node_arg_set_shape_i64(const NodeArg& node_arg, + void node_arg_set_denotation(const NodeArg& node_arg, + const std::vector& denotation); + void node_arg_set_element_type(NodeArg& node_arg, +- int data_type); ++ ONNX_NAMESPACE::TensorProto::DataType data_type); ++void node_arg_set_shape(NodeArg& node_arg, std::vector shape); ++ + const ONNX_NAMESPACE::TensorProto& node_arg_get_const_data_as_tensor(const Graph& graph, + const NodeArg& node_arg); + +diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h b/onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h +new file mode 100644 +index 000000000..49cd1aad8 +--- /dev/null ++++ b/onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h +@@ -0,0 +1,46 @@ ++// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. ++// Licensed under the MIT License. ++ ++#pragma once ++#include ++ ++#include ++ ++#include "core/graph/basic_types.h" ++namespace vaip { ++using namespace onnxruntime; ++class NodeAttr { ++ public: ++ NodeAttr(const std::string& name, int64_t value); ++ NodeAttr(const std::string& name, const std::vector& value); ++ NodeAttr(const std::string& name, const std::string& value); ++ NodeAttr(const std::string& name, const std::vector& value); ++ NodeAttr(const std::string& name, const std::vector& value); ++ NodeAttr(const std::string& name, const onnx::TensorProto& value); ++ ++ onnx::AttributeProto& get(); ++ ++ private: ++ onnx::AttributeProto attribute_proto_; ++}; ++ ++class NodeAttributesBuiler { ++ public: ++ explicit NodeAttributesBuiler(size_t capacity = 10); ++ NodeAttributesBuiler(const NodeAttributesBuiler&) = delete; ++ NodeAttributesBuiler(NodeAttributesBuiler&&) = default; ++ /// after build, all attrs_ are cleared. ++ NodeAttributes build(); ++ /// for efficiency reason, after merge_into, all attrs_ are moved. ++ void merge_into(Node& node); ++ void merge_into(NodeAttributes& attrs); ++ template ++ NodeAttributesBuiler& add(const std::string& name, T&& value) { ++ attrs_.emplace_back(name, std::forward(value)); ++ return *this; ++ } ++ ++ private: ++ std::vector attrs_; ++}; ++} // namespace vaip +diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +index ae5f71d66..0d7d5f622 100644 +--- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h ++++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +@@ -13,7 +13,6 @@ struct OrtApi; + namespace vaip_core { + + struct OrtApiForVaip { +- onnxruntime::ProviderHost* host_; + const OrtApi* ort_api_; + // model + Model* (*model_load)(const std::string& file); // [0] +@@ -50,7 +49,7 @@ struct OrtApiForVaip { + const std::string& description, + const std::vector& input_args, + const std::vector& output_args, +- const NodeAttributes& attributes, ++ NodeAttributes& attributes, + const std::string& domain); // [18] + void (*graph_save)(const Graph& graph, const std::string& filename, + const std::string& dat_filename, +@@ -120,8 +119,8 @@ struct OrtApiForVaip { + NodeAttributes* (*node_attributes_new)(); // [46] + void (*node_attributes_delete)(NodeAttributes* p); // [47] + void (*node_attributes_add)(NodeAttributes& p, AttributeProto&& attr); // [48] +- const AttributeProto* (*node_attributes_get)(const NodeAttributes& p, +- const std::string& name); // [49] ++ AttributeProto* (*node_attributes_get)(NodeAttributes& p, ++ const std::string& name); // [49] + DllSafe> (*node_attributes_get_keys)( + NodeAttributes& p); // [50] + /// attr proto +@@ -195,4 +194,5 @@ VAIP_DLL_SPEC const OrtApiForVaip* api(); + ? ::vaip_core::api()->name \ + : (assert(false && #name " is not set"), nullptr)) + #endif ++VAIP_DLL_SPEC void initialize_ort(); + } // namespace vaip_core +diff --git a/onnxruntime/core/providers/vitisai/symbols.def b/onnxruntime/core/providers/vitisai/symbols.def +deleted file mode 100644 +index 4ec2f7914..000000000 +--- a/onnxruntime/core/providers/vitisai/symbols.def ++++ /dev/null +@@ -1,2 +0,0 @@ +-EXPORTS +- GetProvider +diff --git a/onnxruntime/core/providers/vitisai/version_script.lds b/onnxruntime/core/providers/vitisai/version_script.lds +deleted file mode 100644 +index 2c8e9c4b3..000000000 +--- a/onnxruntime/core/providers/vitisai/version_script.lds ++++ /dev/null +@@ -1,9 +0,0 @@ +-#_init and _fini should be local +-VERS_1.0 { +- global: +- GetProvider; +- +- # Hide everything else. +- local: +- *; +-}; +diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +index 6fc09f349..5f20b32cd 100644 +--- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc ++++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +@@ -1,34 +1,91 @@ + // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. + // Licensed under the MIT License. ++#include "core/graph/graph_utils.h" + #include "vitisai_execution_provider.h" + + #include ++#include + #include + #include + ++#include "core/common/common.h" ++ + #include "vaip/capability.h" + #include "vaip/global_api.h" ++#include "core/session/custom_ops.h" ++#include "core/session/inference_session.h" + + using namespace ONNX_NAMESPACE; + + namespace onnxruntime { ++ + constexpr const char* VITISAI = "VITISAI"; + +-VitisAIExecutionProvider::VitisAIExecutionProvider( +- const ProviderOptions& info) ++static vaip_core::DllSafe>> compile_onnx_model( ++ const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { ++#ifndef _WIN32 ++ auto model_path = graph_viewer.ModelPath().ToPathString(); ++#else ++ using convert_t = std::codecvt_utf8; ++ std::wstring_convert strconverter; ++ auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); ++#endif ++ return compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options); ++} ++ ++struct MyCustomOpKernel : OpKernel { ++ MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { ++ op_kernel_ = ++ op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), reinterpret_cast(&info)); ++ } ++ ++ ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } ++ ++ Status Compute(OpKernelContext* ctx) const override { ++ op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); ++ return Status::OK(); ++ } ++ ++ private: ++ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MyCustomOpKernel); ++ ++ const OrtCustomOp& op_; ++ void* op_kernel_; ++}; ++ ++VitisAIExecutionProvider::VitisAIExecutionProvider(const ProviderOptions& info) + : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) { ++ custom_op_domains_ = initialize_vitisai_ep(); ++ registry_ = std::make_shared(); + CreateKernelRegistry(); + } + + void VitisAIExecutionProvider::CreateKernelRegistry() { +- for (const auto& domain : get_domains_vitisaiep()) { ++ for (const auto& domain : custom_op_domains_) { + for (const auto* op : domain->custom_ops_) { ++ KernelDefBuilder def_builder; ++ def_builder.SetName(op->GetName(op)); ++ def_builder.SetDomain(domain->domain_); ++ def_builder.SinceVersion(1); ++ if (op->version > 12) { ++ auto input_count = op->GetInputTypeCount(op); ++ for (auto i = 0u; i < input_count; i++) { ++ def_builder.InputMemoryType(op->GetInputMemoryType(op, i), i); ++ } ++ } ++ def_builder.Provider(onnxruntime::kVitisAIExecutionProvider); ++ KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, ++ std::unique_ptr& out) -> Status { ++ out = std::make_unique(info, *op); ++ return Status::OK(); ++ }; ++ std::ignore = registry_->Register(def_builder, kernel_create_fn); + vitisai_optypes_.insert(op->GetName(op)); + } + } + } + +-std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return get_kernel_registry_vitisaiep(); } ++std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return registry_; } + + std::vector> VitisAIExecutionProvider::GetCapability( + const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { +@@ -54,9 +111,9 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector& node_compute_funcs) { + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + NodeComputeInfo compute_info; +- auto& attrs = fused_node_graph.fused_node.get().GetAttributes(); +- assert(attrs.count("index")); +- size_t index = attrs.at("index").i(); ++ const onnx::AttributeProto* attr = graph_utils::GetNodeAttribute(fused_node_graph.fused_node, "index"); ++ assert(attr != nullptr); ++ size_t index = (size_t)attr->i(); + compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { + auto* p = (**this->execution_providers_)[index]->compile().release(); + *state = p; +diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +index 186427be4..e86b53339 100644 +--- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h ++++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +@@ -9,7 +9,8 @@ + #include + #include + +-#include "core/providers/shared_library/provider_api.h" ++#include "core/framework/execution_provider.h" ++#include "core/framework/customregistry.h" + #include "core/session/onnxruntime_c_api.h" + + // we cannot include vaip/vaip.hpp here because header file referred by +@@ -20,6 +21,7 @@ class DllSafe; + class ExecutionProvider; + } // namespace vaip_core + namespace onnxruntime { ++ + // Logical device representation. + class VitisAIExecutionProvider : public IExecutionProvider { + public: +diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +index dc34419ef..4c416124c 100755 +--- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc ++++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +@@ -10,6 +10,9 @@ + #include "./vitisai_execution_provider.h" + #include "core/framework/execution_provider.h" + ++#include "core/session/abi_session_options_impl.h" ++#include "core/providers/shared_library/provider_host_api.h" ++ + using namespace onnxruntime; + namespace onnxruntime { + +@@ -27,37 +30,10 @@ std::unique_ptr VitisAIProviderFactory::CreateProvider() { + return std::make_unique(info_); + } + +-struct VitisAI_Provider : Provider { +- // Takes a pointer to a provider specific structure to create the factory. For example, with OpenVINO it is a pointer to an OrtOpenVINOProviderOptions structure +- std::shared_ptr +- CreateExecutionProviderFactory(const void* options) override { +- return std::make_shared(GetProviderOptions(options)); +- } +- // Convert provider options struct to ProviderOptions which is a map +- ProviderOptions GetProviderOptions(const void* options) override { +- auto vitisai_options = reinterpret_cast(options); +- return *vitisai_options; +- } +- // Update provider options from key-value string configuration +- void UpdateProviderOptions(void* options, const ProviderOptions& provider_options) override { +- auto vitisai_options = reinterpret_cast(options); +- for (const auto& entry : provider_options) { +- vitisai_options->insert_or_assign(entry.first, entry.second); +- } +- }; +- // Get provider specific custom op domain list. Provider has the resposibility to release OrtCustomOpDomain instances it creates. +- void GetCustomOpDomainList(IExecutionProviderFactory*, std::vector&) override{}; +- // Called right after loading the shared library, if this throws any errors Shutdown() will be called and the library unloaded +- void Initialize() override { initialize_vitisai_ep(); } +- // Called right before unloading the shared library +- void Shutdown() override {} +-} g_provider; ++std::shared_ptr VitisAIProviderFactoryCreator::Create( ++ const ProviderOptions& provider_options) { ++ initialize_vitisai_ep(); ++ return std::make_shared(provider_options); ++} + + } // namespace onnxruntime +- +-extern "C" { +- +-ORT_API(onnxruntime::Provider*, GetProvider) { +- return &onnxruntime::g_provider; +-} +-} +diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h +index d94729e60..85dafcaf6 100644 +--- a/onnxruntime/core/providers/webnn/builders/helper.h ++++ b/onnxruntime/core/providers/webnn/builders/helper.h +@@ -54,19 +54,6 @@ std::string GetShapeString(std::vector& shape) { + return shape_info.str(); + } + +-inline std::string GetTensorName(const ConstPointerContainer>& input_defs, const size_t index) { +- return (input_defs.size() > index) ? std::string(input_defs[index]->Name()) : ""; +-} +- +-inline std::vector GetVecUint32FromVecInt64(const std::vector& int64_vec) { +- std::vector uint32_vec; +- uint32_vec.reserve(int64_vec.size()); +- std::transform(int64_vec.begin(), int64_vec.end(), +- std::back_inserter(uint32_vec), +- [](int64_t val) -> uint32_t { return SafeInt(val); }); +- return uint32_vec; +-} +- + template + bool ReadIntArrayFrom1DTensor(const onnx::TensorProto& tensor, std::vector& array, const logging::Logger& logger) { + std::vector unpacked_tensor; +diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +index c74545479..ceacb7c2b 100644 +--- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc ++++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +@@ -42,61 +42,72 @@ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod + // Helper functions + common::Status SetConvBaseOptions(ModelBuilder& model_builder, + const Node& node, emscripten::val& options, +- const std::vector input_shape, +- const std::vector weight_shape, +- const std::vector& strides, +- const std::vector& dilations, +- std::vector& pads, +- const bool is_nhwc, +- const bool is_conv1d, ++ const std::vector& strides, ++ const std::vector& dilations, ++ std::vector& pads, + const logging::Logger& logger) { + NodeAttrHelper helper(node); ++ const auto group = helper.Get("group", static_cast(1)); + const auto& input_defs = node.InputDefs(); +- ++ std::vector weight_shape; ++ ORT_RETURN_IF_NOT(GetShape(*input_defs[1], weight_shape, logger), "Cannot get weight shape"); ++ options.set("strides", emscripten::val::array(strides)); ++ options.set("dilations", emscripten::val::array(dilations)); ++ options.set("groups", group); + // Add Padding. ++ std::vector input_shape; ++ ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); + if (node.OpType() == "Conv") { + // Calculate explicit padding for autoPad. + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + std::vector pads_out; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], +- pads, strides, dilations, auto_pad_type, pads_out, !is_nhwc)); +- pads = pads_out; ++ helper.Get("pads", std::vector{0, 0, 0, 0}), ++ helper.Get("strides", std::vector{1, 1}), ++ helper.Get("dilations", std::vector{1, 1}), ++ auto_pad_type, ++ pads_out, ++ model_builder.GetPreferredLayout() == DataLayout::NCHW)); ++ std::transform(pads_out.begin(), pads_out.end(), pads.begin(), ++ [](int64_t pad) -> int32_t { return static_cast(pad); }); + } + } else if (node.OpType() == "ConvTranspose") { + // When the 'output_shape' is specificed, the 'output_padding' values + // in options.outputPadding are ignored. +- std::vector dims; +- std::vector output_padding{0, 0}; ++ std::vector dim; ++ std::vector output_padding{0, 0}; + if (helper.HasAttr("output_shape")) { +- // Default value of 'output_shape' will be ignored as we already check if it existed. +- dims = helper.Get("output_shape", std::vector{-1, -1}); ++ // Default value of 'output_shape' will be ignore as we already check if ++ // it's existed. ++ dim = helper.Get("output_shape", std::vector{-1, -1}); + // Extract the height and width. +- std::vector output_shape; +- if (dims.size() == 1 && is_conv1d) { // ConvTranspose 1d +- output_shape = {dims[0], 1}; +- } else if (dims.size() == 2 && !is_conv1d) { +- output_shape = dims; ++ std::vector output_shape; ++ if (dim.size() == 2) { ++ output_shape = dim; ++ } else if (dim.size() == 4) { ++ output_shape = {dim[2], dim[3]}; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output shape"); + } + // Padding values are auto generated. + if (helper.HasAttr("kernel_shape")) { +- std::vector kernel_shape = helper.Get("kernel_shape", std::vector{-1, -1}); +- if (is_conv1d) { // ConvTranspose 1d +- kernel_shape.push_back(1); +- } +- std::vector total_padding(2); ++ std::vector kernel_shape = helper.Get("kernel_shape", std::vector{-1, -1}); ++ std::vector total_padding(2); ++ std::vector input_shape; ++ ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + for (size_t i = 0; i < 2; i++) { + // Get the dimensions of H and W. + // For NHWC layout, the dimensions of H and W correspond to index 1 and 2. + // For NCHW layout, the dimensions of H and W correspond to index 2 and 3. +- if (is_nhwc) { +- total_padding[i] = strides[i] * (input_shape[i + 1] - 1) + output_padding[i] + +- ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; ++ if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { ++ total_padding[i] = strides[i] * (narrow(input_shape[i + 1]) - 1) + ++ output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; + } else { +- total_padding[i] = strides[i] * (input_shape[i + 2] - 1) + output_padding[i] + +- ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; ++ ORT_RETURN_IF_NOT(model_builder.GetPreferredLayout() == DataLayout::NCHW, ++ "WebNN GPU backend preferred layout should be NCHW."); ++ total_padding[i] = strides[i] * (narrow(input_shape[i + 2]) - 1) + ++ output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i]; + } + } + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); +@@ -111,27 +122,18 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, + } + } + } +- options.set("outputSizes", emscripten::val::array(GetVecUint32FromVecInt64(output_shape))); ++ options.set("outputSizes", emscripten::val::array(output_shape)); + } else { +- output_padding = helper.Get("output_padding", std::vector{0, 0}); +- if (output_padding.size() == 1 && is_conv1d) { // ConvTranspose 1d +- output_padding.push_back(0); +- } +- options.set("outputPadding", emscripten::val::array(GetVecUint32FromVecInt64(output_padding))); ++ output_padding = helper.Get("output_padding", std::vector{0, 0}); ++ options.set("outputPadding", emscripten::val::array(output_padding)); + } + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "conv_op_builder only supports Op Conv and ConvTranspose."); + } +- +- const auto group = helper.Get("group", static_cast(1)); +- options.set("groups", group); +- options.set("strides", emscripten::val::array(GetVecUint32FromVecInt64(strides))); +- options.set("dilations", emscripten::val::array(GetVecUint32FromVecInt64(dilations))); +- + // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], + // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. +- const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; +- options.set("padding", emscripten::val::array(GetVecUint32FromVecInt64(padding))); ++ const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; ++ options.set("padding", emscripten::val::array(padding)); + + // Add bias if present. + if (input_defs.size() > 2) { +@@ -149,8 +151,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, + // Both depthwise Conv and ConvTranspose share the same logic to add the layout. + Status AddInitializerInNewLayout(ModelBuilder& model_builder, + const std::string& name, +- bool is_conv, +- bool is_conv1d) { ++ bool is_conv) { + const auto& tensor = *model_builder.GetInitializerTensors().at(name); + auto data_type = tensor.data_type(); + if (!IsSupportedDataType(data_type, model_builder.GetWebnnDeviceType())) { +@@ -160,13 +161,13 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, + } + + const auto& shape = tensor.dims(); +- std::vector dims = GetVecUint32FromVecInt64(std::vector(std::begin(shape), std::end(shape))); +- +- if (is_conv1d) { +- // Support conv1d by prepending a 1 size dimension. +- dims.push_back(1); +- } ++ std::vector dims; ++ std::transform(shape.cbegin(), shape.cend(), ++ std::back_inserter(dims), ++ [](int64_t dim) -> int32_t { return SafeInt(dim); }); + ++ ORT_RETURN_IF_NOT(dims.size() == 4, ++ "The initializer is not 4D: ", name, " actual dim ", dims.size()); + const uint8_t* src = nullptr; + Initializer unpacked_tensor(tensor, model_builder.GetGraphViewer().ModelPath()); + src = unpacked_tensor.DataAsByteSpan().data(); +@@ -256,101 +257,57 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val output = emscripten::val::object(); + +- std::vector input_shape; +- ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); +- std::vector weight_shape; +- ORT_RETURN_IF_NOT(GetShape(*input_defs[1], weight_shape, logger), "Cannot get weight shape"); +- const auto& weight_name = input_defs[1]->Name(); +- + NodeAttrHelper helper(node); +- auto strides = helper.Get("strides", std::vector{1, 1}); +- auto dilations = helper.Get("dilations", std::vector{1, 1}); +- auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); +- +- const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; +- const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3; +- // Support conv1d by prepending a 1 or 2 size dimensions. +- if (is_conv1d) { +- // Reshape input. +- if (is_nhwc) { +- // For NHWC preferred layout, the input has been transposed. +- // For conv1d it is NCD1 -> ND1C, so we need to prepend 1 to the index 2. +- input_shape.insert(input_shape.begin() + 2, 1); +- } else { +- input_shape.push_back(1); +- } +- std::vector new_shape = GetVecUint32FromVecInt64(input_shape); +- input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); +- +- weight_shape.resize(4, 1); // Ensure 4D by appending 1's if needed. +- strides.resize(2, 1); // Ensure 2D by appending 1's if needed. +- dilations.resize(2, 1); // Ensure 2D by appending 1's if needed. +- if (pads.size() == 2) { +- pads.insert(pads.begin() + 1, 0); +- pads.push_back(0); +- } +- } +- ++ const auto strides = helper.Get("strides", std::vector{1, 1}); ++ const auto dilations = helper.Get("dilations", std::vector{1, 1}); ++ auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); ++ const auto& weight_name = input_defs[1]->Name(); + emscripten::val options = emscripten::val::object(); +- ORT_RETURN_IF_ERROR(SetConvBaseOptions( +- model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger)); ++ ORT_RETURN_IF_ERROR(SetConvBaseOptions(model_builder, node, options, strides, dilations, pads, logger)); + if (op_type == "Conv" || op_type == "ConvInteger") { + int groups = options["groups"].as(); +- if (is_nhwc) { ++ std::vector input_shape; ++ ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); ++ if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + bool depthwise = (groups == input_shape[3] && groups != 1); + options.set("inputLayout", emscripten::val("nhwc")); +- ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d)); ++ ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise)); + if (!depthwise) { + options.set("filterLayout", emscripten::val("ohwi")); + } else { + options.set("filterLayout", emscripten::val("ihwo")); + } + } +- } else { // ConvTranspose +- if (is_nhwc) { +- options.set("inputLayout", emscripten::val("nhwc")); +- options.set("filterLayout", emscripten::val("ohwi")); +- ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false, is_conv1d)); +- } +- } +- +- emscripten::val filter = model_builder.GetOperand(weight_name); +- if (!is_nhwc && is_conv1d) { +- // Reshape weight to 4D for conv1d with NCHW preferred layout. +- std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); +- filter = model_builder.GetBuilder().call("reshape", filter, emscripten::val::array(new_shape)); +- } +- +- if (op_type == "Conv") { +- output = model_builder.GetBuilder().call("conv2d", input, filter, options); +- } else if (op_type == "ConvInteger") { +- emscripten::val x_zero_point = emscripten::val::null(); +- emscripten::val w_zero_point = emscripten::val::null(); +- if (input_defs.size() >= 3) { +- x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); +- } else { +- x_zero_point = model_builder.GetZeroConstant("uint8"); +- } +- if (input_defs.size() >= 4) { +- w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); ++ emscripten::val filter = model_builder.GetOperand(weight_name); ++ if (op_type == "Conv") { ++ output = model_builder.GetBuilder().call("conv2d", input, filter, options); + } else { +- w_zero_point = model_builder.GetZeroConstant("uint8"); ++ emscripten::val x_zero_point = emscripten::val::null(); ++ emscripten::val w_zero_point = emscripten::val::null(); ++ if (input_defs.size() >= 3) { ++ x_zero_point = model_builder.GetOperand(node.InputDefs()[2]->Name()); ++ } else { ++ x_zero_point = model_builder.GetZeroConstant("uint8"); ++ } ++ if (input_defs.size() >= 4) { ++ w_zero_point = model_builder.GetOperand(node.InputDefs()[3]->Name()); ++ } else { ++ w_zero_point = model_builder.GetZeroConstant("uint8"); ++ } ++ output = model_builder.GetBuilder().call("conv2dInteger", ++ input, x_zero_point, filter, w_zero_point, options); + } +- output = model_builder.GetBuilder().call("conv2dInteger", +- input, x_zero_point, filter, w_zero_point, options); ++ + } else { ++ if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { ++ options.set("inputLayout", emscripten::val("nhwc")); ++ options.set("filterLayout", emscripten::val("ohwi")); ++ ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false)); ++ } ++ emscripten::val filter = model_builder.GetOperand(input_defs[1]->Name()); + output = model_builder.GetBuilder().call("convTranspose2d", input, filter, options); + } + +- // If it's a conv1d, reshape it back. +- if (is_conv1d) { +- const auto& output_defs = node.OutputDefs(); +- std::vector output_shape; +- ORT_RETURN_IF_NOT(GetShape(*output_defs[0], output_shape, logger), "Cannot get output shape"); +- std::vector new_shape = GetVecUint32FromVecInt64(output_shape); +- output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(new_shape)); +- } +- + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); + return Status::OK(); + } +@@ -372,9 +329,9 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + } + + const auto input_size = input_shape.size(); +- if (input_size != 4 && input_size != 3) { ++ if (input_size != 4) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s input dimension: " << input_size +- << ". Only conv 1d / 2d is supported."; ++ << ". Only conv 2d is supported."; + return false; + } + +@@ -385,9 +342,9 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + } + + const auto weight_size = weight_shape.size(); +- if (weight_size != 4 && weight_size != 3) { ++ if (weight_size != 4) { + LOGS(logger, VERBOSE) << op_type << " [" << name << "]'s weight dimension: " << weight_size +- << ". Only conv 1d / 2d is supported."; ++ << ". Only conv 2d is supported."; + return false; + } + +diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +index 50e04df4f..4d2470dfe 100644 +--- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc ++++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +@@ -125,7 +125,10 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder + output = model_builder.GetBuilder().call("instanceNormalization", input, options); + // Reshape back to the original output shape for 3D input. + if (input_shape.size() != 4) { +- std::vector output_shape = GetVecUint32FromVecInt64(input_shape); ++ std::vector output_shape; ++ std::transform(input_shape.begin(), input_shape.end(), ++ std::back_inserter(output_shape), ++ [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + output = model_builder.GetBuilder().call( + "reshape", output, emscripten::val::array(output_shape)); + } +diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +index 52b551885..a2a1e2f2e 100644 +--- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc ++++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +@@ -178,10 +178,8 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + return false; + } + for (size_t i = 1; i < input_defs.size(); i++) { +- // Optional tensors (constant_value, axes) can be indicated by an empty name, just ignore it. +- const std::string input_name = GetTensorName(input_defs, i); +- if (!input_name.empty() && !Contains(initializers, input_name)) { +- LOGS(logger, VERBOSE) << "Input [" << input_name << "] must be known as initializer"; ++ if (!Contains(initializers, input_defs[i]->Name())) { ++ LOGS(logger, VERBOSE) << "Input [" << input_defs[i]->Name() << "] must be known as initializer"; + return false; + } + } +diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +index 8b3eecf35..739c3b3f3 100644 +--- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc ++++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +@@ -81,7 +81,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + const auto onnx_kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); + const auto onnx_strides = helper.Get("strides", std::vector{1, 1}); + const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); +- auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); ++ auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); +@@ -94,11 +94,12 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + auto_pad_type, + pads_out, + model_builder.GetPreferredLayout() == DataLayout::NCHW)); +- pads = GetVecUint32FromVecInt64(pads_out); ++ std::transform(pads_out.begin(), pads_out.end(), pads.begin(), ++ [](int64_t pad) -> int32_t { return static_cast(pad); }); + } + // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], + // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. +- const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; ++ const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; + options.set("padding", emscripten::val::array(padding)); + + const auto ceil_mode = helper.Get("ceil_mode", 0); +diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +index f446a7b81..1a702649b 100644 +--- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc ++++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +@@ -134,9 +134,8 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ + return false; + + const auto& op_type = node.OpType(); +- const std::string axes_name = GetTensorName(input_defs, 1); + // If the optional input 'axes' is provided, it must be an initializer. +- if (!axes_name.empty() && !Contains(initializers, axes_name)) { ++ if (input_defs.size() > 1 && !Contains(initializers, input_defs[1]->Name())) { + LOGS(logger, VERBOSE) << "Input axes of " << op_type << " must be a constant"; + return false; + } +diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +index 9018f8c96..186d1e7c1 100644 +--- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc ++++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +@@ -120,9 +120,8 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + std::vector scales_hw; + std::vector sizes_hw; + std::vector axes; +- std::string scales_name = GetTensorName(input_defs, 2); + const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; +- if (!scales_name.empty()) { // Use scales. ++ if (input_defs.size() == 3) { // Use scales. + ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); + if (is_nhwc) { + scales_hw = {scales[1], scales[2]}; +@@ -130,7 +129,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + scales_hw = {scales[2], scales[3]}; + } + options.set("scales", emscripten::val::array(scales_hw)); +- } else { // Use sizes, we already checked inputs in IsOpSupportedImpl. ++ } else { // We already checked number of inputs in IsOpSupportedImpl. + std::vector output_sizes; + ORT_RETURN_IF_NOT(GetResizeOutputSizes(initializers, node, output_sizes, logger), + "Error getting resize output_sizes"); +@@ -204,31 +203,26 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers + } + + { // scales and sizes (if present) must be initializers. +- const std::string scales_name = GetTensorName(input_defs, 2); +- const std::string sizes_name = GetTensorName(input_defs, 3); +- +- // scales (scales may be empty tensor) +- bool has_scales = !scales_name.empty(); +- if ((has_scales && !Contains(initializers, scales_name)) || (!has_scales && node.SinceVersion() == 11)) { +- LOGS(logger, VERBOSE) << "Input scales of Resize must be known"; ++ if (input_defs.size() < 3) { ++ LOGS(logger, VERBOSE) << "Input scales or sizes of Resize must be known"; + return false; + } + +- // sizes (sizes may be empty tensor) +- bool has_sizes = !sizes_name.empty(); +- if (has_sizes && !Contains(initializers, sizes_name)) { +- LOGS(logger, VERBOSE) << "Input sizes of Resize must be known"; ++ // scales ++ if (input_defs.size() == 3 && !Contains(initializers, input_defs[2]->Name())) { ++ LOGS(logger, VERBOSE) << "Input scales of Resize must be known"; + return false; + } + +- if (has_scales && has_sizes) { +- LOGS(logger, VERBOSE) << "Only one of 'scales' and 'sizes' can be specified"; ++ // sizes ++ if (input_defs.size() > 3 && !Contains(initializers, input_defs[3]->Name())) { ++ LOGS(logger, VERBOSE) << "Input sizes of Resize must be known"; + return false; + } + + const bool is_nhwc = node.Domain() == kMSInternalNHWCDomain; + // We want to check if the scales or sizes are not trying to resize on N/C channels here. +- if (has_scales) { // We are using scales. ++ if (input_defs.size() == 3) { // We are using scales. + std::vector scales; + if (!GetResizeScales(initializers, node, scales, logger)) + return false; +@@ -257,9 +251,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers + LOGS(logger, VERBOSE) << "Resize: scale_w: " << scale_w << " is not a whole number"; + return false; + } +- } +- +- if (has_sizes) { ++ } else { + // We are using sizes. + std::vector output_sizes; + if (!GetResizeOutputSizes(initializers, node, output_sizes, logger)) +diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +index 4e0628581..e48cf3501 100644 +--- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc ++++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +@@ -123,10 +123,8 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + + // Inputs: starts, ends, axes, and steps must be constant initializers if present. + for (size_t i = 1; i < input_defs.size(); i++) { +- // Optional tensors (axes, steps) can be indicated by an empty name, just ignore it. +- const std::string input_name = GetTensorName(input_defs, i); +- if (!input_name.empty() && !Contains(initializers, input_name)) { +- LOGS(logger, VERBOSE) << "Input [" << input_name << "] of " << op_type ++ if (!Contains(initializers, input_defs[i]->Name())) { ++ LOGS(logger, VERBOSE) << "Input [" << input_defs[i]->Name() << "] of " << op_type + << " [" << name << "] must be known as initializer"; + return false; + } +diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +index 91f21b196..d568d4e62 100644 +--- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc ++++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +@@ -83,7 +83,10 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + std::vector mapping_split; + mapping_split.insert(mapping_split.begin(), num_outputs - 1, input_shape[axis] / num_outputs); + mapping_split.insert(mapping_split.end(), input_shape[axis] % num_outputs); +- std::vector converted_splits = GetVecUint32FromVecInt64(mapping_split); ++ std::vector converted_splits; ++ std::transform(mapping_split.cbegin(), mapping_split.cend(), ++ std::back_inserter(converted_splits), ++ [](int64_t dim) -> int32_t { return SafeInt(dim); }); + output_array = model_builder.GetBuilder().call("split", + input, + emscripten::val::array(converted_splits), +@@ -133,9 +136,9 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + int32_t axis = helper.Get("axis", 0); + axis = SafeInt(HandleNegativeAxis(axis, rank)); + +- const std::string split_name = GetTensorName(input_defs, 1); +- // Inputs contain optional 'split' input. +- if (!split_name.empty()) { ++ if (input_defs.size() == 2) { ++ // Inputs contains optional 'split' input ++ const auto& split_name = input_defs[1]->Name(); + if (!Contains(initializers, split_name)) { + LOGS(logger, VERBOSE) << "The split must be a constant initializer."; + return false; +@@ -163,7 +166,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, + LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified."; + return false; + } +- } else { ++ } else if (input_defs.size() == 1) { + if (helper.HasAttr("num_outputs")) { + // Split has 'num_outputs' attribute when opset is 18. + const int32_t num_outputs = helper.Get("num_outputs", 1); +diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +index 15149bd8f..2a1672c00 100644 +--- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc ++++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +@@ -87,7 +87,10 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil + + emscripten::val output = emscripten::val::undefined(); + // Use WebNN's reshape to implement Squeeze/Unsqueeze. +- std::vector new_shape = GetVecUint32FromVecInt64(input_shape); ++ std::vector new_shape; ++ std::transform( ++ input_shape.begin(), input_shape.end(), std::back_inserter(new_shape), ++ [](int64_t data) -> uint32_t { return SafeInt(data); }); + // Sort axes_data in ascending order. + std::sort(axes_data.begin(), axes_data.end()); + if (op_type == "Squeeze") { +@@ -135,8 +138,8 @@ bool SqueezeUnsqueezeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& in + + // Squeeze/Unsqueeze opset 13 uses input 1 as axes, it needs to be an initializer. + if (node.SinceVersion() >= 13) { +- const std::string axes_name = GetTensorName(input_defs, 1); +- if (!axes_name.empty()) { ++ if (input_defs.size() > 1) { ++ const auto& axes_name = input_defs[1]->Name(); + if (!Contains(initializers, axes_name)) { + LOGS(logger, ERROR) << "Input axes of " << op_type << " is not present and constant"; + return false; +diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +index 79f60c51a..eca152138 100644 +--- a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc ++++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +@@ -40,7 +40,10 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, + + emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val options = emscripten::val::object(); +- std::vector permutation = GetVecUint32FromVecInt64(perm); ++ std::vector permutation; ++ std::transform(perm.cbegin(), perm.cend(), ++ std::back_inserter(permutation), ++ [](int64_t dim) -> int32_t { return SafeInt(dim); }); + options.set("permutation", emscripten::val::array(permutation)); + emscripten::val output = model_builder.GetBuilder().call("transpose", input, options); + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); +diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc +index ef807a8c4..eaf549ef4 100644 +--- a/onnxruntime/core/providers/webnn/builders/model.cc ++++ b/onnxruntime/core/providers/webnn/builders/model.cc +@@ -70,13 +70,22 @@ Status Model::Predict(const InlinedHashMap& inputs, + "The input of graph has unsupported type, name: ", + name, " type: ", tensor.tensor_info.data_type); + } +- // Copy the inputs from Wasm ArrayBuffer to the WebNN inputs ArrayBuffer. +- // As Wasm ArrayBuffer is not detachable. ++#ifdef ENABLE_WEBASSEMBLY_THREADS ++ // Copy the inputs from Wasm SharedArrayBuffer to the pre-allocated ArrayBuffers. + wnn_inputs_[name].call("set", view); ++#else ++ wnn_inputs_.set(name, view); ++#endif + } + ++#ifdef ENABLE_WEBASSEMBLY_THREADS ++ // This vector uses for recording output buffers from WebNN graph compution when WebAssembly ++ // multi-threads is enabled, since WebNN API only accepts non-shared ArrayBufferView, ++ // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews ++ // and at this time the 'view' defined by Emscripten is shared ArrayBufferView, the memory ++ // address is different from the non-shared one, additional memory copy is required here. + InlinedHashMap output_views; +- ++#endif + for (const auto& output : outputs) { + const std::string& name = output.first; + const struct OnnxTensorData tensor = output.second; +@@ -122,23 +131,21 @@ Status Model::Predict(const InlinedHashMap& inputs, + name, " type: ", tensor.tensor_info.data_type); + } + ++#ifdef ENABLE_WEBASSEMBLY_THREADS + output_views.insert({name, view}); ++#else ++ wnn_outputs_.set(name, view); ++#endif + } +- emscripten::val results = wnn_context_.call( +- "compute", wnn_graph_, wnn_inputs_, wnn_outputs_) +- .await(); +- +- // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer. ++ wnn_context_.call("computeSync", wnn_graph_, wnn_inputs_, wnn_outputs_); ++#ifdef ENABLE_WEBASSEMBLY_THREADS ++ // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm SharedArrayBuffer. + for (const auto& output : outputs) { + const std::string& name = output.first; + emscripten::val view = output_views.at(name); +- view.call("set", results["outputs"][name]); ++ view.call("set", wnn_outputs_[name]); + } +- // WebNN compute() method would return the input and output buffers via the promise +- // resolution. Reuse the buffers to avoid additional allocation. +- wnn_inputs_ = results["inputs"]; +- wnn_outputs_ = results["outputs"]; +- ++#endif + return Status::OK(); + } + +diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc +index 56f7ead8c..cf8a0e23d 100644 +--- a/onnxruntime/core/providers/webnn/builders/model_builder.cc ++++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc +@@ -386,8 +386,7 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { + for (auto& name : output_names_) { + named_operands.set(name, wnn_operands_.at(name)); + } +- +- emscripten::val wnn_graph = wnn_builder_.call("build", named_operands).await(); ++ emscripten::val wnn_graph = wnn_builder_.call("buildSync", named_operands); + if (!wnn_graph.as()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph."); + } +@@ -396,10 +395,13 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { + model->SetOutputs(std::move(output_names_)); + model->SetScalarOutputs(std::move(scalar_outputs_)); + model->SetInputOutputInfo(std::move(input_output_info_)); +- // Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews +- // for inputs and outputs because they will be transferred after compute() done. +- // https://webmachinelearning.github.io/webnn/#api-mlcontext-async-execution ++#ifdef ENABLE_WEBASSEMBLY_THREADS ++ // Pre-allocate the input and output tensors for the WebNN graph ++ // when WebAssembly multi-threads is enabled since WebNN API only ++ // accepts non-shared ArrayBufferView. ++ // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + model->AllocateInputOutputBuffers(); ++#endif + return Status::OK(); + } + +diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +index 29c8ca91f..2922cf954 100644 +--- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc ++++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +@@ -19,7 +19,7 @@ namespace onnxruntime { + + WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags, + const std::string& webnn_threads_number, const std::string& webnn_power_flags) +- : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { ++ : IExecutionProvider{onnxruntime::kWebNNExecutionProvider, true} { + // Create WebNN context and graph builder. + const emscripten::val ml = emscripten::val::global("navigator")["ml"]; + if (!ml.as()) { +@@ -42,8 +42,7 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f + if (webnn_power_flags.compare("default") != 0) { + context_options.set("powerPreference", emscripten::val(webnn_power_flags)); + } +- +- wnn_context_ = ml.call("createContext", context_options).await(); ++ wnn_context_ = ml.call("createContextSync", context_options); + if (!wnn_context_.as()) { + ORT_THROW("Failed to create WebNN context."); + } +@@ -169,7 +168,7 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view + + // Assign inputs and outputs to subgraph's meta_def. + uint64_t model_hash; +- int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); ++ int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); + meta_def->name = "WEBNN_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id); + meta_def->domain = kMSDomain; +@@ -282,6 +281,9 @@ common::Status WebNNExecutionProvider::Compile(const std::vector temp(shape.size()); ++ transform(shape.begin(), shape.end(), temp.begin(), ++ [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + const void* inputBuffer = const_cast(input_tensor.GetTensorRawData()); + inputs.emplace( + input_name, +diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h +index d9cfa5f17..13a475327 100644 +--- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h ++++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h +@@ -6,7 +6,6 @@ + + #include "core/common/inlined_containers.h" + #include "core/framework/execution_provider.h" +-#include "core/framework/model_metadef_id_generator.h" + #include "core/providers/webnn/builders/helper.h" + + #include +@@ -49,6 +48,5 @@ class WebNNExecutionProvider : public IExecutionProvider { + DataLayout preferred_layout_; + webnn::WebnnDeviceType wnn_device_type_; + InlinedHashMap> models_; +- ModelMetadefIdGenerator metadef_id_generator_; + }; + } // namespace onnxruntime +diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +index eafbfae6f..a2a776df4 100644 +--- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc ++++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +@@ -155,7 +155,7 @@ std::unique_ptr RegisterKernels() { + using namespace xnnpack; + + XnnpackExecutionProvider::XnnpackExecutionProvider(const XnnpackExecutionProviderInfo& info) +- : IExecutionProvider{kXnnpackExecutionProvider} { ++ : IExecutionProvider{kXnnpackExecutionProvider, true} { + int xnn_thread_pool_size = info.xnn_thread_pool_size; + int ort_thread_pool_size = info.session_options ? info.session_options->intra_op_param.thread_pool_size : 1; + bool allow_intra_op_spinning = (info.session_options == nullptr) || +diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc +index 7a233c57c..4bae42f4b 100644 +--- a/onnxruntime/core/session/custom_ops.cc ++++ b/onnxruntime/core/session/custom_ops.cc +@@ -26,15 +26,8 @@ + #include "core/session/ort_apis.h" + #include "core/platform/threadpool.h" + +-// NOTE: OrtKernelContext is used by both custom ops and compiled kernels. +-// In a minimal build, ORT_EXTENDED_MINIMAL_BUILD is used to enable EPs like CoreML/NNAPI which use compiled kernels, +-// and ORT_MINIMAL_BUILD_CUSTOM_OPS is used to allow external custom op libraries to be used. +-#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +-#define ENABLE_ORT_KERNEL_CONTEXT_API 1 +-#endif +- + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +-#define ENABLE_CUSTOM_OP_API 1 ++#define ENABLE_CUSTOM_OP_API + #endif + + #if !defined(ORT_MINIMAL_BUILD) +@@ -43,7 +36,7 @@ static constexpr uint32_t min_ort_version_with_variadic_io_support = 14; + static constexpr uint32_t min_ort_version_with_custom_version = 17; + #endif + +-#if ENABLE_CUSTOM_OP_API ++#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + static constexpr uint32_t min_ort_version_with_compute_v2_support = 16; + static constexpr uint32_t min_ort_version_with_shape_inference = 17; + #endif +@@ -59,8 +52,7 @@ struct OrtShapeInferContext { + size_t GetInputCount() const { return 0; } + OrtTensorTypeAndShapeInfo* GetInputTypeShape(size_t) const { return {}; } + onnxruntime::Status SetOutputTypeShape(size_t, const OrtTensorTypeAndShapeInfo*) const { +- return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, +- "OrtShapeInferContext::SetOutputTypeShape not implemented for minimal build"); ++ return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtShapeInferContext::SetOutputTypeShape not implemented for minimal build"); + } + const ONNX_NAMESPACE::AttributeProto* GetAttr(const char*) const { return {}; } + }; +@@ -71,15 +63,13 @@ struct OrtShapeInferContext { + for (size_t ith_input = 0; ith_input < num_inputs; ++ith_input) { + const auto* input_type = ctx_.getInputType(ith_input); + const auto& value_case = input_type->value_case(); +- ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kTensorType, +- "shape inference not yet supported for non-tensor types"); ++ ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kTensorType, "shape inference not yet supported for non-tensor types"); + const auto& shape_proto = input_type->tensor_type().shape(); + const auto& type_proto = input_type->tensor_type(); + auto elem_type = ::onnxruntime::utils::CApiElementTypeFromProtoType(type_proto.elem_type()); + auto tensor_shape = ::onnxruntime::utils::GetTensorShapeFromTensorShapeProto(shape_proto); + auto symbolic_dims = GetSymbolicDims(shape_proto); +- input_type_shapes_.emplace_back( +- OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims).release()); ++ input_type_shapes_.emplace_back(OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims).release()); + } + } + +@@ -131,392 +121,304 @@ struct OrtShapeInferContext { + }; + #endif + +-#if ENABLE_ORT_KERNEL_CONTEXT_API +-template +-static OrtStatusPtr ExecuteIfKernelContextApiEnabled(const T& fn) { ++ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, _Out_ size_t* out) { + API_IMPL_BEGIN +- return fn(); ++ *out = context->GetInputCount(); ++ return nullptr; + API_IMPL_END + } +-#else +-template +-static OrtStatusPtr ExecuteIfKernelContextApiEnabled(const T&) { +- return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "OrtKernelContext API is not enabled in this build"); +-} +-#endif + +-ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out) { +- return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { +- *out = reinterpret_cast(context)->InputCount(); ++ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info) { ++ API_IMPL_BEGIN ++ *info = context->GetInputTypeShape(index); ++ if (*info) { + return nullptr; +- }); +-}; ++ } else { ++ return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Failed to fetch type shape info for the index."); ++ } ++ API_IMPL_END ++} + +-ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out) { +- return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { +- *out = reinterpret_cast(context)->OutputCount(); ++ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr) { ++ API_IMPL_BEGIN ++ *attr = reinterpret_cast(context->GetAttr(attr_name)); ++ if (*attr) { + return nullptr; +- }); +-}; ++ } else { ++ return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist."); ++ } ++ API_IMPL_END ++} + +-ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, +- _Out_ const OrtValue** out) { +- return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { +- const auto* ctx = reinterpret_cast(context); +- *out = reinterpret_cast(ctx->GetInputMLValue(onnxruntime::narrow(index))); +- return nullptr; +- }); +-}; ++ORT_API_STATUS_IMPL(OrtApis::ReadOpAttr, ++ _In_ const OrtOpAttr* op_attr, ++ _In_ OrtOpAttrType type, ++ _Inout_ void* data, ++ _In_ size_t len, ++ _Out_ size_t* out) { ++ API_IMPL_BEGIN + +-ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, +- _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out) { +- return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { +- onnxruntime::TensorShape shape(dim_values, dim_count); +- auto* ctx = reinterpret_cast(context); +- *out = reinterpret_cast(ctx->OutputMLValue(onnxruntime::narrow(index), shape)); +- return nullptr; +- }); +-}; ++ if (!op_attr) { ++ return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Invalid attribute."); ++ } + +-ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, +- _Outptr_ void** out) { +- return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { +- auto* stream = reinterpret_cast(context)->GetComputeStream(); +- if (stream) +- *out = stream->GetHandle(); +- else +- *out = nullptr; +- return nullptr; +- }); +-}; ++ auto attr = reinterpret_cast(op_attr); ++ OrtStatusPtr ret = nullptr; ++ *out = 0; + +-ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetAllocator, _In_ const OrtKernelContext* context, +- _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out) { +- return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { +- const auto* ctx = reinterpret_cast(context); +- onnxruntime::AllocatorPtr allocator = ctx->GetAllocator(mem_info->device); +- if (!allocator) { +- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available"); ++ if (type == OrtOpAttrType::ORT_OP_ATTR_FLOAT) { ++ if (len < sizeof(float)) { ++ ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold a float."); ++ } else { ++ if (attr->has_f()) { ++ auto output_f = reinterpret_cast(data); ++ *output_f = attr->f(); ++ } else { ++ ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute has no float value."); ++ } + } ++ *out = sizeof(float); + +- auto p = std::make_unique(std::move(allocator)); +- *out = p.release(); +- return nullptr; +- }); +-}; ++ } else if (type == OrtOpAttrType::ORT_OP_ATTR_FLOATS) { ++ const auto& floats = attr->floats(); ++ auto num_floats = floats.size(); + +-ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetResource, _In_ const OrtKernelContext* context, +- _In_ int resource_version, _In_ int resource_id, _Outptr_ void** resource) { +- return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { +- *resource = {}; +- const auto* ctx = reinterpret_cast(context); +- auto* stream = reinterpret_cast(ctx->GetComputeStream()); +- if (!stream) { +- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Failed to fetch a stream hosting the requested resource"); ++ if (len < sizeof(float) * num_floats) { ++ ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold the array of floats."); ++ } else { ++ auto output_f = reinterpret_cast(data); ++ for (auto f : floats) { ++ *output_f = f; ++ output_f++; ++ } + } +- *resource = stream->GetResource(resource_version, resource_id); +- return nullptr; +- }); +-}; ++ *out = num_floats * sizeof(float); + +-ORT_API_STATUS_IMPL(OrtApis::KernelContext_ParallelFor, _In_ const OrtKernelContext* context, +- _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data) { +- return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { +- if (!context) { +- return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, "Invalid context"); +- } +- if (fn && total) { +- const auto* ctx = reinterpret_cast(context); +- auto* tp = ctx->GetOperatorThreadPool(); +- if (num_batch) { +- onnxruntime::concurrency::ThreadPool::TryBatchParallelFor( +- tp, +- static_cast(total), +- [fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast(ith)); }, +- static_cast(num_batch)); ++ } else if (type == OrtOpAttrType::ORT_OP_ATTR_INT) { ++ if (len < sizeof(int)) { ++ ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold an int64."); ++ } else { ++ if (attr->has_i()) { ++ auto output_i = reinterpret_cast(data); ++ *output_i = attr->i(); + } else { +- onnxruntime::concurrency::ThreadPool::TrySimpleParallelFor( +- tp, +- static_cast(total), +- [fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast(ith)); }); ++ ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute has no int64 value."); + } + } +- return nullptr; +- }); +-}; ++ *out = sizeof(int64_t); + +-ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetLogger, _In_ const OrtKernelContext* context, +- _Outptr_ const OrtLogger** logger) { +- return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { +- const auto& kernel_ctx_logger = reinterpret_cast(context)->Logger(); ++ } else if (type == OrtOpAttrType::ORT_OP_ATTR_INTS) { ++ const auto& ints = attr->ints(); ++ auto num_ints = ints.size(); + +- *logger = reinterpret_cast(&kernel_ctx_logger); +- return nullptr; +- }); +-} ++ if (len < sizeof(int64_t) * num_ints) { ++ ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold the array of int64."); ++ } else { ++ auto output_i = reinterpret_cast(data); ++ for (auto i : ints) { ++ *output_i = i; ++ output_i++; ++ } ++ } ++ *out = num_ints * sizeof(int64_t); + +-// Enabled via ExecuteIfKernelContextApiEnabled due to KernelContext_GetLogger +-ORT_API_STATUS_IMPL(OrtApis::Logger_LogMessage, _In_ const OrtLogger* logger, OrtLoggingLevel log_severity_level, +- _In_z_ const char* message, _In_z_ const ORTCHAR_T* file_path, int line_number, +- _In_z_ const char* func_name) { +- return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { +- const auto& actual_logger = *reinterpret_cast(logger); +- const auto severity = static_cast(log_severity_level); +- const auto log_data_type = onnxruntime::logging::DataType::SYSTEM; ++ } else if (type == OrtOpAttrType::ORT_OP_ATTR_STRING) { ++ const auto& s = attr->s(); ++ if (len < s.size() + 1) { ++ ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold the string."); ++ } else { ++ char* output_c = reinterpret_cast(data); ++ for (char c : s) { ++ *output_c++ = c; ++ } ++ *output_c = '\0'; ++ } ++ *out = s.size() + 1; + +- if (actual_logger.OutputIsEnabled(severity, log_data_type)) { +-#ifdef _WIN32 +- const std::string file_path_str = onnxruntime::ToUTF8String(file_path); +- onnxruntime::CodeLocation location(file_path_str.c_str(), line_number, func_name); +-#else +- onnxruntime::CodeLocation location(file_path, line_number, func_name); +-#endif ++ } else if (type == OrtOpAttrType::ORT_OP_ATTR_STRINGS) { ++ const auto& ss = attr->strings(); ++ size_t num_bytes = 0; ++ for_each(ss.begin(), ss.end(), [&num_bytes](const std::string& s) { num_bytes += s.size() + 1; }); + +- onnxruntime::logging::Capture( +- actual_logger, +- severity, +- onnxruntime::logging::Category::onnxruntime, +- log_data_type, +- location) +- .Stream() +- << message; ++ if (len < num_bytes) { ++ ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold the array of strings."); ++ } else { ++ char* output_c = reinterpret_cast(data); ++ for (const auto& s : ss) { ++ for (char c : s) { ++ *output_c++ = c; ++ } ++ *output_c++ = '\0'; ++ } + } ++ *out = num_bytes; + +- return nullptr; +- }); ++ } else { ++ ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unknown attribute type."); ++ } ++ ++ return ret; ++ API_IMPL_END + } + +-// Enabled via ExecuteIfKernelContextApiEnabled due to KernelContext_GetLogger +-ORT_API_STATUS_IMPL(OrtApis::Logger_GetLoggingSeverityLevel, _In_ const OrtLogger* logger, +- _Out_ OrtLoggingLevel* out) { +- return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { +- const auto& actual_logger = *reinterpret_cast(logger); +- *out = static_cast(actual_logger.GetSeverity()); ++ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info) { ++ API_IMPL_BEGIN ++ auto status = context->SetOutputTypeShape(index, info); ++ if (status.IsOK()) { + return nullptr; +- }); ++ } else { ++ return OrtApis::CreateStatus(static_cast(status.Code()), status.ErrorMessage().c_str()); ++ } ++ API_IMPL_END + } + +-#if ENABLE_CUSTOM_OP_API +-template +-static OrtStatusPtr ExecuteIfCustomOpsApiEnabled(const T& fn) { ++ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) { + API_IMPL_BEGIN +- return fn(); ++ auto status = reinterpret_cast(info)->GetAttr(name, out); ++ if (status.IsOK()) ++ return nullptr; ++ return onnxruntime::ToOrtStatus(status); + API_IMPL_END + } +-#else +-template +-static OrtStatusPtr ExecuteIfCustomOpsApiEnabled(const T&) { +- return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Custom operator API is not enabled in this build"); +-} +-#endif + +-ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, +- _Out_ size_t* out) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- *out = context->GetInputCount(); ++ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) { ++ API_IMPL_BEGIN ++ auto status = reinterpret_cast(info)->GetAttr(name, out); ++ if (status.IsOK()) + return nullptr; +- }); ++ return onnxruntime::ToOrtStatus(status); ++ API_IMPL_END + } + +-ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, +- _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- *info = context->GetInputTypeShape(index); +- if (*info) { +- return nullptr; +- } else { +- return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, +- "Failed to fetch type shape info for the index."); +- } +- }); +-} ++ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out) { ++ API_IMPL_BEGIN ++ *out = reinterpret_cast(context)->InputCount(); ++ return nullptr; ++ API_IMPL_END ++}; + +-ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, +- _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- *attr = reinterpret_cast(context->GetAttr(attr_name)); +- if (*attr) { +- return nullptr; +- } else { +- return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist."); +- } +- }); +-} ++ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out) { ++ API_IMPL_BEGIN ++ *out = reinterpret_cast(context)->OutputCount(); ++ return nullptr; ++ API_IMPL_END ++}; + +-ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, +- _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- auto status = context->SetOutputTypeShape(index, info); +- if (status.IsOK()) { ++ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out) { ++ API_IMPL_BEGIN ++ *out = reinterpret_cast(reinterpret_cast(context)->GetInputMLValue(gsl::narrow_cast(index))); ++ return nullptr; ++ API_IMPL_END ++}; ++ ++ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out) { ++ API_IMPL_BEGIN ++ onnxruntime::TensorShape shape(dim_values, dim_count); ++ *out = reinterpret_cast(reinterpret_cast(context)->OutputMLValue(gsl::narrow_cast(index), shape)); ++ return nullptr; ++ API_IMPL_END ++}; ++ ++ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, _Inout_ size_t* size) { ++ API_IMPL_BEGIN ++ std::string value; ++ auto status = reinterpret_cast(info)->GetAttr(name, &value); ++ if (status.IsOK()) { ++ if (out == nullptr) { // User is querying the true size of the attribute ++ *size = value.size() + 1; + return nullptr; +- } else { +- return OrtApis::CreateStatus(static_cast(status.Code()), status.ErrorMessage().c_str()); ++ } else if (*size >= value.size() + 1) { ++ std::memcpy(out, value.data(), value.size()); ++ out[value.size()] = '\0'; ++ *size = value.size() + 1; ++ return nullptr; ++ } else { // User has provided a buffer that is not large enough ++ *size = value.size() + 1; ++ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Result buffer is not large enough"); + } +- }); ++ } ++ return onnxruntime::ToOrtStatus(status); ++ API_IMPL_END + } + +-ORT_API_STATUS_IMPL(OrtApis::ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, +- _In_ size_t len, _Out_ size_t* out) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- if (!op_attr) { +- return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Invalid attribute."); +- } +- +- auto attr = reinterpret_cast(op_attr); +- OrtStatusPtr ret = nullptr; +- *out = 0; ++#ifdef _WIN32 ++#pragma warning(push) ++#pragma warning(disable : 28196 6387) ++#endif + +- switch (type) { +- case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { +- if (len < sizeof(float)) { +- ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, +- "Size of data not large enough to hold a float."); +- } else { +- if (attr->has_f()) { +- auto output_f = reinterpret_cast(data); +- *output_f = attr->f(); +- } else { +- ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute has no float value."); +- } +- } +- *out = sizeof(float); ++ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out) { ++ API_IMPL_BEGIN ++ auto* stream = reinterpret_cast(context)->GetComputeStream(); ++ if (stream) ++ *out = stream->GetHandle(); ++ else ++ *out = nullptr; ++ return nullptr; ++ API_IMPL_END ++}; + +- break; +- } +- case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { +- const auto& floats = attr->floats(); +- auto num_floats = floats.size(); ++ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out) { ++ API_IMPL_BEGIN ++ onnxruntime::AllocatorPtr allocator = reinterpret_cast(context)->GetAllocator(mem_info->device); ++ if (!allocator) { ++ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available"); ++ } ++ std::unique_ptr p = std::make_unique(std::move(allocator)); ++ *out = p.release(); ++ return nullptr; ++ API_IMPL_END ++}; + +- if (len < sizeof(float) * num_floats) { +- ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, +- "Size of data not large enough to hold the array of floats."); +- } else { +- auto output_f = reinterpret_cast(data); +- for (auto f : floats) { +- *output_f = f; +- output_f++; +- } +- } +- *out = num_floats * sizeof(float); +- break; +- } +- case OrtOpAttrType::ORT_OP_ATTR_INT: { +- if (len < sizeof(int)) { +- ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, +- "Size of data not large enough to hold an int64."); +- } else { +- if (attr->has_i()) { +- auto output_i = reinterpret_cast(data); +- *output_i = attr->i(); +- } else { +- ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute has no int64 value."); +- } +- } +- *out = sizeof(int64_t); +- break; +- } +- case OrtOpAttrType::ORT_OP_ATTR_INTS: { +- const auto& ints = attr->ints(); +- auto num_ints = ints.size(); ++ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resource_version, _In_ int resource_id, _Outptr_ void** resource) { ++ API_IMPL_BEGIN ++ *resource = {}; ++ const auto* ctx = reinterpret_cast(context); ++ auto* stream = reinterpret_cast(ctx->GetComputeStream()); ++ if (!stream) { ++ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Failed to fetch a stream hosting the requested resource"); ++ } ++ *resource = stream->GetResource(resource_version, resource_id); ++ return nullptr; ++ API_IMPL_END ++}; + +- if (len < sizeof(int64_t) * num_ints) { +- ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, +- "Size of data not large enough to hold the array of int64."); +- } else { +- auto output_i = reinterpret_cast(data); +- for (auto i : ints) { +- *output_i = i; +- output_i++; +- } +- } +- *out = num_ints * sizeof(int64_t); +- break; +- } +- case OrtOpAttrType::ORT_OP_ATTR_STRING: { +- const auto& s = attr->s(); +- if (len < s.size() + 1) { +- ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, +- "Size of data not large enough to hold the string."); +- } else { +- char* output_c = reinterpret_cast(data); +- for (char c : s) { +- *output_c++ = c; +- } +- *output_c = '\0'; +- } +- *out = s.size() + 1; +- break; +- } +- case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { +- const auto& ss = attr->strings(); +- size_t num_bytes = 0; +- for_each(ss.begin(), ss.end(), [&num_bytes](const std::string& s) { num_bytes += s.size() + 1; }); +- +- if (len < num_bytes) { +- ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, +- "Size of data not large enough to hold the array of strings."); +- } else { +- char* output_c = reinterpret_cast(data); +- for (const auto& s : ss) { +- for (char c : s) { +- *output_c++ = c; +- } +- *output_c++ = '\0'; +- } +- } +- *out = num_bytes; +- break; +- } +- default: +- ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type. "); ++ORT_API_STATUS_IMPL(OrtApis::KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data) { ++#ifdef ENABLE_CUSTOM_OP_API ++ API_IMPL_BEGIN ++ if (!context) { ++ return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, "Invalid context"); ++ } ++ if (fn && total) { ++ const auto* ctx = reinterpret_cast(context); ++ auto* tp = ctx->GetOperatorThreadPool(); ++ if (num_batch) { ++ onnxruntime::concurrency::ThreadPool::TryBatchParallelFor( ++ tp, ++ static_cast(total), ++ [fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast(ith)); }, ++ static_cast(num_batch)); ++ } else { ++ onnxruntime::concurrency::ThreadPool::TrySimpleParallelFor( ++ tp, ++ static_cast(total), ++ [fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast(ith)); }); + } ++ } ++ return nullptr; ++ API_IMPL_END ++#else ++ ORT_UNUSED_PARAMETER(context); ++ ORT_UNUSED_PARAMETER(fn); ++ ORT_UNUSED_PARAMETER(total); ++ ORT_UNUSED_PARAMETER(num_batch); ++ ORT_UNUSED_PARAMETER(usr_data); ++ return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "ParallelFor API not implemented for this build"); ++#endif ++}; + +- return ret; +- }); +-} +- +-ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, +- _Out_ float* out) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- auto status = reinterpret_cast(info)->GetAttr(name, out); +- if (status.IsOK()) +- return nullptr; +- return onnxruntime::ToOrtStatus(status); +- }); +-} +- +-ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, +- _Out_ int64_t* out) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- auto status = reinterpret_cast(info)->GetAttr(name, out); +- if (status.IsOK()) +- return nullptr; +- return onnxruntime::ToOrtStatus(status); +- }); +-} +- +-ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, +- _Out_ char* out, _Inout_ size_t* size) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- std::string value; +- auto status = reinterpret_cast(info)->GetAttr(name, &value); +- if (status.IsOK()) { +- if (out == nullptr) { // User is querying the true size of the attribute +- *size = value.size() + 1; +- return nullptr; +- } else if (*size >= value.size() + 1) { +- std::memcpy(out, value.data(), value.size()); +- out[value.size()] = '\0'; +- *size = value.size() + 1; +- return nullptr; +- } else { // User has provided a buffer that is not large enough +- *size = value.size() + 1; +- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Result buffer is not large enough"); +- } +- } +- return onnxruntime::ToOrtStatus(status); +- }); +-} ++#ifdef _WIN32 ++#pragma warning(pop) ++#endif + + template ::value, int>::type = 0> + static Status CopyDataFromVectorToMemory(const std::vector& values, T* out, size_t* size) { +@@ -536,209 +438,256 @@ static Status CopyDataFromVectorToMemory(const std::vector& values, T* out, s + + ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ float* out, _Inout_ size_t* size) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- std::vector values; +- auto status = reinterpret_cast(info)->GetAttrs(name, values); +- if (status.IsOK()) { +- status = CopyDataFromVectorToMemory(values, out, size); +- } +- return onnxruntime::ToOrtStatus(status); +- }); ++ API_IMPL_BEGIN ++ std::vector values; ++ auto status = reinterpret_cast(info)->GetAttrs(name, values); ++ if (status.IsOK()) { ++ status = CopyDataFromVectorToMemory(values, out, size); ++ } ++ return onnxruntime::ToOrtStatus(status); ++ API_IMPL_END + } + + ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ int64_t* out, _Inout_ size_t* size) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- std::vector values; +- auto status = reinterpret_cast(info)->GetAttrs(name, values); +- if (status.IsOK()) { +- status = CopyDataFromVectorToMemory(values, out, size); +- } +- return onnxruntime::ToOrtStatus(status); +- }); ++ API_IMPL_BEGIN ++ std::vector values; ++ auto status = reinterpret_cast(info)->GetAttrs(name, values); ++ if (status.IsOK()) { ++ status = CopyDataFromVectorToMemory(values, out, size); ++ } ++ return onnxruntime::ToOrtStatus(status); ++ API_IMPL_END + } + + ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name, + _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- const auto* op_kinfo = reinterpret_cast(info); +- +- // Get TensorProto attribute +- onnx::TensorProto tensor_proto; +- auto status = op_kinfo->GetAttr(name, &tensor_proto); +- if (!status.IsOK()) { +- return onnxruntime::ToOrtStatus(status); +- } ++ API_IMPL_BEGIN ++ const auto* op_kinfo = reinterpret_cast(info); + +- // Determine the tensor's size in bytes. +- size_t req_size = 0; +- status = onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &req_size); +- if (!status.IsOK()) { +- return onnxruntime::ToOrtStatus(status); +- } ++ // Get TensorProto attribute ++ onnx::TensorProto tensor_proto; ++ auto status = op_kinfo->GetAttr(name, &tensor_proto); ++ if (!status.IsOK()) { ++ return onnxruntime::ToOrtStatus(status); ++ } + +- // Create Tensor that owns buffer memory that will be allocated with the provided OrtAllocator. +- onnxruntime::TensorShape tensor_shape = onnxruntime::utils::GetTensorShapeFromTensorProto(tensor_proto); +- const auto* type = onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); +- onnxruntime::AllocatorPtr alloc_ptr = std::make_shared(allocator); +- auto tensorp = std::make_unique(type, tensor_shape, std::move(alloc_ptr)); ++ // Determine the tensor's size in bytes. ++ size_t req_size = 0; ++ status = onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &req_size); ++ if (!status.IsOK()) { ++ return onnxruntime::ToOrtStatus(status); ++ } + +- // Deserialize TensorProto into pre-allocated, empty Tensor. +- status = onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), nullptr, tensor_proto, *tensorp); +- if (!status.IsOK()) { +- return onnxruntime::ToOrtStatus(status); +- } ++ // Create Tensor that owns buffer memory that will be allocated with the provided OrtAllocator. ++ onnxruntime::TensorShape tensor_shape = onnxruntime::utils::GetTensorShapeFromTensorProto(tensor_proto); ++ const auto* const type = onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); ++ onnxruntime::AllocatorPtr alloc_ptr = std::make_shared(allocator); ++ auto tensorp = std::make_unique(type, tensor_shape, std::move(alloc_ptr)); ++ ++ // Deserialize TensorProto into pre-allocated, empty Tensor. ++ status = onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), nullptr, tensor_proto, *tensorp); ++ if (!status.IsOK()) { ++ return onnxruntime::ToOrtStatus(status); ++ } + +- // Initialize OrtValue from Tensor. +- auto ml_tensor = onnxruntime::DataTypeImpl::GetType(); +- auto value = std::make_unique(); +- value->Init(tensorp.release(), ml_tensor, ml_tensor->GetDeleteFunc()); ++ // Initialize OrtValue from Tensor. ++ auto ml_tensor = onnxruntime::DataTypeImpl::GetType(); ++ auto value = std::make_unique(); ++ value->Init(tensorp.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + +- *out = value.release(); +- return nullptr; +- }); ++ *out = value.release(); ++ return nullptr; ++ API_IMPL_END + } + + ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- *out = reinterpret_cast(info)->GetInputCount(); +- return nullptr; +- }); ++ API_IMPL_BEGIN ++ *out = reinterpret_cast(info)->GetInputCount(); ++ return nullptr; ++ API_IMPL_END + }; + + ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- *out = reinterpret_cast(info)->GetOutputCount(); +- return nullptr; +- }); ++ API_IMPL_BEGIN ++ *out = reinterpret_cast(info)->GetOutputCount(); ++ return nullptr; ++ API_IMPL_END + }; + +-ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, +- _Out_ char* out, _Inout_ size_t* size) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- const auto* op_info = reinterpret_cast(info); +- const auto input_defs = op_info->node().InputDefs(); ++ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, ++ _Inout_ size_t* size) { ++ API_IMPL_BEGIN ++ const auto* op_info = reinterpret_cast(info); ++ const auto input_defs = op_info->node().InputDefs(); + +- if (index >= input_defs.size()) { +- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo input index is out of bounds"); +- } ++ if (index >= input_defs.size()) { ++ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo input index is out of bounds"); ++ } + +- auto status = CopyStringToOutputArg(input_defs[index]->Name(), +- "Output buffer is not large enough for ::OrtKernelInfo input name", out, size); ++ auto status = CopyStringToOutputArg(input_defs[index]->Name(), ++ "Output buffer is not large enough for ::OrtKernelInfo input name", out, size); + +- return onnxruntime::ToOrtStatus(status); +- }); ++ return onnxruntime::ToOrtStatus(status); ++ API_IMPL_END + } + + ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, + _Inout_ size_t* size) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- const auto* op_info = reinterpret_cast(info); +- const auto output_defs = op_info->node().OutputDefs(); ++ API_IMPL_BEGIN ++ const auto* op_info = reinterpret_cast(info); ++ const auto output_defs = op_info->node().OutputDefs(); + +- if (index >= output_defs.size()) { +- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo output index is out of bounds"); +- } ++ if (index >= output_defs.size()) { ++ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo output index is out of bounds"); ++ } + +- auto status = CopyStringToOutputArg(output_defs[index]->Name(), +- "Output buffer is not large enough for ::OrtKernelInfo output name", +- out, size); ++ auto status = CopyStringToOutputArg(output_defs[index]->Name(), ++ "Output buffer is not large enough for ::OrtKernelInfo output name", out, size); + +- return onnxruntime::ToOrtStatus(status); +- }); ++ return onnxruntime::ToOrtStatus(status); ++ API_IMPL_END + } + + ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, + _Outptr_ OrtTypeInfo** type_info) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- const auto* op_info = reinterpret_cast(info); +- const auto input_defs = op_info->node().InputDefs(); ++ API_IMPL_BEGIN ++ const auto* op_info = reinterpret_cast(info); ++ const auto input_defs = op_info->node().InputDefs(); + +- if (index >= input_defs.size()) { +- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo input index is out of bounds"); +- } ++ if (index >= input_defs.size()) { ++ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo input index is out of bounds"); ++ } + +- const onnxruntime::NodeArg* node_arg = input_defs[index]; +- const ONNX_NAMESPACE::TypeProto* type_proto = node_arg->TypeAsProto(); ++ const onnxruntime::NodeArg* node_arg = input_defs[index]; ++ const ONNX_NAMESPACE::TypeProto* type_proto = node_arg->TypeAsProto(); + +- if (type_proto == nullptr) { +- return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo input does not have a type"); +- } ++ if (type_proto == nullptr) { ++ return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo input does not have a type"); ++ } + +- auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto); +- *type_info = type_info_ret.release(); +- return nullptr; +- }); ++ auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto); ++ *type_info = type_info_ret.release(); ++ return nullptr; ++ API_IMPL_END + } + + ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, + _Outptr_ OrtTypeInfo** type_info) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- const auto* op_info = reinterpret_cast(info); +- const auto output_defs = op_info->node().OutputDefs(); ++ API_IMPL_BEGIN ++ const auto* op_info = reinterpret_cast(info); ++ const auto output_defs = op_info->node().OutputDefs(); + +- if (index >= output_defs.size()) { +- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo output index is out of bounds"); +- } ++ if (index >= output_defs.size()) { ++ return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo output index is out of bounds"); ++ } + +- const onnxruntime::NodeArg* node_arg = output_defs[index]; +- const ONNX_NAMESPACE::TypeProto* type_proto = node_arg->TypeAsProto(); ++ const onnxruntime::NodeArg* node_arg = output_defs[index]; ++ const ONNX_NAMESPACE::TypeProto* type_proto = node_arg->TypeAsProto(); + +- if (type_proto == nullptr) { +- return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo output does not have a type"); +- } ++ if (type_proto == nullptr) { ++ return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo output does not have a type"); ++ } + +- auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto); +- *type_info = type_info_ret.release(); +- return nullptr; +- }); ++ auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto); ++ *type_info = type_info_ret.release(); ++ return nullptr; ++ API_IMPL_END + } + + ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, _In_ size_t index, + _Out_ int* is_constant, _Outptr_ const OrtValue** out) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- const auto* op_info = reinterpret_cast(info); +- *is_constant = static_cast(op_info->TryGetConstantInput(gsl::narrow_cast(index), out)); +- return nullptr; +- }); ++ API_IMPL_BEGIN ++ const auto* op_info = reinterpret_cast(info); ++ *is_constant = static_cast(op_info->TryGetConstantInput(gsl::narrow_cast(index), out)); ++ return nullptr; ++ API_IMPL_END + }; + + ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_ char* out, + _Inout_ size_t* size) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- const auto* op_info = reinterpret_cast(info); ++ API_IMPL_BEGIN ++ const auto* op_info = reinterpret_cast(info); + +- auto status = CopyStringToOutputArg(op_info->node().Name(), +- "Output buffer is not large enough for ::OrtKernelInfo node name", out, size); ++ auto status = CopyStringToOutputArg(op_info->node().Name(), ++ "Output buffer is not large enough for ::OrtKernelInfo node name", out, size); + +- return onnxruntime::ToOrtStatus(status); +- }); ++ return onnxruntime::ToOrtStatus(status); ++ API_IMPL_END + } + + ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* info, _Outptr_ const OrtLogger** logger) { +- return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { +- const auto* ep = reinterpret_cast(info)->GetExecutionProvider(); ++ API_IMPL_BEGIN ++ const auto* ep = reinterpret_cast(info)->GetExecutionProvider(); + +- if (ep == nullptr) { +- return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo does not have an execution provider"); +- } ++ if (ep == nullptr) { ++ return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo does not have an execution provider"); ++ } + +- const auto* ep_logger = ep->GetLogger(); ++ const auto* ep_logger = ep->GetLogger(); + +- if (ep_logger == nullptr) { +- return OrtApis::CreateStatus(ORT_INVALID_GRAPH, +- "::OrtKernelInfo cannot get a valid logger from " +- "its execution provider"); +- } ++ if (ep_logger == nullptr) { ++ return OrtApis::CreateStatus(ORT_INVALID_GRAPH, ++ "::OrtKernelInfo cannot get a valid logger from " ++ "its execution provider"); ++ } + +- *logger = reinterpret_cast(ep_logger); +- return nullptr; +- }); ++ *logger = reinterpret_cast(ep_logger); ++ return nullptr; ++ API_IMPL_END + } + +-#if ENABLE_CUSTOM_OP_API ++ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetLogger, _In_ const OrtKernelContext* context, _Outptr_ const OrtLogger** logger) { ++ API_IMPL_BEGIN ++ const auto& kernel_ctx_logger = reinterpret_cast(context)->Logger(); ++ ++ *logger = reinterpret_cast(&kernel_ctx_logger); ++ return nullptr; ++ API_IMPL_END ++} ++ ++ORT_API_STATUS_IMPL(OrtApis::Logger_LogMessage, _In_ const OrtLogger* logger, OrtLoggingLevel log_severity_level, ++ _In_z_ const char* message, _In_z_ const ORTCHAR_T* file_path, int line_number, ++ _In_z_ const char* func_name) { ++ API_IMPL_BEGIN ++ const auto& actual_logger = *reinterpret_cast(logger); ++ const auto severity = static_cast(log_severity_level); ++ const auto log_data_type = onnxruntime::logging::DataType::SYSTEM; ++ ++ if (actual_logger.OutputIsEnabled(severity, log_data_type)) { ++#ifdef _WIN32 ++ const std::string file_path_str = onnxruntime::ToUTF8String(file_path); ++ onnxruntime::CodeLocation location(file_path_str.c_str(), line_number, func_name); ++#else ++ onnxruntime::CodeLocation location(file_path, line_number, func_name); ++#endif ++ ++ onnxruntime::logging::Capture( ++ actual_logger, ++ severity, ++ onnxruntime::logging::Category::onnxruntime, ++ log_data_type, ++ location) ++ .Stream() ++ << message; ++ } ++ ++ return nullptr; ++ API_IMPL_END ++} ++ ++ORT_API_STATUS_IMPL(OrtApis::Logger_GetLoggingSeverityLevel, _In_ const OrtLogger* logger, _Out_ OrtLoggingLevel* out) { ++ API_IMPL_BEGIN ++ const auto& actual_logger = *reinterpret_cast(logger); ++ *out = static_cast(actual_logger.GetSeverity()); ++ return nullptr; ++ API_IMPL_END ++} ++ ++#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) + #include "core/framework/customregistry.h" + namespace onnxruntime { ++ + struct CustomOpKernel : OpKernel { + CustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { + if (op_.version > ORT_API_VERSION) { +@@ -817,8 +766,7 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust + if (input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { + def_builder.TypeConstraint(input_name, SUPPORTED_TENSOR_TYPES); + } else { +- def_builder.TypeConstraint(input_name, +- DataTypeImpl::TensorTypeFromONNXEnum(static_cast(input_type))->AsTensorType()); ++ def_builder.TypeConstraint(input_name, DataTypeImpl::TensorTypeFromONNXEnum(static_cast(input_type))->AsTensorType()); + } + } + +@@ -828,8 +776,7 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust + if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { + def_builder.TypeConstraint(output_name, SUPPORTED_TENSOR_TYPES); + } else { +- def_builder.TypeConstraint(output_name, +- DataTypeImpl::TensorTypeFromONNXEnum(static_cast(output_type))->AsTensorType()); ++ def_builder.TypeConstraint(output_name, DataTypeImpl::TensorTypeFromONNXEnum(static_cast(output_type))->AsTensorType()); + } + } + +@@ -839,8 +786,7 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust + def_builder.Provider(onnxruntime::kCpuExecutionProvider); + } + +- KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, +- std::unique_ptr& out) -> Status { ++ KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + out = std::make_unique(info, *op); + return Status::OK(); + }; +@@ -945,8 +891,8 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect + "There must be one (and only one) dynamic typed input to the custom op. " + "Its type info at runtime will be used to infer the type info of this dynamic typed output " + "which is required for the success of the model loading step. " +- "More than one dynamic typed inputs are currently not supported as differing types at runtime " +- "means the output type cannot be inferred without which model loading cannot proceed."); ++ "More than one dynamic typed inputs are currently not supported as differing types at runtime means the output type " ++ "cannot be inferred without which model loading cannot proceed."); + } + } + create_type_constraint(op, static_cast(output_count), static_cast(i), false); +@@ -1057,8 +1003,7 @@ void InferOutputTypes(const InlinedVector& kernel_defs, + if (tc_iter != type_constraints.end()) { + if (tc_iter->second.size() > 1) { + undef = elem_type; +- } else if (tc_iter->second.size() != 1 || +- tc_iter->second[0] != DataTypeImpl::TensorTypeFromONNXEnum(elem_type)) { ++ } else if (tc_iter->second.size() != 1 || tc_iter->second[0] != DataTypeImpl::TensorTypeFromONNXEnum(elem_type)) { + matched = false; + } + } else { +@@ -1085,8 +1030,7 @@ void InferOutputTypes(const InlinedVector& kernel_defs, + if (tc_iter->second.size() > 1) { + output_type->mutable_tensor_type()->set_elem_type(undef); + } else { +- output_type->mutable_tensor_type()->set_elem_type( +- tc_iter->second[0]->GetTypeProto()->tensor_type().elem_type()); ++ output_type->mutable_tensor_type()->set_elem_type(tc_iter->second[0]->GetTypeProto()->tensor_type().elem_type()); + } + } + break; +@@ -1108,8 +1052,7 @@ common::Status CreateCustomRegistry(gsl::span op_domai + // If domain is empty, it is assumed to be part of the ONNX domain + if (!domain->domain_.empty()) { + // Add it to the DomainToVersion ONNX map if it doesn't already exist +- // For example, two sessions using the same session_options should not add the same custom op domain +- // to the version map twice ++ // For example, two sessions using the same session_options should not add the same custom op domain to the version map twice + auto& domain_to_version_range_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); + const auto& domain_to_version_map = domain_to_version_range_instance.Map(); + +@@ -1156,13 +1099,12 @@ common::Status CreateCustomRegistry(gsl::span op_domai + schemas.push_back(schema_iter.second); + InlinedVector kernel_defs = std::move(kernel_def_map[schema_iter.first]); + auto infer_fn = schemas.back().GetTypeAndShapeInferenceFunction(); +- ONNX_NAMESPACE::InferenceFunction extended_infer_fn = +- [infer_fn, kernel_defs](ONNX_NAMESPACE::InferenceContext& infer_ctx) { +- InferOutputTypes(kernel_defs, infer_ctx); +- if (infer_fn) { +- infer_fn(infer_ctx); +- } +- }; ++ ONNX_NAMESPACE::InferenceFunction extended_infer_fn = [infer_fn, kernel_defs](ONNX_NAMESPACE::InferenceContext& infer_ctx) { ++ InferOutputTypes(kernel_defs, infer_ctx); ++ if (infer_fn) { ++ infer_fn(infer_ctx); ++ } ++ }; + schemas.back().TypeAndShapeInferenceFunction(extended_infer_fn); + } + +@@ -1221,4 +1163,4 @@ common::Status CreateCustomRegistry(gsl::span op_domai + } + + } // namespace onnxruntime +-#endif // ENABLE_CUSTOM_OP_API ++#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc +index cae714954..e8853c882 100644 +--- a/onnxruntime/core/session/inference_session.cc ++++ b/onnxruntime/core/session/inference_session.cc +@@ -145,30 +145,28 @@ static bool HasMemcpyNodes(const Graph& graph) { + return false; + } + +-static bool AreAllComputeNodesAssignedToCudaOrJsEp(const Graph& graph) { +- bool nodes_on_cpu_and_cuda_and_js_eps_only = true; ++static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { ++ bool nodes_on_cpu_and_cuda_eps_only = true; + + for (const auto& node : graph.Nodes()) { + const auto& node_provider = node.GetExecutionProviderType(); + + // Empty node provider means CPU EP + if (!node_provider.empty() && +- !(node_provider == kCudaExecutionProvider || +- node_provider == kRocmExecutionProvider || +- node_provider == kJsExecutionProvider) && ++ node_provider != kCudaExecutionProvider && + node_provider != kCpuExecutionProvider) { +- nodes_on_cpu_and_cuda_and_js_eps_only = false; ++ nodes_on_cpu_and_cuda_eps_only = false; + break; + } + } + +- // If we see nodes assigned to EPs other than CPU, or CUDA/JS ++ // If we see nodes assigned to EPs other than CPU or CUDA + // (or) if there are Memcpy nodes, then all compute nodes have +- // not been parititoned to the CUDA/JS EP. ++ // not been parititoned to the CUDA EP. + // We allow CPU EPs to show up in the EP list as long as thre is no Memcpy + // involved as shape subgraphs will be forced onto CPU and these will not have + // Memcpy nodes involved. +- return nodes_on_cpu_and_cuda_and_js_eps_only && !HasMemcpyNodes(graph); ++ return nodes_on_cpu_and_cuda_eps_only && !HasMemcpyNodes(graph); + } + + static bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) { +@@ -1717,7 +1715,7 @@ common::Status InferenceSession::Initialize() { + // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. + ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); + +- // Currently graph capture is only considered by CUDA EP, TRT EP, ROCM EP and JS EP. ++ // Currently CUDA graph is only considered by CUDA EP and TRT EP. + // + // Check for CUDA EP: + // If the CUDA EP is part of the providers list for this session AND +@@ -1730,70 +1728,47 @@ common::Status InferenceSession::Initialize() { + // The TRT EP is configured to do a graph capture AND + // All the graph nodes have been assigned to the TRT EP, + // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). +- // +- // Check for JS EP: +- // If the JS EP is part of the providers list for this session AND +- // The JS EP is configured to do a graph capture AND +- // All the "compute" graph nodes have been assigned to the JS EP, +- // Then the JS EP is cached for triggering a ReplayGraph() in Run(). +- // +- // Check for ROCM EP: +- // If the ROCM EP is part of the providers list for this session AND +- // The ROCM EP is configured to do a graph capture AND +- // All the "compute" graph nodes have been assigned to the ROCM EP, +- // Then the ROCM EP is cached for triggering a ReplayGraph() in Run(). +- // +- std::vector graph_support_ep_list = { +- onnxruntime::kTensorrtExecutionProvider, +- onnxruntime::kCudaExecutionProvider, +- onnxruntime::kRocmExecutionProvider, +- onnxruntime::kJsExecutionProvider}; ++ std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider}; + +- for (auto& it : graph_support_ep_list) { ++ for (auto& it : cuda_graph_support_ep_list) { + auto* target_ep = execution_providers_.Get(it); + + if (target_ep && target_ep->IsGraphCaptureEnabled()) { +- // Graphs capture can't work with control flow nodes ++ // CUDA Graphs can't work with control flow nodes + if (HasControlflowNodes(graph)) { +- LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " +- << "as the model has control flow nodes which can't be supported by " +- << target_ep->Type(); ++ LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " ++ << "as the model has control flow nodes which can't be supported by CUDA Graphs."; + + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, +- "This session cannot use the graph capture feature as requested by the user " +- "as the model has control flow nodes which can't be supported by" + +- target_ep->Type())); ++ "This session cannot use the CUDA Graph feature as requested by the user " ++ "as the model has control flow nodes which can't be supported by CUDA Graphs.")); + } + +- if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || +- strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0 || +- strcmp(target_ep->Type().c_str(), onnxruntime::kJsExecutionProvider) == 0) { +- // Ensure that all nodes have been partitioned to CUDA/JS or CPU EP && there are no memcpy nodes ++ if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0) { ++ // Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes + // The reasoning behind this logic is that certain shape nodes will be forced onto CPU + // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP + // which is all we care about. +- if (!AreAllComputeNodesAssignedToCudaOrJsEp(graph)) { +- LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " +- << " as all compute graph nodes have not been partitioned to the " +- << target_ep->Type(); ++ if (!AreAllComputeNodesAssignedToCudaEp(graph)) { ++ LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " ++ << " as all compute graph nodes have not been partitioned to the CUDA EP."; + + ORT_RETURN_IF_ERROR_SESSIONID_( + ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, +- "This session cannot use the graph capture feature as requested by the user " +- " as all compute graph nodes have not been partitioned to the " + +- target_ep->Type())); ++ "This session cannot use the CUDA Graph feature as requested by the user " ++ " as all compute graph nodes have not been partitioned to the CUDA EP.")); + } + + // Log a warning for the user to know that there are shape subgraphs that will execute on CPU + if (HasShapeSubgraphNodes(graph)) { + LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " +- << "Use the graph capture feature with caution. " ++ << "Use the CUDA Graph feature with caution. " + << "As long as the intermediate shapes produced in the model " +- << "using the representative input used to capture the graph, " ++ << "using the representative input used to capture the CUDA graph, " + << "will match the shapes produced in the model for other inputs " + << "of the same shape as the representative input (common case), " +- << "it is safe to use the graph capture feature."; ++ << "it is safe to use the CUDA Graph feature."; + } + } else { + // Following code path is for TRT EP currently. +@@ -1812,7 +1787,7 @@ common::Status InferenceSession::Initialize() { + } + } + +- LOGS(*session_logger_, INFO) << "This session will use the CUDA/HIP Graph feature as requested by the user."; ++ LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; + cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep); + break; // Make sure only one ep can run CUDA graph. + } +@@ -2502,9 +2477,7 @@ Status InferenceSession::Run(const RunOptions& run_options, + // As N+1 inference runs (N for memory allocation and 1 for graph capturing) + // are needed before replaying the captured graph, here run N inference runs recursively until graph captured, + // so that users just need one session run to capture the graph. +- // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, +- // N is defined in min_num_runs_before_hip_graph_capture_ for ROCM EP, +- // and the value could be different for other EP. ++ // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, and the value could be different for other EP. + if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() && + !cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { + LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture."; +diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc +index dec8754ea..d77c188f8 100644 +--- a/onnxruntime/core/session/onnxruntime_c_api.cc ++++ b/onnxruntime/core/session/onnxruntime_c_api.cc +@@ -2397,7 +2397,7 @@ Second example, if we wanted to add and remove some members, we'd do this: + In GetApi we now make it return ort_api_3 for version 3. + */ + +-static constexpr OrtApi ort_api_1_to_18 = { ++static constexpr OrtApi ort_api_1_to_17 = { + // NOTE: The ordering of these fields MUST not change after that version has shipped since existing binaries depend on this ordering. + + // Shipped as version 1 - DO NOT MODIFY (see above text for more information) +@@ -2724,7 +2724,6 @@ static constexpr OrtApi ort_api_1_to_18 = { + &OrtApis::SetDeterministicCompute, + &OrtApis::KernelContext_ParallelFor, + &OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2, +- &OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, + }; + + // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. +@@ -2757,16 +2756,16 @@ static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265 + static_assert(offsetof(OrtApi, SetUserLoggingFunction) / sizeof(void*) == 266, "Size of version 17 API cannot change"); + + // So that nobody forgets to finish an API version, this check will serve as a reminder: +-static_assert(std::string_view(ORT_VERSION) == "1.18.0", ++static_assert(std::string_view(ORT_VERSION) == "1.17.0", + "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); + // 1. Update the hardcoded version string in above static_assert to silence it +-// 2. If there were any APIs added to ort_api_1_to_18 above: ++// 2. If there were any APIs added to ort_api_1_to_17 above: + // a. Add the 'End of version #' markers (pattern above should be obvious) + // b. Add a static_assert in the directly above list of version sizes to ensure nobody adds any more functions to the just shipped API version + + ORT_API(const OrtApi*, OrtApis::GetApi, uint32_t version) { + if (version >= 1 && version <= ORT_API_VERSION) +- return &ort_api_1_to_18; ++ return &ort_api_1_to_17; + + fprintf(stderr, + "The requested API version [%u] is not available, only API versions [1, %u] are supported in this build." +diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h +index 9ce94ba89..c1caafa4d 100644 +--- a/onnxruntime/core/session/ort_apis.h ++++ b/onnxruntime/core/session/ort_apis.h +@@ -509,8 +509,4 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO_V2, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); +- +-ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, +- _In_reads_(num_keys) const char* const* provider_options_keys, +- _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + } // namespace OrtApis +diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc +index 32ae15e71..29c2c6b0c 100644 +--- a/onnxruntime/core/session/provider_bridge_ort.cc ++++ b/onnxruntime/core/session/provider_bridge_ort.cc +@@ -30,7 +30,6 @@ + #include "core/framework/sparse_utils.h" + #include "core/graph/graph_proto_serializer.h" + #include "core/framework/murmurhash3.h" +-#include "core/framework/model_metadef_id_generator.h" + + #include "core/session/onnxruntime_c_api.h" + #include "core/common/string_helper.h" +@@ -57,8 +56,6 @@ + namespace ONNX_NAMESPACE { + // We use these names in the provider API because we don't have the protobuf definitions of the RepeatedField* types + using int64s = google::protobuf::RepeatedField; +-using float32s = google::protobuf::RepeatedField; +-using StringStringEntryProtos = google::protobuf::RepeatedPtrField; + using TensorProtos = google::protobuf::RepeatedPtrField; + using TensorShapeProto_Dimensions = google::protobuf::RepeatedPtrField; + using ValueInfoProtos = google::protobuf::RepeatedPtrField; +@@ -79,7 +76,6 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; + #include "core/providers/migraphx/migraphx_provider_factory_creator.h" + #include "core/providers/openvino/openvino_provider_factory_creator.h" + #include "core/providers/tensorrt/tensorrt_provider_factory_creator.h" +-#include "core/providers/vitisai/vitisai_provider_factory_creator.h" + + #include "core/providers/cuda/cuda_provider_factory.h" + #include "core/providers/cann/cann_provider_factory.h" +@@ -126,7 +122,6 @@ ProviderInfo_Dnnl& GetProviderInfo_Dnnl(); + ProviderInfo_ROCM* TryGetProviderInfo_ROCM(); + ProviderInfo_ROCM& GetProviderInfo_ROCM(); + ProviderHostCPU& GetProviderHostCPU(); +-ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector& ops); + struct TensorShapeProto_Dimension_Iterator_Impl : TensorShapeProto_Dimension_Iterator { + TensorShapeProto_Dimension_Iterator_Impl(google::protobuf::internal::RepeatedPtrIterator&& v) : v_{std::move(v)} {} + +@@ -278,10 +273,7 @@ struct ProviderHostImpl : ProviderHost { + Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint32_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } + Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } + Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } +- Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, +- /*out*/ std::vector& unpacked_tensor) override { +- return utils::UnpackInitializerData(tensor, model_path, unpacked_tensor); +- } ++ + uint16_t math__floatToHalf(float f) override { return math::floatToHalf(f); } + float math__halfToFloat(uint16_t h) override { return math::halfToFloat(h); } + +@@ -325,6 +317,10 @@ struct ProviderHostImpl : ProviderHost { + return p->IExecutionProvider::Compile(fused_nodes_and_graphs, node_compute_funcs); + } + ++ int IExecutionProvider__GenerateMetaDefId(const IExecutionProvider* p, const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) override { ++ return p->IExecutionProvider::GenerateMetaDefId(graph_viewer, model_hash); ++ } ++ + // Status (direct) + std::string Status__ToString(const Status* p) override { return p->Status::ToString(); } + +@@ -359,32 +355,12 @@ struct ProviderHostImpl : ProviderHost { + void logging__Capture__operator_delete(logging::Capture* p) noexcept override { delete p; } + std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); } + +- // Env +- Env& Env__Default() override { return Env::Default(); } +- + // Utils::DataTypeUtils (wrapped) + const std::string* Utils__DataTypeUtils__ToType(const ONNX_NAMESPACE::TypeProto& type_proto) override { return ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(type_proto); } + + // int64s (wrapped) + int int64s__size(const ONNX_NAMESPACE::int64s* p) override { return p->size(); } + const int64_t& int64s__Get(const ONNX_NAMESPACE::int64s* p, int index) override { return p->Get(index); } +- void int64s__Reserve(ONNX_NAMESPACE::int64s* p, int size) override { p->Reserve(size); }; +- const int64_t* int64s__data(const ONNX_NAMESPACE::int64s* p) override { return p->data(); } +- +- // float32s +- void float32s__Reserve(ONNX_NAMESPACE::float32s* p, int size) override { p->Reserve(size); }; +- const float* float32s__data(const ONNX_NAMESPACE::float32s* p) override { return p->data(); } +- int float32s__size(const ONNX_NAMESPACE::float32s* p) override { return p->size(); } +- +- // StringStringEntryProto +- std::string* StringStringEntryProto__mutable_key(ONNX_NAMESPACE::StringStringEntryProto* p) override { return p->mutable_key(); } +- std::string* StringStringEntryProto__mutable_value(ONNX_NAMESPACE::StringStringEntryProto* p) override { return p->mutable_value(); } +- +- // StringStringEntryProtos +- void StringStringEntryProtos__Clear(ONNX_NAMESPACE::StringStringEntryProtos* p) override { p->Clear(); }; +- ONNX_NAMESPACE::StringStringEntryProto* StringStringEntryProtos__Add(ONNX_NAMESPACE::StringStringEntryProtos* p) override { return p->Add(); } +- int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) override { return p->size(); } +- ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) override { return p->at(index); }; + + #if !defined(DISABLE_OPTIONAL_TYPE) + // TypeProto_Optional (wrapped) +@@ -401,7 +377,6 @@ struct ProviderHostImpl : ProviderHost { + const ONNX_NAMESPACE::TensorShapeProto& TypeProto_Tensor__shape(const ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->shape(); } + ONNX_NAMESPACE::TensorShapeProto* TypeProto_Tensor__mutable_shape(ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->mutable_shape(); } + int32_t TypeProto_Tensor__elem_type(const ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->elem_type(); } +- void TypeProto_Tensor__set_elem_type(ONNX_NAMESPACE::TypeProto_Tensor* p, int32_t value) override { p->set_elem_type(value); }; + + // TypeProto_SparseTensor (wrapped) + #if !defined(DISABLE_SPARSE_TENSORS) +@@ -454,18 +429,9 @@ struct ProviderHostImpl : ProviderHost { + float AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p, int i) override { return p->floats(i); } + const std::string& AttributeProto__strings(const ONNX_NAMESPACE::AttributeProto* p, int i) override { return p->strings(i); } + const ONNX_NAMESPACE::int64s& AttributeProto__ints(const ONNX_NAMESPACE::AttributeProto* p) override { return p->ints(); } +- const ONNX_NAMESPACE::float32s& AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p) override { return p->floats(); } +- ONNX_NAMESPACE::int64s* AttributeProto__mutable_ints(ONNX_NAMESPACE::AttributeProto* p) override { return p->mutable_ints(); } +- ONNX_NAMESPACE::float32s* AttributeProto__mutable_floats(ONNX_NAMESPACE::AttributeProto* p) override { return p->mutable_floats(); } +- void AttributeProto__add_ints(ONNX_NAMESPACE::AttributeProto* p, int64_t value) override { p->add_ints(value); }; +- void AttributeProto__add_floats(ONNX_NAMESPACE::AttributeProto* p, float value) override { p->add_floats(value); }; +- void AttributeProto__add_strings(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { p->add_strings(value); }; +- + int64_t AttributeProto__i(const ONNX_NAMESPACE::AttributeProto* p) override { return p->i(); } + float AttributeProto__f(const ONNX_NAMESPACE::AttributeProto* p) override { return p->f(); } +- const ONNX_NAMESPACE::TensorProto& AttributeProto__t(const ONNX_NAMESPACE::AttributeProto* p) override { return p->t(); } + void AttributeProto__set_s(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { return p->set_s(value); } +- void AttributeProto__set_f(ONNX_NAMESPACE::AttributeProto* p, const float& value) override { return p->set_f(value); } + void AttributeProto__set_i(ONNX_NAMESPACE::AttributeProto* p, int64_t value) override { return p->set_i(value); } + const ::std::string& AttributeProto__s(const ONNX_NAMESPACE::AttributeProto* p) override { return p->s(); } + void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { return p->set_name(value); } +@@ -487,7 +453,6 @@ struct ProviderHostImpl : ProviderHost { + ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_value_info(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_value_info(); } + ONNX_NAMESPACE::TensorProtos* GraphProto__mutable_initializer(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_initializer(); } + ONNX_NAMESPACE::NodeProto* GraphProto__add_node(ONNX_NAMESPACE::GraphProto* p) override { return p->add_node(); } +- std::string* GraphProto__mutable_name(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_name(); } + ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) override { return p->mutable_node(index); } + + void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) override { *p = v; } +@@ -505,7 +470,6 @@ struct ProviderHostImpl : ProviderHost { + ONNX_NAMESPACE::GraphProto* ModelProto__mutable_graph(ONNX_NAMESPACE::ModelProto* p) override { return p->mutable_graph(); } + + void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) override { p->set_ir_version(value); } +- ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) override { return p->mutable_metadata_props(); }; + + // NodeProto (wrapped) + std::unique_ptr NodeProto__construct() override { return std::make_unique(); } +@@ -520,34 +484,19 @@ struct ProviderHostImpl : ProviderHost { + void TensorProto__operator_delete(ONNX_NAMESPACE::TensorProto* p) override { delete p; } + void TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) override { *p = v; } + bool TensorProto__has_name(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_name(); } +- void TensorProto__set_name(ONNX_NAMESPACE::TensorProto* p, const ::std::string& name) override { p->set_name(name); } +- const ::std::string& TensorProto__name(const ONNX_NAMESPACE::TensorProto* p) override { return p->name(); } + int TensorProto__dims_size(const ONNX_NAMESPACE::TensorProto* p) override { return p->dims_size(); } + const ONNX_NAMESPACE::int64s& TensorProto__dims(const ONNX_NAMESPACE::TensorProto* p) override { return p->dims(); } +- void TensorProto__add_dims(ONNX_NAMESPACE::TensorProto* p, int64_t value) override { p->add_dims(value); } + bool TensorProto__has_data_location(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_data_location(); } + int TensorProto__data_location(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_location(); } + bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_raw_data(); } + const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->raw_data(); } +- std::string* TensorProto__mutable_raw_data(ONNX_NAMESPACE::TensorProto* p) override { return p->mutable_raw_data(); } +- + int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_type(); } +- void TensorProto__set_data_type(ONNX_NAMESPACE::TensorProto* p, int32_t type) override { p->set_data_type(type); } + + bool TensorProto_DataType_IsValid(int value) override { return ONNX_NAMESPACE::TensorProto::DataType_IsValid(value); } + void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) override { p->CopyFrom(*other); } +- ONNX_NAMESPACE::StringStringEntryProtos* TensorProto__mutable_external_data(ONNX_NAMESPACE::TensorProto* p) override { return p->mutable_external_data(); }; +- void TensorProto__clear_float_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_float_data(); } +- void TensorProto__clear_int32_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_int32_data(); } +- void TensorProto__clear_string_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_string_data(); } +- void TensorProto__clear_int64_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_int64_data(); } +- void TensorProto__clear_double_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_double_data(); } +- void TensorProto__clear_uint64_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_uint64_data(); } + + // TensorProtos (wrapped) + ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) override { return p->Add(); } +- int TensorProtos__size(ONNX_NAMESPACE::TensorProtos* p) override { return p->size(); } +- ONNX_NAMESPACE::TensorProto& TensorProtos__at(ONNX_NAMESPACE::TensorProtos* p, int index) override { return p->at(index); }; + + // TensorShapeProto_Dimension (wrapped) + int TensorShapeProto_Dimension__value_case(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->value_case(); } +@@ -557,8 +506,6 @@ struct ProviderHostImpl : ProviderHost { + void TensorShapeProto_Dimension__set_dim_value(ONNX_NAMESPACE::TensorShapeProto_Dimension* p, int64_t value) override { return p->set_dim_value(value); } + bool TensorShapeProto_Dimension__has_dim_value(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->has_dim_value(); } + bool TensorShapeProto_Dimension__has_dim_param(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->has_dim_param(); } +- const std::string& TensorShapeProto_Dimension__denotation(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) const override { return p->denotation(); } +- void TensorShapeProto_Dimension__set_denotation(ONNX_NAMESPACE::TensorShapeProto_Dimension* p, const std::string& value) override { return p->set_denotation(value); } + + // TensorShapeProto_Dimensions (wrapped) + std::unique_ptr TensorShapeProto_Dimensions__begin(const ONNX_NAMESPACE::TensorShapeProto_Dimensions* p) override { +@@ -587,90 +534,6 @@ struct ProviderHostImpl : ProviderHost { + + const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; } + +- static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { +- auto* shape = ctx.getAttribute("shape"); +- auto* data_type = ctx.getAttribute("data_type"); +- int32_t elemType = 0; +- if (data_type->s() == "float32") { +- elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; +- } else if (data_type->s() == "int8") { +- elemType = ONNX_NAMESPACE::TensorProto_DataType_INT8; +- } else if (data_type->s() == "uint8") { +- elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT8; +- } else if (data_type->s() == "int32") { +- elemType = ONNX_NAMESPACE::TensorProto_DataType_INT32; +- } else if (data_type->s() == "int64") { +- elemType = ONNX_NAMESPACE::TensorProto_DataType_INT64; +- } else if (data_type->s() == "int1") { +- elemType = ONNX_NAMESPACE::TensorProto_DataType_BOOL; +- } else if (data_type->s() == "bfloat16") { +- elemType = ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16; +- } else if (data_type->s() == "float16") { +- elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; +- } else if (data_type->s() == "uint16") { +- elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT16; +- } else if (data_type->s() == "int16") { +- elemType = ONNX_NAMESPACE::TensorProto_DataType_INT16; +- } else { +- return; +- } +- ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType); +- if (shape != nullptr) { +- for (auto i = 0; i < shape->ints_size(); ++i) { +- ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); +- } +- } else { +- // set scalar type. +- ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim(); +- } +- } +- +- static void xir_fixneuron_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { +- ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); +- ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); +- } +- +- static void xir_subgraph_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { +- auto num_inputs = ctx.getNumInputs(); +- +- // Run inferencing on the subgraph +- auto* graphInferencer = ctx.getGraphAttributeInferencer("body"); +- +- std::vector input_data; +- std::vector subgraph_input_types; +- for (size_t i = 0; i < num_inputs; ++i) { +- input_data.push_back(ctx.getInputData(i)); +- subgraph_input_types.push_back(ctx.getInputType(i)); +- } +- +- auto output_types = graphInferencer->doInferencing(subgraph_input_types, input_data); +- for (size_t i = 0, end = output_types.size(); i < end; ++i) { +- *ctx.getOutputType(i) = *output_types[i]; +- } +- } +- void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) override { +- auto& domain_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); +- const auto& domain_to_version_map = domain_instance.Map(); +- if (domain_to_version_map.find(domain) == domain_to_version_map.end()) { +- domain_instance.AddDomainToVersion(domain, 1, 1000); +- } +- auto schema = CreateSchema(domain, {op}); +- switch (type) { +- case 1: +- schema.TypeAndShapeInferenceFunction(xir_subgraph_shape_inference); +- break; +- case 2: +- schema.TypeAndShapeInferenceFunction(xir_fixneuron_shape_inference); +- break; +- case 3: +- schema.TypeAndShapeInferenceFunction(xir_shape_infer); +- break; +- default: +- break; +- } +- ONNX_NAMESPACE::RegisterSchema(schema, ORT_API_VERSION); +- } +- + // ConfigOptions (wrapped) + std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) override { + return p->GetConfigEntry(config_key); +@@ -902,9 +765,6 @@ struct ProviderHostImpl : ProviderHost { + void Node__ToProto(const Node* p, ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) override { p->ToProto(proto, update_subgraphs); } + + const NodeAttributes& Node__GetAttributes(const Node* p) noexcept override { return p->GetAttributes(); } +- void Node__AddAttribute(Node* p, const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) override { +- p->AddAttribute(attr_name, value); +- } + size_t Node__GetInputEdgesCount(const Node* p) noexcept override { return p->GetInputEdgesCount(); } + size_t Node__GetOutputEdgesCount(const Node* p) noexcept override { return p->GetOutputEdgesCount(); } + +@@ -913,19 +773,13 @@ struct ProviderHostImpl : ProviderHost { + + std::unique_ptr Node__OutputNodesBegin(const Node* p) noexcept override { return std::make_unique(p->OutputNodesBegin()); } + std::unique_ptr Node__OutputNodesEnd(const Node* p) noexcept override { return std::make_unique(p->OutputNodesEnd()); } +- std::unique_ptr Node__InputEdgesBegin(const Node* p) noexcept override { +- return std::make_unique(p->InputEdgesBegin()); +- } +- std::unique_ptr Node__InputEdgesEnd(const Node* p) noexcept override { +- return std::make_unique(p->InputEdgesEnd()); +- } ++ + std::unique_ptr Node__OutputEdgesBegin(const Node* p) noexcept override { return std::make_unique(p->OutputEdgesBegin()); } + std::unique_ptr Node__OutputEdgesEnd(const Node* p) noexcept override { return std::make_unique(p->OutputEdgesEnd()); } + + void Node__ForEachDef(const Node* p, std::function func, bool include_missing_optional_defs) override { p->ForEachDef(func, std::move(include_missing_optional_defs)); } + const std::unordered_map>& Node__GetAttributeNameToMutableSubgraphMap(Node* p) noexcept override { return p->GetAttributeNameToMutableSubgraphMap(); } + std::unordered_map> Node__GetAttributeNameToSubgraphMap(const Node* p) const override { return p->GetAttributeNameToSubgraphMap(); } +- int Node__NodeType(const Node* p) const noexcept override { return int(p->NodeType()); } + + // NodeArg (wrapped) + const std::string& NodeArg__Name(const NodeArg* p) noexcept override { return p->Name(); } +@@ -934,7 +788,6 @@ struct ProviderHostImpl : ProviderHost { + const NodeArgInfo& NodeArg__ToProto(const NodeArg* p) noexcept override { return p->ToProto(); } + bool NodeArg__Exists(const NodeArg* p) const noexcept override { return p->Exists(); } + const ONNX_NAMESPACE::TypeProto* NodeArg__TypeAsProto(const NodeArg* p) noexcept override { return p->TypeAsProto(); } +- Status NodeArg__OverrideTypesHelper(NodeArg* p, const ONNX_NAMESPACE::TypeProto& input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types) override { return p->OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, override_types); }; + + // NodeAttributes (wrapped) + std::unique_ptr NodeAttributes__construct() override { return std::make_unique(); } +@@ -957,20 +810,12 @@ struct ProviderHostImpl : ProviderHost { + } + void NodeAttributes__insert(NodeAttributes* p, const NodeAttributes& v) override { return p->insert(v.begin(), v.end()); } + void NodeAttributes__emplace(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) override { p->emplace(k, v); } +- void NodeAttributes__insert_or_assign(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) override { p->insert_or_assign(k, v); } + void NodeAttributes__reserve(NodeAttributes* p, size_t size) override { p->reserve(size); } + + // Model (wrapped) +- std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, +- const logging::Logger& logger) override { +- return std::make_unique(model_proto, model_path, nullptr, logger); +- } + void Model__operator_delete(Model* p) override { delete p; } + Graph& Model__MainGraph(Model* p) override { return p->MainGraph(); } + std::unique_ptr Model__ToProto(Model* p) override { return std::make_unique(p->ToProto()); } +- std::unique_ptr Model__ToGraphProtoWithExternalInitializers(Model* p, const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) override { return std::make_unique(p->ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold)); }; +- const ModelMetaData& Model__MetaData(const Model* p) const noexcept override { return p->MetaData(); }; +- Status Model__Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) override { return Model::Load(file_path, model_proto); } + + // Graph (wrapped) + std::unique_ptr Graph__CreateGraphViewer(const Graph* p) override { return std::make_unique(*p); } +@@ -990,12 +835,6 @@ struct ProviderHostImpl : ProviderHost { + void Graph__SetOutputs(Graph* p, gsl::span outputs) override { p->SetOutputs(outputs); } + + const std::vector& Graph__GetInputs(const Graph* p) noexcept override { return p->GetInputs(); } +- std::vector Graph__Nodes(const Graph* p) override { +- auto& node_refererence = p->Nodes(); +- std::vector nodes(p->NumberOfNodes(), nullptr); +- std::transform(node_refererence.begin(), node_refererence.end(), nodes.begin(), [](const Node& n) { return &n; }); +- return nodes; +- } + bool Graph__GetInitializedTensor(const Graph* p, const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) override { return p->GetInitializedTensor(tensor_name, value); } + + const Node* Graph__ParentNode(const Graph* p) const override { return p->ParentNode(); } +@@ -1005,40 +844,6 @@ struct ProviderHostImpl : ProviderHost { + const Path& Graph__ModelPath(const Graph* p) const override { return p->ModelPath(); } + const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept override { return p->GetInputsIncludingInitializers(); } + bool Graph__IsSubgraph(const Graph* p) override { return p->IsSubgraph(); } +- const Node* Graph__GetProducerNode(const Graph* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } +- const Model& Graph__GetModel(const Graph* p) override { return p->GetModel(); } +- void Graph__ReverseDFSFrom(const Graph* p, gsl::span from, +- const std::function& enter, +- const std::function& leave, +- const std::function& comp, +- const std::function& stop) const override { +- p->ReverseDFSFrom(from, enter, leave, comp, stop); +- } +- Graph& Graph__SetGraphResolveNeeded(Graph* p) override { return p->SetGraphResolveNeeded(); } +- void Graph__RemoveInitializedTensor(Graph* p, const std::string& tensor_name) override { p->RemoveInitializedTensor(tensor_name); } +- +- std::vector Graph__GetConsumerNodes(const Graph* p, const std::string& node_arg_name) const override { +- return p->GetConsumerNodes(node_arg_name); +- } +- void Graph__AddEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, +- int dst_arg_index) override { +- p->AddEdge(src_node_index, dst_node_index, src_arg_index, dst_arg_index); +- } +- void Graph__RemoveEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, +- int dst_arg_index) override { +- p->RemoveEdge(src_node_index, dst_node_index, src_arg_index, dst_arg_index); +- } +- void Graph__RemoveNode(Graph* p, NodeIndex index) override { p->RemoveNode(index); } +- Node& Graph__FuseSubGraph(Graph* p, const IndexedSubGraph& sub_graph, const std::string& fused_node_name) override { +- return p->FuseSubGraph(sub_graph, fused_node_name); +- } +- void Graph__UpdateProducerNode(Graph* p, const std::string& node_arg_name, NodeIndex node_index) override { +- p->UpdateProducerNode(node_arg_name, node_index); +- } +- const ONNX_NAMESPACE::TensorProto* Graph__GetConstantInitializer(const Graph* p, const std::string& name, bool check_outer_scope) const override { +- return p->GetConstantInitializer(name, check_outer_scope); +- } +- const InitializedTensorSet& Graph__GetAllInitializedTensors(const Graph* p) override { return p->GetAllInitializedTensors(); } + int Graph__MaxNodeIndex(const Graph* p) const noexcept override { return p->MaxNodeIndex(); } + Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept override { return p->GetNode(node_index); } + const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const override { return p->GetNode(node_index); } +@@ -1083,14 +888,11 @@ struct ProviderHostImpl : ProviderHost { + void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept override { + GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args); + } +- const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } + + // Path (wrapped) + PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); } + const std::vector& Path__GetComponents(const Path* p) noexcept override { return p->GetComponents(); } + bool Path__IsEmpty(const Path* p) noexcept override { return p->IsEmpty(); } +- std::unique_ptr Path__construct() override { return std::make_unique(); } +- void Path__operator_delete(ONNX_NAMESPACE::Path* p) override { delete p; }; + + // OpKernel (direct) + const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); } +@@ -1281,11 +1083,6 @@ struct ProviderHostImpl : ProviderHost { + void TensorSeq__Add(TensorSeq* p, Tensor&& tensor) override { p->Add(std::move(tensor)); } + void TensorSeq__Reserve(TensorSeq* p, size_t capacity) override { p->Reserve(capacity); } + +- // ModelMetadefIdGenerator(wrapped) +- std::unique_ptr ModelMetadefIdGenerator__construct() override { return std::make_unique(); } +- void ModelMetadefIdGenerator__operator_delete(ModelMetadefIdGenerator* p) override { delete p; } +- int ModelMetadefIdGenerator__GenerateId(const ModelMetadefIdGenerator* p, const GraphViewer& graph_viewer, HashValue& model_hash) override { return p->GenerateId(graph_viewer, model_hash); } +- + #if defined(ENABLE_TRAINING) && defined(ORT_USE_NCCL) + training::DistributedRunContext& GetDistributedRunContextInstance() override { return training::DistributedRunContext::GetInstance(); } + #endif +@@ -1481,7 +1278,6 @@ static ProviderLibrary s_library_rocm(LIBRARY_PREFIX ORT_TSTR("onnxruntime_provi + #endif + ); + static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_dnnl") LIBRARY_EXTENSION); +-static ProviderLibrary s_library_vitisai(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_vitisai") LIBRARY_EXTENSION); + static ProviderLibrary s_library_openvino(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_openvino") LIBRARY_EXTENSION); + static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_tensorrt") LIBRARY_EXTENSION + #ifndef _WIN32 +@@ -1510,7 +1306,6 @@ static ProviderLibrary s_library_migraphx(LIBRARY_PREFIX ORT_TSTR("onnxruntime_p + + void UnloadSharedProviders() { + s_library_dnnl.Unload(); +- s_library_vitisai.Unload(); + s_library_openvino.Unload(); + s_library_tensorrt.Unload(); + s_library_cuda.Unload(); +@@ -1727,10 +1522,6 @@ std::shared_ptr DnnlProviderFactoryCreator::Create(co + return s_library_dnnl.Get().CreateExecutionProviderFactory(dnnl_options); + } + +-std::shared_ptr VitisAIProviderFactoryCreator::Create(const ProviderOptions& provider_options) { +- return s_library_vitisai.Get().CreateExecutionProviderFactory(&provider_options); +-} +- + ProviderInfo_OpenVINO* GetProviderInfo_OpenVINO() { + return reinterpret_cast(s_library_openvino.Get().GetInfo()); + } +@@ -2556,7 +2347,6 @@ ORT_API_STATUS_IMPL(OrtApis::CreateROCMProviderOptions, _Outptr_ OrtROCMProvider + options->has_user_compute_stream = 0; + options->user_compute_stream = nullptr; + options->default_memory_arena_cfg = nullptr; +- options->enable_hip_graph = false; + options->tunable_op_enable = 0; + options->tunable_op_tuning_enable = 0; + options->tunable_op_max_tuning_duration_ms = 0; +@@ -2623,34 +2413,3 @@ ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProvid + ORT_UNUSED_PARAMETER(ptr); + #endif + } +- +-ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, +- _In_reads_(num_keys) const char* const* provider_options_keys, +- _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { +- API_IMPL_BEGIN +- onnxruntime::ProviderOptions provider_options; +- for (size_t i = 0; i != num_keys; ++i) { +- if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' || +- provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') { +- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Provider options key/value cannot be empty"); +- } +- +- // arbitrary length to validate the key/value. adjust if/when needed. +- // TODO: are any other input validation checks required here (and in the other functions that process +- // provider options)? +- if (strlen(provider_options_keys[i]) > 1024 || strlen(provider_options_values[i]) > 1024) { +- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, +- "Maximum string length for a provider options key/value is 1024."); +- } +- +- provider_options[provider_options_keys[i]] = provider_options_values[i]; +- } +- auto factory = onnxruntime::VitisAIProviderFactoryCreator::Create(provider_options); +- if (!factory) { +- return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_VitisAI: Failed to load shared library"); +- } +- +- options->provider_factories.push_back(factory); +- return nullptr; +- API_IMPL_END +-} +diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc +index ade1d96d6..86b3d01c6 100644 +--- a/onnxruntime/core/session/provider_registration.cc ++++ b/onnxruntime/core/session/provider_registration.cc +@@ -145,7 +145,13 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, + if (options->value.config_options.TryGetConfigEntry("preferredLayout", preferred_layout)) { + provider_options["preferred_layout"] = preferred_layout; + } +- options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); ++ options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options)); ++#else ++ status = create_not_supported_status(); ++#endif ++ } else if (strcmp(provider_name, "VitisAI") == 0) { ++#if defined(USE_VITISAI) ++ options->provider_factories.push_back(VitisAIProviderFactoryCreator::Create(provider_options)); + #else + status = create_not_supported_status(); + #endif +@@ -493,14 +499,4 @@ ORT_API_STATUS_IMPL(OrtApis::GetROCMProviderOptionsAsString, + ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions* ptr) { + ORT_UNUSED_PARAMETER(ptr); + } +- +-ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, +- _In_ OrtSessionOptions* options, _In_reads_(num_keys) const char* const* provider_options_keys, +- _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { +- ORT_UNUSED_PARAMETER(options); +- ORT_UNUSED_PARAMETER(provider_options_keys); +- ORT_UNUSED_PARAMETER(provider_options_values); +- ORT_UNUSED_PARAMETER(num_keys); +- return CreateNotEnabledStatus("VitisAI"); +-} + #endif +diff --git a/onnxruntime/core/util/thread_utils.cc b/onnxruntime/core/util/thread_utils.cc +index a5a165e15..48f58add8 100644 +--- a/onnxruntime/core/util/thread_utils.cc ++++ b/onnxruntime/core/util/thread_utils.cc +@@ -7,7 +7,6 @@ + + #ifdef _WIN32 + #include +-#include + #endif + #include + #include "core/session/ort_apis.h" +@@ -99,16 +98,7 @@ CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) { + } + options.thread_pool_size = static_cast(default_affinities.size()); + if (options.auto_set_affinity) { +-#ifdef _WIN32 +- // Only set thread affinity on Server with auto affinity. +- // On client best to let OS scheduler handle. +- // On big (P-Core) / little (E-Core) CPU designs affinity overrides QoS and has high power usage +- if (IsWindowsServer()) { +- to.affinities = std::move(default_affinities); +- } +-#else + to.affinities = std::move(default_affinities); +-#endif + } + } + if (options.thread_pool_size <= 1) { +diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +index 0bbcee12e..f470e9f6b 100644 +--- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc ++++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +@@ -659,12 +659,7 @@ static bool CheckIfInputIsSequenceType(const std::string& name_input, + if (!temp) { + throw std::runtime_error("Corresponding type_proto is null"); + } else { +- if (temp->has_optional_type()) { +- const ::onnx::TypeProto_Optional& optional_type_proto = temp->optional_type(); +- type_proto = optional_type_proto.elem_type(); +- } else { +- type_proto = *temp; +- } ++ type_proto = *temp; + } + + return type_proto.has_sequence_type(); +diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc +index 9c36eb635..8e13982ca 100644 +--- a/onnxruntime/python/onnxruntime_pybind_state.cc ++++ b/onnxruntime/python/onnxruntime_pybind_state.cc +@@ -982,7 +982,7 @@ std::unique_ptr CreateExecutionProviderInstance( + return onnxruntime::TVMProviderFactoryCreator::Create(info)->CreateProvider(); + #endif + } else if (type == kVitisAIExecutionProvider) { +-#ifdef USE_VITISAI ++#if USE_VITISAI + const auto it = provider_options_map.find(type); + if (it == provider_options_map.end()) { + LOGS_DEFAULT(FATAL) << "cannot find provider options for VitisAIExecutionProvider"; +diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py +index 802d924c2..6e1e43184 100644 +--- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py ++++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py +@@ -44,7 +44,6 @@ total_seqlens = [128, 512] + num_heads = [8, 12] + head_sizes = [64] + biaseds = [False, True] +-causals = [False] + mask_dims = [0, 2, 3, 4] + + +@@ -82,57 +81,8 @@ def maybe_pack_q_k_v_bnsh_for_device_on_host(q, k, v, dtype, qkv_format): + raise NotImplementedError + + +-def _make_causal_mask( +- seqence_length, +- total_sequence_length, +- dtype: np.dtype, +-): +- """ +- Make causal mask used for Attention with attribute unidirectional == 1. +- The mask is a upper triangular matrix with shape [sequence_length, total_sequence_length]. +- Putting a 1 indicates that the token at this position should be masked. +- For Example: +- sequence_length = 5, total_sequence_length = 5, +- mask: [[0. 1. 1. 1. 1.] +- [0. 0. 1. 1. 1.] +- [0. 0. 0. 1. 1.] +- [0. 0. 0. 0. 1.] +- [0. 0. 0. 0. 0.]] +- seqence_length = 5, total_seqence_length = 3, +- mask: [[1. 1. 1.] +- [1. 1. 1.] +- [0. 1. 1.] +- [0. 0. 1.] +- [0. 0. 0.]] +- seqence_length = 5, total_seqence_length = 7, +- mask: [[0. 0. 0. 1. 1. 1. 1.] +- [0. 0. 0. 0. 1. 1. 1.] +- [0. 0. 0. 0. 0. 1. 1.] +- [0. 0. 0. 0. 0. 0. 1.] +- [0. 0. 0. 0. 0. 0. 0.]] +- """ +- mask = np.full((seqence_length, seqence_length), 1) +- mask_cond = np.arange(mask.shape[-1]) +- mask = np.where(mask_cond < (mask_cond + 1).reshape(mask.shape[-1], 1), 0, mask) +- +- mask = mask.astype(dtype) +- +- if total_sequence_length - seqence_length > 0: +- mask = np.concatenate( +- [np.zeros((seqence_length, total_sequence_length - seqence_length), dtype=dtype), mask], axis=-1 +- ) +- +- if total_sequence_length - seqence_length < 0: +- mask = mask[:, -total_sequence_length:] +- +- correct_mask = np.full((seqence_length, total_sequence_length), 1) +- for i in range(seqence_length): +- correct_mask[i][:] = sum(mask[i]) != total_sequence_length +- return mask, correct_mask +- +- + def _test_gemm_softmax_gemm_permute( +- f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format ++ f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, qkv_format + ): + v_head_size = head_size + q_shape = [batch, num_heads, seqlen, head_size] +@@ -173,8 +123,6 @@ def _test_gemm_softmax_gemm_permute( + pre_softmax_attn_scores = pre_softmax_attn_scores * scale + if attn_bias is not None: + pre_softmax_attn_scores = pre_softmax_attn_scores + attn_bias +- +- correct_causal_mask = np.full((seqlen, total_seqlen), 1) + if attn_mask is not None: + filter_value = -10000.0 + if mask_dim == 4: +@@ -183,18 +131,7 @@ def _test_gemm_softmax_gemm_permute( + else: + converted_mask = (1 - attn_mask.reshape(mask_shape_broadcasted)) * filter_value + pre_softmax_attn_scores = pre_softmax_attn_scores + converted_mask +- if causal: +- filter_value = np.finfo(dtype).min +- causal_mask, correct_causal_mask = _make_causal_mask(seqlen, total_seqlen, pre_softmax_attn_scores.dtype) +- causal_mask = np.broadcast_to(causal_mask, pre_softmax_attn_scores.shape) * filter_value +- pre_softmax_attn_scores = pre_softmax_attn_scores + causal_mask + attn_scores = softmax(pre_softmax_attn_scores, axis=-1) +- +- # apply mask to attn_scores to correct softmax result, in c++ implementation, if all values in a row are masked, +- # the softmax result in this row will be filled with 0. +- correct_causal_mask = np.broadcast_to(correct_causal_mask, pre_softmax_attn_scores.shape) +- attn_scores = attn_scores * correct_causal_mask +- + attn = matmul(attn_scores, v) + ref = np.swapaxes(attn, 2, 1) # permute 0213 + +@@ -217,7 +154,6 @@ def _test_gemm_softmax_gemm_permute( + head_size, + mask_dim, + scale, +- causal, + qkv_format, + dev_q, + dev_k, +@@ -266,26 +202,12 @@ def _test_gemm_softmax_gemm_permute( + @pytest.mark.parametrize("total_seqlen", total_seqlens) + @pytest.mark.parametrize("seqlen", seqlens) + @pytest.mark.parametrize("batch", [16]) +-@pytest.mark.parametrize("causal", [False, True]) + @pytest.mark.parametrize("dtype", ["float16", "float32"]) +-def test_gemm_softmax_gemm_permute_generic( +- dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim +-): ++def test_gemm_softmax_gemm_permute_generic(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim): + f = getattr(ke, "GemmSoftmaxGemmPermuteGeneric_" + dtype_to_suffix(dtype)) + scale = 1.0 / np.sqrt(head_size) + _test_gemm_softmax_gemm_permute( +- f, +- dtype, +- batch, +- seqlen, +- total_seqlen, +- nhead, +- head_size, +- biased, +- mask_dim, +- scale, +- causal, +- ke.qkv_format.Q_K_V_BNSH, ++ f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, ke.qkv_format.Q_K_V_BNSH + ) + + +@@ -296,26 +218,14 @@ def test_gemm_softmax_gemm_permute_generic( + @pytest.mark.parametrize("total_seqlen", [128]) + @pytest.mark.parametrize("seqlen", [64]) + @pytest.mark.parametrize("batch", [16]) +-@pytest.mark.parametrize("causal", [True, False]) + @pytest.mark.parametrize("dtype", ["float16", "float32"]) + def test_gemm_softmax_gemm_permute_generic_nested_tunable( +- dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim ++ dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim + ): + f = getattr(ke, "GemmSoftmaxGemmPermuteGenericNestedTunable_" + dtype_to_suffix(dtype)) + scale = 1.0 / np.sqrt(head_size) + _test_gemm_softmax_gemm_permute( +- f, +- dtype, +- batch, +- seqlen, +- total_seqlen, +- nhead, +- head_size, +- biased, +- mask_dim, +- scale, +- causal, +- ke.qkv_format.Q_K_V_BNSH, ++ f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, ke.qkv_format.Q_K_V_BNSH + ) + + +@@ -327,24 +237,12 @@ def test_gemm_softmax_gemm_permute_generic_nested_tunable( + @pytest.mark.parametrize("total_seqlen", total_seqlens) + @pytest.mark.parametrize("seqlen", seqlens) + @pytest.mark.parametrize("batch", batches) +-@pytest.mark.parametrize("causal", [False, True]) + @pytest.mark.parametrize("dtype", dtypes) +-def test_gemm_softmax_gemm_permute_ck(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim): ++def test_gemm_softmax_gemm_permute_ck(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim): + f = getattr(ke, get_ck_binding_name(dtype, biased, mask_dim != 0)) + scale = 1.0 / np.sqrt(head_size) + _test_gemm_softmax_gemm_permute( +- f, +- dtype, +- batch, +- seqlen, +- total_seqlen, +- nhead, +- head_size, +- biased, +- mask_dim, +- scale, +- causal, +- ke.qkv_format.Q_K_V_BNSH, ++ f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, ke.qkv_format.Q_K_V_BNSH + ) + + +@@ -355,26 +253,12 @@ def test_gemm_softmax_gemm_permute_ck(dtype, batch, seqlen, total_seqlen, nhead, + @pytest.mark.parametrize("total_seqlen", [128]) + @pytest.mark.parametrize("seqlen", [64]) + @pytest.mark.parametrize("batch", [16]) +-@pytest.mark.parametrize("causal", [False, True]) + @pytest.mark.parametrize("dtype", ["float16"]) +-def test_gemm_softmax_gemm_permute_tunable( +- dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim +-): ++def test_gemm_softmax_gemm_permute_tunable(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim): + f = getattr(ke, "GemmSoftmaxGemmPermuteTunable_" + dtype_to_suffix(dtype)) + scale = 1.0 / np.sqrt(head_size) + _test_gemm_softmax_gemm_permute( +- f, +- dtype, +- batch, +- seqlen, +- total_seqlen, +- nhead, +- head_size, +- biased, +- mask_dim, +- scale, +- causal, +- ke.qkv_format.Q_K_V_BNSH, ++ f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, ke.qkv_format.Q_K_V_BNSH + ) + + +@@ -394,17 +278,16 @@ stabel_diffusion_configs = [ + @pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") + @pytest.mark.parametrize("mask_dim", [0], ids=get_mask_dim_id) + @pytest.mark.parametrize("biased", [False], ids=get_biased_id) +-@pytest.mark.parametrize("causal", [False, True]) + @pytest.mark.parametrize("batch, seqlen, total_seqlen, nhead, head_size, qkv_format_name", stabel_diffusion_configs) + @pytest.mark.parametrize("dtype", dtypes) + def test_gemm_softmax_gemm_permute_ck_sd( +- dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim, qkv_format_name ++ dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, qkv_format_name + ): + qkv_format = getattr(ke.qkv_format, qkv_format_name) + f = getattr(ke, get_ck_binding_name(dtype, biased, mask_dim != 0)) + scale = 1.0 / np.sqrt(head_size) + _test_gemm_softmax_gemm_permute( +- f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, causal, qkv_format ++ f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, qkv_format + ) + + +@@ -433,7 +316,7 @@ class GemmSoftmaxGemmPermuteMetric(ke.ComputeMetric): + + + def profile_gemm_softmax_gemm_permute_func( +- f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format ++ f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, qkv_format + ): + v_head_size = head_size + q_shape = [batch, num_heads, seqlen, head_size] +@@ -486,7 +369,6 @@ def profile_gemm_softmax_gemm_permute_func( + head_size, + mask_dim, + scale, +- causal, + qkv_format, + dev_q, + dev_k, +@@ -520,10 +402,10 @@ def profile_gemm_softmax_gemm_permute_func( + + + def profile_with_args( +- dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, causal, mask_dim, scale, qkv_format, *, sort=False ++ dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, qkv_format, *, sort=False + ): + with ke.benchmark(sort): +- args = (dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format) ++ args = (dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, qkv_format) + if qkv_format == ke.qkv_format.Q_K_V_BNSH: + profile_gemm_softmax_gemm_permute_func( + getattr(ke, "GemmSoftmaxGemmPermuteGeneric_" + dtype_to_suffix(dtype)), *args +@@ -547,7 +429,6 @@ def profile(): + nhead, + head_size, + biased=False, +- causal=False, + mask_dim=0, + qkv_format=getattr(ke.qkv_format, qkv_format_name), + scale=0.125, +@@ -555,7 +436,7 @@ def profile(): + ) + print() + +- for args in product(dtypes, batches, seqlens, total_seqlens, num_heads, head_sizes, biaseds, causals, mask_dims): ++ for args in product(dtypes, batches, seqlens, total_seqlens, num_heads, head_sizes, biaseds, mask_dims): + profile_with_args(*args, qkv_format=ke.qkv_format.Q_K_V_BNSH, scale=0.125, sort=True) + print() + +@@ -574,7 +455,6 @@ if __name__ == "__main__": + group.add_argument("head_size", type=int) + group.add_argument("biased", type=int, choices=[0, 1], default=0) + group.add_argument("mask_dim", type=int, choices=[0, 2, 3, 4], default=2, help="0 for mask disabled") +- group.add_argument("causal", type=int, choices=[0, 1], default=0) + group.add_argument("--scale", type=float, default=None, help="default to 1.0/sqrt(head_size)") + group.add_argument( + "--qkv_format", +@@ -591,7 +471,6 @@ if __name__ == "__main__": + profile() + else: + args = parser.parse_args() +- print(args) + profile_with_args( + args.dtype, + args.batch, +@@ -600,7 +479,6 @@ if __name__ == "__main__": + args.num_heads, + args.head_size, + args.biased, +- args.causal, + args.mask_dim, + args.scale, + getattr(ke.qkv_format, args.qkv_format), +diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu +index 7068fc8fd..5e60bad77 100644 +--- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu ++++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu +@@ -28,7 +28,6 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer { + int64_t head_size, + int64_t mask_dim, + double scale, +- bool causal, + contrib::AttentionQkvFormat qkv_format, + DeviceArray& Q, + std::optional& K, +@@ -52,7 +51,7 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer { + attn_.v_hidden_size = attn_.hidden_size; // Q,K,V hidden size must agree now + attn_.v_head_size = attn_.head_size; // Q,K,V hidden size must agree now + attn_.num_heads = num_heads; +- attn_.is_unidirectional = causal; ++ attn_.is_unidirectional = false; + attn_.past_present_share_buffer = false; + attn_.do_rotary = false; + attn_.mask_filter_value = -10000.0f; +@@ -149,7 +148,6 @@ class GemmSoftmaxGemmPermuteGeneric : public IGemmSoftmaxGemmPermuteKernelExplor + int64_t head_size, + int64_t mask_dim, + double scale, +- bool causal, + contrib::AttentionQkvFormat qkv_format, + DeviceArray& Q, + std::optional& K, +@@ -158,7 +156,7 @@ class GemmSoftmaxGemmPermuteGeneric : public IGemmSoftmaxGemmPermuteKernelExplor + std::optional& attn_mask, + DeviceArray& out) + : IGemmSoftmaxGemmPermuteKernelExplorer(batch, seqlen, total_seqlen, max_seqlen, +- num_heads, head_size, mask_dim, scale, causal, qkv_format, ++ num_heads, head_size, mask_dim, scale, qkv_format, + Q, K, V, attn_bias, attn_mask, out) { + this->SetWorkspace(GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(&this->attn_)); + } +@@ -189,7 +187,6 @@ class GemmSoftmaxGemmPermuteGenericNestedTunable : public GemmSoftmaxGemmPermute + int64_t head_size, + int64_t mask_dim, + double scale, +- bool causal, + contrib::AttentionQkvFormat qkv_format, + DeviceArray& Q, + std::optional& K, +@@ -198,7 +195,7 @@ class GemmSoftmaxGemmPermuteGenericNestedTunable : public GemmSoftmaxGemmPermute + std::optional& attn_mask, + DeviceArray& out) + : GemmSoftmaxGemmPermuteGeneric(batch, seqlen, total_seqlen, max_seqlen, +- num_heads, head_size, mask_dim, scale, causal, qkv_format, ++ num_heads, head_size, mask_dim, scale, qkv_format, + Q, K, V, attn_bias, attn_mask, out) { + this->params_.TuningContext()->EnableTunableOpAndTuning(); + } +@@ -217,7 +214,6 @@ class GemmSoftmaxGemmPermuteCK : public IGemmSoftmaxGemmPermuteKernelExplorer + int64_t head_size, + int64_t mask_dim, + double scale, +- bool causal, + contrib::AttentionQkvFormat qkv_format, + DeviceArray& Q, + std::optional& K, +@@ -226,7 +222,7 @@ class GemmSoftmaxGemmPermuteCK : public IGemmSoftmaxGemmPermuteKernelExplorer + std::optional& attn_mask, + DeviceArray& out) + : IGemmSoftmaxGemmPermuteKernelExplorer(batch, seqlen, total_seqlen, max_seqlen, +- num_heads, head_size, mask_dim, scale, causal, qkv_format, ++ num_heads, head_size, mask_dim, scale, qkv_format, + Q, K, V, attn_bias, attn_mask, out) { + this->SetWorkspace(GemmSoftmaxGemmPermuteTunableOp::GetWorkspaceNumBytes(&this->attn_)); + +@@ -279,7 +275,6 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor + int64_t head_size, + int64_t mask_dim, + double scale, +- bool causal, + contrib::AttentionQkvFormat qkv_format, + DeviceArray& Q, + std::optional& K, +@@ -288,7 +283,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor + std::optional& attn_mask, + DeviceArray& out) + : IGemmSoftmaxGemmPermuteKernelExplorer(batch, seqlen, total_seqlen, max_seqlen, +- num_heads, head_size, mask_dim, scale, causal, qkv_format, ++ num_heads, head_size, mask_dim, scale, qkv_format, + Q, K, V, attn_bias, attn_mask, out) { + this->SetWorkspace(std::max( + GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(&this->attn_), +@@ -316,7 +311,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor + #define REGISTER_COMMON(name, type, ...) \ + py::class_>(m, name) \ + .def(py::init, int64_t, int64_t, int64_t, \ +- float, bool, contrib::AttentionQkvFormat, \ ++ float, contrib::AttentionQkvFormat, \ + DeviceArray&, \ + std::optional&, \ + std::optional&, \ +diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py +index 123cfe913..b0153aed7 100644 +--- a/onnxruntime/python/tools/quantization/qdq_quantizer.py ++++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py +@@ -270,8 +270,6 @@ class QDQQuantizer(ONNXQuantizer): + + self.model.model.producer_name = __producer__ + self.model.model.producer_version = __version__ +- if self.qdq_op_domain == ms_domain: +- self.model.set_opset_import(ms_domain, 1) + + return self.model.model + +diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py +index 9823e8264..ef4c4ae90 100755 +--- a/onnxruntime/python/tools/symbolic_shape_infer.py ++++ b/onnxruntime/python/tools/symbolic_shape_infer.py +@@ -197,7 +197,6 @@ class SymbolicShapeInference: + "BiasGelu": self._infer_BiasGelu, + "BiasSplitGelu": self._infer_BiasSplitGelu, + "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention, +- "DequantizeLinear": self._infer_DequantizeLinear, + "EmbedLayerNormalization": self._infer_EmbedLayerNormalization, + "FastGelu": self._infer_FastGelu, + "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, +@@ -213,7 +212,6 @@ class SymbolicShapeInference: + "PackedAttention": self._infer_PackedAttention, + "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, + "PythonOp": self._infer_PythonOp, +- "QuantizeLinear": self._infer_QuantizeLinear, + "QuickGelu": self._infer_FastGelu, + "RelativePositionBias": self._infer_RelativePositionBias, + "RemovePadding": self._infer_RemovePadding, +@@ -459,8 +457,6 @@ class SymbolicShapeInference: + "GemmFastGelu", + "LayerNormalization", + "LongformerAttention", +- "DequantizeLinear", +- "QuantizeLinear", + "RelativePositionBias", + "RemovePadding", + "RestorePadding", +@@ -983,29 +979,6 @@ class SymbolicShapeInference: + ) + ) + +- def _infer_DequantizeLinear(self, node): # noqa: N802 +- # Get the output data type from the scale input (index 1, required). +- output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type +- +- # Get the output shape from the first input. +- output_shape = self._get_shape(node, 0) +- +- vi = self.known_vi_[node.output[0]] +- vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) +- +- def _infer_QuantizeLinear(self, node): # noqa: N802 +- # Get the output data type from the zero-point input (index 2, optional). +- # Otherwise, default to uint8 +- output_dtype = onnx.TensorProto.UINT8 +- if len(node.input) > 2 and node.input[2]: +- output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type +- +- # Get the output shape from the first input. +- output_shape = self._get_shape(node, 0) +- +- vi = self.known_vi_[node.output[0]] +- vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) +- + def _infer_Einsum(self, node): # noqa: N802 + # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 + equation = get_attribute(node, "equation") +diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py +index a2cdd17e1..17f0dd0bc 100644 +--- a/onnxruntime/python/tools/transformers/convert_generation.py ++++ b/onnxruntime/python/tools/transformers/convert_generation.py +@@ -55,6 +55,10 @@ import onnx + import torch + from benchmark_helper import Precision, setup_logger + from fusion_utils import NumpyHelper ++from models.gpt2.convert_to_onnx import main as convert_gpt2_to_onnx ++from models.gpt2.gpt2_helper import PRETRAINED_GPT2_MODELS ++from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models ++from models.t5.t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS + from onnx import GraphProto, ModelProto, TensorProto + from onnx_model import OnnxModel + from transformers import ( +@@ -69,10 +73,6 @@ from transformers import ( + ) + + from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers +-from onnxruntime.transformers.models.gpt2.convert_to_onnx import main as convert_gpt2_to_onnx +-from onnxruntime.transformers.models.gpt2.gpt2_helper import PRETRAINED_GPT2_MODELS +-from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models +-from onnxruntime.transformers.models.t5.t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS + + logger = logging.getLogger("") + +@@ -372,7 +372,7 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: + type=int, + required=False, + default=1, +- help="Minimum number of tokens we keep per batch example in the output.", ++ help="Minimumber of tokens we keep per batch example in the output.", + ) + + beam_parameters_group.add_argument( +@@ -466,7 +466,7 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: + "--save_test_data", + required=False, + action="store_true", +- help="save test data for onnxruntime_perf_test tool", ++ help="save test data for onnxruntimer_perf_test tool", + ) + test_group.set_defaults(save_test_data=False) + +@@ -1225,7 +1225,7 @@ def find_past_seq_len_usage(subg: GraphProto): + tensor_names_to_rename = set() + nodes_to_remove = [] + +- graph_input_names = {inp.name: index for index, inp in enumerate(subg.input)} ++ graph_intput_names = {inp.name: index for index, inp in enumerate(subg.input)} + + input_name_to_nodes = {} + output_name_to_node = {} +@@ -1259,7 +1259,7 @@ def find_past_seq_len_usage(subg: GraphProto): + if ( + shape_node.op_type == "Shape" + and shape_node.input[0] +- and shape_node.input[0] in graph_input_names ++ and shape_node.input[0] in graph_intput_names + and ( + shape_node.input[0].startswith("past_key_self_") + or shape_node.input[0].startswith("past_value_self_") +@@ -1423,7 +1423,7 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: ModelP + if node.op_type == "MultiHeadAttention": + old_nodes.extend([node]) + +- # If not all the MultiHeadAttention nodes are fused, this optimization is not applicable ++ # If not all the MultiheadAttention nodes are fused, this optimization is not applicable + if len(old_nodes) < num_layers: + return False + +diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py +index c6d550d47..a8b84729b 100644 +--- a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py ++++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py +@@ -253,7 +253,7 @@ def save_results(results, filename): + # Save results to csv with standard format + records = [] + for _, row in df.iterrows(): +- if row["Engine"] in ["optimum-ort", "onnxruntime"]: ++ if row["Engine"] == "optimum-ort": + record = BenchmarkRecord( + row["Model Name"], row["Precision"], "onnxruntime", row["Device"], ort_pkg_name, ort_pkg_version + ) +diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +index 2cd64e878..40692701c 100644 +--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py ++++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +@@ -61,7 +61,6 @@ if __name__ == "__main__": + controlnet_scales=controlnet_scale, + show_latency=not warmup, + output_type="pil", +- deterministic=args.deterministic, + ) + + if not args.disable_cuda_graph: +diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +index 369f31511..965a2598a 100644 +--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py ++++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +@@ -239,11 +239,8 @@ def parse_arguments(is_xl: bool, parser): + ) + parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling.") + parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") +- parser.add_argument("--deterministic", action="store_true", help="use deterministic algorithms.") + parser.add_argument("-dc", "--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + +- parser.add_argument("--framework-model-dir", default=None, help="framework model directory") +- + group = parser.add_argument_group("Options for ORT_CUDA engine only") + group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.") + +@@ -408,7 +405,6 @@ def initialize_pipeline( + lora_scale=1.0, + use_fp16_vae=True, + use_vae=True, +- framework_model_dir=None, + ): + pipeline_info = PipelineInfo( + version, +@@ -428,7 +424,7 @@ def initialize_pipeline( + input_engine_dir = engine_dir + + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( +- work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type, framework_model_dir=framework_model_dir ++ work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type + ) + + pipeline = StableDiffusionPipeline( +@@ -561,7 +557,6 @@ def load_pipelines(args, batch_size=None): + "lora_scale": args.lora_scale, + "use_fp16_vae": "xl" in args.version, + "use_vae": True, +- "framework_model_dir": args.framework_model_dir, + } + + if "xl" in args.version: +diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +index c03c6f0b2..46a83f5dc 100644 +--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py ++++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +@@ -5,7 +5,6 @@ + import hashlib + import os + from enum import Enum +-from typing import Optional + + import torch + from diffusion_models import CLIP, VAE, CLIPWithProj, PipelineInfo, UNet, UNetXL +@@ -274,9 +273,7 @@ class EngineBuilder: + return self._vae_decode(latents) + + +-def get_engine_paths( +- work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType, framework_model_dir: Optional[str] = None +-): ++def get_engine_paths(work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType): + root_dir = work_dir or "." + short_name = pipeline_info.short_name() + +@@ -290,7 +287,6 @@ def get_engine_paths( + + # Shared among ORT_CUDA, ORT_TRT and TRT engines, and need use load_model(..., always_download_fp16=True) + # So that the shared model is always fp16. +- if framework_model_dir is None: +- framework_model_dir = os.path.join(root_dir, "torch_model") ++ framework_model_dir = os.path.join(root_dir, "torch_model") + + return onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache +diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +index 0ad8b13b6..104ce984b 100644 +--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py ++++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +@@ -754,7 +754,6 @@ class StableDiffusionPipeline: + controlnet_scales: Optional[torch.Tensor] = None, + show_latency: bool = False, + output_type: str = "pil", +- deterministic: bool = False, + ): + """ + Run the diffusion pipeline. +@@ -784,9 +783,6 @@ class StableDiffusionPipeline: + output_type (str): + It can be "latent", "pt" or "pil". + """ +- if deterministic: +- torch.use_deterministic_algorithms(True) +- + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER +diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/astronaut_riding_txt2image-DDIM-50.png b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/astronaut_riding_txt2image-DDIM-50.png +deleted file mode 100644 +index 9d20ce550..000000000 +Binary files a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/astronaut_riding_txt2image-DDIM-50.png and /dev/null differ +diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py +deleted file mode 100644 +index da7f47b14..000000000 +--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py ++++ /dev/null +@@ -1,73 +0,0 @@ +-import argparse +-import os +-from typing import Optional +- +-import cv2 +-import open_clip +-import torch +-from PIL import Image +-from sentence_transformers import util +- +- +-def arg_parser(): +- parser = argparse.ArgumentParser(description="Options for Compare 2 image") +- parser.add_argument("--image1", type=str, help="Path to image 1") +- parser.add_argument("--image2", type=str, help="Path to image 2") +- parser.add_argument("--cache_dir", type=str, help="Path to model cache directory") +- args = parser.parse_args() +- return args +- +- +-def image_encoder(img: Image.Image, cache_dir: Optional[str] = None): # -> torch.Tensor: +- device = "cuda" if torch.cuda.is_available() else "cpu" +- model, _, preprocess = open_clip.create_model_and_transforms( +- "ViT-B-16-plus-240", pretrained="laion400m_e32", cache_dir=cache_dir +- ) +- model.to(device) +- +- img1 = Image.fromarray(img).convert("RGB") +- img1 = preprocess(img1).unsqueeze(0).to(device) +- img1 = model.encode_image(img1) +- return img1 +- +- +-def load_image(image_path: str): # -> Image.Image: +- # cv2.imread() can silently fail when the path is too long +- # https://stackoverflow.com/questions/68716321/how-to-use-absolute-path-in-cv2-imread +- if os.path.isabs(image_path): +- directory = os.path.dirname(image_path) +- current_directory = os.getcwd() +- os.chdir(directory) +- img = cv2.imread(os.path.basename(image_path), cv2.IMREAD_UNCHANGED) +- os.chdir(current_directory) +- else: +- img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) +- return img +- +- +-def generate_score(image1: str, image2: str, cache_dir: Optional[str] = None): # -> float: +- test_img = load_image(image1) +- data_img = load_image(image2) +- img1 = image_encoder(test_img, cache_dir) +- img2 = image_encoder(data_img, cache_dir) +- cos_scores = util.pytorch_cos_sim(img1, img2) +- score = round(float(cos_scores[0][0]) * 100, 2) +- return score +- +- +-def main(): +- args = arg_parser() +- image1 = args.image1 +- image2 = args.image2 +- cache_dir = args.cache_dir +- score = round(generate_score(image1, image2, cache_dir), 2) +- print("similarity Score: ", {score}) +- if score < 97: +- print(f"{image1} and {image2} are different") +- raise SystemExit(1) +- else: +- print(f"{image1} and {image2} are same") +- +- +-if __name__ == "__main__": +- main() +diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/requirements.txt +deleted file mode 100644 +index e51ffb395..000000000 +--- a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/requirements.txt ++++ /dev/null +@@ -1,4 +0,0 @@ +-git+https://github.com/openai/CLIP.git +-open_clip_torch +-sentence_transformers +-pillow +diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +index a74666b7a..33958e55f 100644 +--- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py ++++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +@@ -53,9 +53,9 @@ def chain_model(args): + + beam_outputs = ["sequences"] + if args.output_sequence_scores: +- beam_outputs.append("sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores") ++ beam_outputs.append("sequence_scores") + if args.output_scores: +- beam_outputs.append("scores_fp16" if args.precision == Precision.FLOAT16 else "scores") ++ beam_outputs.append("scores") + + if args.use_whisper_beamsearch: + assert len(beam_inputs) == 12 +@@ -75,7 +75,6 @@ def chain_model(args): + beam_outputs.extend(["no_speech_probs_beam"]) + + input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None +- output_scores_cast_node = output_sequence_scores_cast_node = None + if args.precision == Precision.FLOAT16: + input_features_cast_node = helper.make_node( + "Cast", +@@ -98,22 +97,6 @@ def chain_model(args): + name="CastRepetitionPenaltyToFp16", + to=TensorProto.FLOAT16, + ) +- if args.output_sequence_scores: +- output_sequence_scores_cast_node = helper.make_node( +- "Cast", +- inputs=["sequence_scores_fp16"], +- outputs=["sequence_scores"], +- name="CastOutputSequenceScoresToFp32", +- to=TensorProto.FLOAT, +- ) +- if args.output_scores: +- output_scores_cast_node = helper.make_node( +- "Cast", +- inputs=["scores_fp16"], +- outputs=["scores"], +- name="CastScoresToFp32", +- to=TensorProto.FLOAT, +- ) + + operator_type = "WhisperBeamSearch" if args.use_whisper_beamsearch else "BeamSearch" + node = helper.make_node(operator_type, inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode") +@@ -231,18 +214,10 @@ def chain_model(args): + opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)] + + graph_nodes = ( +- [ +- input_features_cast_node, +- len_pen_cast_node, +- rep_pen_cast_node, +- node, +- output_sequence_scores_cast_node, +- output_scores_cast_node, +- ] ++ [input_features_cast_node, len_pen_cast_node, rep_pen_cast_node, node] + if args.precision == Precision.FLOAT16 + else [node] + ) +- graph_nodes = [node for node in graph_nodes if node is not None] + if args.output_no_speech_probs: + prob_cast_node = helper.make_node( + "Cast", +diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +index 2ad20eafc..e0ed32630 100644 +--- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc ++++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +@@ -14,8 +14,6 @@ + #include "test/optimizer/graph_transform_test_builder.h" + #include "test/providers/provider_test_utils.h" + #include "test/util/include/default_providers.h" +-#include "core/session/onnxruntime_cxx_api.h" +-#include "core/session/ort_env.h" + #include "core/util/qmath.h" + + #include +@@ -23,13 +21,12 @@ + + #include "gtest/gtest.h" + #include "gmock/gmock.h" +-extern std::unique_ptr ort_env; + + namespace onnxruntime { +- + namespace test { + + static constexpr int QBits = 4; ++ + void QuantizeDequantize(std::vector& raw_vals, + std::vector& quant_vals, + std::vector& scales, +@@ -37,8 +34,9 @@ void QuantizeDequantize(std::vector& raw_vals, + int32_t N, + int32_t K, + int32_t block_size) { +- auto& ortenv = **ort_env.get(); +- onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool(); ++ OrtThreadPoolParams to; ++ auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, ++ concurrency::ThreadPoolType::INTRA_OP); + + MlasQuantizeBlockwise( + quant_vals.data(), +@@ -50,7 +48,7 @@ void QuantizeDequantize(std::vector& raw_vals, + K, + N, + N, +- tp); ++ tp.get()); + + // Note that input1_f_vals is NxK after dequant + MlasDequantizeBlockwise( +@@ -62,7 +60,7 @@ void QuantizeDequantize(std::vector& raw_vals, + true, // columnwise quantization + K, // number of rows + N, // number of columns +- tp); ++ tp.get()); + } + + void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level, +diff --git a/onnxruntime/test/framework/execution_provider_test.cc b/onnxruntime/test/framework/execution_provider_test.cc +index 390fda7bf..5a7351a76 100644 +--- a/onnxruntime/test/framework/execution_provider_test.cc ++++ b/onnxruntime/test/framework/execution_provider_test.cc +@@ -6,7 +6,6 @@ + #include "test_utils.h" + #include "test/test_environment.h" + #include "test/util/include/asserts.h" +-#include "core/framework/model_metadef_id_generator.h" + + #include "gtest/gtest.h" + +@@ -19,14 +18,11 @@ class TestEP : public IExecutionProvider { + static constexpr const char* kEPType = "TestEP"; + + public: +- TestEP() : IExecutionProvider{kEPType} {} ++ TestEP() : IExecutionProvider{kEPType, true} {} + + int GetId(const GraphViewer& viewer, HashValue& model_hash) { +- return metadef_id_generator_.GenerateId(viewer, model_hash); ++ return GenerateMetaDefId(viewer, model_hash); + } +- +- private: +- ModelMetadefIdGenerator metadef_id_generator_; + }; + + TEST(ExecutionProviderTest, MetadefIdGeneratorUsingModelPath) { +diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc +index 6fe0754db..19253e1a5 100644 +--- a/onnxruntime/test/framework/tunable_op_test.cc ++++ b/onnxruntime/test/framework/tunable_op_test.cc +@@ -82,7 +82,7 @@ class TestEP : public IExecutionProvider { + TestTuningContext tuning_ctx_{this}; + + public: +- TestEP() : IExecutionProvider{kEPType} {} ++ TestEP() : IExecutionProvider{kEPType, true} {} + + ITuningContext* GetTuningContext() const override { + return const_cast(&tuning_ctx_); +diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc +index f55368297..4772e7de2 100644 +--- a/onnxruntime/test/global_thread_pools/test_inference.cc ++++ b/onnxruntime/test/global_thread_pools/test_inference.cc +@@ -55,15 +55,9 @@ static void RunSession(OrtAllocator& allocator, Ort::Session& session_object, + // size_t total_len = type_info.GetElementCount(); + ASSERT_EQ(values_y.size(), static_cast(5)); + +-// test inference is using onnxruntime_shared_lib_test_LIBS, so HasCudaEnvironment(800) isn't available +-#ifdef USE_CUDA +- const float tolerance = 1e-5f; +-#else +- const float tolerance = 1e-6f; +-#endif + OutT* f = output_tensor->GetTensorMutableData(); + for (size_t i = 0; i != static_cast(5); ++i) { +- ASSERT_NEAR(values_y[i], f[i], tolerance); ++ ASSERT_NEAR(values_y[i], f[i], 1e-6f); + } + } + +diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +index b7b453415..668d7a061 100644 +--- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp ++++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +@@ -61,11 +61,10 @@ void SQNBITGEMM(benchmark::State& state) { + } + + std::unique_ptr PackedQuantBData; +- if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); ++ if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen); + PackedQuantBDataSize > 0) { + PackedQuantBData = std::make_unique(PackedQuantBDataSize); +- MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), +- tp.get()); ++ MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData.data(), PackedQuantBData.get(), tp.get()); + } + + MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; +@@ -88,9 +87,7 @@ void SQNBITGEMM(benchmark::State& state) { + } + } + +-static void SQ4BitGemmArgs(benchmark::internal::Benchmark* b) { +- constexpr size_t BlkBitWidth = 4; +- ++static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { + b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"}); + + ArgsProductWithFilter(b, +@@ -99,17 +96,19 @@ static void SQ4BitGemmArgs(benchmark::internal::Benchmark* b) { + {1, 1024, 2048}, // M + {4096, 11008}, // N + {4096, 11008}, // K +- {1, 8}, // Threads ++ {8}, // Threads + {int64_t{false}, int64_t{true}}, // Symmetric + {int64_t{CompFp32}, int64_t{CompInt8}}}, // ComputeType + +- [&](const std::vector& args) { ++ [](const std::vector& args) { + return MlasIsSQNBitGemmAvailable( ++ // M, N, K ++ narrow(args[1]), narrow(args[2]), narrow(args[3]), + // BlkBitWidth, BlkLen +- BlkBitWidth, narrow(args[0]), ++ 4, narrow(args[0]), + // ComputeType + static_cast(args[6])); + }); + } + +-BENCHMARK(SQNBITGEMM<4>)->Apply(SQ4BitGemmArgs)->UseRealTime(); ++BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); +diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +index ed09d7ee9..4fb8ab417 100644 +--- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp ++++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +@@ -259,11 +259,10 @@ class MlasSQNBitGemmTest : public MlasTestBase { + } + + void* PackedQuantBData = nullptr; +- if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); ++ if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen); + PackedQuantBDataSize > 0) { + PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); +- MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBData, +- GetMlasThreadPool()); ++ MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData, PackedQuantBData, GetMlasThreadPool()); + } + + if (ComputeType == CompFp32) { +@@ -331,7 +330,7 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture shape; +- gsl::span expected; ++ std::string equation; ++ std::vector shape; ++ std::vector expected; ++ EinsumTestCase(const std::string& eq, const std::vector& sh, const std::vector& exp) : equation(eq), shape(sh), expected(exp) {} + }; +-static constexpr std::string_view equation0 = "abc,cd->abc"; +-static constexpr std::array shape0{2, 2, 2}; +-static constexpr std::array expected0{0.f, 5.f, 2.f, 15.f, 4.f, 25.f, 6.f, 35.f}; +-static constexpr std::string_view equation1 = "abc,cd->abd"; +-static constexpr std::array shape1{2, 2, 2}; +-static constexpr std::array expected1{2.f, 3.f, 6.f, 11.f, 10.f, 19.f, 14.f, 27.f}; +-static constexpr std::string_view equation2 = "abc,cd->acd"; +-static constexpr std::array shape2{2, 2, 2}; +-static constexpr std::array expected2{0.f, 2.f, 8.f, 12.f, 0.f, 10.f, 24.f, 36.f}; +-static constexpr std::string_view equation3 = "abc,dc->abd"; +-static constexpr std::array shape3{2, 2, 2}; +-static constexpr std::array expected3{1.f, 3.f, 3.f, 13.f, 5.f, 23.f, 7.f, 33.f}; +-static constexpr std::string_view equation4 = "abc,dc->abc"; +-static constexpr std::array shape4{2, 2, 2}; +-static constexpr std::array expected4{0.f, 4.f, 4.f, 12.f, 8.f, 20.f, 12.f, 28.f}; +-static constexpr std::string_view equation5 = "abc,dc->acd"; +-static constexpr std::array shape5{2, 2, 2}; +-static constexpr std::array expected5{0.f, 4.f, 4.f, 12.f, 0.f, 20.f, 12.f, 36.f}; +-static constexpr std::string_view equation6 = "acb,cd->acd"; +-static constexpr std::array shape6{2, 2, 2}; +-static constexpr std::array expected6{0.f, 1.f, 10.f, 15.f, 0.f, 9.f, 26.f, 39.f}; +-static constexpr std::string_view equation7 = "acb,cd->abc"; +-static constexpr std::array shape7{2, 2, 2}; +-static constexpr std::array expected7{0.f, 10.f, 1.f, 15.f, 4.f, 30.f, 5.f, 35.f}; +-static constexpr std::string_view equation8 = "acb,cd->abd"; +-static constexpr std::array shape8{2, 2, 2}; +-static constexpr std::array expected8{4.f, 6.f, 6.f, 10.f, 12.f, 22.f, 14.f, 26.f}; +-static constexpr std::string_view equation9 = "acb,dc->acd"; +-static constexpr std::array shape9{2, 2, 2}; +-static constexpr std::array expected9{0.f, 2.f, 5.f, 15.f, 0.f, 18.f, 13.f, 39.f}; +-static constexpr std::string_view equation10 = "acb,dc->abd"; +-static constexpr std::array shape10{2, 2, 2}; +-static constexpr std::array expected10{2.f, 6.f, 3.f, 11.f, 6.f, 26.f, 7.f, 31.f}; +-static constexpr std::string_view equation11 = "acb,dc->abc"; +-static constexpr std::array shape11{2, 2, 2}; +-static constexpr std::array expected11{0.f, 8.f, 2.f, 12.f, 8.f, 24.f, 10.f, 28.f}; +-static constexpr std::string_view equation12 = "bac,cd->bac"; +-static constexpr std::array shape12{2, 2, 2}; +-static constexpr std::array expected12{0.f, 5.f, 2.f, 15.f, 4.f, 25.f, 6.f, 35.f}; +-static constexpr std::string_view equation13 = "bac,cd->bad"; +-static constexpr std::array shape13{2, 2, 2}; +-static constexpr std::array expected13{2.f, 3.f, 6.f, 11.f, 10.f, 19.f, 14.f, 27.f}; +-static constexpr std::string_view equation14 = "bac,cd->bcd"; +-static constexpr std::array shape14{2, 2, 2}; +-static constexpr std::array expected14{0.f, 2.f, 8.f, 12.f, 0.f, 10.f, 24.f, 36.f}; +-static constexpr std::string_view equation15 = "bac,dc->bad"; +-static constexpr std::array shape15{2, 2, 2}; +-static constexpr std::array expected15{1.f, 3.f, 3.f, 13.f, 5.f, 23.f, 7.f, 33.f}; +-static constexpr std::string_view equation16 = "bac,dc->bac"; +-static constexpr std::array shape16{2, 2, 2}; +-static constexpr std::array expected16{0.f, 4.f, 4.f, 12.f, 8.f, 20.f, 12.f, 28.f}; +-static constexpr std::string_view equation17 = "bac,dc->bcd"; +-static constexpr std::array shape17{2, 2, 2}; +-static constexpr std::array expected17{0.f, 4.f, 4.f, 12.f, 0.f, 20.f, 12.f, 36.f}; +-static constexpr std::string_view equation18 = "bca,cd->bcd"; +-static constexpr std::array shape18{2, 2, 2}; +-static constexpr std::array expected18{0.f, 1.f, 10.f, 15.f, 0.f, 9.f, 26.f, 39.f}; +-static constexpr std::string_view equation19 = "bca,cd->bac"; +-static constexpr std::array shape19{2, 2, 2}; +-static constexpr std::array expected19{0.f, 10.f, 1.f, 15.f, 4.f, 30.f, 5.f, 35.f}; +-static constexpr std::string_view equation20 = "bca,cd->bad"; +-static constexpr std::array shape20{2, 2, 2}; +-static constexpr std::array expected20{4.f, 6.f, 6.f, 10.f, 12.f, 22.f, 14.f, 26.f}; +-static constexpr std::string_view equation21 = "bca,dc->bcd"; +-static constexpr std::array shape21{2, 2, 2}; +-static constexpr std::array expected21{0.f, 2.f, 5.f, 15.f, 0.f, 18.f, 13.f, 39.f}; +-static constexpr std::string_view equation22 = "bca,dc->bad"; +-static constexpr std::array shape22{2, 2, 2}; +-static constexpr std::array expected22{2.f, 6.f, 3.f, 11.f, 6.f, 26.f, 7.f, 31.f}; +-static constexpr std::string_view equation23 = "bca,dc->bac"; +-static constexpr std::array shape23{2, 2, 2}; +-static constexpr std::array expected23{0.f, 8.f, 2.f, 12.f, 8.f, 24.f, 10.f, 28.f}; +-static constexpr std::string_view equation24 = "cab,cd->cad"; +-static constexpr std::array shape24{2, 2, 2}; +-static constexpr std::array expected24{0.f, 1.f, 0.f, 5.f, 18.f, 27.f, 26.f, 39.f}; +-static constexpr std::string_view equation25 = "cab,cd->cbd"; +-static constexpr std::array shape25{2, 2, 2}; +-static constexpr std::array expected25{0.f, 2.f, 0.f, 4.f, 20.f, 30.f, 24.f, 36.f}; +-static constexpr std::string_view equation26 = "cab,dc->cad"; +-static constexpr std::array shape26{2, 2, 2}; +-static constexpr std::array expected26{0.f, 2.f, 0.f, 10.f, 9.f, 27.f, 13.f, 39.f}; +-static constexpr std::string_view equation27 = "cab,dc->cbd"; +-static constexpr std::array shape27{2, 2, 2}; +-static constexpr std::array expected27{0.f, 4.f, 0.f, 8.f, 10.f, 30.f, 12.f, 36.f}; +-static constexpr std::string_view equation28 = "cba,cd->cbd"; +-static constexpr std::array shape28{2, 2, 2}; +-static constexpr std::array expected28{0.f, 1.f, 0.f, 5.f, 18.f, 27.f, 26.f, 39.f}; +-static constexpr std::string_view equation29 = "cba,cd->cad"; +-static constexpr std::array shape29{2, 2, 2}; +-static constexpr std::array expected29{0.f, 2.f, 0.f, 4.f, 20.f, 30.f, 24.f, 36.f}; +-static constexpr std::string_view equation30 = "cba,dc->cbd"; +-static constexpr std::array shape30{2, 2, 2}; +-static constexpr std::array expected30{0.f, 2.f, 0.f, 10.f, 9.f, 27.f, 13.f, 39.f}; +-static constexpr std::string_view equation31 = "cba,dc->cad"; +-static constexpr std::array shape31{2, 2, 2}; +-static constexpr std::array expected31{0.f, 4.f, 0.f, 8.f, 10.f, 30.f, 12.f, 36.f}; +-static constexpr std::array case0 = {{ +- {equation0, shape0, expected0}, +- {equation1, shape1, expected1}, +- {equation2, shape2, expected2}, +- {equation3, shape3, expected3}, +- {equation4, shape4, expected4}, +- {equation5, shape5, expected5}, +- {equation6, shape6, expected6}, +- {equation7, shape7, expected7}, +- {equation8, shape8, expected8}, +- {equation9, shape9, expected9}, +- {equation10, shape10, expected10}, +- {equation11, shape11, expected11}, +- {equation12, shape12, expected12}, +- {equation13, shape13, expected13}, +- {equation14, shape14, expected14}, +- {equation15, shape15, expected15}, +- {equation16, shape16, expected16}, +- {equation17, shape17, expected17}, +- {equation18, shape18, expected18}, +- {equation19, shape19, expected19}, +- {equation20, shape20, expected20}, +- {equation21, shape21, expected21}, +- {equation22, shape22, expected22}, +- {equation23, shape23, expected23}, +- {equation24, shape24, expected24}, +- {equation25, shape25, expected25}, +- {equation26, shape26, expected26}, +- {equation27, shape27, expected27}, +- {equation28, shape28, expected28}, +- {equation29, shape29, expected29}, +- {equation30, shape30, expected30}, +- {equation31, shape31, expected31}, +-}}; +- +-static constexpr std::string_view equation32 = "abc,cd,def->abd"; +-static constexpr std::array shape32{2, 2, 2}; +-static constexpr std::array expected32{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}; +-static constexpr std::string_view equation33 = "abc,cd,def->abe"; +-static constexpr std::array shape33{2, 2, 2}; +-static constexpr std::array expected33{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}; +-static constexpr std::string_view equation34 = "abc,cd,def->acd"; +-static constexpr std::array shape34{2, 2, 2}; +-static constexpr std::array expected34{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}; +-static constexpr std::string_view equation35 = "abc,cd,def->ace"; +-static constexpr std::array shape35{2, 2, 2}; +-static constexpr std::array expected35{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}; +-static constexpr std::string_view equation36 = "abc,cd,dfe->abd"; +-static constexpr std::array shape36{2, 2, 2}; +-static constexpr std::array expected36{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}; +-static constexpr std::string_view equation37 = "abc,cd,dfe->abf"; +-static constexpr std::array shape37{2, 2, 2}; +-static constexpr std::array expected37{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}; +-static constexpr std::string_view equation38 = "abc,cd,dfe->acd"; +-static constexpr std::array shape38{2, 2, 2}; +-static constexpr std::array expected38{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}; +-static constexpr std::string_view equation39 = "abc,cd,dfe->acf"; +-static constexpr std::array shape39{2, 2, 2}; +-static constexpr std::array expected39{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}; +-static constexpr std::string_view equation40 = "abc,cd,edf->abe"; +-static constexpr std::array shape40{2, 2, 2}; +-static constexpr std::array expected40{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}; +-static constexpr std::string_view equation41 = "abc,cd,edf->abd"; +-static constexpr std::array shape41{2, 2, 2}; +-static constexpr std::array expected41{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}; +-static constexpr std::string_view equation42 = "abc,cd,edf->ace"; +-static constexpr std::array shape42{2, 2, 2}; +-static constexpr std::array expected42{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}; +-static constexpr std::string_view equation43 = "abc,cd,edf->acd"; +-static constexpr std::array shape43{2, 2, 2}; +-static constexpr std::array expected43{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}; +-static constexpr std::string_view equation44 = "abc,cd,efd->abe"; +-static constexpr std::array shape44{2, 2, 2}; +-static constexpr std::array expected44{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}; +-static constexpr std::string_view equation45 = "abc,cd,efd->abf"; +-static constexpr std::array shape45{2, 2, 2}; +-static constexpr std::array expected45{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}; +-static constexpr std::string_view equation46 = "abc,cd,efd->ace"; +-static constexpr std::array shape46{2, 2, 2}; +-static constexpr std::array expected46{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}; +-static constexpr std::string_view equation47 = "abc,cd,efd->acf"; +-static constexpr std::array shape47{2, 2, 2}; +-static constexpr std::array expected47{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}; +-static constexpr std::string_view equation48 = "abc,cd,fde->abf"; +-static constexpr std::array shape48{2, 2, 2}; +-static constexpr std::array expected48{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}; +-static constexpr std::string_view equation49 = "abc,cd,fde->abd"; +-static constexpr std::array shape49{2, 2, 2}; +-static constexpr std::array expected49{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}; +-static constexpr std::string_view equation50 = "abc,cd,fde->acf"; +-static constexpr std::array shape50{2, 2, 2}; +-static constexpr std::array expected50{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}; +-static constexpr std::string_view equation51 = "abc,cd,fde->acd"; +-static constexpr std::array shape51{2, 2, 2}; +-static constexpr std::array expected51{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}; +-static constexpr std::string_view equation52 = "abc,cd,fed->abf"; +-static constexpr std::array shape52{2, 2, 2}; +-static constexpr std::array expected52{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}; +-static constexpr std::string_view equation53 = "abc,cd,fed->abe"; +-static constexpr std::array shape53{2, 2, 2}; +-static constexpr std::array expected53{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}; +-static constexpr std::string_view equation54 = "abc,cd,fed->acf"; +-static constexpr std::array shape54{2, 2, 2}; +-static constexpr std::array expected54{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}; +-static constexpr std::string_view equation55 = "abc,cd,fed->ace"; +-static constexpr std::array shape55{2, 2, 2}; +-static constexpr std::array expected55{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}; +-static constexpr std::string_view equation56 = "abc,dc,def->abd"; +-static constexpr std::array shape56{2, 2, 2}; +-static constexpr std::array expected56{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}; +-static constexpr std::string_view equation57 = "abc,dc,def->abe"; +-static constexpr std::array shape57{2, 2, 2}; +-static constexpr std::array expected57{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}; +-static constexpr std::string_view equation58 = "abc,dc,def->acd"; +-static constexpr std::array shape58{2, 2, 2}; +-static constexpr std::array expected58{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}; +-static constexpr std::string_view equation59 = "abc,dc,def->ace"; +-static constexpr std::array shape59{2, 2, 2}; +-static constexpr std::array expected59{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}; +-static constexpr std::string_view equation60 = "abc,dc,dfe->abd"; +-static constexpr std::array shape60{2, 2, 2}; +-static constexpr std::array expected60{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}; +-static constexpr std::string_view equation61 = "abc,dc,dfe->abf"; +-static constexpr std::array shape61{2, 2, 2}; +-static constexpr std::array expected61{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}; +-static constexpr std::string_view equation62 = "abc,dc,dfe->acd"; +-static constexpr std::array shape62{2, 2, 2}; +-static constexpr std::array expected62{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}; +-static constexpr std::string_view equation63 = "abc,dc,dfe->acf"; +-static constexpr std::array shape63{2, 2, 2}; +-static constexpr std::array expected63{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}; +-static constexpr std::string_view equation64 = "abc,dc,edf->abe"; +-static constexpr std::array shape64{2, 2, 2}; +-static constexpr std::array expected64{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}; +-static constexpr std::string_view equation65 = "abc,dc,edf->abd"; +-static constexpr std::array shape65{2, 2, 2}; +-static constexpr std::array expected65{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}; +-static constexpr std::string_view equation66 = "abc,dc,edf->ace"; +-static constexpr std::array shape66{2, 2, 2}; +-static constexpr std::array expected66{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}; +-static constexpr std::string_view equation67 = "abc,dc,edf->acd"; +-static constexpr std::array shape67{2, 2, 2}; +-static constexpr std::array expected67{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}; +-static constexpr std::string_view equation68 = "abc,dc,efd->abe"; +-static constexpr std::array shape68{2, 2, 2}; +-static constexpr std::array expected68{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}; +-static constexpr std::string_view equation69 = "abc,dc,efd->abf"; +-static constexpr std::array shape69{2, 2, 2}; +-static constexpr std::array expected69{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}; +-static constexpr std::string_view equation70 = "abc,dc,efd->ace"; +-static constexpr std::array shape70{2, 2, 2}; +-static constexpr std::array expected70{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}; +-static constexpr std::string_view equation71 = "abc,dc,efd->acf"; +-static constexpr std::array shape71{2, 2, 2}; +-static constexpr std::array expected71{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}; +-static constexpr std::string_view equation72 = "abc,dc,fde->abf"; +-static constexpr std::array shape72{2, 2, 2}; +-static constexpr std::array expected72{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}; +-static constexpr std::string_view equation73 = "abc,dc,fde->abd"; +-static constexpr std::array shape73{2, 2, 2}; +-static constexpr std::array expected73{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}; +-static constexpr std::string_view equation74 = "abc,dc,fde->acf"; +-static constexpr std::array shape74{2, 2, 2}; +-static constexpr std::array expected74{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}; +-static constexpr std::string_view equation75 = "abc,dc,fde->acd"; +-static constexpr std::array shape75{2, 2, 2}; +-static constexpr std::array expected75{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}; +-static constexpr std::string_view equation76 = "abc,dc,fed->abf"; +-static constexpr std::array shape76{2, 2, 2}; +-static constexpr std::array expected76{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}; +-static constexpr std::string_view equation77 = "abc,dc,fed->abe"; +-static constexpr std::array shape77{2, 2, 2}; +-static constexpr std::array expected77{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}; +-static constexpr std::string_view equation78 = "abc,dc,fed->acf"; +-static constexpr std::array shape78{2, 2, 2}; +-static constexpr std::array expected78{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}; +-static constexpr std::string_view equation79 = "abc,dc,fed->ace"; +-static constexpr std::array shape79{2, 2, 2}; +-static constexpr std::array expected79{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}; +-static constexpr std::string_view equation80 = "acb,cd,def->acd"; +-static constexpr std::array shape80{2, 2, 2}; +-static constexpr std::array expected80{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}; +-static constexpr std::string_view equation81 = "acb,cd,def->ace"; +-static constexpr std::array shape81{2, 2, 2}; +-static constexpr std::array expected81{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}; +-static constexpr std::string_view equation82 = "acb,cd,def->abd"; +-static constexpr std::array shape82{2, 2, 2}; +-static constexpr std::array expected82{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}; +-static constexpr std::string_view equation83 = "acb,cd,def->abe"; +-static constexpr std::array shape83{2, 2, 2}; +-static constexpr std::array expected83{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}; +-static constexpr std::string_view equation84 = "acb,cd,dfe->acd"; +-static constexpr std::array shape84{2, 2, 2}; +-static constexpr std::array expected84{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}; +-static constexpr std::string_view equation85 = "acb,cd,dfe->acf"; +-static constexpr std::array shape85{2, 2, 2}; +-static constexpr std::array expected85{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}; +-static constexpr std::string_view equation86 = "acb,cd,dfe->abd"; +-static constexpr std::array shape86{2, 2, 2}; +-static constexpr std::array expected86{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}; +-static constexpr std::string_view equation87 = "acb,cd,dfe->abf"; +-static constexpr std::array shape87{2, 2, 2}; +-static constexpr std::array expected87{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}; +-static constexpr std::string_view equation88 = "acb,cd,edf->ace"; +-static constexpr std::array shape88{2, 2, 2}; +-static constexpr std::array expected88{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}; +-static constexpr std::string_view equation89 = "acb,cd,edf->acd"; +-static constexpr std::array shape89{2, 2, 2}; +-static constexpr std::array expected89{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}; +-static constexpr std::string_view equation90 = "acb,cd,edf->abe"; +-static constexpr std::array shape90{2, 2, 2}; +-static constexpr std::array expected90{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}; +-static constexpr std::string_view equation91 = "acb,cd,edf->abd"; +-static constexpr std::array shape91{2, 2, 2}; +-static constexpr std::array expected91{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}; +-static constexpr std::string_view equation92 = "acb,cd,efd->ace"; +-static constexpr std::array shape92{2, 2, 2}; +-static constexpr std::array expected92{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}; +-static constexpr std::string_view equation93 = "acb,cd,efd->acf"; +-static constexpr std::array shape93{2, 2, 2}; +-static constexpr std::array expected93{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}; +-static constexpr std::string_view equation94 = "acb,cd,efd->abe"; +-static constexpr std::array shape94{2, 2, 2}; +-static constexpr std::array expected94{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}; +-static constexpr std::string_view equation95 = "acb,cd,efd->abf"; +-static constexpr std::array shape95{2, 2, 2}; +-static constexpr std::array expected95{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}; +-static constexpr std::string_view equation96 = "acb,cd,fde->acf"; +-static constexpr std::array shape96{2, 2, 2}; +-static constexpr std::array expected96{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}; +-static constexpr std::string_view equation97 = "acb,cd,fde->acd"; +-static constexpr std::array shape97{2, 2, 2}; +-static constexpr std::array expected97{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}; +-static constexpr std::string_view equation98 = "acb,cd,fde->abf"; +-static constexpr std::array shape98{2, 2, 2}; +-static constexpr std::array expected98{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}; +-static constexpr std::string_view equation99 = "acb,cd,fde->abd"; +-static constexpr std::array shape99{2, 2, 2}; +-static constexpr std::array expected99{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}; +-static constexpr std::string_view equation100 = "acb,cd,fed->acf"; +-static constexpr std::array shape100{2, 2, 2}; +-static constexpr std::array expected100{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}; +-static constexpr std::string_view equation101 = "acb,cd,fed->ace"; +-static constexpr std::array shape101{2, 2, 2}; +-static constexpr std::array expected101{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}; +-static constexpr std::string_view equation102 = "acb,cd,fed->abf"; +-static constexpr std::array shape102{2, 2, 2}; +-static constexpr std::array expected102{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}; +-static constexpr std::string_view equation103 = "acb,cd,fed->abe"; +-static constexpr std::array shape103{2, 2, 2}; +-static constexpr std::array expected103{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}; +-static constexpr std::string_view equation104 = "acb,dc,def->acd"; +-static constexpr std::array shape104{2, 2, 2}; +-static constexpr std::array expected104{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}; +-static constexpr std::string_view equation105 = "acb,dc,def->ace"; +-static constexpr std::array shape105{2, 2, 2}; +-static constexpr std::array expected105{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}; +- +-static constexpr std::string_view equation106 = "acb,dc,def->abd"; +-static constexpr std::array shape106{2, 2, 2}; +-static constexpr std::array expected106{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}; +-static constexpr std::string_view equation107 = "acb,dc,def->abe"; +-static constexpr std::array shape107{2, 2, 2}; +-static constexpr std::array expected107{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}; +-static constexpr std::string_view equation108 = "acb,dc,dfe->acd"; +-static constexpr std::array shape108{2, 2, 2}; +-static constexpr std::array expected108{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}; +-static constexpr std::string_view equation109 = "acb,dc,dfe->acf"; +-static constexpr std::array shape109{2, 2, 2}; +-static constexpr std::array expected109{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}; +-static constexpr std::string_view equation110 = "acb,dc,dfe->abd"; +-static constexpr std::array shape110{2, 2, 2}; +-static constexpr std::array expected110{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}; +-static constexpr std::string_view equation111 = "acb,dc,dfe->abf"; +-static constexpr std::array shape111{2, 2, 2}; +-static constexpr std::array expected111{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}; +-static constexpr std::string_view equation112 = "acb,dc,edf->ace"; +-static constexpr std::array shape112{2, 2, 2}; +-static constexpr std::array expected112{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}; +-static constexpr std::string_view equation113 = "acb,dc,edf->acd"; +-static constexpr std::array shape113{2, 2, 2}; +-static constexpr std::array expected113{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}; +-static constexpr std::string_view equation114 = "acb,dc,edf->abe"; +-static constexpr std::array shape114{2, 2, 2}; +-static constexpr std::array expected114{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}; +-static constexpr std::string_view equation115 = "acb,dc,edf->abd"; +-static constexpr std::array shape115{2, 2, 2}; +-static constexpr std::array expected115{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}; +-static constexpr std::string_view equation116 = "acb,dc,efd->ace"; +-static constexpr std::array shape116{2, 2, 2}; +-static constexpr std::array expected116{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}; +-static constexpr std::string_view equation117 = "acb,dc,efd->acf"; +-static constexpr std::array shape117{2, 2, 2}; +-static constexpr std::array expected117{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}; +-static constexpr std::string_view equation118 = "acb,dc,efd->abe"; +-static constexpr std::array shape118{2, 2, 2}; +-static constexpr std::array expected118{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}; +-static constexpr std::string_view equation119 = "acb,dc,efd->abf"; +-static constexpr std::array shape119{2, 2, 2}; +-static constexpr std::array expected119{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}; +-static constexpr std::string_view equation120 = "acb,dc,fde->acf"; +-static constexpr std::array shape120{2, 2, 2}; +-static constexpr std::array expected120{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}; +-static constexpr std::string_view equation121 = "acb,dc,fde->acd"; +-static constexpr std::array shape121{2, 2, 2}; +-static constexpr std::array expected121{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}; +-static constexpr std::string_view equation122 = "acb,dc,fde->abf"; +-static constexpr std::array shape122{2, 2, 2}; +-static constexpr std::array expected122{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}; +-static constexpr std::string_view equation123 = "acb,dc,fde->abd"; +-static constexpr std::array shape123{2, 2, 2}; +-static constexpr std::array expected123{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}; +-static constexpr std::string_view equation124 = "acb,dc,fed->acf"; +-static constexpr std::array shape124{2, 2, 2}; +-static constexpr std::array expected124{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}; +-static constexpr std::string_view equation125 = "acb,dc,fed->ace"; +-static constexpr std::array shape125{2, 2, 2}; +-static constexpr std::array expected125{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}; +-static constexpr std::string_view equation126 = "acb,dc,fed->abf"; +-static constexpr std::array shape126{2, 2, 2}; +-static constexpr std::array expected126{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}; +-static constexpr std::string_view equation127 = "acb,dc,fed->abe"; +-static constexpr std::array shape127{2, 2, 2}; +-static constexpr std::array expected127{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}; +-static constexpr std::string_view equation128 = "bac,cd,def->bad"; +-static constexpr std::array shape128{2, 2, 2}; +-static constexpr std::array expected128{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}; +-static constexpr std::string_view equation129 = "bac,cd,def->bae"; +-static constexpr std::array shape129{2, 2, 2}; +-static constexpr std::array expected129{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}; +-static constexpr std::string_view equation130 = "bac,cd,def->bcd"; +-static constexpr std::array shape130{2, 2, 2}; +-static constexpr std::array expected130{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}; +-static constexpr std::string_view equation131 = "bac,cd,def->bce"; +-static constexpr std::array shape131{2, 2, 2}; +-static constexpr std::array expected131{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}; +-static constexpr std::string_view equation132 = "bac,cd,dfe->bad"; +-static constexpr std::array shape132{2, 2, 2}; +-static constexpr std::array expected132{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}; +-static constexpr std::string_view equation133 = "bac,cd,dfe->baf"; +-static constexpr std::array shape133{2, 2, 2}; +-static constexpr std::array expected133{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}; +-static constexpr std::string_view equation134 = "bac,cd,dfe->bcd"; +-static constexpr std::array shape134{2, 2, 2}; +-static constexpr std::array expected134{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}; +-static constexpr std::string_view equation135 = "bac,cd,dfe->bcf"; +-static constexpr std::array shape135{2, 2, 2}; +-static constexpr std::array expected135{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}; +-static constexpr std::string_view equation136 = "bac,cd,edf->bae"; +-static constexpr std::array shape136{2, 2, 2}; +-static constexpr std::array expected136{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}; +-static constexpr std::string_view equation137 = "bac,cd,edf->bad"; +-static constexpr std::array shape137{2, 2, 2}; +-static constexpr std::array expected137{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}; +-static constexpr std::string_view equation138 = "bac,cd,edf->bce"; +-static constexpr std::array shape138{2, 2, 2}; +-static constexpr std::array expected138{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}; +-static constexpr std::string_view equation139 = "bac,cd,edf->bcd"; +-static constexpr std::array shape139{2, 2, 2}; +-static constexpr std::array expected139{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}; +-static constexpr std::string_view equation140 = "bac,cd,efd->bae"; +-static constexpr std::array shape140{2, 2, 2}; +-static constexpr std::array expected140{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}; +-static constexpr std::string_view equation141 = "bac,cd,efd->baf"; +-static constexpr std::array shape141{2, 2, 2}; +-static constexpr std::array expected141{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}; +-static constexpr std::string_view equation142 = "bac,cd,efd->bce"; +-static constexpr std::array shape142{2, 2, 2}; +-static constexpr std::array expected142{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}; +-static constexpr std::string_view equation143 = "bac,cd,efd->bcf"; +-static constexpr std::array shape143{2, 2, 2}; +-static constexpr std::array expected143{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}; +-static constexpr std::string_view equation144 = "bac,cd,fde->baf"; +-static constexpr std::array shape144{2, 2, 2}; +-static constexpr std::array expected144{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}; +-static constexpr std::string_view equation145 = "bac,cd,fde->bad"; +-static constexpr std::array shape145{2, 2, 2}; +-static constexpr std::array expected145{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}; +-static constexpr std::string_view equation146 = "bac,cd,fde->bcf"; +-static constexpr std::array shape146{2, 2, 2}; +-static constexpr std::array expected146{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}; +-static constexpr std::string_view equation147 = "bac,cd,fde->bcd"; +-static constexpr std::array shape147{2, 2, 2}; +-static constexpr std::array expected147{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}; +-static constexpr std::string_view equation148 = "bac,cd,fed->baf"; +-static constexpr std::array shape148{2, 2, 2}; +-static constexpr std::array expected148{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}; +-static constexpr std::string_view equation149 = "bac,cd,fed->bae"; +-static constexpr std::array shape149{2, 2, 2}; +-static constexpr std::array expected149{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}; +-static constexpr std::string_view equation150 = "bac,cd,fed->bcf"; +-static constexpr std::array shape150{2, 2, 2}; +-static constexpr std::array expected150{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}; +-static constexpr std::string_view equation151 = "bac,cd,fed->bce"; +-static constexpr std::array shape151{2, 2, 2}; +-static constexpr std::array expected151{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}; +-static constexpr std::string_view equation152 = "bac,dc,def->bad"; +-static constexpr std::array shape152{2, 2, 2}; +-static constexpr std::array expected152{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}; +-static constexpr std::string_view equation153 = "bac,dc,def->bae"; +-static constexpr std::array shape153{2, 2, 2}; +-static constexpr std::array expected153{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}; +-static constexpr std::string_view equation154 = "bac,dc,def->bcd"; +-static constexpr std::array shape154{2, 2, 2}; +-static constexpr std::array expected154{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}; +-static constexpr std::string_view equation155 = "bac,dc,def->bce"; +-static constexpr std::array shape155{2, 2, 2}; +-static constexpr std::array expected155{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}; +-static constexpr std::string_view equation156 = "bac,dc,dfe->bad"; +-static constexpr std::array shape156{2, 2, 2}; +-static constexpr std::array expected156{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}; +-static constexpr std::string_view equation157 = "bac,dc,dfe->baf"; +-static constexpr std::array shape157{2, 2, 2}; +-static constexpr std::array expected157{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}; +-static constexpr std::string_view equation158 = "bac,dc,dfe->bcd"; +-static constexpr std::array shape158{2, 2, 2}; +-static constexpr std::array expected158{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}; +-static constexpr std::string_view equation159 = "bac,dc,dfe->bcf"; +-static constexpr std::array shape159{2, 2, 2}; +-static constexpr std::array expected159{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}; +-static constexpr std::string_view equation160 = "bac,dc,edf->bae"; +-static constexpr std::array shape160{2, 2, 2}; +-static constexpr std::array expected160{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}; +-static constexpr std::string_view equation161 = "bac,dc,edf->bad"; +-static constexpr std::array shape161{2, 2, 2}; +-static constexpr std::array expected161{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}; +-static constexpr std::string_view equation162 = "bac,dc,edf->bce"; +-static constexpr std::array shape162{2, 2, 2}; +-static constexpr std::array expected162{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}; +-static constexpr std::string_view equation163 = "bac,dc,edf->bcd"; +-static constexpr std::array shape163{2, 2, 2}; +-static constexpr std::array expected163{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}; +-static constexpr std::string_view equation164 = "bac,dc,efd->bae"; +-static constexpr std::array shape164{2, 2, 2}; +-static constexpr std::array expected164{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}; +-static constexpr std::string_view equation165 = "bac,dc,efd->baf"; +-static constexpr std::array shape165{2, 2, 2}; +-static constexpr std::array expected165{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}; +-static constexpr std::string_view equation166 = "bac,dc,efd->bce"; +-static constexpr std::array shape166{2, 2, 2}; +-static constexpr std::array expected166{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}; +-static constexpr std::string_view equation167 = "bac,dc,efd->bcf"; +-static constexpr std::array shape167{2, 2, 2}; +-static constexpr std::array expected167{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}; +-static constexpr std::string_view equation168 = "bac,dc,fde->baf"; +-static constexpr std::array shape168{2, 2, 2}; +-static constexpr std::array expected168{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}; +-static constexpr std::string_view equation169 = "bac,dc,fde->bad"; +-static constexpr std::array shape169{2, 2, 2}; +-static constexpr std::array expected169{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}; +-static constexpr std::string_view equation170 = "bac,dc,fde->bcf"; +-static constexpr std::array shape170{2, 2, 2}; +-static constexpr std::array expected170{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}; +-static constexpr std::string_view equation171 = "bac,dc,fde->bcd"; +-static constexpr std::array shape171{2, 2, 2}; +-static constexpr std::array expected171{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}; +-static constexpr std::string_view equation172 = "bac,dc,fed->baf"; +-static constexpr std::array shape172{2, 2, 2}; +-static constexpr std::array expected172{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}; +-static constexpr std::string_view equation173 = "bac,dc,fed->bae"; +-static constexpr std::array shape173{2, 2, 2}; +-static constexpr std::array expected173{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}; +-static constexpr std::string_view equation174 = "bac,dc,fed->bcf"; +-static constexpr std::array shape174{2, 2, 2}; +-static constexpr std::array expected174{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}; +-static constexpr std::string_view equation175 = "bac,dc,fed->bce"; +-static constexpr std::array shape175{2, 2, 2}; +-static constexpr std::array expected175{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}; +-static constexpr std::string_view equation176 = "bca,cd,def->bcd"; +-static constexpr std::array shape176{2, 2, 2}; +-static constexpr std::array expected176{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}; +-static constexpr std::string_view equation177 = "bca,cd,def->bce"; +-static constexpr std::array shape177{2, 2, 2}; +-static constexpr std::array expected177{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}; +-static constexpr std::string_view equation178 = "bca,cd,def->bad"; +-static constexpr std::array shape178{2, 2, 2}; +-static constexpr std::array expected178{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}; +-static constexpr std::string_view equation179 = "bca,cd,def->bae"; +-static constexpr std::array shape179{2, 2, 2}; +-static constexpr std::array expected179{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}; +-static constexpr std::string_view equation180 = "bca,cd,dfe->bcd"; +-static constexpr std::array shape180{2, 2, 2}; +-static constexpr std::array expected180{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}; +-static constexpr std::string_view equation181 = "bca,cd,dfe->bcf"; +-static constexpr std::array shape181{2, 2, 2}; +-static constexpr std::array expected181{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}; +-static constexpr std::string_view equation182 = "bca,cd,dfe->bad"; +-static constexpr std::array shape182{2, 2, 2}; +-static constexpr std::array expected182{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}; +-static constexpr std::string_view equation183 = "bca,cd,dfe->baf"; +-static constexpr std::array shape183{2, 2, 2}; +-static constexpr std::array expected183{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}; +-static constexpr std::string_view equation184 = "bca,cd,edf->bce"; +-static constexpr std::array shape184{2, 2, 2}; +-static constexpr std::array expected184{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}; +-static constexpr std::string_view equation185 = "bca,cd,edf->bcd"; +-static constexpr std::array shape185{2, 2, 2}; +-static constexpr std::array expected185{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}; +-static constexpr std::string_view equation186 = "bca,cd,edf->bae"; +-static constexpr std::array shape186{2, 2, 2}; +-static constexpr std::array expected186{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}; +-static constexpr std::string_view equation187 = "bca,cd,edf->bad"; +-static constexpr std::array shape187{2, 2, 2}; +-static constexpr std::array expected187{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}; +-static constexpr std::string_view equation188 = "bca,cd,efd->bce"; +-static constexpr std::array shape188{2, 2, 2}; +-static constexpr std::array expected188{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}; +-static constexpr std::string_view equation189 = "bca,cd,efd->bcf"; +-static constexpr std::array shape189{2, 2, 2}; +-static constexpr std::array expected189{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}; +-static constexpr std::string_view equation190 = "bca,cd,efd->bae"; +-static constexpr std::array shape190{2, 2, 2}; +-static constexpr std::array expected190{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}; +-static constexpr std::string_view equation191 = "bca,cd,efd->baf"; +-static constexpr std::array shape191{2, 2, 2}; +-static constexpr std::array expected191{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}; +-static constexpr std::string_view equation192 = "bca,cd,fde->bcf"; +-static constexpr std::array shape192{2, 2, 2}; +-static constexpr std::array expected192{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}; +-static constexpr std::string_view equation193 = "bca,cd,fde->bcd"; +-static constexpr std::array shape193{2, 2, 2}; +-static constexpr std::array expected193{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}; +-static constexpr std::string_view equation194 = "bca,cd,fde->baf"; +-static constexpr std::array shape194{2, 2, 2}; +-static constexpr std::array expected194{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}; +-static constexpr std::string_view equation195 = "bca,cd,fde->bad"; +-static constexpr std::array shape195{2, 2, 2}; +-static constexpr std::array expected195{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}; +-static constexpr std::string_view equation196 = "bca,cd,fed->bcf"; +-static constexpr std::array shape196{2, 2, 2}; +-static constexpr std::array expected196{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}; +-static constexpr std::string_view equation197 = "bca,cd,fed->bce"; +-static constexpr std::array shape197{2, 2, 2}; +-static constexpr std::array expected197{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}; +-static constexpr std::string_view equation198 = "bca,cd,fed->baf"; +-static constexpr std::array shape198{2, 2, 2}; +-static constexpr std::array expected198{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}; +-static constexpr std::string_view equation199 = "bca,cd,fed->bae"; +-static constexpr std::array shape199{2, 2, 2}; +-static constexpr std::array expected199{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}; +-static constexpr std::string_view equation200 = "bca,dc,def->bcd"; +-static constexpr std::array shape200{2, 2, 2}; +-static constexpr std::array expected200{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}; +-static constexpr std::string_view equation201 = "bca,dc,def->bce"; +-static constexpr std::array shape201{2, 2, 2}; +-static constexpr std::array expected201{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}; +-static constexpr std::string_view equation202 = "bca,dc,def->bad"; +-static constexpr std::array shape202{2, 2, 2}; +-static constexpr std::array expected202{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}; +-static constexpr std::string_view equation203 = "bca,dc,def->bae"; +-static constexpr std::array shape203{2, 2, 2}; +-static constexpr std::array expected203{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}; +-static constexpr std::string_view equation204 = "bca,dc,dfe->bcd"; +-static constexpr std::array shape204{2, 2, 2}; +-static constexpr std::array expected204{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}; +-static constexpr std::string_view equation205 = "bca,dc,dfe->bcf"; +-static constexpr std::array shape205{2, 2, 2}; +-static constexpr std::array expected205{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}; +-static constexpr std::string_view equation206 = "bca,dc,dfe->bad"; +-static constexpr std::array shape206{2, 2, 2}; +-static constexpr std::array expected206{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}; +-static constexpr std::string_view equation207 = "bca,dc,dfe->baf"; +-static constexpr std::array shape207{2, 2, 2}; +-static constexpr std::array expected207{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}; +-static constexpr std::string_view equation208 = "bca,dc,edf->bce"; +-static constexpr std::array shape208{2, 2, 2}; +-static constexpr std::array expected208{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}; +-static constexpr std::string_view equation209 = "bca,dc,edf->bcd"; +-static constexpr std::array shape209{2, 2, 2}; +-static constexpr std::array expected209{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}; +-static constexpr std::string_view equation210 = "bca,dc,edf->bae"; +-static constexpr std::array shape210{2, 2, 2}; +-static constexpr std::array expected210{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}; +-static constexpr std::string_view equation211 = "bca,dc,edf->bad"; +-static constexpr std::array shape211{2, 2, 2}; +-static constexpr std::array expected211{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}; +-static constexpr std::string_view equation212 = "bca,dc,efd->bce"; +-static constexpr std::array shape212{2, 2, 2}; +-static constexpr std::array expected212{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}; +-static constexpr std::string_view equation213 = "bca,dc,efd->bcf"; +-static constexpr std::array shape213{2, 2, 2}; +-static constexpr std::array expected213{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}; +-static constexpr std::string_view equation214 = "bca,dc,efd->bae"; +-static constexpr std::array shape214{2, 2, 2}; +-static constexpr std::array expected214{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}; +-static constexpr std::string_view equation215 = "bca,dc,efd->baf"; +-static constexpr std::array shape215{2, 2, 2}; +-static constexpr std::array expected215{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}; +-static constexpr std::string_view equation216 = "bca,dc,fde->bcf"; +-static constexpr std::array shape216{2, 2, 2}; +-static constexpr std::array expected216{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}; +-static constexpr std::string_view equation217 = "bca,dc,fde->bcd"; +-static constexpr std::array shape217{2, 2, 2}; +-static constexpr std::array expected217{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}; +-static constexpr std::string_view equation218 = "bca,dc,fde->baf"; +-static constexpr std::array shape218{2, 2, 2}; +-static constexpr std::array expected218{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}; +-static constexpr std::string_view equation219 = "bca,dc,fde->bad"; +-static constexpr std::array shape219{2, 2, 2}; +-static constexpr std::array expected219{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}; +-static constexpr std::string_view equation220 = "bca,dc,fed->bcf"; +-static constexpr std::array shape220{2, 2, 2}; +-static constexpr std::array expected220{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}; +-static constexpr std::string_view equation221 = "bca,dc,fed->bce"; +-static constexpr std::array shape221{2, 2, 2}; +-static constexpr std::array expected221{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}; +-static constexpr std::string_view equation222 = "bca,dc,fed->baf"; +-static constexpr std::array shape222{2, 2, 2}; +-static constexpr std::array expected222{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}; +-static constexpr std::string_view equation223 = "bca,dc,fed->bae"; +-static constexpr std::array shape223{2, 2, 2}; +-static constexpr std::array expected223{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}; +-static constexpr std::string_view equation224 = "cab,cd,def->cad"; +-static constexpr std::array shape224{2, 2, 2}; +-static constexpr std::array expected224{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}; +-static constexpr std::string_view equation225 = "cab,cd,def->cae"; +-static constexpr std::array shape225{2, 2, 2}; +-static constexpr std::array expected225{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}; +-static constexpr std::string_view equation226 = "cab,cd,def->cbd"; +-static constexpr std::array shape226{2, 2, 2}; +-static constexpr std::array expected226{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}; +-static constexpr std::string_view equation227 = "cab,cd,def->cbe"; +-static constexpr std::array shape227{2, 2, 2}; +-static constexpr std::array expected227{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}; +-static constexpr std::string_view equation228 = "cab,cd,dfe->cad"; +-static constexpr std::array shape228{2, 2, 2}; +-static constexpr std::array expected228{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}; +-static constexpr std::string_view equation229 = "cab,cd,dfe->caf"; +-static constexpr std::array shape229{2, 2, 2}; +-static constexpr std::array expected229{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}; +-static constexpr std::string_view equation230 = "cab,cd,dfe->cbd"; +-static constexpr std::array shape230{2, 2, 2}; +-static constexpr std::array expected230{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}; +-static constexpr std::string_view equation231 = "cab,cd,dfe->cbf"; +-static constexpr std::array shape231{2, 2, 2}; +-static constexpr std::array expected231{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}; +-static constexpr std::string_view equation232 = "cab,cd,edf->cae"; +-static constexpr std::array shape232{2, 2, 2}; +-static constexpr std::array expected232{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}; +-static constexpr std::string_view equation233 = "cab,cd,edf->cad"; +-static constexpr std::array shape233{2, 2, 2}; +-static constexpr std::array expected233{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}; +-static constexpr std::string_view equation234 = "cab,cd,edf->cbe"; +-static constexpr std::array shape234{2, 2, 2}; +-static constexpr std::array expected234{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}; +-static constexpr std::string_view equation235 = "cab,cd,edf->cbd"; +-static constexpr std::array shape235{2, 2, 2}; +-static constexpr std::array expected235{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}; +-static constexpr std::string_view equation236 = "cab,cd,efd->cae"; +-static constexpr std::array shape236{2, 2, 2}; +-static constexpr std::array expected236{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}; +-static constexpr std::string_view equation237 = "cab,cd,efd->caf"; +-static constexpr std::array shape237{2, 2, 2}; +-static constexpr std::array expected237{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}; +-static constexpr std::string_view equation238 = "cab,cd,efd->cbe"; +-static constexpr std::array shape238{2, 2, 2}; +-static constexpr std::array expected238{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}; +-static constexpr std::string_view equation239 = "cab,cd,efd->cbf"; +-static constexpr std::array shape239{2, 2, 2}; +-static constexpr std::array expected239{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}; +-static constexpr std::string_view equation240 = "cab,cd,fde->caf"; +-static constexpr std::array shape240{2, 2, 2}; +-static constexpr std::array expected240{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}; +-static constexpr std::string_view equation241 = "cab,cd,fde->cad"; +-static constexpr std::array shape241{2, 2, 2}; +-static constexpr std::array expected241{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}; +-static constexpr std::string_view equation242 = "cab,cd,fde->cbf"; +-static constexpr std::array shape242{2, 2, 2}; +-static constexpr std::array expected242{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}; +-static constexpr std::string_view equation243 = "cab,cd,fde->cbd"; +-static constexpr std::array shape243{2, 2, 2}; +-static constexpr std::array expected243{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}; +-static constexpr std::string_view equation244 = "cab,cd,fed->caf"; +-static constexpr std::array shape244{2, 2, 2}; +-static constexpr std::array expected244{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}; +-static constexpr std::string_view equation245 = "cab,cd,fed->cae"; +-static constexpr std::array shape245{2, 2, 2}; +-static constexpr std::array expected245{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}; +-static constexpr std::string_view equation246 = "cab,cd,fed->cbf"; +-static constexpr std::array shape246{2, 2, 2}; +-static constexpr std::array expected246{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}; +-static constexpr std::string_view equation247 = "cab,cd,fed->cbe"; +-static constexpr std::array shape247{2, 2, 2}; +-static constexpr std::array expected247{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}; +-static constexpr std::string_view equation248 = "cab,dc,def->cad"; +-static constexpr std::array shape248{2, 2, 2}; +-static constexpr std::array expected248{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}; +-static constexpr std::string_view equation249 = "cab,dc,def->cae"; +-static constexpr std::array shape249{2, 2, 2}; +-static constexpr std::array expected249{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}; +-static constexpr std::string_view equation250 = "cab,dc,def->cbd"; +-static constexpr std::array shape250{2, 2, 2}; +-static constexpr std::array expected250{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}; +- +-static constexpr std::string_view equation251 = "cab,dc,def->cbe"; +-static constexpr std::array shape251{2, 2, 2}; +-static constexpr std::array expected251{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}; +-static constexpr std::string_view equation252 = "cab,dc,dfe->cad"; +-static constexpr std::array shape252{2, 2, 2}; +-static constexpr std::array expected252{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}; +-static constexpr std::string_view equation253 = "cab,dc,dfe->caf"; +-static constexpr std::array shape253{2, 2, 2}; +-static constexpr std::array expected253{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}; +-static constexpr std::string_view equation254 = "cab,dc,dfe->cbd"; +-static constexpr std::array shape254{2, 2, 2}; +-static constexpr std::array expected254{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}; +-static constexpr std::string_view equation255 = "cab,dc,dfe->cbf"; +-static constexpr std::array shape255{2, 2, 2}; +-static constexpr std::array expected255{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}; +-static constexpr std::string_view equation256 = "cab,dc,edf->cae"; +-static constexpr std::array shape256{2, 2, 2}; +-static constexpr std::array expected256{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}; +-static constexpr std::string_view equation257 = "cab,dc,edf->cad"; +-static constexpr std::array shape257{2, 2, 2}; +-static constexpr std::array expected257{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}; +-static constexpr std::string_view equation258 = "cab,dc,edf->cbe"; +-static constexpr std::array shape258{2, 2, 2}; +-static constexpr std::array expected258{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}; +-static constexpr std::string_view equation259 = "cab,dc,edf->cbd"; +-static constexpr std::array shape259{2, 2, 2}; +-static constexpr std::array expected259{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}; +-static constexpr std::string_view equation260 = "cab,dc,efd->cae"; +-static constexpr std::array shape260{2, 2, 2}; +-static constexpr std::array expected260{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}; +-static constexpr std::string_view equation261 = "cab,dc,efd->caf"; +-static constexpr std::array shape261{2, 2, 2}; +-static constexpr std::array expected261{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}; +-static constexpr std::string_view equation262 = "cab,dc,efd->cbe"; +-static constexpr std::array shape262{2, 2, 2}; +-static constexpr std::array expected262{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}; +-static constexpr std::string_view equation263 = "cab,dc,efd->cbf"; +-static constexpr std::array shape263{2, 2, 2}; +-static constexpr std::array expected263{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}; +-static constexpr std::string_view equation264 = "cab,dc,fde->caf"; +-static constexpr std::array shape264{2, 2, 2}; +-static constexpr std::array expected264{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}; +-static constexpr std::string_view equation265 = "cab,dc,fde->cad"; +-static constexpr std::array shape265{2, 2, 2}; +-static constexpr std::array expected265{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}; +-static constexpr std::string_view equation266 = "cab,dc,fde->cbf"; +-static constexpr std::array shape266{2, 2, 2}; +-static constexpr std::array expected266{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}; +-static constexpr std::string_view equation267 = "cab,dc,fde->cbd"; +-static constexpr std::array shape267{2, 2, 2}; +-static constexpr std::array expected267{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}; +-static constexpr std::string_view equation268 = "cab,dc,fed->caf"; +-static constexpr std::array shape268{2, 2, 2}; +-static constexpr std::array expected268{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}; +-static constexpr std::string_view equation269 = "cab,dc,fed->cae"; +-static constexpr std::array shape269{2, 2, 2}; +-static constexpr std::array expected269{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}; +-static constexpr std::string_view equation270 = "cab,dc,fed->cbf"; +-static constexpr std::array shape270{2, 2, 2}; +-static constexpr std::array expected270{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}; +-static constexpr std::string_view equation271 = "cab,dc,fed->cbe"; +-static constexpr std::array shape271{2, 2, 2}; +-static constexpr std::array expected271{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}; +-static constexpr std::string_view equation272 = "cba,cd,def->cbd"; +-static constexpr std::array shape272{2, 2, 2}; +-static constexpr std::array expected272{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}; +-static constexpr std::string_view equation273 = "cba,cd,def->cbe"; +-static constexpr std::array shape273{2, 2, 2}; +-static constexpr std::array expected273{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}; +-static constexpr std::string_view equation274 = "cba,cd,def->cad"; +-static constexpr std::array shape274{2, 2, 2}; +-static constexpr std::array expected274{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}; +-static constexpr std::string_view equation275 = "cba,cd,def->cae"; +-static constexpr std::array shape275{2, 2, 2}; +-static constexpr std::array expected275{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}; +-static constexpr std::string_view equation276 = "cba,cd,dfe->cbd"; +-static constexpr std::array shape276{2, 2, 2}; +-static constexpr std::array expected276{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}; +-static constexpr std::string_view equation277 = "cba,cd,dfe->cbf"; +-static constexpr std::array shape277{2, 2, 2}; +-static constexpr std::array expected277{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}; +-static constexpr std::string_view equation278 = "cba,cd,dfe->cad"; +-static constexpr std::array shape278{2, 2, 2}; +-static constexpr std::array expected278{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}; +-static constexpr std::string_view equation279 = "cba,cd,dfe->caf"; +-static constexpr std::array shape279{2, 2, 2}; +-static constexpr std::array expected279{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}; +-static constexpr std::string_view equation280 = "cba,cd,edf->cbe"; +-static constexpr std::array shape280{2, 2, 2}; +-static constexpr std::array expected280{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}; +-static constexpr std::string_view equation281 = "cba,cd,edf->cbd"; +-static constexpr std::array shape281{2, 2, 2}; +-static constexpr std::array expected281{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}; +-static constexpr std::string_view equation282 = "cba,cd,edf->cae"; +-static constexpr std::array shape282{2, 2, 2}; +-static constexpr std::array expected282{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}; +-static constexpr std::string_view equation283 = "cba,cd,edf->cad"; +-static constexpr std::array shape283{2, 2, 2}; +-static constexpr std::array expected283{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}; +-static constexpr std::string_view equation284 = "cba,cd,efd->cbe"; +-static constexpr std::array shape284{2, 2, 2}; +-static constexpr std::array expected284{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}; +-static constexpr std::string_view equation285 = "cba,cd,efd->cbf"; +-static constexpr std::array shape285{2, 2, 2}; +-static constexpr std::array expected285{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}; +-static constexpr std::string_view equation286 = "cba,cd,efd->cae"; +-static constexpr std::array shape286{2, 2, 2}; +-static constexpr std::array expected286{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}; +-static constexpr std::string_view equation287 = "cba,cd,efd->caf"; +-static constexpr std::array shape287{2, 2, 2}; +-static constexpr std::array expected287{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}; +-static constexpr std::string_view equation288 = "cba,cd,fde->cbf"; +-static constexpr std::array shape288{2, 2, 2}; +-static constexpr std::array expected288{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}; +-static constexpr std::string_view equation289 = "cba,cd,fde->cbd"; +-static constexpr std::array shape289{2, 2, 2}; +-static constexpr std::array expected289{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}; +-static constexpr std::string_view equation290 = "cba,cd,fde->caf"; +-static constexpr std::array shape290{2, 2, 2}; +-static constexpr std::array expected290{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}; +-static constexpr std::string_view equation291 = "cba,cd,fde->cad"; +-static constexpr std::array shape291{2, 2, 2}; +-static constexpr std::array expected291{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}; +-static constexpr std::string_view equation292 = "cba,cd,fed->cbf"; +-static constexpr std::array shape292{2, 2, 2}; +-static constexpr std::array expected292{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}; +-static constexpr std::string_view equation293 = "cba,cd,fed->cbe"; +-static constexpr std::array shape293{2, 2, 2}; +-static constexpr std::array expected293{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}; +-static constexpr std::string_view equation294 = "cba,cd,fed->caf"; +-static constexpr std::array shape294{2, 2, 2}; +-static constexpr std::array expected294{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}; +-static constexpr std::string_view equation295 = "cba,cd,fed->cae"; +-static constexpr std::array shape295{2, 2, 2}; +-static constexpr std::array expected295{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}; +-static constexpr std::string_view equation296 = "cba,dc,def->cbd"; +-static constexpr std::array shape296{2, 2, 2}; +-static constexpr std::array expected296{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}; +-static constexpr std::string_view equation297 = "cba,dc,def->cbe"; +-static constexpr std::array shape297{2, 2, 2}; +-static constexpr std::array expected297{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}; +-static constexpr std::string_view equation298 = "cba,dc,def->cad"; +-static constexpr std::array shape298{2, 2, 2}; +-static constexpr std::array expected298{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}; +-static constexpr std::string_view equation299 = "cba,dc,def->cae"; +-static constexpr std::array shape299{2, 2, 2}; +-static constexpr std::array expected299{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}; +-static constexpr std::string_view equation300 = "cba,dc,dfe->cbd"; +-static constexpr std::array shape300{2, 2, 2}; +-static constexpr std::array expected300{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}; +-static constexpr std::string_view equation301 = "cba,dc,dfe->cbf"; +-static constexpr std::array shape301{2, 2, 2}; +-static constexpr std::array expected301{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}; +-static constexpr std::string_view equation302 = "cba,dc,dfe->cad"; +-static constexpr std::array shape302{2, 2, 2}; +-static constexpr std::array expected302{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}; +-static constexpr std::string_view equation303 = "cba,dc,dfe->caf"; +-static constexpr std::array shape303{2, 2, 2}; +-static constexpr std::array expected303{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}; +-static constexpr std::string_view equation304 = "cba,dc,edf->cbe"; +-static constexpr std::array shape304{2, 2, 2}; +-static constexpr std::array expected304{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}; +-static constexpr std::string_view equation305 = "cba,dc,edf->cbd"; +-static constexpr std::array shape305{2, 2, 2}; +-static constexpr std::array expected305{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}; +-static constexpr std::string_view equation306 = "cba,dc,edf->cae"; +-static constexpr std::array shape306{2, 2, 2}; +-static constexpr std::array expected306{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}; +-static constexpr std::string_view equation307 = "cba,dc,edf->cad"; +-static constexpr std::array shape307{2, 2, 2}; +-static constexpr std::array expected307{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}; +-static constexpr std::string_view equation308 = "cba,dc,efd->cbe"; +-static constexpr std::array shape308{2, 2, 2}; +-static constexpr std::array expected308{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}; +-static constexpr std::string_view equation309 = "cba,dc,efd->cbf"; +-static constexpr std::array shape309{2, 2, 2}; +-static constexpr std::array expected309{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}; +-static constexpr std::string_view equation310 = "cba,dc,efd->cae"; +-static constexpr std::array shape310{2, 2, 2}; +-static constexpr std::array expected310{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}; +-static constexpr std::string_view equation311 = "cba,dc,efd->caf"; +-static constexpr std::array shape311{2, 2, 2}; +-static constexpr std::array expected311{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}; +-static constexpr std::string_view equation312 = "cba,dc,fde->cbf"; +-static constexpr std::array shape312{2, 2, 2}; +-static constexpr std::array expected312{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}; +-static constexpr std::string_view equation313 = "cba,dc,fde->cbd"; +-static constexpr std::array shape313{2, 2, 2}; +-static constexpr std::array expected313{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}; +-static constexpr std::string_view equation314 = "cba,dc,fde->caf"; +-static constexpr std::array shape314{2, 2, 2}; +-static constexpr std::array expected314{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}; +-static constexpr std::string_view equation315 = "cba,dc,fde->cad"; +-static constexpr std::array shape315{2, 2, 2}; +-static constexpr std::array expected315{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}; +-static constexpr std::string_view equation316 = "cba,dc,fed->cbf"; +-static constexpr std::array shape316{2, 2, 2}; +-static constexpr std::array expected316{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}; +-static constexpr std::string_view equation317 = "cba,dc,fed->cbe"; +-static constexpr std::array shape317{2, 2, 2}; +-static constexpr std::array expected317{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}; +-static constexpr std::string_view equation318 = "cba,dc,fed->caf"; +-static constexpr std::array shape318{2, 2, 2}; +-static constexpr std::array expected318{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}; +-static constexpr std::string_view equation319 = "cba,dc,fed->cae"; +-static constexpr std::array shape319{2, 2, 2}; +-static constexpr std::array expected319{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}; +-static constexpr std::array case1 = {{{equation32, shape32, expected32}, +- {equation33, shape33, expected33}, +- {equation34, shape34, expected34}, +- {equation35, shape35, expected35}, +- {equation36, shape36, expected36}, +- {equation37, shape37, expected37}, +- {equation38, shape38, expected38}, +- {equation39, shape39, expected39}, +- {equation40, shape40, expected40}, +- {equation41, shape41, expected41}, +- {equation42, shape42, expected42}, +- {equation43, shape43, expected43}, +- {equation44, shape44, expected44}, +- {equation45, shape45, expected45}, +- {equation46, shape46, expected46}, +- {equation47, shape47, expected47}, +- {equation48, shape48, expected48}, +- {equation49, shape49, expected49}, +- {equation50, shape50, expected50}, +- {equation51, shape51, expected51}, +- {equation52, shape52, expected52}, +- {equation53, shape53, expected53}, +- {equation54, shape54, expected54}, +- {equation55, shape55, expected55}, +- {equation56, shape56, expected56}, +- {equation57, shape57, expected57}, +- {equation58, shape58, expected58}, +- {equation59, shape59, expected59}, +- {equation60, shape60, expected60}, +- {equation61, shape61, expected61}, +- {equation62, shape62, expected62}, +- {equation63, shape63, expected63}, +- {equation64, shape64, expected64}, +- {equation65, shape65, expected65}, +- {equation66, shape66, expected66}, +- {equation67, shape67, expected67}, +- {equation68, shape68, expected68}, +- {equation69, shape69, expected69}, +- {equation70, shape70, expected70}, +- {equation71, shape71, expected71}, +- {equation72, shape72, expected72}, +- {equation73, shape73, expected73}, +- {equation74, shape74, expected74}, +- {equation75, shape75, expected75}, +- {equation76, shape76, expected76}, +- {equation77, shape77, expected77}, +- {equation78, shape78, expected78}, +- {equation79, shape79, expected79}, +- {equation80, shape80, expected80}, +- {equation81, shape81, expected81}, +- {equation82, shape82, expected82}, +- {equation83, shape83, expected83}, +- {equation84, shape84, expected84}, +- {equation85, shape85, expected85}, +- {equation86, shape86, expected86}, +- {equation87, shape87, expected87}, +- {equation88, shape88, expected88}, +- {equation89, shape89, expected89}, +- {equation90, shape90, expected90}, +- {equation91, shape91, expected91}, +- {equation92, shape92, expected92}, +- {equation93, shape93, expected93}, +- {equation94, shape94, expected94}, +- {equation95, shape95, expected95}, +- {equation96, shape96, expected96}, +- {equation97, shape97, expected97}, +- {equation98, shape98, expected98}, +- {equation99, shape99, expected99}, +- {equation100, shape100, expected100}, +- {equation101, shape101, expected101}, +- {equation102, shape102, expected102}, +- {equation103, shape103, expected103}, +- {equation104, shape104, expected104}, +- {equation105, shape105, expected105}, +- {equation106, shape106, expected106}, +- {equation107, shape107, expected107}, +- {equation108, shape108, expected108}, +- {equation109, shape109, expected109}, +- {equation110, shape110, expected110}, +- {equation111, shape111, expected111}, +- {equation112, shape112, expected112}, +- {equation113, shape113, expected113}, +- {equation114, shape114, expected114}, +- {equation115, shape115, expected115}, +- {equation116, shape116, expected116}, +- {equation117, shape117, expected117}, +- {equation118, shape118, expected118}, +- {equation119, shape119, expected119}, +- {equation120, shape120, expected120}, +- {equation121, shape121, expected121}, +- {equation122, shape122, expected122}, +- {equation123, shape123, expected123}, +- {equation124, shape124, expected124}, +- {equation125, shape125, expected125}, +- {equation126, shape126, expected126}, +- {equation127, shape127, expected127}, +- {equation128, shape128, expected128}, +- {equation129, shape129, expected129}, +- {equation130, shape130, expected130}, +- {equation131, shape131, expected131}, +- {equation132, shape132, expected132}, +- {equation133, shape133, expected133}, +- {equation134, shape134, expected134}, +- {equation135, shape135, expected135}, +- {equation136, shape136, expected136}, +- {equation137, shape137, expected137}, +- {equation138, shape138, expected138}, +- {equation139, shape139, expected139}, +- {equation140, shape140, expected140}, +- {equation141, shape141, expected141}, +- {equation142, shape142, expected142}, +- {equation143, shape143, expected143}, +- {equation144, shape144, expected144}, +- {equation145, shape145, expected145}, +- {equation146, shape146, expected146}, +- {equation147, shape147, expected147}, +- {equation148, shape148, expected148}, +- {equation149, shape149, expected149}, +- {equation150, shape150, expected150}, +- {equation151, shape151, expected151}, +- {equation152, shape152, expected152}, +- {equation153, shape153, expected153}, +- {equation154, shape154, expected154}, +- {equation155, shape155, expected155}, +- {equation156, shape156, expected156}, +- {equation157, shape157, expected157}, +- {equation158, shape158, expected158}, +- {equation159, shape159, expected159}, +- {equation160, shape160, expected160}, +- {equation161, shape161, expected161}, +- {equation162, shape162, expected162}, +- {equation163, shape163, expected163}, +- {equation164, shape164, expected164}, +- {equation165, shape165, expected165}, +- {equation166, shape166, expected166}, +- {equation167, shape167, expected167}, +- {equation168, shape168, expected168}, +- {equation169, shape169, expected169}, +- {equation170, shape170, expected170}, +- {equation171, shape171, expected171}, +- {equation172, shape172, expected172}, +- {equation173, shape173, expected173}, +- {equation174, shape174, expected174}, +- {equation175, shape175, expected175}, +- {equation176, shape176, expected176}, +- {equation177, shape177, expected177}, +- {equation178, shape178, expected178}, +- {equation179, shape179, expected179}, +- {equation180, shape180, expected180}, +- {equation181, shape181, expected181}, +- {equation182, shape182, expected182}, +- {equation183, shape183, expected183}, +- {equation184, shape184, expected184}, +- {equation185, shape185, expected185}, +- {equation186, shape186, expected186}, +- {equation187, shape187, expected187}, +- {equation188, shape188, expected188}, +- {equation189, shape189, expected189}, +- {equation190, shape190, expected190}, +- {equation191, shape191, expected191}, +- {equation192, shape192, expected192}, +- {equation193, shape193, expected193}, +- {equation194, shape194, expected194}, +- {equation195, shape195, expected195}, +- {equation196, shape196, expected196}, +- {equation197, shape197, expected197}, +- {equation198, shape198, expected198}, +- {equation199, shape199, expected199}, +- {equation200, shape200, expected200}, +- {equation201, shape201, expected201}, +- {equation202, shape202, expected202}, +- {equation203, shape203, expected203}, +- {equation204, shape204, expected204}, +- {equation205, shape205, expected205}, +- {equation206, shape206, expected206}, +- {equation207, shape207, expected207}, +- {equation208, shape208, expected208}, +- {equation209, shape209, expected209}, +- {equation210, shape210, expected210}, +- {equation211, shape211, expected211}, +- {equation212, shape212, expected212}, +- {equation213, shape213, expected213}, +- {equation214, shape214, expected214}, +- {equation215, shape215, expected215}, +- {equation216, shape216, expected216}, +- {equation217, shape217, expected217}, +- {equation218, shape218, expected218}, +- {equation219, shape219, expected219}, +- {equation220, shape220, expected220}, +- {equation221, shape221, expected221}, +- {equation222, shape222, expected222}, +- {equation223, shape223, expected223}, +- {equation224, shape224, expected224}, +- {equation225, shape225, expected225}, +- {equation226, shape226, expected226}, +- {equation227, shape227, expected227}, +- {equation228, shape228, expected228}, +- {equation229, shape229, expected229}, +- {equation230, shape230, expected230}, +- {equation231, shape231, expected231}, +- {equation232, shape232, expected232}, +- {equation233, shape233, expected233}, +- {equation234, shape234, expected234}, +- {equation235, shape235, expected235}, +- {equation236, shape236, expected236}, +- {equation237, shape237, expected237}, +- {equation238, shape238, expected238}, +- {equation239, shape239, expected239}, +- {equation240, shape240, expected240}, +- {equation241, shape241, expected241}, +- {equation242, shape242, expected242}, +- {equation243, shape243, expected243}, +- {equation244, shape244, expected244}, +- {equation245, shape245, expected245}, +- {equation246, shape246, expected246}, +- {equation247, shape247, expected247}, +- {equation248, shape248, expected248}, +- {equation249, shape249, expected249}, +- {equation250, shape250, expected250}, +- {equation251, shape251, expected251}, +- {equation252, shape252, expected252}, +- {equation253, shape253, expected253}, +- {equation254, shape254, expected254}, +- {equation255, shape255, expected255}, +- {equation256, shape256, expected256}, +- {equation257, shape257, expected257}, +- {equation258, shape258, expected258}, +- {equation259, shape259, expected259}, +- {equation260, shape260, expected260}, +- {equation261, shape261, expected261}, +- {equation262, shape262, expected262}, +- {equation263, shape263, expected263}, +- {equation264, shape264, expected264}, +- {equation265, shape265, expected265}, +- {equation266, shape266, expected266}, +- {equation267, shape267, expected267}, +- {equation268, shape268, expected268}, +- {equation269, shape269, expected269}, +- {equation270, shape270, expected270}, +- {equation271, shape271, expected271}, +- {equation272, shape272, expected272}, +- {equation273, shape273, expected273}, +- {equation274, shape274, expected274}, +- {equation275, shape275, expected275}, +- {equation276, shape276, expected276}, +- {equation277, shape277, expected277}, +- {equation278, shape278, expected278}, +- {equation279, shape279, expected279}, +- {equation280, shape280, expected280}, +- {equation281, shape281, expected281}, +- {equation282, shape282, expected282}, +- {equation283, shape283, expected283}, +- {equation284, shape284, expected284}, +- {equation285, shape285, expected285}, +- {equation286, shape286, expected286}, +- {equation287, shape287, expected287}, +- {equation288, shape288, expected288}, +- {equation289, shape289, expected289}, +- {equation290, shape290, expected290}, +- {equation291, shape291, expected291}, +- {equation292, shape292, expected292}, +- {equation293, shape293, expected293}, +- {equation294, shape294, expected294}, +- {equation295, shape295, expected295}, +- {equation296, shape296, expected296}, +- {equation297, shape297, expected297}, +- {equation298, shape298, expected298}, +- {equation299, shape299, expected299}, +- {equation300, shape300, expected300}, +- {equation301, shape301, expected301}, +- {equation302, shape302, expected302}, +- {equation303, shape303, expected303}, +- {equation304, shape304, expected304}, +- {equation305, shape305, expected305}, +- {equation306, shape306, expected306}, +- {equation307, shape307, expected307}, +- {equation308, shape308, expected308}, +- {equation309, shape309, expected309}, +- {equation310, shape310, expected310}, +- {equation311, shape311, expected311}, +- {equation312, shape312, expected312}, +- {equation313, shape313, expected313}, +- {equation314, shape314, expected314}, +- {equation315, shape315, expected315}, +- {equation316, shape316, expected316}, +- {equation317, shape317, expected317}, +- {equation318, shape318, expected318}, +- {equation319, shape319, expected319}}}; + + TEST(Einsum, EinsumTransposeMatMulTwoInputsTestSuite) { ++ std::vector test_cases{ ++ EinsumTestCase("abc,cd->abc", std::vector{2, 2, 2}, std::vector{0.f, 5.f, 2.f, 15.f, 4.f, 25.f, 6.f, 35.f}), ++ EinsumTestCase("abc,cd->abd", std::vector{2, 2, 2}, std::vector{2.f, 3.f, 6.f, 11.f, 10.f, 19.f, 14.f, 27.f}), ++ EinsumTestCase("abc,cd->acd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 8.f, 12.f, 0.f, 10.f, 24.f, 36.f}), ++ EinsumTestCase("abc,dc->abd", std::vector{2, 2, 2}, std::vector{1.f, 3.f, 3.f, 13.f, 5.f, 23.f, 7.f, 33.f}), ++ EinsumTestCase("abc,dc->abc", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 4.f, 12.f, 8.f, 20.f, 12.f, 28.f}), ++ EinsumTestCase("abc,dc->acd", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 4.f, 12.f, 0.f, 20.f, 12.f, 36.f}), ++ EinsumTestCase("acb,cd->acd", std::vector{2, 2, 2}, std::vector{0.f, 1.f, 10.f, 15.f, 0.f, 9.f, 26.f, 39.f}), ++ EinsumTestCase("acb,cd->abc", std::vector{2, 2, 2}, std::vector{0.f, 10.f, 1.f, 15.f, 4.f, 30.f, 5.f, 35.f}), ++ EinsumTestCase("acb,cd->abd", std::vector{2, 2, 2}, std::vector{4.f, 6.f, 6.f, 10.f, 12.f, 22.f, 14.f, 26.f}), ++ EinsumTestCase("acb,dc->acd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 5.f, 15.f, 0.f, 18.f, 13.f, 39.f}), ++ EinsumTestCase("acb,dc->abd", std::vector{2, 2, 2}, std::vector{2.f, 6.f, 3.f, 11.f, 6.f, 26.f, 7.f, 31.f}), ++ EinsumTestCase("acb,dc->abc", std::vector{2, 2, 2}, std::vector{0.f, 8.f, 2.f, 12.f, 8.f, 24.f, 10.f, 28.f}), ++ EinsumTestCase("bac,cd->bac", std::vector{2, 2, 2}, std::vector{0.f, 5.f, 2.f, 15.f, 4.f, 25.f, 6.f, 35.f}), ++ EinsumTestCase("bac,cd->bad", std::vector{2, 2, 2}, std::vector{2.f, 3.f, 6.f, 11.f, 10.f, 19.f, 14.f, 27.f}), ++ EinsumTestCase("bac,cd->bcd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 8.f, 12.f, 0.f, 10.f, 24.f, 36.f}), ++ EinsumTestCase("bac,dc->bad", std::vector{2, 2, 2}, std::vector{1.f, 3.f, 3.f, 13.f, 5.f, 23.f, 7.f, 33.f}), ++ EinsumTestCase("bac,dc->bac", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 4.f, 12.f, 8.f, 20.f, 12.f, 28.f}), ++ EinsumTestCase("bac,dc->bcd", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 4.f, 12.f, 0.f, 20.f, 12.f, 36.f}), ++ EinsumTestCase("bca,cd->bcd", std::vector{2, 2, 2}, std::vector{0.f, 1.f, 10.f, 15.f, 0.f, 9.f, 26.f, 39.f}), ++ EinsumTestCase("bca,cd->bac", std::vector{2, 2, 2}, std::vector{0.f, 10.f, 1.f, 15.f, 4.f, 30.f, 5.f, 35.f}), ++ EinsumTestCase("bca,cd->bad", std::vector{2, 2, 2}, std::vector{4.f, 6.f, 6.f, 10.f, 12.f, 22.f, 14.f, 26.f}), ++ EinsumTestCase("bca,dc->bcd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 5.f, 15.f, 0.f, 18.f, 13.f, 39.f}), ++ EinsumTestCase("bca,dc->bad", std::vector{2, 2, 2}, std::vector{2.f, 6.f, 3.f, 11.f, 6.f, 26.f, 7.f, 31.f}), ++ EinsumTestCase("bca,dc->bac", std::vector{2, 2, 2}, std::vector{0.f, 8.f, 2.f, 12.f, 8.f, 24.f, 10.f, 28.f}), ++ EinsumTestCase("cab,cd->cad", std::vector{2, 2, 2}, std::vector{0.f, 1.f, 0.f, 5.f, 18.f, 27.f, 26.f, 39.f}), ++ EinsumTestCase("cab,cd->cbd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 0.f, 4.f, 20.f, 30.f, 24.f, 36.f}), ++ EinsumTestCase("cab,dc->cad", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 0.f, 10.f, 9.f, 27.f, 13.f, 39.f}), ++ EinsumTestCase("cab,dc->cbd", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 0.f, 8.f, 10.f, 30.f, 12.f, 36.f}), ++ EinsumTestCase("cba,cd->cbd", std::vector{2, 2, 2}, std::vector{0.f, 1.f, 0.f, 5.f, 18.f, 27.f, 26.f, 39.f}), ++ EinsumTestCase("cba,cd->cad", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 0.f, 4.f, 20.f, 30.f, 24.f, 36.f}), ++ EinsumTestCase("cba,dc->cbd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 0.f, 10.f, 9.f, 27.f, 13.f, 39.f}), ++ EinsumTestCase("cba,dc->cad", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 0.f, 8.f, 10.f, 30.f, 12.f, 36.f})}; ++ + std::vector m1{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + std::vector m2{0.f, 1.f, 2.f, 3.f}; +- for (const auto& tst : case0) { ++ for (const auto& tst : test_cases) { + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); +- std::string s(tst.equation); +- test.AddAttribute("equation", s); ++ test.AddAttribute("equation", tst.equation); + test.AddInput("x", {2, 2, 2}, m1); + test.AddInput("y", {2, 2}, m2); +- +- std::vector v1(tst.shape.begin(), tst.shape.end()); +- std::vector v2(tst.expected.begin(), tst.expected.end()); +- test.AddOutput("o", v1, v2); ++ test.AddOutput("o", tst.shape, tst.expected); + test.Run(); + } + } + +-class EinsumTransposeMatMulThreeInputsTest : public testing::TestWithParam { +-}; +- +-TEST_P(EinsumTransposeMatMulThreeInputsTest, EinsumTransposeMatMulThreeInputsTestSuite) { +- const auto& tst = GetParam(); +- std::vector m1{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; +- std::vector m2{0.f, 1.f, 2.f, 3.f}; +- std::vector m3{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; +- OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); +- std::string s(tst.equation); +- test.AddAttribute("equation", s); +- test.AddInput("x", {2, 2, 2}, m1); +- test.AddInput("y", {2, 2}, m2); +- test.AddInput("z", {2, 2, 2}, m3); +- std::vector v1(tst.shape.begin(), tst.shape.end()); +- std::vector v2(tst.expected.begin(), tst.expected.end()); +- test.AddOutput("o", v1, v2); +- test.Run(); +-} ++TEST(Einsum, EinsumTransposeMatMulThreeInputsTestSuite) { ++ std::vector test_cases_set_1{ ++ EinsumTestCase("abc,cd,def->abd", std::vector{2, 2, 2}, std::vector{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}), ++ EinsumTestCase("abc,cd,def->abe", std::vector{2, 2, 2}, std::vector{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}), ++ EinsumTestCase("abc,cd,def->acd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}), ++ EinsumTestCase("abc,cd,def->ace", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}), ++ EinsumTestCase("abc,cd,dfe->abd", std::vector{2, 2, 2}, std::vector{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}), ++ EinsumTestCase("abc,cd,dfe->abf", std::vector{2, 2, 2}, std::vector{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}), ++ EinsumTestCase("abc,cd,dfe->acd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}), ++ EinsumTestCase("abc,cd,dfe->acf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}), ++ EinsumTestCase("abc,cd,edf->abe", std::vector{2, 2, 2}, std::vector{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}), ++ EinsumTestCase("abc,cd,edf->abd", std::vector{2, 2, 2}, std::vector{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}), ++ EinsumTestCase("abc,cd,edf->ace", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}), ++ EinsumTestCase("abc,cd,edf->acd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}), ++ EinsumTestCase("abc,cd,efd->abe", std::vector{2, 2, 2}, std::vector{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}), ++ EinsumTestCase("abc,cd,efd->abf", std::vector{2, 2, 2}, std::vector{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}), ++ EinsumTestCase("abc,cd,efd->ace", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}), ++ EinsumTestCase("abc,cd,efd->acf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}), ++ EinsumTestCase("abc,cd,fde->abf", std::vector{2, 2, 2}, std::vector{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}), ++ EinsumTestCase("abc,cd,fde->abd", std::vector{2, 2, 2}, std::vector{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}), ++ EinsumTestCase("abc,cd,fde->acf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}), ++ EinsumTestCase("abc,cd,fde->acd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}), ++ EinsumTestCase("abc,cd,fed->abf", std::vector{2, 2, 2}, std::vector{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}), ++ EinsumTestCase("abc,cd,fed->abe", std::vector{2, 2, 2}, std::vector{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}), ++ EinsumTestCase("abc,cd,fed->acf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}), ++ EinsumTestCase("abc,cd,fed->ace", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}), ++ EinsumTestCase("abc,dc,def->abd", std::vector{2, 2, 2}, std::vector{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}), ++ EinsumTestCase("abc,dc,def->abe", std::vector{2, 2, 2}, std::vector{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}), ++ EinsumTestCase("abc,dc,def->acd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}), ++ EinsumTestCase("abc,dc,def->ace", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}), ++ EinsumTestCase("abc,dc,dfe->abd", std::vector{2, 2, 2}, std::vector{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}), ++ EinsumTestCase("abc,dc,dfe->abf", std::vector{2, 2, 2}, std::vector{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}), ++ EinsumTestCase("abc,dc,dfe->acd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}), ++ EinsumTestCase("abc,dc,dfe->acf", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}), ++ EinsumTestCase("abc,dc,edf->abe", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}), ++ EinsumTestCase("abc,dc,edf->abd", std::vector{2, 2, 2}, std::vector{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}), ++ EinsumTestCase("abc,dc,edf->ace", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}), ++ EinsumTestCase("abc,dc,edf->acd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}), ++ EinsumTestCase("abc,dc,efd->abe", std::vector{2, 2, 2}, std::vector{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}), ++ EinsumTestCase("abc,dc,efd->abf", std::vector{2, 2, 2}, std::vector{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}), ++ EinsumTestCase("abc,dc,efd->ace", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}), ++ EinsumTestCase("abc,dc,efd->acf", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}), ++ EinsumTestCase("abc,dc,fde->abf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}), ++ EinsumTestCase("abc,dc,fde->abd", std::vector{2, 2, 2}, std::vector{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}), ++ EinsumTestCase("abc,dc,fde->acf", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}), ++ EinsumTestCase("abc,dc,fde->acd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}), ++ EinsumTestCase("abc,dc,fed->abf", std::vector{2, 2, 2}, std::vector{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}), ++ EinsumTestCase("abc,dc,fed->abe", std::vector{2, 2, 2}, std::vector{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}), ++ EinsumTestCase("abc,dc,fed->acf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}), ++ EinsumTestCase("abc,dc,fed->ace", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}), ++ EinsumTestCase("acb,cd,def->acd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}), ++ EinsumTestCase("acb,cd,def->ace", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}), ++ EinsumTestCase("acb,cd,def->abd", std::vector{2, 2, 2}, std::vector{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}), ++ EinsumTestCase("acb,cd,def->abe", std::vector{2, 2, 2}, std::vector{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}), ++ EinsumTestCase("acb,cd,dfe->acd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}), ++ EinsumTestCase("acb,cd,dfe->acf", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}), ++ EinsumTestCase("acb,cd,dfe->abd", std::vector{2, 2, 2}, std::vector{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}), ++ EinsumTestCase("acb,cd,dfe->abf", std::vector{2, 2, 2}, std::vector{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}), ++ EinsumTestCase("acb,cd,edf->ace", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}), ++ EinsumTestCase("acb,cd,edf->acd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}), ++ EinsumTestCase("acb,cd,edf->abe", std::vector{2, 2, 2}, std::vector{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}), ++ EinsumTestCase("acb,cd,edf->abd", std::vector{2, 2, 2}, std::vector{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}), ++ EinsumTestCase("acb,cd,efd->ace", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}), ++ EinsumTestCase("acb,cd,efd->acf", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}), ++ EinsumTestCase("acb,cd,efd->abe", std::vector{2, 2, 2}, std::vector{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}), ++ EinsumTestCase("acb,cd,efd->abf", std::vector{2, 2, 2}, std::vector{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}), ++ EinsumTestCase("acb,cd,fde->acf", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}), ++ EinsumTestCase("acb,cd,fde->acd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}), ++ EinsumTestCase("acb,cd,fde->abf", std::vector{2, 2, 2}, std::vector{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}), ++ EinsumTestCase("acb,cd,fde->abd", std::vector{2, 2, 2}, std::vector{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}), ++ EinsumTestCase("acb,cd,fed->acf", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}), ++ EinsumTestCase("acb,cd,fed->ace", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}), ++ EinsumTestCase("acb,cd,fed->abf", std::vector{2, 2, 2}, std::vector{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}), ++ EinsumTestCase("acb,cd,fed->abe", std::vector{2, 2, 2}, std::vector{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}), ++ EinsumTestCase("acb,dc,def->acd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}), ++ EinsumTestCase("acb,dc,def->ace", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f})}; ++ ++ std::vector test_cases_set_2{ ++ EinsumTestCase("acb,dc,def->abd", std::vector{2, 2, 2}, std::vector{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}), ++ EinsumTestCase("acb,dc,def->abe", std::vector{2, 2, 2}, std::vector{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}), ++ EinsumTestCase("acb,dc,dfe->acd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}), ++ EinsumTestCase("acb,dc,dfe->acf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}), ++ EinsumTestCase("acb,dc,dfe->abd", std::vector{2, 2, 2}, std::vector{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}), ++ EinsumTestCase("acb,dc,dfe->abf", std::vector{2, 2, 2}, std::vector{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}), ++ EinsumTestCase("acb,dc,edf->ace", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}), ++ EinsumTestCase("acb,dc,edf->acd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}), ++ EinsumTestCase("acb,dc,edf->abe", std::vector{2, 2, 2}, std::vector{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}), ++ EinsumTestCase("acb,dc,edf->abd", std::vector{2, 2, 2}, std::vector{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}), ++ EinsumTestCase("acb,dc,efd->ace", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}), ++ EinsumTestCase("acb,dc,efd->acf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}), ++ EinsumTestCase("acb,dc,efd->abe", std::vector{2, 2, 2}, std::vector{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}), ++ EinsumTestCase("acb,dc,efd->abf", std::vector{2, 2, 2}, std::vector{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}), ++ EinsumTestCase("acb,dc,fde->acf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}), ++ EinsumTestCase("acb,dc,fde->acd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}), ++ EinsumTestCase("acb,dc,fde->abf", std::vector{2, 2, 2}, std::vector{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}), ++ EinsumTestCase("acb,dc,fde->abd", std::vector{2, 2, 2}, std::vector{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}), ++ EinsumTestCase("acb,dc,fed->acf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}), ++ EinsumTestCase("acb,dc,fed->ace", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}), ++ EinsumTestCase("acb,dc,fed->abf", std::vector{2, 2, 2}, std::vector{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}), ++ EinsumTestCase("acb,dc,fed->abe", std::vector{2, 2, 2}, std::vector{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}), ++ EinsumTestCase("bac,cd,def->bad", std::vector{2, 2, 2}, std::vector{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}), ++ EinsumTestCase("bac,cd,def->bae", std::vector{2, 2, 2}, std::vector{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}), ++ EinsumTestCase("bac,cd,def->bcd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}), ++ EinsumTestCase("bac,cd,def->bce", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}), ++ EinsumTestCase("bac,cd,dfe->bad", std::vector{2, 2, 2}, std::vector{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}), ++ EinsumTestCase("bac,cd,dfe->baf", std::vector{2, 2, 2}, std::vector{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}), ++ EinsumTestCase("bac,cd,dfe->bcd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}), ++ EinsumTestCase("bac,cd,dfe->bcf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}), ++ EinsumTestCase("bac,cd,edf->bae", std::vector{2, 2, 2}, std::vector{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}), ++ EinsumTestCase("bac,cd,edf->bad", std::vector{2, 2, 2}, std::vector{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}), ++ EinsumTestCase("bac,cd,edf->bce", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}), ++ EinsumTestCase("bac,cd,edf->bcd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}), ++ EinsumTestCase("bac,cd,efd->bae", std::vector{2, 2, 2}, std::vector{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}), ++ EinsumTestCase("bac,cd,efd->baf", std::vector{2, 2, 2}, std::vector{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}), ++ EinsumTestCase("bac,cd,efd->bce", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}), ++ EinsumTestCase("bac,cd,efd->bcf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}), ++ EinsumTestCase("bac,cd,fde->baf", std::vector{2, 2, 2}, std::vector{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}), ++ EinsumTestCase("bac,cd,fde->bad", std::vector{2, 2, 2}, std::vector{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}), ++ EinsumTestCase("bac,cd,fde->bcf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}), ++ EinsumTestCase("bac,cd,fde->bcd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}), ++ EinsumTestCase("bac,cd,fed->baf", std::vector{2, 2, 2}, std::vector{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}), ++ EinsumTestCase("bac,cd,fed->bae", std::vector{2, 2, 2}, std::vector{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}), ++ EinsumTestCase("bac,cd,fed->bcf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}), ++ EinsumTestCase("bac,cd,fed->bce", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}), ++ EinsumTestCase("bac,dc,def->bad", std::vector{2, 2, 2}, std::vector{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}), ++ EinsumTestCase("bac,dc,def->bae", std::vector{2, 2, 2}, std::vector{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}), ++ EinsumTestCase("bac,dc,def->bcd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}), ++ EinsumTestCase("bac,dc,def->bce", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}), ++ EinsumTestCase("bac,dc,dfe->bad", std::vector{2, 2, 2}, std::vector{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}), ++ EinsumTestCase("bac,dc,dfe->baf", std::vector{2, 2, 2}, std::vector{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}), ++ EinsumTestCase("bac,dc,dfe->bcd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}), ++ EinsumTestCase("bac,dc,dfe->bcf", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}), ++ EinsumTestCase("bac,dc,edf->bae", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}), ++ EinsumTestCase("bac,dc,edf->bad", std::vector{2, 2, 2}, std::vector{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}), ++ EinsumTestCase("bac,dc,edf->bce", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}), ++ EinsumTestCase("bac,dc,edf->bcd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}), ++ EinsumTestCase("bac,dc,efd->bae", std::vector{2, 2, 2}, std::vector{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}), ++ EinsumTestCase("bac,dc,efd->baf", std::vector{2, 2, 2}, std::vector{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}), ++ EinsumTestCase("bac,dc,efd->bce", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}), ++ EinsumTestCase("bac,dc,efd->bcf", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}), ++ EinsumTestCase("bac,dc,fde->baf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}), ++ EinsumTestCase("bac,dc,fde->bad", std::vector{2, 2, 2}, std::vector{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}), ++ EinsumTestCase("bac,dc,fde->bcf", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}), ++ EinsumTestCase("bac,dc,fde->bcd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}), ++ EinsumTestCase("bac,dc,fed->baf", std::vector{2, 2, 2}, std::vector{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}), ++ EinsumTestCase("bac,dc,fed->bae", std::vector{2, 2, 2}, std::vector{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}), ++ EinsumTestCase("bac,dc,fed->bcf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}), ++ EinsumTestCase("bac,dc,fed->bce", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}), ++ EinsumTestCase("bca,cd,def->bcd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}), ++ EinsumTestCase("bca,cd,def->bce", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}), ++ EinsumTestCase("bca,cd,def->bad", std::vector{2, 2, 2}, std::vector{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}), ++ EinsumTestCase("bca,cd,def->bae", std::vector{2, 2, 2}, std::vector{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}), ++ EinsumTestCase("bca,cd,dfe->bcd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}), ++ EinsumTestCase("bca,cd,dfe->bcf", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}), ++ EinsumTestCase("bca,cd,dfe->bad", std::vector{2, 2, 2}, std::vector{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}), ++ EinsumTestCase("bca,cd,dfe->baf", std::vector{2, 2, 2}, std::vector{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}), ++ EinsumTestCase("bca,cd,edf->bce", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}), ++ EinsumTestCase("bca,cd,edf->bcd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}), ++ EinsumTestCase("bca,cd,edf->bae", std::vector{2, 2, 2}, std::vector{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}), ++ EinsumTestCase("bca,cd,edf->bad", std::vector{2, 2, 2}, std::vector{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}), ++ EinsumTestCase("bca,cd,efd->bce", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}), ++ EinsumTestCase("bca,cd,efd->bcf", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}), ++ EinsumTestCase("bca,cd,efd->bae", std::vector{2, 2, 2}, std::vector{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}), ++ EinsumTestCase("bca,cd,efd->baf", std::vector{2, 2, 2}, std::vector{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}), ++ EinsumTestCase("bca,cd,fde->bcf", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}), ++ EinsumTestCase("bca,cd,fde->bcd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}), ++ EinsumTestCase("bca,cd,fde->baf", std::vector{2, 2, 2}, std::vector{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}), ++ EinsumTestCase("bca,cd,fde->bad", std::vector{2, 2, 2}, std::vector{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}), ++ EinsumTestCase("bca,cd,fed->bcf", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}), ++ EinsumTestCase("bca,cd,fed->bce", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}), ++ EinsumTestCase("bca,cd,fed->baf", std::vector{2, 2, 2}, std::vector{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}), ++ EinsumTestCase("bca,cd,fed->bae", std::vector{2, 2, 2}, std::vector{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}), ++ EinsumTestCase("bca,dc,def->bcd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}), ++ EinsumTestCase("bca,dc,def->bce", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}), ++ EinsumTestCase("bca,dc,def->bad", std::vector{2, 2, 2}, std::vector{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}), ++ EinsumTestCase("bca,dc,def->bae", std::vector{2, 2, 2}, std::vector{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}), ++ EinsumTestCase("bca,dc,dfe->bcd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}), ++ EinsumTestCase("bca,dc,dfe->bcf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}), ++ EinsumTestCase("bca,dc,dfe->bad", std::vector{2, 2, 2}, std::vector{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}), ++ EinsumTestCase("bca,dc,dfe->baf", std::vector{2, 2, 2}, std::vector{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}), ++ EinsumTestCase("bca,dc,edf->bce", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}), ++ EinsumTestCase("bca,dc,edf->bcd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}), ++ EinsumTestCase("bca,dc,edf->bae", std::vector{2, 2, 2}, std::vector{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}), ++ EinsumTestCase("bca,dc,edf->bad", std::vector{2, 2, 2}, std::vector{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}), ++ EinsumTestCase("bca,dc,efd->bce", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}), ++ EinsumTestCase("bca,dc,efd->bcf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}), ++ EinsumTestCase("bca,dc,efd->bae", std::vector{2, 2, 2}, std::vector{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}), ++ EinsumTestCase("bca,dc,efd->baf", std::vector{2, 2, 2}, std::vector{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}), ++ EinsumTestCase("bca,dc,fde->bcf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}), ++ EinsumTestCase("bca,dc,fde->bcd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}), ++ EinsumTestCase("bca,dc,fde->baf", std::vector{2, 2, 2}, std::vector{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}), ++ EinsumTestCase("bca,dc,fde->bad", std::vector{2, 2, 2}, std::vector{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}), ++ EinsumTestCase("bca,dc,fed->bcf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}), ++ EinsumTestCase("bca,dc,fed->bce", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}), ++ EinsumTestCase("bca,dc,fed->baf", std::vector{2, 2, 2}, std::vector{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}), ++ EinsumTestCase("bca,dc,fed->bae", std::vector{2, 2, 2}, std::vector{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}), ++ EinsumTestCase("cab,cd,def->cad", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}), ++ EinsumTestCase("cab,cd,def->cae", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}), ++ EinsumTestCase("cab,cd,def->cbd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}), ++ EinsumTestCase("cab,cd,def->cbe", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}), ++ EinsumTestCase("cab,cd,dfe->cad", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}), ++ EinsumTestCase("cab,cd,dfe->caf", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}), ++ EinsumTestCase("cab,cd,dfe->cbd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}), ++ EinsumTestCase("cab,cd,dfe->cbf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}), ++ EinsumTestCase("cab,cd,edf->cae", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}), ++ EinsumTestCase("cab,cd,edf->cad", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}), ++ EinsumTestCase("cab,cd,edf->cbe", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}), ++ EinsumTestCase("cab,cd,edf->cbd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}), ++ EinsumTestCase("cab,cd,efd->cae", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}), ++ EinsumTestCase("cab,cd,efd->caf", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}), ++ EinsumTestCase("cab,cd,efd->cbe", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}), ++ EinsumTestCase("cab,cd,efd->cbf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}), ++ EinsumTestCase("cab,cd,fde->caf", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}), ++ EinsumTestCase("cab,cd,fde->cad", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}), ++ EinsumTestCase("cab,cd,fde->cbf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}), ++ EinsumTestCase("cab,cd,fde->cbd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}), ++ EinsumTestCase("cab,cd,fed->caf", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}), ++ EinsumTestCase("cab,cd,fed->cae", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}), ++ EinsumTestCase("cab,cd,fed->cbf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}), ++ EinsumTestCase("cab,cd,fed->cbe", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}), ++ EinsumTestCase("cab,dc,def->cad", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}), ++ EinsumTestCase("cab,dc,def->cae", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}), ++ EinsumTestCase("cab,dc,def->cbd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f})}; ++ ++ std::vector test_cases_set_3{ ++ EinsumTestCase("cab,dc,def->cbe", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}), ++ EinsumTestCase("cab,dc,dfe->cad", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}), ++ EinsumTestCase("cab,dc,dfe->caf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}), ++ EinsumTestCase("cab,dc,dfe->cbd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}), ++ EinsumTestCase("cab,dc,dfe->cbf", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}), ++ EinsumTestCase("cab,dc,edf->cae", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}), ++ EinsumTestCase("cab,dc,edf->cad", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}), ++ EinsumTestCase("cab,dc,edf->cbe", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}), ++ EinsumTestCase("cab,dc,edf->cbd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}), ++ EinsumTestCase("cab,dc,efd->cae", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}), ++ EinsumTestCase("cab,dc,efd->caf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}), ++ EinsumTestCase("cab,dc,efd->cbe", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}), ++ EinsumTestCase("cab,dc,efd->cbf", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}), ++ EinsumTestCase("cab,dc,fde->caf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}), ++ EinsumTestCase("cab,dc,fde->cad", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}), ++ EinsumTestCase("cab,dc,fde->cbf", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}), ++ EinsumTestCase("cab,dc,fde->cbd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}), ++ EinsumTestCase("cab,dc,fed->caf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}), ++ EinsumTestCase("cab,dc,fed->cae", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}), ++ EinsumTestCase("cab,dc,fed->cbf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}), ++ EinsumTestCase("cab,dc,fed->cbe", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}), ++ EinsumTestCase("cba,cd,def->cbd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}), ++ EinsumTestCase("cba,cd,def->cbe", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}), ++ EinsumTestCase("cba,cd,def->cad", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}), ++ EinsumTestCase("cba,cd,def->cae", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}), ++ EinsumTestCase("cba,cd,dfe->cbd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}), ++ EinsumTestCase("cba,cd,dfe->cbf", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}), ++ EinsumTestCase("cba,cd,dfe->cad", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}), ++ EinsumTestCase("cba,cd,dfe->caf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}), ++ EinsumTestCase("cba,cd,edf->cbe", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}), ++ EinsumTestCase("cba,cd,edf->cbd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}), ++ EinsumTestCase("cba,cd,edf->cae", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}), ++ EinsumTestCase("cba,cd,edf->cad", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}), ++ EinsumTestCase("cba,cd,efd->cbe", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}), ++ EinsumTestCase("cba,cd,efd->cbf", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}), ++ EinsumTestCase("cba,cd,efd->cae", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}), ++ EinsumTestCase("cba,cd,efd->caf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}), ++ EinsumTestCase("cba,cd,fde->cbf", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}), ++ EinsumTestCase("cba,cd,fde->cbd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}), ++ EinsumTestCase("cba,cd,fde->caf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}), ++ EinsumTestCase("cba,cd,fde->cad", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}), ++ EinsumTestCase("cba,cd,fed->cbf", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}), ++ EinsumTestCase("cba,cd,fed->cbe", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}), ++ EinsumTestCase("cba,cd,fed->caf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}), ++ EinsumTestCase("cba,cd,fed->cae", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}), ++ EinsumTestCase("cba,dc,def->cbd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}), ++ EinsumTestCase("cba,dc,def->cbe", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}), ++ EinsumTestCase("cba,dc,def->cad", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}), ++ EinsumTestCase("cba,dc,def->cae", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}), ++ EinsumTestCase("cba,dc,dfe->cbd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}), ++ EinsumTestCase("cba,dc,dfe->cbf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}), ++ EinsumTestCase("cba,dc,dfe->cad", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}), ++ EinsumTestCase("cba,dc,dfe->caf", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}), ++ EinsumTestCase("cba,dc,edf->cbe", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}), ++ EinsumTestCase("cba,dc,edf->cbd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}), ++ EinsumTestCase("cba,dc,edf->cae", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}), ++ EinsumTestCase("cba,dc,edf->cad", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}), ++ EinsumTestCase("cba,dc,efd->cbe", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}), ++ EinsumTestCase("cba,dc,efd->cbf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}), ++ EinsumTestCase("cba,dc,efd->cae", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}), ++ EinsumTestCase("cba,dc,efd->caf", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}), ++ EinsumTestCase("cba,dc,fde->cbf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}), ++ EinsumTestCase("cba,dc,fde->cbd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}), ++ EinsumTestCase("cba,dc,fde->caf", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}), ++ EinsumTestCase("cba,dc,fde->cad", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}), ++ EinsumTestCase("cba,dc,fed->cbf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}), ++ EinsumTestCase("cba,dc,fed->cbe", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}), ++ EinsumTestCase("cba,dc,fed->caf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}), ++ EinsumTestCase("cba,dc,fed->cae", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f})}; ++ ++ auto test_lambda = [](const std::vector& test_cases_set) { ++ std::vector m1{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; ++ std::vector m2{0.f, 1.f, 2.f, 3.f}; ++ std::vector m3{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; ++ for (const auto& tst : test_cases_set) { ++ OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); ++ test.AddAttribute("equation", tst.equation); ++ test.AddInput("x", {2, 2, 2}, m1); ++ test.AddInput("y", {2, 2}, m2); ++ test.AddInput("z", {2, 2, 2}, m3); ++ test.AddOutput("o", tst.shape, tst.expected); ++ test.Run(); ++ } ++ }; ++ ++ test_lambda(test_cases_set_1); ++ test_lambda(test_cases_set_2); ++ test_lambda(test_cases_set_3); + +-INSTANTIATE_TEST_SUITE_P(EinsumTransposeMatMulThreeInputsTests, EinsumTransposeMatMulThreeInputsTest, testing::ValuesIn(case1)); ++} // namespace test + + } // namespace test + } // namespace onnxruntime +diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc +index 8128c170c..859e08271 100644 +--- a/onnxruntime/test/providers/cpu/model_tests.cc ++++ b/onnxruntime/test/providers/cpu/model_tests.cc +@@ -39,8 +39,6 @@ + #include "core/providers/armnn/armnn_provider_factory.h" + #endif + +-#include "test/common/cuda_op_test_utils.h" +- + // test infrastructure + #include "test/onnx/testenv.h" + #include "test/onnx/TestCase.h" +@@ -96,21 +94,6 @@ TEST_P(ModelTest, Run) { + + std::unique_ptr model_info = std::make_unique(model_path.c_str()); + +-#if defined(__linux__) +- // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test. +- if (HasCudaEnvironment(800) && provider_name == "cuda") { +- per_sample_tolerance = 1e-1; +- if (model_path.find(ORT_TSTR("SSD")) > 0 || +- model_path.find(ORT_TSTR("ssd")) > 0 || +- model_path.find(ORT_TSTR("yolov3")) > 0 || +- model_path.find(ORT_TSTR("mask_rcnn")) > 0 || +- model_path.find(ORT_TSTR("FNS")) > 0) { +- SkipTest("Skipping SSD test for big tolearance failure or other errors"); +- return; +- } +- } +-#endif +- + if (model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || + model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { + SkipTest("it has the training domain. No pipeline should need to run these tests."); +diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +index 0c8d6c46d..026bb07ed 100644 +--- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc ++++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +@@ -34,6 +34,11 @@ TEST(DequantizeLinearOpTest, Int8) { + + // scalar zero & scale with int8 + TEST(DequantizeLinearOpTest, Int32) { ++ // TODO: Unskip when fixed #41968513 ++ if (DefaultDmlExecutionProvider().get() != nullptr) { ++ GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect"; ++ } ++ + OpTester test("DequantizeLinear", 10); + std::vector dims{4}; + test.AddInput("x", dims, {-30, -3, 100, 127}); +@@ -93,6 +98,11 @@ TEST(DequantizeLinearOpMLFloat16Test, Scalar) { + + // dequantize without zero point + TEST(DequantizeLinearOpTest, Without_Zero_Point) { ++ // TODO: Unskip when fixed #41968513 ++ if (DefaultDmlExecutionProvider().get() != nullptr) { ++ GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect"; ++ } ++ + OpTester test("DequantizeLinear", 10); + test.AddInput("x", {}, {100}); + test.AddInput("x_scale", {}, {2.0f}); +diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +index 10f02349a..f473c98ca 100644 +--- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc ++++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +@@ -1870,8 +1870,6 @@ void TestAntialiasing(std::map attributes, + test.AddAttribute("extrapolation_value", std::stof(v)); + } else if (k == "roi") { + roi = parse_attr(v, 0.0f); +- } else if (k == "antialias") { +- test.AddAttribute("antialias", std::stoll(v)); + } else { + throw std::invalid_argument("Unknown attribute"); + } +@@ -1896,9 +1894,6 @@ void TestAntialiasing(std::map attributes, + } + + TEST(ResizeOpTest, Antialias_Bilinear_No_ExcludeOutside) { +- if (DefaultDmlExecutionProvider().get() != nullptr) { +- GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; +- } + std::vector X(16); + std::iota(X.begin(), X.end(), 1.f); + +@@ -1917,6 +1912,7 @@ TEST(ResizeOpTest, Antialias_Bilinear_ExcludeOutside) { + 12.1f, 13.3f, 14.5f}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y); + } ++ + TEST(ResizeOpTest, Antialias_Bilinear_Scale_Is_All_1) { + std::vector X(3 * 4 * 5 * 6); + std::iota(X.begin(), X.end(), 1.f); +@@ -2013,9 +2009,6 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { + } + + TEST(ResizeOpTest, Antialias_Trilinear_No_ExcludeOutside) { +- if (DefaultDmlExecutionProvider().get() != nullptr) { +- GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; +- } + std::vector X(16 * 4); + std::iota(X.begin(), X.end(), 0.f); + std::vector Y = {5.7272725f, 6.9545455f, 8.181818f, 10.636364f, 11.863636f, +@@ -2037,9 +2030,6 @@ TEST(ResizeOpTest, Antialias_Trilinear_ExcludeOutside) { + } + + TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) { +- if (DefaultDmlExecutionProvider().get() != nullptr) { +- GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; +- } + std::vector X(16 * 4 * 4); + std::iota(X.begin(), X.end(), 0.f); + { +@@ -2128,9 +2118,6 @@ TEST(ResizeOpTest, Antialias_NHWCBicubic_ExcludeOutside) { + } + + TEST(ResizeOpTest, Antialias_Linear_AlignCorners) { +- if (DefaultDmlExecutionProvider().get() != nullptr) { +- GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; +- } + std::vector X(256); + std::iota(X.begin(), X.end(), 0.0f); + +@@ -2244,87 +2231,5 @@ TEST(ResizeOpTest, Antialias_Use_Extrapolation) { + }, + {4, 4, 4}, X, {3, 3, 3}, Y); + } +- +-TEST(ResizeOpTest, Antialias_Large_half_pixel) { +- std::vector X{0.f, 1.f, 2.f, 3.f, 4.f, 5.f}; +- std::vector Y = {1.f, 4.f}; +- std::vector roi{}; +- std::vector scales{}; +- std::vector output_shape{1, 1, 2, 1}; +- +- OpTester test("Resize", 18); +- +- test.AddAttribute("exclude_outside", 0LL); +- test.AddAttribute("antialias", 1LL); +- test.AddAttribute("mode", "linear"); +- +- test.AddInput("X", {1, 1, 6, 1}, X); +- test.AddInput("roi", {int64_t(roi.size())}, roi); +- test.AddInput("", {0}, scales); +- test.AddInput("sizes", {4}, output_shape); +- +- // Have absolute tolerance because ort is slightly different results. +- // DML implementation is equivalent to resize with variable input window size while ORT using a convolution approach. +- // Absolute error is for ORT CPU. +- test.AddOutput("Y", output_shape, Y, false, /*rel_error*/ 0.0f, /*abs_error*/ 0.12f); +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); +-} +- +-// Test without anti-aliasing for better comparison with DirectML +-TEST(ResizeOpTest, Axes_and_Scale_18) { +- std::vector X(16 * 4); +- std::iota(X.begin(), X.end(), 0.f); +- std::vector Y = {3.5f, 4.8333335f, 6.1666665f, 8.833333f, 10.166667f, 11.5f, 14.166667f, +- 15.5f, 16.833334f, 24.833334f, 26.166666f, 27.5f, 30.166666f, 31.5f, +- 32.833332f, 35.5f, 36.833332f, 38.166668f, 46.166668f, 47.5f, 48.833332f, +- 51.5f, 52.833332f, 54.166668f, 56.833332f, 58.166668f, 59.5}; +- std::vector roi{}; +- std::vector scales{3 / 4.0f, 3 / 4.0f, 3 / 4.0f}; +- std::vector output_shape{1, 1, 3, 3, 3}; +- std::vector axes{2, 3, 4}; +- +- OpTester test("Resize", 18); +- +- test.AddAttribute("exclude_outside", 0LL); +- test.AddAttribute>("axes", axes); +- test.AddAttribute("antialias", 0LL); +- test.AddAttribute("mode", "linear"); +- +- test.AddInput("X", {1, 1, 4, 4, 4}, X); +- test.AddInput("roi", {int64_t(roi.size())}, roi); +- test.AddInput("scales", {int64_t(scales.size())}, scales, true); +- +- test.AddOutput("Y", output_shape, Y); +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); +-} +- +-TEST(ResizeOpTest, Axes_and_Size_18) { +- std::vector X(16 * 4); +- std::iota(X.begin(), X.end(), 0.f); +- std::vector Y = {3.5f, 4.8333335f, 6.1666665f, 8.833333f, 10.166667f, 11.5f, 14.166667f, +- 15.5f, 16.833334f, 24.833334f, 26.166666f, 27.5f, 30.166666f, 31.5f, +- 32.833332f, 35.5f, 36.833332f, 38.166668f, 46.166668f, 47.5f, 48.833332f, +- 51.5f, 52.833332f, 54.166668f, 56.833332f, 58.166668f, 59.5}; +- std::vector roi{}; +- std::vector scales{}; +- std::vector output_shape{1, 1, 3, 3, 3}; +- std::vector axes{2, 3, 4}; +- +- OpTester test("Resize", 18); +- +- test.AddAttribute("exclude_outside", 0LL); +- test.AddAttribute>("axes", axes); +- test.AddAttribute("antialias", 0LL); +- test.AddAttribute("mode", "linear"); +- +- test.AddInput("X", {1, 1, 4, 4, 4}, X); +- test.AddInput("roi", {int64_t(roi.size())}, roi); +- test.AddInput("", {0}, scales); +- test.AddInput("sizes", {3}, {3, 3, 3}); +- +- test.AddOutput("Y", output_shape, Y); +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); +-} +- + } // namespace test + } // namespace onnxruntime +diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +index 30e27bb15..9b44bf400 100644 +--- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc ++++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +@@ -302,137 +302,5 @@ TEST(Scatter, BoolInputWithAxis) { + scatter_bool_with_axis_tests("ScatterElements", 11); + } + +-TEST(ScatterElements, AddReduction) { +- OpTester test("ScatterElements", 18); +- test.AddAttribute("axis", 0); +- test.AddAttribute("reduction", "add"); +- +- test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); +- test.AddInput("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); +- test.AddInput("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f}); +- test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)}); +- +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +-} +- +-TEST(ScatterElements, AddReductionAxis1) { +- OpTester test("ScatterElements", 18); +- test.AddAttribute("axis", 1); +- test.AddAttribute("reduction", "add"); +- +- // update's slice shape is {2, 1} +- test.AddInput("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f}); +- test.AddInput("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1}); +- test.AddInput("updates", {2, 4}, {2.f, 5.f, 3.f, 6.f, 7.f, 9.f, 8.f, 10.f}); +- test.AddOutput("y", {2, 3}, {9.f, 4.f + (2.f + 5.f + 3.f + 6.f), 1.f, 7.f, 3.f + (7.f + 9.f + 8.f + 10.f), 6.f}); +- +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +-} +- +-TEST(ScatterElements, MulReduction) { +- OpTester test("ScatterElements", 18); +- test.AddAttribute("axis", 0); +- test.AddAttribute("reduction", "mul"); +- +- test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); +- test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); +- test.AddInput("updates", {2, 3}, {7.f, 3.f, 6.f, 7.f, 3.f, 6.f}); +- test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f * 7.f * 7.f, -3.f * 3.f * 3.f, -6.f * 6.f * 6.f}); +- +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +-} +- +-TEST(ScatterElements, MulReductionAxis1) { +- OpTester test("ScatterElements", 18); +- test.AddAttribute("axis", 1); +- test.AddAttribute("reduction", "mul"); +- +- // update's slice shape is {2, 1} +- test.AddInput("data", {2, 3}, {9.f, 4.f, 1.f, 7.f, 3.f, 6.f}); +- test.AddInput("indices", {2, 4}, {1, 1, 1, 1, 1, 1, 1, 1}); +- test.AddInput("updates", {2, 4}, {2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}); +- test.AddOutput("y", {2, 3}, {9.f, 4.f * (2.f * 3.f * 4.f * 5.f), 1.f, 7.f, 3.f * (6.f * 7.f * 8.f * 9.f), 6.f}); +- +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +-} +- +-TEST(ScatterElements, MaxReduction_MLFloat16) { +- OpTester test("ScatterElements", 18); +- test.AddAttribute("axis", 0); +- test.AddAttribute("reduction", "max"); +- +- test.AddInput("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, -7.f, -3.f, -6.f})); +- test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); +- test.AddInput("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f})); +- test.AddOutput("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 7.f, 5.f, 6.f})); +- +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +-} +- +-TEST(ScatterElements, MaxReduction_Float) { +- OpTester test("ScatterElements", 18); +- test.AddAttribute("axis", 0); +- test.AddAttribute("reduction", "max"); +- +- test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); +- test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); +- test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); +- test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}); +- +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +-} +- +-TEST(ScatterElements, MaxReduction_Double) { +- OpTester test("ScatterElements", 18); +- test.AddAttribute("axis", 0); +- test.AddAttribute("reduction", "max"); +- +- test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f}); +- test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); +- test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); +- test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 7.f, 5.f, 6.f}); +- +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +-} +- +-TEST(ScatterElements, MinReduction_MLFloat16) { +- OpTester test("ScatterElements", 18); +- test.AddAttribute("axis", 0); +- test.AddAttribute("reduction", "min"); +- +- test.AddInput("data", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 8.f, -3.f, 5.f})); +- test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); +- test.AddInput("updates", {2, 3}, ToFloat16({1.f, 5.f, 3.f, 7.f, 3.f, 6.f})); +- test.AddOutput("y", {2, 3}, ToFloat16({-9.f, -4.f, -1.f, 1.f, -3.f, 3.f})); +- +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +-} +- +-TEST(ScatterElements, MinReduction_Float) { +- OpTester test("ScatterElements", 18); +- test.AddAttribute("axis", 0); +- test.AddAttribute("reduction", "min"); +- +- test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f}); +- test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); +- test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); +- test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}); +- +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +-} +- +-TEST(ScatterElements, MinReduction_Double) { +- OpTester test("ScatterElements", 18); +- test.AddAttribute("axis", 0); +- test.AddAttribute("reduction", "min"); +- +- test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, 8.f, -3.f, 5.f}); +- test.AddInput("indices", {2, 3}, {1, 1, 1, 1, 1, 1}); +- test.AddInput("updates", {2, 3}, {1.f, 5.f, 3.f, 7.f, 3.f, 6.f}); +- test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, 1.f, -3.f, 3.f}); +- +- test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); +-} +- + } // namespace test + } // namespace onnxruntime +diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +index 6514feadf..06da2a530 100644 +--- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc ++++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +@@ -70,11 +70,7 @@ TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { + auto op = + ConvTransposeOp{.input_dims = {1, 8, 80, 80}, .kernel_shape = {5, 5}, .channels = 16, .bias = true}; + +- if (HasCudaEnvironment(800)) { +- MAKE_PROVIDERS_EPS(1e-2) +- } else { +- MAKE_PROVIDERS_EPS_TYPE(TypeParam) +- } ++ MAKE_PROVIDERS_EPS_TYPE(TypeParam) + } + + TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcPad) { +diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +index 0167f7a77..957443c23 100644 +--- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc ++++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.cc +@@ -85,7 +85,7 @@ constexpr const char* INTERNAL_TESTING_EP = "InternalTestingEP"; + InternalTestingExecutionProvider::InternalTestingExecutionProvider(const std::unordered_set& ops, + const std::unordered_set& stop_ops, + DataLayout preferred_layout) +- : IExecutionProvider{utils::kInternalTestingExecutionProvider}, ++ : IExecutionProvider{utils::kInternalTestingExecutionProvider, true}, + ep_name_{INTERNAL_TESTING_EP}, + ops_{ops}, + stop_ops_{stop_ops}, +@@ -212,7 +212,7 @@ InternalTestingExecutionProvider::GetCapability(const onnxruntime::GraphViewer& + // create functor to generate a guaranteed unique metadef id + auto generate_metadef_name = [this, &graph_viewer]() { + HashValue model_hash; +- int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); ++ int metadef_id = GenerateMetaDefId(graph_viewer, model_hash); + auto meta_def = std::make_unique<::onnxruntime::IndexedSubGraph::MetaDef>(); + return ep_name_ + "_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id); + }; +diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +index 6615eb82f..610335262 100644 +--- a/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h ++++ b/onnxruntime/test/providers/internal_testing/internal_testing_execution_provider.h +@@ -4,7 +4,6 @@ + #pragma once + #include + #include "core/framework/execution_provider.h" +-#include "core/framework/model_metadef_id_generator.h" + + namespace onnxruntime { + namespace internal_testing_ep { +@@ -83,7 +82,6 @@ class InternalTestingExecutionProvider : public IExecutionProvider { + // per-instance kernel registry so tests using static kernels don't clash. + // shared_ptr as required by IExecutionProvider::GetKernelRegistry + std::shared_ptr kernel_registry_; +- ModelMetadefIdGenerator metadef_id_generator_; + }; + + } // namespace internal_testing_ep +diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc +index 4e1aef2c4..c50b1002f 100644 +--- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc ++++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc +@@ -613,6 +613,94 @@ static GetTestModelFn BuildCastAddTestCase() { + }; + } + ++// Test that models with 2 inputs which has different data type can still generate the context binary ++TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { ++ ProviderOptions provider_options; ++#if defined(_WIN32) ++ provider_options["backend_path"] = "QnnHtp.dll"; ++#else ++ provider_options["backend_path"] = "libQnnHtp.so"; ++#endif ++ ++ // Add kMSDomain to cover contrib op like Gelu ++ const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; ++ ++ auto& logging_manager = DefaultLoggingManager(); ++ logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); ++ ++ onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), ++ IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, ++ logging_manager.DefaultLogger()); ++ Graph& graph = model.MainGraph(); ++ ModelTestBuilder helper(graph); ++ BuildCastAddTestCase()(helper); ++ helper.SetGraphOutputs(); ++ ASSERT_STATUS_OK(model.MainGraph().Resolve()); ++ ++ // Serialize the model to a string. ++ std::string model_data; ++ model.ToProto().SerializeToString(&model_data); ++ ++ const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); ++ ++ const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; ++ Ort::SessionOptions so; ++ so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); ++ so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); ++ ++ so.AppendExecutionProvider("QNN", provider_options); ++ ++ Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); ++ ++ // Make sure the Qnn context cache binary file is generated ++ EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); ++ ++ // clean up ++ ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); ++} ++ ++// Generate context cache model from the ONNX models with 2 inputs. ++// The generated model should have same input order. ++// The input ONNX model is created in the way that the model inputs order ++// is different with the order in the graph (topological order). ++// It cause issue if the generated model doesn't set the inputs/outputs explicitly. ++TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { ++ ProviderOptions provider_options; ++#if defined(_WIN32) ++ provider_options["backend_path"] = "QnnHtp.dll"; ++#else ++ provider_options["backend_path"] = "libQnnHtp.so"; ++#endif ++ ++ // Add kMSDomain to cover contrib op like Gelu ++ const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; ++ ++ auto& logging_manager = DefaultLoggingManager(); ++ logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); ++ ++ const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; ++ Ort::SessionOptions so; ++ so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); ++ so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); ++ ++ so.AppendExecutionProvider("QNN", provider_options); ++ ++ Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); ++ ++ // Make sure the Qnn context cache binary file is generated ++ EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); ++ ++ std::shared_ptr model; ++ ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); ++ auto inputs = model->MainGraph().GetInputs(); ++ EXPECT_TRUE(inputs.size() == 2); ++ EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); ++ EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); ++ ++ // clean up ++ ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); ++} ++ + // A repro of QC case 06838696, accuracy issue for Cast + Op (quantized) + // the value pair(1, 0.00392156886) at index #1 don't match, + // which is -0.996078 from 1 +diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +deleted file mode 100644 +index b1f3b52e7..000000000 +--- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc ++++ /dev/null +@@ -1,657 +0,0 @@ +-// Copyright (c) Microsoft Corporation. All rights reserved. +-// Licensed under the MIT License. +- +-#include +-#include +- +-#include "core/session/onnxruntime_cxx_api.h" +-#include "core/session/onnxruntime_session_options_config_keys.h" +-#include "core/session/inference_session.h" +- +-#include "test/providers/qnn/qnn_test_utils.h" +- +-#include "gtest/gtest.h" +-#include "gmock/gmock.h" +- +-using namespace ONNX_NAMESPACE; +-using namespace onnxruntime::logging; +- +-// in test_main.cc +-extern std::unique_ptr ort_env; +- +-namespace onnxruntime { +-namespace test { +- +-#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +- +-// Create a model with Case + Add (quantized) +-// input1 -> Add -> Q -> DQ \ +-// FusedMatMul -> Q -> DQ -> output +-// input2 -> Q -> DQ / +-static GetTestModelFn BuildGraphWithQAndNonQ(bool single_ep_node = true) { +- return [single_ep_node](ModelTestBuilder& builder) { +- // Creat non-quantized Add node1 +- NodeArg* input1 = MakeTestInput(builder, TestInputDef({2, 2}, false, {0, 1, 0, 1})); +- NodeArg* add1_ini_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, {0, 0, 0, 0})); +- +- auto* add1_output = builder.MakeIntermediate(); +- builder.AddNode("FusedMatMul", {input1, add1_ini_input2}, {add1_output}, kMSDomain); +- +- // Create quantized Add node2 +- std::vector data = {0.0f, 0.0f, 1.0f, 0.0f}; +- gsl::span data_range = gsl::make_span(data); +- QuantParams q_parameter = GetDataQuantParams(data_range); +- auto* add2_input1_qdq = AddQDQNodePair(builder, add1_output, q_parameter.scale, q_parameter.zero_point); +- +- NodeArg* add2_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, data)); +- auto* add2_input2_qdq = AddQDQNodePair(builder, add2_input2, q_parameter.scale, q_parameter.zero_point); +- +- auto* add2_output = builder.MakeIntermediate(); +- +- builder.AddNode("Add", {add2_input1_qdq, add2_input2_qdq}, {add2_output}); +- +- if (single_ep_node) { +- // add_output -> Q -> DQ -> output +- AddQDQNodePairWithOutputAsGraphOutput(builder, add2_output, q_parameter.scale, q_parameter.zero_point); +- } else { +- auto* add3_input1_qdq = AddQDQNodePair(builder, add2_output, q_parameter.scale, q_parameter.zero_point); +- NodeArg* add3_ini_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, {0, 0, 0, 0})); +- +- auto* add3_output = builder.MakeIntermediate(); +- builder.AddNode("FusedMatMul", {add3_input1_qdq, add3_ini_input2}, {add3_output}, kMSDomain); +- +- // Create quantized Add node4 +- auto* add4_input1_qdq = AddQDQNodePair(builder, add3_output, q_parameter.scale, q_parameter.zero_point); +- +- NodeArg* add4_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, data)); +- auto* add4_input2_qdq = AddQDQNodePair(builder, add4_input2, q_parameter.scale, q_parameter.zero_point); +- +- auto* add4_output = builder.MakeIntermediate(); +- +- builder.AddNode("Add", {add4_input1_qdq, add4_input2_qdq}, {add4_output}); +- // add_output -> Q -> DQ -> output +- AddQDQNodePairWithOutputAsGraphOutput(builder, add4_output, q_parameter.scale, q_parameter.zero_point); +- } +- }; +-} +- +-void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { +- ProviderOptions provider_options; +-#if defined(_WIN32) +- provider_options["backend_path"] = "QnnHtp.dll"; +-#else +- provider_options["backend_path"] = "libQnnHtp.so"; +-#endif +- +- const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; +- +- auto& logging_manager = DefaultLoggingManager(); +- logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); +- +- onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), +- IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, +- logging_manager.DefaultLogger()); +- Graph& graph = model.MainGraph(); +- ModelTestBuilder helper(graph); +- BuildGraphWithQAndNonQ(single_ep_node)(helper); +- helper.SetGraphOutputs(); +- ASSERT_STATUS_OK(model.MainGraph().Resolve()); +- +- // Serialize the model to a string. +- std::string model_data; +- model.ToProto().SerializeToString(&model_data); +- +- const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); +- +- const std::string context_binary_file = "./qnn_context_binary_multi_partition_test.onnx"; +- std::remove(context_binary_file.c_str()); +- Ort::SessionOptions so; +- so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); +- so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); +- so.AppendExecutionProvider("QNN", provider_options); +- +- Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); +- +- // Make sure the Qnn context cache binary file is generated +- EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); +- +- int ep_context_node_count = 0; +- int non_ep_context_node_count = 0; +- std::shared_ptr ctx_model; +- ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); +- auto& ctx_graph = ctx_model->MainGraph(); +- for (auto& node : ctx_graph.Nodes()) { +- if (node.OpType() == "EPContext") { +- ++ep_context_node_count; +- } else { +- ++non_ep_context_node_count; +- } +- } +- +- int expected_node_count = single_ep_node ? 1 : 2; +- ASSERT_EQ(ep_context_node_count, expected_node_count); +- ASSERT_EQ(non_ep_context_node_count, expected_node_count); +- +- Ort::SessionOptions so2; +- // context file path is required if it's non-embed mode and the model is loaded from memory +- so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); +- so2.AppendExecutionProvider("QNN", provider_options); +- +- std::string ctx_model_data; +- ctx_model->ToProto().SerializeToString(&ctx_model_data); +- Ort::Session session2(*ort_env, ctx_model_data.data(), ctx_model_data.size(), so2); +- +- // clean up +- ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +-} +- +-// Test that models with 1 non-quantized Add node and 1 quantized Add node can still generate the context binary +-// The generated Onnx model has 1 Add node and 1 EPContext node +-TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { +- bool single_ep_node = true; +- QnnContextBinaryMultiPartitionTestBody(single_ep_node); +-} +- +-// Test that models with 2 non-quantized Add nodes and 2 quantized Add nodes can still generate the context binary +-// The generated Onnx model has 2 Add nodes and 1 EPContext nodes +-TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport2) { +- bool single_ep_node = false; +- QnnContextBinaryMultiPartitionTestBody(single_ep_node); +-} +- +-// Create a model with Case + Add (quantized) +-// cast_input -> Cast -> Q -> DQ \ +-// Add -> Q -> DQ -> output +-// input2 -> Q -> DQ / +-static GetTestModelFn BuildCastAddTestCase() { +- return [](ModelTestBuilder& builder) { +- // Creat Cast node int32 -> float32 +- NodeArg* cast_input = MakeTestInput(builder, TestInputDef({2, 3}, false, {0, 1, 0, 1, 0, 1})); +- +- auto* cast_output = builder.MakeIntermediate(); +- Node& cast_node = builder.AddNode("Cast", {cast_input}, {cast_output}); +- cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); +- +- // Create Add node +- std::vector data = {0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; +- gsl::span data_range = gsl::make_span(data); +- QuantParams q_parameter = GetDataQuantParams(data_range); +- auto* add_input1_qdq = AddQDQNodePair(builder, cast_output, q_parameter.scale, q_parameter.zero_point); +- +- NodeArg* add_input2 = MakeTestInput(builder, TestInputDef({2, 3}, false, data)); +- auto* add_input2_qdq = AddQDQNodePair(builder, add_input2, q_parameter.scale, q_parameter.zero_point); +- +- auto* add_output = builder.MakeIntermediate(); +- +- builder.AddNode("Add", {add_input1_qdq, add_input2_qdq}, {add_output}); +- +- // add_output -> Q -> DQ -> output +- AddQDQNodePairWithOutputAsGraphOutput(builder, add_output, q_parameter.scale, q_parameter.zero_point); +- }; +-} +- +-// Test that models with 2 inputs which has different data type can still generate the context binary +-TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { +- ProviderOptions provider_options; +-#if defined(_WIN32) +- provider_options["backend_path"] = "QnnHtp.dll"; +-#else +- provider_options["backend_path"] = "libQnnHtp.so"; +-#endif +- +- const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; +- +- auto& logging_manager = DefaultLoggingManager(); +- logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); +- +- onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), +- IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, +- logging_manager.DefaultLogger()); +- Graph& graph = model.MainGraph(); +- ModelTestBuilder helper(graph); +- BuildCastAddTestCase()(helper); +- helper.SetGraphOutputs(); +- ASSERT_STATUS_OK(model.MainGraph().Resolve()); +- +- // Serialize the model to a string. +- std::string model_data; +- model.ToProto().SerializeToString(&model_data); +- +- const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); +- +- const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; +- std::remove(context_binary_file.c_str()); +- Ort::SessionOptions so; +- so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); +- so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); +- +- so.AppendExecutionProvider("QNN", provider_options); +- +- Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); +- +- // Make sure the Qnn context cache binary file is generated +- EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); +- +- // clean up +- ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +-} +- +-// Generate context cache model from the ONNX models with 2 inputs. +-// The generated model should have same input order. +-// The input ONNX model is created in the way that the model inputs order +-// is different with the order in the graph (topological order). +-// It cause issue if the generated model doesn't set the inputs/outputs explicitly. +-TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { +- ProviderOptions provider_options; +-#if defined(_WIN32) +- provider_options["backend_path"] = "QnnHtp.dll"; +-#else +- provider_options["backend_path"] = "libQnnHtp.so"; +-#endif +- +- // Add kMSDomain to cover contrib op like Gelu +- const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; +- +- auto& logging_manager = DefaultLoggingManager(); +- logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); +- +- const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; +- Ort::SessionOptions so; +- so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); +- so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); +- so.AppendExecutionProvider("QNN", provider_options); +- +- Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); +- +- // Make sure the Qnn context cache binary file is generated +- EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); +- +- std::shared_ptr model; +- ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); +- auto inputs = model->MainGraph().GetInputs(); +- EXPECT_TRUE(inputs.size() == 2); +- EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); +- EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); +- +- // clean up +- ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +-} +- +-// Run QDQ model on HTP 3 times +-// 1st run will generate the Qnn context cache onnx file +-// 2nd run directly loads and run from Qnn context cache model +-TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { +- ProviderOptions provider_options; +-#if defined(_WIN32) +- provider_options["backend_path"] = "QnnHtp.dll"; +-#else +- provider_options["backend_path"] = "libQnnHtp.so"; +-#endif +- const std::string context_binary_file = "./qnn_context_binary_test.onnx"; +- std::remove(context_binary_file.c_str()); +- +- std::unordered_map session_option_pairs; +- session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); +- session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); +- +- const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); +- const std::string op_type = "Atan"; +- +- // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. +- // 1st run will generate the Qnn context cache binary file +- TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), +- BuildQDQOpTestCase(op_type, {input_def}, {}, {}), +- provider_options, +- 14, +- ExpectedEPNodeAssignment::All, +- QDQTolerance(), +- logging::Severity::kERROR, +- "", // context model file path, not required for this inference +- session_option_pairs); +- +- // Make sure the Qnn context cache binary file is generated +- EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); +- +- // 2nd run directly loads and run from Qnn context cache model +- TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), +- BuildQDQOpTestCase(op_type, {input_def}, {}, {}), +- provider_options, +- 14, +- ExpectedEPNodeAssignment::All, +- QDQTolerance(), +- logging::Severity::kERROR, +- context_binary_file); +- // Clean up +- ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +-} +- +-// Run QDQ model on HTP 3 times +-// 1st run will generate the Onnx skeleton file + Qnn context cache binary file +-// 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file +-TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { +- ProviderOptions provider_options; +-#if defined(_WIN32) +- provider_options["backend_path"] = "QnnHtp.dll"; +-#else +- provider_options["backend_path"] = "libQnnHtp.so"; +-#endif +- const std::string context_binary_file = "./testdata/qnn_context_cache_non_embed.onnx"; +- std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; +- +- std::unordered_map session_option_pairs; +- session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); +- session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); +- session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); +- +- std::remove(context_binary_file.c_str()); +- std::remove(qnn_ctx_bin.c_str()); +- +- const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); +- const std::string op_type = "Atan"; +- +- // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. +- // 1st run will generate the Onnx skeleton file + Qnn context cache binary file +- TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), +- BuildQDQOpTestCase(op_type, {input_def}, {}, {}), +- provider_options, +- 14, +- ExpectedEPNodeAssignment::All, +- QDQTolerance(), +- logging::Severity::kERROR, +- "", // context model file path, not required for this inference +- session_option_pairs); +- +- // Check the Onnx skeleton file is generated +- EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); +- // Check the Qnn context cache binary file is generated +- EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin)); +- +- std::unordered_map session_option_pairs2; +- // Need to set the context file path since TestQDQModelAccuracy load the model from memory +- session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); +- // 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file +- TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), +- BuildQDQOpTestCase(op_type, {input_def}, {}, {}), +- provider_options, +- 14, +- ExpectedEPNodeAssignment::All, +- QDQTolerance(), +- logging::Severity::kERROR, +- context_binary_file, +- session_option_pairs2); +- +- // load the model from file +- std::vector buffer; +- { +- std::ifstream file(context_binary_file, std::ios::binary | std::ios::ate); +- if (!file) +- ORT_THROW("Error reading model"); +- buffer.resize(narrow(file.tellg())); +- file.seekg(0, std::ios::beg); +- if (!file.read(buffer.data(), buffer.size())) +- ORT_THROW("Error reading model"); +- } +- +- Ort::SessionOptions so; // No need to set the context file path in so since it's load from file +- so.AppendExecutionProvider("QNN", provider_options); +-#ifdef _WIN32 +- std::wstring ctx_model_file(context_binary_file.begin(), context_binary_file.end()); +-#else +- std::string ctx_model_file(context_binary_file.begin(), context_binary_file.end()); +-#endif +- Ort::Session session(*ort_env.get(), ctx_model_file.c_str(), so); +- +- // Clean up +- ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +- ASSERT_EQ(std::remove(qnn_ctx_bin.c_str()), 0); +-} +- +-// Run QDQ model on HTP 2 times +-// 1st run will generate the Onnx skeleton file + Qnn context cache binary file +-// Then delete the context bin file to make the 2nd sesssion.Initialize() return the status with code INVALID_GRAPH +-TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) { +- ProviderOptions provider_options; +-#if defined(_WIN32) +- provider_options["backend_path"] = "QnnHtp.dll"; +-#else +- provider_options["backend_path"] = "libQnnHtp.so"; +-#endif +- const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; +- std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; +- std::remove(context_binary_file.c_str()); +- std::remove(context_bin.string().c_str()); +- +- std::unordered_map session_option_pairs; +- session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); +- session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); +- session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); +- +- const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); +- const std::string op_type = "Atan"; +- +- // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. +- // 1st run will generate the Onnx skeleton file + Qnn context cache binary file +- TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), +- BuildQDQOpTestCase(op_type, {input_def}, {}, {}), +- provider_options, +- 14, +- ExpectedEPNodeAssignment::All, +- QDQTolerance(), +- logging::Severity::kERROR, +- "", // context model file path, not required for this inference +- session_option_pairs); +- +- // Check the Onnx skeleton file is generated +- EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); +- // Check the Qnn context cache binary file is generated +- EXPECT_TRUE(std::filesystem::exists(context_bin)); +- // Delete the Qnn context cache binary file +- EXPECT_TRUE(std::filesystem::remove(context_bin)); +- +- // loads and run from Onnx skeleton file + Qnn context cache binary file +- onnx::ModelProto model_proto; +- onnxruntime::Model qnn_ctx_model; +- // Load the QNN context cache model from path specified +- ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(context_binary_file), model_proto)); +- std::string qnn_ctx_model_data; +- model_proto.SerializeToString(&qnn_ctx_model_data); +- +- SessionOptions so; +- so.session_logid = "qnn_ctx_model_logger"; +- RunOptions run_options; +- run_options.run_tag = so.session_logid; +- +- InferenceSessionWrapper session_object{so, GetEnvironment()}; +- +- std::string provider_type = kCpuExecutionProvider; +- ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); +- ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); +- // Verify the return status with code INVALID_GRAPH +- ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +- +- // Clean up +- ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +-} +- +-std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { +- const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; +- auto& logging_manager = DefaultLoggingManager(); +- onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), +- IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, +- logging_manager.DefaultLogger()); +- Graph& graph = model.MainGraph(); +- ModelTestBuilder helper(graph); +- std::vector shape = {2, 3}; +- NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, true, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); +- auto* graph_output = helper.MakeOutput(shape); +- Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); +- ep_context_node.AddAttribute("embed_mode", static_cast(0)); +- // The .. in the path will cause INVALID_GRAPH +- ep_context_node.AddAttribute("ep_cache_context", external_bin_path); +- ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); +- ep_context_node.AddAttribute("source", "QNN"); +- helper.SetGraphOutputs(); +- std::string model_data; +- model.ToProto().SerializeToString(&model_data); +- +- return model_data; +-} +- +-// Create a model with EPContext node. Set the node property ep_cache_context has ".." +-// Verify that it return INVALID_GRAPH status +-TEST_F(QnnHTPBackendTests, QnnContextBinaryRelativePathTest) { +- std::string model_data = CreateQnnCtxModelWithNonEmbedMode("../qnn_context.bin"); +- +- SessionOptions so; +- so.session_logid = "qnn_ctx_model_logger"; +- RunOptions run_options; +- run_options.run_tag = so.session_logid; +- +- InferenceSessionWrapper session_object{so, GetEnvironment()}; +- +- ProviderOptions provider_options; +-#if defined(_WIN32) +- provider_options["backend_path"] = "QnnHtp.dll"; +-#else +- provider_options["backend_path"] = "libQnnHtp.so"; +-#endif +- +- ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); +- ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); +- // Verify the return status with code INVALID_GRAPH +- ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +-} +- +-// Create a model with EPContext node. Set the node property ep_cache_context has absolute path +-// Verify that it return INVALID_GRAPH status +-TEST_F(QnnHTPBackendTests, QnnContextBinaryAbsolutePathTest) { +-#if defined(_WIN32) +- std::string external_ctx_bin_path = "D:/qnn_context.bin"; +-#else +- std::string external_ctx_bin_path = "/data/qnn_context.bin"; +-#endif +- std::string model_data = CreateQnnCtxModelWithNonEmbedMode(external_ctx_bin_path); +- +- SessionOptions so; +- so.session_logid = "qnn_ctx_model_logger"; +- RunOptions run_options; +- run_options.run_tag = so.session_logid; +- +- InferenceSessionWrapper session_object{so, GetEnvironment()}; +- +- ProviderOptions provider_options; +-#if defined(_WIN32) +- provider_options["backend_path"] = "QnnHtp.dll"; +-#else +- provider_options["backend_path"] = "libQnnHtp.so"; +-#endif +- +- ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); +- ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); +- // Verify the return status with code INVALID_GRAPH +- ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +-} +- +-// Create a model with EPContext node. Set the node property ep_cache_context to a file not exist +-// Verify that it return INVALID_GRAPH status +-TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { +- std::string model_data = CreateQnnCtxModelWithNonEmbedMode("qnn_context_not_exist.bin"); +- +- SessionOptions so; +- so.session_logid = "qnn_ctx_model_logger"; +- RunOptions run_options; +- run_options.run_tag = so.session_logid; +- +- InferenceSessionWrapper session_object{so, GetEnvironment()}; +- +- ProviderOptions provider_options; +-#if defined(_WIN32) +- provider_options["backend_path"] = "QnnHtp.dll"; +-#else +- provider_options["backend_path"] = "libQnnHtp.so"; +-#endif +- +- ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); +- ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); +- // Verify the return status with code INVALID_GRAPH +- ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +-} +- +-// Create a model with EPContext node. Set the node property ep_cache_context to empty string +-// Verify that it return INVALID_GRAPH status +-TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { +- std::string model_data = CreateQnnCtxModelWithNonEmbedMode(""); +- +- SessionOptions so; +- so.session_logid = "qnn_ctx_model_logger"; +- RunOptions run_options; +- run_options.run_tag = so.session_logid; +- +- InferenceSessionWrapper session_object{so, GetEnvironment()}; +- +- ProviderOptions provider_options; +-#if defined(_WIN32) +- provider_options["backend_path"] = "QnnHtp.dll"; +-#else +- provider_options["backend_path"] = "libQnnHtp.so"; +-#endif +- +- ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); +- ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); +- // Verify the return status with code INVALID_GRAPH +- ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +-} +- +-// Run QDQ model on HTP with 2 inputs +-// 1st run will generate the Qnn context cache onnx file +-// 2nd run directly loads and run from Qnn context cache model +-TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { +- ProviderOptions provider_options; +-#if defined(_WIN32) +- provider_options["backend_path"] = "QnnHtp.dll"; +-#else +- provider_options["backend_path"] = "libQnnHtp.so"; +-#endif +- const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; +- std::remove(context_binary_file.c_str()); +- +- std::unordered_map session_option_pairs; +- session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); +- session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); +- +- const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); +- const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); +- const std::string op_type = "Add"; +- +- // Runs model with DQ-> Add-> Q and compares the outputs of the CPU and QNN EPs. +- // 1st run will generate the Qnn context cache binary file +- TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), +- BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), +- provider_options, +- 14, +- ExpectedEPNodeAssignment::All, +- QDQTolerance(), +- logging::Severity::kERROR, +- "", // context model file path, not required for this inference +- session_option_pairs); +- +- // Make sure the Qnn context cache binary file is generated +- EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); +- +- // 2nd run directly loads and run from Qnn context cache model +- TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), +- BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), +- provider_options, +- 14, +- ExpectedEPNodeAssignment::All, +- QDQTolerance(), +- logging::Severity::kERROR, +- context_binary_file); +- // Clean up +- ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +-} +- +-#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +- +-} // namespace test +-} // namespace onnxruntime +diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h +index f4febd99d..bfe5bab31 100644 +--- a/onnxruntime/test/providers/qnn/qnn_test_utils.h ++++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h +@@ -361,7 +361,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe + model_proto.SerializeToString(&qnn_ctx_model_data); + // Run QNN context cache model on QNN EP and collect outputs. + InferenceModel(qnn_ctx_model_data, "qnn_ctx_model_logger", qnn_options, +- expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs); ++ expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep); + } else { + // Run QDQ model on QNN EP and collect outputs. + // Only need to apply the extra session options to this QDQ model inference on QNN EP +diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +index 2f3b0e84a..556b579e9 100644 +--- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc ++++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +@@ -723,6 +723,371 @@ TEST_F(QnnHTPBackendTests, SpaceToDepthOp_U16) { + true); // Use com.microsoft domain for Q/DQ ops + } + ++// Run QDQ model on HTP 3 times ++// 1st run will generate the Qnn context cache onnx file ++// 2nd run will load and run from QDQ model + Qnn context cache model ++// 3rd run directly loads and run from Qnn context cache model ++TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { ++ ProviderOptions provider_options; ++#if defined(_WIN32) ++ provider_options["backend_path"] = "QnnHtp.dll"; ++#else ++ provider_options["backend_path"] = "libQnnHtp.so"; ++#endif ++ const std::string context_binary_file = "./qnn_context_binary_test.onnx"; ++ ++ std::unordered_map session_option_pairs; ++ session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); ++ session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); ++ ++ const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); ++ const std::string op_type = "Atan"; ++ ++ // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. ++ // 1st run will generate the Qnn context cache binary file ++ TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), ++ BuildQDQOpTestCase(op_type, {input_def}, {}, {}), ++ provider_options, ++ 14, ++ ExpectedEPNodeAssignment::All, ++ QDQTolerance(), ++ logging::Severity::kERROR, ++ "", // context model file path, not required for this inference ++ session_option_pairs); ++ ++ // Make sure the Qnn context cache binary file is generated ++ EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); ++ ++ // 2nd run loads and run from QDQ model + Qnn context cache model ++ TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), ++ BuildQDQOpTestCase(op_type, {input_def}, {}, {}), ++ provider_options, ++ 14, ++ ExpectedEPNodeAssignment::All, ++ QDQTolerance(), ++ logging::Severity::kERROR, ++ "", // context model file path, not required for this inference ++ session_option_pairs); ++ ++ // 3rd run directly loads and run from Qnn context cache model ++ TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), ++ BuildQDQOpTestCase(op_type, {input_def}, {}, {}), ++ provider_options, ++ 14, ++ ExpectedEPNodeAssignment::All, ++ QDQTolerance(), ++ logging::Severity::kERROR, ++ context_binary_file); ++ // Clean up ++ ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); ++} ++ ++// Run QDQ model on HTP 3 times ++// 1st run will generate the Onnx skeleton file + Qnn context cache binary file ++// 2nd run will loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file ++// 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file ++TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { ++ ProviderOptions provider_options; ++#if defined(_WIN32) ++ provider_options["backend_path"] = "QnnHtp.dll"; ++#else ++ provider_options["backend_path"] = "libQnnHtp.so"; ++#endif ++ const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; ++ std::unordered_map session_option_pairs; ++ session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); ++ session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); ++ session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); ++ ++ const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); ++ const std::string op_type = "Atan"; ++ ++ // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. ++ // 1st run will generate the Onnx skeleton file + Qnn context cache binary file ++ TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), ++ BuildQDQOpTestCase(op_type, {input_def}, {}, {}), ++ provider_options, ++ 14, ++ ExpectedEPNodeAssignment::All, ++ QDQTolerance(), ++ logging::Severity::kERROR, ++ "", // context model file path, not required for this inference ++ session_option_pairs); ++ ++ // Check the Onnx skeleton file is generated ++ EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); ++ // Check the Qnn context cache binary file is generated ++ EXPECT_TRUE(std::filesystem::exists("qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin")); ++ ++ // 2nd run loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file ++ TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), ++ BuildQDQOpTestCase(op_type, {input_def}, {}, {}), ++ provider_options, ++ 14, ++ ExpectedEPNodeAssignment::All, ++ QDQTolerance(), ++ logging::Severity::kERROR, ++ "", // context model file path, not required for this inference ++ session_option_pairs); ++ ++ // 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file ++ TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), ++ BuildQDQOpTestCase(op_type, {input_def}, {}, {}), ++ provider_options, ++ 14, ++ ExpectedEPNodeAssignment::All, ++ QDQTolerance(), ++ logging::Severity::kERROR, ++ context_binary_file); ++} ++ ++// Run QDQ model on HTP 2 times ++// 1st run will generate the Onnx skeleton file + Qnn context cache binary file ++// Then delete the context bin file to make the 2nd sesssion.Initialize() return the status with code INVALID_GRAPH ++TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { ++ ProviderOptions provider_options; ++#if defined(_WIN32) ++ provider_options["backend_path"] = "QnnHtp.dll"; ++#else ++ provider_options["backend_path"] = "libQnnHtp.so"; ++#endif ++ const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; ++ std::unordered_map session_option_pairs; ++ session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); ++ session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); ++ session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); ++ ++ const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); ++ const std::string op_type = "Atan"; ++ ++ // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. ++ // 1st run will generate the Onnx skeleton file + Qnn context cache binary file ++ TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), ++ BuildQDQOpTestCase(op_type, {input_def}, {}, {}), ++ provider_options, ++ 14, ++ ExpectedEPNodeAssignment::All, ++ QDQTolerance(), ++ logging::Severity::kERROR, ++ "", // context model file path, not required for this inference ++ session_option_pairs); ++ ++ // Check the Onnx skeleton file is generated ++ EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); ++ // Check the Qnn context cache binary file is generated ++ std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; ++ EXPECT_TRUE(std::filesystem::exists(context_bin)); ++ // Delete the Qnn context cache binary file ++ EXPECT_TRUE(std::filesystem::remove(context_bin)); ++ ++ // loads and run from Onnx skeleton file + Qnn context cache binary file ++ onnx::ModelProto model_proto; ++ onnxruntime::Model qnn_ctx_model; ++ // Load the QNN context cache model from path specified ++ ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(context_binary_file), model_proto)); ++ std::string qnn_ctx_model_data; ++ model_proto.SerializeToString(&qnn_ctx_model_data); ++ ++ SessionOptions so; ++ so.session_logid = "qnn_ctx_model_logger"; ++ RunOptions run_options; ++ run_options.run_tag = so.session_logid; ++ ++ InferenceSessionWrapper session_object{so, GetEnvironment()}; ++ ++ std::string provider_type = kCpuExecutionProvider; ++ ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); ++ ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); ++ // Verify the return status with code INVALID_GRAPH ++ ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); ++} ++ ++std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { ++ const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; ++ auto& logging_manager = DefaultLoggingManager(); ++ onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), ++ IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, ++ logging_manager.DefaultLogger()); ++ Graph& graph = model.MainGraph(); ++ ModelTestBuilder helper(graph); ++ std::vector shape = {2, 3}; ++ NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, true, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); ++ auto* graph_output = helper.MakeOutput(shape); ++ Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); ++ ep_context_node.AddAttribute("embed_mode", static_cast(0)); ++ // The .. in the path will cause INVALID_GRAPH ++ ep_context_node.AddAttribute("ep_cache_context", external_bin_path); ++ ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); ++ ep_context_node.AddAttribute("source", "QNN"); ++ helper.SetGraphOutputs(); ++ std::string model_data; ++ model.ToProto().SerializeToString(&model_data); ++ ++ return model_data; ++} ++ ++// Create a model with EPContext node. Set the node property ep_cache_context has ".." ++// Verify that it return INVALID_GRAPH status ++TEST_F(QnnHTPBackendTests, QnnContextBinaryRelativePathTest) { ++ std::string model_data = CreateQnnCtxModelWithNonEmbedMode("../qnn_context.bin"); ++ ++ SessionOptions so; ++ so.session_logid = "qnn_ctx_model_logger"; ++ RunOptions run_options; ++ run_options.run_tag = so.session_logid; ++ ++ InferenceSessionWrapper session_object{so, GetEnvironment()}; ++ ++ ProviderOptions provider_options; ++#if defined(_WIN32) ++ provider_options["backend_path"] = "QnnHtp.dll"; ++#else ++ provider_options["backend_path"] = "libQnnHtp.so"; ++#endif ++ ++ ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); ++ ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); ++ // Verify the return status with code INVALID_GRAPH ++ ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); ++} ++ ++// Create a model with EPContext node. Set the node property ep_cache_context has absolute path ++// Verify that it return INVALID_GRAPH status ++TEST_F(QnnHTPBackendTests, QnnContextBinaryAbsolutePathTest) { ++#if defined(_WIN32) ++ std::string external_ctx_bin_path = "D:/qnn_context.bin"; ++#else ++ std::string external_ctx_bin_path = "/data/qnn_context.bin"; ++#endif ++ std::string model_data = CreateQnnCtxModelWithNonEmbedMode(external_ctx_bin_path); ++ ++ SessionOptions so; ++ so.session_logid = "qnn_ctx_model_logger"; ++ RunOptions run_options; ++ run_options.run_tag = so.session_logid; ++ ++ InferenceSessionWrapper session_object{so, GetEnvironment()}; ++ ++ ProviderOptions provider_options; ++#if defined(_WIN32) ++ provider_options["backend_path"] = "QnnHtp.dll"; ++#else ++ provider_options["backend_path"] = "libQnnHtp.so"; ++#endif ++ ++ ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); ++ ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); ++ // Verify the return status with code INVALID_GRAPH ++ ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); ++} ++ ++// Create a model with EPContext node. Set the node property ep_cache_context to a file not exist ++// Verify that it return INVALID_GRAPH status ++TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { ++ std::string model_data = CreateQnnCtxModelWithNonEmbedMode("qnn_context_not_exist.bin"); ++ ++ SessionOptions so; ++ so.session_logid = "qnn_ctx_model_logger"; ++ RunOptions run_options; ++ run_options.run_tag = so.session_logid; ++ ++ InferenceSessionWrapper session_object{so, GetEnvironment()}; ++ ++ ProviderOptions provider_options; ++#if defined(_WIN32) ++ provider_options["backend_path"] = "QnnHtp.dll"; ++#else ++ provider_options["backend_path"] = "libQnnHtp.so"; ++#endif ++ ++ ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); ++ ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); ++ // Verify the return status with code INVALID_GRAPH ++ ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); ++} ++ ++// Create a model with EPContext node. Set the node property ep_cache_context to empty string ++// Verify that it return INVALID_GRAPH status ++TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { ++ std::string model_data = CreateQnnCtxModelWithNonEmbedMode(""); ++ ++ SessionOptions so; ++ so.session_logid = "qnn_ctx_model_logger"; ++ RunOptions run_options; ++ run_options.run_tag = so.session_logid; ++ ++ InferenceSessionWrapper session_object{so, GetEnvironment()}; ++ ++ ProviderOptions provider_options; ++#if defined(_WIN32) ++ provider_options["backend_path"] = "QnnHtp.dll"; ++#else ++ provider_options["backend_path"] = "libQnnHtp.so"; ++#endif ++ ++ ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); ++ ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); ++ // Verify the return status with code INVALID_GRAPH ++ ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); ++} ++ ++// Run QDQ model on HTP with 2 inputs ++// 1st run will generate the Qnn context cache onnx file ++// 2nd run will load and run from QDQ model + Qnn context cache model ++// 3rd run directly loads and run from Qnn context cache model ++TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { ++ ProviderOptions provider_options; ++#if defined(_WIN32) ++ provider_options["backend_path"] = "QnnHtp.dll"; ++#else ++ provider_options["backend_path"] = "libQnnHtp.so"; ++#endif ++ const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; ++ std::unordered_map session_option_pairs; ++ session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); ++ session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); ++ ++ const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); ++ const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); ++ const std::string op_type = "Add"; ++ ++ // Runs model with DQ-> Add-> Q and compares the outputs of the CPU and QNN EPs. ++ // 1st run will generate the Qnn context cache binary file ++ TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), ++ BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), ++ provider_options, ++ 14, ++ ExpectedEPNodeAssignment::All, ++ QDQTolerance(), ++ logging::Severity::kERROR, ++ "", // context model file path, not required for this inference ++ session_option_pairs); ++ ++ // Make sure the Qnn context cache binary file is generated ++ EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); ++ ++ // 2nd run loads and run from QDQ model + Qnn context cache model ++ TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), ++ BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), ++ provider_options, ++ 14, ++ ExpectedEPNodeAssignment::All, ++ QDQTolerance(), ++ logging::Severity::kERROR, ++ "", // context model file path, not required for this inference ++ session_option_pairs); ++ ++ // 3rd run directly loads and run from Qnn context cache model ++ TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), ++ BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), ++ provider_options, ++ 14, ++ ExpectedEPNodeAssignment::All, ++ QDQTolerance(), ++ logging::Severity::kERROR, ++ context_binary_file); ++} ++ + TEST_F(QnnHTPBackendTests, QuantAccuracyTest) { + ProviderOptions provider_options; + +diff --git a/onnxruntime/test/providers/qnn/split_op_test.cc b/onnxruntime/test/providers/qnn/split_op_test.cc +index 6dc721edb..57e4b2117 100644 +--- a/onnxruntime/test/providers/qnn/split_op_test.cc ++++ b/onnxruntime/test/providers/qnn/split_op_test.cc +@@ -302,46 +302,19 @@ TEST_F(QnnHTPBackendTests, Split_Int32_Opset13) { + // Test 8-bit QDQ Split opset 18 on HTP backend: equal split of axis 0 via 'num_outputs' attribute + // and 'split' input. + TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset18) { +- // Split 6 into 3 outputs of lengths [2, 2, 2] +- TestInputDef input_def({6, 2}, false, +- {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f, 9.0f, 10.0f, 11.0f}); +- + // Use 'split' input (initializer). +- RunQDQSplitOpTestOnHTP(input_def, +- {2, 2, 2}, // split +- 0, // axis +- -1, // num_outputs +- 18, // opset ++ RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), ++ {2, 2}, // split ++ 0, // axis ++ -1, // num_outputs ++ 18, // opset + ExpectedEPNodeAssignment::All); + + // Use 'num_outputs' attribute. +- RunQDQSplitOpTestOnHTP(input_def, +- {}, // split (use num_outputs instead) +- 0, // axis +- 3, // num_outputs +- 18, // opset +- ExpectedEPNodeAssignment::All); +-} +- +-// Test 8-bit QDQ Split opset 18 on HTP backend. Use an uneven split (last chunk should be smaller). +-TEST_F(QnnHTPBackendTests, Split_NonEqual_Axis0_Opset18) { +- // Split 7 into 3 outputs of lengths [3, 3, 1] +- TestInputDef input_def({7, 2}, false, +- {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f}); +- +- // Use a `split` input with uneven split lengths. +- RunQDQSplitOpTestOnHTP(input_def, +- {3, 3, 1}, // split +- 0, // axis +- -1, // num_outputs +- 18, // opset +- ExpectedEPNodeAssignment::All); +- +- // Use a `num_outputs` attribute that does not evenly divide into shape[axis]. +- RunQDQSplitOpTestOnHTP(input_def, ++ RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis +- 3, // num_outputs ++ 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + } +diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py +index 68e441c87..8c23286e4 100644 +--- a/onnxruntime/test/python/onnxruntime_test_python.py ++++ b/onnxruntime/test/python/onnxruntime_test_python.py +@@ -434,25 +434,6 @@ class TestInferenceSession(unittest.TestCase): + self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_alloc"], "0") + self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_free"], "0") + self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_empty_cache"], "0") +- +- option["user_compute_stream"] = "0" +- sess.set_providers(["CUDAExecutionProvider"], [option]) +- options = sess.get_provider_options() +- self.assertEqual(options["CUDAExecutionProvider"]["user_compute_stream"], "0") +- +- try: +- import torch +- +- if torch.cuda.is_available(): +- s = torch.cuda.Stream() +- option["user_compute_stream"] = str(s.cuda_stream) +- sess.set_providers(["CUDAExecutionProvider"], [option]) +- options = sess.get_provider_options() +- self.assertEqual(options["CUDAExecutionProvider"]["user_compute_stream"], str(s.cuda_stream)) +- self.assertEqual(options["CUDAExecutionProvider"]["has_user_compute_stream"], "1") +- except ImportError: +- print("torch is not installed, skip testing setting user_compute_stream from torch cuda stream") +- + # + # Note: Tests that throw an exception leave an empty session due to how set_providers currently works, + # so run them last. Each set_providers call will attempt to re-create a session, so it's +@@ -650,14 +631,6 @@ class TestInferenceSession(unittest.TestCase): + if "ROCMExecutionProvider" in onnxrt.get_available_providers(): + do_test_get_and_set_tuning_results("ROCMExecutionProvider") + +- def test_run_model_with_optional_sequence_input(self): +- sess = onnxrt.InferenceSession(get_name("identity_opt.onnx")) +- x = [np.array([1, 2, 3, 4, 5]).astype(np.float32)] +- input_name = sess.get_inputs()[0].name +- output_name = sess.get_outputs()[0].name +- res = sess.run([output_name], {input_name: x}) +- np.testing.assert_allclose(res[0], x) +- + def test_run_model(self): + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers) + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) +diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +index eca143044..67db411dd 100644 +--- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py ++++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +@@ -392,208 +392,6 @@ class TestSymbolicShapeInferenceForOperators(unittest.TestCase): + self.assertEqual(len(output_dims), 1) + self.assertEqual(output_dims[0].dim_value, 512) + +- def test_quantize_linear(self): +- """ +- Test ONNX QuantizeLinear op. +- Check that the output shape is propagated from the first input and that the output data +- type comes from the zero-point input. +- """ +- initializers = [ +- helper.make_tensor( +- "scale", +- TensorProto.FLOAT, +- [], +- [1.0], +- ), +- helper.make_tensor( +- "zero_point", +- TensorProto.INT8, +- [], +- [16], +- ), +- ] +- +- nodes = [ +- helper.make_node( +- "QuantizeLinear", +- inputs=[ +- "input_f32", +- "scale", +- "zero_point", +- ], +- outputs=["output_s8"], +- ), +- ] +- +- inputs = [ +- helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["b", 2, 3, 4]), +- ] +- +- outputs = [ +- helper.make_tensor_value_info("output_s8", TensorProto.UNDEFINED, None), +- ] +- +- graph = helper.make_graph(nodes, "QuantizeLinear_Test", inputs, outputs, initializers) +- model = helper.make_model(graph) +- +- inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) +- +- expected_shapes = [ +- helper.make_tensor_value_info("output_s8", TensorProto.INT8, ["b", 2, 3, 4]), +- ] +- self._check_shapes(graph, inferred.graph, expected_shapes) +- +- def test_quantize_linear_ms_domain(self): +- """ +- Test QuantizeLinear op ('com.microsoft' domain). +- Check that the output shape is propagated from the first input and that the output data +- type comes from the zero-point input. +- """ +- initializers = [ +- helper.make_tensor( +- "scale", +- TensorProto.FLOAT, +- [], +- [1.0], +- ), +- helper.make_tensor( +- "zero_point", +- TensorProto.UINT16, +- [], +- [16], +- ), +- ] +- +- nodes = [ +- helper.make_node( +- "QuantizeLinear", +- inputs=[ +- "input_f32", +- "scale", +- "zero_point", +- ], +- outputs=["output_u16"], +- domain="com.microsoft", +- ), +- ] +- +- inputs = [ +- helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["b", 2, 3, 4]), +- ] +- +- outputs = [ +- helper.make_tensor_value_info("output_u16", TensorProto.UNDEFINED, None), +- ] +- +- graph = helper.make_graph(nodes, "QuantizeLinear_MSDomain_Test", inputs, outputs, initializers) +- model = helper.make_model(graph) +- +- inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) +- +- expected_shapes = [ +- helper.make_tensor_value_info("output_u16", TensorProto.UINT16, ["b", 2, 3, 4]), +- ] +- self._check_shapes(graph, inferred.graph, expected_shapes) +- +- def test_quantize_linear_no_zp_input(self): +- """ +- Test QuantizeLinear op ('com.microsoft' domain). +- Check that the output shape is propagated from the first input. +- The zero-point input is missing, so the output data type should default to uint8. +- """ +- initializers = [ +- helper.make_tensor( +- "scale", +- TensorProto.FLOAT, +- [], +- [1.0], +- ), +- ] +- +- nodes = [ +- helper.make_node( +- "QuantizeLinear", +- inputs=[ +- "input_f32", +- "scale", +- ], +- outputs=["output_u8"], +- domain="com.microsoft", +- ), +- ] +- +- inputs = [ +- helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["b", 2, 3, 4]), +- ] +- +- outputs = [ +- helper.make_tensor_value_info("output_u8", TensorProto.UNDEFINED, None), +- ] +- +- graph = helper.make_graph(nodes, "QuantizeLinear_NoZP_Test", inputs, outputs, initializers) +- model = helper.make_model(graph) +- +- inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) +- +- # Check that the output shape is propagated from the first input and that the +- # output data type comes from the zero-point input. +- expected_shapes = [ +- helper.make_tensor_value_info("output_u8", TensorProto.UINT8, ["b", 2, 3, 4]), +- ] +- self._check_shapes(graph, inferred.graph, expected_shapes) +- +- def test_dequantize_linear_ms_domain(self): +- """ +- Test DequantizeLinear operator ('com.microsoft' domain). +- Check that the output shape is propagated from the first input and that the output data +- type comes from the scale input. +- """ +- initializers = [ +- helper.make_tensor( +- "scale", +- TensorProto.FLOAT, +- [], +- [1.0], +- ), +- helper.make_tensor( +- "zero_point", +- TensorProto.UINT16, +- [], +- [16], +- ), +- ] +- +- nodes = [ +- helper.make_node( +- "DequantizeLinear", +- inputs=[ +- "input_u16", +- "scale", +- "zero_point", +- ], +- outputs=["output_f32"], +- domain="com.microsoft", +- ), +- ] +- +- inputs = [ +- helper.make_tensor_value_info("input_u16", TensorProto.UINT16, ["b", 2, 3, 4]), +- ] +- +- outputs = [ +- helper.make_tensor_value_info("output_f32", TensorProto.UNDEFINED, None), +- ] +- +- graph = helper.make_graph(nodes, "DequantizeLinear_MSDomain_Test", inputs, outputs, initializers) +- model = helper.make_model(graph) +- +- inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) +- +- expected_shapes = [ +- helper.make_tensor_value_info("output_f32", TensorProto.FLOAT, ["b", 2, 3, 4]), +- ] +- self._check_shapes(graph, inferred.graph, expected_shapes) +- + + class TestSymbolicShapeInferenceForSlice(unittest.TestCase): + def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim): +diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py +index 223f405e8..4de797400 100644 +--- a/onnxruntime/test/python/quantization/test_qdq.py ++++ b/onnxruntime/test/python/quantization/test_qdq.py +@@ -601,13 +601,6 @@ class TestQDQFormatConvRelu(TestQDQFormat): + ) + check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) + +- # If the model uses Q/DQ ops with "com.microsoft" domain (e.g., for int16 support), +- # then ensure the model has the appropriate opset import. +- if extra_options and extra_options.get("UseQDQContribOps", False): +- qdq_model = onnx.load_model(model_qdq_path) +- ms_opset = next((opset for opset in qdq_model.opset_import if opset.domain == "com.microsoft"), None) +- self.assertIsNot(ms_opset, None) +- + def verify_qop(self, per_channel, is_quant_type_int8): + np.random.seed(1) + model_fp32_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_fp32.{per_channel}.onnx") +diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc +index 8dad2c8e2..6ffe72f81 100644 +--- a/onnxruntime/test/shared_lib/test_inference.cc ++++ b/onnxruntime/test/shared_lib/test_inference.cc +@@ -43,10 +43,6 @@ + #include + #endif + +-#ifdef USE_ROCM +-#include +-#endif +- + // Once we use C++17 this could be replaced with std::size + template + constexpr size_t countof(T (&)[N]) { return N; } +@@ -1766,27 +1762,6 @@ TEST(CApiTest, get_allocator_cuda) { + } + #endif + +-#ifdef USE_ROCM +-TEST(CApiTest, get_allocator_rocm) { +- Ort::SessionOptions session_options; +- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); +- Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); +- +- Ort::MemoryInfo info_rocm("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +- Ort::Allocator rocm_allocator(session, info_rocm); +- +- auto allocator_info = rocm_allocator.GetInfo(); +- ASSERT_TRUE(info_rocm == allocator_info); +- void* p = rocm_allocator.Alloc(1024); +- ASSERT_NE(p, nullptr); +- rocm_allocator.Free(p); +- +- auto mem_allocation = rocm_allocator.GetAllocation(1024); +- ASSERT_NE(nullptr, mem_allocation.get()); +- ASSERT_EQ(1024U, mem_allocation.size()); +-} +-#endif +- + TEST(CApiTest, io_binding) { + Ort::SessionOptions session_options; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); +@@ -1962,7 +1937,7 @@ TEST(CApiTest, io_binding_cuda) { + } + #endif + +-#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) ++#if defined(USE_CUDA) || defined(USE_TENSORRT) + TEST(CApiTest, basic_cuda_graph) { + const auto& api = Ort::GetApi(); + Ort::SessionOptions session_options; +@@ -1980,7 +1955,7 @@ TEST(CApiTest, basic_cuda_graph) { + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( + static_cast(session_options), + rel_trt_options.get()) == nullptr); +-#elif defined(USE_CUDA) ++#else + // Enable cuda graph in cuda provider option. + OrtCUDAProviderOptionsV2* cuda_options = nullptr; + ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); +@@ -1993,55 +1968,34 @@ TEST(CApiTest, basic_cuda_graph) { + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( + static_cast(session_options), + rel_cuda_options.get()) == nullptr); +-#elif defined(USE_ROCM) +- // Enable hip graph in rocm provider option. +- OrtROCMProviderOptions* rocm_options = nullptr; +- ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); +- std::unique_ptr +- rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); +- std::vector keys{"enable_hip_graph"}; +- std::vector values{"1"}; +- ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); +- +- ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( +- static_cast(session_options), +- rel_rocm_options.get()) == nullptr); + #endif + + Ort::Session session(*ort_env, MODEL_URI, session_options); +-#if defined(USE_ROCM) +-// local hipify +-#define cudaMemcpy hipMemcpy +-#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +-#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +- Ort::MemoryInfo info_mem("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +-#else +- Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +-#endif ++ Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); + +- Ort::Allocator allocator(session, info_mem); +- auto allocator_info = allocator.GetInfo(); +- ASSERT_TRUE(info_mem == allocator_info); ++ Ort::Allocator cuda_allocator(session, info_cuda); ++ auto allocator_info = cuda_allocator.GetInfo(); ++ ASSERT_TRUE(info_cuda == allocator_info); + + const std::array x_shape = {3, 2}; + std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; +- auto input_data = allocator.GetAllocation(x_values.size() * sizeof(float)); ++ auto input_data = cuda_allocator.GetAllocation(x_values.size() * sizeof(float)); + + ASSERT_NE(input_data.get(), nullptr); +- (void)cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); ++ cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); + + // Create an OrtValue tensor backed by data on CUDA memory +- Ort::Value bound_x = Ort::Value::CreateTensor(info_mem, reinterpret_cast(input_data.get()), x_values.size(), ++ Ort::Value bound_x = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(input_data.get()), x_values.size(), + x_shape.data(), x_shape.size()); + + const std::array expected_y_shape = {3, 2}; + std::array expected_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; +- auto output_data = allocator.GetAllocation(expected_y.size() * sizeof(float)); ++ auto output_data = cuda_allocator.GetAllocation(expected_y.size() * sizeof(float)); + + ASSERT_NE(output_data.get(), nullptr); + + // Create an OrtValue tensor backed by data on CUDA memory +- Ort::Value bound_y = Ort::Value::CreateTensor(info_mem, reinterpret_cast(output_data.get()), ++ Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(output_data.get()), + expected_y.size(), expected_y_shape.data(), expected_y_shape.size()); + + // Create IoBinding for inputs and outputs. +@@ -2054,37 +2008,31 @@ TEST(CApiTest, basic_cuda_graph) { + + // Check the values against the bound raw memory (needs copying from device to host first) + std::array y_values; +- (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); ++ cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); + + // Replay the captured CUDA graph + session.Run(Ort::RunOptions(), binding); +- (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); ++ cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); + + // Change the input and replay the CUDA graph again. + x_values = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}; +- (void)cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); ++ cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); + binding.SynchronizeInputs(); + + session.Run(Ort::RunOptions(), binding); +- (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); ++ cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + expected_y = {10.0f, 40.0f, 90.0f, 160.0f, 250.0f, 360.0f}; + ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); + + // Clean up + binding.ClearBoundInputs(); + binding.ClearBoundOutputs(); +-#if defined(USE_ROCM) +-#undef cudaMemcpy +-#undef cudaMemcpyHostToDevice +-#undef cudaMemcpyDeviceToHost +-#endif + } + +-// The following test uses some ops not supported in the reduced ops build + #ifndef REDUCED_OPS_BUILD +-#if defined(USE_CUDA) || defined(USE_TENSORRT) ++// The following test uses some ops not supported in the reduced ops build + TEST(CApiTest, cuda_graph_with_shape_nodes) { + const auto& api = Ort::GetApi(); + +@@ -2105,34 +2053,10 @@ TEST(CApiTest, cuda_graph_with_shape_nodes) { + // Successful loading of the ONNX model with shape nodes with cuda graph feature enabled + Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); + } +-#endif // defined(USE_CUDA) || defined(USE_TENSORRT) + +-#if defined(USE_ROCM) +-TEST(CApiTest, hip_graph_with_shape_nodes) { +- const auto& api = Ort::GetApi(); +- +- // Enable hip graph in rocm provider option. +- OrtROCMProviderOptions* rocm_options = nullptr; +- ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); +- std::unique_ptr +- rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); +- std::vector keys{"enable_hip_graph"}; +- std::vector values{"1"}; +- ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); +- +- Ort::SessionOptions session_options; +- ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( +- static_cast(session_options), +- rel_rocm_options.get()) == nullptr); +- +- // Successful loading of the ONNX model with shape nodes with hip graph feature enabled +- Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); +-} +-#endif // defined(USE_ROCM) +- +-#endif // REDUCED_OPS_BUILD ++#endif + +-#endif // defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) ++#endif + + TEST(CApiTest, create_tensor) { + const char* s[] = {"abc", "kmp"}; +diff --git a/onnxruntime/test/testdata/identity_opt.onnx b/onnxruntime/test/testdata/identity_opt.onnx +deleted file mode 100644 +index 24c05f7b7..000000000 +Binary files a/onnxruntime/test/testdata/identity_opt.onnx and /dev/null differ +diff --git a/onnxruntime/test/testdata/training_api/checkpoint.ckpt b/onnxruntime/test/testdata/training_api/checkpoint.ckpt +index d1bc1f121..d0b7d0deb 100644 +Binary files a/onnxruntime/test/testdata/training_api/checkpoint.ckpt and b/onnxruntime/test/testdata/training_api/checkpoint.ckpt differ +diff --git a/onnxruntime/test/testdata/training_api/custom_ops/checkpoint b/onnxruntime/test/testdata/training_api/custom_ops/checkpoint +index ce23d149e..753b24af6 100644 +Binary files a/onnxruntime/test/testdata/training_api/custom_ops/checkpoint and b/onnxruntime/test/testdata/training_api/custom_ops/checkpoint differ +diff --git a/onnxruntime/test/testdata/training_api/nominal_checkpoint b/onnxruntime/test/testdata/training_api/nominal_checkpoint +deleted file mode 100644 +index 2eadfeece..000000000 +Binary files a/onnxruntime/test/testdata/training_api/nominal_checkpoint and /dev/null differ +diff --git a/onnxruntime/test/testdata/training_api/ort_format/checkpoint b/onnxruntime/test/testdata/training_api/ort_format/checkpoint +index 83ef6aa4c..ab35c9ad5 100644 +Binary files a/onnxruntime/test/testdata/training_api/ort_format/checkpoint and b/onnxruntime/test/testdata/training_api/ort_format/checkpoint differ +diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc +index 4c38c90c2..97169df36 100644 +--- a/onnxruntime/test/unittest_main/test_main.cc ++++ b/onnxruntime/test/unittest_main/test_main.cc +@@ -59,8 +59,8 @@ int TEST_MAIN(int argc, char** argv) { + int status = 0; + + ORT_TRY { +- ortenv_setup(); + ::testing::InitGoogleTest(&argc, argv); ++ ortenv_setup(); + + // allow verbose logging to be enabled by setting this environment variable to a numeric log level + constexpr auto kLogLevelEnvironmentVariableName = "ORT_UNIT_TEST_MAIN_LOG_LEVEL"; +diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js +index cbc60c70b..25ece9c70 100644 +--- a/onnxruntime/wasm/js_internal_api.js ++++ b/onnxruntime/wasm/js_internal_api.js +@@ -24,7 +24,7 @@ Module['unmountExternalData'] = () => { + /** + * init JSEP + */ +-Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel, captureBegin, captureEnd, replay) => { ++Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel) => { + Module.jsepBackend = backend; + Module.jsepAlloc = alloc; + Module.jsepFree = free; +@@ -33,9 +33,6 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea + Module.jsepCreateKernel = createKernel; + Module.jsepReleaseKernel = releaseKernel; + Module.jsepRunKernel = runKernel; +- Module.jsepCaptureBegin = captureBegin; +- Module.jsepCaptureEnd = captureEnd; +- Module.jsepReplay = replay; + + // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) + // It removes some overhead in cwarp() and ccall() that we don't need. +@@ -163,10 +160,6 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea + }; + + // replace the original functions with asyncified versions +- Module['_OrtCreateSession'] = jsepWrapAsync( +- Module['_OrtCreateSession'], +- () => Module['_OrtCreateSession'], +- v => Module['_OrtCreateSession'] = v); + Module['_OrtRun'] = runAsync(jsepWrapAsync( + Module['_OrtRun'], + () => Module['_OrtRun'], +@@ -184,16 +177,13 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea + Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { + return backend['registerBuffer'](sessionId, index, buffer, size); + }; ++ Module['jsepUnregisterBuffers'] = sessionId => { ++ backend['unregisterBuffers'](sessionId); ++ }; + Module['jsepGetBuffer'] = (dataId) => { + return backend['getBuffer'](dataId); + }; + Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { + return backend['createDownloader'](gpuBuffer, size, type); + }; +- Module['jsepOnReleaseSession'] = sessionId => { +- backend['onReleaseSession'](sessionId); +- }; +- Module['jsepOnRunStart'] = sessionId => { +- return backend['onRunStart'](sessionId); +- }; + }; +diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc +index 4ab8db856..0c2bfa19e 100644 +--- a/orttraining/orttraining/python/orttraining_pybind_state.cc ++++ b/orttraining/orttraining/python/orttraining_pybind_state.cc +@@ -802,9 +802,6 @@ void addObjectMethodsForTraining(py::module& m) { + .def("copy_parameter_from", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& parameter_name, OrtValue& value) -> void { +- if (state->module_checkpoint_state.is_nominal_state) { +- ORT_THROW("Cannot copy parameter to a nominal state. Please load all the parameter states first"); +- } + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); +@@ -814,9 +811,6 @@ void addObjectMethodsForTraining(py::module& m) { + }) + .def("get_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { +- if (state->module_checkpoint_state.is_nominal_state) { +- ORT_THROW("Cannot get parameter from a nominal state. Please load the parameter states first"); +- } + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); +@@ -857,9 +851,6 @@ void addObjectMethodsForTraining(py::module& m) { + return std::make_unique(optimizer_model_uri, state, providers, session_options); + })) + .def("optimizer_step", [](PyOptimizer* optimizer) -> void { +- // In case the optimizer was constructed using a nominal checkpoint, +- // the optimizer state construction is delayed until the first call to Optimizer::Step(). +- // It is expected that the model parameter state is available at this point. + ORT_THROW_IF_ERROR(optimizer->optimizer_->Step()); + }) + .def("set_learning_rate", [](PyOptimizer* optimizer, float lr) -> void { +@@ -902,7 +893,7 @@ void addObjectMethodsForTraining(py::module& m) { + "save_checkpoint", + [](const std::vector& trainable_tensor_protos_pybytes, + const std::vector& non_trainable_tensor_protos_pybytes, +- const std::string& checkpoint_path, const bool nominal_checkpoint) { ++ const std::string& checkpoint_path) { + std::vector trainable_tensor_protos(trainable_tensor_protos_pybytes.size()); + std::vector non_trainable_tensor_protos(non_trainable_tensor_protos_pybytes.size()); + +@@ -923,8 +914,7 @@ void addObjectMethodsForTraining(py::module& m) { + + ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(trainable_tensor_protos, + non_trainable_tensor_protos, +- ToPathString(checkpoint_path), +- nominal_checkpoint)); ++ ToPathString(checkpoint_path))); + }); + + m.def("save_checkpoint", +diff --git a/orttraining/orttraining/python/training/api/checkpoint_state.py b/orttraining/orttraining/python/training/api/checkpoint_state.py +index cc4e84111..ba95cd04f 100644 +--- a/orttraining/orttraining/python/training/api/checkpoint_state.py ++++ b/orttraining/orttraining/python/training/api/checkpoint_state.py +@@ -222,8 +222,6 @@ class CheckpointState: + def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: + """Loads the checkpoint state from the checkpoint file + +- The checkpoint file can either be the complete checkpoint or the nominal checkpoint. +- + Args: + checkpoint_uri: The path to the checkpoint file. + +diff --git a/orttraining/orttraining/python/training/api/module.py b/orttraining/orttraining/python/training/api/module.py +index a87cd6fdd..f8f6b4322 100644 +--- a/orttraining/orttraining/python/training/api/module.py ++++ b/orttraining/orttraining/python/training/api/module.py +@@ -178,9 +178,6 @@ class Module: + def copy_buffer_to_parameters(self, buffer: OrtValue, trainable_only: bool = True) -> None: + """Copies the OrtValue buffer to the training session parameters. + +- In case the module was loaded from a nominal checkpoint, invoking this function is required +- to load the updated parameters onto the checkpoint to complete it. +- + Args: + buffer: The OrtValue buffer to copy to the training session parameters. + """ +diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py +index 7a4eb251b..a57105545 100644 +--- a/orttraining/orttraining/python/training/artifacts.py ++++ b/orttraining/orttraining/python/training/artifacts.py +@@ -43,11 +43,7 @@ def generate_artifacts( + loss: Optional[Union[LossType, onnxblock.Block]] = None, + optimizer: Optional[OptimType] = None, + artifact_directory: Optional[Union[str, bytes, os.PathLike]] = None, +- prefix: str = "", +- ort_format: bool = False, +- custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None, +- additional_output_names: Optional[List[str]] = None, +- nominal_checkpoint: bool = False, ++ **extra_options, + ) -> None: + """Generates artifacts required for training with ORT training api. + +@@ -67,16 +63,11 @@ def generate_artifacts( + optimizer: The optimizer enum to be used for training. If None, no optimizer model is generated. + artifact_directory: The directory to save the generated artifacts. + If None, the current working directory is used. +- prefix: The prefix to be used for the generated artifacts. If not specified, no prefix is used. +- ort_format: Whether to save the generated artifacts in ORT format or not. Default is False. +- custom_op_library: The path to the custom op library. +- If not specified, no custom op library is used. +- additional_output_names: List of additional output names to be added to the training/eval model in addition +- to the loss output. Default is None. +- nominal_checkpoint: Whether to generate the nominal checkpoint in addition to the complete checkpoint. +- Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model +- parameters. It can be used on the device to reduce overhead while constructing the training model +- as well as to reduce the size of the checkpoint packaged with the on-device application. ++ prefix (str): The prefix to be used for the generated artifacts. If not specified, no prefix is used. ++ ort_format (bool): Whether to save the generated artifacts in ORT format or not. Default is False. ++ custom_op_library (str | os.PathLike): The path to the custom op library. ++ If not specified, no custom op library is used. ++ additional_output_names (List[str]): List of additional output names to be added to the training/eval model. + + Raises: + RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block` +@@ -116,19 +107,19 @@ def generate_artifacts( + self._loss = _loss + + def build(self, *inputs_to_loss): +- if additional_output_names: ++ if "additional_output_names" in extra_options: + # If additional output names is not a list, raise an error +- if not isinstance(additional_output_names, list): ++ if not isinstance(extra_options["additional_output_names"], list): + raise RuntimeError( +- f"Unknown type provided for additional output names {type(additional_output_names)}. " ++ f"Unknown type provided for additional output names {type(extra_options['additional_output_names'])}. " + "Expected additional output names to be a list of strings." + ) + + loss_output = self._loss(*inputs_to_loss) + if isinstance(loss_output, tuple): +- return (*loss_output, *tuple(additional_output_names)) ++ return (*loss_output, *tuple(extra_options["additional_output_names"])) + else: +- return (loss_output, *tuple(additional_output_names)) ++ return (loss_output, *tuple(extra_options["additional_output_names"])) + + return self._loss(*inputs_to_loss) + +@@ -152,57 +143,58 @@ def generate_artifacts( + eval_model = None + model_params = None + +- custom_op_library_path = None ++ custom_op_library = extra_options.get("custom_op_library", None) + if custom_op_library is not None: + logging.info("Custom op library provided: %s", custom_op_library) +- custom_op_library_path = pathlib.Path(custom_op_library) ++ custom_op_library = pathlib.Path(custom_op_library) + + with onnxblock.base(model), onnxblock.custom_op_library( +- custom_op_library_path ++ custom_op_library + ) if custom_op_library is not None else contextlib.nullcontext(): + _ = training_block(*[output.name for output in model.graph.output]) + training_model, eval_model = training_block.to_model_proto() + model_params = training_block.parameters() + +- def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_path): +- if ort_format: ++ def _export_to_ort_format(model_path, output_dir, extra_options): ++ if extra_options.get("ort_format", False): ++ custom_op_library = extra_options.get("custom_op_library", None) ++ if custom_op_library is not None: ++ custom_op_library = pathlib.Path(custom_op_library) + convert_onnx_models_to_ort( + model_path, + output_dir=output_dir, +- custom_op_library_path=custom_op_library_path, ++ custom_op_library_path=custom_op_library, + optimization_styles=[OptimizationStyle.Fixed], + ) + + if artifact_directory is None: + artifact_directory = pathlib.Path.cwd() +- artifact_directory = pathlib.Path(artifact_directory) +- +- if prefix: ++ prefix = "" ++ if "prefix" in extra_options: ++ prefix = extra_options["prefix"] + logging.info("Using prefix %s for generated artifacts.", prefix) + ++ artifact_directory = pathlib.Path(artifact_directory) ++ + training_model_path = artifact_directory / f"{prefix}training_model.onnx" + if os.path.exists(training_model_path): + logging.info("Training model path %s already exists. Overwriting.", training_model_path) + onnx.save(training_model, training_model_path) +- _export_to_ort_format(training_model_path, artifact_directory, ort_format, custom_op_library_path) ++ _export_to_ort_format(training_model_path, artifact_directory, extra_options) + logging.info("Saved training model to %s", training_model_path) + + eval_model_path = artifact_directory / f"{prefix}eval_model.onnx" + if os.path.exists(eval_model_path): + logging.info("Eval model path %s already exists. Overwriting.", eval_model_path) + onnx.save(eval_model, eval_model_path) +- _export_to_ort_format(eval_model_path, artifact_directory, ort_format, custom_op_library_path) ++ _export_to_ort_format(eval_model_path, artifact_directory, extra_options) + logging.info("Saved eval model to %s", eval_model_path) + + checkpoint_path = artifact_directory / f"{prefix}checkpoint" + if os.path.exists(checkpoint_path): + logging.info("Checkpoint path %s already exists. Overwriting.", checkpoint_path) +- onnxblock.save_checkpoint(training_block.parameters(), checkpoint_path, nominal_checkpoint=False) ++ onnxblock.save_checkpoint(training_block.parameters(), checkpoint_path) + logging.info("Saved checkpoint to %s", checkpoint_path) +- if nominal_checkpoint: +- nominal_checkpoint_path = artifact_directory / f"{prefix}nominal_checkpoint" +- onnxblock.save_checkpoint(training_block.parameters(), nominal_checkpoint_path, nominal_checkpoint=True) +- logging.info("Saved nominal checkpoint to %s", nominal_checkpoint_path) + + # If optimizer is not specified, skip creating the optimizer model + if optimizer is None: +@@ -233,5 +225,5 @@ def generate_artifacts( + + optimizer_model_path = artifact_directory / f"{prefix}optimizer_model.onnx" + onnx.save(optim_model, optimizer_model_path) +- _export_to_ort_format(optimizer_model_path, artifact_directory, ort_format, custom_op_library_path) ++ _export_to_ort_format(optimizer_model_path, artifact_directory, extra_options) + logging.info("Saved optimizer model to %s", optimizer_model_path) +diff --git a/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py b/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py +index de3453c63..bc50d4afa 100644 +--- a/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py ++++ b/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py +@@ -6,21 +6,18 @@ from typing import List, Tuple, Union + + import onnx + +-from onnxruntime.capi._pybind_state import get_model_after_loading_checkpoint as _load_checkpoint_to_model +-from onnxruntime.capi._pybind_state import save_checkpoint as _save_checkpoint ++from onnxruntime.capi._pybind_state import get_model_after_loading_checkpoint as _internal_load_checkpoint_to_model ++from onnxruntime.capi._pybind_state import save_checkpoint as _internal_save_checkpoint + + + def save_checkpoint( +- parameters: Tuple[List[onnx.TensorProto], List[onnx.TensorProto]], +- path_to_checkpoint: Union[str, os.PathLike], +- nominal_checkpoint: bool = False, ++ parameters: Tuple[List[onnx.TensorProto], List[onnx.TensorProto]], path_to_checkpoint: Union[str, os.PathLike] + ) -> None: + """Saves the parameters to the checkpoint directory path_to_checkpoint. + + Args: + parameters tuple(trainable_params, non_trainable_params): The parameters to save to the checkpoint file. +- path_to_checkpoint: The path to the checkpoint directory. +- nominal_checkpoint: If True, the checkpoint is saved as a nominal checkpoint. Default is False. ++ path_to_checkpoint (str): The path to the checkpoint directory. + """ + + if parameters is None: +@@ -29,7 +26,7 @@ def save_checkpoint( + trainable_params, non_trainable_params = parameters + trainable_params = [param.SerializeToString() for param in trainable_params] + non_trainable_params = [param.SerializeToString() for param in non_trainable_params] +- _save_checkpoint(trainable_params, non_trainable_params, os.fspath(path_to_checkpoint), nominal_checkpoint) ++ _internal_save_checkpoint(trainable_params, non_trainable_params, os.fspath(path_to_checkpoint)) + + + def load_checkpoint_to_model(path_to_checkpoint: Union[str, os.PathLike], model: onnx.ModelProto) -> None: +@@ -40,4 +37,4 @@ def load_checkpoint_to_model(path_to_checkpoint: Union[str, os.PathLike], model: + model (onnx.ModelProto): The model to load the checkpoint to. + """ + +- model.ParseFromString(_load_checkpoint_to_model(os.fspath(path_to_checkpoint), model.SerializeToString())) ++ model.ParseFromString(_internal_load_checkpoint_to_model(os.fspath(path_to_checkpoint), model.SerializeToString())) +diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py +index 9c7214f46..e0f65ed27 100644 +--- a/orttraining/orttraining/python/training/ort_triton/_codegen.py ++++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py +@@ -37,7 +37,7 @@ from ._ir import ( + from ._lowering import lower + from ._sorted_graph import SortedGraph + from ._sympy_utils import parse_shape, sympy_dot +-from ._utils import is_number, may_add_brackets ++from ._utils import may_add_brackets + + + class TritonCodegen(NodeVisitor): +@@ -318,7 +318,7 @@ class TritonCodegen(NodeVisitor): + if op_type == "Cast": + from_dtype = node.inputs[0].dtype.type + to_dtype = node.outputs[0].dtype.type +- if from_dtype == to_dtype or is_number(kwargs["i0"]): ++ if from_dtype == to_dtype: + op_type = "Identity" + elif to_dtype == np.bool_: + op_type = "CastBool" +diff --git a/orttraining/orttraining/python/training/ort_triton/_utils.py b/orttraining/orttraining/python/training/ort_triton/_utils.py +index 95e6703be..c80e28f6f 100644 +--- a/orttraining/orttraining/python/training/ort_triton/_utils.py ++++ b/orttraining/orttraining/python/training/ort_triton/_utils.py +@@ -150,11 +150,3 @@ def next_power_of_2(n: int) -> int: + n |= n >> 16 + n += 1 + return n +- +- +-def is_number(name: str) -> bool: +- try: +- float(name) +- return True +- except ValueError: +- return name.startswith("float(") and name.endswith(")") +diff --git a/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py b/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py +index a3681a136..ed9292358 100644 +--- a/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py ++++ b/orttraining/orttraining/python/training/ort_triton/kernel/_mm.py +@@ -11,7 +11,7 @@ from typing import Tuple + import torch + + from .._cache import ModuleCache, PyCodeCache +-from .._utils import gen_unique_name, next_power_of_2 ++from .._utils import next_power_of_2 + + _DEBUG_MODE = "ORTMODULE_TRITON_DEBUG" in os.environ and int(os.getenv("ORTMODULE_TRITON_DEBUG")) == 1 + +@@ -305,18 +305,18 @@ def _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name): + + + def _gen_mm_key(dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float) -> int: +- return hash(f"mm|{dtype}|{m}|{n}|{k}|{trans_a}|{trans_b}|{alpha}") ++ return hash(f"mm|{dtype}|{m}|{n}|{k}|{trans_a}|{trans_b}|{alpha}") % (10**8) + + + def _gen_mm_module( + dtype: torch.dtype, m: int, n: int, k: int, trans_a: bool, trans_b: bool, alpha: float + ) -> Tuple[str, ModuleType]: +- func_name = gen_unique_name("mm") ++ func_name = f"mm_{_gen_mm_key(dtype, m, n, k, trans_a, trans_b, alpha)}" + kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) + src_code = _MM_TEMPLATE.format(**kwargs) + if _DEBUG_MODE: + os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) +- with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: ++ with open(f"triton_debug/{func_name}.py", "w") as f: + f.write(src_code) + return func_name, PyCodeCache().load(src_code) + +@@ -333,7 +333,7 @@ def _gen_gemm_key( + alpha: float, + beta: float, + ) -> int: +- return hash(f"gemm|{dtype}|{m}|{n}|{k}|{stride_cm}|{stride_cn}|{trans_a}|{trans_b}|{alpha}|{beta}") ++ return hash(f"gemm|{dtype}|{m}|{n}|{k}|{stride_cm}|{stride_cn}|{trans_a}|{trans_b}|{alpha}|{beta}") % (10**8) + + + def _gen_gemm_module( +@@ -348,7 +348,7 @@ def _gen_gemm_module( + alpha: float, + beta: float, + ) -> Tuple[str, ModuleType]: +- func_name = gen_unique_name("gemm") ++ func_name = f"gemm_{_gen_gemm_key(dtype, m, n, k, stride_cm, stride_cn, trans_a, trans_b, alpha, beta)}" + kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) + kwargs["stride_cm"] = stride_cm + kwargs["stride_cn"] = stride_cn +@@ -356,7 +356,7 @@ def _gen_gemm_module( + src_code = _GEMM_TEMPLATE.format(**kwargs) + if _DEBUG_MODE: + os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) +- with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: ++ with open(f"triton_debug/{func_name}.py", "w") as f: + f.write(src_code) + return func_name, PyCodeCache().load(src_code) + +@@ -364,13 +364,13 @@ def _gen_gemm_module( + def _gen_bmm_key( + dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float + ) -> int: +- return hash(f"bmm|{dtype}|{m}|{n}|{k}|{batch_a}|{batch_b}|{trans_a}|{trans_b}|{alpha}") ++ return hash(f"bmm|{dtype}|{m}|{n}|{k}|{batch_a}|{batch_b}|{trans_a}|{trans_b}|{alpha}") % (10**8) + + + def _gen_bmm_module( + dtype: torch.dtype, m: int, n: int, k: int, batch_a: int, batch_b: int, trans_a: bool, trans_b: bool, alpha: float + ) -> Tuple[str, ModuleType]: +- func_name = gen_unique_name("bmm") ++ func_name = f"bmm_{_gen_bmm_key(dtype, m, n, k, batch_a, batch_b, trans_a, trans_b, alpha)}" + kwargs = _mm_configs(dtype, m, n, k, trans_a, trans_b, alpha, func_name) + batch = batch_a if batch_a >= batch_b else batch_b + kwargs["stride_aq"] = m * k if batch_a == batch else 0 +@@ -379,7 +379,7 @@ def _gen_bmm_module( + src_code = _BMM_TEMPLATE.format(**kwargs) + if _DEBUG_MODE: + os.makedirs(os.path.dirname("triton_debug/"), exist_ok=True) +- with open(f"triton_debug/{func_name}.py", "w", encoding="utf-8") as f: ++ with open(f"triton_debug/{func_name}.py", "w") as f: + f.write(src_code) + return func_name, PyCodeCache().load(src_code) + +diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +index f16abc712..1fe61750e 100644 +--- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py ++++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +@@ -67,7 +67,7 @@ class _ShapeCache: + + def _gen_key(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> int: + # pylint: disable=unused-argument +- return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") ++ return hash(f"{onnx_key}|{str(shapes).replace(' ', '')}") % (10**8) + + + def _gen_module(onnx_key: int, onnx_str: bytes, shapes: List[List[Union[int, str]]]) -> Tuple[str, ModuleType]: +diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py +index 539859a0d..df3b07878 100644 +--- a/orttraining/orttraining/python/training/ortmodule/options.py ++++ b/orttraining/orttraining/python/training/ortmodule/options.py +@@ -379,9 +379,6 @@ class _RuntimeOptions: + import triton # noqa: F401 + except ImportError: + pass +- self._logger.warning( +- "triton library missing. Please install triton with `pip install triton`. Triton feature will be off." +- ) + else: + self.enable_triton = True + +diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py +index 573ec85d7..f0b6b9c5f 100644 +--- a/orttraining/orttraining/test/python/orttraining_test_dort.py ++++ b/orttraining/orttraining/test/python/orttraining_test_dort.py +@@ -216,12 +216,7 @@ class TestTorchDynamoOrt(unittest.TestCase): + tensor_q = tensor_p.relu() + return tensor_q + +- # TODO: Set use_aot_autograd=False. In order to decompose torch +- # function calls to aten ops, we need to set +- # user_aot_autograd=True because there is no decomposition in DORT +- # anymore. A long-term fix will be brining # decomposition pass back +- # into DORT. +- local_backend = make_local_backend(dynamic=True, use_aot_autograd=True) ++ local_backend = make_local_backend(dynamic=True, use_aot_autograd=False) + optimized_elementwise_model = torch.compile(elementwise_model, backend=local_backend, dynamic=True) + + def run(fun, list_x): +diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +index 3d41c8678..910ddb34e 100644 +--- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py ++++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +@@ -1047,26 +1047,3 @@ def test_custom_loss_function(): + + with tempfile.TemporaryDirectory() as temp_dir: + artifacts.generate_artifacts(onnx_model, loss=CustomLossBlock(), artifact_directory=temp_dir) +- +- +-def test_save_nominal_checkpoint(): +- device = "cpu" +- batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10 +- _, base_model = _get_models(device, batch_size, input_size, hidden_size, output_size) +- +- with tempfile.TemporaryDirectory() as temp_dir: +- artifacts.generate_artifacts( +- base_model, +- requires_grad=["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"], +- loss=artifacts.LossType.CrossEntropyLoss, +- optimizer=artifacts.OptimType.AdamW, +- artifact_directory=temp_dir, +- nominal_checkpoint=True, +- ) +- +- assert os.path.exists(os.path.join(temp_dir, "checkpoint")) +- assert os.path.exists(os.path.join(temp_dir, "nominal_checkpoint")) +- assert ( +- os.stat(os.path.join(temp_dir, "checkpoint")).st_size +- > os.stat(os.path.join(temp_dir, "nominal_checkpoint")).st_size +- ) +diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py +index ce251b984..34d8c24cc 100644 +--- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py ++++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py +@@ -6,7 +6,6 @@ from __future__ import annotations + import os + import pathlib + import tempfile +-from dataclasses import dataclass + + import numpy as np + import onnx +@@ -29,22 +28,11 @@ class SimpleModelWithCrossEntropyLoss(onnxblock.TrainingBlock): + return self.loss(output_name) + + +-@dataclass +-class Artifacts: +- checkpoint_file_path: str +- training_model_file_path: str +- eval_model_file_path: str +- optimizer_model_file_path: str +- pt_model: torch.nn.Module +- nominal_checkpoint_file_path: str | None = None +- +- + def _create_training_artifacts( + artifact_directory: str | os.PathLike, + requires_grad: list[str] | None = None, + frozen_params: list[str] | None = None, + optimizer_type=artifacts.OptimType.AdamW, +- nominal_checkpoint: bool = False, + ): + device = "cpu" + batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10 +@@ -63,20 +51,14 @@ def _create_training_artifacts( + requires_grad=requires_grad, + frozen_params=frozen_params, + artifact_directory=artifact_directory, +- nominal_checkpoint=nominal_checkpoint, + ) + + training_model_file = os.path.join(artifact_directory, "training_model.onnx") + eval_model_file = os.path.join(artifact_directory, "eval_model.onnx") + optimizer_model_file = os.path.join(artifact_directory, "optimizer_model.onnx") + checkpoint_file = os.path.join(artifact_directory, "checkpoint") +- nominal_checkpoint_file = None +- if nominal_checkpoint: +- nominal_checkpoint_file = os.path.join(artifact_directory, "nominal_checkpoint") + +- return Artifacts( +- checkpoint_file, training_model_file, eval_model_file, optimizer_model_file, pt_model, nominal_checkpoint_file +- ) ++ return checkpoint_file, training_model_file, eval_model_file, optimizer_model_file, pt_model + + + def test_train_step(): +@@ -85,16 +67,22 @@ def test_train_step(): + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir) ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ _, ++ _, ++ pt_model, ++ ) = _create_training_artifacts(temp_dir) + # Create Checkpoint State. +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + # Create a Module. +- model = Module(artifacts.training_model_file_path, state) ++ model = Module(training_model_file_path, state) + model.train() + ort_loss = model(inputs, labels) + + # Calculate loss using pytorch model to compare it with Module's output. +- pt_outputs = artifacts.pt_model(torch.from_numpy(inputs)) ++ pt_outputs = pt_model(torch.from_numpy(inputs)) + loss_fn = torch.nn.CrossEntropyLoss() + pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels).long()) + +@@ -107,11 +95,17 @@ def test_eval_step(): + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir) ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ eval_model_file_path, ++ _, ++ _, ++ ) = _create_training_artifacts(temp_dir) + # Create Checkpoint State. +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + # Create a Module. +- model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path) ++ model = Module(training_model_file_path, state, eval_model_file_path) + model.train() + model(inputs, labels) + +@@ -127,12 +121,18 @@ def test_optimizer_step(optimizer_type): + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ _, ++ optimizer_model_file_path, ++ _, ++ ) = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) + # Create Checkpoint State. +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + # Create a Module and Optimizer. +- model = Module(artifacts.training_model_file_path, state) +- optimizer = Optimizer(artifacts.optimizer_model_file_path, model) ++ model = Module(training_model_file_path, state) ++ optimizer = Optimizer(optimizer_model_file_path, model) + + model.train() + old_flatten_params = model.get_contiguous_parameters() +@@ -147,12 +147,18 @@ def test_optimizer_step(optimizer_type): + @pytest.mark.parametrize("optimizer_type", [artifacts.OptimType.SGD, artifacts.OptimType.AdamW]) + def test_get_and_set_lr(optimizer_type): + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ _, ++ optimizer_model_file_path, ++ _, ++ ) = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) + # Create Checkpoint State. +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + # Create a Module and Optimizer. +- model = Module(artifacts.training_model_file_path, state) +- optimizer = Optimizer(artifacts.optimizer_model_file_path, model) ++ model = Module(training_model_file_path, state) ++ optimizer = Optimizer(optimizer_model_file_path, model) + + # Test get and set learning rate. + lr = optimizer.get_learning_rate() +@@ -172,11 +178,18 @@ def test_scheduler_step(optimizer_type): + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ _, ++ optimizer_model_file_path, ++ _, ++ ) = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) ++ # Create Checkpoint State. ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + # Create a Module and Optimizer. +- model = Module(artifacts.training_model_file_path, state) +- optimizer = Optimizer(artifacts.optimizer_model_file_path, model) ++ model = Module(training_model_file_path, state) ++ optimizer = Optimizer(optimizer_model_file_path, model) + scheduler = LinearLRScheduler(optimizer, 1, 2, 0.2) + + # Test get and set learning rate. +@@ -199,11 +212,17 @@ def test_training_module_checkpoint(): + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir) ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ _, ++ _, ++ _, ++ ) = _create_training_artifacts(temp_dir) + # Create Checkpoint State. +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + # Create a Training Module and Training Optimizer. +- model = Module(artifacts.training_model_file_path, state) ++ model = Module(training_model_file_path, state) + + model.train() + model(inputs, labels) +@@ -218,7 +237,7 @@ def test_training_module_checkpoint(): + + # Assert the checkpoint parameters remain after saving. + new_state = CheckpointState.load_checkpoint(checkpoint_save_path) +- new_model = Module(artifacts.training_model_file_path, new_state) ++ new_model = Module(training_model_file_path, new_state) + + new_params = new_model.get_contiguous_parameters() + +@@ -233,17 +252,23 @@ def test_copy_buffer_to_parameters(trainable_only, optimizer_type): + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts( ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ _, ++ optimizer_model_file_path, ++ _, ++ ) = _create_training_artifacts( + temp_dir, + requires_grad=["fc2.weight", "fc2.bias"], + frozen_params=["fc1.weight", "fc1.bias"], + optimizer_type=optimizer_type, + ) +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + + # Create a Module and Optimizer. +- model = Module(artifacts.training_model_file_path, state) +- optimizer = Optimizer(artifacts.optimizer_model_file_path, model) ++ model = Module(training_model_file_path, state) ++ optimizer = Optimizer(optimizer_model_file_path, model) + + # Keep a copy of the parameters. + old_output_params = model.get_contiguous_parameters(trainable_only=trainable_only) +@@ -270,13 +295,19 @@ def test_copy_buffer_to_parameters(trainable_only, optimizer_type): + + def test_export_model_for_inferencing(): + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir) ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ eval_model_file_path, ++ _, ++ _, ++ ) = _create_training_artifacts(temp_dir) + + # Create Checkpoint State. +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + + # Create a Module. +- model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path) ++ model = Module(training_model_file_path, state, eval_model_file_path) + + # Export inference model + inference_model_file_path = os.path.join(temp_dir, "inference_model.onnx") +@@ -286,12 +317,18 @@ def test_export_model_for_inferencing(): + + def test_cuda_execution_provider(): + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir) ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ _, ++ _, ++ _, ++ ) = _create_training_artifacts(temp_dir) + + # Create Checkpoint State. +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + # Create a Module. +- model = Module(artifacts.training_model_file_path, state, device="cuda") ++ model = Module(training_model_file_path, state, device="cuda") + params = model.get_contiguous_parameters() + + # Check if parameters are moved to cuda. +@@ -304,13 +341,19 @@ def test_cuda_execution_provider(): + ) + def test_add_get_property(property_value): + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir) ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ _, ++ _, ++ _, ++ ) = _create_training_artifacts(temp_dir) + + # Create Checkpoint State. +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + + # Create a Module. +- _ = Module(artifacts.training_model_file_path, state) ++ _ = Module(training_model_file_path, state) + + # Float values in python are double precision. + # Convert to float32 to match the type of the property. +@@ -324,8 +367,8 @@ def test_add_get_property(property_value): + assert state.properties["property"] == property_value + assert len(state.properties) == 1 + +- CheckpointState.save_checkpoint(state, artifacts.checkpoint_file_path) +- new_state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ CheckpointState.save_checkpoint(state, checkpoint_file_path) ++ new_state = CheckpointState.load_checkpoint(checkpoint_file_path) + assert "property" in new_state.properties + assert new_state.properties["property"] == property_value + assert len(new_state.properties) == 1 +@@ -333,15 +376,21 @@ def test_add_get_property(property_value): + + def test_get_input_output_names(): + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir) ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ eval_model_file_path, ++ _, ++ _, ++ ) = _create_training_artifacts(temp_dir) + + # Create Checkpoint State. +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + + # Create a Module. +- model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path) ++ model = Module(training_model_file_path, state, eval_model_file_path) + +- training_model = onnx.load(artifacts.training_model_file_path) ++ training_model = onnx.load(training_model_file_path) + assert model.input_names() == [input.name for input in training_model.graph.input][:2] + assert model.output_names() == [output.name for output in training_model.graph.output][:1] + +@@ -469,18 +518,23 @@ def test_train_step_with_ort_values(): + labels = OrtValue.ortvalue_from_numpy(labels_np) + + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir) +- ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ _, ++ _, ++ pt_model, ++ ) = _create_training_artifacts(temp_dir) + # Create Checkpoint State. +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + # Create a Module. +- model = Module(artifacts.training_model_file_path, state) ++ model = Module(training_model_file_path, state) + model.train() + ort_loss = model(inputs, labels) + assert isinstance(ort_loss, OrtValue) + + # Calculate loss using pytorch model to compare it with Module's output. +- pt_outputs = artifacts.pt_model(torch.from_numpy(inputs_np)) ++ pt_outputs = pt_model(torch.from_numpy(inputs_np)) + loss_fn = torch.nn.CrossEntropyLoss() + pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels_np).long()) + +@@ -495,11 +549,17 @@ def test_eval_step_with_ort_values(): + labels = OrtValue.ortvalue_from_numpy(labels_np) + + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir) ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ eval_model_file_path, ++ _, ++ _, ++ ) = _create_training_artifacts(temp_dir) + # Create Checkpoint State. +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + # Create a Module. +- model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path) ++ model = Module(training_model_file_path, state, eval_model_file_path) + model.train() + model(inputs, labels) + +@@ -512,20 +572,26 @@ def test_eval_step_with_ort_values(): + @pytest.mark.parametrize("device", ["cpu", "cuda"]) + def test_get_and_set_parameter_values(device): + with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts( ++ ( ++ checkpoint_file_path, ++ training_model_file_path, ++ eval_model_file_path, ++ _, ++ pt_model, ++ ) = _create_training_artifacts( + temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"] + ) + +- state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) ++ state = CheckpointState.load_checkpoint(checkpoint_file_path) + +- model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path, device=device) ++ model = Module(training_model_file_path, state, eval_model_file_path, device=device) + +- state_dict = artifacts.pt_model.state_dict() ++ state_dict = pt_model.state_dict() + assert len(state_dict) == len(state.parameters) + for parameter_name, _ in state.parameters: + assert parameter_name in state_dict + +- for name, pt_param in artifacts.pt_model.named_parameters(): ++ for name, pt_param in pt_model.named_parameters(): + ort_param = state.parameters[name] + assert ort_param.name == name + assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data) +@@ -546,7 +612,7 @@ def test_get_and_set_parameter_values(device): + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + loss = model(inputs, labels) + assert loss is not None +- for name, _ in artifacts.pt_model.named_parameters(): ++ for name, _ in pt_model.named_parameters(): + ort_param = state.parameters[name] + assert ort_param.name == name + if name in ["fc1.weight", "fc1.bias"]: +@@ -558,111 +624,3 @@ def test_get_and_set_parameter_values(device): + + state.parameters["fc1.weight"] = original_param + assert np.allclose(state.parameters["fc1.weight"].data, original_param) +- +- +-def test_model_construction_with_nominal_checkpoint(): +- with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir, nominal_checkpoint=True) +- +- nominal_state = CheckpointState.load_checkpoint(artifacts.nominal_checkpoint_file_path) +- model_with_nominal_state = Module( +- artifacts.training_model_file_path, nominal_state, artifacts.eval_model_file_path +- ) +- optimizer_with_nominal_state = Optimizer(artifacts.optimizer_model_file_path, model_with_nominal_state) +- +- inputs = torch.randn(64, 784).numpy() +- labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() +- +- err_msg = "Please load the parameter states first" +- +- # Accessing the checkpoint parameter raises +- state_dict = artifacts.pt_model.state_dict() +- for param_name in state_dict: +- assert param_name in nominal_state.parameters +- with pytest.raises(Exception) as exc_info: +- _ = nominal_state.parameters["fc1.weight"] +- +- assert err_msg in str(exc_info.value) +- +- err_msg = "Please load all the parameter states first" +- with pytest.raises(Exception) as exc_info: +- nominal_state.parameters["fc1.weight"] = np.ones((10, 10), dtype=np.float32) +- +- assert err_msg in str(exc_info.value) +- +- err_msg = "Please load the model parameters first." +- +- # Getting contiguous parameters raises +- with pytest.raises(Exception) as exc_info: +- _ = model_with_nominal_state.get_contiguous_parameters() +- +- assert err_msg in str(exc_info.value) +- +- # Train step raises +- with pytest.raises(Exception) as exc_info: +- model_with_nominal_state.train() +- model_with_nominal_state(inputs, labels) +- +- assert err_msg in str(exc_info.value) +- +- # Optimizer step raises +- with pytest.raises(Exception) as exc_info: +- optimizer_with_nominal_state.step() +- +- assert err_msg in str(exc_info.value) +- +- # Eval step raises +- with pytest.raises(Exception) as exc_info: +- model_with_nominal_state.eval() +- model_with_nominal_state(inputs, labels) +- +- assert err_msg in str(exc_info.value) +- +- # Get parameters size does not raise +- params_size = model_with_nominal_state.get_parameters_size() +- assert params_size > 0 +- +- +-def test_train_with_nominal_checkpoint(): +- with tempfile.TemporaryDirectory() as temp_dir: +- artifacts = _create_training_artifacts(temp_dir, nominal_checkpoint=True) +- +- # Create Checkpoint State with nominal checkpoint as well as the complete checkpoint. +- complete_state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) +- nominal_state = CheckpointState.load_checkpoint(artifacts.nominal_checkpoint_file_path) +- +- # Create a Module with both complete and nominal checkpoint states. +- model_with_complete_state = Module(artifacts.training_model_file_path, complete_state) +- model_with_nominal_state = Module(artifacts.training_model_file_path, nominal_state) +- +- optimizer_with_complete_state = Optimizer(artifacts.optimizer_model_file_path, model_with_complete_state) +- optimizer_with_nominal_state = Optimizer(artifacts.optimizer_model_file_path, model_with_nominal_state) +- +- parameter_buffer = model_with_complete_state.get_contiguous_parameters() +- model_with_nominal_state.copy_buffer_to_parameters(parameter_buffer, trainable_only=False) +- +- model_with_complete_state.train() +- model_with_nominal_state.train() +- +- # Generate random data for testing. +- inputs = torch.randn(64, 784).numpy() +- labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() +- +- ort_loss_1 = model_with_complete_state(inputs, labels) +- ort_loss_2 = model_with_nominal_state(inputs, labels) +- +- # Calculate loss using pytorch model to compare it with both the Modules' output. +- pt_outputs = artifacts.pt_model(torch.from_numpy(inputs)) +- loss_fn = torch.nn.CrossEntropyLoss() +- pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels).long()) +- +- assert np.allclose(ort_loss_1, ort_loss_2) +- assert np.allclose(ort_loss_1, pt_loss.detach().numpy()) +- +- optimizer_with_complete_state.step() +- optimizer_with_nominal_state.step() +- +- new_params_1 = model_with_complete_state.get_contiguous_parameters() +- new_params_2 = model_with_nominal_state.get_contiguous_parameters() +- +- assert np.allclose(new_params_1.numpy(), new_params_2.numpy()) +diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +index 922f5c696..0c381d70c 100644 +--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py ++++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +@@ -12,7 +12,6 @@ import onnx + import pytest + import torch + from onnx import TensorProto, helper +-from packaging.version import Version + from torch._C import _from_dlpack + from torch.utils.dlpack import to_dlpack + +@@ -843,32 +842,6 @@ def test_slice_scel_module(dtype, has_sum): + _run_module_test(NeuralNetSliceScel, dtype, _gen_inputs, 2) + + +-@pytest.mark.skipif( +- Version(torch.__version__) < Version("2.1"), reason="PyTorch has scaled_dot_product_attention since 2.1." +-) +-def test_scaled_dot_product_attention_module(): +- class NeuralNetScaledDotProductAttention(torch.nn.Module): +- def __init__(self): +- super().__init__() +- self.linear1 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16) +- self.linear2 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16) +- self.linear3 = torch.nn.Linear(64, 64, bias=False, dtype=torch.float16) +- +- def forward(self, q, k, v): +- return torch.nn.functional.scaled_dot_product_attention( +- self.linear1(q), self.linear2(k), self.linear3(v) +- ).to(torch.float16) +- +- def _gen_inputs(dtype): +- return [ +- (torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE), +- (torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE), +- (torch.rand(32, 8, 128, 64) * 0.01).to(dtype=torch.float16, device=DEVICE), +- ] +- +- _run_module_test(NeuralNetScaledDotProductAttention, torch.float16, _gen_inputs, 3) +- +- + @pytest.mark.parametrize("dtype", [torch.float32, torch.float16]) + @pytest.mark.parametrize("input_shapes", [([128, 64], [64, 64]), ([16, 64, 128], [16, 128, 64])]) + def test_matmul_tunable_op(dtype, input_shapes): +diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +index 5c53addb8..1369c9c69 100644 +--- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc ++++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +@@ -95,8 +95,7 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) { + // Call Save APIs. + PathString checkpoint_path{ + ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("e2e_ckpt_save_cpu"))}; +- ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path, +- false /* nominal checkpoint */)); ++ ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path)); + + /// Phase 3 - Run load checkpoint APIs. + /// And check the result comparable with initial parameter values. +@@ -194,8 +193,7 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpointThenLoadFromBufferCPU) { + // Call Save APIs. + PathString checkpoint_path{ + ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("e2e_ckpt_save_cpu"))}; +- ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path, +- false /* nominal checkpoint */)); ++ ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path)); + + /// Phase 3 - Run load checkpoint APIs. + /// And check the result comparable with initial parameter values. +@@ -437,37 +435,4 @@ TEST(CheckpointApiTest, SaveCustomPropertyAsCheckpoint_ThenLoad_CPU) { + std::string restored_s_data = restored_property_bag.GetProperty(s_property_name); + ASSERT_EQ(s_data, restored_s_data); + } +- +-/** +- * Loads a nominal checkpoint. Checks for nominal flag, and that the state is empty. +- * Saves the checkpoint, and loads it again. Checks for nominal flag, and that the state is empty. +- */ +-TEST(CheckpointApiTest, LoadAndSaveNominalCheckpoint) { +- PathString nominal_checkpoint_path{ORT_TSTR("testdata/training_api/nominal_checkpoint")}; +- +- CheckpointState checkpoint_state; +- ASSERT_STATUS_OK(LoadCheckpoint(nominal_checkpoint_path, checkpoint_state)); +- ASSERT_TRUE(checkpoint_state.module_checkpoint_state.is_nominal_state); +- for (auto& [name, param] : checkpoint_state.module_checkpoint_state.named_parameters) { +- ASSERT_TRUE(param->Data().IsTensor()); +- // An empty tensor will have size 1. +- ASSERT_EQ(param->Data().Get().Shape().Size(), 1); +- } +- +- // Remove the temporary directory if it already exists. +- auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir"); +- TemporaryDirectory tmp_dir{ckpt_test_root_dir}; +- PathString checkpoint_path{ +- ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("nominal_checkpoint_2"))}; +- ASSERT_STATUS_OK(SaveCheckpoint(checkpoint_state, checkpoint_path, false)); +- +- CheckpointState checkpoint_state_2; +- ASSERT_STATUS_OK(LoadCheckpoint(checkpoint_path, checkpoint_state_2)); +- ASSERT_TRUE(checkpoint_state_2.module_checkpoint_state.is_nominal_state); +- for (auto& [name, param] : checkpoint_state_2.module_checkpoint_state.named_parameters) { +- ASSERT_TRUE(param->Data().IsTensor()); +- // An empty tensor will have size 1. +- ASSERT_EQ(param->Data().Get().Shape().Size(), 1); +- } +-} + } // namespace onnxruntime::training::test +diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc +index e2232687d..2170f7957 100644 +--- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc ++++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc +@@ -537,167 +537,6 @@ TEST(TrainingApiTest, OptimStep) { + } + } + +-TEST(TrainingApiTest, ModuleAndOptimizerWithNominalState) { +- auto model_uri = MODEL_FOLDER "training_model.onnx"; +- auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; +- auto optim_uri = MODEL_FOLDER "adamw.onnx"; +- +- onnxruntime::training::api::CheckpointState complete_state; +- onnxruntime::training::api::CheckpointState nominal_state; +- auto complete_checkpoint_path = MODEL_FOLDER "checkpoint.ckpt"; +- auto nominal_checkpoint_path = MODEL_FOLDER "nominal_checkpoint"; +- ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpoint(complete_checkpoint_path, complete_state)); +- ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpoint(nominal_checkpoint_path, nominal_state)); +- +- ASSERT_FALSE(complete_state.module_checkpoint_state.is_nominal_state); +- ASSERT_TRUE(nominal_state.module_checkpoint_state.is_nominal_state); +- +- onnxruntime::SessionOptions session_option; +- std::unique_ptr env; +- std::vector> providers; +-#if defined(USE_CUDA) +- providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider()); +-#endif +- ASSERT_STATUS_OK(Environment::Create(nullptr, env)); +- +- auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), +- std::optional(onnxruntime::ToUTF8String(eval_model_uri)), +- std::optional(onnxruntime::ToUTF8String(optim_uri))); +- auto model_with_complete_state = std::make_unique( +- model_identifier, &complete_state, session_option, +- *env, providers); +- auto model_with_nominal_state = std::make_unique( +- model_identifier, &nominal_state, session_option, +- *env, providers); +- auto optim_with_complete_state = std::make_unique( +- model_identifier, &complete_state, session_option, +- *env, providers); +- auto optim_with_nominal_state = std::make_unique( +- model_identifier, &nominal_state, session_option, +- *env, providers); +- +- // Before running the test, copy all the parameters to the nominal module. +- ASSERT_EQ(model_with_complete_state->GetParametersSize(), model_with_nominal_state->GetParametersSize()); +- int64_t params_size = static_cast(model_with_nominal_state->GetParametersSize()); +- OrtValue params_buffer; +- Tensor::InitOrtValue(DataTypeImpl::GetType(), {params_size}, +- onnxruntime::test::TestCPUExecutionProvider()->CreatePreferredAllocators()[0], +- params_buffer); +- ASSERT_STATUS_OK(model_with_complete_state->CopyParametersToBuffer(params_buffer, false)); +- ASSERT_STATUS_OK(model_with_nominal_state->CopyBufferToParameters(params_buffer, false)); +- +- ASSERT_STATUS_OK(optim_with_nominal_state->ConstructOptimizerStateAndInputs()); +- +- OrtValue input, target; +- GenerateRandomInput(std::array{2, 784}, input); +- target = onnxruntime::test::CreateInputOrtValueOnCPU( +- std::array{2}, std::vector(2, 1)); +- auto data_loader = std::vector>(4, std::vector{input, target}); +- +- for (auto it = data_loader.begin(); it != data_loader.end(); ++it) { +- std::vector& inputs = *it; +- std::vector complete_fetches; +- std::vector nominal_fetches; +- ASSERT_STATUS_OK(model_with_complete_state->TrainStep(inputs, complete_fetches)); +- ASSERT_STATUS_OK(model_with_nominal_state->TrainStep(inputs, nominal_fetches)); +- +- ASSERT_GT(complete_fetches.size(), 0); +- for (size_t i = 0; i < complete_fetches.size(); ++i) { +- ASSERT_TRUE(complete_fetches[i].IsTensor()); +- ASSERT_TRUE(nominal_fetches[i].IsTensor()); +- const Tensor& complete_tensor = complete_fetches[i].Get(); +- const Tensor& nominal_tensor = nominal_fetches[i].Get(); +- ASSERT_EQ(complete_tensor.Shape(), nominal_tensor.Shape()); +- ASSERT_EQ(complete_tensor.DataType(), nominal_tensor.DataType()); +- +- std::vector complete_fetches_vec; +- std::vector nominal_fetches_vec; +-#if defined(USE_CUDA) +- CudaOrtValueToCpuVec(complete_fetches[i], complete_fetches_vec); +- CudaOrtValueToCpuVec(nominal_fetches[i], nominal_fetches_vec); +-#else +- CpuOrtValueToVec(complete_fetches[i], complete_fetches_vec); +- CpuOrtValueToVec(nominal_fetches[i], nominal_fetches_vec); +-#endif +- +- for (size_t j = 0; j < complete_fetches_vec.size(); ++j) { +- ASSERT_EQ(complete_fetches_vec[j], nominal_fetches_vec[j]); +- } +- } +- +- ASSERT_STATUS_OK(optim_with_complete_state->Step()); +- ASSERT_STATUS_OK(optim_with_nominal_state->Step()); +- +- for (auto& [name, param] : model_with_complete_state->NamedParameters()) { +- ASSERT_TRUE(param->Data().IsTensor()); +- ASSERT_TRUE(param->Gradient().IsTensor()); +- ASSERT_TRUE(model_with_nominal_state->NamedParameters().at(name)->Data().IsTensor()); +- ASSERT_TRUE(model_with_nominal_state->NamedParameters().at(name)->Gradient().IsTensor()); +- +- const Tensor& complete_data = param->Data().Get(); +- const Tensor& complete_grad = param->Gradient().Get(); +- const Tensor& nominal_data = model_with_nominal_state->NamedParameters().at(name)->Data().Get(); +- const Tensor& nominal_grad = model_with_nominal_state->NamedParameters().at(name)->Gradient().Get(); +- +- ASSERT_EQ(complete_data.Shape(), nominal_data.Shape()); +- ASSERT_EQ(complete_data.DataType(), nominal_data.DataType()); +- ASSERT_EQ(complete_grad.Shape(), nominal_grad.Shape()); +- ASSERT_EQ(complete_grad.DataType(), nominal_grad.DataType()); +- +- std::vector complete_data_vec; +- std::vector complete_grad_vec; +- std::vector nominal_data_vec; +- std::vector nominal_grad_vec; +- +-#if defined(USE_CUDA) +- CudaOrtValueToCpuVec(param->Data(), complete_data_vec); +- CudaOrtValueToCpuVec(param->Gradient(), complete_grad_vec); +- CudaOrtValueToCpuVec(model_with_nominal_state->NamedParameters().at(name)->Data(), nominal_data_vec); +- CudaOrtValueToCpuVec(model_with_nominal_state->NamedParameters().at(name)->Gradient(), nominal_grad_vec); +-#else +- CpuOrtValueToVec(param->Data(), complete_data_vec); +- CpuOrtValueToVec(param->Gradient(), complete_grad_vec); +- CpuOrtValueToVec(model_with_nominal_state->NamedParameters().at(name)->Data(), nominal_data_vec); +- CpuOrtValueToVec(model_with_nominal_state->NamedParameters().at(name)->Gradient(), nominal_grad_vec); +-#endif +- +- for (size_t j = 0; j < complete_data_vec.size(); ++j) { +- ASSERT_EQ(complete_data_vec[j], nominal_data_vec[j]); +- ASSERT_EQ(complete_grad_vec[j], nominal_grad_vec[j]); +- } +- } +- +- std::vector complete_eval_fetches; +- std::vector nominal_eval_fetches; +- ASSERT_STATUS_OK(model_with_complete_state->EvalStep(inputs, complete_eval_fetches)); +- ASSERT_STATUS_OK(model_with_nominal_state->EvalStep(inputs, nominal_eval_fetches)); +- +- ASSERT_GT(complete_eval_fetches.size(), 0); +- for (size_t i = 0; i < complete_eval_fetches.size(); ++i) { +- ASSERT_TRUE(complete_eval_fetches[i].IsTensor()); +- ASSERT_TRUE(nominal_eval_fetches[i].IsTensor()); +- const Tensor& complete_tensor = complete_eval_fetches[i].Get(); +- const Tensor& nominal_tensor = nominal_eval_fetches[i].Get(); +- ASSERT_EQ(complete_tensor.Shape(), nominal_tensor.Shape()); +- ASSERT_EQ(complete_tensor.DataType(), nominal_tensor.DataType()); +- +- std::vector complete_eval_fetches_vec; +- std::vector nominal_eval_fetches_vec; +-#if defined(USE_CUDA) +- CudaOrtValueToCpuVec(complete_eval_fetches[i], complete_eval_fetches_vec); +- CudaOrtValueToCpuVec(nominal_eval_fetches[i], nominal_eval_fetches_vec); +-#else +- CpuOrtValueToVec(complete_eval_fetches[i], complete_eval_fetches_vec); +- CpuOrtValueToVec(nominal_eval_fetches[i], nominal_eval_fetches_vec); +-#endif +- +- for (size_t j = 0; j < complete_eval_fetches_vec.size(); ++j) { +- ASSERT_EQ(complete_eval_fetches_vec[j], nominal_eval_fetches_vec[j]); +- } +- } +- } +-} +- + } // namespace test + } // namespace training + } // namespace onnxruntime +diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +index 8f25e1e4c..e46952d87 100644 +--- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc ++++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +@@ -420,79 +420,4 @@ TEST(TrainingCApiTest, UpdateParameterDifferentDevices) { + } + #endif + +-TEST(TrainingCApiTest, ModuleAndOptimizerWithNominalState) { +- auto training_model_uri = MODEL_FOLDER "training_model.onnx"; +- auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; +- auto optimizer_model_uri = MODEL_FOLDER "adamw.onnx"; +- +- Ort::Env env; +- Ort::SessionOptions session_options_for_complete_state; +- Ort::SessionOptions session_options_for_nominal_state; +- Ort::CheckpointState complete_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); +- Ort::CheckpointState nominal_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "nominal_checkpoint"); +- +-#ifdef USE_CUDA +- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options_for_complete_state, 0)); +- Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options_for_nominal_state, 0)); +-#endif +- +- Ort::TrainingSession complete_training_session = Ort::TrainingSession(env, session_options_for_complete_state, complete_state, +- training_model_uri, eval_model_uri, optimizer_model_uri); +- Ort::TrainingSession nominal_training_session = Ort::TrainingSession(env, session_options_for_nominal_state, nominal_state, +- training_model_uri, eval_model_uri, +- optimizer_model_uri); +- +- Ort::Value params_buffer = complete_training_session.ToBuffer(false); +- nominal_training_session.FromBuffer(params_buffer); +- +- for (size_t i = 0; i < 4U; ++i) { +- std::vector x(2 * 784); +- std::vector x_shape{2, 784}; +- GenerateRandomData(x); +- +- std::vector labels{0, 8}; +- std::vector labels_shape{2}; +- +- Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); +- std::vector ort_inputs; +- ort_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, x.data(), +- x.size() * sizeof(float), +- x_shape.data(), x_shape.size(), +- ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); +- ort_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels.data(), +- labels.size() * sizeof(int32_t), +- labels_shape.data(), labels_shape.size(), +- ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)); +- +- std::vector complete_fetches = complete_training_session.TrainStep(ort_inputs); +- std::vector nominal_fetches = nominal_training_session.TrainStep(ort_inputs); +- +- ASSERT_EQ(complete_fetches.size(), nominal_fetches.size()); +- ASSERT_GT(complete_fetches.size(), 0U); +- for (size_t j = 0; j < complete_fetches.size(); ++j) { +- ASSERT_TRUE(complete_fetches[j].IsTensor()); +- ASSERT_TRUE(nominal_fetches[j].IsTensor()); +- +- auto complete_tensor_info = complete_fetches[j].GetTensorTypeAndShapeInfo(); +- auto nominal_tensor_info = nominal_fetches[j].GetTensorTypeAndShapeInfo(); +- +- ASSERT_EQ(complete_tensor_info.GetShape(), nominal_tensor_info.GetShape()); +- ASSERT_EQ(complete_tensor_info.GetElementType(), nominal_tensor_info.GetElementType()); +- +- gsl::span complete_data = gsl::span(complete_fetches[j].GetTensorMutableData(), +- complete_tensor_info.GetElementCount()); +- gsl::span nominal_data = gsl::span(nominal_fetches[j].GetTensorMutableData(), +- nominal_tensor_info.GetElementCount()); +- +- ASSERT_EQ(complete_data, nominal_data); +- } +- +- complete_training_session.OptimizerStep(); +- nominal_training_session.OptimizerStep(); +- +- complete_training_session.LazyResetGrad(); +- nominal_training_session.LazyResetGrad(); +- } +-} +- + } // namespace onnxruntime::training::test +diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc +index 720bdd7e6..dbcef78c3 100644 +--- a/orttraining/orttraining/training_api/checkpoint.cc ++++ b/orttraining/orttraining/training_api/checkpoint.cc +@@ -174,7 +174,7 @@ Status ToFile(const PathString& checkpoint_path, flatbuffers::FlatBufferBuilder& + Status FromTensorProtos( + gsl::span trainable_tensor_protos, + gsl::span non_trainable_tensor_protos, +- const PathString& checkpoint_path, const bool nominal_checkpoint) { ++ const PathString& checkpoint_path) { + const auto check_unique = [](gsl::span tensor_protos, + InlinedHashSet& unique_names) { + for (const auto& tensor_proto : tensor_protos) { +@@ -230,7 +230,6 @@ Status FromTensorProtos( + fbs::ModuleStateBuilder module_state_builder(builder); + module_state_builder.add_requires_grad_params(fbs_trainable_tensors); + module_state_builder.add_frozen_params(fbs_non_trainable_tensors); +- module_state_builder.add_is_nominal_state(nominal_checkpoint); + flatbuffers::Offset fbs_module_state = module_state_builder.Finish(); + + fbs::CheckpointBuilder checkpoint_builder(builder); +@@ -295,7 +294,6 @@ Status FromModuleState(const ModuleCheckpointState& module_state, + fbs::ModuleStateBuilder module_state_builder(builder); + module_state_builder.add_requires_grad_params(fbs_trainable_tensors); + module_state_builder.add_frozen_params(fbs_non_trainable_tensors); +- module_state_builder.add_is_nominal_state(module_state.is_nominal_state); + fbs_module_state = module_state_builder.Finish(); + + return Status::OK(); +@@ -515,8 +513,6 @@ Status ToModuleState( + module_state.named_parameters.insert({name, param}); + } + +- module_state.is_nominal_state = fbs_module_state.is_nominal_state(); +- + return Status::OK(); + } + +@@ -650,10 +646,6 @@ Status ToModelProto(gsl::span checkpoint_bytes, + ORT_RETURN_IF_NOT(frozen_params, + "Checkpoint is invalid. Expected: Valid non-trainable params flatbuffer. Actual: nullptr."); + +- ORT_RETURN_IF(module_state->is_nominal_state(), +- "Cannot load a nominal checkpoint to a model proto. " +- "Expected: Complete checkpoint. Actual: Nominal checkpoint."); +- + InlinedHashMap param_tensor_protos; + param_tensor_protos.reserve( + static_cast(requires_grad_params->size()) + static_cast(frozen_params->size())); +@@ -725,33 +717,14 @@ Status ToCheckpointState(gsl::span checkpoint_bytes, CheckpointSt + + } // namespace load + +-#if !defined(ORT_MINIMAL_BUILD) +-InlinedVector Nominalize(gsl::span tensor_protos) { +- InlinedVector nominal_tensor_protos; +- nominal_tensor_protos.reserve(tensor_protos.size()); +- for (const auto& tensor_proto : tensor_protos) { +- ONNX_NAMESPACE::TensorProto nominal_tensor_proto; +- nominal_tensor_proto.set_name(tensor_proto.name()); +- nominal_tensor_proto.set_data_type(tensor_proto.data_type()); +- nominal_tensor_protos.push_back(nominal_tensor_proto); +- } +- +- return nominal_tensor_protos; +-} +-#endif +- + } // namespace + + #if !defined(ORT_MINIMAL_BUILD) + Status SaveCheckpoint(gsl::span trainable_tensor_protos, + gsl::span non_trainable_tensor_protos, +- const PathString& checkpoint_path, const bool nominal_checkpoint) { ++ const PathString& checkpoint_path) { + ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines"); +- return nominal_checkpoint +- ? save::FromTensorProtos(Nominalize(trainable_tensor_protos), Nominalize(non_trainable_tensor_protos), +- checkpoint_path, nominal_checkpoint) +- : save::FromTensorProtos(trainable_tensor_protos, non_trainable_tensor_protos, checkpoint_path, +- nominal_checkpoint); ++ return save::FromTensorProtos(trainable_tensor_protos, non_trainable_tensor_protos, checkpoint_path); + } + #endif + +diff --git a/orttraining/orttraining/training_api/checkpoint.h b/orttraining/orttraining/training_api/checkpoint.h +index 95d3820a3..5d8554662 100644 +--- a/orttraining/orttraining/training_api/checkpoint.h ++++ b/orttraining/orttraining/training_api/checkpoint.h +@@ -49,12 +49,11 @@ Status SaveCheckpoint(const CheckpointState& state, const PathString& checkpoint + * @param trainable_tensor_protos trainable parameters in TensorProto format. + * @param non_trainable_tensor_protos non-trainable parameters in TensorProto format. + * @param checkpoint_path file where checkpoint is saved. +- * @param nominal_checkpoint flag indicating whether to save the complete checkpoint or the nominal checkpoint. + * @return Status + */ + Status SaveCheckpoint(gsl::span trainable_tensor_protos, + gsl::span non_trainable_tensor_protos, +- const PathString& checkpoint_path, const bool nominal_checkpoint); ++ const PathString& checkpoint_path); + #endif + + /** +diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +index ed6d151a5..0e8544a76 100644 +--- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h ++++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +@@ -132,7 +132,6 @@ struct OrtTrainingApi { + * \note Note that the training session created with a checkpoint state uses this state to store the entire + * training state (including model parameters, its gradients, the optimizer states and the properties). + * As a result, it is required that the checkpoint state outlive the lifetime of the training session. +- * \note Note that the checkpoint file can be either the complete checkpoint or the nominal checkpoint. + * + * \param[in] checkpoint_path Path to the checkpoint file + * \param[out] checkpoint_state Checkpoint state that contains the states of the training session. +@@ -464,12 +463,10 @@ struct OrtTrainingApi { + * + * The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call, + * with matching setting for trainable_only argument. All the target parameters must be of the same +- * datatype. This is a complementary function to OrtTrainingApi::CopyParametersToBuffer ++ * datatype. This is a complementary function to OrtTrainingApi::CopyBufferToParameters + * and can be used to load updated buffer values onto the training state. + * Parameter ordering is preserved. + * User is responsible for allocating and freeing the resources used by the parameters_buffer. +- * In case the training session was created with a nominal checkpoint, invoking this function is required +- * to load the updated parameters onto the checkpoint to complete it. + * + * \param[in] sess The `this` pointer to the training session. + * \param[in] trainable_only Whether to skip non-trainable parameters +diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +index e78c16136..218bef524 100644 +--- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h ++++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +@@ -58,8 +58,6 @@ using Property = std::variant; + * training state (including model parameters, its gradients, the optimizer states and the properties). + * The Ort::TrainingSession does not hold a copy of the Ort::CheckpointState and as a result, it is required + * that the checkpoint state outlive the lifetime of the training session. +- * \note Note that the checkpoint state can be either the complete checkpoint state or the nominal checkpoint +- * state depending on the version provided while loading the checkpoint. + * + */ + class CheckpointState : public detail::Base { +@@ -388,9 +386,6 @@ class TrainingSession : public detail::Base { + Value ToBuffer(const bool only_trainable); + + /** \brief Loads the training session model parameters from a contiguous buffer +- * +- * In case the training session was created with a nominal checkpoint, invoking this function is required +- * to load the updated parameters onto the checkpoint to complete it. + * + * \param[in] buffer Contiguous buffer to load the parameters from. + */ +diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +index 397cba0b0..7d1326a10 100644 +--- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h ++++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +@@ -168,23 +168,22 @@ inline void TrainingSession::FromBuffer(Value& buffer) { + + auto buffer_size = buffer_shape.front(); + +- size_t session_buffer_size = 0U; +- ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false)); +- +- if (buffer_size == static_cast(session_buffer_size)) { +- ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false)); +- return; +- } +- + size_t session_buffer_size_trainable_only = 0U; + ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size_trainable_only, true)); + + if (buffer_size == static_cast(session_buffer_size_trainable_only)) { + ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, true)); + return; +- } else { ++ } ++ ++ size_t session_buffer_size = 0U; ++ ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false)); ++ ++ if (buffer_size != static_cast(session_buffer_size)) { + ThrowStatus(Status("Incorrect buffer size received.", OrtErrorCode::ORT_INVALID_ARGUMENT)); + } ++ ++ ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false)); + } + + inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string& path_to_checkpoint) { +diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc +index 41ed79d28..cf49a0151 100644 +--- a/orttraining/orttraining/training_api/module.cc ++++ b/orttraining/orttraining/training_api/module.cc +@@ -6,8 +6,6 @@ + #include "core/common/safeint.h" + #include "core/common/string_utils.h" + #include "core/framework/execution_provider.h" +-#include "core/framework/mldata_type_utils.h" +-#include "core/framework/tensorprotoutils.h" + #include "core/session/inference_session.h" + #include "core/session/environment.h" + #include "core/session/onnxruntime_session_options_config_keys.h" +@@ -119,75 +117,6 @@ Status TransformModelInputsForInference(Graph& inference_graph, + return Status::OK(); + } + #endif +- +-InlinedHashMap BuildParameterToInputNodeArgMap(const ModuleCheckpointState& state, +- const InputDefList* model_inputs) { +- ORT_ENFORCE(model_inputs != nullptr, "Model inputs are not defined."); +- InlinedHashMap parameter_to_input_node_arg_map; +- parameter_to_input_node_arg_map.reserve(state.named_parameters.size()); +- for (const auto& input_def : *model_inputs) { +- const std::string& input_name = input_def->Name(); +- const auto param_it = state.named_parameters.find(input_name); +- if (param_it == state.named_parameters.end()) { +- continue; +- } +- parameter_to_input_node_arg_map[input_name] = input_def; +- } +- return parameter_to_input_node_arg_map; +-} +- +-InlinedHashMap BuildParameterToGradInputIndexMap(gsl::span grad_names) { +- InlinedHashMap param_name_to_grad_input_index_map; +- param_name_to_grad_input_index_map.reserve(grad_names.size()); +- for (size_t i = 0; i < grad_names.size(); ++i) { +- std::string param_name; +- utils::GetParamNameFromGradient(grad_names[i], param_name); +- param_name_to_grad_input_index_map.insert({param_name, i}); +- } +- return param_name_to_grad_input_index_map; +-} +- +-Status LoadParameter(const std::string& param_name, const Tensor& src_weight_tensor, +- const SessionState& session_state, const bool force_load, +- const InlinedHashMap& param_to_grad_index, +- gsl::span grad_names, Parameter& param) { +- InlinedVector node_info_vec; +- ORT_THROW_IF_ERROR(session_state.GetInputNodeInfo(param_name, node_info_vec)); +- const auto& node_info = node_info_vec.front(); +- const auto target_device = *node_info.device; +- for (auto it = node_info_vec.begin(); it != node_info_vec.end(); ++it) { +- ORT_ENFORCE(target_device == *(it->device), "Inconsistent device requirements found for input: ", param_name); +- } +- +- if (force_load || src_weight_tensor.Location().device.Type() != target_device.Type()) { +- auto weight_allocator = session_state.GetAllocator(target_device); +- ORT_ENFORCE(weight_allocator != nullptr); +- +- // Create a new tensor on the target_device and switch the source_ortvalue to point to this new tensor +- auto dst_weight_tensor = std::make_unique(src_weight_tensor.DataType(), src_weight_tensor.Shape(), +- weight_allocator); +- ORT_THROW_IF_ERROR(session_state.GetDataTransferMgr().CopyTensor(src_weight_tensor, *dst_weight_tensor.get())); +- auto ml_tensor_type = DataTypeImpl::GetType(); +- param.Data().Init(dst_weight_tensor.release(), ml_tensor_type, ml_tensor_type->GetDeleteFunc()); +- } +- +- if (param.RequiresGrad()) { +- // Create gradient accumulation buffer. +- auto grad_it = param_to_grad_index.find(param_name); +- ORT_ENFORCE(grad_it != param_to_grad_index.end(), "Gradient buffer input not provided for param: ", +- param_name); +- +- const size_t grad_input_index = grad_it->second; +- auto& param_grad_name = grad_names[grad_input_index]; +- +- OrtValue param_grad; +- ORT_THROW_IF_ERROR(utils::CreateZeroValuedOrtValueLike(session_state, param.Data(), param_grad)); +- ORT_THROW_IF_ERROR(param.SetGrad(param_grad_name, param_grad)); +- } +- +- return Status::OK(); +-} +- + } // namespace + + Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const { +@@ -322,6 +251,7 @@ Module::Module(const ModelIdentifiers& model_identifiers, + // user inputs, weights, gradients, reset_grad + InlinedVector user_input_names, param_input_names, grad_input_names, reset_grad_name; + ++ std::unordered_map param_name_to_grad_input_index_map; + for (const auto& input_name : train_input_names) { + auto it = state_->module_checkpoint_state.named_parameters.find(input_name); + if (it != state_->module_checkpoint_state.named_parameters.end()) { +@@ -329,6 +259,7 @@ Module::Module(const ModelIdentifiers& model_identifiers, + } else if (input_name == ACCUMULATE_GRAD_CONTROL_INPUT_NAME) { + reset_grad_name.emplace_back(input_name); + } else if (std::string param_name; utils::GetParamNameFromGradient(input_name, param_name)) { ++ param_name_to_grad_input_index_map.insert({param_name, grad_input_names.size()}); + grad_input_names.emplace_back(input_name); + } else { + user_input_names.emplace_back(input_name); +@@ -337,7 +268,11 @@ Module::Module(const ModelIdentifiers& model_identifiers, + + gradients_.resize(grad_input_names.size()); + +- train_input_names_ = TrainInputNames(user_input_names, param_input_names, grad_input_names); ++ train_input_names_ = user_input_names; ++ train_user_input_count_ = user_input_names.size(); ++ train_input_names_.insert(train_input_names_.end(), param_input_names.begin(), param_input_names.end()); ++ train_input_names_.insert(train_input_names_.end(), grad_input_names.begin(), grad_input_names.end()); ++ train_input_names_.insert(train_input_names_.end(), reset_grad_name.begin(), reset_grad_name.end()); + + for (const auto& output_name : train_output_names) { + if (std::string param_name; !utils::GetParamNameFromGradient(output_name, param_name)) { +@@ -345,24 +280,58 @@ Module::Module(const ModelIdentifiers& model_identifiers, + } + } + +- if (!state_->module_checkpoint_state.is_nominal_state) { +- // ORT_THROW_IF_ERROR(AllocateMemoryForWeights()); +- // Loop each parameter, and allocate its memory based on the user-specified device. +- const auto param_to_grad_index = BuildParameterToGradInputIndexMap(train_input_names_.GradientInputNames()); +- for (auto& param_name : train_input_names_.WeightsInputNames()) { +- auto params_iter = state_->module_checkpoint_state.named_parameters.find(param_name); +- ORT_ENFORCE(params_iter != state_->module_checkpoint_state.named_parameters.end()); +- +- OrtValue& param_data = params_iter->second->Data(); +- ORT_ENFORCE(param_data.IsTensor(), "Expected: Parameter data should be of tensor type. Actual: ", +- params_iter->second->Name(), " is not a tensor."); +- ORT_THROW_IF_ERROR(LoadParameter(param_name, param_data.Get(), train_sess_->GetSessionState(), +- false /* force_load */, param_to_grad_index, +- train_input_names_.GradientInputNames(), *params_iter->second)); +- weights_.push_back(param_data); +- if (params_iter->second->RequiresGrad()) { +- gradients_[param_to_grad_index.at(param_name)] = params_iter->second->Gradient(); +- } ++ // Loop each parameter, and allocate its memory based on the user-specified device. ++ auto& train_sess_state = train_sess_->GetSessionState(); ++ for (auto& param_name : param_input_names) { ++ auto params_iter = state_->module_checkpoint_state.named_parameters.find(param_name); ++ ORT_ENFORCE(params_iter != state_->module_checkpoint_state.named_parameters.end()); ++ ++ // Retrieve the target device for "param_name". ++ InlinedVector node_info_vec; ++ ORT_THROW_IF_ERROR(train_sess_state.GetInputNodeInfo(param_name, node_info_vec)); ++ const auto& node_info = node_info_vec.front(); ++ const auto target_device = *node_info.device; ++ for (auto it = node_info_vec.begin(); it != node_info_vec.end(); ++it) { ++ ORT_ENFORCE(target_device == *(it->device), "Inconsistent device requirements found for input: ", param_name); ++ } ++ ++ // Copy ortvalue buffer from CPU to target_device for this "param_name" (based on graph partitioning) ++ // Only copies data if the target device is not the same as the current device the buffer is placed on ++ OrtValue& param_data = params_iter->second->Data(); ++ ORT_ENFORCE(param_data.IsTensor()); ++ const Tensor& param_data_tensor = param_data.Get(); ++ // If the source device type is already the same as target device skip copy ++ if (param_data_tensor.Location().device.Type() != target_device.Type()) { ++ // TODO: move this outside of the for loop? ++ auto target_allocator = train_sess_state.GetAllocator(target_device); ++ ORT_ENFORCE(target_allocator != nullptr); ++ ++ // Create a new tensor on the target_device and switch the source_ortvalue to point to this new tensor ++ auto target_tensor = std::make_unique(param_data_tensor.DataType(), param_data_tensor.Shape(), ++ target_allocator); ++ ORT_THROW_IF_ERROR(train_sess_state.GetDataTransferMgr().CopyTensor(param_data_tensor, *target_tensor.get())); ++ auto ml_tensor_type = DataTypeImpl::GetType(); ++ param_data.Init(target_tensor.release(), ml_tensor_type, ml_tensor_type->GetDeleteFunc()); ++ } ++ ++ weights_.push_back(param_data); ++ weight_names_.push_back(param_name); ++ ++ // Create gradient buffer when parameter requires gradient. ++ if (params_iter->second->RequiresGrad()) { ++ // Create gradient accumulation buffer. ++ auto it = param_name_to_grad_input_index_map.find(param_name); ++ ORT_ENFORCE(it != param_name_to_grad_input_index_map.end(), "Gradient buffer input not provided for param: ", ++ param_name); ++ ++ const size_t grad_input_index = it->second; ++ auto& param_grad_name = grad_input_names[grad_input_index]; ++ // TODO: don't pre-allocate the gradient buffer. ++ // Gradient usually stays on the same device of its parameter. ++ OrtValue param_grad; ++ ORT_THROW_IF_ERROR(utils::CreateZeroValuedOrtValueLike(train_sess_state, param_data, param_grad)); ++ ORT_THROW_IF_ERROR(params_iter->second->SetGrad(param_grad_name, param_grad)); ++ gradients_[grad_input_index] = params_iter->second->Gradient(); + } + } + +@@ -445,24 +414,16 @@ std::string Module::GetEvalModelOutputName(size_t index) const { + + size_t Module::GetParametersSize(const bool trainable_only) const { + SafeInt parameters_size = 0; +- const auto model_inputs_with_error = GetTrainingModelInputs(); +- ORT_THROW_IF_ERROR(model_inputs_with_error.first); +- ORT_ENFORCE(model_inputs_with_error.second, "Training model graph inputs are not defined."); +- for (const auto& input_def : *model_inputs_with_error.second) { +- const std::string& input_name = input_def->Name(); +- const auto param_it = state_->module_checkpoint_state.named_parameters.find(input_name); +- if (param_it == state_->module_checkpoint_state.named_parameters.end() || +- (trainable_only && !param_it->second->RequiresGrad())) { ++ for (const auto& it : state_->module_checkpoint_state.named_parameters) { ++ if (trainable_only && !it.second->RequiresGrad()) { + continue; + } +- parameters_size += onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*input_def->Shape()).Size(); ++ parameters_size += it.second->Data().Get().Shape().Size(); + } + return parameters_size; + } + + std::vector> Module::Parameters() const { +- ORT_ENFORCE(!state_->module_checkpoint_state.is_nominal_state, +- "Cannot fetch parameters from a nominal checkpoint state. Please load the model parameters first."); + std::vector> params; + for (auto& it : state_->module_checkpoint_state.named_parameters) { + params.push_back(it.second); +@@ -471,27 +432,23 @@ std::vector> Module::Parameters() const { + } + + std::unordered_map> Module::NamedParameters() const { +- ORT_ENFORCE(!state_->module_checkpoint_state.is_nominal_state, +- "Cannot fetch named parameters from a nominal checkpoint state. Please load the model parameters first."); + return state_->module_checkpoint_state.named_parameters; + } + + Status Module::CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only) { +- ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, +- "Cannot copy parameters from a nominal checkpoint state. Please load the model parameters first."); +- ORT_RETURN_IF_NOT(parameters_buffer.IsAllocated(), "Parameters buffer should be pre-allocated."); +- ORT_RETURN_IF_NOT(parameters_buffer.IsTensor(), "Parameters buffer should be of tensor type."); ++ ORT_ENFORCE(parameters_buffer.IsAllocated(), "Parameters buffer should be pre-allocated."); ++ ORT_ENFORCE(parameters_buffer.IsTensor(), "Parameters buffer should be of tensor type."); + auto* init_tensor = parameters_buffer.GetMutable(); + ORT_ENFORCE(nullptr != init_tensor); + auto expected_buffer_size = static_cast(GetParametersSize(trainable_only)); +- ORT_RETURN_IF(init_tensor->Shape().Size() != expected_buffer_size, +- "Parameters buffer size incorrect. Expected:", expected_buffer_size, +- ", Actual:", init_tensor->Shape().Size()); ++ ORT_ENFORCE(init_tensor->Shape().Size() == expected_buffer_size, ++ "Parameters buffer size incorrect. Expected:", expected_buffer_size, ++ ", Actual:", init_tensor->Shape().Size()); + + const DataTransferManager& sess_data_transfer_manager = train_sess_->GetDataTransferManager(); + + size_t offset = 0; +- for (const auto& param_name : train_input_names_.WeightsInputNames()) { ++ for (const auto& param_name : weight_names_) { + auto& param = state_->module_checkpoint_state.named_parameters.at(param_name); + if (trainable_only && !param->RequiresGrad()) { + continue; +@@ -501,7 +458,7 @@ Status Module::CopyParametersToBuffer(OrtValue& parameters_buffer, const bool tr + + const TensorShape& shape = weight_tensor->Shape(); + auto element_type = init_tensor->DataType(); +- ORT_RETURN_IF(weight_tensor->DataType() != element_type, "Data types must match."); ++ ORT_ENFORCE(weight_tensor->DataType() == element_type, "Data types must match."); + + const OrtMemoryInfo& info = init_tensor->Location(); + std::unique_ptr p_tensor; +@@ -513,102 +470,54 @@ Status Module::CopyParametersToBuffer(OrtValue& parameters_buffer, const bool tr + data_buffer + offset, + info); + } else { +- ORT_THROW("Unsupported type: ", element_type, " encountered while copying parameters to buffer. ", +- "Only float is supported."); ++ ORT_THROW("Unsupported type: ", element_type); + } +- ORT_RETURN_IF_ERROR(sess_data_transfer_manager.CopyTensor(*weight_tensor, *p_tensor.get())); ++ ORT_THROW_IF_ERROR(sess_data_transfer_manager.CopyTensor(*weight_tensor, *p_tensor.get())); + offset += shape.Size(); + } + return Status::OK(); + } + + Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only) { +- // In case of a nominal checkpoint state, all parameters need to be loaded into the model. +- // i.e. trainable_only must be false. +- ORT_RETURN_IF(trainable_only && state_->module_checkpoint_state.is_nominal_state, +- "For nominal checkpoint state, all parameters need to be loaded into the model " +- "(trainable_only = false)."); +- ORT_RETURN_IF_NOT(parameters_buffer.IsAllocated(), "Parameters buffer should be pre-allocated."); +- ORT_RETURN_IF_NOT(parameters_buffer.IsTensor(), "Parameters buffer should be of tensor type."); +- auto* buffer_tensor = parameters_buffer.GetMutable(); +- ORT_RETURN_IF(nullptr == buffer_tensor, "Expected valid parameter buffer. Actual: nullptr."); ++ ORT_ENFORCE(parameters_buffer.IsAllocated(), "Parameters buffer should be pre-allocated."); ++ ORT_ENFORCE(parameters_buffer.IsTensor(), "Parameters buffer should be of tensor type."); ++ auto* init_tensor = parameters_buffer.GetMutable(); ++ ORT_ENFORCE(nullptr != init_tensor); + auto expected_buffer_size = static_cast(GetParametersSize(trainable_only)); +- ORT_RETURN_IF(buffer_tensor->Shape().Size() != expected_buffer_size, +- "Parameters buffer size incorrect. Expected:", expected_buffer_size, +- ", Actual:", buffer_tensor->Shape().Size()); ++ ORT_ENFORCE(init_tensor->Shape().Size() == expected_buffer_size, ++ "Parameters buffer size incorrect. Expected:", expected_buffer_size, ++ ", Actual:", init_tensor->Shape().Size()); + +- auto& train_sess_state = train_sess_->GetSessionState(); + const DataTransferManager& sess_data_transfer_manager = train_sess_->GetDataTransferManager(); +- const auto model_inputs_with_error = GetTrainingModelInputs(); +- ORT_RETURN_IF_ERROR(model_inputs_with_error.first); +- ORT_RETURN_IF_NOT(model_inputs_with_error.second, "Training model graph inputs are not defined."); +- const auto param_to_node_arg = BuildParameterToInputNodeArgMap(state_->module_checkpoint_state, +- model_inputs_with_error.second); +- const auto param_to_grad_index = BuildParameterToGradInputIndexMap(train_input_names_.GradientInputNames()); +- +- if (state_->module_checkpoint_state.is_nominal_state) { +- // weights_ vector is not initialized for a nominal state. This function is expected to +- // initialize the weights_. +- ORT_ENFORCE(weights_.empty(), "Weights vector should be empty for a nominal state."); +- } + + size_t offset = 0; +- for (const auto& param_name : train_input_names_.WeightsInputNames()) { ++ for (const auto& param_name : weight_names_) { + auto& param = state_->module_checkpoint_state.named_parameters.at(param_name); + if (trainable_only && !param->RequiresGrad()) { + continue; + } + OrtValue& weight = param->Data(); ++ auto* weight_tensor = weight.GetMutable(); + +- auto param_it = param_to_node_arg.find(param_name); +- const TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorShapeProto( +- *(param_it->second->Shape())); +- const auto element_type = static_cast( +- onnxruntime::utils::GetMLDataType(*param_it->second)) +- ->GetElementType(); ++ const TensorShape& shape = weight_tensor->Shape(); ++ auto element_type = init_tensor->DataType(); ++ ORT_ENFORCE(weight_tensor->DataType() == element_type, "Data types must match."); + +- const OrtMemoryInfo& info = buffer_tensor->Location(); +- std::unique_ptr src_tensor; ++ const OrtMemoryInfo& info = init_tensor->Location(); ++ std::unique_ptr p_tensor; + + if (onnxruntime::utils::IsPrimitiveDataType(element_type)) { +- float* data_buffer = buffer_tensor->MutableData(); +- src_tensor = std::make_unique(element_type, +- shape, +- data_buffer + offset, +- info); +- } else { +- ORT_THROW("Unsupported type: ", element_type, " encountered while copying buffer to parameters. ", +- "Only float is supported."); +- } +- +- if (state_->module_checkpoint_state.is_nominal_state) { +- // If state is a nominal state, then we first need to allocate the memory for +- // parameters and their gradients in the checkpoint state before copying the data. +- ORT_RETURN_IF_ERROR(LoadParameter(param_name, *src_tensor, train_sess_state, true, +- param_to_grad_index, train_input_names_.GradientInputNames(), +- *param)); +- weights_.push_back(param->Data()); +- if (param->RequiresGrad()) { +- // It is expected that the gradients_ vector is already initialized with the correct size +- // in the Module constructor (even though the OrtValues contained in the vector are empty). +- gradients_[param_to_grad_index.at(param_name)] = param->Gradient(); +- } ++ float* data_buffer = init_tensor->MutableData(); ++ p_tensor = std::make_unique(element_type, ++ shape, ++ data_buffer + offset, ++ info); + } else { +- // If state is not a nominal state, then we can directly copy the data to the existing +- // parameters in the checkpoint state. +- auto* weight_tensor = weight.GetMutable(); +- ORT_ENFORCE(weight_tensor->DataType() == element_type, "Data types must match."); +- ORT_THROW_IF_ERROR(sess_data_transfer_manager.CopyTensor(*src_tensor.get(), *weight_tensor)); ++ ORT_THROW("Unsupported type: ", element_type); + } +- ++ ORT_THROW_IF_ERROR(sess_data_transfer_manager.CopyTensor(*p_tensor.get(), *weight_tensor)); + offset += shape.Size(); + } +- +- if (state_->module_checkpoint_state.is_nominal_state) { +- // Once the parameters are loaded, the state is no longer a nominal state. +- state_->module_checkpoint_state.is_nominal_state = false; +- } +- + return Status::OK(); + } + +@@ -618,9 +527,6 @@ Status Module::LazyResetGrad() { + } + + Status Module::TrainStep(const std::vector& inputs, std::vector& outputs) { +- ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, +- "Cannot perform TrainStep with a nominal state. Please load the model parameters first."); +- std::vector> params; + std::vector feeds{inputs}; + feeds.insert(feeds.end(), weights_.begin(), weights_.end()); + feeds.insert(feeds.end(), gradients_.begin(), gradients_.end()); +@@ -629,7 +535,7 @@ Status Module::TrainStep(const std::vector& inputs, std::vector(!accumulate_gradient_, &reset_grad_input); + feeds.push_back(reset_grad_input); + +- ORT_THROW_IF_ERROR(train_sess_->Run(RunOptions(), train_input_names_.AllInputNames(), feeds, train_output_names_, &outputs)); ++ ORT_THROW_IF_ERROR(train_sess_->Run(RunOptions(), train_input_names_, feeds, train_output_names_, &outputs)); + + // Reset the flag after every step. In case the ResetGrad was called before running + // the current step, it will have done the effective resetting during the +@@ -640,8 +546,6 @@ Status Module::TrainStep(const std::vector& inputs, std::vector& inputs, std::vector& outputs) { +- ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, +- "Cannot perform EvalStep with a nominal state. Please load the model parameters first."); + ORT_ENFORCE(nullptr != eval_sess_, "Evaluation session not initialized."); + std::vector feeds{inputs}; + feeds.insert(feeds.end(), weights_.begin(), weights_.end()); +@@ -656,8 +560,6 @@ Status Module::EvalStep(const std::vector& inputs, std::vector graph_output_names) const { +- ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, +- "Cannot export the model with a nominal state. Please load the model parameters first."); + ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(), + "Eval model was not provided. Cannot export a model for inferencing."); + +@@ -684,7 +586,7 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path + #endif + + size_t Module::GetTrainingModelInputCount() const noexcept { +- return train_input_names_.UserInputNames().size(); ++ return train_user_input_count_; + } + + size_t Module::GetEvalModelInputCount() const noexcept { +@@ -692,10 +594,10 @@ size_t Module::GetEvalModelInputCount() const noexcept { + } + + std::string Module::GetTrainingModelInputName(size_t index) const { +- ORT_ENFORCE(index < train_input_names_.UserInputNames().size(), +- "Train input name index out of range. Expected in range [0-", train_input_names_.UserInputNames().size(), "). Actual: ", ++ ORT_ENFORCE(index < train_user_input_count_, ++ "Train input name index out of range. Expected in range [0-", train_user_input_count_, "). Actual: ", + index); +- return train_input_names_.UserInputNames()[index]; ++ return train_input_names_.at(index); + } + + std::string Module::GetEvalModelInputName(size_t index) const { +@@ -713,43 +615,6 @@ std::pair Module::GetEvalModelInputs() cons + return eval_sess_->GetModelInputs(); + } + +-Module::TrainInputNames::TrainInputNames(gsl::span user_input_names, +- gsl::span weights_input_names, +- gsl::span gradient_input_names) { +- train_input_names_.reserve(user_input_names.size() + +- weights_input_names.size() + +- gradient_input_names.size() + +- 1U); // +1 for the reset gradient flag input +- train_input_index_offsets_.reserve(3); +- +- train_input_names_.insert(train_input_names_.end(), +- user_input_names.begin(), user_input_names.end()); +- train_input_index_offsets_.push_back(train_input_names_.size()); +- train_input_names_.insert(train_input_names_.end(), +- weights_input_names.begin(), weights_input_names.end()); +- train_input_index_offsets_.push_back(train_input_names_.size()); +- train_input_names_.insert(train_input_names_.end(), +- gradient_input_names.begin(), gradient_input_names.end()); +- train_input_index_offsets_.push_back(train_input_names_.size()); +- train_input_names_.push_back(ACCUMULATE_GRAD_CONTROL_INPUT_NAME); +-} +- +-gsl::span Module::TrainInputNames::AllInputNames() const { return train_input_names_; } +- +-gsl::span Module::TrainInputNames::UserInputNames() const { +- return gsl::span{train_input_names_.begin(), train_input_index_offsets_[0]}; +-} +- +-gsl::span Module::TrainInputNames::WeightsInputNames() const { +- return gsl::span{train_input_names_.begin() + train_input_index_offsets_[0], +- train_input_index_offsets_[1] - train_input_index_offsets_[0]}; +-} +- +-gsl::span Module::TrainInputNames::GradientInputNames() const { +- return gsl::span{train_input_names_.begin() + train_input_index_offsets_[1], +- train_input_index_offsets_[2] - train_input_index_offsets_[1]}; +-} +- + } // namespace api + } // namespace training + } // namespace onnxruntime +diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h +index 917887404..f323e6be7 100644 +--- a/orttraining/orttraining/training_api/module.h ++++ b/orttraining/orttraining/training_api/module.h +@@ -53,7 +53,6 @@ struct ModuleCheckpointState { + public: + std::unordered_map> named_parameters; + const DataTransferManager* train_session_data_transfer_mgr; +- bool is_nominal_state = false; + }; + + struct CheckpointState; +@@ -88,28 +87,19 @@ struct Module { + ~Module(); + + // Return the trainable/nontrainable parameters +- // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, +- // and the state has not been loaded yet, then this function will raise an exception. + std::vector> Parameters() const; + +- // Return the trainable/nontrainable parameters as a map +- // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, +- // and the state has not been loaded yet, then this function will raise an exception. + std::unordered_map> NamedParameters() const; + + // Reset and release the gradient buffer of all trainable params lazily. + Status LazyResetGrad(); + + // Train Step – does forward and backward computation. The outputs will be the forward’s outputs. +- // Gradients will be accumulated within the Parameter object. +- // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, +- // and the state has not been loaded yet, then this function will return an error. ++ // Gradients will be accumulated within the Parameter object + Status TrainStep(const std::vector& inputs, std::vector& outputs); + + // Eval Step – does forward computation. This will use a separate inference session + // and take in a separate inference graph, while sharing the parameters +- // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, +- // and the state has not been loaded yet, then this function will return an error. + Status EvalStep(const std::vector& inputs, std::vector& outputs); + + // Returns the output count for training graph +@@ -128,20 +118,14 @@ struct Module { + size_t GetParametersSize(const bool trainable_only = true) const; + + // Copy parameters onto contiguous buffer held by parameters_buffer +- // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, +- // and the state has not been loaded yet, then this function will return an error. + Status CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only = true); + + // Copy parameter values from contiguous buffer held by parameters_buffer onto parameters +- // This function is responsible for completing the nominal checkpoint state. The checkpoint +- // state will no longer be nominal after the successful completion of this function. + Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only = true); + + #if !defined(ORT_MINIMAL_BUILD) + // Load the eval model from eval_model_path_or_bytes and transform it for the purpose of +- // inferencing, and serialize to given path. +- // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, +- // and the state has not been loaded yet, then this function will return an error. ++ // inferencing, and serialize to given path + Status ExportModelForInferencing(const std::string& inference_model_path, + gsl::span graph_output_names) const; + #endif +@@ -168,28 +152,11 @@ struct Module { + std::unique_ptr train_sess_{nullptr}; + std::unique_ptr eval_sess_{nullptr}; + +- struct TrainInputNames { +- private: +- InlinedVector train_input_names_; +- InlinedVector train_input_index_offsets_; // offset range[[0], [1]) = user input names +- // offset range[[1], [2]) = weights input names +- // offset range[[2], [3]) = gradient input names +- public: +- TrainInputNames() = default; +- TrainInputNames(gsl::span user_input_names, +- gsl::span weights_input_names, +- gsl::span gradient_input_names); +- +- gsl::span AllInputNames() const; +- gsl::span UserInputNames() const; +- gsl::span WeightsInputNames() const; +- gsl::span GradientInputNames() const; +- }; +- +- TrainInputNames train_input_names_; ++ InlinedVector train_input_names_; + InlinedVector train_output_names_; + InlinedVector eval_input_names_; + InlinedVector eval_output_names_; ++ InlinedVector weight_names_; + + InlinedVector weights_; + InlinedVector gradients_; +@@ -198,6 +165,7 @@ struct Module { + + bool accumulate_gradient_ = false; + std::optional eval_model_path_; ++ size_t train_user_input_count_{0U}; + size_t eval_user_input_count_{0U}; + }; + +diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +index 0ed41f670..38a9aad96 100644 +--- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc ++++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +@@ -568,16 +568,9 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtChe + API_IMPL_BEGIN + + auto chkpt_state = reinterpret_cast(checkpoint_state); +- if (chkpt_state->module_checkpoint_state.is_nominal_state) { +- const std::string err_msg = +- "Parameter type and shape cannot be retrieved from nominal checkpoint state. " +- "Please load the parameter states first."; +- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); +- } +- + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { +- const std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; ++ std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + +@@ -593,15 +586,9 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); +- if (chkpt_state->module_checkpoint_state.is_nominal_state) { +- const std::string err_msg = +- "Parameter cannot be updated for nominal checkpoint state. Please load all the parameter states first."; +- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); +- } +- + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { +- const std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; ++ std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom( +@@ -621,15 +608,9 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); +- if (chkpt_state->module_checkpoint_state.is_nominal_state) { +- const std::string err_msg = +- "Parameter cannot be retrieved from nominal checkpoint state. Please load the parameter states first."; +- return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); +- } +- + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { +- const std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; ++ std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + +diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc +index 84c35e610..7f583ce8f 100644 +--- a/orttraining/orttraining/training_api/optimizer.cc ++++ b/orttraining/orttraining/training_api/optimizer.cc +@@ -21,8 +21,8 @@ namespace { + constexpr char GROUP_ZERO_NAME[] = "group0"; + static constexpr std::array CommonOptimizerInputs{"learning_rate", "step", "params", "gradients"}; + +-Status GraphInputsAreExpected(gsl::span actual_graph_inputs, +- gsl::span expected_graph_inputs) { ++Status GraphInputsAreExpected(gsl::span actual_graph_inputs, ++ gsl::span expected_graph_inputs) { + const auto stringify = [](const auto& container) { + if (container.empty()) { + return std::string("[]"); +@@ -245,17 +245,8 @@ Optimizer::Optimizer(const ModelIdentifiers& model_identifiers, + if (!find_group_zero) + state_->optimizer_checkpoint_state.group_named_optimizer_states.insert( + {GROUP_ZERO_NAME, std::make_shared()}); +- if (!state_->module_checkpoint_state.is_nominal_state) { +- // Construct the optimizer state and inputs only if the complete state +- // is available. +- // For a nominal state, delay the construction of the optimizer state +- // and inputs until the complete state is available. Once the complete +- // state is available, the optimizer state and inputs can be constructed +- // by invoking ConstructOptimizerStateAndInputs(). +- ORT_THROW_IF_ERROR(ConstructOptimizerStateAndInputs()); +- } else { +- delay_optimizer_state_contruction_ = true; +- } ++ ORT_THROW_IF_ERROR(GenerateMomentumNamedStates(state_->optimizer_checkpoint_state)); ++ ORT_THROW_IF_ERROR(ConstructInputs()); + } else { + ORT_THROW_IF_ERROR(LoadStateDict(state_->optimizer_checkpoint_state)); + } +@@ -307,10 +298,6 @@ void Optimizer::Initialize(const ModelIdentifiers& model_identifiers, + } + + Status Optimizer::Step() { +- if (delay_optimizer_state_contruction_) { +- ORT_RETURN_IF_ERROR(ConstructOptimizerStateAndInputs()); +- } +- + OrtValue learning_rate_input, step_input; + utils::WrapInOrtValue(optimizer_state_->learning_rate, &learning_rate_input); + // Use step count + 1 before running optimizer step. +@@ -388,17 +375,6 @@ Status Optimizer::LoadStateDict(OptimizerCheckpointState& optimizer_checkpoint_s + return Status::OK(); + } + +-Status Optimizer::ConstructOptimizerStateAndInputs() { +- ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, +- "The optimizer state cannot be constructed. Please load the model parameters first."); +- ORT_RETURN_IF_ERROR(GenerateMomentumNamedStates(state_->optimizer_checkpoint_state)); +- ORT_RETURN_IF_ERROR(ConstructInputs()); +- +- delay_optimizer_state_contruction_ = false; +- +- return Status::OK(); +-} +- + } // namespace api + } // namespace training + } // namespace onnxruntime +diff --git a/orttraining/orttraining/training_api/optimizer.h b/orttraining/orttraining/training_api/optimizer.h +index 031b11426..d9bc4870b 100644 +--- a/orttraining/orttraining/training_api/optimizer.h ++++ b/orttraining/orttraining/training_api/optimizer.h +@@ -123,15 +123,6 @@ struct Optimizer { + return Status::OK(); + } + +- // Constructs the optimizer state and prepares the model inputs. +- // This is called once during the construction of the Optimizer if the model state is available. +- // In case the optimizer was instantiated with a nominal checkpoint, this function must be +- // called when the model state is available. +- // The optimizer checks if the optimizer state needs to be constructed in the train step function. +- // However, this is exposed as a public function in case the user wants to construct the optimizer +- // state before the train step function is called. +- Status ConstructOptimizerStateAndInputs(); +- + private: + void Initialize(const ModelIdentifiers& model_identifiers, + const std::vector>& providers, +@@ -143,7 +134,8 @@ struct Optimizer { + + // Generates optimizer momentum states for parameters that require grad. + Status GenerateMomentumNamedStates(OptimizerCheckpointState& optimizer_checkpoint_states); +- // Constructs the ortvalue inputs to be fed to the graph at each step. ++ // Constructs the ortvalue inputs to be fed to the graph ++ // at each step. + Status ConstructInputs(); + + /** +@@ -168,8 +160,6 @@ struct Optimizer { + InlinedVector inputs_; + + int32_t group_count_{0}; +- +- bool delay_optimizer_state_contruction_{false}; + }; + + } // namespace api +diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc +index 78619947b..45f0f0ddc 100644 +--- a/orttraining/orttraining/training_api/training_session.cc ++++ b/orttraining/orttraining/training_api/training_session.cc +@@ -112,16 +112,7 @@ Status TrainingSession::CopyParametersToBuffer(OrtValue& parameters_buffer, cons + } + + Status TrainingSession::CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only) { +- const bool was_nominal_state = state_->module_checkpoint_state.is_nominal_state; +- ORT_RETURN_IF_ERROR(module_->CopyBufferToParameters(parameters_buffer, trainable_only)); +- +- // If the checkpoint state was nominal before loading the params, then we need to construct the +- // optimizer state and inputs. +- if (was_nominal_state) { +- ORT_RETURN_IF_ERROR(optimizer_->ConstructOptimizerStateAndInputs()); +- } +- +- return Status::OK(); ++ return module_->CopyBufferToParameters(parameters_buffer, trainable_only); + } + + #if !defined(ORT_MINIMAL_BUILD) +diff --git a/setup.py b/setup.py +index 67d34b065..e94165fdf 100644 +--- a/setup.py ++++ b/setup.py +@@ -298,7 +298,6 @@ if platform.system() == "Linux": + libs.extend(["libonnxruntime_providers_shared.so"]) + libs.extend(["libonnxruntime_providers_dnnl.so"]) + libs.extend(["libonnxruntime_providers_openvino.so"]) +- libs.extend(["libonnxruntime_providers_vitisai.so"]) + libs.append(providers_cuda_or_rocm) + libs.append(providers_tensorrt_or_migraphx) + libs.append(providers_cann) +@@ -311,7 +310,6 @@ elif platform.system() == "Darwin": + libs.extend(["libonnxruntime_providers_dnnl.dylib"]) + libs.extend(["libonnxruntime_providers_tensorrt.dylib"]) + libs.extend(["libonnxruntime_providers_cuda.dylib"]) +- libs.extend(["libonnxruntime_providers_vitisai.dylib"]) + if nightly_build: + libs.extend(["libonnxruntime_pywrapper.dylib"]) + else: +@@ -322,7 +320,6 @@ else: + libs.extend(["onnxruntime_providers_tensorrt.dll"]) + libs.extend(["onnxruntime_providers_openvino.dll"]) + libs.extend(["onnxruntime_providers_cuda.dll"]) +- libs.extend(["onnxruntime_providers_vitisai.dll"]) + # DirectML Libs + libs.extend(["DirectML.dll"]) + if nightly_build: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py -index b2040b24f..e1ad48699 100644 +index b2040b24f..c42ba9386 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py -@@ -991,6 +991,8 @@ def generate_build_tree( +@@ -56,7 +56,7 @@ class UsageError(BaseError): + + + def _check_python_version(): +- required_minor_version = 8 ++ required_minor_version = 7 + if (sys.version_info.major, sys.version_info.minor) < (3, required_minor_version): + raise UsageError( + f"Invalid Python version. At least Python 3.{required_minor_version} is required. " +@@ -328,12 +328,6 @@ def parse_arguments(): + help="[cross-compiling] Create Windows x86 makefiles. Requires --update and no existing cache " + "CMake setup. Delete CMakeCache.txt if needed", + ) +- parser.add_argument( +- "--rv64", +- action="store_true", +- help="[cross-compiling] Create riscv64 makefiles. Requires --update and no existing cache " +- "CMake setup. Delete CMakeCache.txt if needed", +- ) + parser.add_argument( + "--arm", + action="store_true", +@@ -357,18 +351,6 @@ def parse_arguments(): + action="store_true", + help="[cross-compiling] Create ARM64X Binary.", + ) +- parser.add_argument( +- "--riscv_toolchain_root", +- type=str, +- default="", +- help="Path to RISC-V toolchain root dir. e.g. --riscv_toolchain_root=$HOME/riscv-tools/", +- ) +- parser.add_argument( +- "--riscv_qemu_path", +- type=str, +- default="", +- help="Path to RISC-V qemu. e.g. --riscv_qemu_path=$HOME/qemu-dir/qemu-riscv64", +- ) + parser.add_argument("--msvc_toolset", help="MSVC toolset to use. e.g. 14.11") + parser.add_argument("--windows_sdk_version", help="Windows SDK version to use. e.g. 10.0.19041.0") + parser.add_argument("--android", action="store_true", help="Build for Android") +@@ -442,8 +424,8 @@ def parse_arguments(): + parser.add_argument( + "--enable_address_sanitizer", action="store_true", help="Enable address sanitizer. Windows/Linux/MacOS only." + ) +- # The following flag is mostly designed to be used in ONNX Runtime's Azure DevOps/Github build pipelines. Its main purpose is to make the built binaries pass BinSkim scan. +- parser.add_argument("--use_binskim_compliant_compile_flags", action="store_true", help="Use preset compile flags.") ++ # The following feature requires installing some special Visual Studio components that do not get installed by default. Therefore the options is default OFF. ++ parser.add_argument("--enable_qspectre", action="store_true", help="Enable Qspectre. Windows only.") + parser.add_argument( + "--disable_memleak_checker", + action="store_true", +@@ -804,6 +786,11 @@ def get_linux_distro(): + return "", "" + + ++def is_ubuntu_1604(): ++ dist, ver = get_linux_distro() ++ return dist == "Ubuntu" and ver.startswith("16.04") ++ ++ + def get_config_build_dir(build_dir, config): + # build directory per configuration + return os.path.join(build_dir, config) +@@ -857,6 +844,15 @@ def update_submodules(source_dir): + run_subprocess(["git", "submodule", "update", "--init", "--recursive"], cwd=source_dir) + + ++def is_docker(): ++ path = "/proc/self/cgroup" ++ return ( ++ os.path.exists("/.dockerenv") ++ or os.path.isfile(path) ++ and any("docker" in line for line in open(path)) # noqa: SIM115 ++ ) ++ ++ + def install_python_deps(numpy_version=""): + dep_packages = ["setuptools", "wheel", "pytest"] + dep_packages.append(f"numpy=={numpy_version}" if numpy_version else "numpy>=1.16.6") +@@ -991,6 +987,8 @@ def generate_build_tree( disable_optional_type = "optional" in types_to_disable disable_sparse_tensors = "sparsetensor" in types_to_disable @@ -11,12 +37257,1991 @@ index b2040b24f..e1ad48699 100644 cmake_args += [ "-Donnxruntime_RUN_ONNX_TESTS=" + ("ON" if args.enable_onnx_tests else "OFF"), "-Donnxruntime_GENERATE_TEST_REPORTS=ON", -@@ -1583,7 +1585,7 @@ def generate_build_tree( +@@ -1095,19 +1093,6 @@ def generate_build_tree( + "-Donnxruntime_DISABLE_OPTIONAL_TYPE=" + ("ON" if disable_optional_type else "OFF"), + ] + +- if args.rv64: +- add_default_definition(cmake_extra_defines, "onnxruntime_CROSS_COMPILING", "ON") +- if not args.riscv_toolchain_root: +- raise BuildError("The --riscv_toolchain_root option is required to build for riscv64.") +- if not args.skip_tests and not args.riscv_qemu_path: +- raise BuildError("The --riscv_qemu_path option is required for testing riscv64.") +- +- cmake_args += [ +- "-DRISCV_TOOLCHAIN_ROOT:PATH=" + args.riscv_toolchain_root, +- "-DRISCV_QEMU_PATH:PATH=" + args.riscv_qemu_path, +- "-DCMAKE_TOOLCHAIN_FILE=" + os.path.join(source_dir, "cmake", "riscv64.toolchain.cmake"), +- ] +- + # By default on Windows we currently support only cross compiling for ARM/ARM64 + # (no native compilation supported through this script). + if args.arm64 or args.arm64ec or args.arm: +@@ -1314,6 +1299,10 @@ def generate_build_tree( + if args.use_webnn: + if not args.build_wasm: + raise BuildError("WebNN is only available for WebAssembly build.") ++ if args.disable_rtti: ++ # Avoid unboundTypeError for WebNN EP since unbound type names are illegal with RTTI disabled ++ # in Embind API, relevant issue: https://github.com/emscripten-core/emscripten/issues/16911 ++ raise BuildError("WebNN is not supported with RTTI disabled.") + cmake_args += ["-Donnxruntime_USE_WEBNN=ON"] + + if args.use_snpe: +@@ -1484,29 +1473,27 @@ def generate_build_tree( + f"-DVERSION_PRIVATE_PART={MM}{DD}", + f"-DVERSION_STRING={ort_major}.{ort_minor}.{build_number}.{source_version[0:7]}", + ] +- ++ cflags = None ++ cxxflags = None ++ ldflags = None ++ cudaflags = [] + for config in configs: +- cflags = [] +- cxxflags = None +- ldflags = None +- cudaflags = [] +- if is_windows() and not args.ios and not args.android and not args.build_wasm: +- njobs = number_of_parallel_jobs(args) +- if njobs > 1: +- if args.parallel == 0: +- cflags += ["/MP"] +- else: +- cflags += ["/MP%d" % njobs] + # Setup default values for cflags/cxxflags/ldflags. + # The values set here are purely for security and compliance purposes. ONNX Runtime should work fine without these flags. + if ( +- (args.use_binskim_compliant_compile_flags or args.enable_address_sanitizer) ++ "CFLAGS" not in os.environ ++ and "CXXFLAGS" not in os.environ ++ and (not args.use_cuda or "CUDAFLAGS" not in os.environ) + and not args.ios + and not args.android + and not args.build_wasm ++ and not args.use_rocm ++ and not (is_linux() and platform.machine() != "aarch64" and platform.machine() != "x86_64") + ): + if is_windows(): +- cflags += ["/guard:cf", "/DWIN32", "/D_WINDOWS"] ++ cflags = ["/guard:cf", "/DWIN32", "/D_WINDOWS"] ++ if args.parallel: ++ cflags += ["/MP"] + if not args.use_gdk: + # Target Windows 10 + cflags += [ +@@ -1518,8 +1505,7 @@ def generate_build_tree( + # The "/profile" flag implies "/DEBUG:FULL /DEBUGTYPE:cv,fixup /OPT:REF /OPT:NOICF /INCREMENTAL:NO /FIXED:NO". We set it for satisfying a Microsoft internal compliance requirement. External users + # do not need to have it. + ldflags = ["/profile", "/DYNAMICBASE"] +- # Address Sanitizer libs do not have a Qspectre version. So they two cannot be both enabled. +- if not args.enable_address_sanitizer: ++ if args.enable_qspectre: + cflags += ["/Qspectre"] + if config == "Release": + cflags += ["/O2", "/Ob2", "/DNDEBUG"] +@@ -1527,11 +1513,13 @@ def generate_build_tree( + cflags += ["/O2", "/Ob1", "/DNDEBUG"] + elif config == "Debug": + cflags += ["/Ob0", "/Od", "/RTC1"] ++ if args.enable_address_sanitizer: ++ cflags += ["/fsanitize=address"] + elif config == "MinSizeRel": + cflags += ["/O1", "/Ob1", "/DNDEBUG"] +- if args.enable_address_sanitizer: +- cflags += ["/fsanitize=address"] + cxxflags = cflags.copy() ++ if not args.disable_exceptions: ++ cxxflags += ["/EHsc"] + if args.use_cuda: + # On Windows, nvcc passes /EHsc to the host compiler by default. + cuda_compile_flags_str = "" +@@ -1583,16 +1571,12 @@ def generate_build_tree( "-pipe", "-ggdb3", ] - if is_linux() and platform.machine() == "x86_64": + if is_linux() and is_x86_64_build: # The following flags needs GCC 8 and newer - cflags += ["-fstack-clash-protection"] - if not args.rv64: +- cflags += ["-fstack-clash-protection"] +- if not args.rv64: +- cflags += ["-fcf-protection"] ++ cflags += ["-fstack-clash-protection", "-fcf-protection"] + cxxflags = cflags.copy() + if args.use_cuda: + cudaflags = cflags.copy() +- if cxxflags is None and cflags is not None and len(cflags) != 0: +- cxxflags = cflags.copy() + config_build_dir = get_config_build_dir(build_dir, config) + os.makedirs(config_build_dir, exist_ok=True) + if args.use_tvm: +@@ -1607,7 +1591,7 @@ def generate_build_tree( + ) + preinstalled_dir = Path(build_dir) / config + temp_cmake_args = cmake_args.copy() +- if cflags is not None and cxxflags is not None and len(cflags) != 0 and len(cxxflags) != 0: ++ if cflags is not None and cxxflags is not None: + temp_cmake_args += [ + "-DCMAKE_C_FLAGS=%s" % (" ".join(cflags)), + "-DCMAKE_CXX_FLAGS=%s" % (" ".join(cxxflags)), +@@ -2420,6 +2404,16 @@ def run_csharp_tests(source_dir, build_dir, use_cuda, use_openvino, use_tensorrt + run_subprocess(cmd_args, cwd=csharp_source_dir) + + ++def is_cross_compiling_on_apple(args): ++ if not is_macOS(): ++ return False ++ if args.ios: ++ return True ++ if args.osx_arch != platform.machine(): ++ return True ++ return False ++ ++ + def generate_documentation(source_dir, build_dir, configs, validate): + # Randomly choose one build config + config = next(iter(configs)) +@@ -2734,6 +2728,12 @@ def main(): + log.info("Activating emsdk...") + run_subprocess([emsdk_file, "activate", emsdk_version], cwd=emsdk_dir) + ++ if is_ubuntu_1604(): ++ if args.arm or args.arm64: ++ raise BuildError("Only Windows ARM(64) cross-compiled builds supported currently through this script") ++ if not is_docker() and not args.use_acl and not args.use_armnn: ++ install_python_deps() ++ + if args.enable_pybind and is_windows(): + install_python_deps(args.numpy_version) + +diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +deleted file mode 100644 +index 0de2ac442..000000000 +--- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml ++++ /dev/null +@@ -1,336 +0,0 @@ +-##### start trigger Don't edit it manually, Please do edit set-trigger-rules.py #### +-trigger: +- branches: +- include: +- - main +- - rel-* +- paths: +- exclude: +- - docs/** +- - README.md +- - CONTRIBUTING.md +- - BUILD.md +- - 'js/web' +- - 'onnxruntime/core/providers/js' +-pr: +- branches: +- include: +- - main +- - rel-* +- paths: +- exclude: +- - docs/** +- - README.md +- - CONTRIBUTING.md +- - BUILD.md +- - 'js/web' +- - 'onnxruntime/core/providers/js' +-#### end trigger ####parameters: +- +-# reference: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +-parameters: +-- name: specificArtifact +- displayName: Use Specific Artifact +- type: boolean +- default: false +-- name: BuildId +- displayName: Specific Artifact's RunId +- type: number +- default: 0 +- +-resources: +- repositories: +- - repository: manylinux +- type: Github +- endpoint: Microsoft +- name: pypa/manylinux +- ref: 5eda9aded5462201e6310105728d33016e637ea7 +- +- - repository: LLaMa2Onnx +- type: Github +- endpoint: Microsoft +- name: Microsoft/Llama-2-Onnx +- ref: main +- +-variables: +- - template: templates/common-variables.yml +- - name: docker_base_image +- value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 +- - name: linux_trt_version +- value: 8.6.1.6-1.cuda11.8 +- +-stages: +-- stage: Build_Onnxruntime_Cuda +- jobs: +- - job: Linux_Build +- timeoutInMinutes: 120 +- variables: +- skipComponentGovernanceDetection: true +- CCACHE_DIR: $(Pipeline.Workspace)/ccache +- workspace: +- clean: all +- pool: onnxruntime-Ubuntu2204-AMD-CPU +- steps: +- - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 +- displayName: 'Clean Agent Directories' +- condition: always() +- +- - checkout: self +- clean: true +- submodules: none +- +- - template: templates/get-docker-image-steps.yml +- parameters: +- Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +- Context: tools/ci_build/github/linux/docker +- DockerBuildArgs: " +- --network=host +- --build-arg BASEIMAGE=$(docker_base_image) +- --build-arg TRT_VERSION=$(linux_trt_version) +- --build-arg BUILD_UID=$( id -u ) +- " +- Repository: onnxruntimecuda11build +- +- - task: Cache@2 +- inputs: +- key: '"ccache" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' +- path: $(CCACHE_DIR) +- restoreKeys: | +- "ccache" | "$(Build.SourceBranch)" +- "ccache" +- cacheHitVar: CACHE_RESTORED +- displayName: Cach Task +- +- - script: | +- sudo mkdir -p $(Pipeline.Workspace)/ccache +- condition: ne(variables.CACHE_RESTORED, 'true') +- displayName: Create Cache Dir +- +- - task: CmdLine@2 +- inputs: +- script: | +- mkdir -p $HOME/.onnx +- docker run -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" --rm \ +- --volume /data/onnx:/data/onnx:ro \ +- --volume $(Build.SourcesDirectory):/onnxruntime_src \ +- --volume $(Build.BinariesDirectory):/build \ +- --volume /data/models:/build/models:ro \ +- --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ +- --volume $(Pipeline.Workspace)/ccache:/cache \ +- -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ +- -e NIGHTLY_BUILD \ +- -e BUILD_BUILDNUMBER \ +- -e CCACHE_DIR=/cache \ +- onnxruntimecuda11build \ +- /bin/bash -c " +- set -ex; \ +- env; \ +- ccache -s; \ +- /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ +- --build_dir /build --cmake_generator Ninja \ +- --config Release --update --build \ +- --skip_submodule_sync \ +- --build_shared_lib \ +- --parallel \ +- --build_wheel \ +- --enable_onnx_tests --use_cuda --cuda_version=${{variables.common_cuda_version}} --cuda_home=/usr/local/cuda-${{variables.common_cuda_version}} --cudnn_home=/usr/local/cuda-${{variables.common_cuda_version}} \ +- --enable_cuda_profiling --enable_cuda_nhwc_ops \ +- --enable_pybind --build_java \ +- --use_cache \ +- --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=75;86' ; \ +- ccache -sv; \ +- ccache -z" +- workingDirectory: $(Build.SourcesDirectory) +- +- - task: CmdLine@2 +- inputs: +- script: | +- rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 +- rm -f $(Build.BinariesDirectory)/Release/models +- find $(Build.BinariesDirectory)/Release/_deps -mindepth 1 ! -regex '^$(Build.BinariesDirectory)/Release/_deps/onnx-src\(/.*\)?' -delete +- cd $(Build.BinariesDirectory)/Release +- find -executable -type f > $(Build.BinariesDirectory)/Release/perms.txt +- +- - script: | +- set -ex +- mkdir -p $(Agent.TempDirectory)/ort +- cp $(Build.BinariesDirectory)/Release/dist/*.whl $(Agent.TempDirectory)/ort/ +- displayName: 'Copy Wheels' +- +- - task: PublishPipelineArtifact@0 +- displayName: 'Publish Pipeline Artifact' +- inputs: +- artifactName: 'drop-ort-linux-gpu' +- targetPath: '$(Agent.TempDirectory)/ort' +- +- - template: templates/explicitly-defined-final-tasks.yml +- +-- stage: Stable_Diffusion +- dependsOn: +- - Build_Onnxruntime_Cuda +- jobs: +- - job: Stable_Diffusion +- variables: +- skipComponentGovernanceDetection: true +- CLIP_MODEL_CACHE: $(Agent.TempDirectory)/clip_cache +- STABLE_DIFFUSION_MODEL_CACHE: $(Agent.TempDirectory)/stablediffusion_cache +- GenerateImage_DIR: $(Agent.TempDirectory)/images +- workspace: +- clean: all +- pool: onnxruntime-Linux-GPU-A10-12G +- steps: +- - checkout: self +- clean: true +- submodules: none +- +- - template: templates/flex-downloadPipelineArtifact.yml +- parameters: +- StepName: 'Download Onnxruntime Artifact' +- ArtifactName: 'drop-ort-linux-gpu' +- TargetPath: '$(Build.BinariesDirectory)/Release' +- SpecificArtifact: ${{ parameters.specificArtifact }} +- BuildId: ${{ parameters.BuildId }} +- +- - task: Cache@2 +- inputs: +- key: stable_diffusion | $(Build.SourcesDirectory)/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +- restoreKeys: | +- stable_diffusion | $(Build.SourcesDirectory)/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +- stable_diffusion +- path: $(STABLE_DIFFUSION_MODEL_CACHE) +- displayName: Cache stable diffusion model +- +- - script: | +- mkdir -p $(GenerateImage_DIR) +- docker run --rm --gpus all -v $PWD:/workspace \ +- -v $(Build.BinariesDirectory)/Release:/Release \ +- -v $(STABLE_DIFFUSION_MODEL_CACHE):/model_cache:rw \ +- -v $(GenerateImage_DIR):/images:rw \ +- nvcr.io/nvidia/pytorch:22.11-py3 \ +- bash -c ' \ +- set -ex; \ +- python3 --version; \ +- python3 -m pip install --upgrade pip; \ +- python3 -m pip install /Release/*.whl; \ +- pushd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion; \ +- python3 -m pip install -r requirements-cuda11.txt; \ +- python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com; \ +- echo Generate an image guided by a text prompt; \ +- python3 demo_txt2img.py --framework-model-dir /model_cache --seed 1 --deterministic "astronaut riding a horse on mars" ; \ +- find $(pwd)/ORT_CUDA -name "*.png" -exec cp {} /images/ \; ; \ +- popd ; \ +- ' +- displayName: 'Run stable diffusion demo' +- workingDirectory: $(Build.SourcesDirectory) +- +- # For verification we will check the generated image looks . +- - task: PublishPipelineArtifact@0 +- displayName: 'Publish code coverage report' +- inputs: +- artifactName: "Generated-Image" +- targetPath: '$(GenerateImage_DIR)' +- +- - task: Cache@2 +- inputs: +- key: clip_model | $(Build.SourcesDirectory)/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py +- restoreKeys: | +- clip_model | $(Build.SourcesDirectory)/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py +- clip_model +- path: $(CLIP_MODEL_CACHE) +- displayName: Cache clip model +- +- - script: | +- docker run --rm --gpus all -v $PWD:/workspace \ +- -v $(CLIP_MODEL_CACHE):/model_cache:rw \ +- nvcr.io/nvidia/pytorch:22.11-py3 \ +- bash -c ' +- set -ex; \ +- python3 --version; \ +- python3 -m pip install --upgrade pip; \ +- pushd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion/; \ +- image2=$(find $(pwd) -name "astronaut_riding_a_h*.png") ; \ +- pushd test; \ +- python3 -m pip install -r requirements.txt; \ +- echo check demo_txt2image.py generate image; \ +- python3 -u check_image.py --image1 astronaut_riding_txt2image-DDIM-50.png --image2 $image2 --cache_dir /model_cache ; \ +- popd ; \ +- popd ; \ +- ' +- displayName: 'Check the generated image' +- workingDirectory: $(Build.SourcesDirectory) +- +-- stage: Llama2_ONNX_FP16 +- dependsOn: +- - Build_Onnxruntime_Cuda +- jobs: +- - job: Llama2_ONNX_FP16 +- variables: +- skipComponentGovernanceDetection: true +- workspace: +- clean: all +- pool: onnxruntime-Linux-GPU-T4 +- steps: +- - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 +- displayName: 'Clean Agent Directories' +- condition: always() +- +- - checkout: self +- clean: true +- submodules: none +- +- - checkout: LLaMa2Onnx +- clean: true +- submodules: none +- +- - template: templates/flex-downloadPipelineArtifact.yml +- parameters: +- StepName: 'Download Onnxruntime Artifact' +- ArtifactName: 'drop-ort-linux-gpu' +- TargetPath: '$(Build.BinariesDirectory)/ort-artifact/' +- SpecificArtifact: ${{ parameters.specificArtifact }} +- BuildId: ${{ parameters.BuildId }} +- +- - task: DownloadPackage@1 +- displayName: 'Download Llama2 model' +- inputs: +- packageType: upack +- feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' +- version: 1.0.0 +- definition: '772ebce3-7e06-46d5-b3cc-82040ec4b2ce' +- downloadPath: $(Agent.TempDirectory)/llama2_onnx_ft16 +- +- - template: templates/get-docker-image-steps.yml +- parameters: +- Dockerfile: onnxruntime/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 +- Context: onnxruntime/tools/ci_build/github/linux/docker/ +- ScriptName: onnxruntime/tools/ci_build/get_docker_image.py +- DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" +- Repository: onnxruntimeubi8packagestest +- UpdateDepsTxt: false +- +- - script: | +- docker run --rm --gpus all -v $(Build.SourcesDirectory)/Llama-2-Onnx:/workspace \ +- -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ +- -v $(Agent.TempDirectory)/llama2_onnx_ft16:/models \ +- onnxruntimeubi8packagestest \ +- bash -c " +- set -ex; \ +- python3 -m pip install --upgrade pip ; \ +- python3 -m pip install /ort-artifact/*.whl ; \ +- python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ +- python3 -m pip install sentencepiece ; \ +- pushd /workspace ; \ +- python3 MinimumExample/Example_ONNX_LlamaV2.py --onnx_file /models/ONNX/LlamaV2_7B_FT_float16.onnx \ +- --embedding_file /models/embeddings.pth --tokenizer_path tokenizer.model --prompt 'What is the lightest element?' > /workspace/answer.txt ; \ +- popd ; \ +- " +- displayName: 'Run Llama2 demo' +- workingDirectory: $(Build.SourcesDirectory) +- +- - script: | +- set -ex +- real=$(cat $(Build.SourcesDirectory)/Llama-2-Onnx/answer.txt) +- trim_actual=$(tr -dc '[[:print:]]' <<< "$real") +- expected="The lightest element is hydrogen. Hydrogen is the lightest element on the periodic table, with an atomic mass of 1.00794 u (unified atomic mass units)." +- [ "$expected" == "$trim_actual" ] && exit 0 || exit 1 +- displayName: 'Check result' +diff --git a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml +index d37e9bdc5..3ddc167bc 100644 +--- a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml +@@ -28,7 +28,7 @@ stages: + artifactName: 'onnxruntime-android-full-aar' + job_name_suffix: 'Full' + publish_executables: '1' +- pool_name: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ pool_name: 'onnxruntime-Ubuntu2004-AMD-CPU' + + # build Python packages + # Linux GPU only +diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +index 5a50a9964..719a0c484 100644 +--- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml ++++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +@@ -248,7 +248,7 @@ stages: + workspace: + clean: all + timeoutInMinutes: 120 +- pool: onnxruntime-Ubuntu2204-AMD-CPU ++ pool: onnxruntime-Ubuntu2004-AMD-CPU + variables: + RocmVersion: '5.6' + steps: +@@ -1023,7 +1023,7 @@ stages: + + - template: nuget/templates/test_win.yml + parameters: +- AgentPool : 'onnxruntime-Win2022-GPU-A10' ++ AgentPool : 'onnxruntime-Win2022-GPU-T4' + NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' + ArtifactSuffix: 'GPU' + StageSuffix: 'GPU' +@@ -1034,7 +1034,7 @@ stages: + + - template: nuget/templates/test_win.yml + parameters: +- AgentPool : 'onnxruntime-Win2022-GPU-A10' ++ AgentPool : 'onnxruntime-Win2022-GPU-T4' + NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows' + ArtifactSuffix: 'GPU' + StageSuffix: 'GPU' +@@ -1046,7 +1046,7 @@ stages: + + - template: nuget/templates/test_linux.yml + parameters: +- AgentPool : Onnxruntime-Linux-GPU-A10 ++ AgentPool : Onnxruntime-Linux-GPU + ArtifactSuffix: 'GPU' + StageSuffix: 'GPU' + NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' +@@ -1055,7 +1055,7 @@ stages: + + - template: nuget/templates/test_linux.yml + parameters: +- AgentPool : Onnxruntime-Linux-GPU-A10 ++ AgentPool : Onnxruntime-Linux-GPU + ArtifactSuffix: 'GPU' + StageSuffix: 'GPU' + MoreSuffix: '_Linux' +diff --git a/tools/ci_build/github/azure-pipelines/clean-build-docker-image-cache-pipeline.yml b/tools/ci_build/github/azure-pipelines/clean-build-docker-image-cache-pipeline.yml +index 43e668eef..24086b616 100644 +--- a/tools/ci_build/github/azure-pipelines/clean-build-docker-image-cache-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/clean-build-docker-image-cache-pipeline.yml +@@ -19,7 +19,8 @@ variables: + jobs: + - job: Clean_Build_Docker_Image_Cache + +- pool: onnxruntime-Ubuntu2204-AMD-CPU ++ pool: ++ vmImage: 'ubuntu-20.04' + + timeoutInMinutes: 30 + +@@ -28,6 +29,13 @@ jobs: + submodules: false + fetchDepth: 1 + ++ - task: UsePythonVersion@0 ++ inputs: ++ versionSpec: '3.9' ++ addToPath: true ++ architecture: 'x64' ++ displayName: "Use Python 3.9" ++ + - task: AzureCLI@2 + inputs: + azureSubscription: 'AIInfraBuild' +diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +index 0c24d4897..df7b5f59d 100644 +--- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +@@ -126,7 +126,7 @@ stages: + BaseImage: 'registry.access.redhat.com/ubi8/ubi' + OnnxruntimeArch: 'x64' + OnnxruntimeNodejsBindingArch: 'x64' +- PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + PackageJava: false + PackageNodeJS: false + # Nuget Packaging +@@ -151,7 +151,7 @@ stages: + # Testing + - template: nuget/templates/test_win.yml + parameters: +- AgentPool : 'onnxruntime-Win2022-GPU-A10' ++ AgentPool : 'onnxruntime-Win2022-GPU-T4' + NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' + ArtifactSuffix: 'GPU' + StageSuffix: 'GPU' +@@ -162,7 +162,7 @@ stages: + + - template: nuget/templates/test_win.yml + parameters: +- AgentPool : 'onnxruntime-Win2022-GPU-A10' ++ AgentPool : 'onnxruntime-Win2022-GPU-T4' + NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows' + ArtifactSuffix: 'GPU' + StageSuffix: 'GPU' +@@ -174,7 +174,7 @@ stages: + + - template: nuget/templates/test_linux.yml + parameters: +- AgentPool : Onnxruntime-Linux-GPU-A10 ++ AgentPool : Onnxruntime-Linux-GPU + ArtifactSuffix: 'GPU' + StageSuffix: 'GPU' + NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' +@@ -184,7 +184,7 @@ stages: + + - template: nuget/templates/test_linux.yml + parameters: +- AgentPool : Onnxruntime-Linux-GPU-A10 ++ AgentPool : Onnxruntime-Linux-GPU + ArtifactSuffix: 'GPU' + StageSuffix: 'GPU' + MoreSuffix: '_Linux' +diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +index a4bd24b4d..07f672c75 100644 +--- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +@@ -46,7 +46,7 @@ stages: + skipComponentGovernanceDetection: true + ORT_CACHE_DIR: $(Agent.TempDirectory)/ort_ccache + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] +- pool: onnxruntime-Ubuntu2204-AMD-CPU ++ pool: onnxruntime-Ubuntu2004-AMD-CPU + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' +@@ -93,7 +93,7 @@ stages: + --config Debug \ + --skip_submodule_sync \ + --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --build_csharp \ + --enable_onnx_tests --enable_address_sanitizer \ + --update --build; +@@ -102,7 +102,7 @@ stages: + --config Debug \ + --skip_submodule_sync \ + --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --build_csharp \ + --enable_onnx_tests --enable_address_sanitizer \ + --test;" +@@ -123,7 +123,7 @@ stages: + skipComponentGovernanceDetection: true + ORT_CACHE_DIR: $(Agent.TempDirectory)/ort_ccache + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] +- pool: onnxruntime-Ubuntu2204-AMD-CPU ++ pool: onnxruntime-Ubuntu2004-AMD-CPU + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' +@@ -228,7 +228,7 @@ stages: + --config Release \ + --skip_submodule_sync \ + --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --build_wheel \ + --build_csharp \ + --enable_onnx_tests \ +diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml +index 31decb0c2..146186e9e 100644 +--- a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml +@@ -43,7 +43,7 @@ jobs: + variables: + CCACHE_DIR: $(Agent.TempDirectory)/ccache + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] +- pool: onnxruntime-Ubuntu2204-AMD-CPU ++ pool: onnxruntime-Ubuntu2004-AMD-CPU + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' +@@ -94,7 +94,7 @@ jobs: + --config Release \ + --skip_submodule_sync \ + --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --build_wheel \ + --skip_tests \ + --cmake_extra_defines onnxruntime_ENABLE_ATEN=ON \ +@@ -126,7 +126,7 @@ jobs: + --config Release \ + --skip_submodule_sync \ + --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --build_wheel \ + --test \ + --cmake_extra_defines onnxruntime_ENABLE_ATEN=ON" +diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml +index b3f5ff963..a5c08e95b 100644 +--- a/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/linux-cpu-eager-pipeline.yml +@@ -51,7 +51,7 @@ jobs: + timeoutInMinutes: 120 + workspace: + clean: all +- pool: onnxruntime-Ubuntu2204-AMD-CPU ++ pool: onnxruntime-Ubuntu2004-AMD-CPU + steps: + - checkout: self + clean: true +@@ -80,7 +80,7 @@ jobs: + --config Release \ + --skip_submodule_sync \ + --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --enable_lazy_tensor --enable_training --build_wheel --skip_test \ + workingDirectory: $(Build.SourcesDirectory) + +diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml +index 1053a2518..1df36c2f2 100644 +--- a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml +@@ -141,7 +141,7 @@ jobs: + --config Debug \ + --skip_submodule_sync \ + --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --skip_tests \ + --minimal_build \ + --disable_exceptions \ +@@ -222,7 +222,7 @@ jobs: + --build_dir /build/5 --cmake_generator Ninja \ + --config Debug \ + --skip_submodule_sync \ +- --build_shared_lib --use_binskim_compliant_compile_flags \ ++ --build_shared_lib \ + --parallel \ + --minimal_build extended + workingDirectory: $(Build.SourcesDirectory) +@@ -246,7 +246,7 @@ jobs: + --skip_submodule_sync \ + --build_shared_lib \ + --build_wheel \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --skip_tests \ + --disable_ml_ops \ + --disable_types sparsetensor float8 optional \ +@@ -272,7 +272,7 @@ jobs: + --config MinSizeRel \ + --skip_submodule_sync \ + --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --minimal_build \ + --disable_exceptions \ + --disable_ml_ops \ +@@ -300,7 +300,7 @@ jobs: + --cmake_generator Ninja \ + --config MinSizeRel \ + --skip_submodule_sync \ +- --build_shared_lib --use_binskim_compliant_compile_flags \ ++ --build_shared_lib \ + --parallel \ + --minimal_build extended \ + --disable_exceptions \ +@@ -330,7 +330,7 @@ jobs: + --cmake_generator Ninja \ + --config MinSizeRel \ + --skip_submodule_sync \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --android \ + --android_sdk_path /android_home \ + --android_ndk_path /ndk_home \ +diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +index b19a8b11d..0993a81a0 100644 +--- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +@@ -43,6 +43,7 @@ resources: + ref: 5eda9aded5462201e6310105728d33016e637ea7 + + variables: ++ - template: templates/common-variables.yml + - name: docker_base_image + ${{ if eq(parameters.CudaVersion, '11.8') }}: + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 +@@ -55,12 +56,6 @@ variables: + ${{ if eq(parameters.CudaVersion, '12.2') }}: + value: 8.6.1.6-1.cuda12.0 + +- - name: Repository +- ${{ if eq(parameters.CudaVersion, '11.8') }}: +- value: 'onnxruntimecuda11build' +- ${{ if eq(parameters.CudaVersion, '12.2') }}: +- value: 'onnxruntimecuda12build' +- + jobs: + - job: Linux_Build + timeoutInMinutes: 120 +@@ -69,8 +64,7 @@ jobs: + CCACHE_DIR: $(Pipeline.Workspace)/ccache + workspace: + clean: all +- pool: onnxruntime-Ubuntu2204-AMD-CPU +- ++ pool: onnxruntime-Ubuntu2004-AMD-CPU + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' +@@ -79,25 +73,24 @@ jobs: + - checkout: self + clean: true + submodules: none +- + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: " +- --network=host ++ --network=host + --build-arg BASEIMAGE=$(docker_base_image) +- --build-arg TRT_VERSION=$(linux_trt_version) ++ --build-arg TRT_VERSION=$(linux_trt_version) + --build-arg BUILD_UID=$( id -u ) + " +- Repository: $(Repository) ++ Repository: onnxruntimecuda11build + + - task: Cache@2 + inputs: +- key: '"ccache" | "${{parameters.CudaVersion}}" |"$(Build.SourceBranch)" | "$(Build.SourceVersion)"' ++ key: '"ccache" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' + path: $(CCACHE_DIR) + restoreKeys: | +- "ccache" | "${{parameters.CudaVersion}}" | "$(Build.SourceBranch)" ++ "ccache" | "$(Build.SourceBranch)" + "ccache" + cacheHitVar: CACHE_RESTORED + displayName: Cach Task +@@ -107,41 +100,41 @@ jobs: + condition: ne(variables.CACHE_RESTORED, 'true') + displayName: Create Cache Dir + +- - script: | +- set -e -x +- mkdir -p $HOME/.onnx +- docker run -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" --rm \ +- --volume /data/onnx:/data/onnx:ro \ +- --volume $(Build.SourcesDirectory):/onnxruntime_src \ +- --volume $(Build.BinariesDirectory):/build \ +- --volume /data/models:/build/models:ro \ +- --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ +- --volume $(Pipeline.Workspace)/ccache:/cache \ +- -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ +- -e NIGHTLY_BUILD \ +- -e BUILD_BUILDNUMBER \ +- -e CCACHE_DIR=/cache \ +- $(Repository) \ +- /bin/bash -c " +- set -ex; \ +- env; \ +- ccache -s; \ +- /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ +- --build_dir /build --cmake_generator Ninja \ +- --config Release --update --build \ +- --skip_submodule_sync \ +- --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ +- --build_wheel \ +- --enable_onnx_tests --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda-${{parameters.CudaVersion}} --cudnn_home=/usr/local/cuda-${{parameters.CudaVersion}} \ +- --enable_cuda_profiling --enable_cuda_nhwc_ops \ +- --enable_pybind --build_java \ +- --use_cache \ +- --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86; \ +- ccache -sv; \ +- ccache -z" +- workingDirectory: $(Build.SourcesDirectory) +- displayName: Build Onnxruntime ++ - task: CmdLine@2 ++ inputs: ++ script: | ++ mkdir -p $HOME/.onnx ++ docker run -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" --rm \ ++ --volume /data/onnx:/data/onnx:ro \ ++ --volume $(Build.SourcesDirectory):/onnxruntime_src \ ++ --volume $(Build.BinariesDirectory):/build \ ++ --volume /data/models:/build/models:ro \ ++ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ ++ --volume $(Pipeline.Workspace)/ccache:/cache \ ++ -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ ++ -e NIGHTLY_BUILD \ ++ -e BUILD_BUILDNUMBER \ ++ -e CCACHE_DIR=/cache \ ++ onnxruntimecuda11build \ ++ /bin/bash -c " ++ set -ex; \ ++ env; \ ++ ccache -s; \ ++ /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ ++ --build_dir /build --cmake_generator Ninja \ ++ --config Release --update --build \ ++ --skip_submodule_sync \ ++ --build_shared_lib \ ++ --parallel \ ++ --build_wheel \ ++ --enable_onnx_tests --use_cuda --cuda_version=${{variables.common_cuda_version}} --cuda_home=/usr/local/cuda-${{variables.common_cuda_version}} --cudnn_home=/usr/local/cuda-${{variables.common_cuda_version}} \ ++ --enable_cuda_profiling --enable_cuda_nhwc_ops \ ++ --enable_pybind --build_java \ ++ --use_cache \ ++ --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75; \ ++ ccache -sv; \ ++ ccache -z" ++ workingDirectory: $(Build.SourcesDirectory) + + - task: CmdLine@2 + inputs: +@@ -166,7 +159,7 @@ jobs: + skipComponentGovernanceDetection: true + workspace: + clean: all +- pool: onnxruntime-Linux-GPU-A10 ++ pool: Onnxruntime-Linux-GPU-T4 + dependsOn: + - Linux_Build + steps: +@@ -186,12 +179,12 @@ jobs: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: " +- --network=host ++ --network=host + --build-arg BASEIMAGE=$(docker_base_image) + --build-arg TRT_VERSION=$(linux_trt_version) + --build-arg BUILD_UID=$( id -u ) + " +- Repository: $(Repository) ++ Repository: onnxruntimecuda11build + + - task: CmdLine@2 + inputs: +@@ -204,7 +197,7 @@ jobs: + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + --volume /data/onnx:/data/onnx \ +- $(Repository) \ ++ onnxruntimecuda11build \ + /bin/bash -c " + set -ex; \ + cp /onnxruntime_src/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt /tmp/requirements.txt; \ +@@ -215,8 +208,8 @@ jobs: + cd /onnxruntime_src/java && /onnxruntime_src/java/gradlew cmakeCheck -DcmakeBuildDir=/build/Release -DUSE_CUDA=1; \ + cd /tmp; \ + /tmp/python3 /onnxruntime_src/tools/ci_build/build.py \ +- --build_dir /build --config Release --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests \ +- --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda --cudnn_home=/usr/local/cuda \ ++ --build_dir /build --config Release --test --skip_submodule_sync --build_shared_lib --parallel --build_wheel --enable_onnx_tests \ ++ --use_cuda --cuda_version=${{variables.common_cuda_version}} --cuda_home=/usr/local/cuda --cudnn_home=/usr/local/cuda \ + --enable_pybind --build_java --ctest_path '' " + + - template: templates/clean-agent-build-directory-step.yml +diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +index 75e4ba540..4ca11a4d1 100644 +--- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +@@ -114,7 +114,7 @@ jobs: + --config Release \ + --skip_submodule_sync \ + --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --build_wheel \ + --enable_onnx_tests \ + --use_cuda --cuda_home=/usr/local/cuda-${{ parameters.CudaVersion }} --cudnn_home=/usr/local/cuda-${{ parameters.CudaVersion }} \ +diff --git a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml +index 9cf7a3fb4..f7571a3b7 100644 +--- a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml +@@ -46,7 +46,7 @@ jobs: + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + workspace: + clean: all +- pool: onnxruntime-Ubuntu2204-AMD-CPU ++ pool: onnxruntime-Ubuntu2004-AMD-CPU + timeoutInMinutes: 120 + + steps: +diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +index 0312b70d2..07910911a 100644 +--- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +@@ -63,7 +63,7 @@ jobs: + python3 tools/ci_build/build.py \ + --build_dir build \ + --config Release \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --use_qnn \ + --qnn_home $(QNN_SDK_ROOT) \ + --cmake_generator=Ninja \ +@@ -73,7 +73,7 @@ jobs: + - script: | + python3 tools/ci_build/build.py \ + --build_dir build \ +- --config Release --use_binskim_compliant_compile_flags \ ++ --config Release \ + --test \ + --qnn_home $(QNN_SDK_ROOT) \ + --cmake_generator=Ninja \ +diff --git a/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml +index a3f56f5c4..f5472a49c 100644 +--- a/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/mac-coreml-ci-pipeline.yml +@@ -57,7 +57,7 @@ jobs: + --build_dir build \ + --skip_submodule_sync \ + --cmake_generator=Ninja \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --build_shared_lib \ + --config Debug \ + --use_cache \ +diff --git a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml +index a1ca68c82..33701fccf 100644 +--- a/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/mac-ios-ci-pipeline.yml +@@ -62,7 +62,7 @@ jobs: + --use_xcode \ + --config RelWithDebInfo \ + --build_apple_framework \ +- --parallel --use_binskim_compliant_compile_flags ++ --parallel + displayName: (CPU, CoreML, XNNPACK EPs) Build onnxruntime for iOS x86_64 and run tests using simulator + env: + CC: clang +diff --git a/tools/ci_build/github/azure-pipelines/mac-objc-static-analysis-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-objc-static-analysis-ci-pipeline.yml +index 7e8e72cad..6893fb95c 100644 +--- a/tools/ci_build/github/azure-pipelines/mac-objc-static-analysis-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/mac-objc-static-analysis-ci-pipeline.yml +@@ -26,7 +26,7 @@ jobs: + --enable_training_apis \ + --cmake_extra_defines CMAKE_EXPORT_COMPILE_COMMANDS=ON \ + --update --skip_submodule_sync \ +- --build --parallel --use_binskim_compliant_compile_flags --target onnx_proto ++ --build --parallel --target onnx_proto + displayName: Generate compile_commands.json and ONNX protobuf files + + - script: | +diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +index 21fc205c7..7f73da23b 100644 +--- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +@@ -41,7 +41,7 @@ stages: + parameters: + NpmPackagingMode: ${{ variables.NpmPackagingMode }} + IsReleasePipeline: true +- PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + PackageName: 'onnxruntime-web' + ExtraBuildArgs: '' + UseWebPoolName: true +@@ -54,7 +54,7 @@ stages: + parameters: + NpmPackagingMode: ${{ variables.NpmPackagingMode }} + BuildConfig: 'Release' +- PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + PackageName: 'onnxruntime-react-native' + BuildAndroidAARStageDependsOn: 'Precheck_and_extract_commit' + +diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +index 9393fb07d..4e7093f04 100644 +--- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml ++++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +@@ -103,7 +103,7 @@ stages: + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' +- arguments: '$(BuildCommand) --use_binskim_compliant_compile_flags --parallel --path_to_protoc_exe $(Build.BinariesDirectory)\installed\bin\protoc.exe --build_csharp --update --config $(BuildConfig) ${{ variables.build_py_lto_flag }}' ++ arguments: '$(BuildCommand) --path_to_protoc_exe $(Build.BinariesDirectory)\installed\bin\protoc.exe --build_csharp --update --config $(BuildConfig) ${{ variables.build_py_lto_flag }}' + workingDirectory: '$(Build.BinariesDirectory)' + + - ${{ if notIn(parameters['sln_platform'], 'Win32', 'x64') }}: +@@ -176,7 +176,7 @@ stages: + python.exe -m pip install -q --upgrade %WHEEL_FILENAME% + set PATH=%PATH%;$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig) + @echo %PATH% +- python $(Build.SourcesDirectory)\tools\ci_build\build.py $(BuildCommand) --parallel --use_binskim_compliant_compile_flags --test --config $(BuildConfig) ${{ variables.build_py_lto_flag }} ++ python $(Build.SourcesDirectory)\tools\ci_build\build.py $(BuildCommand) --test --config $(BuildConfig) ${{ variables.build_py_lto_flag }} + workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' + displayName: 'Run tests' + +diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +index 2567bec9f..f44106c14 100644 +--- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml ++++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +@@ -1,5 +1,5 @@ + parameters: +- AgentPool: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ AgentPool: 'onnxruntime-Ubuntu2004-AMD-CPU' + ArtifactSuffix: '' + NugetPackageName : '' + StageSuffix: 'CPU' +diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml +index d8f02054a..018672e0b 100644 +--- a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml +@@ -44,7 +44,7 @@ jobs: + skipComponentGovernanceDetection: true + CCACHE_DIR: $(Pipeline.Workspace)/ccache + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] +- pool: onnxruntime-Ubuntu-2204-Training-CPU ++ pool: onnxruntime-Ubuntu-2004-Training-CPU + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' +@@ -102,7 +102,7 @@ jobs: + --config Release \ + --skip_submodule_sync \ + --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --build_wheel \ + --enable_onnx_tests \ + --enable_training \ +diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +index 71b224b65..a53f91fb3 100644 +--- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +@@ -37,7 +37,7 @@ jobs: + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + workspace: + clean: all +- pool: onnxruntime-Ubuntu2204-AMD-CPU ++ pool: onnxruntime-Ubuntu2004-AMD-CPU + timeoutInMinutes: 120 + + steps: +@@ -132,7 +132,7 @@ jobs: + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + workspace: + clean: all +- pool: onnxruntime-Ubuntu2204-AMD-CPU ++ pool: onnxruntime-Ubuntu2004-AMD-CPU + timeoutInMinutes: 120 + + steps: +diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml +index bf1ba71b7..817ace057 100644 +--- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml ++++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml +@@ -16,7 +16,7 @@ stages: + timeoutInMinutes: 180 + workspace: + clean: all +- pool: onnxruntime-Ubuntu2204-AMD-CPU ++ pool: onnxruntime-Ubuntu2004-AMD-CPU + + strategy: + matrix: +@@ -69,7 +69,7 @@ stages: + --config Debug Release \ + --skip_submodule_sync \ + --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --build_wheel \ + --enable_onnx_tests \ + --enable_pybind --enable_training +diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +index 3ec5400da..5ee398767 100644 +--- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml ++++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +@@ -4,7 +4,7 @@ stages: + parameters: + NpmPackagingMode: 'dev' + IsReleasePipeline: true +- PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + BuildStaticLib: true + ExtraBuildArgs: '' + UseWebPoolName: true +@@ -367,7 +367,7 @@ stages: + timeoutInMinutes: 150 + variables: + skipComponentGovernanceDetection: true +- pool: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + steps: + - template: templates/set-version-number-variables-step.yml + +@@ -413,7 +413,7 @@ stages: + - job: AndroidCustomBuildScript + workspace: + clean: all +- pool: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + variables: + dockerImageTag: onnxruntime-android-custom-build + steps: +diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +index 04f555deb..55d3150f2 100644 +--- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +@@ -18,7 +18,7 @@ stages: + - template: templates/py-packaging-linux-test-cpu.yml + parameters: + arch: 'x86_64' +- machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + base_image: 'registry.access.redhat.com/ubi8/ubi' + devtoolset_rootpath: /opt/rh/gcc-toolset-12/root + ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 +diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +index b0509467e..47d97787d 100644 +--- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +@@ -37,7 +37,7 @@ jobs: + buildArch: x64 + setVcvars: true + ALLOW_RELEASED_ONNX_OPSET_ONLY: '1' +- commonBuildArgs: '--compile_no_warning_as_error --build_dir $(Build.BinariesDirectory)\Windows --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --use_qnn --qnn_home C:\data\qnnsdk\${{parameters.QnnSdk}} --parallel --use_binskim_compliant_compile_flags ' ++ commonBuildArgs: '--compile_no_warning_as_error --build_dir $(Build.BinariesDirectory)\Windows --skip_submodule_sync --build_shared_lib --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --use_qnn --qnn_home C:\data\qnnsdk\${{parameters.QnnSdk}} --parallel' + + steps: + - template: templates/set-version-number-variables-step.yml +diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml +index f82c80d4d..e6d8ee35e 100644 +--- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml ++++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-packaging-stage.yml +@@ -105,7 +105,7 @@ stages: + - template: ../templates/py-linux-gpu.yml + parameters: + arch: 'x86_64' +- machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + docker_base_image: ${{ variables.docker_base_image }} +diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml +index 2a4debcf9..4f440e0f6 100644 +--- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml ++++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml +@@ -20,7 +20,7 @@ stages: + dependsOn: [] + jobs: + - job: +- pool: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + steps: + - checkout: none + - task: DownloadPipelineArtifact@2 +diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +index 509fea45e..5e61f88b4 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +@@ -33,7 +33,7 @@ parameters: + - name: pool_name + displayName: Pool name + type: string +- default: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ default: 'onnxruntime-Ubuntu2004-AMD-CPU' + + - name: packageName + # now we can build onnxruntime or onnxruntime-mobile for Android, need specify it here +diff --git a/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml b/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml +index e77b1a400..e664cf69d 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/build-linux-wasm-step.yml +@@ -24,17 +24,19 @@ parameters: + type: string + + steps: ++ - task: Cache@2 ++ inputs: ++ ${{if eq(variables['Build.SourceBranchName'], 'merge')}}: ++ key: ' "${{parameters.TODAY}}" | ${{parameters.AdditionalKey}} | merge ' ++ ${{else}}: ++ key: '"${{parameters.TODAY}}" | ${{parameters.AdditionalKey}} | $(Build.SourceVersion) ' ++ path: ${{parameters.CacheDir}} ++ restoreKeys: | ++ "${{parameters.TODAY}}" | ${{parameters.AdditionalKey}} ++ displayName: Cache Task ++ condition: eq('${{parameters.WithCache}}', true) ++ + - ${{if eq(parameters.WithCache, true)}}: +- - task: Cache@2 +- inputs: +- ${{if eq(variables['Build.SourceBranchName'], 'merge')}}: +- key: ' "${{parameters.TODAY}}" | ${{parameters.AdditionalKey}} | merge ' +- ${{else}}: +- key: '"${{parameters.TODAY}}" | ${{parameters.AdditionalKey}} | $(Build.SourceVersion) ' +- path: ${{parameters.CacheDir}} +- restoreKeys: | +- "${{parameters.TODAY}}" | ${{parameters.AdditionalKey}} +- displayName: Cache Task + - script: | + set -e -x + pushd '$(Build.SourcesDirectory)/cmake/external/emsdk' +diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +index 8bdb395c0..3325e2617 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +@@ -760,7 +760,7 @@ stages: + + - template: ../nuget/templates/test_linux.yml + parameters: +- AgentPool : onnxruntime-Ubuntu2204-AMD-CPU ++ AgentPool : onnxruntime-Ubuntu2004-AMD-CPU + NugetPackageName : 'Microsoft.ML.OnnxRuntime' + ArtifactSuffix: 'CPU' + SpecificArtifact: ${{ parameters.SpecificArtifact }} +@@ -797,7 +797,7 @@ stages: + OS: Linux + BuildId: ${{ parameters.BuildId }} + SpecificArtifact: ${{ parameters.SpecificArtifact }} +- PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + + - template: final-jar-testing.yml + parameters: +diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +index 2da3b8a9b..8538f15e9 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +@@ -19,7 +19,7 @@ parameters: + + - name: PoolName + type: string +- default: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ default: 'onnxruntime-Ubuntu2004-AMD-CPU' + + - name: ArtifactNamePrefix + type: string +@@ -64,7 +64,7 @@ jobs: + docker run --rm --volume /data/onnx:/data/onnx:ro --volume $(Build.SourcesDirectory):/onnxruntime_src --volume $(Build.BinariesDirectory):/build \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}} /bin/bash -c "python3.9 \ + /onnxruntime_src/tools/ci_build/build.py --enable_lto --build_java --build_nodejs --build_dir /build --config Release \ +- --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib ${{ parameters.AdditionalBuildFlags }} && cd /build/Release && make install DESTDIR=/build/linux-${{parameters.OnnxruntimeArch}}" ++ --skip_submodule_sync --parallel --build_shared_lib ${{ parameters.AdditionalBuildFlags }} && cd /build/Release && make install DESTDIR=/build/linux-${{parameters.OnnxruntimeArch}}" + workingDirectory: $(Build.SourcesDirectory) + displayName: 'Build' + +diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +index 95e34cd86..55f6561b7 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +@@ -11,7 +11,7 @@ steps: + packageType: upack + feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' + definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' +- version: 1.0.133 ++ version: 1.0.132 + downloadPath: $(Build.BinariesDirectory)/deps + + # The private ADO project +@@ -22,7 +22,7 @@ steps: + packageType: upack + feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' + definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' +- version: 1.0.133 ++ version: 1.0.132 + downloadPath: $(Build.BinariesDirectory)/deps + + # You can add more ADO accounts at here. +diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +index dd703f319..e40c4d0e9 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +@@ -138,7 +138,7 @@ jobs: + Today: $(TODAY) + CacheDir: $(ORT_CACHE_DIR) + AdditionalKey: " $(System.StageName) | ${{ parameters.BuildConfig }} " +- BuildPyArguments: '--config ${{ parameters.BuildConfig }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_csharp --update --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests ${{ parameters.additionalBuildFlags }}' ++ BuildPyArguments: '--config ${{ parameters.BuildConfig }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_csharp --update --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests ${{ parameters.additionalBuildFlags }}' + MsbuildArguments: '-maxcpucount' + BuildArch: ${{ parameters.buildArch }} + Platform: ${{ parameters.msbuildPlatform }} +diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml +index 15165e3cb..7b9788d90 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml +@@ -1,5 +1,5 @@ + parameters: +- AgentPool : 'onnxruntime-Ubuntu2204-AMD-CPU' ++ AgentPool : 'onnxruntime-Ubuntu2004-AMD-CPU' + StageName : 'Linux_CI_Dev' + RunDockerBuildArgs: '-o ubuntu20.04 -d cpu -x "--build_wheel"' + NuPackScript: '' +diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml +index 8972d55f6..6ad5f9f38 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml +@@ -32,7 +32,7 @@ stages: + BaseImage: 'registry.access.redhat.com/ubi8/ubi' + OnnxruntimeArch: 'x64' + OnnxruntimeNodejsBindingArch: 'x64' +- PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + ArtifactNamePrefix: ${{ parameters.ArtifactNamePrefix }} + PackageJava: ${{ parameters.PackageJava }} + PackageNodeJS: ${{ parameters.PackageNodeJS }} +diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +index 360e3d5ef..e6693a6f6 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +@@ -13,7 +13,7 @@ parameters: + + - name: PoolName + type: string +- default: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ default: 'onnxruntime-Ubuntu2004-AMD-CPU' + + - name: SkipPublish + type: boolean +@@ -174,7 +174,7 @@ jobs: + ${{ else }}: + AdditionalKey: wasm_simd_jsep | ${{ parameters.BuildConfig }} + CacheDir: $(ORT_CACHE_DIR)/wasm_simd_jsep +- Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_simd_jsep --enable_wasm_simd --use_jsep --use_webnn --target onnxruntime_webassembly --skip_tests' ++ Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_simd_jsep --enable_wasm_simd --use_jsep --target onnxruntime_webassembly --skip_tests' + DisplayName: 'Build (simd + JSEP)' + WithCache: ${{ parameters.WithCache }} + - template: build-linux-wasm-step.yml +@@ -185,7 +185,7 @@ jobs: + ${{ else }}: + AdditionalKey: wasm_simd_threads_jsep | ${{ parameters.BuildConfig }} + CacheDir: $(ORT_CACHE_DIR)/wasm_simd_threads_jsep +- Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_simd_threads_jsep --enable_wasm_simd --enable_wasm_threads --use_jsep --use_webnn --target onnxruntime_webassembly --skip_tests' ++ Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_simd_threads_jsep --enable_wasm_simd --enable_wasm_threads --use_jsep --target onnxruntime_webassembly --skip_tests' + DisplayName: 'Build (simd + threads + JSEP)' + WithCache: ${{ parameters.WithCache }} + +diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml +index 7672b604a..0cb77e222 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml +@@ -47,14 +47,14 @@ steps: + BuildStep: + - script: | + rm -rf $(Build.BinariesDirectory)/Release +- python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --update --build ${{ parameters.AdditionalBuildFlags }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --config Release ++ python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --update --build ${{ parameters.AdditionalBuildFlags }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --parallel --build_shared_lib --config Release + displayName: 'Build ${{ parameters.MacosArch }}' + env: + CCACHE_DIR: ${{ parameters.CacheDir }} + + - ${{ if eq(parameters.MacosArch, 'x86_64') }}: + - script: | +- python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --test ${{ parameters.AdditionalBuildFlags }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --config Release ++ python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --test ${{ parameters.AdditionalBuildFlags }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --parallel --build_shared_lib --config Release + displayName: 'Running Tests' + + - task: ShellScript@2 +diff --git a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml +index cf39be23c..51583a25f 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml +@@ -336,7 +336,7 @@ stages: + + - template: ../nuget/templates/test_linux.yml + parameters: +- AgentPool : onnxruntime-Ubuntu2204-AMD-CPU ++ AgentPool : onnxruntime-Ubuntu2004-AMD-CPU + NugetPackageName : 'Microsoft.ML.OnnxRuntime.Training' + ArtifactSuffix: 'Training-CPU' + StageSuffix: 'Training_CPU' +diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml +index 01cab936a..00ba5ea4a 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-selectable-stage.yml +@@ -48,7 +48,7 @@ stages: + timeoutInMinutes: 90 + workspace: + clean: all +- pool: onnxruntime-Ubuntu2204-AMD-CPU ++ pool: onnxruntime-Ubuntu2004-AMD-CPU + strategy: + matrix: + ${{ each PythonVersion in parameters.python_version }}: +diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +index 146e3e584..28870a9ee 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +@@ -182,7 +182,7 @@ stages: + --enable_pybind + --enable_onnx_tests + ${{ parameters.build_py_parameters }} +- --parallel --use_binskim_compliant_compile_flags --update ++ --parallel --update + $(TelemetryOption) + workingDirectory: '$(Build.BinariesDirectory)' + +@@ -388,7 +388,7 @@ stages: + set -e -x + export _PYTHON_HOST_PLATFORM=macosx-${{variables.MACOSX_DEPLOYMENT_TARGET}}-universal2 + python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' +- python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --build_dir $(Build.BinariesDirectory) --use_coreml --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --config Release --build_wheel ${{ parameters.build_py_parameters }} --use_coreml --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" --update --build ++ python3 $(Build.SourcesDirectory)/tools/ci_build/build.py --build_dir $(Build.BinariesDirectory) --use_coreml --skip_submodule_sync --parallel --config Release --build_wheel ${{ parameters.build_py_parameters }} --use_coreml --cmake_extra_defines CMAKE_OSX_ARCHITECTURES="arm64;x86_64" --update --build + displayName: 'Command Line Script' + + - script: | +@@ -435,7 +435,7 @@ stages: + - template: py-linux.yml + parameters: + arch: 'x86_64' +- machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + base_image: 'registry.access.redhat.com/ubi8/ubi' + devtoolset_rootpath: /opt/rh/gcc-toolset-12/root + ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 +@@ -448,7 +448,7 @@ stages: + - template: py-linux-gpu.yml + parameters: + arch: 'x86_64' +- machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' + extra_build_arg: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} + +diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml +index c6921e151..158037661 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage.yml +@@ -171,7 +171,7 @@ stages: + --build_dir /build \ + --config ${{ variables['buildConfig'] }} \ + --skip_submodule_sync \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --build_wheel \ + --enable_onnx_tests \ + ${{ parameters.build_py_parameters }} \ +diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +index 18368e59c..c83e130dd 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +@@ -110,13 +110,13 @@ jobs: + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > +- --config RelWithDebInfo ++ --config RelWithDebInfo --enable_qspectre + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --enable_pybind + --enable_onnx_tests +- --parallel --use_binskim_compliant_compile_flags --update ++ --parallel --update + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} + workingDirectory: '$(Build.BinariesDirectory)' + +diff --git a/tools/ci_build/github/azure-pipelines/templates/rocm.yml b/tools/ci_build/github/azure-pipelines/templates/rocm.yml +index 43a80aa4f..2e9e6c6b3 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/rocm.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/rocm.yml +@@ -14,7 +14,7 @@ jobs: + workspace: + clean: all + timeoutInMinutes: 180 +- pool: Ubuntu-2204-rocm-aiinfra ++ pool: Ubuntu-2004-rocm-aiinfra + variables: + - name: PythonVersion + value: ${{ parameters.PythonVersion }} +diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +index ed32c5d0e..d1dff0769 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +@@ -78,6 +78,10 @@ stages: + pip install -r tools/ci_build/github/apple/ios_packaging.requirements.txt + displayName: "Install Python requirements" + ++ - script: | ++ $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_protobuf.sh -p $(Build.BinariesDirectory)/protobuf_install -d $(Build.SourcesDirectory)/cmake/deps.txt ++ displayName: "Build Host Protoc" ++ + # create and test mobile pods + - script: | + python tools/ci_build/github/apple/build_and_assemble_apple_pods.py \ +@@ -87,7 +91,8 @@ stages: + --test \ + --variant ${{ parameters.packageVariant }} \ + --build-settings-file "${{ variables.buildSettingsFile }}" \ +- ${{ variables.optionalIncludeOpsByConfigOption }} ++ ${{ variables.optionalIncludeOpsByConfigOption }} \ ++ -b="--path_to_protoc_exe=$(Build.BinariesDirectory)/protobuf_install/bin/protoc" + displayName: "Build macOS/iOS framework and assemble pod package files" + + - script: | +diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +index 8ed22153f..31e41eb4b 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +@@ -150,9 +150,9 @@ stages: + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + ${{ if eq(parameters['UseIncreasedTimeoutForTests'], 'true') }}: +- arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} --test_all_timeout 72000' ++ arguments: '--config RelWithDebInfo --enable_qspectre --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} --test_all_timeout 72000' + ${{ else }}: +- arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} ' ++ arguments: '--config RelWithDebInfo --enable_qspectre --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} ' + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 +@@ -172,7 +172,7 @@ stages: + condition: and(succeeded(), eq('${{ parameters.runTests}}', true)) + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' +- arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }}' ++ arguments: '--config RelWithDebInfo --enable_qspectre --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }}' + workingDirectory: '$(Build.BinariesDirectory)' + + - script: | +diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +index f2005ec5a..79647cc56 100644 +--- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml ++++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +@@ -127,14 +127,14 @@ jobs: + displayName: 'Build (simd + JSEP)' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' +- arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)\wasm_simd_jsep --enable_wasm_simd --use_jsep --use_webnn --target onnxruntime_webassembly --skip_tests' ++ arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)\wasm_simd_jsep --enable_wasm_simd --use_jsep --target onnxruntime_webassembly --skip_tests' + workingDirectory: '$(Build.BinariesDirectory)' + - ${{ if eq(parameters.BuildJsep, true) }}: + - task: PythonScript@0 + displayName: 'Build (simd + threads + JSEP)' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' +- arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)\wasm_simd_threads_jsep --enable_wasm_simd --enable_wasm_threads --use_jsep --use_webnn --target onnxruntime_webassembly --skip_tests' ++ arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)\wasm_simd_threads_jsep --enable_wasm_simd --enable_wasm_threads --use_jsep --target onnxruntime_webassembly --skip_tests' + workingDirectory: '$(Build.BinariesDirectory)' + - ${{ if eq(parameters.SkipPublish, false) }}: + - script: | +diff --git a/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml +index 24809ccfd..e352a0406 100644 +--- a/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/web-ci-pipeline.yml +@@ -53,7 +53,7 @@ stages: + parameters: + NpmPackagingMode: ${{ variables.NpmPackagingMode }} + IsReleasePipeline: false +- PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' ++ PoolName: 'onnxruntime-Ubuntu2004-AMD-CPU' + BuildStaticLib: true + ExtraBuildArgs: $(ExtraBuildArgs) + WASMTemplate: linux-wasm-ci.yml +diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +index 53eea1d69..d65b75ba9 100644 +--- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +@@ -74,7 +74,7 @@ stages: + displayName: 'Build and Test' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' +- arguments: --config Debug --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --disable_memleak_checker --enable_address_sanitizer ++ arguments: --config Debug --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --parallel --cmake_generator "Visual Studio 17 2022" --disable_memleak_checker --enable_address_sanitizer + workingDirectory: '$(Build.BinariesDirectory)' + + +diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml +index 6cbe20bb9..2c4e4eb01 100644 +--- a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-ci-pipeline.yml +@@ -54,7 +54,7 @@ jobs: + WithCache: True + Today: $(TODAY) + AdditionalKey: "gpu-tensorrt | RelWithDebInfo" +- BuildPyArguments: '--config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75' ++ BuildPyArguments: '--config RelWithDebInfo --enable_qspectre --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75' + MsbuildArguments: $(MsbuildArguments) + BuildArch: 'x64' + Platform: 'x64' +@@ -74,7 +74,7 @@ jobs: + del wheel_filename_file + python.exe -m pip install -q --upgrade %WHEEL_FILENAME% + set PATH=$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo;%PATH% +- python $(Build.SourcesDirectory)\tools\ci_build\build.py --config RelWithDebInfo --use_binskim_compliant_compile_flags --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 ++ python $(Build.SourcesDirectory)\tools\ci_build\build.py --config RelWithDebInfo --enable_qspectre --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 + + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + displayName: 'Run tests' +diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +index 6246bb835..c686fc57a 100644 +--- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml ++++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +@@ -76,7 +76,7 @@ jobs: + WithCache: True + Today: $(TODAY) + AdditionalKey: "win-qnn | $(BuildConfig)" +- BuildPyArguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --compile_no_warning_as_error --update --cmake_generator "Visual Studio 17 2022" --use_qnn --qnn_home $(QNN_SDK_ROOT) --parallel --use_binskim_compliant_compile_flags' ++ BuildPyArguments: '--config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --compile_no_warning_as_error --update --cmake_generator "Visual Studio 17 2022" --use_qnn --qnn_home $(QNN_SDK_ROOT) --parallel' + MsbuildArguments: $(MsbuildArguments) + BuildArch: $(buildArch) + Platform: 'x64' +diff --git a/tools/ci_build/github/linux/build_cuda_c_api_package.sh b/tools/ci_build/github/linux/build_cuda_c_api_package.sh +index aec02f766..106536c00 100755 +--- a/tools/ci_build/github/linux/build_cuda_c_api_package.sh ++++ b/tools/ci_build/github/linux/build_cuda_c_api_package.sh +@@ -4,6 +4,6 @@ docker run --gpus all -e NVIDIA_VISIBLE_DEVICES=all --rm --volume \ + $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \ + --volume /data/models:/build/models:ro --volume /data/onnx:/data/onnx:ro -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}build \ + /usr/bin/python3.9 /onnxruntime_src/tools/ci_build/build.py --enable_lto --build_java --build_nodejs --build_dir /build --config Release \ +---skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION \ ++--skip_submodule_sync --parallel --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION \ + --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr/local/cuda-$CUDA_VERSION \ + --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80' +diff --git a/tools/ci_build/github/linux/build_linux_python_package.sh b/tools/ci_build/github/linux/build_linux_python_package.sh +index bc57cf412..1059dd504 100755 +--- a/tools/ci_build/github/linux/build_linux_python_package.sh ++++ b/tools/ci_build/github/linux/build_linux_python_package.sh +@@ -7,9 +7,9 @@ mkdir -p /build/dist + + EXTRA_ARG="" + +-# Put 3.8 at the last because Ubuntu 22.04 use python 3.10 and we will upload the intermediate build files of this +-# config to Azure DevOps Artifacts and download them to a Ubuntu 22.04 machine to run the tests. +-PYTHON_EXES=("/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp311-cp311/bin/python3.11" "/opt/python/cp312-cp312/bin/python3.12" "/opt/python/cp310-cp310/bin/python3.10") ++# Put 3.8 at the last because Ubuntu 20.04 use python 3.8 and we will upload the intermediate build files of this ++# config to Azure DevOps Artifacts and download them to a Ubuntu 20.04 machine to run the tests. ++PYTHON_EXES=("/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11" "/opt/python/cp312-cp312/bin/python3.12" "/opt/python/cp38-cp38/bin/python3.8") + while getopts "d:p:x:c:" parameter_Option + do case "${parameter_Option}" + in +@@ -23,7 +23,7 @@ c) BUILD_CONFIG=${OPTARG};; + esac + done + +-BUILD_ARGS=("--build_dir" "/build" "--config" "$BUILD_CONFIG" "--update" "--build" "--skip_submodule_sync" "--parallel" "--use_binskim_compliant_compile_flags" "--build_wheel") ++BUILD_ARGS=("--build_dir" "/build" "--config" "$BUILD_CONFIG" "--update" "--build" "--skip_submodule_sync" "--parallel" "--build_wheel") + + if [ "$BUILD_CONFIG" != "Debug" ]; then + BUILD_ARGS+=("--enable_lto") +diff --git a/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh b/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh +index 7d65a6f73..a65be0cb6 100755 +--- a/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh ++++ b/tools/ci_build/github/linux/build_tensorrt_c_api_package.sh +@@ -4,4 +4,4 @@ mkdir -p $HOME/.onnx + docker run --gpus all -e CFLAGS -e CXXFLAGS -e NVIDIA_VISIBLE_DEVICES=all --rm --volume /data/onnx:/data/onnx:ro --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \ + --volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda${CUDA_VERSION_MAJOR}xtrt86build \ + /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release \ +---skip_submodule_sync --parallel --use_binskim_compliant_compile_flags --build_shared_lib --build_java --build_nodejs --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80' ++--skip_submodule_sync --parallel --build_shared_lib --build_java --build_nodejs --use_tensorrt --cuda_version=$CUDA_VERSION --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80' +diff --git a/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh b/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh +index 640028ee7..b35bbfbd5 100755 +--- a/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh ++++ b/tools/ci_build/github/linux/ort_minimal/build_full_ort_and_create_ort_files.sh +@@ -22,7 +22,7 @@ python3 /onnxruntime_src/tools/ci_build/build.py \ + --build_dir ${BUILD_DIR} --cmake_generator Ninja \ + --config Debug \ + --skip_submodule_sync \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --build_wheel \ + --skip_tests \ + --enable_training_ops \ +diff --git a/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh b/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh +index 58d493086..2efcff917 100755 +--- a/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh ++++ b/tools/ci_build/github/linux/ort_minimal/build_minimal_ort_and_run_tests.sh +@@ -72,7 +72,7 @@ python3 /onnxruntime_src/tools/ci_build/build.py \ + --config Debug \ + --skip_submodule_sync \ + --build_shared_lib \ +- --parallel --use_binskim_compliant_compile_flags \ ++ --parallel \ + --minimal_build ${MINIMAL_BUILD_ARGS} \ + --disable_ml_ops \ + --include_ops_by_config ${REDUCED_OPS_CONFIG_FILE} \ +diff --git a/tools/ci_build/github/linux/run_build.sh b/tools/ci_build/github/linux/run_build.sh +index 25b361087..43e154389 100755 +--- a/tools/ci_build/github/linux/run_build.sh ++++ b/tools/ci_build/github/linux/run_build.sh +@@ -37,7 +37,7 @@ if [ $BUILD_OS = "yocto" ]; then + + make -j$(nproc) + else +- COMMON_BUILD_ARGS="--skip_submodule_sync --enable_onnx_tests --parallel --use_binskim_compliant_compile_flags --cmake_path /usr/bin/cmake --ctest_path /usr/bin/ctest" ++ COMMON_BUILD_ARGS="--skip_submodule_sync --enable_onnx_tests --parallel --cmake_path /usr/bin/cmake --ctest_path /usr/bin/ctest" + + if [ $BUILD_DEVICE = "gpu" ]; then + _CUDNN_VERSION=$(echo $CUDNN_VERSION | cut -d. -f1-2) +diff --git a/tools/ci_build/github/linux/run_python_tests.sh b/tools/ci_build/github/linux/run_python_tests.sh +index 082c561dd..3164a10a0 100755 +--- a/tools/ci_build/github/linux/run_python_tests.sh ++++ b/tools/ci_build/github/linux/run_python_tests.sh +@@ -15,7 +15,7 @@ c) BUILD_CONFIG=${OPTARG};; + esac + done + +-export PATH=/opt/python/cp310-cp310/bin:$PATH ++export PATH=/opt/python/cp38-cp38/bin:$PATH + cd /build + files=(whl/*.whl) + FILE_NAME="${files[0]}" +diff --git a/tools/ci_build/set-trigger-rules.py b/tools/ci_build/set-trigger-rules.py +index d26fec410..cdb75154e 100644 +--- a/tools/ci_build/set-trigger-rules.py ++++ b/tools/ci_build/set-trigger-rules.py +@@ -14,7 +14,6 @@ skip_doc_changes = ["web-ci-pipeline.yml"] + skip_js_changes = [ + "android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml", + "android-x86_64-crosscompile-ci-pipeline.yml", +- "bigmodels-ci-pipeline.yml", + "linux-ci-pipeline.yml", + "linux-cpu-aten-pipeline.yml", + "linux-cpu-eager-pipeline.yml", +@@ -32,6 +31,7 @@ skip_js_changes = [ + "orttraining-linux-ci-pipeline.yml", + "orttraining-linux-gpu-ci-pipeline.yml", + "orttraining-linux-gpu-ortmodule-distributed-test-ci-pipeline.yml", ++ "orttraining-linux-gpu-training-apis.yml", + "orttraining-mac-ci-pipeline.yml", + "win-ci-pipeline.yml", + "win-gpu-ci-pipeline.yml", +diff --git a/tools/scripts/build_riscv64.sh b/tools/scripts/build_riscv64.sh +deleted file mode 100755 +index 65681c0b6..000000000 +--- a/tools/scripts/build_riscv64.sh ++++ /dev/null +@@ -1,129 +0,0 @@ +-#!/bin/bash +-# Copyright (c) 2024 SiFive, Inc. All rights reserved. +-# Copyright (c) 2024, Phoebe Chen +-# Licensed under the MIT License. +- +- +-# The script is a sample for RISC-V 64-bit cross compilation in +-# GNU/Linux, and you should ensure that your environment meets +-# ORT requirements. You may need to make changes before using it. +- +-set -e +-set -o pipefail +- +-# Get directory this script is in +-DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +-OS=$(uname -s) +- +-if [ "$OS" == "Linux" ]; then +- LINUX_DISTRO=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"') +- if [[ "${LINUX_DISTRO}" == "ubuntu" ]] ;then +- DIR_OS="Linux" +- else +- echo "${LINUX_DISTRO} is not supported" +- return 1 +- fi +-else +- echo "$OS is not supported" +- return 1 +-fi +- +-function cleanup { +- if [ -d "$WORK_DIR" ]; then +- rm -rf "$WORK_DIR" +- fi +-} +- +-# The riscv toolchain, qemu and other platform related settings. +-ORT_ROOT_DIR=$DIR/../.. +- +-PREBUILT_DIR="${ORT_ROOT_DIR}/riscv_tools" +- +-read -rp "Enter the riscv tools root path(press enter to use default path:${PREBUILT_DIR}): " INPUT_PATH +-if [[ "${INPUT_PATH}" ]]; then +- PREBUILT_DIR=${INPUT_PATH} +-fi +-echo "The riscv tool prefix path: ${PREBUILT_DIR}" +- +-WORK_DIR=$DIR/.prebuilt +- +-# The prebuit toolchain download from riscv-collab works with Ubuntu. +-RISCV_GNU_TOOLCHAIN_URL="https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download" +-TOOLCHAIN_VERSION="2023.11.20" +-RISCV_TOOLCHAIN_FILE_NAME="riscv64-glibc-ubuntu-22.04-llvm-nightly-2023.11.20-nightly.tar.gz" +-RISCV_TOOLCHAIN_FILE_SHA="98d6531b757fac01e065460c19abe8974976c607a8d88631cc5c1529d90ba7ba" +- +-TOOLCHAIN_PATH_PREFIX=${PREBUILT_DIR} +- +-execute () { +- if ! eval "$1"; then +- echo "command:\"$1\" error" +- exit 1 +- fi +-} +- +-execute "mkdir -p $WORK_DIR" +- +-# Call the cleanup function when this tool exits. +-trap cleanup EXIT +- +-# Download and install the toolchain from +-# https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download +-download_file() { +- local file_name="$1" +- local install_path="$2" +- local file_sha="$3" +- +- echo "Install $1 to $2" +- if [[ "$(ls -A "$2")" ]]; then +- read -rp "The file already exists. Keep it (y/n)? " replaced +- case ${replaced:0:1} in +- y|Y ) +- echo "Skip download $1." +- return +- ;; +- * ) +- rm -rf "$2" +- ;; +- esac +- fi +- +- echo "Download ${file_name} ..." +- mkdir -p "$install_path" +- wget --progress=bar:force:noscroll --directory-prefix="${WORK_DIR}" \ +- "${RISCV_GNU_TOOLCHAIN_URL}/${TOOLCHAIN_VERSION}/${file_name}" && \ +- echo "${file_sha} ${WORK_DIR}/${file_name}" | sha256sum -c - +- echo "Extract ${file_name} ..." +- tar -C "${install_path}" -xf "${WORK_DIR}/${file_name}" --no-same-owner \ +- --strip-components=1 +-} +- +- +-read -rp "Install RISCV toolchain(y/n)? " answer +-case ${answer:0:1} in +- y|Y ) +- download_file "${RISCV_TOOLCHAIN_FILE_NAME}" \ +- "${TOOLCHAIN_PATH_PREFIX}" \ +- "${RISCV_TOOLCHAIN_FILE_SHA}" +- ;; +- * ) +- echo "Skip install RISCV toolchain." +- ;; +-esac +-echo "download finished." +- +- +-# RISC-V cross compilation in GNU/Linux +-RISCV_TOOLCHAIN_ROOT=${TOOLCHAIN_PATH_PREFIX} +-RISCV_QEMU_PATH=${TOOLCHAIN_PATH_PREFIX}/bin/qemu-riscv64 +-python3 "${ORT_ROOT_DIR}"/tools/ci_build/build.py \ +- --build_dir "${ORT_ROOT_DIR}/build/${DIR_OS}" \ +- --rv64 \ +- --parallel \ +- --skip_tests \ +- --config RelWithDebInfo \ +- --cmake_generator=Ninja \ +- --riscv_qemu_path="${RISCV_QEMU_PATH}" \ +- --riscv_toolchain_root="${RISCV_TOOLCHAIN_ROOT}" "$@" +- +- +diff --git a/tools/scripts/python_test.sh b/tools/scripts/python_test.sh +index 39d9ed432..bfdd4663f 100755 +--- a/tools/scripts/python_test.sh ++++ b/tools/scripts/python_test.sh +@@ -24,5 +24,5 @@ python3 -m pip install $build_dir/$config/dist/*.whl + + echo Run $config unit tests + pushd $build_dir/$config/ +-python3 $src_dir/tools/ci_build/build.py --build_dir $build_dir --cmake_generator Ninja --config $config --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests --enable_transformers_tool_test --ctest_path "" ++python3 $src_dir/tools/ci_build/build.py --build_dir $build_dir --cmake_generator Ninja --config $config --test --skip_submodule_sync --build_shared_lib --parallel --build_wheel --enable_onnx_tests --enable_transformers_tool_test --ctest_path "" + popd