diff --git a/src/cpp/src/image_generation/models/clip_text_model.cpp b/src/cpp/src/image_generation/models/clip_text_model.cpp index efbc840d4f..72fdc63082 100644 --- a/src/cpp/src/image_generation/models/clip_text_model.cpp +++ b/src/cpp/src/image_generation/models/clip_text_model.cpp @@ -118,13 +118,20 @@ ov::Tensor CLIPTextModel::infer(const std::string& pos_prompt, const std::string const size_t text_embedding_batch_size = do_classifier_free_guidance ? 2 : 1; auto perform_tokenization = [&](const std::string& prompt, ov::Tensor input_ids) { - std::fill_n(input_ids.data(), input_ids.get_size(), pad_token_id); - ov::Tensor input_ids_token = m_clip_tokenizer.encode(prompt).input_ids; - std::copy_n(input_ids_token.data(), input_ids_token.get_size(), input_ids.data()); + + if (input_ids.get_element_type() == ov::element::i32) { + std::fill_n(input_ids.data(), input_ids.get_size(), pad_token_id); + std::copy_n(input_ids_token.data(), input_ids_token.get_size(), input_ids.data()); + } else { + std::fill_n(input_ids.data(), input_ids.get_size(), pad_token_id); + std::copy_n(input_ids_token.data(), input_ids_token.get_size(), input_ids.data()); + } }; - ov::Tensor input_ids(ov::element::i32, {text_embedding_batch_size, m_config.max_position_embeddings}); + ov::Tensor input_ids = m_request.get_input_tensor(); + input_ids.set_shape({text_embedding_batch_size, m_config.max_position_embeddings}); + size_t current_batch_idx = 0; if (do_classifier_free_guidance) { @@ -141,7 +148,6 @@ ov::Tensor CLIPTextModel::infer(const std::string& pos_prompt, const std::string {current_batch_idx + 1, m_config.max_position_embeddings})); // text embeddings - m_request.set_tensor("input_ids", input_ids); m_request.infer(); return m_request.get_output_tensor(0); diff --git a/src/cpp/src/image_generation/models/clip_text_model_with_projection.cpp b/src/cpp/src/image_generation/models/clip_text_model_with_projection.cpp index 982800a701..1160c30b6a 100644 --- a/src/cpp/src/image_generation/models/clip_text_model_with_projection.cpp +++ b/src/cpp/src/image_generation/models/clip_text_model_with_projection.cpp @@ -109,13 +109,20 @@ ov::Tensor CLIPTextModelWithProjection::infer(const std::string& pos_prompt, con const size_t text_embedding_batch_size = do_classifier_free_guidance ? 2 : 1; auto perform_tokenization = [&](const std::string& prompt, ov::Tensor input_ids) { - std::fill_n(input_ids.data(), input_ids.get_size(), pad_token_id); - ov::Tensor input_ids_token = m_clip_tokenizer.encode(prompt).input_ids; - std::copy_n(input_ids_token.data(), input_ids_token.get_size(), input_ids.data()); + + if (input_ids.get_element_type() == ov::element::i32) { + std::fill_n(input_ids.data(), input_ids.get_size(), pad_token_id); + std::copy_n(input_ids_token.data(), input_ids_token.get_size(), input_ids.data()); + } else { + std::fill_n(input_ids.data(), input_ids.get_size(), pad_token_id); + std::copy_n(input_ids_token.data(), input_ids_token.get_size(), input_ids.data()); + } }; - ov::Tensor input_ids(ov::element::i64, {text_embedding_batch_size, m_config.max_position_embeddings}); + ov::Tensor input_ids = m_request.get_input_tensor(); + input_ids.set_shape({text_embedding_batch_size, m_config.max_position_embeddings}); + size_t current_batch_idx = 0; if (do_classifier_free_guidance) { @@ -132,7 +139,6 @@ ov::Tensor CLIPTextModelWithProjection::infer(const std::string& pos_prompt, con {current_batch_idx + 1, m_config.max_position_embeddings})); // text embeddings - m_request.set_tensor("input_ids", input_ids); m_request.infer(); return m_request.get_output_tensor(0); diff --git a/src/cpp/src/image_generation/models/t5_encoder_model.cpp b/src/cpp/src/image_generation/models/t5_encoder_model.cpp index 21df456d46..a83697b2e6 100644 --- a/src/cpp/src/image_generation/models/t5_encoder_model.cpp +++ b/src/cpp/src/image_generation/models/t5_encoder_model.cpp @@ -80,8 +80,13 @@ ov::Tensor T5EncoderModel::infer(const std::string& pos_prompt, const std::strin ov::Tensor input_ids_token = m_tokenizer.encode(prompt).input_ids; size_t min_length = std::min(input_ids.get_size(), input_ids_token.get_size()); - std::fill_n(input_ids.data(), input_ids.get_size(), pad_token_id); - std::copy_n(input_ids_token.data(), min_length, input_ids.data()); + if (input_ids.get_element_type() == ov::element::i32) { + std::fill_n(input_ids.data(), input_ids.get_size(), pad_token_id); + std::copy_n(input_ids_token.data(), min_length, input_ids.data()); + } else { + std::fill_n(input_ids.data(), input_ids.get_size(), pad_token_id); + std::copy_n(input_ids_token.data(), min_length, input_ids.data()); + } }; ov::Tensor input_ids = m_request.get_input_tensor(); @@ -114,7 +119,6 @@ ov::Tensor T5EncoderModel::infer(const std::string& pos_prompt, const std::strin {current_batch_idx + 1, input_ids.get_shape()[1]})); // text embeddings - m_request.set_tensor("input_ids", input_ids); m_request.infer(); return m_request.get_output_tensor(0); diff --git a/src/cpp/src/image_generation/models/unet_inference_dynamic.hpp b/src/cpp/src/image_generation/models/unet_inference_dynamic.hpp index 6dc285f76d..914fbcf50b 100644 --- a/src/cpp/src/image_generation/models/unet_inference_dynamic.hpp +++ b/src/cpp/src/image_generation/models/unet_inference_dynamic.hpp @@ -12,11 +12,8 @@ namespace genai { class UNet2DConditionModel::UNetInferenceDynamic : public UNet2DConditionModel::UNetInference { - public: - - virtual void compile(std::shared_ptr model, const std::string& device, const ov::AnyMap& properties) override - { + virtual void compile(std::shared_ptr model, const std::string& device, const ov::AnyMap& properties) override { ov::Core core = utils::singleton_core(); ov::CompiledModel compiled_model = core.compile_model(model, device, properties); @@ -24,20 +21,17 @@ class UNet2DConditionModel::UNetInferenceDynamic : public UNet2DConditionModel:: m_request = compiled_model.create_infer_request(); } - virtual void set_hidden_states(const std::string& tensor_name, ov::Tensor encoder_hidden_states) override - { + virtual void set_hidden_states(const std::string& tensor_name, ov::Tensor encoder_hidden_states) override { OPENVINO_ASSERT(m_request, "UNet model must be compiled first"); m_request.set_tensor(tensor_name, encoder_hidden_states); } - virtual void set_adapters(AdapterController &adapter_controller, const AdapterConfig& adapters) override - { + virtual void set_adapters(AdapterController &adapter_controller, const AdapterConfig& adapters) override { OPENVINO_ASSERT(m_request, "UNet model must be compiled first"); adapter_controller.apply(m_request, adapters); } - virtual ov::Tensor infer(ov::Tensor sample, ov::Tensor timestep) override - { + virtual ov::Tensor infer(ov::Tensor sample, ov::Tensor timestep) override { OPENVINO_ASSERT(m_request, "UNet model must be compiled first. Cannot infer non-compiled model"); m_request.set_tensor("sample", sample); @@ -49,10 +43,8 @@ class UNet2DConditionModel::UNetInferenceDynamic : public UNet2DConditionModel:: } private: - ov::InferRequest m_request; }; - } // namespace genai } // namespace ov \ No newline at end of file diff --git a/src/cpp/src/image_generation/models/unet_inference_static_bs1.hpp b/src/cpp/src/image_generation/models/unet_inference_static_bs1.hpp index 7aa6f6301c..f63a8ea237 100644 --- a/src/cpp/src/image_generation/models/unet_inference_static_bs1.hpp +++ b/src/cpp/src/image_generation/models/unet_inference_static_bs1.hpp @@ -42,8 +42,7 @@ class UNet2DConditionModel::UNetInferenceStaticBS1 : public UNet2DConditionModel ov::CompiledModel compiled_model = core.compile_model(model, device, properties); ov::genai::utils::print_compiled_model_properties(compiled_model, "UNet 2D Condition batch-1 model"); - for (int i = 0; i < m_native_batch_size; i++) - { + for (int i = 0; i < m_native_batch_size; i++) { m_requests[i] = compiled_model.create_infer_request(); } }