diff --git a/src/target/parsers/aprofile.cc b/src/target/parsers/aprofile.cc index 907e0cae72d27..f84c7485a0180 100644 --- a/src/target/parsers/aprofile.cc +++ b/src/target/parsers/aprofile.cc @@ -111,7 +111,8 @@ static TargetFeatures GetFeatures(TargetJSON target) { {"has_sve", Bool(has_feature("sve"))}, {"has_dotprod", Bool(has_feature("dotprod"))}, {"has_matmul_i8", Bool(has_feature("i8mm"))}, - {"has_fp16_simd", Bool(has_feature("fullfp16"))}}; + {"has_fp16_simd", Bool(has_feature("fullfp16"))}, + {"has_sme", Bool(has_feature("sme"))}}; #endif LOG(WARNING) << "Cannot parse Arm(R)-based target features without LLVM support."; diff --git a/tests/cpp/target/parsers/aprofile_test.cc b/tests/cpp/target/parsers/aprofile_test.cc index a134e162fc2d6..d329a9b958ad3 100644 --- a/tests/cpp/target/parsers/aprofile_test.cc +++ b/tests/cpp/target/parsers/aprofile_test.cc @@ -38,6 +38,7 @@ static float defaultI8MM = 8.6; static float optionalI8MM[] = {8.2, 8.3, 8.4, 8.5}; static float defaultDotProd = 8.4; static float optionalDotProd[] = {8.2, 8.3}; +static float optionalSME[] = {9.2, 9.3}; static bool CheckArchitectureAvailability() { #if TVM_LLVM_VERSION > 120 @@ -405,6 +406,21 @@ TEST(AProfileParserInvalid, LLVMUnsupportedArchitecture) { } } +using AProfileOptionalSME = AProfileParserTestWithParam; +TEST_P(AProfileOptionalSME, OptionalSMESupport) { + const std::string arch_attr = "+v9a"; + + TargetJSON target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr}); + TargetFeatures features = Downcast(target.at("features")); + ASSERT_TRUE(IsArch(target)); + ASSERT_FALSE(Downcast(features.at("has_sme"))); + + target = ParseTargetWithAttrs("", "aarch64-arm-none-eabi", {arch_attr, "+sme"}); + features = Downcast(target.at("features")); + ASSERT_TRUE(IsArch(target)); + ASSERT_TRUE(Downcast(features.at("has_sme"))); +} + INSTANTIATE_TEST_SUITE_P(AProfileParser, AProfileOptionalI8MM, ::testing::ValuesIn(optionalI8MM)); INSTANTIATE_TEST_SUITE_P(AProfileParser, AProfileOptionalDotProd, ::testing::ValuesIn(optionalDotProd)); @@ -412,6 +428,7 @@ INSTANTIATE_TEST_SUITE_P(AProfileParser, AProfileOptionalSVE, ::testing::Values(8.0, 8.1, 8.2, 8.3, 8.4, 8.5, 8.6, 8.7, 8.8, 8.9)); INSTANTIATE_TEST_SUITE_P(AProfileParser, AProfileOptionalFP16, ::testing::Values(8.2, 8.3, 8.4, 8.5, 8.6, 8.7, 8.8, 8.9)); +INSTANTIATE_TEST_SUITE_P(AProfileParser, AProfileOptionalSME, ::testing::ValuesIn(optionalSME)); } // namespace aprofile } // namespace parsers