diff --git a/.github/actions/build_app/action.yml b/.github/actions/build_app/action.yml new file mode 100644 index 0000000000..8656eb2d71 --- /dev/null +++ b/.github/actions/build_app/action.yml @@ -0,0 +1,23 @@ +name: 'Build App' +inputs: + ov_dir: + description: 'Directory where OpenVINO is installed' + default: './ov' + required: false + build_dir: + description: 'Directory where the app is built' + default: './build' + required: false + build_target: + description: 'Target to build' + default: '' + required: false +runs: + using: "composite" + steps: + - name: Build app + shell: bash + run: | + source ${{ inputs.ov_dir }}/setupvars.sh + cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ${{ inputs.build_dir }} + cmake --build ${{ inputs.build_dir }} --config Release ${{ inputs.build_target && format('--target {0}', inputs.build_target) || '' }} -j diff --git a/.github/actions/install_openvino/action.yml b/.github/actions/install_openvino/action.yml new file mode 100644 index 0000000000..79d64ead54 --- /dev/null +++ b/.github/actions/install_openvino/action.yml @@ -0,0 +1,18 @@ +name: 'Install OpenVINO' +inputs: + ov_link: + description: 'URL to download OpenVINO' + required: true + ov_dir: + description: 'Directory to install OpenVINO' + default: './ov' + required: false +runs: + using: "composite" + steps: + - name: 'Install OpenVINO' + shell: bash + run: | + mkdir ${{ inputs.ov_dir }} + curl ${{ inputs.ov_link }} | tar --directory ${{ inputs.ov_dir }} --strip-components 1 -xz + sudo ${{ inputs.ov_dir }}/install_dependencies/install_openvino_dependencies.sh diff --git a/.github/actions/install_python_deps/action.yml b/.github/actions/install_python_deps/action.yml new file mode 100644 index 0000000000..8f269cc42e --- /dev/null +++ b/.github/actions/install_python_deps/action.yml @@ -0,0 +1,15 @@ +name: 'Install Python Dependencies' +inputs: + ov_dir: + description: 'Directory where OpenVINO is installed' + default: './ov' + required: false +runs: + using: "composite" + steps: + - name: Install Python dependencies + shell: bash + run: | + source ${{ inputs.ov_dir }}/setupvars.sh + python -m pip install ./thirdparty/openvino_tokenizers/[transformers] --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly + python -m pip install --upgrade-strategy eager -r ./samples/requirements.txt diff --git a/.github/scripts/generate_reference_llava.py b/.github/scripts/generate_reference_llava.py new file mode 100644 index 0000000000..f772fe078a --- /dev/null +++ b/.github/scripts/generate_reference_llava.py @@ -0,0 +1,58 @@ +import argparse +from pathlib import Path +from optimum.intel.openvino import OVModelForVisualCausalLM +from transformers import AutoProcessor +from PIL import Image + +IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".bmp"] + + +def main(model_path: str, images_path: str): + print(f"Selected model: {model_path}\n") + + if Path(images_path).is_file(): + image_files = [Path(images_path)] + else: + image_files = sorted( + [f for f in Path(images_path).glob("*") if f.is_file() and f.suffix.lower() in IMAGE_EXTENSIONS], + key=lambda x: x.name + ) + + if not image_files: + raise FileNotFoundError(f"No images found in '{images_path}' directory. Supported formats: {IMAGE_EXTENSIONS}") + + images = [] + for file in image_files: + images.append( + Image.open(file).convert("RGB") + ) + + print("Images:", image_files) + + model = OVModelForVisualCausalLM.from_pretrained(model_path, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + + conversation = [{ + "role": "user", + "content": [ + *[{"type": "image"} for _ in images], + {"type": "text", "text": "Describe the images."}, + ], + }] + + prompt = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) + print(prompt) + inputs = processor(text=[prompt], images=images, return_tensors="pt") + result = model.generate(**inputs, max_new_tokens=100, do_sample=False) + decoded = processor.tokenizer.batch_decode(result[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] + print(decoded) + with open("ref.txt", "w") as f: + f.write(f"question:\n{decoded}\n----------\nquestion:\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-m", "--model_path", type=str, required=True, help="Path to the model.") + parser.add_argument("-i", "--images_path", type=str, required=True, help="Path to the directory with images.") + args = parser.parse_args() + main(args.model_path, args.images_path) diff --git a/.github/workflows/causal_lm_cpp.yml b/.github/workflows/causal_lm_cpp.yml index 4907d21cbe..29f1d082f2 100644 --- a/.github/workflows/causal_lm_cpp.yml +++ b/.github/workflows/causal_lm_cpp.yml @@ -713,7 +713,7 @@ jobs: echo "Chat sample python" passed - visual_language_chat_sample-ubuntu: + visual_language_chat_sample-ubuntu-minicpm_v2_6: runs-on: ubuntu-22.04-16-cores steps: - uses: actions/checkout@v4 @@ -722,21 +722,13 @@ jobs: - uses: actions/setup-python@v4 with: python-version: 3.11 - - name: Install OpenVINO - run: | - mkdir ./ov/ - curl ${{ env.l_u22_ov_link }} | tar --directory ./ov/ --strip-components 1 -xz - sudo ./ov/install_dependencies/install_openvino_dependencies.sh - - name: Build app - run: | - source ./ov/setupvars.sh - cmake -DCMAKE_BUILD_TYPE=Release -S ./ -B ./build/ - cmake --build ./build/ --config Release --target visual_language_chat py_openvino_genai -j - - name: Install dependencies - run: | - source ./ov/setupvars.sh - python -m pip install ./thirdparty/openvino_tokenizers/[transformers] --pre --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly - python -m pip install --upgrade-strategy eager -r ./samples/requirements.txt opencv-python + - uses: ./.github/actions/install_openvino + with: + ov_link: ${{ env.l_u22_ov_link }} + - uses: ./.github/actions/build_app + with: + build_target: 'visual_language_chat py_openvino_genai' + - uses: ./.github/actions/install_python_deps - name: Download and convert tiny-random-minicpmv-2_6 model and an image run: | python -m pip install git+https://github.com/eaidova/optimum-intel.git@ea/minicpmv @@ -764,13 +756,6 @@ jobs: && ./build/samples/cpp/visual_language_chat/visual_language_chat ./tiny-random-minicpmv-2_6/ ./images/ <<< $'Describe the images?' | tee cpp.txt timeout-minutes: 2 - - name: Encode cpp.txt with Python encoding instead of terminal one - shell: python - run: | - with open("cpp.txt", "rb") as f: - content = f.read().decode("utf-8", "replace") - with open("cpp.txt", "wb") as f: - f.write(content.encode("utf-8")) - name: Run visual_language_chat Python sample - tiny-random-minicpmv-2_6 run: > set -o pipefail @@ -779,6 +764,13 @@ jobs: <<< $'Describe the images?' | tee py.txt env: PYTHONPATH: "./build/" + - name: Encode cpp.txt with Python encoding instead of terminal one + shell: python + run: | + with open("cpp.txt", "rb") as f: + content = f.read().decode("utf-8", "replace") + with open("cpp.txt", "wb") as f: + f.write(content.encode("utf-8")) - run: diff cpp.txt py.txt - name: Run visual_language_chat C++ sample with 2 prompts - tiny-random-minicpmv-2_6 run: > @@ -803,31 +795,44 @@ jobs: with open("cpp2.txt", "wb") as f: f.write(content.encode("utf-8")) - run: diff cpp2.txt py2.txt - - name: Download and convert LLaVa 1.5 model and an image - run: | - source ./ov/setupvars.sh - optimum-cli export openvino --model llava-hf/llava-1.5-7b-hf ./llava_1_5_7b_ov/ - wget https://llava-vl.github.io/static/images/monalisa.jpg - - name: Run visual_language_chat C++ sample - LLaVa 1.5 - run: > - source ./ov/setupvars.sh - && ./build/samples/cpp/visual_language_chat/visual_language_chat ./llava_1_5_7b_ov/ monalisa.jpg - <<< $'Who drew this painting?\nWhen did the painter live?' - timeout-minutes: 4 - - name: Download and convert LLaVa-Next model - run: | - source ./ov/setupvars.sh - optimum-cli export openvino --model llava-hf/llava-v1.6-mistral-7b-hf ./llava_v1_6_mistral_7b_ov/ - - name: Run visual_language_chat C++ sample - LLaVa-Next - run: > - source ./ov/setupvars.sh - && ./build/samples/cpp/visual_language_chat/visual_language_chat ./llava_v1_6_mistral_7b_ov/ monalisa.jpg - <<< $'Who drew this painting?\nWhen did the painter live?' - timeout-minutes: 4 + + visual_language_chat_sample-ubuntu-llava_1_5: + uses: ./.github/workflows/job_vlm_sample_llava.yml + with: + model_id: llava-hf/llava-1.5-7b-hf + model_dir: llava_1_5_7b_ov + + visual_language_chat_sample-ubuntu-llava_next: + uses: ./.github/workflows/job_vlm_sample_llava.yml + with: + model_id: llava-hf/llava-v1.6-mistral-7b-hf + model_dir: llava_v1_6_mistral_7b_ov + + visual_language_chat_sample-ubuntu-internvl2: + runs-on: ubuntu-22.04-16-cores + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + - uses: actions/setup-python@v4 + with: + python-version: 3.11 + - uses: ./.github/actions/install_openvino + with: + ov_link: ${{ env.l_u22_ov_link }} + - uses: ./.github/actions/build_app + with: + build_target: 'visual_language_chat py_openvino_genai' + - uses: ./.github/actions/install_python_deps - name: Download and convert InternVL2 model run: | + # Lowering transformers version, workaround for https://huggingface.co/OpenGVLab/InternVL2-1B/discussions/7 + python -m pip install -U "transformers<4.45.0" source ./ov/setupvars.sh optimum-cli export openvino --model OpenGVLab/InternVL2-4B ./internvl2_4b_ov/ --trust-remote-code + - name: Download images + run: | + wget https://llava-vl.github.io/static/images/monalisa.jpg - name: Run visual_language_chat C++ sample - InternVL2 run: > source ./ov/setupvars.sh @@ -835,7 +840,6 @@ jobs: <<< $'Who drew this painting?\nWhen did the painter live?' timeout-minutes: 4 - cpp-continuous-batching-ubuntu: runs-on: ubuntu-20.04-8-cores defaults: @@ -975,7 +979,7 @@ jobs: cpp-greedy_causal_lm-Qwen-7B-Chat, cpp-beam_search_causal_lm-Qwen1_5-7B-Chat, cpp-beam_search_causal_lm-Phi-2, cpp-beam_search_causal_lm-notus-7b-v1, cpp-speculative_decoding_lm-ubuntu, cpp-prompt_lookup_decoding_lm-ubuntu, cpp-Phi-1_5, cpp-greedy_causal_lm-redpajama-3b-chat, cpp-chat_sample-ubuntu, cpp-continuous-batching-ubuntu, - visual_language_chat_sample-ubuntu, + visual_language_chat_sample-ubuntu-minicpm_v2_6, visual_language_chat_sample-ubuntu-llava_1_5, visual_language_chat_sample-ubuntu-llava_next, visual_language_chat_sample-ubuntu-internvl2, cpp-continuous-batching-windows, cpp-continuous-batching-macos] if: ${{ always() }} runs-on: ubuntu-latest diff --git a/.github/workflows/job_vlm_sample_llava.yml b/.github/workflows/job_vlm_sample_llava.yml new file mode 100644 index 0000000000..300b3df7de --- /dev/null +++ b/.github/workflows/job_vlm_sample_llava.yml @@ -0,0 +1,44 @@ +name: visual_language_chat sample - LLaVA + +on: + workflow_call: + inputs: + model_id: + required: true + type: string + model_dir: + required: true + type: string + +env: + l_u22_ov_link: https://storage.openvinotoolkit.org/repositories/openvino/packages/nightly/2025.0.0-17289-7cf2bbb8391/l_openvino_toolkit_ubuntu22_2025.0.0.dev20241105_x86_64.tgz +jobs: + visual_language_chat_sample-ubuntu-llava: + runs-on: ubuntu-22.04-16-cores + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + - uses: actions/setup-python@v4 + with: + python-version: 3.11 + - uses: ./.github/actions/install_openvino + with: + ov_link: ${{ env.l_u22_ov_link }} + - uses: ./.github/actions/build_app + with: + build_target: 'visual_language_chat py_openvino_genai' + - uses: ./.github/actions/install_python_deps + - name: Download and convert model + run: | + source ./ov/setupvars.sh + optimum-cli export openvino --model ${{ inputs.model_id }} ./${{ inputs.model_dir }} + - name: Download images + run: | + wget https://llava-vl.github.io/static/images/monalisa.jpg + - name: Run visual_language_chat C++ sample + run: > + source ./ov/setupvars.sh + && ./build/samples/cpp/visual_language_chat/visual_language_chat ./${{ inputs.model_dir }} monalisa.jpg + <<< $'Who drew this painting?\nWhen did the painter live?' + timeout-minutes: 4 diff --git a/.gitignore b/.gitignore index 852233c185..729767a485 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,9 @@ !__init__.py !__main__.py +# don't skip GitHub Actions files and directories +!.github/** + # developer tools *.idea .vscode diff --git a/src/cpp/src/llm_pipeline_static.cpp b/src/cpp/src/llm_pipeline_static.cpp index 12d6add5ee..7174321ff5 100644 --- a/src/cpp/src/llm_pipeline_static.cpp +++ b/src/cpp/src/llm_pipeline_static.cpp @@ -20,6 +20,26 @@ namespace { +uint32_t align_to(uint32_t value, uint32_t alignment) { + return (value + alignment - 1) & ~(alignment - 1); +} + +enum class GenerateHint { + FAST_COMPILE, + BEST_PERF +}; + +GenerateHint str_to_hint(const std::string& str) { + if (str == "FAST_COMPILE") { + return GenerateHint::FAST_COMPILE; + } + if (str == "BEST_PERF") { + return GenerateHint::BEST_PERF; + } + OPENVINO_THROW("Unsupported \"GENERATE_HINT\" provided: " + + str + ". Please select either \"FAST_COMPILE\" or \"BEST_PERF\"."); +} + std::shared_ptr cvt_kvcache_to_fp16(const std::shared_ptr& model) { ov::preprocess::PrePostProcessor ppp(model); @@ -275,8 +295,12 @@ ov::AnyMap get_default_prefill_config(const std::shared_ptr& model, } ov::AnyMap get_default_generate_config(const std::shared_ptr& model, - const std::optional& npudesc) { + const std::optional& npudesc, + const GenerateHint hint) { auto config = get_default_common_config(model); + if (hint == GenerateHint::BEST_PERF) { + config.emplace("NPUW_ONLINE_PIPELINE", "NONE"); + } // NB: Unconditionally set for generation model config.emplace("NPUW_DQ", "YES"); if (npudesc.has_value() && npudesc->arch == "4000") { @@ -404,8 +428,8 @@ void StaticLLMPipeline::setupAndCompileModels( m_prefill_model = m_kvcache_model->clone(); m_prefill_model->set_friendly_name(m_kvcache_model->get_friendly_name() + "_prefill"); // (7) Reshape both models to static shape - const uint32_t kMaxPromptLen = pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(1024u); - const uint32_t kMinResponseLen = pop_int_and_cast(properties, "MIN_RESPONSE_LEN").value_or(128u); + const uint32_t kMaxPromptLen = align_to(pop_int_and_cast(properties, "MAX_PROMPT_LEN").value_or(1024u), 64u); + const uint32_t kMinResponseLen = align_to(pop_int_and_cast(properties, "MIN_RESPONSE_LEN").value_or(128u), 64u); KVAxesPosition axes = get_kv_axes(get_model_type_from_json(models_path / "config.json")); m_kvcache_desc = KVCacheDesc { kMaxPromptLen, kMaxPromptLen + kMinResponseLen, 0u, axes.seq_len }; reshape_to_static(m_prefill_model, m_kvcache_desc.max_prompt_size, m_kvcache_desc.max_prompt_size, axes); @@ -414,8 +438,10 @@ void StaticLLMPipeline::setupAndCompileModels( auto prefill_config = pop_or_default( properties, "PREFILL_CONFIG", get_default_prefill_config(m_prefill_model, npudesc) ); + // NB: GENERATE_HINT is only applicable for default generate config! + auto generate_hint = str_to_hint(pop_or_default(properties, "GENERATE_HINT", "FAST_COMPILE")); auto generate_config = pop_or_default( - properties, "GENERATE_CONFIG", get_default_generate_config(m_kvcache_model, npudesc) + properties, "GENERATE_CONFIG", get_default_generate_config(m_kvcache_model, npudesc, generate_hint) ); merge_config_with(prefill_config, properties); merge_config_with(generate_config, properties); diff --git a/src/cpp/src/lora_adapter.cpp b/src/cpp/src/lora_adapter.cpp index 472cc2724e..5e8839513e 100644 --- a/src/cpp/src/lora_adapter.cpp +++ b/src/cpp/src/lora_adapter.cpp @@ -209,6 +209,11 @@ LoRATensors group_lora_tensors(const ConstantMap& tensors, const LoRAPartsParser // Squeeze all dimensions from the right of the shape producing a tensor of 2D shape. NodePtr squeeze_2d (const ov::Output& input) { auto shape = v0::Constant::create(ov::element::i32, {2}, std::vector{0, 0}); + auto dims = static_cast>(input.get_partial_shape()); + OPENVINO_ASSERT( + dims.end() == std::find_if(dims.begin() + 2, dims.end(), [](const ov::Dimension& d) { return d.get_max_length() > 1; }), + "LoRA adapter with not pointwise Convolutional kernel is not supported." + ); auto reshape = std::make_shared(input, shape->output(0), true); return reshape; } @@ -326,8 +331,10 @@ struct LoRAParametersByWeightGetter { ov::Dimension rank = ov::Dimension::dynamic(); if(dynamic_lora_rank) { // Leave rank dynamic if at least one adapter exist for a give node. - if(weight_getter.end() == - std::find_if(weight_getter.begin(), weight_getter.end(), [node](const LoRAWeightGetter& getter) { + // It is important to go over all weight_getter's because they record used LoRA tensors to + // be able to report unused tensors later. + // Hence, avoid find_if here and use an std algorithm that goes over all elements in a sequence. + if(!std::count_if(weight_getter.begin(), weight_getter.end(), [node](const LoRAWeightGetter& getter) { return bool(getter(node->get_friendly_name())); })) { return std::nullopt; @@ -493,7 +500,7 @@ class LoRATransformBase : public ov::pass::MatcherPass { } ~LoRATransformBase () { - DEBUG_PRINT("LoRA applied for " << applied << " layers"); // For debugging purposes only + DEBUG_PRINT("LoRA applied for " << applied << " layers"); } protected: @@ -517,6 +524,9 @@ NodePtr tensors_multiplication(NodePtr input, const NodeVector multipliers, ov:: NodePtr normalized = multipliers[i]; if(normalized->get_output_element_type(0) != target_type) { normalized = std::make_shared(normalized, target_type); + if(std::dynamic_pointer_cast(normalized)) { + input->get_rt_info()["decompression"]; + } } if(normalized->get_output_partial_shape(0).rank().get_length() > 2) { // FIXME: Any other shape patterns possible? @@ -570,6 +580,15 @@ NodePtr decompression_convert (NodePtr node) { // It maps a model signature which is an arbitrary string to OpenVINO infer request. // Defines `evaluate` method that compute a model by a given signature and input tensors. class InferRequestSignatureCache { + + // Infer request with additional input-output pairs that are bypassed from input to output to eliminate Parameter -> Result pairs from the OV model + struct RequestWithBypass { + ov::InferRequest request; + std::vector> bypass; // a set of index pairs [j, k], where j is an index of input tensor to be forwarded to k-th output tensor + std::vector inputs; // inputs[i] gives an index in the original input tensor vector to be set to i-th input of the request + std::vector outputs; // outputs[i] gives an index in the original output tensor vector to be set as an i-th output of the request + }; + public: using Signature = std::string; @@ -579,37 +598,78 @@ class InferRequestSignatureCache { return requests.count(signature); } - void insert (const Signature& signature, std::shared_ptr model) { - ov::Core core = ov::genai::utils::singleton_core(); - requests[signature] = core.compile_model(model, device).create_infer_request(); - } + void insert (const Signature& signature, ov::ResultVector& results, ov::ParameterVector& parameters) { + // Detect Parameter -> Result patterns and do not allow them to be included into compiled model to avoid unnecessary overheads, and handle them via a bypass. + // Assume that each parameter from parameters vector do not have other consumers outside model formed by parameters -> ... -> results. + // That allows filter out those parameters that are consumed by Result operations only detecting them by count of consumers instead of + // tracing dependencies inside the model. + + ov::ResultVector request_results; + request_results.reserve(results.size()); + ov::ParameterVector request_parameters; + request_parameters.reserve(parameters.size()); + RequestWithBypass rwb; + + for(size_t result_index = 0; result_index < results.size(); ++result_index) { + auto& result = results[result_index]; + auto parameter = std::dynamic_pointer_cast(result->get_input_node_shared_ptr(0)); + if(parameter) { + // Bypass result + size_t parameter_index = std::distance(parameters.begin(), std::find(parameters.begin(), parameters.end(), parameter)); + rwb.bypass.emplace_back(parameter_index, result_index); + result.reset(); // enough under the assumption there are no other refernces to that result + } else { + // Normal output + request_results.push_back(result); + rwb.outputs.push_back(result_index); + } + } - ov::InferRequest& at(const Signature& signature) { - return requests.at(signature); + for(size_t parameter_index = 0; parameter_index < parameters.size(); ++parameter_index) { + auto& parameter = parameters[parameter_index]; + if(!parameter->get_output_target_inputs(0).empty()) { + request_parameters.push_back(parameter); + rwb.inputs.push_back(parameter_index); + } else { + parameter.reset(); + } + } + + ov::Core core = ov::genai::utils::singleton_core(); + auto model = std::make_shared(request_results, request_parameters); + rwb.request = core.compile_model(model, device).create_infer_request(); + requests.emplace(signature, rwb); } void evaluate(const Signature& signature, const ov::TensorVector& inputs, ov::TensorVector& outputs) { - auto& request = at(signature); + auto& rwb = at(signature); + auto request = rwb.request; auto compiled_model = request.get_compiled_model(); - OPENVINO_ASSERT(inputs.size() == compiled_model.inputs().size()); - OPENVINO_ASSERT(outputs.size() == compiled_model.outputs().size()); - for(size_t i = 0; i < inputs.size(); ++i) { - request.set_input_tensor(i, inputs[i]); + for(size_t i = 0; i < rwb.inputs.size(); ++i) { + request.set_input_tensor(i, inputs[rwb.inputs[i]]); } - for(size_t i = 0; i < outputs.size(); ++i) { + for(size_t i = 0; i < rwb.outputs.size(); ++i) { auto target_shape = request.get_compiled_model().output(i).get_partial_shape(); - if(target_shape != outputs[i].get_shape() && target_shape.is_static()) { + auto& output_tensor = outputs[rwb.outputs[i]]; + if(target_shape != output_tensor.get_shape() && target_shape.is_static()) { // do it for static case only, because if target shape is dynamic, the plugin is allowed to set shape on its own - outputs[i].set_shape(target_shape.get_shape()); + output_tensor.set_shape(target_shape.get_shape()); } - request.set_output_tensor(i, outputs[i]); + request.set_output_tensor(i, output_tensor); + } + for(auto bypass: rwb.bypass) { + outputs[bypass.second] = inputs[bypass.first]; } request.infer(); // TODO: Consider using async to increase throughput, requires more complicated archestration } private: - std::unordered_map requests; + RequestWithBypass& at(const Signature& signature) { + return requests.at(signature); + } + + std::unordered_map requests; std::string device; }; @@ -667,8 +727,8 @@ class LoRAFuseTransform : public LoRATransformBase { parameters.push_back(std::make_shared(multiplier->get_output_element_type(0), multiplier->get_output_partial_shape(0))); } auto result = std::make_shared(tensors_multiplication(nullptr, NodeVector{parameters.begin() + 1, parameters.end()}, target, false, 1, false)); - auto weights_model = std::make_shared(ov::ResultVector{result}, parameters); - fusers.insert(signature, weights_model); + ov::ResultVector results{result}; + fusers.insert(signature, results, parameters); } // Newly created constants in the next line are not mmaped unlike original weights, so it will inflate required memory @@ -839,7 +899,7 @@ struct AdapterControllerImpl { ov::Tensor(params_getter.type, ov::Shape{0}) }; auto name = node->get_friendly_name(); - auto lora_weight = prepare_lora_tensors(name, params_getter.weight_getter, lora_placeholder, false); + auto lora_weight = prepare_lora_tensors(name, params_getter.weight_getter, lora_placeholder, /*set_empty_tensors=*/false, /*alpha_only=*/false); if(lora_weight.alpha) { return LoRANode( // TODO: Make sure that tensors will not be disposed during constant life time @@ -869,7 +929,6 @@ struct AdapterControllerImpl { } pm.run_passes(model); - model->validate_nodes_and_infer_types(); // FIXME: For debugging purposes only // Collect all variable names to quickly detect which state tensor belongs to this adapter controller later for(const auto& var: variable_ids) { @@ -943,11 +1002,10 @@ struct AdapterControllerImpl { } void set_new_adapter_alphas (ov::InferRequest& infer_request) { - // FIXME: Provide more economical way to update only alphas - set_new_adapter_tensors(infer_request); + set_new_adapter_tensors(infer_request, /*alpha_only=*/true); } - void set_new_adapter_tensors (ov::InferRequest& infer_request) { + void set_new_adapter_tensors (ov::InferRequest& infer_request, bool alpha_only = false) { if(current_config.get_mode() != AdapterConfig::MODE_AUTO && current_config.get_mode() != AdapterConfig::MODE_DYNAMIC && current_config.get_mode() != AdapterConfig::MODE_STATIC_RANK ) { return; } @@ -977,7 +1035,7 @@ struct AdapterControllerImpl { lora_indices.alpha = state_name_to_index.at(lora_var_ids.second.alpha.variable_id); lora_indices.A = state_name_to_index.at(lora_var_ids.second.A.variable_id); lora_indices.B = state_name_to_index.at(lora_var_ids.second.B.variable_id); - set_lora_tensors(state, lora_var_ids.first, lora_var_ids.second, lora_indices, weight_getters); + set_lora_tensors(state, lora_var_ids.first, lora_var_ids.second, lora_indices, weight_getters, alpha_only); } } @@ -988,7 +1046,7 @@ struct AdapterControllerImpl { result.reserve(weight_getters.size()); for(size_t i = 0; i < adapters.size(); ++i) { if(auto lora_tensors = weight_getters[i](lora_name)) { - // FIXME: Introduce more flexible logic of setting alpha based on alpha set in the adapter file itself, now it is ignored and only alpha from config is used + // TODO: Is it practical to use alpha from the adapter file itself. In the current code it is ignored and only alpha from config is used. OPENVINO_ASSERT(lora_tensors->A); OPENVINO_ASSERT(lora_tensors->B); lora_tensors->alpha = alpha_as_constant(current_config.get_alpha(adapters[i])); @@ -1002,53 +1060,68 @@ struct AdapterControllerImpl { return result; } - InferRequestSignatureCache::Signature get_tensor_signature(const ov::Output& output) { - return get_tensor_signature(output.get_element_type(), output.get_partial_shape()); + using Signature = InferRequestSignatureCache::Signature; + + Signature get_tensor_signature(const ov::element::Type& type, const ov::PartialShape& shape) { + return '(' + type.get_type_name() + shape.to_string() + ')'; } - InferRequestSignatureCache::Signature get_tensor_signature(const ov::element::Type& type, const ov::PartialShape& shape) { - return type.get_type_name() + shape.to_string(); + Signature get_tensor_signature(const std::shared_ptr& constant) { + return get_tensor_signature(constant->get_element_type(), constant->get_shape()); } - InferRequestSignatureCache::Signature get_lora_signature(const std::vector& inputs, const LoRAParts& outputs) { - InferRequestSignatureCache::Signature signature; - for(const auto& input: inputs) { - signature += - std::string("(") + - "(" + get_tensor_signature(input.alpha) + ")" + - "(" + get_tensor_signature(input.A) + ")" + // TODO: Adjust shape to have a dynamic low-rank LoRA dimension in case of fully static shape doesn't have significant speedup - "(" + get_tensor_signature(input.B) + ")" + // TODO: Adjust shape to have a dynamic low-rank LoRA dimension in case of fully static shape doesn't have significant speedup - ")"; - } + Signature get_tensor_signature(const ov::Tensor& tensor, const PartialShape& overridden_shape) { + return tensor ? get_tensor_signature(tensor.get_element_type(), overridden_shape) : Signature(); + } + + Signature get_lora_signature(const std::vector& inputs, const LoRAParts& outputs) { + Signature signature; for(const auto& input: inputs) { signature += std::string("(") + - // Shape is set to be dynamic because it doesn't matter for signature as it is completely determined by the corresponding model - "(" + get_tensor_signature(outputs.alpha.get_element_type(), ov::PartialShape::dynamic(1)) + ")" + - "(" + get_tensor_signature(outputs.A.get_element_type(), ov::PartialShape::dynamic(2)) + ")" + - "(" + get_tensor_signature(outputs.B.get_element_type(), ov::PartialShape::dynamic(2)) + ")" + + get_tensor_signature(input.alpha) + + get_tensor_signature(input.A) + // TODO: Adjust shape to have a dynamic low-rank LoRA dimension in case of fully static shape doesn't have significant speedup + get_tensor_signature(input.B) + // TODO: Adjust shape to have a dynamic low-rank LoRA dimension in case of fully static shape doesn't have significant speedup ")"; } + signature += + std::string("(") + + // Shape is set to be dynamic because it doesn't matter for signature as it is completely determined by the corresponding model + // The corresponding model hasn't been created at the moment when this function is got called, so to avoid duplicated shape propagation logic for + // outputs we just use the target rank of output tensors that is enough to distinguish output signatures. + get_tensor_signature(outputs.alpha, ov::PartialShape::dynamic(1)) + + get_tensor_signature(outputs.A, ov::PartialShape::dynamic(2)) + + get_tensor_signature(outputs.B, ov::PartialShape::dynamic(2)) + + ")"; return signature; } - ov::TensorVector to_tensor_vector(const std::vector& v) { + ov::TensorVector to_tensor_vector(const std::vector& v, bool alpha_only) { ov::TensorVector result; - result.reserve(v.size()*3); + result.reserve(v.size()*(alpha_only ? 1 : 3)); for(auto const& lora_weights: v) { result.push_back(lora_weights.alpha->get_tensor_view()); - result.push_back(lora_weights.A->get_tensor_view()); - result.push_back(lora_weights.B->get_tensor_view()); + if(!alpha_only) { + result.push_back(lora_weights.A->get_tensor_view()); + result.push_back(lora_weights.B->get_tensor_view()); + } } return result; } - ov::TensorVector to_tensor_vector(const LoRAParts& lora_tensors) { + LoRAParts from_tensor_vector(const ov::TensorVector& v, bool alpha_only) { + OPENVINO_ASSERT(v.size() == (alpha_only ? 1 : 3)); + return LoRAParts(v[0], alpha_only ? ov::Tensor() : v[1], alpha_only ? ov::Tensor() : v[2]); + } + + ov::TensorVector to_tensor_vector(const LoRAParts& lora_tensors, bool alpha_only) { ov::TensorVector result; - result.reserve(3); + result.reserve((alpha_only ? 1 : 3)); result.push_back(lora_tensors.alpha); - result.push_back(lora_tensors.A); - result.push_back(lora_tensors.B); + if(!alpha_only) { + result.push_back(lora_tensors.A); + result.push_back(lora_tensors.B); + } return result; } @@ -1059,13 +1132,14 @@ struct AdapterControllerImpl { ov::Tensor output, size_t offset, size_t concat_axis, + bool alpha_only, std::function(const LoRAWeight&)> input_accessor, std::function parameter_postprocessing = [](const LoRAWeight&, NodePtr node) { return node; } ) { ov::OutputVector concat_inputs; concat_inputs.reserve(inputs.size()); for(size_t i = 0; i < inputs.size(); ++i) { - NodePtr input = parameters[3*i + offset] = input_accessor(inputs[i]); + NodePtr input = parameters[(alpha_only ? 1 : 3)*i + offset] = input_accessor(inputs[i]); if(input->get_output_element_type(0) != output.get_element_type()) { input = std::make_shared(input, output.get_element_type()); } @@ -1081,16 +1155,9 @@ struct AdapterControllerImpl { result = std::make_shared(concat_inputs, concat_axis); } else { result = concat_inputs.front().get_node_shared_ptr(); - // FIXME: Workaround CPU plugin bug with Parameter -> Result models: add a small constant to force copying input to output - // FIXME: Do it differently: not use model-based evaluation in this case but just pass lora tensor directly as a new state value - if(result == parameters[offset]) { - result = std::make_shared(result, v0::Constant::create(result->get_output_element_type(0), Shape{}, {1e-37f})); - } } results[offset] = std::make_shared(result); - - // TODO: Optimize trivial Parameter->Result cases } LoRAParts empty_adapters(const std::vector& inputs, LoRAParts& outputs) { @@ -1100,14 +1167,16 @@ struct AdapterControllerImpl { return outputs; } - LoRAParts concat_adapters(const std::vector& inputs, LoRAParts& outputs) { + LoRAParts concat_adapters(const std::vector& inputs, LoRAParts& outputs, bool alpha_only) { auto signature = get_lora_signature(inputs, outputs); + size_t inputs_per_adapter = alpha_only ? 1 : 3; if(!lora_state_evaluators.exist(signature)) { // Prepare LoRA state evaluate model - ov::ParameterVector parameters(3*inputs.size()); - ov::ResultVector results(3); + ov::ParameterVector parameters(inputs_per_adapter*inputs.size()); + ov::ResultVector results(inputs_per_adapter); build_concat_model(parameters, results, inputs, outputs.alpha, 0, 1, + alpha_only, [](const LoRAWeight& lora_weight) { return std::make_shared( lora_weight.alpha->get_output_element_type(0), @@ -1121,27 +1190,31 @@ struct AdapterControllerImpl { return std::make_shared(parameter, lora_rank_constant); }); - build_concat_model(parameters, results, inputs, outputs.A, 1, 0, - [](const LoRAWeight& lora_weight) { - return std::make_shared( - lora_weight.A->get_output_element_type(0), - lora_weight.A->get_output_partial_shape(0)); // TODO: Consider using dynamic LoRA rank dimension instead of static dimension - } - ); + if(!alpha_only) { + build_concat_model(parameters, results, inputs, outputs.A, 1, 0, + alpha_only, + [](const LoRAWeight& lora_weight) { + return std::make_shared( + lora_weight.A->get_output_element_type(0), + lora_weight.A->get_output_partial_shape(0)); // TODO: Consider using dynamic LoRA rank dimension instead of static dimension + } + ); - build_concat_model(parameters, results, inputs, outputs.B, 2, 1, - [](const LoRAWeight& lora_weight) { - return std::make_shared( - lora_weight.B->get_output_element_type(0), - lora_weight.B->get_output_partial_shape(0)); // TODO: Consider using dynamic LoRA rank dimension instead of static dimension - } - ); + build_concat_model(parameters, results, inputs, outputs.B, 2, 1, + alpha_only, + [](const LoRAWeight& lora_weight) { + return std::make_shared( + lora_weight.B->get_output_element_type(0), + lora_weight.B->get_output_partial_shape(0)); // TODO: Consider using dynamic LoRA rank dimension instead of static dimension + } + ); + } - lora_state_evaluators.insert(signature, std::make_shared(results, parameters)); + lora_state_evaluators.insert(signature, results, parameters); } - auto output_tensors = to_tensor_vector(outputs); - lora_state_evaluators.evaluate(signature, to_tensor_vector(inputs), output_tensors); - return outputs; + auto output_tensors = to_tensor_vector(outputs, alpha_only); + lora_state_evaluators.evaluate(signature, to_tensor_vector(inputs, alpha_only), output_tensors); + return from_tensor_vector(output_tensors, alpha_only); } ov::Shape dynamic_to_static(const ov::PartialShape& pshape) { @@ -1152,30 +1225,40 @@ struct AdapterControllerImpl { return shape; } - void set_lora_tensors(std::vector& state, const std::string& name, const LoRAVarIDs& lora_var_ids, const LoRAIndices& lora_indices, const std::vector& weight_getters) { + void set_lora_tensors( + std::vector& state, + const std::string& name, + const LoRAVarIDs& lora_var_ids, + const LoRAIndices& lora_indices, + const std::vector& weight_getters, + bool alpha_only + ) { LoRAParts lora_state_tensors{ ov::Tensor(lora_var_ids.alpha.data_type, dynamic_to_static(lora_var_ids.alpha.data_shape)), - ov::Tensor(lora_var_ids.A.data_type, dynamic_to_static(lora_var_ids.A.data_shape)), - ov::Tensor(lora_var_ids.B.data_type, dynamic_to_static(lora_var_ids.B.data_shape)) + alpha_only ? ov::Tensor() : ov::Tensor(lora_var_ids.A.data_type, dynamic_to_static(lora_var_ids.A.data_shape)), + alpha_only ? ov::Tensor() : ov::Tensor(lora_var_ids.B.data_type, dynamic_to_static(lora_var_ids.B.data_shape)) }; - auto new_tensors = prepare_lora_tensors(name, weight_getters, lora_state_tensors); + auto new_tensors = prepare_lora_tensors(name, weight_getters, lora_state_tensors, /*set_empty_adapters=*/true, alpha_only); state[lora_indices.alpha].set_state(new_tensors.alpha); - state[lora_indices.A].set_state(new_tensors.A); - state[lora_indices.B].set_state(new_tensors.B); + if(!alpha_only) { + state[lora_indices.A].set_state(new_tensors.A); + state[lora_indices.B].set_state(new_tensors.B); + } } LoRAParts prepare_lora_tensors ( const std::string& name, const std::vector& weight_getters, LoRAParts& output, - bool set_empty_adapters = true + bool set_empty_adapters, + bool alpha_only ) { - auto lora_tensors = collect_applicable_tensors(name, weight_getters); + auto lora_tensors = collect_applicable_tensors(name, weight_getters); // request A and B regardless of alpha_only, because it is a way to get lora_rank later when alpha is broadcasted LoRAParts new_tensors; if(!lora_tensors.empty()) { - new_tensors = concat_adapters(lora_tensors, output); - } else if(set_empty_adapters) { // FIXME: Make it as a separate step outside of this function + new_tensors = concat_adapters(lora_tensors, output, alpha_only); + } else if(set_empty_adapters) { new_tensors = empty_adapters(lora_tensors, output); } return new_tensors; diff --git a/tools/llm_bench/llm_bench_utils/ov_utils.py b/tools/llm_bench/llm_bench_utils/ov_utils.py index deaaac0a41..5f7fd5c7f1 100644 --- a/tools/llm_bench/llm_bench_utils/ov_utils.py +++ b/tools/llm_bench/llm_bench_utils/ov_utils.py @@ -343,7 +343,6 @@ def create_genai_speech_2_txt_model(model_path, device, **kwargs): def create_speech_2txt_model(model_path, device, **kwargs): """Create speech generation model. - - model_path: can be model_path or IR path - device: can be CPU - model_type: diff --git a/tools/llm_bench/task/speech_to_text_generation.py b/tools/llm_bench/task/speech_to_text_generation.py index ef61664495..ad5fed8b98 100644 --- a/tools/llm_bench/task/speech_to_text_generation.py +++ b/tools/llm_bench/task/speech_to_text_generation.py @@ -46,7 +46,7 @@ def run_speech_2_txt_generation(input_param, args, md5_list, iter_data_list): max_new_tokens=max_gen_tokens, # 'task' and 'language' parameters are supported for multilingual models only language=speech_language, - task="transcribe", + task="translate", return_timestamps=ret_timestamps ) end = time.perf_counter() @@ -57,7 +57,7 @@ def run_speech_2_txt_generation(input_param, args, md5_list, iter_data_list): start = time.perf_counter() result_text = pipe( raw_speech, - generate_kwargs={"task": 'transcribe', "language": speech_language}, + generate_kwargs={"task": 'translate', "language": speech_language}, return_timestamps=ret_timestamps )["text"] end = time.perf_counter() diff --git a/tools/llm_bench/task/text_generation.py b/tools/llm_bench/task/text_generation.py index d0ef71a027..58f6448fe0 100644 --- a/tools/llm_bench/task/text_generation.py +++ b/tools/llm_bench/task/text_generation.py @@ -195,7 +195,7 @@ def run_text_generation_genai(input_text, num, model, tokenizer, args, iter_data mem_consumption.start_collect_memory_consumption() max_gen_tokens = DEFAULT_OUTPUT_TOKEN_SIZE if args['infer_count'] is None else args['infer_count'] start = time.perf_counter() - generation_result = model.generate(input_text_list, max_new_tokens=max_gen_tokens, num_beams=args["num_beams"]) + generation_result = model.generate(input_text_list, max_new_tokens=max_gen_tokens, num_beams=args["num_beams"], do_sample=False) end = time.perf_counter() generated_text = generation_result.texts perf_metrics = generation_result.perf_metrics