diff --git a/.github/workflows/causal_lm_cpp.yml b/.github/workflows/causal_lm_cpp.yml index 07352c0818..73ce951164 100644 --- a/.github/workflows/causal_lm_cpp.yml +++ b/.github/workflows/causal_lm_cpp.yml @@ -683,7 +683,7 @@ jobs: diff pred2.txt ref.txt echo "Chat sample python" passed - py-vlm_chat_sample-ubuntu: + visual_language_chat_sample-ubuntu: runs-on: ubuntu-22.04-16-cores steps: - uses: actions/checkout@v4 @@ -859,6 +859,7 @@ jobs: cpp-beam_search_causal_lm-Qwen-7B-Chat, cpp-beam_search_causal_lm-Qwen1_5-7B-Chat, cpp-beam_search_causal_lm-Phi-2, cpp-beam_search_causal_lm-notus-7b-v1, cpp-speculative_decoding_lm-ubuntu, cpp-prompt_lookup_decoding_lm-ubuntu, cpp-Phi-1_5, cpp-greedy_causal_lm-redpajama-3b-chat, cpp-chat_sample-ubuntu, cpp-continuous-batching-ubuntu, + visual_language_chat_sample-ubuntu, cpp-continuous-batching-windows, cpp-continuous-batching-macos] if: ${{ always() }} runs-on: ubuntu-latest diff --git a/samples/CMakeLists.txt b/samples/CMakeLists.txt index f6a94dfca9..2a8f26ff4d 100644 --- a/samples/CMakeLists.txt +++ b/samples/CMakeLists.txt @@ -25,7 +25,8 @@ install(DIRECTORY cpp/greedy_causal_lm cpp/multinomial_causal_lm # Don't install prompt_lookup_decoding_lm and speculative_decoding_lm because they don't use openvino_genai library and arent verifyed yet. - # Don't install continuous_batching_accuracy and continuous_batching_benchmark because they depend on json. + # Don't install continuous_batching_accuracy and continuous_batching_benchmark because CB isn't ready. + cpp/visual_language_chat cpp/whisper_speech_recognition cpp/text2image cpp/lora_greedy_causal_lm diff --git a/samples/cpp/visual_language_chat/CMakeLists.txt b/samples/cpp/visual_language_chat/CMakeLists.txt index 0df2b5ab5c..9a1b21632f 100644 --- a/samples/cpp/visual_language_chat/CMakeLists.txt +++ b/samples/cpp/visual_language_chat/CMakeLists.txt @@ -1,9 +1,11 @@ # Copyright (C) 2023-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -find_package(OpenVINOGenAI REQUIRED PATHS - "${CMAKE_BINARY_DIR}" # Reuse the package from the build. - ${OpenVINO_DIR} # GenAI may be installed alogside OpenVINO. +find_package(OpenVINOGenAI REQUIRED + PATHS + "${CMAKE_BINARY_DIR}" # Reuse the package from the build. + ${OpenVINO_DIR} # GenAI may be installed alogside OpenVINO. + NO_CMAKE_FIND_ROOT_PATH ) file(DOWNLOAD @@ -14,11 +16,11 @@ file(DOWNLOAD add_executable(visual_language_chat visual_language_chat.cpp load_image.cpp) target_include_directories(visual_language_chat PRIVATE "${CMAKE_CURRENT_SOUCE_DIR}" "${CMAKE_BINARY_DIR}") target_link_libraries(visual_language_chat PRIVATE openvino::genai) + set_target_properties(visual_language_chat PROPERTIES - COMPILE_PDB_NAME chat_sample + COMPILE_PDB_NAME visual_language_chat # Ensure out of box LC_RPATH on macOS with SIP INSTALL_RPATH_USE_LINK_PATH ON) -target_compile_features(visual_language_chat PRIVATE cxx_std_11) install(TARGETS visual_language_chat RUNTIME DESTINATION samples_bin/ diff --git a/samples/cpp/visual_language_chat/export_MiniCPM-V-2_6.py b/samples/cpp/visual_language_chat/export_MiniCPM-V-2_6.py index a08c3ad55b..7d2f0f1175 100644 --- a/samples/cpp/visual_language_chat/export_MiniCPM-V-2_6.py +++ b/samples/cpp/visual_language_chat/export_MiniCPM-V-2_6.py @@ -9,22 +9,58 @@ from transformers import AutoModel, AutoTokenizer, AutoProcessor, TextIteratorStreamer from transformers.generation import GenerationMixin from transformers import AutoConfig, GenerationConfig -from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPooling +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from pathlib import Path from huggingface_hub import snapshot_download import types -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Union from openvino.runtime import opset13 import openvino as ov import openvino_tokenizers import numpy as np import gc +from openvino.runtime.passes import Manager, MatcherPass, WrapType, Matcher +import time text_emb_path = Path("embed_tokens.xml") image_emb_path = Path("image_encoder.xml") resampler_path = Path("resampler.xml") llm_path = Path("language_model.xml") +class InsertSlice(MatcherPass): + def __init__(self): + MatcherPass.__init__(self) + self.model_changed = False + + param = WrapType("opset10.Result") + + def callback(matcher: Matcher) -> bool: + root = matcher.get_match_root() + if root is None: + return False + if len(root.get_output_partial_shape(0)) == 3: + parent = root.input_value(0).get_node() + grand_parent = parent.input_value(0).get_node() + + grand_parent_output = parent.input(0).get_source_output() + consumers = grand_parent_output.get_target_inputs() + start = np.array([0, -1, 0], dtype=np.int32) + stop = np.array([1, -2, grand_parent_output.get_partial_shape()[-1].get_length()], dtype=np.int32) + step = np.array([1, -1, 1], dtype=np.int32) + axes = np.array([0, 1, 2], dtype=np.int32) + slice = opset13.slice(grand_parent, start, stop, step, axes, name="inserted_slice") + for consumer in consumers: + consumer.replace_source_output(slice.output(0)) + self.model_changed = True + # Use new operation for additional matching + self.register_new_node(slice) + print("applied slice for lm head") + + return True + + self.register_matcher(Matcher(param, "InsertSlice"), callback) + def model_has_state(ov_model: ov.Model): return len(ov_model.get_sinks()) > 0 @@ -324,13 +360,151 @@ def convert_vision_encoder(model, model_dir): tgt_sizes = torch.tensor([[23, 45]]) if not (model_dir / image_emb_path).exists(): print("⌛ Convert Image embedding model") + def siglip_vis_embed_forward( + self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None, + position_ids: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + if position_ids is None: + batch_size = pixel_values.size(0) + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full( + size=( + batch_size, + max_nb_patches_h * max_nb_patches_w, + ), + fill_value=0, + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + def siglip_attn_forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, key_states, value_states, attention_mask, is_causal=attention_mask is None + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None + + def siglip_transformer_forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + tgt_sizes: Optional[torch.IntTensor] = None, + position_ids: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings( + pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, tgt_sizes=tgt_sizes, position_ids=position_ids + ) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) if not self._use_flash_attention_2 else patch_attention_mask + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + if not return_dict: + return (last_hidden_state, None) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=None, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + vpm = model.vpm + vpm.embeddings.forward = types.MethodType(siglip_vis_embed_forward, vpm.embeddings) + for layer in vpm.encoder.layers: + layer.self_attn.forward = types.MethodType(siglip_attn_forward, layer.self_attn) + vpm.forward = types.MethodType(siglip_transformer_forward, vpm) + pixel_values = torch.randn([1, 3, 14, 14490]) patch_attn_mask = torch.zeros((1, 1, 1035), dtype=torch.bool) patch_attn_mask[0, 0, : tgt_sizes[0][0] * tgt_sizes[0][1]] = True - ov_model = ov.convert_model(model.vpm, example_input={"pixel_values": pixel_values, "tgt_sizes": tgt_sizes, "patch_attention_mask": patch_attn_mask}) + position_ids = prepare_vis_position_ids( + pixel_values, patch_attn_mask, tgt_sizes, model.config.vision_config.patch_size, model.config.vision_config.image_size // model.config.patch_size + ) + ov_model = ov.convert_model(vpm, example_input={"pixel_values": pixel_values, "position_ids": position_ids, "patch_attention_mask": patch_attn_mask}) ov.save_model(ov_model, model_dir / image_emb_path) del ov_model cleanup_torchscript_cache() + gc.collect() print("✅ Image embedding model successfully converted") if not (model_dir / resampler_path).exists(): @@ -343,7 +517,9 @@ def resampler_forward(self, x, pos_embed, key_padding_mask): q = self.ln_q(self.query) # Q * D - out = self.attn(self._repeat(q, bs), x + pos_embed, x, key_padding_mask=key_padding_mask)[0] # Q * B * D # L * B * D + L * B * D + q_bs = q.unsqueeze(1).repeat(1, bs, 1) + + out = self.attn(q_bs, x + pos_embed, x, key_padding_mask=key_padding_mask)[0] # Q * B * D # L * B * D + L * B * D # out: Q * B * D x = out.permute(1, 0, 2) # B * Q * D @@ -369,6 +545,8 @@ def resampler_forward(self, x, pos_embed, key_padding_mask): ov.save_model(ov_model, model_dir / resampler_path) del ov_model cleanup_torchscript_cache() + del model.resampler + gc.collect() print("✅ Resampler model successfully converted") @@ -380,11 +558,38 @@ def copy_llm_files(model_dir, dst_dir): shutil.copy(model_dir / llm_path.parent / "modeling_navit_siglip.py", model_dir / dst_dir / "modeling_navit_siglip.py") +def prepare_vis_position_ids(pixel_values, patch_attention_mask, tgt_sizes, patch_size, num_patches_per_side): + batch_size = pixel_values.size(0) + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = max_im_h // patch_size, max_im_w // patch_size + boundaries = torch.arange(1 / num_patches_per_side, 1.0, 1 / num_patches_per_side) + position_ids = torch.full(size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + if tgt_sizes is not None: + nb_patches_h = tgt_sizes[batch_idx][0] + nb_patches_w = tgt_sizes[batch_idx][1] + else: + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + return position_ids + + core = ov.Core() class OvModelForCausalLMWithEmb(GenerationMixin): - def __init__(self, model_dir, device="CPU", ov_config=None, compile=True) -> None: + def __init__(self, model_dir, device="CPU", ov_config=None, compile=True, slice_lm_head=True) -> None: self._supports_cache_class = False self.config = AutoConfig.from_pretrained(model_dir, trust_remote_code=True) self.config.is_decoder = True @@ -393,6 +598,8 @@ def __init__(self, model_dir, device="CPU", ov_config=None, compile=True) -> Non model_dir = Path(model_dir) self.model = core.read_model(model_dir / "language_model.xml") self.token_emb = core.read_model(model_dir / "embed_tokens.xml") + if slice_lm_head: + self.slice_lm_head() self.request = None self.token_emb_request = None self._device = device.upper() @@ -402,9 +609,16 @@ def __init__(self, model_dir, device="CPU", ov_config=None, compile=True) -> Non self._past_length = None self.input_names = [input_t.get_any_name() for input_t in self.model.inputs] self.main_input_name = "input_ids" + self.llm_times = [] if compile: self.compile() + def slice_lm_head(self): + manager = Manager() + manager.register_pass(InsertSlice()) + manager.run_passes(self.model) + self.model.validate_nodes_and_infer_types() + def compile(self): if self.request is None: self.request = core.compile_model(self.model, self._device, self.ov_config).create_infer_request() @@ -446,6 +660,7 @@ def prepare_inputs( inputs = {} # past_key_values are not used explicitly, instead they are handled inside the model if past_key_values is None: + self.llm_times = [] # This is the first iteration in a sequence, reset all states if self.request is not None: self.request.reset_state() @@ -657,20 +872,39 @@ def get_vllm_embedding(self, data): for i in range(B): patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True - vision_batch_size = 1 + vision_batch_size = 32 all_pixel_values = all_pixel_values if B > vision_batch_size: hs = [] for i in range(0, B, vision_batch_size): start_idx = i end_idx = i + vision_batch_size - tmp_hs = torch.from_numpy( - self.vpm([all_pixel_values[start_idx:end_idx], patch_attn_mask[start_idx:end_idx], tgt_sizes[start_idx:end_idx]])[0] + block_pxl_values = all_pixel_values[start_idx:end_idx] + block_patch_attn_mask = patch_attn_mask[start_idx:end_idx] + block_tgt_sizes = tgt_sizes[start_idx:end_idx] + block_position_ids = prepare_vis_position_ids( + block_pxl_values, + block_patch_attn_mask, + block_tgt_sizes, + self.config.vision_config.patch_size, + self.config.vision_config.image_size // self.config.patch_size, ) + start = time.perf_counter() + tmp_hs = torch.from_numpy(self.vpm([block_pxl_values, block_patch_attn_mask, block_position_ids])[0]) + self.vpm_times.append(time.perf_counter() - start) hs.append(tmp_hs) vision_embedding = torch.cat(hs, dim=0) else: - vision_embedding = torch.from_numpy(self.vpm([all_pixel_values, patch_attn_mask, tgt_sizes])[0]) + position_ids = prepare_vis_position_ids( + all_pixel_values, + patch_attn_mask, + tgt_sizes, + self.config.vision_config.patch_size, + self.config.vision_config.image_size // self.config.patch_size, + ) + start = time.perf_counter() + vision_embedding = torch.from_numpy(self.vpm([all_pixel_values, patch_attn_mask, position_ids])[0]) + vision_embedding = torch.from_numpy(self.vpm([all_pixel_values, patch_attn_mask, position_ids])[0]) vision_embedding = self.resampler(vision_embedding, tgt_sizes) start = 0 @@ -801,6 +1035,8 @@ def chat( use_image_id=None, **kwargs, ): + self.vpm_times = [] + self.resampler_times = [] if isinstance(msgs[0], list): batched = True else: @@ -844,7 +1080,6 @@ def chat( copy_msgs = deepcopy(msgs) assert len(msgs) > 0, "msgs is empty" - assert sampling or not stream, "if use stream mode, make sure sampling=True" if image is not None and isinstance(copy_msgs[0]["content"], str): copy_msgs[0]["content"] = [image, copy_msgs[0]["content"]] @@ -882,7 +1117,6 @@ def chat( generation_config = {"top_p": 0.8, "top_k": 100, "temperature": 0.7, "do_sample": True, "repetition_penalty": 1.05} else: generation_config = { - "num_beams": 3, "repetition_penalty": 1.2, } @@ -958,8 +1192,8 @@ def main(): gc.collect() convert_vision_encoder(model, model_dir) - ov_cpm = init_model(model_dir, "CPU") - print(ov_cpm.chat(Image.open(requests.get("https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11", stream=True).raw), [{"role": "user", "content": "What is unusual on this image?"}], ov_cpm.processor.tokenizer)) + # ov_cpm = init_model(model_dir, "CPU") + # print(ov_cpm.chat(Image.open(requests.get("https://github.com/openvinotoolkit/openvino_notebooks/assets/29454499/d5fbbd1a-d484-415c-88cb-9986625b7b11", stream=True).raw), [{"role": "user", "content": "What is unusual on this image?"}], ov_cpm.processor.tokenizer)) if "__main__" == __name__: main() diff --git a/samples/cpp/visual_language_chat/load_image.cpp b/samples/cpp/visual_language_chat/load_image.cpp index 85fe7e2fbe..855f7567bf 100644 --- a/samples/cpp/visual_language_chat/load_image.cpp +++ b/samples/cpp/visual_language_chat/load_image.cpp @@ -13,7 +13,7 @@ ov::Tensor utils::load_image(const std::filesystem::path& image_path) { image_path.string().c_str(), &x, &y, &channels_in_file, desired_channels); if (!image) { - throw std::runtime_error{"Failed to load the image"}; + throw std::runtime_error{"Failed to load the image."}; } struct SharedImageAllocator { unsigned char* image; @@ -22,11 +22,11 @@ ov::Tensor utils::load_image(const std::filesystem::path& image_path) { if (channels * height * width == bytes) { return image; } - throw std::runtime_error{"Unexpected number of bytes was requested to allocate"}; + throw std::runtime_error{"Unexpected number of bytes was requested to allocate."}; } void deallocate(void*, size_t bytes, size_t) { if (channels * height * width != bytes) { - throw std::runtime_error{"Unexpected number of bytes was requested to deallocate"}; + throw std::runtime_error{"Unexpected number of bytes was requested to deallocate."}; } std::free(image); image = nullptr; diff --git a/samples/requirements.txt b/samples/requirements.txt index 18145bed85..4821d6dbef 100644 --- a/samples/requirements.txt +++ b/samples/requirements.txt @@ -3,5 +3,4 @@ optimum[openvino]==1.22.0 einops==0.8.0 # For Qwen transformers_stream_generator==0.0.5 # For Qwen diffusers==0.30.3 -pillow torchvision # needed for mini-CPM export script. Need to remove when we switch to exporting with optimum-intel. diff --git a/src/cpp/include/openvino/genai/processor_config.hpp b/src/cpp/include/openvino/genai/processor_config.hpp index 9a70d1f3ae..bef6754e14 100644 --- a/src/cpp/include/openvino/genai/processor_config.hpp +++ b/src/cpp/include/openvino/genai/processor_config.hpp @@ -14,6 +14,7 @@ namespace ov::genai { /// preprocessor_config.json. class OPENVINO_GENAI_EXPORTS ProcessorConfig { public: + size_t image_size = 980; /// @brief Dimensions of the smaller, non-overlapping patches that the /// input image is divided into before being fed into the /// transformer model. Used to divide image height and width. diff --git a/src/cpp/src/clip.hpp b/src/cpp/src/clip.hpp index c8965a4890..99c06a05d2 100644 --- a/src/cpp/src/clip.hpp +++ b/src/cpp/src/clip.hpp @@ -25,6 +25,8 @@ struct clip_ctx { std::vector buf_compute_meta; projector_type proj_type = PROJECTOR_TYPE_RESAMPLER; + size_t patch_size = 0; + size_t image_size = 0; }; // RGB uint8 image diff --git a/src/cpp/src/llm_pipeline.cpp b/src/cpp/src/llm_pipeline.cpp index 4a46b525e7..ff7ceb051e 100644 --- a/src/cpp/src/llm_pipeline.cpp +++ b/src/cpp/src/llm_pipeline.cpp @@ -18,24 +18,6 @@ #include "openvino/genai/lora_adapter.hpp" #include "lora_helper.hpp" -namespace { - -ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& fisrt, const ov::genai::TokenizedInputs& second){ - auto first_size = fisrt.input_ids.get_size(); - auto second_size = second.input_ids.get_size(); - ov::Shape new_shape{1, first_size - second_size}; - - ov::Tensor new_input_ids(ov::element::i64, new_shape); - auto data_ptr = fisrt.input_ids.data(); - std::copy(data_ptr + second_size, data_ptr + first_size, new_input_ids.data()); - - ov::Tensor new_attention_mask(ov::element::i64, new_shape); - std::fill_n(new_attention_mask.data(), new_shape[1], 1); - - return {new_input_ids, new_attention_mask}; -} -} - namespace ov { namespace genai { @@ -153,7 +135,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase { encoded_input = new_chat_tokens; } else { auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(add_special_tokens_)); - encoded_input = subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens); + encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens); } m_templated_chat_history = new_templated_chat_history; // TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied diff --git a/src/cpp/src/utils.hpp b/src/cpp/src/utils.hpp index c149bb308f..fe6e4eed14 100644 --- a/src/cpp/src/utils.hpp +++ b/src/cpp/src/utils.hpp @@ -86,6 +86,20 @@ ProcessorConfig from_any_map( std::pair split_core_complile_config(const ov::AnyMap& plugin_config); +inline ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& fisrt, const ov::genai::TokenizedInputs& second){ + auto first_size = fisrt.input_ids.get_size(); + auto second_size = second.input_ids.get_size(); + ov::Shape new_shape{1, first_size - second_size}; + + ov::Tensor new_input_ids(ov::element::i64, new_shape); + auto data_ptr = fisrt.input_ids.data(); + std::copy(data_ptr + second_size, data_ptr + first_size, new_input_ids.data()); + + ov::Tensor new_attention_mask(ov::element::i64, new_shape); + std::fill_n(new_attention_mask.data(), new_shape[1], 1); + + return {new_input_ids, new_attention_mask}; +} } // namespace utils } // namespace genai } // namespace ov diff --git a/src/cpp/src/vision_encoder.cpp b/src/cpp/src/vision_encoder.cpp index a35a5d8db7..05539b67dc 100644 --- a/src/cpp/src/vision_encoder.cpp +++ b/src/cpp/src/vision_encoder.cpp @@ -216,6 +216,65 @@ ov::Tensor preprocess_for_encoder(const ov::Tensor& images, size_t kernel) { return permuted_tensor; } +// torch.bucketize(fractional_coords, boundaries, right=True) +std::vector bucket_size_right(const std::vector& fractional_coords, const std::vector& boundaries) { + std::vector bucket_coords(fractional_coords.size()); + std::transform(fractional_coords.begin(), fractional_coords.end(), bucket_coords.begin(), [&boundaries](float fractional_coord) { + return std::distance(boundaries.begin(), std::upper_bound(boundaries.begin(), boundaries.end(), fractional_coord)); + }); + return bucket_coords; +} + +ov::Tensor prepare_vis_position_ids( + const ov::Tensor& pixel_values, + const ov::Tensor& patch_attention_mask, + const std::vector tgt_sizes, + size_t patch_size, + size_t num_patches_per_side +) { + size_t batch_size = pixel_values.get_shape().at(0); + size_t max_im_h = pixel_values.get_shape().at(2), max_im_w = pixel_values.get_shape().at(3); + size_t max_nb_patches_h = max_im_h / patch_size, max_nb_patches_w = max_im_w / patch_size; + std::vector boundaries(1.0f * num_patches_per_side - 1); + std::generate(boundaries.begin(), boundaries.end(), [num_patches_per_side, val = 0.0f]() mutable { + val += 1.0f / num_patches_per_side; + return val; + }); + size_t position_ids_batch_elem = max_nb_patches_h * max_nb_patches_w; + ov::Tensor position_ids{ov::element::i64, {batch_size, position_ids_batch_elem}}; + // throw std::runtime_error(""); + int64_t* res_data = position_ids.data(); + std::fill_n(res_data, position_ids.get_size(), 0); + + for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + size_t nb_patches_h = tgt_sizes.at(batch_idx).height; + size_t nb_patches_w = tgt_sizes.at(batch_idx).width; + + std::vector fractional_coords_h(nb_patches_h); + std::generate(fractional_coords_h.begin(), fractional_coords_h.end(), [nb_patches_h, val = -1.0f / nb_patches_h]() mutable { + val += 1.0f / nb_patches_h; + return val; + }); + std::vector fractional_coords_w(nb_patches_w); + std::generate(fractional_coords_w.begin(), fractional_coords_w.end(), [nb_patches_w, val = -1.0f / nb_patches_w]() mutable { + val += 1.0f / nb_patches_w; + return val; + }); + + std::vector bucket_coords_h = bucket_size_right(fractional_coords_h, boundaries); + std::vector bucket_coords_w = bucket_size_right(fractional_coords_w, boundaries); + + std::vector pos_ids(bucket_coords_h.size() * bucket_coords_w.size()); + for (size_t col = 0; col < bucket_coords_h.size(); ++col) { + for (size_t row = 0; row < bucket_coords_w.size(); ++row) {; + pos_ids.at(col * bucket_coords_w.size() + row) = bucket_coords_h.at(col) * num_patches_per_side + bucket_coords_w.at(row); + } + } + std::copy(pos_ids.begin(), pos_ids.end(), res_data + batch_idx * position_ids_batch_elem); + } + return position_ids; +} + EncodedImage llava_image_embed_make_with_bytes_slice(clip_ctx& ctx_clip, const ov::Tensor& img, ov::InferRequest& encoder, int max_slice_nums, int scale_resolution, size_t patch_size, bool never_split) { clip_image_u8 source{ int(img.get_shape().at(3)), @@ -244,14 +303,11 @@ EncodedImage llava_image_embed_make_with_bytes_slice(clip_ctx& ctx_clip, const o ov::Tensor patch_attention_mask{ov::element::boolean, {pixel_values.get_shape().at(0), 1, resized_source_size.height * resized_source_size.width}}; std::fill_n(patch_attention_mask.data(), patch_attention_mask.get_size(), true); encoder.set_tensor("patch_attention_mask", patch_attention_mask); - ov::Tensor tgt_sizes{ov::element::i64, {1, 2}}; - int64_t* tgt_sizes_data = tgt_sizes.data(); - tgt_sizes_data[0] = resized_source_size.height; - tgt_sizes_data[1] = resized_source_size.width; - encoder.set_tensor("tgt_sizes", tgt_sizes); + ov::Tensor position_ids = prepare_vis_position_ids(pixel_values, patch_attention_mask, {resized_source_size}, ctx_clip.patch_size, ctx_clip.image_size / ctx_clip.patch_size); + encoder.set_tensor("position_ids", position_ids); encoder.infer(); const ov::Tensor& output_tensor = encoder.get_output_tensor(); - ov::Tensor resized_source{output_tensor.get_element_type(), output_tensor.get_shape()}; + ov::Tensor resized_source{ov::element::f32, output_tensor.get_shape()}; output_tensor.copy_to(resized_source); if (1 == preprocessed.size()) { @@ -266,27 +322,24 @@ EncodedImage llava_image_embed_make_with_bytes_slice(clip_ctx& ctx_clip, const o size_t n_patches = size.height / patch_size * size.width / patch_size, old_hidden_size = resized_source.get_shape().at(2); ov::Tensor encoded_slices{ov::element::f32, {preprocessed.size() - 1, preprocessed.at(1).size(), n_patches, old_hidden_size}}; - // там внутри есть какая-то операция которая констант фолдит батч и из-за этого нельзя использовать отличный от того что был при экспорте - // констант фолдит она его в торч скрипте - // Even though batch can't be used, it's still possible to use async. for (size_t row = 1; row < preprocessed.size(); ++row) { for (size_t col = 0; col < preprocessed.at(row).size(); ++col) { clip_image_f32& elem = preprocessed.at(row).at(col); sliced_sizes.push_back({elem.ny / patch_size, elem.nx / patch_size}); - encoder.set_tensor("pixel_values", preprocess_for_encoder( + ov::Tensor pixel_values = preprocess_for_encoder( {ov::element::f32, {1, 3, size_t(elem.ny), size_t(elem.nx)}, elem.buf.data()}, patch_size - )); + ); + encoder.set_tensor("pixel_values", pixel_values); ov::Tensor patch_attention_mask{ov::element::boolean, {1, 1, sliced_sizes.back().height * sliced_sizes.back().width}}; std::fill_n(patch_attention_mask.data(), patch_attention_mask.get_size(), true); encoder.set_tensor("patch_attention_mask", patch_attention_mask); - ov::Tensor tgt_sizes{ov::element::i64, {1, 2}}; - int64_t* tgt_sizes_data = tgt_sizes.data(); - tgt_sizes_data[0] = sliced_sizes.back().height; - tgt_sizes_data[1] = sliced_sizes.back().width; - encoder.set_tensor("tgt_sizes", tgt_sizes); + ov::Tensor position_ids = prepare_vis_position_ids(pixel_values, patch_attention_mask, {sliced_sizes.back()}, ctx_clip.patch_size, ctx_clip.image_size / ctx_clip.patch_size); + encoder.set_tensor("position_ids", position_ids); + const ov::Tensor& old = encoder.get_output_tensor(); encoder.set_output_tensor({ov::element::f32, {1, n_patches, old_hidden_size}, encoded_slices.data() + ((row - 1) * preprocessed.at(row).size() + col) * n_patches * old_hidden_size}); encoder.infer(); + encoder.set_output_tensor(old); } } return {resized_source, resized_source_size, encoded_slices, sliced_sizes}; @@ -305,6 +358,8 @@ VisionEncoder::VisionEncoder(const std::filesystem::path& model_dir, const std:: EncodedImage VisionEncoder::encode(const ov::Tensor& image, const ProcessorConfig& config) { clip_ctx ctx_clip; + ctx_clip.patch_size = m_processor_config.patch_size; + ctx_clip.image_size = m_processor_config.image_size; std::copy(config.norm_mean.begin(), config.norm_mean.end(), ctx_clip.image_mean); std::copy(config.norm_std.begin(), config.norm_std.end(), ctx_clip.image_std); return llava_image_embed_make_with_bytes_slice(ctx_clip, image, m_encoder, config.max_slice_nums, config.scale_resolution, config.patch_size, 0 == config.max_slice_nums); diff --git a/src/cpp/src/vlm_pipeline.cpp b/src/cpp/src/vlm_pipeline.cpp index 5910d54bd8..99c38c976d 100644 --- a/src/cpp/src/vlm_pipeline.cpp +++ b/src/cpp/src/vlm_pipeline.cpp @@ -338,12 +338,11 @@ DecodedResults VLMPipeline::generate( const StreamerVariant& streamer ) { std::string images_prompt; - EncodedImage embeds; - if (!rgbs.empty()) { - OPENVINO_ASSERT(1 == rgbs.size(), "TODO: Only a single image allowed"); - embeds = m_vision_encoder.encode(rgbs.at(0)); + std::vector embeds; + for (const ov::Tensor& rgb : rgbs) { + EncodedImage encoded_image = m_vision_encoder.encode(rgb); if (m_vlm_config.use_image_id) { - images_prompt = m_vlm_config.im_id_start + std::to_string(image_id) + m_vlm_config.im_id_end; + images_prompt += m_vlm_config.im_id_start + std::to_string(image_id) + m_vlm_config.im_id_end; ++image_id; } std::string unk64; @@ -351,8 +350,8 @@ DecodedResults VLMPipeline::generate( unk64 += m_vlm_config.unk; } images_prompt += m_vlm_config.im_start + unk64 + m_vlm_config.im_end; - if (embeds.slices) { - ov::Shape slices_shape = embeds.slices.get_shape(); + if (encoded_image.slices) { + ov::Shape slices_shape = encoded_image.slices.get_shape(); for (size_t row_idx = 0; row_idx < slices_shape.at(0); ++row_idx) { for (size_t col_idx = 0; col_idx < slices_shape.at(1); ++col_idx) { images_prompt += m_vlm_config.slice_start + unk64 + m_vlm_config.slice_end; @@ -365,9 +364,10 @@ DecodedResults VLMPipeline::generate( // Strangely, \n isn't placed between . images_prompt += '\n'; } + embeds.push_back(std::move(encoded_image)); } images_prompt += prompt; - std::string new_templated_chat_history; + ov::Tensor encoded_input; if (m_is_chat_conversation) { // KV cache in model already contains prompts and answers from previous iterations. // So only new prompt wrapped into chat template to be sent into model. Tokenizer always returns @@ -379,8 +379,29 @@ DecodedResults VLMPipeline::generate( // KV cache contains it. So we have to add it manually or get it by tokenization all chat history. m_history.push_back({{"role", "user"}, {"content", images_prompt}}); constexpr bool add_generation_prompt = true; - new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); + std::string new_templated_chat_history = m_tokenizer.apply_chat_template(m_history, add_generation_prompt); + ov::Tensor new_chat_tokens = m_tokenizer.encode(new_templated_chat_history).input_ids; + if (0 == m_language.get_tensor("attention_mask").get_shape().at(1)) { + encoded_input = new_chat_tokens; + } else { + TokenizedInputs prev_chat_tokens = m_tokenizer.encode( + m_templated_chat_history + ); + encoded_input = utils::subtract_chat_tokenized_inputs( + {new_chat_tokens}, prev_chat_tokens + ).input_ids; + } + m_templated_chat_history = std::move(new_templated_chat_history); + } else { + encoded_input = m_tokenizer.encode(images_prompt).input_ids; } + m_embedding.set_input_tensor(encoded_input); + m_embedding.infer(); + ov::Tensor inputs_embeds = m_embedding.get_output_tensor(); + OPENVINO_ASSERT( + m_vlm_config.hidden_size == inputs_embeds.get_shape().at(2), + "Unexpected embedding size" + ); ov::Tensor special_tokens = m_tokenizer.encode( m_vlm_config.im_start + m_vlm_config.im_end @@ -391,59 +412,37 @@ DecodedResults VLMPipeline::generate( 4 == special_tokens.get_shape().at(1), "Every special token must be represented with a single int." ); - size_t im_start_id = special_tokens.data()[0]; - size_t im_end_id = special_tokens.data()[1]; - size_t slice_start_id = special_tokens.data()[2]; - size_t slice_end_id = special_tokens.data()[3]; - ov::Tensor input_ids = m_tokenizer.encode(new_templated_chat_history).input_ids; - m_embedding.set_input_tensor(input_ids); - m_embedding.infer(); - ov::Tensor inputs_embeds = m_embedding.get_output_tensor(); - OPENVINO_ASSERT( - m_vlm_config.hidden_size == inputs_embeds.get_shape().at(2), - "Unexpected embedding size" - ); - if (!rgbs.empty()) { - int64_t* ids = input_ids.data(); - const ov::Tensor& resampled_source = resample(*this, embeds.resized_source, {embeds.resized_source_size}); + int64_t im_start_id = special_tokens.data()[0]; + int64_t im_end_id = special_tokens.data()[1]; + int64_t slice_start_id = special_tokens.data()[2]; + int64_t slice_end_id = special_tokens.data()[3]; + int64_t im_start_pos = 0, slice_start_pos = 0; + int64_t* begin = encoded_input.data(); + int64_t* ids = begin; + size_t encoded_input_size = encoded_input.get_size(); + int64_t* end = ids + encoded_input_size; + float* inputs_embeds_data = inputs_embeds.data(); + for (const EncodedImage& encoded_image : embeds) { + const ov::Tensor& resampled_source = resample(*this, encoded_image.resized_source, {encoded_image.resized_source_size}); float* emb = resampled_source.data(); - bool replacing = false; - for (size_t token_idx = 0; token_idx < inputs_embeds.get_shape().at(1); ++token_idx) { - if (im_start_id == ids[token_idx]) { - replacing = true; - } - if (replacing) { - std::copy_n(emb, resampled_source.get_size(), inputs_embeds.data() + token_idx * m_vlm_config.hidden_size); - token_idx += resampled_source.get_shape().at(1); - replacing = false; - break; - } - } - if (embeds.slices) { + ids = std::find(ids, end, im_start_id); + OPENVINO_ASSERT(end != ids); + std::copy_n(emb, resampled_source.get_size(), inputs_embeds_data + std::distance(begin, ids) * m_vlm_config.hidden_size); + ids += m_vlm_config.query_num; + if (encoded_image.slices) { size_t token_idx = 0; - const ov::Shape& slices_shape = embeds.slices.get_shape(); - const std::vector& sliced_sizes = embeds.slices_sizes; + const ov::Shape& slices_shape = encoded_image.slices.get_shape(); + const std::vector& sliced_sizes = encoded_image.slices_sizes; for (size_t i = 0; i < slices_shape.at(0); ++i) { for (size_t ja = 0; ja < slices_shape.at(1); ++ja) { size_t d2 = slices_shape.at(2); size_t d3 = slices_shape.at(3); - ov::Tensor encoded_view{ov::element::f32, {1, d2, d3}, embeds.slices.data() + (i * slices_shape.at(1) + ja) * d2 * d3}; + ov::Tensor encoded_view{ov::element::f32, {1, d2, d3}, encoded_image.slices.data() + (i * slices_shape.at(1) + ja) * d2 * d3}; const ov::Tensor& vision_embed_tensor_i_j = resample(*this, encoded_view, {sliced_sizes.at(i * slices_shape.at(1) + ja)}); - for (; token_idx < inputs_embeds.get_shape().at(1); ++token_idx) { - if (slice_start_id == ids[token_idx]) { - replacing = true; - } - if (slice_end_id == ids[token_idx]) { - replacing = false; - break; - } - if (replacing) { - std::copy_n(vision_embed_tensor_i_j.data(), vision_embed_tensor_i_j.get_size(), inputs_embeds.data() + token_idx * m_vlm_config.hidden_size); - token_idx += vision_embed_tensor_i_j.get_shape().at(1); - replacing = false; - break; - } - } + ids = std::find(ids, end, slice_start_id); + OPENVINO_ASSERT(end != ids); + std::copy_n(vision_embed_tensor_i_j.data(), vision_embed_tensor_i_j.get_size(), inputs_embeds_data + std::distance(begin, ids) * m_vlm_config.hidden_size); + ids += m_vlm_config.query_num; } } } @@ -519,39 +518,19 @@ DecodedResults VLMPipeline::generate( streamer_ptr->end(); } + std::string decoded_results = m_tokenizer.decode(generated); if (m_is_chat_conversation) { - // auto new_chat_tokens = m_tokenizer.encode(new_templated_chat_history); - // if (m_is_cache_empty) { - // encoded_input = new_chat_tokens; - // } else { - // auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history); - // encoded_input = subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens); - // } - // m_templated_chat_history = new_templated_chat_history; + // Tail of chat template is missing in KV cache. + // Find the tail to concatenate it with the next input prompt. + m_templated_chat_history.append(decoded_results); + m_history.push_back({{"role", "assistant"}, {"content", decoded_results}}); } else { for (auto& variable : m_language.query_state()) { variable.reset(); } m_language.get_tensor("attention_mask").set_shape({1, 0}); - } - DecodedResults results; - results.texts = {m_tokenizer.decode(generated)}; - - // TODO: implement performance metrics - results.perf_metrics = ov::genai::PerfMetrics(); - results.perf_metrics.m_evaluated = false; - results.perf_metrics.generate_duration = {0, 0}; - results.perf_metrics.inference_duration= {0, 0}; - results.perf_metrics.tokenization_duration = {0, 0}; - results.perf_metrics.detokenization_duration= {0, 0}; - results.perf_metrics.ttft = {0, 0}; - results.perf_metrics.tpot= {0, 0}; - results.perf_metrics.ipot= {0, 0}; - results.perf_metrics.throughput= {0, 0}; - results.perf_metrics.num_generated_tokens = generated.size(); - results.perf_metrics.num_input_tokens= 0; - - return results; + } + return {{std::move(decoded_results)}}; } DecodedResults VLMPipeline::generate( @@ -559,13 +538,23 @@ DecodedResults VLMPipeline::generate( const ov::AnyMap& config_map ) { auto image = config_map.find(ov::genai::image.name()); + auto images = config_map.find(ov::genai::images.name()); + OPENVINO_ASSERT( + config_map.end() == image || config_map.end() == images, + "Only one property can be set: image of images." + ); + std::vector rgbs; + if (config_map.end() != image) { + rgbs = {image->second.as()}; + } if (config_map.end() != images) { + rgbs = images->second.as>(); + } ov::genai::OptionalGenerationConfig config_arg = utils::get_config_from_map(config_map); GenerationConfig config = (config_arg.has_value()) ? *config_arg : get_generation_config(); config.update_generation_config(config_map); return generate( prompt, - config_map.end() == image ? std::vector{} - : std::vector{image->second.as()}, + rgbs, config, utils::get_streamer_from_map(config_map) ); diff --git a/src/docs/BUILD.md b/src/docs/BUILD.md index 79d6ce861a..77657620a0 100644 --- a/src/docs/BUILD.md +++ b/src/docs/BUILD.md @@ -43,11 +43,11 @@ OpenVINO GenAI can be built as an extra module during the OpenVINO build process 1. Clone OpenVINO and OpenVINO GenAI repositories: ```sh git clone --recursive https://github.com/openvinotoolkit/openvino.git - git clone --recursive https://github.com/openvinotoolkit/openvino_genai.git + git clone --recursive https://github.com/openvinotoolkit/openvino.genai.git ``` 2. Configure CMake with OpenVINO extra modules: ```sh - cmake -DOPENVINO_EXTRA_MODULES=./openvino_genai -DCPACK_ARCHIVE_COMPONENT_INSTALL=OFF -S ./openvino -B ./build + cmake -DOPENVINO_EXTRA_MODULES=./openvino.genai -DCPACK_ARCHIVE_COMPONENT_INSTALL=OFF -S ./openvino -B ./build ``` 3. Build OpenVINO archive with GenAI: ```sh diff --git a/thirdparty/openvino_tokenizers b/thirdparty/openvino_tokenizers index b6c36a3026..e74460f8b7 160000 --- a/thirdparty/openvino_tokenizers +++ b/thirdparty/openvino_tokenizers @@ -1 +1 @@ -Subproject commit b6c36a302696329f008e4425c9d98c4e00194a24 +Subproject commit e74460f8b78c26ad46ccaccc0ee34d7ccccf56f7