From c74b8b7773958844417ce68227c60eb49c8ce910 Mon Sep 17 00:00:00 2001 From: Pete Miller Date: Fri, 15 Dec 2023 21:53:58 +1300 Subject: [PATCH] AI Chat: introduce freemium model concept --- .../brave_settings_leo_assistant_handler.cc | 4 +++- ...ave_settings_localized_strings_provider.cc | 2 +- components/ai_chat/core/browser/constants.cc | 4 ++-- .../core/browser/conversation_driver.cc | 10 +++++----- .../browser/engine/engine_consumer_claude.cc | 11 ++-------- .../browser/engine/engine_consumer_llama.cc | 11 ++-------- components/ai_chat/core/browser/models.cc | 20 +++++++++---------- components/ai_chat/core/browser/models.h | 4 +--- components/ai_chat/core/common/features.cc | 6 ++++-- components/ai_chat/core/common/features.h | 3 ++- .../ai_chat/core/common/mojom/ai_chat.mojom | 12 ++++++++++- components/ai_chat/core/common/pref_names.cc | 4 +++- .../components/feature_button_menu/index.tsx | 13 ++++++++++-- .../resources/page/components/main/index.tsx | 2 +- .../page/state/data-context-provider.tsx | 4 ++-- .../page/stories/components_panel.tsx | 15 ++++++++++++-- components/resources/ai_chat_ui_strings.grdp | 4 ++-- 17 files changed, 75 insertions(+), 54 deletions(-) diff --git a/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc b/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc index 775b95394b23..0b75c37a2377 100644 --- a/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc +++ b/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc @@ -15,6 +15,7 @@ #include "brave/browser/ui/sidebar/sidebar_service_factory.h" #include "brave/components/ai_chat/core/browser/ai_chat_metrics.h" #include "brave/components/ai_chat/core/browser/models.h" +#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h" #include "brave/components/ai_chat/core/common/pref_names.h" #include "brave/components/sidebar/sidebar_item.h" @@ -172,7 +173,8 @@ void BraveLeoAssistantHandler::HandleGetModels(const base::Value::List& args) { dict.Set("display_maker", model->display_maker); dict.Set("engine_type", static_cast(model->engine_type)); dict.Set("category", static_cast(model->category)); - dict.Set("is_premium", model->is_premium); + dict.Set("is_premium", + model->access == ai_chat::mojom::ModelAccess::PREMIUM); models_list.Append(std::move(dict)); } diff --git a/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc b/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc index b733ac131290..d57398b6ce1d 100644 --- a/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc +++ b/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc @@ -402,7 +402,7 @@ void BraveAddCommonStrings(content::WebUIDataSource* html_source, {"braveLeoAssistantModelSelectionLabel", IDS_SETTINGS_LEO_ASSISTANT_MODEL_SELECTION_LABEL}, {"braveLeoModelCategory-chat", IDS_CHAT_UI_MODEL_CATEGORY_CHAT}, - {"braveLeoModelSubtitle-chat-default", IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE}, + {"braveLeoModelSubtitle-chat-basic", IDS_CHAT_UI_CHAT_BASIC_SUBTITLE}, {"braveLeoModelSubtitle-chat-leo-expanded", IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE}, {"braveLeoModelSubtitle-chat-claude-instant", diff --git a/components/ai_chat/core/browser/constants.cc b/components/ai_chat/core/browser/constants.cc index 6530173542d2..ee6b65302591 100644 --- a/components/ai_chat/core/browser/constants.cc +++ b/components/ai_chat/core/browser/constants.cc @@ -23,7 +23,7 @@ base::span GetLocalizedStrings() { {"errorNetworkLabel", IDS_CHAT_UI_ERROR_NETWORK}, {"errorRateLimit", IDS_CHAT_UI_ERROR_RATE_LIMIT}, {"retryButtonLabel", IDS_CHAT_UI_RETRY_BUTTON_LABEL}, - {"introMessage-chat-default", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_DEFAULT}, + {"introMessage-chat-basic", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_BASIC}, {"introMessage-chat-leo-expanded", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_LEO_EXPANDED}, {"introMessage-chat-claude-instant", @@ -74,7 +74,7 @@ base::span GetLocalizedStrings() { {"optionOther", IDS_CHAT_UI_OPTION_OTHER}, {"feedbackError", IDS_CHAT_UI_FEEDBACK_SUBMIT_ERROR}, {"ratingError", IDS_CHAT_UI_RATING_ERROR}, - {"braveLeoModelSubtitle-chat-default", IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE}, + {"braveLeoModelSubtitle-chat-basic", IDS_CHAT_UI_CHAT_BASIC_SUBTITLE}, {"braveLeoModelSubtitle-chat-leo-expanded", IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE}, {"braveLeoModelSubtitle-chat-claude-instant", diff --git a/components/ai_chat/core/browser/conversation_driver.cc b/components/ai_chat/core/browser/conversation_driver.cc index 1cb293e049a3..b13857aaa9f4 100644 --- a/components/ai_chat/core/browser/conversation_driver.cc +++ b/components/ai_chat/core/browser/conversation_driver.cc @@ -87,7 +87,7 @@ ConversationDriver::ConversationDriver( return; } // Use default premium model for this instance - instance->ChangeModel(kModelsPremiumDefaultKey); + instance->ChangeModel(features::kAIModelsPremiumDefaultKey.Get()); // Make sure default model reflects premium status const auto* current_default = instance->pref_service_ @@ -95,10 +95,10 @@ ConversationDriver::ConversationDriver( ->GetIfString(); if (current_default && - *current_default != kModelsPremiumDefaultKey) { + *current_default != features::kAIModelsPremiumDefaultKey.Get()) { instance->pref_service_->SetDefaultPrefValue( prefs::kDefaultModelKey, - base::Value(kModelsPremiumDefaultKey)); + base::Value(features::kAIModelsPremiumDefaultKey.Get())); } }, // Unretained is ok as credential manager is owned by this class, @@ -189,7 +189,7 @@ void ConversationDriver::InitEngine() { if (model_match == kAllModels.end()) { NOTREACHED() << "Model was not part of static model list"; // Use default - model_match = kAllModels.find(kModelsDefaultKey); + model_match = kAllModels.find(features::kAIModelsDefaultKey.Get()); const auto is_found = model_match != kAllModels.end(); DCHECK(is_found); if (!is_found) { @@ -768,7 +768,7 @@ void ConversationDriver::OnPremiumStatusReceived( if (last_premium_status_ != premium_status && premium_status == mojom::PremiumStatus::Active) { // Change model if we haven't already - ChangeModel(kModelsPremiumDefaultKey); + ChangeModel(features::kAIModelsPremiumDefaultKey.Get()); } last_premium_status_ = premium_status; if (HasUserOptedIn()) { diff --git a/components/ai_chat/core/browser/engine/engine_consumer_claude.cc b/components/ai_chat/core/browser/engine/engine_consumer_claude.cc index 6b92ed5f8a7c..99faf2fe5fdd 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_claude.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_claude.cc @@ -125,18 +125,11 @@ EngineConsumerClaudeRemote::EngineConsumerClaudeRemote( const mojom::Model& model, scoped_refptr url_loader_factory, AIChatCredentialManager* credential_manager) { - // Allow specific model name to be overriden by feature flag - // TODO(petemill): verify premium status, or ensure server will verify even - // when given a model name override via cli flag param. - std::string model_name = ai_chat::features::kAIModelName.Get(); - if (model_name.empty()) { - model_name = model.name; - } - DCHECK(!model_name.empty()); + DCHECK(!model.name.empty()); base::flat_set stop_sequences(kStopSequences.begin(), kStopSequences.end()); api_ = std::make_unique( - model_name, stop_sequences, url_loader_factory, credential_manager); + model.name, stop_sequences, url_loader_factory, credential_manager); max_page_content_length_ = model.max_page_content_length; } diff --git a/components/ai_chat/core/browser/engine/engine_consumer_llama.cc b/components/ai_chat/core/browser/engine/engine_consumer_llama.cc index 6c2c225e7dac..fe15838cd3b2 100644 --- a/components/ai_chat/core/browser/engine/engine_consumer_llama.cc +++ b/components/ai_chat/core/browser/engine/engine_consumer_llama.cc @@ -243,18 +243,11 @@ EngineConsumerLlamaRemote::EngineConsumerLlamaRemote( const mojom::Model& model, scoped_refptr url_loader_factory, AIChatCredentialManager* credential_manager) { - // Allow specific model name to be overriden by feature flag - // TODO(petemill): verify premium status, or ensure server will verify even - // when given a model name override via cli flag param. - std::string model_name = ai_chat::features::kAIModelName.Get(); - if (model_name.empty()) { - model_name = model.name; - } - DCHECK(!model_name.empty()); + DCHECK(!model.name.empty()); base::flat_set stop_sequences(kStopSequences.begin(), kStopSequences.end()); api_ = std::make_unique( - model_name, stop_sequences, url_loader_factory, credential_manager); + model.name, stop_sequences, url_loader_factory, credential_manager); max_page_content_length_ = model.max_page_content_length; } diff --git a/components/ai_chat/core/browser/models.cc b/components/ai_chat/core/browser/models.cc index 0de03d5460fd..b71b43f503cc 100644 --- a/components/ai_chat/core/browser/models.cc +++ b/components/ai_chat/core/browser/models.cc @@ -33,24 +33,24 @@ namespace ai_chat { // - Long conversation warning threshold: 100k * 0.80 = 80k tokens const base::flat_map kAllModels = { - {"chat-default", - {"chat-default", "llama-2-13b-chat", "llama2 13b", "Meta", - mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, false, - 9000, 9700}}, + {"chat-basic", + {"chat-basic", "llama-2-13b-chat", "llama2 13b", "Meta", + mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, + mojom::ModelAccess::BASIC, 9000, 9700}}, {"chat-leo-expanded", {"chat-leo-expanded", "llama-2-70b-chat", "llama2 70b", "Meta", - mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, true, - 9000, 9700}}, + mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, + mojom::ModelAccess::PREMIUM, 9000, 9700}}, {"chat-claude-instant", {"chat-claude-instant", "claude-instant-v1", "Claude Instant", "Anthropic", - mojom::ModelEngineType::CLAUDE_REMOTE, mojom::ModelCategory::CHAT, true, - 200000, 320000}}, + mojom::ModelEngineType::CLAUDE_REMOTE, mojom::ModelCategory::CHAT, + mojom::ModelAccess::BASIC_AND_PREMIUM, 200000, 320000}}, }; const std::vector kAllModelKeysDisplayOrder = { - "chat-default", - "chat-leo-expanded", "chat-claude-instant", + "chat-basic", + "chat-leo-expanded", }; } // namespace ai_chat diff --git a/components/ai_chat/core/browser/models.h b/components/ai_chat/core/browser/models.h index 9492e6dfb6d7..6749cf29e3d9 100644 --- a/components/ai_chat/core/browser/models.h +++ b/components/ai_chat/core/browser/models.h @@ -9,13 +9,11 @@ #include #include "base/containers/flat_map.h" +#include "brave/components/ai_chat/core/common/features.h" #include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h" namespace ai_chat { -inline constexpr char kModelsDefaultKey[] = "chat-default"; -inline constexpr char kModelsPremiumDefaultKey[] = "chat-claude-instant"; - // All models that the user can choose for chat conversations. extern const base::flat_map kAllModels; // UI display order for models diff --git a/components/ai_chat/core/common/features.cc b/components/ai_chat/core/common/features.cc index 251762a41635..267cf8c61ccc 100644 --- a/components/ai_chat/core/common/features.cc +++ b/components/ai_chat/core/common/features.cc @@ -21,8 +21,10 @@ BASE_FEATURE(kAIChat, base::FEATURE_DISABLED_BY_DEFAULT #endif ); -const base::FeatureParam kAIModelName{&kAIChat, "ai_model_name", - ""}; +const base::FeatureParam kAIModelsDefaultKey{&kAIChat, "default_model", + "chat-claude-instant"}; +const base::FeatureParam kAIModelsPremiumDefaultKey{&kAIChat, "default_premium_model", + "chat-claude-instant"}; const base::FeatureParam kAIChatSSE{&kAIChat, "ai_chat_sse", true}; const base::FeatureParam kAITemperature{&kAIChat, "temperature", 0.2}; diff --git a/components/ai_chat/core/common/features.h b/components/ai_chat/core/common/features.h index 29bf90b9b8d3..bd44817904c3 100644 --- a/components/ai_chat/core/common/features.h +++ b/components/ai_chat/core/common/features.h @@ -14,7 +14,8 @@ namespace ai_chat::features { BASE_DECLARE_FEATURE(kAIChat); -extern const base::FeatureParam kAIModelName; +extern const base::FeatureParam kAIModelsDefaultKey; +extern const base::FeatureParam kAIModelsPremiumDefaultKey; extern const base::FeatureParam kAIChatSSE; extern const base::FeatureParam kAITemperature; diff --git a/components/ai_chat/core/common/mojom/ai_chat.mojom b/components/ai_chat/core/common/mojom/ai_chat.mojom index 52b7e92cce5d..195e1fcca650 100644 --- a/components/ai_chat/core/common/mojom/ai_chat.mojom +++ b/components/ai_chat/core/common/mojom/ai_chat.mojom @@ -36,6 +36,15 @@ enum ModelCategory { CHAT, }; +enum ModelAccess { + // The model only has a single basic tier, accessible by any level + BASIC, + // The model has a basic tier and a more capable premium tier (a.k.a freemium) + BASIC_AND_PREMIUM, + // The model only has a premium tier + PREMIUM, +}; + enum PremiumStatus { Inactive, Active, @@ -96,7 +105,8 @@ struct Model { ModelEngineType engine_type; // user-facing category ModelCategory category; - bool is_premium; + // Which access level grants permission to use the model + ModelAccess access; // max limit to truncate page contents (measured in chars, not tokens) uint32 max_page_content_length; // max limit for overall conversation (measured in chars, not tokens) diff --git a/components/ai_chat/core/common/pref_names.cc b/components/ai_chat/core/common/pref_names.cc index 5892b6cfc74b..288a909a3c87 100644 --- a/components/ai_chat/core/common/pref_names.cc +++ b/components/ai_chat/core/common/pref_names.cc @@ -5,6 +5,7 @@ #include "brave/components/ai_chat/core/common/pref_names.h" +#include "brave/components/ai_chat/core/common/features.h" #include "components/prefs/pref_registry_simple.h" #include "components/prefs/pref_service.h" @@ -21,7 +22,7 @@ void RegisterProfilePrefs(PrefRegistrySimple* registry) { registry->RegisterTimePref(kLastAcceptedDisclaimer, {}); registry->RegisterBooleanPref(kBraveChatAutocompleteProviderEnabled, true); registry->RegisterBooleanPref(kUserDismissedPremiumPrompt, false); - registry->RegisterStringPref(kDefaultModelKey, "chat-default"); + registry->RegisterStringPref(kDefaultModelKey, features::kAIModelsDefaultKey.Get()); #if BUILDFLAG(IS_ANDROID) registry->RegisterBooleanPref(kBraveChatSubscriptionActiveAndroid, false); registry->RegisterStringPref(kBraveChatPurchaseTokenAndroid, ""); @@ -36,6 +37,7 @@ void RegisterProfilePrefsForMigration(PrefRegistrySimple* registry) { void MigrateProfilePrefs(PrefService* profile_prefs) { profile_prefs->ClearPref(kObseleteBraveChatAutoGenerateQuestions); + // TODO(petemill): migrate model key from "chat-default" to "chat-basic" } void RegisterLocalStatePrefs(PrefRegistrySimple* registry) { diff --git a/components/ai_chat/resources/page/components/feature_button_menu/index.tsx b/components/ai_chat/resources/page/components/feature_button_menu/index.tsx index 71adf4c8c649..8941cb40d19b 100644 --- a/components/ai_chat/resources/page/components/feature_button_menu/index.tsx +++ b/components/ai_chat/resources/page/components/feature_button_menu/index.tsx @@ -7,8 +7,9 @@ import * as React from 'react' import ButtonMenu from '@brave/leo/react/buttonMenu' import Button from '@brave/leo/react/button' import Icon from '@brave/leo/react/icon' +import Label from '@brave/leo/react/label' import { getLocale } from '$web-common/locale' -import getPageHandlerInstance from '../../api/page_handler' +import getPageHandlerInstance, * as mojom from '../../api/page_handler' import DataContext from '../../state/context' import styles from './style.module.scss' import classnames from '$web-common/classnames' @@ -55,7 +56,7 @@ export default function FeatureMenu() { {getLocale(`braveLeoModelSubtitle-${model.key}`)}

