Skip to content

Commit

Permalink
Multiple images miniCPM-V-2_6 (#919)
Browse files Browse the repository at this point in the history
TODO:
- [ ] Remove `ov::Core` from constructors.
- [ ] Hide files and API.

---------

Co-authored-by: wenyi5608 <[email protected]>
Co-authored-by: Yang,Su <[email protected]>
Co-authored-by: Yaroslav Tarkan <[email protected]>
Co-authored-by: Alina Kladieva <[email protected]>
Co-authored-by: Pavel Esir <[email protected]>
Co-authored-by: Pavel Esir <[email protected]>
Co-authored-by: Artur Paniukov <[email protected]>
Co-authored-by: Ekaterina Aidova <[email protected]>
Co-authored-by: Ilya Lavrenov <[email protected]>
Co-authored-by: Mikhail Ryzhov <[email protected]>
  • Loading branch information
11 people authored Oct 9, 2024
1 parent a1feff9 commit 6d2763a
Show file tree
Hide file tree
Showing 14 changed files with 426 additions and 146 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ jobs:
diff pred2.txt ref.txt
echo "Chat sample python" passed
py-vlm_chat_sample-ubuntu:
visual_language_chat_sample-ubuntu:
runs-on: ubuntu-22.04-16-cores
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -859,6 +859,7 @@ jobs:
cpp-beam_search_causal_lm-Qwen-7B-Chat, cpp-beam_search_causal_lm-Qwen1_5-7B-Chat, cpp-beam_search_causal_lm-Phi-2,
cpp-beam_search_causal_lm-notus-7b-v1, cpp-speculative_decoding_lm-ubuntu, cpp-prompt_lookup_decoding_lm-ubuntu,
cpp-Phi-1_5, cpp-greedy_causal_lm-redpajama-3b-chat, cpp-chat_sample-ubuntu, cpp-continuous-batching-ubuntu,
visual_language_chat_sample-ubuntu,
cpp-continuous-batching-windows, cpp-continuous-batching-macos]
if: ${{ always() }}
runs-on: ubuntu-latest
Expand Down
3 changes: 2 additions & 1 deletion samples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ install(DIRECTORY
cpp/greedy_causal_lm
cpp/multinomial_causal_lm
# Don't install prompt_lookup_decoding_lm and speculative_decoding_lm because they don't use openvino_genai library and arent verifyed yet.
# Don't install continuous_batching_accuracy and continuous_batching_benchmark because they depend on json.
# Don't install continuous_batching_accuracy and continuous_batching_benchmark because CB isn't ready.
cpp/visual_language_chat
cpp/whisper_speech_recognition
cpp/text2image
cpp/lora_greedy_causal_lm
Expand Down
12 changes: 7 additions & 5 deletions samples/cpp/visual_language_chat/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

find_package(OpenVINOGenAI REQUIRED PATHS
"${CMAKE_BINARY_DIR}" # Reuse the package from the build.
${OpenVINO_DIR} # GenAI may be installed alogside OpenVINO.
find_package(OpenVINOGenAI REQUIRED
PATHS
"${CMAKE_BINARY_DIR}" # Reuse the package from the build.
${OpenVINO_DIR} # GenAI may be installed alogside OpenVINO.
NO_CMAKE_FIND_ROOT_PATH
)

file(DOWNLOAD
Expand All @@ -14,11 +16,11 @@ file(DOWNLOAD
add_executable(visual_language_chat visual_language_chat.cpp load_image.cpp)
target_include_directories(visual_language_chat PRIVATE "${CMAKE_CURRENT_SOUCE_DIR}" "${CMAKE_BINARY_DIR}")
target_link_libraries(visual_language_chat PRIVATE openvino::genai)

set_target_properties(visual_language_chat PROPERTIES
COMPILE_PDB_NAME chat_sample
COMPILE_PDB_NAME visual_language_chat
# Ensure out of box LC_RPATH on macOS with SIP
INSTALL_RPATH_USE_LINK_PATH ON)
target_compile_features(visual_language_chat PRIVATE cxx_std_11)

install(TARGETS visual_language_chat
RUNTIME DESTINATION samples_bin/
Expand Down
260 changes: 247 additions & 13 deletions samples/cpp/visual_language_chat/export_MiniCPM-V-2_6.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions samples/cpp/visual_language_chat/load_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ov::Tensor utils::load_image(const std::filesystem::path& image_path) {
image_path.string().c_str(),
&x, &y, &channels_in_file, desired_channels);
if (!image) {
throw std::runtime_error{"Failed to load the image"};
throw std::runtime_error{"Failed to load the image."};
}
struct SharedImageAllocator {
unsigned char* image;
Expand All @@ -22,11 +22,11 @@ ov::Tensor utils::load_image(const std::filesystem::path& image_path) {
if (channels * height * width == bytes) {
return image;
}
throw std::runtime_error{"Unexpected number of bytes was requested to allocate"};
throw std::runtime_error{"Unexpected number of bytes was requested to allocate."};
}
void deallocate(void*, size_t bytes, size_t) {
if (channels * height * width != bytes) {
throw std::runtime_error{"Unexpected number of bytes was requested to deallocate"};
throw std::runtime_error{"Unexpected number of bytes was requested to deallocate."};
}
std::free(image);
image = nullptr;
Expand Down
1 change: 0 additions & 1 deletion samples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ optimum[openvino]==1.22.0
einops==0.8.0 # For Qwen
transformers_stream_generator==0.0.5 # For Qwen
diffusers==0.30.3
pillow
torchvision # needed for mini-CPM export script. Need to remove when we switch to exporting with optimum-intel.
1 change: 1 addition & 0 deletions src/cpp/include/openvino/genai/processor_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace ov::genai {
/// preprocessor_config.json.
class OPENVINO_GENAI_EXPORTS ProcessorConfig {
public:
size_t image_size = 980;
/// @brief Dimensions of the smaller, non-overlapping patches that the
/// input image is divided into before being fed into the
/// transformer model. Used to divide image height and width.
Expand Down
2 changes: 2 additions & 0 deletions src/cpp/src/clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ struct clip_ctx {
std::vector<uint8_t> buf_compute_meta;

projector_type proj_type = PROJECTOR_TYPE_RESAMPLER;
size_t patch_size = 0;
size_t image_size = 0;
};

// RGB uint8 image
Expand Down
20 changes: 1 addition & 19 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,6 @@
#include "openvino/genai/lora_adapter.hpp"
#include "lora_helper.hpp"

namespace {

ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& fisrt, const ov::genai::TokenizedInputs& second){
auto first_size = fisrt.input_ids.get_size();
auto second_size = second.input_ids.get_size();
ov::Shape new_shape{1, first_size - second_size};

ov::Tensor new_input_ids(ov::element::i64, new_shape);
auto data_ptr = fisrt.input_ids.data<int64_t>();
std::copy(data_ptr + second_size, data_ptr + first_size, new_input_ids.data<int64_t>());

ov::Tensor new_attention_mask(ov::element::i64, new_shape);
std::fill_n(new_attention_mask.data<int64_t>(), new_shape[1], 1);

return {new_input_ids, new_attention_mask};
}
}

namespace ov {
namespace genai {

Expand Down Expand Up @@ -153,7 +135,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
encoded_input = new_chat_tokens;
} else {
auto prev_chat_tokens = m_tokenizer.encode(m_templated_chat_history, ov::genai::add_special_tokens(add_special_tokens_));
encoded_input = subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
encoded_input = utils::subtract_chat_tokenized_inputs(new_chat_tokens, prev_chat_tokens);
}
m_templated_chat_history = new_templated_chat_history;
// TODO: Forbid LoRA config change if we are in the chat mode, because it requires regenerating the history with LoRA applied
Expand Down
14 changes: 14 additions & 0 deletions src/cpp/src/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,20 @@ ProcessorConfig from_any_map(

std::pair<ov::AnyMap, ov::AnyMap> split_core_complile_config(const ov::AnyMap& plugin_config);

inline ov::genai::TokenizedInputs subtract_chat_tokenized_inputs(const ov::genai::TokenizedInputs& fisrt, const ov::genai::TokenizedInputs& second){
auto first_size = fisrt.input_ids.get_size();
auto second_size = second.input_ids.get_size();
ov::Shape new_shape{1, first_size - second_size};

ov::Tensor new_input_ids(ov::element::i64, new_shape);
auto data_ptr = fisrt.input_ids.data<int64_t>();
std::copy(data_ptr + second_size, data_ptr + first_size, new_input_ids.data<int64_t>());

ov::Tensor new_attention_mask(ov::element::i64, new_shape);
std::fill_n(new_attention_mask.data<int64_t>(), new_shape[1], 1);

return {new_input_ids, new_attention_mask};
}
} // namespace utils
} // namespace genai
} // namespace ov
87 changes: 71 additions & 16 deletions src/cpp/src/vision_encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,65 @@ ov::Tensor preprocess_for_encoder(const ov::Tensor& images, size_t kernel) {
return permuted_tensor;
}

// torch.bucketize(fractional_coords, boundaries, right=True)
std::vector<int64_t> bucket_size_right(const std::vector<float>& fractional_coords, const std::vector<float>& boundaries) {
std::vector<int64_t> bucket_coords(fractional_coords.size());
std::transform(fractional_coords.begin(), fractional_coords.end(), bucket_coords.begin(), [&boundaries](float fractional_coord) {
return std::distance(boundaries.begin(), std::upper_bound(boundaries.begin(), boundaries.end(), fractional_coord));
});
return bucket_coords;
}

ov::Tensor prepare_vis_position_ids(
const ov::Tensor& pixel_values,
const ov::Tensor& patch_attention_mask,
const std::vector<HeightWidth> tgt_sizes,
size_t patch_size,
size_t num_patches_per_side
) {
size_t batch_size = pixel_values.get_shape().at(0);
size_t max_im_h = pixel_values.get_shape().at(2), max_im_w = pixel_values.get_shape().at(3);
size_t max_nb_patches_h = max_im_h / patch_size, max_nb_patches_w = max_im_w / patch_size;
std::vector<float> boundaries(1.0f * num_patches_per_side - 1);
std::generate(boundaries.begin(), boundaries.end(), [num_patches_per_side, val = 0.0f]() mutable {
val += 1.0f / num_patches_per_side;
return val;
});
size_t position_ids_batch_elem = max_nb_patches_h * max_nb_patches_w;
ov::Tensor position_ids{ov::element::i64, {batch_size, position_ids_batch_elem}};
// throw std::runtime_error("");
int64_t* res_data = position_ids.data<int64_t>();
std::fill_n(res_data, position_ids.get_size(), 0);

for (size_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
size_t nb_patches_h = tgt_sizes.at(batch_idx).height;
size_t nb_patches_w = tgt_sizes.at(batch_idx).width;

std::vector<float> fractional_coords_h(nb_patches_h);
std::generate(fractional_coords_h.begin(), fractional_coords_h.end(), [nb_patches_h, val = -1.0f / nb_patches_h]() mutable {
val += 1.0f / nb_patches_h;
return val;
});
std::vector<float> fractional_coords_w(nb_patches_w);
std::generate(fractional_coords_w.begin(), fractional_coords_w.end(), [nb_patches_w, val = -1.0f / nb_patches_w]() mutable {
val += 1.0f / nb_patches_w;
return val;
});

std::vector<int64_t> bucket_coords_h = bucket_size_right(fractional_coords_h, boundaries);
std::vector<int64_t> bucket_coords_w = bucket_size_right(fractional_coords_w, boundaries);

std::vector<int64_t> pos_ids(bucket_coords_h.size() * bucket_coords_w.size());
for (size_t col = 0; col < bucket_coords_h.size(); ++col) {
for (size_t row = 0; row < bucket_coords_w.size(); ++row) {;
pos_ids.at(col * bucket_coords_w.size() + row) = bucket_coords_h.at(col) * num_patches_per_side + bucket_coords_w.at(row);
}
}
std::copy(pos_ids.begin(), pos_ids.end(), res_data + batch_idx * position_ids_batch_elem);
}
return position_ids;
}

EncodedImage llava_image_embed_make_with_bytes_slice(clip_ctx& ctx_clip, const ov::Tensor& img, ov::InferRequest& encoder, int max_slice_nums, int scale_resolution, size_t patch_size, bool never_split) {
clip_image_u8 source{
int(img.get_shape().at(3)),
Expand Down Expand Up @@ -244,14 +303,11 @@ EncodedImage llava_image_embed_make_with_bytes_slice(clip_ctx& ctx_clip, const o
ov::Tensor patch_attention_mask{ov::element::boolean, {pixel_values.get_shape().at(0), 1, resized_source_size.height * resized_source_size.width}};
std::fill_n(patch_attention_mask.data<bool>(), patch_attention_mask.get_size(), true);
encoder.set_tensor("patch_attention_mask", patch_attention_mask);
ov::Tensor tgt_sizes{ov::element::i64, {1, 2}};
int64_t* tgt_sizes_data = tgt_sizes.data<int64_t>();
tgt_sizes_data[0] = resized_source_size.height;
tgt_sizes_data[1] = resized_source_size.width;
encoder.set_tensor("tgt_sizes", tgt_sizes);
ov::Tensor position_ids = prepare_vis_position_ids(pixel_values, patch_attention_mask, {resized_source_size}, ctx_clip.patch_size, ctx_clip.image_size / ctx_clip.patch_size);
encoder.set_tensor("position_ids", position_ids);
encoder.infer();
const ov::Tensor& output_tensor = encoder.get_output_tensor();
ov::Tensor resized_source{output_tensor.get_element_type(), output_tensor.get_shape()};
ov::Tensor resized_source{ov::element::f32, output_tensor.get_shape()};
output_tensor.copy_to(resized_source);

if (1 == preprocessed.size()) {
Expand All @@ -266,27 +322,24 @@ EncodedImage llava_image_embed_make_with_bytes_slice(clip_ctx& ctx_clip, const o
size_t n_patches = size.height / patch_size * size.width / patch_size,
old_hidden_size = resized_source.get_shape().at(2);
ov::Tensor encoded_slices{ov::element::f32, {preprocessed.size() - 1, preprocessed.at(1).size(), n_patches, old_hidden_size}};
// там внутри есть какая-то операция которая констант фолдит батч и из-за этого нельзя использовать отличный от того что был при экспорте
// констант фолдит она его в торч скрипте
// Even though batch can't be used, it's still possible to use async.
for (size_t row = 1; row < preprocessed.size(); ++row) {
for (size_t col = 0; col < preprocessed.at(row).size(); ++col) {
clip_image_f32& elem = preprocessed.at(row).at(col);
sliced_sizes.push_back({elem.ny / patch_size, elem.nx / patch_size});
encoder.set_tensor("pixel_values", preprocess_for_encoder(
ov::Tensor pixel_values = preprocess_for_encoder(
{ov::element::f32, {1, 3, size_t(elem.ny), size_t(elem.nx)}, elem.buf.data()},
patch_size
));
);
encoder.set_tensor("pixel_values", pixel_values);
ov::Tensor patch_attention_mask{ov::element::boolean, {1, 1, sliced_sizes.back().height * sliced_sizes.back().width}};
std::fill_n(patch_attention_mask.data<bool>(), patch_attention_mask.get_size(), true);
encoder.set_tensor("patch_attention_mask", patch_attention_mask);
ov::Tensor tgt_sizes{ov::element::i64, {1, 2}};
int64_t* tgt_sizes_data = tgt_sizes.data<int64_t>();
tgt_sizes_data[0] = sliced_sizes.back().height;
tgt_sizes_data[1] = sliced_sizes.back().width;
encoder.set_tensor("tgt_sizes", tgt_sizes);
ov::Tensor position_ids = prepare_vis_position_ids(pixel_values, patch_attention_mask, {sliced_sizes.back()}, ctx_clip.patch_size, ctx_clip.image_size / ctx_clip.patch_size);
encoder.set_tensor("position_ids", position_ids);
const ov::Tensor& old = encoder.get_output_tensor();
encoder.set_output_tensor({ov::element::f32, {1, n_patches, old_hidden_size}, encoded_slices.data<float>() + ((row - 1) * preprocessed.at(row).size() + col) * n_patches * old_hidden_size});
encoder.infer();
encoder.set_output_tensor(old);
}
}
return {resized_source, resized_source_size, encoded_slices, sliced_sizes};
Expand All @@ -305,6 +358,8 @@ VisionEncoder::VisionEncoder(const std::filesystem::path& model_dir, const std::

EncodedImage VisionEncoder::encode(const ov::Tensor& image, const ProcessorConfig& config) {
clip_ctx ctx_clip;
ctx_clip.patch_size = m_processor_config.patch_size;
ctx_clip.image_size = m_processor_config.image_size;
std::copy(config.norm_mean.begin(), config.norm_mean.end(), ctx_clip.image_mean);
std::copy(config.norm_std.begin(), config.norm_std.end(), ctx_clip.image_std);
return llava_image_embed_make_with_bytes_slice(ctx_clip, image, m_encoder, config.max_slice_nums, config.scale_resolution, config.patch_size, 0 == config.max_slice_nums);
Expand Down
Loading

0 comments on commit 6d2763a

Please sign in to comment.