Skip to content

Commit

Permalink
[CPU] Enable u8 kv cache by default (#27454)
Browse files Browse the repository at this point in the history
### Details:
 - *Enable u8 kv cache by default*
 - *...*

### Tickets:
 - *[152621](https://jira.devtools.intel.com/browse/CVS-152621)*
  • Loading branch information
luo-cheng2021 authored Nov 13, 2024
1 parent 0080d90 commit 2d148ec
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 6 deletions.
4 changes: 4 additions & 0 deletions src/plugins/intel_cpu/src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
}
} else if (key == ov::hint::kv_cache_precision.name()) {
try {
kvCachePrecisionSetExplicitly = true;
auto const prec = val.as<ov::element::Type>();
if (one_of(prec, ov::element::f32, ov::element::f16, ov::element::bf16, ov::element::u8)) {
kvCachePrecision = prec;
Expand Down Expand Up @@ -411,6 +412,9 @@ void Config::readProperties(const ov::AnyMap& prop, const ModelType modelType) {
if (!fcDynamicQuantizationGroupSizeSetExplicitly) {
fcDynamicQuantizationGroupSize = 0;
}
if (!kvCachePrecisionSetExplicitly) {
kvCachePrecision = ov::element::f32;
}
}

if (!prop.empty())
Expand Down
4 changes: 3 additions & 1 deletion src/plugins/intel_cpu/src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,16 @@ struct Config {
std::string device_id = {};
float fcSparseWeiDecompressionRate = 1.0f;
uint64_t fcDynamicQuantizationGroupSize = 32;
ov::element::Type kvCachePrecision = ov::element::f16;
bool fcDynamicQuantizationGroupSizeSetExplicitly = false;
bool kvCachePrecisionSetExplicitly = false;
#if defined(OV_CPU_WITH_ACL)
bool aclFastMath = false;
#endif
#if defined(OPENVINO_ARCH_X86_64)
ov::element::Type kvCachePrecision = ov::element::u8;
size_t rtCacheCapacity = 5000ul;
#else
ov::element::Type kvCachePrecision = ov::element::f16;
// TODO: Executor cache may leads to incorrect behavior on oneDNN ACL primitives
size_t rtCacheCapacity = 0ul;
#endif
Expand Down
9 changes: 5 additions & 4 deletions src/plugins/intel_cpu/src/memory_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,18 +297,19 @@ void VariableStateKVcache::set_state_impl(const ov::SoPtr<ov::ITensor>& state) {
auto S = internal.size(3);
auto nthr = parallel_get_max_threads();
std::vector<PlainTensor> buffers(nthr);
m_scale_zp.resize<float>({L0, B, H, 2});
parallel_for3d(B, H, L0, [&](size_t ithr, size_t b, size_t h, size_t m) {
buffers[ithr].resize<float>({S});
cpu_convert(external.ptr_v(b, h, m),
cpu_convert(external.ptr_v(m, b, h),
buffers[ithr].ptr<float>(),
external.m_dt,
element::f32,
S);
attn_quant_u8(buffers[ithr].ptr<float>(),
internal.ptr<uint8_t>(b, h, m),
internal.ptr<uint8_t>(m, b, h),
S,
m_scale_zp.at<float>({b, h, m, size_t{0}}),
m_scale_zp.at<float>({b, h, m, size_t{1}}));
m_scale_zp.at<float>({m, b, h, size_t{0}}),
m_scale_zp.at<float>({m, b, h, size_t{1}}));
});
} else {
m_internal_mem->load(external_mem);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,17 @@ TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckAccuracyModeDynamicQuantiz
ASSERT_EQ(groupSize, 0);
}

TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckAccuracyModeKVCachePrecision) {
ov::Core core;

ASSERT_NO_THROW(core.set_property(deviceName, ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY)));
ov::CompiledModel compiledModel = core.compile_model(model, deviceName);

auto kv_cache_precision_value = ov::element::undefined;
ASSERT_NO_THROW(kv_cache_precision_value = compiledModel.get_property(ov::hint::kv_cache_precision));
ASSERT_EQ(kv_cache_precision_value, ov::element::f32);
}

const auto bf16_if_can_be_emulated = ov::with_cpu_x86_avx512_core() ? ov::element::bf16 : ov::element::f32;

TEST_F(OVClassConfigTestCPU, smoke_CpuExecNetworkCheckExecutionModeIsAvailableInCoreAndModel) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/core/type/float16.hpp"
#include "openvino/opsets/opset13.hpp"
#include "openvino/pass/manager.hpp"
#include "transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp"
Expand Down Expand Up @@ -207,6 +208,10 @@ class ConcatSDPTransposeTestBase : public testing::WithParamInterface<ConcatSDPT
ov::Tensor t{ov::element::f32, shape};
strided_iota(static_cast<float*>(t.data()), t.get_size(), val, 0.1f);
inputs.insert({param, t});
} else if (param->get_element_type() == ov::element::f16) {
ov::Tensor t{ov::element::f16, shape};
strided_iota(static_cast<ov::float16*>(t.data()), t.get_size(), val, 0.1f);
inputs.insert({param, t});
} else {
ASSERT_TRUE(param->get_element_type() == ov::element::bf16);
ov::Tensor t{ov::element::bf16, shape};
Expand Down Expand Up @@ -365,6 +370,15 @@ class ConcatSDPTransposeTestSetState : public ConcatSDPTransposeTestBase {
}
std::vector<ov::Tensor> run_test(std::shared_ptr<ov::Model> model) {
function = model;
// on spr, all kvccache precision will be covered and all paths for get/set_state will be tested
auto input_type = model->get_parameters()[0]->get_element_type();
if (input_type == ov::element::f32) {
configuration[ov::hint::kv_cache_precision.name()] = "f32";
} else if (input_type == ov::element::bf16) {
configuration[ov::hint::kv_cache_precision.name()] = "bf16";
} else {
configuration[ov::hint::kv_cache_precision.name()] = "u8";
}
prepare();
std::vector<ov::Tensor> outputs;
// case 1: initialization + pastkv reaches limitation, remove some state
Expand Down Expand Up @@ -407,6 +421,15 @@ class ConcatSDPTransposeTestSetState : public ConcatSDPTransposeTestBase {

TEST_P(ConcatSDPTransposeTestSetState, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
ElementType inType;
InputShapeAndTransposeOrder inputShapeAndOrders;
bool hasShapeOf;
std::tie(inType, inputShapeAndOrders, hasShapeOf) = this->GetParam();

// skip bf16 test on avx512 platform
if (inType == ElementType::bf16 && !ov::with_cpu_x86_bfloat16())
GTEST_SKIP();

auto actualOutputs = run_test(function);
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1);
CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0);
Expand Down Expand Up @@ -438,7 +461,7 @@ const std::vector<InputShapeAndTransposeOrder> inputShapeAndReordersSetState = {

INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTransposeTestSetState,
ConcatSDPTransposeTestSetState,
::testing::Combine(::testing::Values(ElementType::f32),
::testing::Combine(::testing::Values(ElementType::f32, ElementType::bf16, ElementType::f16),
::testing::ValuesIn(inputShapeAndReordersSetState),
::testing::Values(false)),
ConcatSDPTransposeTest::getTestCaseName);
Expand Down

0 comments on commit 2d148ec

Please sign in to comment.