- {model.isPremium && ( + {model.access === mojom.ModelAccess.PREMIUM && ( )} + {model.access === mojom.ModelAccess.BASIC_AND_PREMIUM && ( + + )} ))} diff --git a/components/ai_chat/resources/page/components/main/index.tsx b/components/ai_chat/resources/page/components/main/index.tsx index 1fb60230ebff..143cd1821193 100644 --- a/components/ai_chat/resources/page/components/main/index.tsx +++ b/components/ai_chat/resources/page/components/main/index.tsx @@ -45,7 +45,7 @@ function Main() { hasAcceptedAgreement && !context.isPremiumStatusFetching && // Avoid flash of content !context.isPremiumUser && - context.currentModel?.isPremium + context.currentModel?.access === mojom.ModelAccess.PREMIUM const shouldShowPremiumSuggestionStandalone = hasAcceptedAgreement && diff --git a/components/ai_chat/resources/page/state/data-context-provider.tsx b/components/ai_chat/resources/page/state/data-context-provider.tsx index d96211634fda..330a2b791148 100644 --- a/components/ai_chat/resources/page/state/data-context-provider.tsx +++ b/components/ai_chat/resources/page/state/data-context-provider.tsx @@ -67,7 +67,7 @@ function DataContextProvider (props: DataContextProviderProps) { const isPremiumUser = premiumStatus !== undefined && premiumStatus !== mojom.PremiumStatus.Inactive const apiHasError = (currentError !== mojom.APIError.None) - const shouldDisableUserInput = !!(apiHasError || isGenerating || (!isPremiumUser && currentModel?.isPremium)) + const shouldDisableUserInput = !!(apiHasError || isGenerating || (!isPremiumUser && currentModel?.access === mojom.ModelAccess.PREMIUM)) const getConversationHistory = () => { getPageHandlerInstance() @@ -127,7 +127,7 @@ function DataContextProvider (props: DataContextProviderProps) { const switchToDefaultModel = () => { // Select the first non-premium model - const nonPremium = allModels.find(m => !m.isPremium) + const nonPremium = allModels.find(m => [mojom.ModelAccess.BASIC, mojom.ModelAccess.BASIC_AND_PREMIUM].includes(m.access)) if (!nonPremium) { console.error('Could not find a non-premium model!') return diff --git a/components/ai_chat/resources/page/stories/components_panel.tsx b/components/ai_chat/resources/page/stories/components_panel.tsx index ab86ec144452..b44695eed853 100644 --- a/components/ai_chat/resources/page/stories/components_panel.tsx +++ b/components/ai_chat/resources/page/stories/components_panel.tsx @@ -49,7 +49,7 @@ const MODELS: mojom.Model[] = [ displayMaker: 'Company', engineType: mojom.ModelEngineType.LLAMA_REMOTE, category: mojom.ModelCategory.CHAT, - isPremium: false, + access: mojom.ModelAccess.BASIC, maxPageContentLength: 10000, longConversationWarningCharacterLimit: 9700 }, @@ -60,7 +60,18 @@ const MODELS: mojom.Model[] = [ displayMaker: 'Company', engineType: mojom.ModelEngineType.LLAMA_REMOTE, category: mojom.ModelCategory.CHAT, - isPremium: true, + access: mojom.ModelAccess.PREMIUM, + maxPageContentLength: 10000, + longConversationWarningCharacterLimit: 9700 + }, + { + key: '3', + name: 'model-three-freemium', + displayName: 'Model Three', + displayMaker: 'Company', + engineType: mojom.ModelEngineType.LLAMA_REMOTE, + category: mojom.ModelCategory.CHAT, + access: mojom.ModelAccess.BASIC_AND_PREMIUM, maxPageContentLength: 10000, longConversationWarningCharacterLimit: 9700 } diff --git a/components/resources/ai_chat_ui_strings.grdp b/components/resources/ai_chat_ui_strings.grdp index 37f17ed391bd..3ca67f27588d 100644 --- a/components/resources/ai_chat_ui_strings.grdp +++ b/components/resources/ai_chat_ui_strings.grdp @@ -51,7 +51,7 @@ Retry - + Hi, I'm Leo. I'm a fully hosted AI assistant by Brave. I'm powered by Llama 13B, a model created by Meta to be performant and applicable to many use cases. @@ -202,7 +202,7 @@ This conversation is too long and cannot continue. There may be other models available with which Leo is capable of maintaining accuracy for longer conversations. - + General purpose chat