From 0016d19c740e46dc8443710e9bee0151908b4410 Mon Sep 17 00:00:00 2001 From: Aleksey Seren Date: Mon, 14 Feb 2022 17:57:53 +0700 Subject: [PATCH] Fix crash if catalog has less then tree unique segments. fix https://github.com/brave/brave-browser/issues/21052 --- .../bandits/epsilon_greedy_bandit_model.cc | 11 ++++++++--- .../epsilon_greedy_bandit_model_unittest.cc | 16 ++++++++++++++++ .../internal/segments/segments_json_reader.cc | 15 +++++++-------- 3 files changed, 31 insertions(+), 11 deletions(-) diff --git a/vendor/bat-native-ads/src/bat/ads/internal/ad_serving/ad_targeting/models/behavioral/bandits/epsilon_greedy_bandit_model.cc b/vendor/bat-native-ads/src/bat/ads/internal/ad_serving/ad_targeting/models/behavioral/bandits/epsilon_greedy_bandit_model.cc index 95fb33647225..3f7b00827de6 100644 --- a/vendor/bat-native-ads/src/bat/ads/internal/ad_serving/ad_targeting/models/behavioral/bandits/epsilon_greedy_bandit_model.cc +++ b/vendor/bat-native-ads/src/bat/ads/internal/ad_serving/ad_targeting/models/behavioral/bandits/epsilon_greedy_bandit_model.cc @@ -79,6 +79,9 @@ SegmentList GetEligibleSegments() { EpsilonGreedyBanditArmMap GetEligibleArms( const EpsilonGreedyBanditArmMap& arms) { const SegmentList eligible_segments = GetEligibleSegments(); + if (eligible_segments.empty()) { + return {}; + } EpsilonGreedyBanditArmMap eligible_arms; @@ -119,7 +122,7 @@ ArmList GetTopArms(const ArmBucketList& buckets, const size_t count) { ArmList arms = bucket.second; if (arms.size() > available_arms) { // Sample without replacement - base::RandomShuffle(begin(arms), end(arms)); + base::RandomShuffle(std::begin(arms), std::end(arms)); arms.resize(available_arms); } @@ -136,8 +139,10 @@ SegmentList ExploreSegments(const EpsilonGreedyBanditArmMap& arms) { segments.push_back(arm.first); } - base::RandomShuffle(begin(segments), end(segments)); - segments.resize(kTopArmCount); + if (segments.size() > kTopArmCount) { + base::RandomShuffle(std::begin(segments), std::end(segments)); + segments.resize(kTopArmCount); + } BLOG(2, "Exploring epsilon greedy bandit segments:"); for (const auto& segment : segments) { diff --git a/vendor/bat-native-ads/src/bat/ads/internal/ad_serving/ad_targeting/models/behavioral/bandits/epsilon_greedy_bandit_model_unittest.cc b/vendor/bat-native-ads/src/bat/ads/internal/ad_serving/ad_targeting/models/behavioral/bandits/epsilon_greedy_bandit_model_unittest.cc index ce2ed632a34e..e15888a3e123 100644 --- a/vendor/bat-native-ads/src/bat/ads/internal/ad_serving/ad_targeting/models/behavioral/bandits/epsilon_greedy_bandit_model_unittest.cc +++ b/vendor/bat-native-ads/src/bat/ads/internal/ad_serving/ad_targeting/models/behavioral/bandits/epsilon_greedy_bandit_model_unittest.cc @@ -43,6 +43,22 @@ TEST_F(BatAdsEpsilonGreedyBanditModelTest, EXPECT_TRUE(segments.empty()); } +TEST_F(BatAdsEpsilonGreedyBanditModelTest, EligableSegmentsAreEmpty) { + // Arrange + base::test::ScopedFeatureList scoped_feature_list; + scoped_feature_list.InitAndEnableFeatureWithParameters( + features::kEpsilonGreedyBandit, {{"epsilon_value", "0.5"}}); + + processor::EpsilonGreedyBandit processor; + + // Act + EpsilonGreedyBandit model; + const SegmentList segments = model.GetSegments(); + + // Assert + EXPECT_TRUE(segments.empty()); +} + TEST_F(BatAdsEpsilonGreedyBanditModelTest, GetSegmentsIfNeverProcessed) { // Arrange SaveAllSegments(); diff --git a/vendor/bat-native-ads/src/bat/ads/internal/segments/segments_json_reader.cc b/vendor/bat-native-ads/src/bat/ads/internal/segments/segments_json_reader.cc index ab950222dc1a..ac8c2ccddb07 100644 --- a/vendor/bat-native-ads/src/bat/ads/internal/segments/segments_json_reader.cc +++ b/vendor/bat-native-ads/src/bat/ads/internal/segments/segments_json_reader.cc @@ -6,7 +6,6 @@ #include "bat/ads/internal/segments/segments_json_reader.h" #include "base/json/json_reader.h" -#include "base/notreached.h" #include "base/values.h" #include "third_party/abseil-cpp/absl/types/optional.h" @@ -14,26 +13,26 @@ namespace ads { namespace JSONReader { SegmentList ReadSegments(const std::string& json) { - SegmentList segments; - absl::optional value = base::JSONReader::Read(json); if (!value) { - return segments; + return {}; } base::ListValue* list = nullptr; if (!value->GetAsList(&list)) { - return segments; + return {}; } + SegmentList segments; for (const auto& element : list->GetList()) { if (!element.is_string()) { - NOTREACHED(); - continue; + return {}; } const std::string segment = element.GetString(); - DCHECK(!segment.empty()); + if (segment.empty()) { + return {}; + } segments.push_back(segment); }