Skip to content

Commit

Permalink
AI Chat: introduce freemium model concept (#21398)
Browse files Browse the repository at this point in the history
* AI Chat: introduce freemium model concept

* AI Chat: Add mixtral model

* AI Chat: premium's default also becomes leo-expanded

* AI Chat: non-premium users on freemium models get rate limit error which prompts to use a basic (free) model as secondary action

* fix format from chromium upgrade

* AI Chat: introduce "is_freemium_available" boolean feature param

If true, certain freemium models are available to non-premium users. If false, those models are premium-only.

* AI Chat: migrate default model key pref from "chat-default" to "chat-basic"

* AI Chat: reduce claude page content character length to avoid token length mismatch more often

* AI Chat: model menu uses model display name instead of API name
  • Loading branch information
petemill committed Jan 16, 2024
1 parent 26ffe8c commit 4b8584b
Show file tree
Hide file tree
Showing 24 changed files with 268 additions and 139 deletions.
26 changes: 9 additions & 17 deletions browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,26 +147,18 @@ void BraveLeoAssistantHandler::HandleResetLeoData(
}

void BraveLeoAssistantHandler::HandleGetModels(const base::Value::List& args) {
std::vector<ai_chat::mojom::ModelPtr> models(
ai_chat::kAllModelKeysDisplayOrder.size());
// Ensure we return only in intended display order
std::transform(ai_chat::kAllModelKeysDisplayOrder.cbegin(),
ai_chat::kAllModelKeysDisplayOrder.cend(), models.begin(),
[](auto& model_key) {
auto model_match = ai_chat::kAllModels.find(model_key);
DCHECK(model_match != ai_chat::kAllModels.end());
return model_match->second.Clone();
});
auto& models = ai_chat::GetAllModels();
base::Value::List models_list;
for (auto& model : models) {
base::Value::Dict dict;
dict.Set("key", model->key);
dict.Set("name", model->name);
dict.Set("display_name", model->display_name);
dict.Set("display_maker", model->display_maker);
dict.Set("engine_type", static_cast<int>(model->engine_type));
dict.Set("category", static_cast<int>(model->category));
dict.Set("is_premium", model->is_premium);
dict.Set("key", model.key);
dict.Set("name", model.name);
dict.Set("display_name", model.display_name);
dict.Set("display_maker", model.display_maker);
dict.Set("engine_type", static_cast<int>(model.engine_type));
dict.Set("category", static_cast<int>(model.category));
dict.Set("is_premium",
model.access == ai_chat::mojom::ModelAccess::PREMIUM);
models_list.Append(std::move(dict));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ const char16_t kEnableNftDiscoveryLearnMoreUrl[] =
u"https://github.com/brave/brave-browser/wiki/"
u"NFT-Discovery";

void BraveAddCommonStrings(content::WebUIDataSource* html_source,
Profile* profile) {
void BraveAddCommonStrings(content::WebUIDataSource *html_source,
Profile *profile) {
webui::LocalizedString localized_strings[] = {
{"importExtensions", IDS_SETTINGS_IMPORT_EXTENSIONS_CHECKBOX},
{"importPayments", IDS_SETTINGS_IMPORT_PAYMENTS_CHECKBOX},
Expand Down Expand Up @@ -194,7 +194,7 @@ void BraveAddCommonStrings(content::WebUIDataSource* html_source,
IDS_SETTINGS_APPEARANCE_SETTINGS_SIDEBAR_ENABLED_DESC},
{"appearanceSettingsSidebarDisabledDesc",
IDS_SETTINGS_APPEARANCE_SETTINGS_SIDEBAR_DISABLED_DESC},
#endif // defined(TOOLKIT_VIEWS)
#endif // defined(TOOLKIT_VIEWS)
#if BUILDFLAG(ENABLE_BRAVE_VPN)
{"showBraveVPNButton", IDS_SETTINGS_SHOW_VPN_BUTTON},
{"showBraveVPNButtonSubLabel", IDS_SETTINGS_SHOW_VPN_BUTTON_SUB_LABEL},
Expand Down Expand Up @@ -394,7 +394,7 @@ void BraveAddCommonStrings(content::WebUIDataSource* html_source,
{"braveLeoAssistantModelSelectionLabel",
IDS_SETTINGS_LEO_ASSISTANT_MODEL_SELECTION_LABEL},
{"braveLeoModelCategory-chat", IDS_CHAT_UI_MODEL_CATEGORY_CHAT},
{"braveLeoModelSubtitle-chat-default", IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE},
{"braveLeoModelSubtitle-chat-basic", IDS_CHAT_UI_CHAT_BASIC_SUBTITLE},
{"braveLeoModelSubtitle-chat-leo-expanded",
IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE},
{"braveLeoModelSubtitle-chat-claude-instant",
Expand Down Expand Up @@ -743,15 +743,15 @@ void BraveAddCommonStrings(content::WebUIDataSource* html_source,
l10n_util::GetStringFUTF16(
IDS_SETTINGS_BRAVE_SHIELDS_DEFAULTS_DESCRIPTION_2,
kBraveUIRewardsURL));
} // NOLINT(readability/fn_size)
} // NOLINT(readability/fn_size)

void BraveAddResources(content::WebUIDataSource* html_source,
Profile* profile) {
void BraveAddResources(content::WebUIDataSource *html_source,
Profile *profile) {
BraveSettingsUI::AddResources(html_source, profile);
}

void BraveAddAboutStrings(content::WebUIDataSource* html_source,
Profile* profile) {
void BraveAddAboutStrings(content::WebUIDataSource *html_source,
Profile *profile) {
std::u16string license = l10n_util::GetStringFUTF16(
IDS_BRAVE_VERSION_UI_LICENSE, kBraveLicenseUrl,
base::ASCIIToUTF16(chrome::kChromeUICreditsURL),
Expand All @@ -762,17 +762,17 @@ void BraveAddAboutStrings(content::WebUIDataSource* html_source,
html_source->AddString("aboutProductLicense", license);
}

void BraveAddSyncStrings(content::WebUIDataSource* html_source) {
void BraveAddSyncStrings(content::WebUIDataSource *html_source) {
std::u16string passphraseDecryptionErrorMessage = l10n_util::GetStringFUTF16(
IDS_BRAVE_SYNC_PASSPHRASE_DECRYPTION_ERROR_MESSAGE, kBraveSyncGuideUrl);
html_source->AddString("braveSyncPassphraseDecryptionErrorMessage",
passphraseDecryptionErrorMessage);
}

} // namespace
} // namespace

void BraveAddLocalizedStrings(content::WebUIDataSource* html_source,
Profile* profile) {
void BraveAddLocalizedStrings(content::WebUIDataSource *html_source,
Profile *profile) {
BraveAddCommonStrings(html_source, profile);
BraveAddResources(html_source, profile);
BraveAddAboutStrings(html_source, profile);
Expand Down Expand Up @@ -942,4 +942,4 @@ void BraveAddLocalizedStrings(content::WebUIDataSource* html_source,
#endif
}

} // namespace settings
} // namespace settings
11 changes: 7 additions & 4 deletions components/ai_chat/core/browser/constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@ base::span<const webui::LocalizedString> GetLocalizedStrings() {
{"errorNetworkLabel", IDS_CHAT_UI_ERROR_NETWORK},
{"errorRateLimit", IDS_CHAT_UI_ERROR_RATE_LIMIT},
{"retryButtonLabel", IDS_CHAT_UI_RETRY_BUTTON_LABEL},
{"introMessage-chat-default", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_DEFAULT},
{"introMessage-chat-basic", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_BASIC},
{"introMessage-chat-leo-expanded",
IDS_CHAT_UI_INTRO_MESSAGE_CHAT_LEO_EXPANDED},
{"introMessage-chat-claude-instant",
IDS_CHAT_UI_INTRO_MESSAGE_CHAT_LEO_CLAUDE_INSTANT},
{"modelNameSyntax", IDS_CHAT_UI_MODEL_NAME_SYNTAX},
{"modelFreemiumLabelNonPremium",
IDS_CHAT_UI_MODEL_FREEMIUM_LABEL_NON_PREMIUM},
{"modelFreemiumLabelPremium", IDS_CHAT_UI_MODEL_FREEMIUM_LABEL_PREMIUM},
{"modelCategory-chat", IDS_CHAT_UI_MODEL_CATEGORY_CHAT},
{"menuNewChat", IDS_CHAT_UI_MENU_NEW_CHAT},
{"menuGoPremium", IDS_CHAT_UI_MENU_GO_PREMIUM},
Expand All @@ -44,8 +47,8 @@ base::span<const webui::LocalizedString> GetLocalizedStrings() {
{"premiumFeature_2", IDS_CHAT_UI_PREMIUM_FEATURE_2},
{"premiumLabel", IDS_CHAT_UI_PREMIUM_LABEL},
{"premiumPricing", IDS_CHAT_UI_PREMIUM_PRICING},
{"switchToDefaultModelButtonLabel",
IDS_CHAT_UI_SWITCH_TO_DEFAULT_MODEL_BUTTON_LABEL},
{"switchToBasicModelButtonLabel",
IDS_CHAT_UI_SWITCH_TO_BASIC_MODEL_BUTTON_LABEL},
{"dismissButtonLabel", IDS_CHAT_UI_DISMISS_BUTTON_LABEL},
{"unlockPremiumTitle", IDS_CHAT_UI_UNLOCK_PREMIUM_TITLE},
{"premiumFeature_1_desc", IDS_CHAT_UI_PREMIUM_FEATURE_1_DESC},
Expand Down Expand Up @@ -74,7 +77,7 @@ base::span<const webui::LocalizedString> GetLocalizedStrings() {
{"optionOther", IDS_CHAT_UI_OPTION_OTHER},
{"feedbackError", IDS_CHAT_UI_FEEDBACK_SUBMIT_ERROR},
{"ratingError", IDS_CHAT_UI_RATING_ERROR},
{"braveLeoModelSubtitle-chat-default", IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE},
{"braveLeoModelSubtitle-chat-basic", IDS_CHAT_UI_CHAT_BASIC_SUBTITLE},
{"braveLeoModelSubtitle-chat-leo-expanded",
IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE},
{"braveLeoModelSubtitle-chat-claude-instant",
Expand Down
53 changes: 29 additions & 24 deletions components/ai_chat/core/browser/conversation_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,18 @@ ConversationDriver::ConversationDriver(raw_ptr<PrefService> pref_service,
base::BindRepeating(&ConversationDriver::OnUserOptedIn,
weak_ptr_factory_.GetWeakPtr()));

// Engines and model names are selectable per conversation, not static.
// Start with default from pref value but only if user set. We can't rely on
// actual default pref value since we should vary if user is premium or not.
// Model choice names is selectable per conversation, not global.
// Start with default from pref value if. If user is premium and premium model
// is different to non-premium default, and user hasn't customized the model
// pref, then switch the user to the premium default.
// TODO(petemill): When we have an event for premium status changed, and a
// profile service for AIChat, then we can call
// |pref_service_->SetDefaultPrefValue| when the user becomes premium. With
// that, we'll be able to simply call GetString(prefs::kDefaultModelKey) and
// not vary on premium status.
if (!pref_service_->GetUserPrefValue(prefs::kDefaultModelKey)) {
// not have to fetch premium status.
if (!pref_service_->GetUserPrefValue(prefs::kDefaultModelKey) &&
features::kAIModelsPremiumDefaultKey.Get() !=
features::kAIModelsDefaultKey.Get()) {
credential_manager_->GetPremiumStatus(base::BindOnce(
[](ConversationDriver* instance, mojom::PremiumStatus status) {
instance->last_premium_status_ = status;
Expand All @@ -76,18 +79,18 @@ ConversationDriver::ConversationDriver(raw_ptr<PrefService> pref_service,
return;
}
// Use default premium model for this instance
instance->ChangeModel(kModelsPremiumDefaultKey);
instance->ChangeModel(features::kAIModelsPremiumDefaultKey.Get());
// Make sure default model reflects premium status
const auto* current_default =
instance->pref_service_
->GetDefaultPrefValue(prefs::kDefaultModelKey)
->GetIfString();

if (current_default &&
*current_default != kModelsPremiumDefaultKey) {
*current_default != features::kAIModelsPremiumDefaultKey.Get()) {
instance->pref_service_->SetDefaultPrefValue(
prefs::kDefaultModelKey,
base::Value(kModelsPremiumDefaultKey));
base::Value(features::kAIModelsPremiumDefaultKey.Get()));
}
},
// Unretained is ok as credential manager is owned by this class,
Expand Down Expand Up @@ -118,16 +121,19 @@ ConversationDriver::~ConversationDriver() = default;
void ConversationDriver::ChangeModel(const std::string& model_key) {
DCHECK(!model_key.empty());
// Check that the key exists
if (kAllModels.find(model_key) == kAllModels.end()) {
auto* new_model = GetModel(model_key);
if (!new_model) {
NOTREACHED() << "No matching model found for key: " << model_key;
return;
}
model_key_ = model_key;
model_key_ = new_model->key;
InitEngine();
}

const mojom::Model& ConversationDriver::GetCurrentModel() {
return kAllModels.find(model_key_)->second;
auto* model = GetModel(model_key_);
DCHECK(model);
return *model;
}

const std::vector<ConversationTurn>& ConversationDriver::GetConversationHistory() {
Expand All @@ -145,33 +151,32 @@ void ConversationDriver::OnConversationActiveChanged(bool is_conversation_active

void ConversationDriver::InitEngine() {
DCHECK(!model_key_.empty());
auto model_match = kAllModels.find(model_key_);
auto* model = GetModel(model_key_);
// Make sure we get a valid model, defaulting to static default or first.
if (model_match == kAllModels.end()) {
if (!model) {
NOTREACHED() << "Model was not part of static model list";
// Use default
model_match = kAllModels.find(kModelsDefaultKey);
const auto is_found = model_match != kAllModels.end();
DCHECK(is_found);
if (!is_found) {
model_match = kAllModels.begin();
model = GetModel(features::kAIModelsDefaultKey.Get());
DCHECK(model);
if (!model) {
// Use first if given bad default value
model = &GetAllModels().at(0);
}
}

auto model = model_match->second;
// Model's key might not be the same as what we asked for (e.g. if the model
// no longer exists).
model_key_ = model.key;
model_key_ = model->key;

// Engine enum on model to decide which one
if (model.engine_type == mojom::ModelEngineType::LLAMA_REMOTE) {
if (model->engine_type == mojom::ModelEngineType::LLAMA_REMOTE) {
VLOG(1) << "Started AI engine: llama";
engine_ = std::make_unique<EngineConsumerLlamaRemote>(
model, url_loader_factory_, credential_manager_.get());
*model, url_loader_factory_, credential_manager_.get());
} else {
VLOG(1) << "Started AI engine: claude";
engine_ = std::make_unique<EngineConsumerClaudeRemote>(
model, url_loader_factory_, credential_manager_.get());
*model, url_loader_factory_, credential_manager_.get());
}

// Pending requests have been deleted along with the model engine
Expand Down Expand Up @@ -719,7 +724,7 @@ void ConversationDriver::OnPremiumStatusReceived(
if (last_premium_status_ != premium_status &&
premium_status == mojom::PremiumStatus::Active) {
// Change model if we haven't already
ChangeModel(kModelsPremiumDefaultKey);
ChangeModel(features::kAIModelsPremiumDefaultKey.Get());
}
last_premium_status_ = premium_status;
std::move(parent_callback).Run(premium_status);
Expand Down
11 changes: 2 additions & 9 deletions components/ai_chat/core/browser/engine/engine_consumer_claude.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,11 @@ EngineConsumerClaudeRemote::EngineConsumerClaudeRemote(
const mojom::Model& model,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
AIChatCredentialManager* credential_manager) {
// Allow specific model name to be overriden by feature flag
// TODO(petemill): verify premium status, or ensure server will verify even
// when given a model name override via cli flag param.
std::string model_name = ai_chat::features::kAIModelName.Get();
if (model_name.empty()) {
model_name = model.name;
}
DCHECK(!model_name.empty());
DCHECK(!model.name.empty());
base::flat_set<std::string_view> stop_sequences(kStopSequences.begin(),
kStopSequences.end());
api_ = std::make_unique<RemoteCompletionClient>(
model_name, stop_sequences, url_loader_factory, credential_manager);
model.name, stop_sequences, url_loader_factory, credential_manager);

max_page_content_length_ = model.max_page_content_length;
}
Expand Down
11 changes: 2 additions & 9 deletions components/ai_chat/core/browser/engine/engine_consumer_llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,18 +243,11 @@ EngineConsumerLlamaRemote::EngineConsumerLlamaRemote(
const mojom::Model& model,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
AIChatCredentialManager* credential_manager) {
// Allow specific model name to be overriden by feature flag
// TODO(petemill): verify premium status, or ensure server will verify even
// when given a model name override via cli flag param.
std::string model_name = ai_chat::features::kAIModelName.Get();
if (model_name.empty()) {
model_name = model.name;
}
DCHECK(!model_name.empty());
DCHECK(!model.name.empty());
base::flat_set<std::string_view> stop_sequences(kStopSequences.begin(),
kStopSequences.end());
api_ = std::make_unique<RemoteCompletionClient>(
model_name, stop_sequences, url_loader_factory, credential_manager);
model.name, stop_sequences, url_loader_factory, credential_manager);

max_page_content_length_ = model.max_page_content_length;
}
Expand Down
53 changes: 32 additions & 21 deletions components/ai_chat/core/browser/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@

#include "brave/components/ai_chat/core/browser/models.h"

#include "base/no_destructor.h"
#include "brave/components/ai_chat/core/common/features.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"

namespace ai_chat {

// When adding new models, especially for display, make sure to add the UI
// strings to ai_chat_ui_strings.grdp and ai_chat/core/constants.cc.
// This also applies for modifying keys, since some of the strings are based
// on the model key.
// on the model key. Also be sure to migrate prefs if changing or removing
// keys.

// Llama2 Token Allocation:
// - Llama2 has a context limit: tokens + max_new_tokens <= 4096
Expand All @@ -32,25 +35,33 @@ namespace ai_chat {
// - Reserverd for page content: 100k / 2 = 50k tokens
// - Long conversation warning threshold: 100k * 0.80 = 80k tokens

const base::flat_map<std::string_view, mojom::Model> kAllModels = {
{"chat-default",
{"chat-default", "llama-2-13b-chat", "llama2 13b", "Meta",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, false,
9000, 9700}},
{"chat-leo-expanded",
{"chat-leo-expanded", "llama-2-70b-chat", "llama2 70b", "Meta",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, true,
9000, 9700}},
{"chat-claude-instant",
{"chat-claude-instant", "claude-instant-v1", "Claude Instant", "Anthropic",
mojom::ModelEngineType::CLAUDE_REMOTE, mojom::ModelCategory::CHAT, true,
200000, 320000}},
};

const std::vector<std::string_view> kAllModelKeysDisplayOrder = {
"chat-default",
"chat-leo-expanded",
"chat-claude-instant",
};
const std::vector<ai_chat::mojom::Model>& GetAllModels() {
static const auto kFreemiumAccess =
features::kFreemiumAvailable.Get() ? mojom::ModelAccess::BASIC_AND_PREMIUM
: mojom::ModelAccess::PREMIUM;
static const base::NoDestructor<std::vector<mojom::Model>> kModels({
{"chat-leo-expanded", "mixtral-8x7b-instruct", "Mixtral", "Mistral AI",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT,
kFreemiumAccess, 9000, 9700},
{"chat-claude-instant", "claude-instant-v1", "Claude Instant",
"Anthropic", mojom::ModelEngineType::CLAUDE_REMOTE,
mojom::ModelCategory::CHAT, kFreemiumAccess, 180000, 320000},
{"chat-basic", "llama-2-13b-chat", "llama2 13b", "Meta",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT,
mojom::ModelAccess::BASIC, 9000, 9700},
});
return *kModels;
}

const ai_chat::mojom::Model* GetModel(std::string_view key) {
auto& models = GetAllModels();
auto match = std::find_if(
models.cbegin(), models.cend(),
[&key](const mojom::Model& item) { return item.key == key; });
if (match != models.cend()) {
return &*match;
}
return nullptr;
}

} // namespace ai_chat
Loading

0 comments on commit 4b8584b

Please sign in to comment.