Skip to content

Commit

Permalink
Win build fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed May 27, 2024
1 parent 1a38535 commit 9b672c0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
//
#pragma once

#ifdef _MSC_VER
# pragma warning(disable : 4244)
#endif

#include <algorithm>
#include <cassert>
#include <cfloat>
Expand Down Expand Up @@ -396,25 +400,25 @@ struct AttrAny {
template <typename T>
T cast_to() {
if (any.is<bool>())
return any.as<bool>();
return static_cast<T>(any.as<bool>());
if (any.is<int>())
return any.as<int>();
return static_cast<T>(any.as<int>());
if (any.is<long>())
return any.as<long>();
return static_cast<T>(any.as<long>());
if (any.is<long long>())
return any.as<long long>();
return static_cast<T>(any.as<long long>());
if (any.is<int32_t>())
return any.as<int32_t>();
return static_cast<T>(any.as<int32_t>());
if (any.is<int64_t>())
return any.as<int64_t>();
return static_cast<T>(any.as<int64_t>());
if (any.is<float>())
return any.as<float>();
return static_cast<T>(any.as<float>());
if (any.is<double>())
return any.as<double>();
return static_cast<T>(any.as<double>());
if (any.is<int8_t>())
return any.as<int8_t>();
return static_cast<T>(any.as<int8_t>());
if (any.is<uint8_t>())
return any.as<uint8_t>();
return static_cast<T>(any.as<uint8_t>());
return any.as<T>();
}

Expand Down Expand Up @@ -820,7 +824,9 @@ struct PatternNode {
return node->get_default_output();
}

PatternNode(const Output<Node>& out) : node(out.get_node_shared_ptr()), output_port(static_cast<int>(out.get_index())) {}
PatternNode(const Output<Node>& out)
: node(out.get_node_shared_ptr()),
output_port(static_cast<int>(out.get_index())) {}

PatternNode() {
node = ov::pass::pattern::any_input(ov::pass::pattern::has_static_rank());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX() {

op::internal::RoPE::Config config;
OutputVector new_args;
config.rotary_ndims = 2 * validator["half_ndims"];
config.rotary_ndims = 2ul * static_cast<size_t>(validator["half_ndims"]);

new_args.push_back(pattern_map.at(x));
new_args.push_back(v_cos);
Expand Down Expand Up @@ -184,11 +184,11 @@ ov::pass::RoPEFusionCosSinPreprocess::RoPEFusionCosSinPreprocess() {
if (pattern_map.count(gather_positions)) {
auto arg_id = rope_node->get_input_size();
rope_node->set_argument(arg_id, pattern_map.at(gather_positions));
config.gather_position_arg_id = arg_id;
config.gather_position_arg_id = static_cast<int>(arg_id);
} else if (pattern_map.count(gather_positions_2d)) {
auto arg_id = rope_node->get_input_size();
rope_node->set_argument(arg_id, pattern_map.at(gather_positions_2d));
config.gather_position_arg_id = arg_id;
config.gather_position_arg_id = static_cast<int>(arg_id);
}
rope_node->set_config(config);
rope_node->validate_and_infer_types();
Expand Down Expand Up @@ -274,8 +274,8 @@ ov::pass::RoPEFusionPreprocess::RoPEFusionPreprocess() {

auto config = rope_node->get_config();
if (pattern_map.count(input_to_slice)) {
config.slice_start = validator["slice_start"];
config.slice_stop = validator["slice_stop"];
config.slice_start = static_cast<size_t>(validator["slice_start"]);
config.slice_stop = static_cast<size_t>(validator["slice_stop"]);
config.input_trans0213 = true;
rope_node->set_argument(0, pattern_map.at(input_to_slice));
} else if (pattern_map.count(input_to_trans)) {
Expand Down Expand Up @@ -446,7 +446,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {

op::internal::RoPE::Config config;
OutputVector new_args;
config.rotary_ndims = validator["ndims"];
config.rotary_ndims = static_cast<size_t>(validator["ndims"]);

config.is_interleaved = true;

Expand Down Expand Up @@ -586,19 +586,19 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id) {

op::internal::RoPE::Config config;
OutputVector new_args;
config.rotary_ndims = validator["ndims"];
config.rotary_ndims = static_cast<size_t>(validator["ndims"]);
config.is_chatglm = true;
config.head_cnt = validator["head_cnt"];
config.head_size = validator["head_size"];
config.head_cnt = static_cast<size_t>(validator["head_cnt"]);
config.head_size = static_cast<size_t>(validator["head_size"]);

if (split_output_id == 0) {
// query : split_output_id == 0
config.slice_start = 0;
config.slice_stop = validator["total_size_q"];
config.slice_stop = static_cast<size_t>(validator["total_size_q"]);
} else {
// key : split_output_id == 1
config.slice_start = validator["total_size_q"];
config.slice_stop = config.slice_start + validator["total_size_k"];
config.slice_start = static_cast<size_t>(validator["total_size_q"]);
config.slice_stop = static_cast<size_t>(config.slice_start + validator["total_size_k"]);
}

new_args.push_back(pattern_map.at(qkv_linear));
Expand Down Expand Up @@ -730,8 +730,8 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
op::internal::RoPE::Config config;
OutputVector new_args;
config.is_qwen = true;
config.head_cnt = validator["head_cnt"];
config.head_size = validator["head_size"];
config.head_cnt = static_cast<size_t>(validator["head_cnt"]);
config.head_size = static_cast<size_t>(validator["head_size"]);
config.rotary_ndims = config.head_size;

if (split_output_id == 0) {
Expand Down

0 comments on commit 9b672c0

Please sign in to comment.