Skip to content

Commit

Permalink
[GPU] Shared RoPE func tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Lyamin-Roman committed May 30, 2024
1 parent 26c609a commit f8b6e3c
Show file tree
Hide file tree
Showing 7 changed files with 751 additions and 1,238 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "subgraph_tests/rotary_pos_emb.hpp"

namespace ov {
namespace test {

INSTANTIATE_TEST_SUITE_P(smoke_RoPETestLlama2,
RoPETestLlama2,
::testing::Values(ov::test::utils::DEVICE_CPU),
RoPETestLlama2::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM,
RoPETestChatGLM,
::testing::Values(ov::test::utils::DEVICE_CPU),
RoPETestChatGLM::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_RoPETestQwen7b,
RoPETestQwen7b,
::testing::Combine(::testing::Values(true, false),
::testing::Values(ov::test::utils::DEVICE_CPU)),
RoPETestQwen7b::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(smoke_RoPETestGPTJ,
RoPETestGPTJ,
::testing::Combine(::testing::Values(true, false),
::testing::Values(ov::test::utils::DEVICE_CPU)),
RoPETestGPTJ::getTestCaseName);
} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,6 @@ std::vector<std::string> disabledTestPatterns() {
// Issue: 136862
R"(.*smoke_ConditionGPUTest_static/StaticConditionLayerGPUTest.CompareWithRefs/IS=\(3.6\)_netPRC=i8_ifCond=PARAM_targetDevice=GPU_.*)",

// TODO: Add RoPE support for Llama2, GPTJ models
R"(.*(RoPEGPUTestLlama2).*)",
R"(.*(RoPEGPUTestGPTJ).*)",
#if defined(_WIN32)
// by calc abs_threshold with expected value
R"(.*smoke_RemoteTensor/OVRemoteTensorBatched_Test.NV12toBGR_buffer/(num_batch_4|num_batch_2).*)",
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "shared_test_classes/subgraph/rotary_pos_emb.hpp"

namespace ov {
namespace test {

inline void CheckNumberOfNodesWithType(std::shared_ptr<const ov::Model> function,
const std::unordered_set<std::string>& nodeTypes,
size_t expectedCount) {
ASSERT_NE(nullptr, function);
int num_ops = 0;
for (const auto& node : function->get_ordered_ops()) {
const auto& rt_info = node->get_rt_info();
const auto layer_type = rt_info.find("layerType")->second.as<std::string>();
if (nodeTypes.count(layer_type)) {
num_ops++;
}
}
ASSERT_EQ(num_ops, expectedCount);
}

TEST_P(RoPETestLlama2, CompareWithRefs) {
run();
auto function = compiledModel.get_runtime_model();
CheckNumberOfNodesWithType(function, {"RoPE"}, 1);
};

TEST_P(RoPETestChatGLM, CompareWithRefs) {
run();
auto function = compiledModel.get_runtime_model();
CheckNumberOfNodesWithType(function, {"RoPE"}, 1);
};

TEST_P(RoPETestQwen7b, CompareWithRefs) {
run();
auto function = compiledModel.get_runtime_model();
CheckNumberOfNodesWithType(function, {"RoPE"}, 1);
};

TEST_P(RoPETestGPTJ, CompareWithRefs) {
run();
auto function = compiledModel.get_runtime_model();
CheckNumberOfNodesWithType(function, {"RoPE"}, 1);
};

} // namespace test
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "shared_test_classes/base/ov_subgraph.hpp"

namespace ov {
namespace test {

class RoPETestLlama2 : public SubgraphBaseTest, public testing::WithParamInterface<std::string> {
private:
ov::OutputVector makeCosSinCache(int max_position_embeddings, int rotary_ndims);
std::shared_ptr<ov::Model> buildROPE_Llama2(int batch,
int seq_length,
int max_position_embeddings,
int num_head,
int ndims);
ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1);
protected:
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
void SetUp() override;

public:
static std::string getTestCaseName(const testing::TestParamInfo<std::string>& obj);
};

class RoPETestChatGLM : public SubgraphBaseTest, public testing::WithParamInterface<std::string> {
private:
std::shared_ptr<ov::Model> buildROPE_ChatGLM(int batch, int head_cnt, int rotary_dims);
ov::Tensor create_i32_tensor(const ov::Shape& shape, int start, int step = 1);
protected:
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
void SetUp() override;

public:
static std::string getTestCaseName(const testing::TestParamInfo<std::string>& obj);
};

class RoPETestQwen7b : public SubgraphBaseTest, public testing::WithParamInterface<std::tuple<bool, std::string>> {
private:
std::shared_ptr<ov::Model> buildROPE_QWen7b(bool specialReshape);
protected:
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
void SetUp() override;

public:
static std::string getTestCaseName(const testing::TestParamInfo<std::tuple<bool, std::string>>& obj);
};

class RoPETestGPTJ : public SubgraphBaseTest, public testing::WithParamInterface<std::tuple<bool, std::string>> {
private:
std::shared_ptr<ov::Model> buildROPE_GPTJ(int num_head,
int hidden_dims,
int rotary_dims,
bool hasShapeOf);
protected:
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
void SetUp() override;

public:
static std::string getTestCaseName(const testing::TestParamInfo<std::tuple<bool, std::string>>& obj);
};

} // namespace test
} // namespace ov
Loading

0 comments on commit f8b6e3c

Please sign in to comment.