From c74b8b7773958844417ce68227c60eb49c8ce910 Mon Sep 17 00:00:00 2001
From: Pete Miller
Date: Fri, 15 Dec 2023 21:53:58 +1300
Subject: [PATCH] AI Chat: introduce freemium model concept
---
.../brave_settings_leo_assistant_handler.cc | 4 +++-
...ave_settings_localized_strings_provider.cc | 2 +-
components/ai_chat/core/browser/constants.cc | 4 ++--
.../core/browser/conversation_driver.cc | 10 +++++-----
.../browser/engine/engine_consumer_claude.cc | 11 ++--------
.../browser/engine/engine_consumer_llama.cc | 11 ++--------
components/ai_chat/core/browser/models.cc | 20 +++++++++----------
components/ai_chat/core/browser/models.h | 4 +---
components/ai_chat/core/common/features.cc | 6 ++++--
components/ai_chat/core/common/features.h | 3 ++-
.../ai_chat/core/common/mojom/ai_chat.mojom | 12 ++++++++++-
components/ai_chat/core/common/pref_names.cc | 4 +++-
.../components/feature_button_menu/index.tsx | 13 ++++++++++--
.../resources/page/components/main/index.tsx | 2 +-
.../page/state/data-context-provider.tsx | 4 ++--
.../page/stories/components_panel.tsx | 15 ++++++++++++--
components/resources/ai_chat_ui_strings.grdp | 4 ++--
17 files changed, 75 insertions(+), 54 deletions(-)
diff --git a/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc b/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc
index 775b95394b23..0b75c37a2377 100644
--- a/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc
+++ b/browser/ui/webui/settings/brave_settings_leo_assistant_handler.cc
@@ -15,6 +15,7 @@
#include "brave/browser/ui/sidebar/sidebar_service_factory.h"
#include "brave/components/ai_chat/core/browser/ai_chat_metrics.h"
#include "brave/components/ai_chat/core/browser/models.h"
+#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
#include "brave/components/ai_chat/core/common/pref_names.h"
#include "brave/components/sidebar/sidebar_item.h"
@@ -172,7 +173,8 @@ void BraveLeoAssistantHandler::HandleGetModels(const base::Value::List& args) {
dict.Set("display_maker", model->display_maker);
dict.Set("engine_type", static_cast(model->engine_type));
dict.Set("category", static_cast(model->category));
- dict.Set("is_premium", model->is_premium);
+ dict.Set("is_premium",
+ model->access == ai_chat::mojom::ModelAccess::PREMIUM);
models_list.Append(std::move(dict));
}
diff --git a/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc b/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc
index b733ac131290..d57398b6ce1d 100644
--- a/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc
+++ b/browser/ui/webui/settings/brave_settings_localized_strings_provider.cc
@@ -402,7 +402,7 @@ void BraveAddCommonStrings(content::WebUIDataSource* html_source,
{"braveLeoAssistantModelSelectionLabel",
IDS_SETTINGS_LEO_ASSISTANT_MODEL_SELECTION_LABEL},
{"braveLeoModelCategory-chat", IDS_CHAT_UI_MODEL_CATEGORY_CHAT},
- {"braveLeoModelSubtitle-chat-default", IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE},
+ {"braveLeoModelSubtitle-chat-basic", IDS_CHAT_UI_CHAT_BASIC_SUBTITLE},
{"braveLeoModelSubtitle-chat-leo-expanded",
IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE},
{"braveLeoModelSubtitle-chat-claude-instant",
diff --git a/components/ai_chat/core/browser/constants.cc b/components/ai_chat/core/browser/constants.cc
index 6530173542d2..ee6b65302591 100644
--- a/components/ai_chat/core/browser/constants.cc
+++ b/components/ai_chat/core/browser/constants.cc
@@ -23,7 +23,7 @@ base::span GetLocalizedStrings() {
{"errorNetworkLabel", IDS_CHAT_UI_ERROR_NETWORK},
{"errorRateLimit", IDS_CHAT_UI_ERROR_RATE_LIMIT},
{"retryButtonLabel", IDS_CHAT_UI_RETRY_BUTTON_LABEL},
- {"introMessage-chat-default", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_DEFAULT},
+ {"introMessage-chat-basic", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_BASIC},
{"introMessage-chat-leo-expanded",
IDS_CHAT_UI_INTRO_MESSAGE_CHAT_LEO_EXPANDED},
{"introMessage-chat-claude-instant",
@@ -74,7 +74,7 @@ base::span GetLocalizedStrings() {
{"optionOther", IDS_CHAT_UI_OPTION_OTHER},
{"feedbackError", IDS_CHAT_UI_FEEDBACK_SUBMIT_ERROR},
{"ratingError", IDS_CHAT_UI_RATING_ERROR},
- {"braveLeoModelSubtitle-chat-default", IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE},
+ {"braveLeoModelSubtitle-chat-basic", IDS_CHAT_UI_CHAT_BASIC_SUBTITLE},
{"braveLeoModelSubtitle-chat-leo-expanded",
IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE},
{"braveLeoModelSubtitle-chat-claude-instant",
diff --git a/components/ai_chat/core/browser/conversation_driver.cc b/components/ai_chat/core/browser/conversation_driver.cc
index 1cb293e049a3..b13857aaa9f4 100644
--- a/components/ai_chat/core/browser/conversation_driver.cc
+++ b/components/ai_chat/core/browser/conversation_driver.cc
@@ -87,7 +87,7 @@ ConversationDriver::ConversationDriver(
return;
}
// Use default premium model for this instance
- instance->ChangeModel(kModelsPremiumDefaultKey);
+ instance->ChangeModel(features::kAIModelsPremiumDefaultKey.Get());
// Make sure default model reflects premium status
const auto* current_default =
instance->pref_service_
@@ -95,10 +95,10 @@ ConversationDriver::ConversationDriver(
->GetIfString();
if (current_default &&
- *current_default != kModelsPremiumDefaultKey) {
+ *current_default != features::kAIModelsPremiumDefaultKey.Get()) {
instance->pref_service_->SetDefaultPrefValue(
prefs::kDefaultModelKey,
- base::Value(kModelsPremiumDefaultKey));
+ base::Value(features::kAIModelsPremiumDefaultKey.Get()));
}
},
// Unretained is ok as credential manager is owned by this class,
@@ -189,7 +189,7 @@ void ConversationDriver::InitEngine() {
if (model_match == kAllModels.end()) {
NOTREACHED() << "Model was not part of static model list";
// Use default
- model_match = kAllModels.find(kModelsDefaultKey);
+ model_match = kAllModels.find(features::kAIModelsDefaultKey.Get());
const auto is_found = model_match != kAllModels.end();
DCHECK(is_found);
if (!is_found) {
@@ -768,7 +768,7 @@ void ConversationDriver::OnPremiumStatusReceived(
if (last_premium_status_ != premium_status &&
premium_status == mojom::PremiumStatus::Active) {
// Change model if we haven't already
- ChangeModel(kModelsPremiumDefaultKey);
+ ChangeModel(features::kAIModelsPremiumDefaultKey.Get());
}
last_premium_status_ = premium_status;
if (HasUserOptedIn()) {
diff --git a/components/ai_chat/core/browser/engine/engine_consumer_claude.cc b/components/ai_chat/core/browser/engine/engine_consumer_claude.cc
index 6b92ed5f8a7c..99faf2fe5fdd 100644
--- a/components/ai_chat/core/browser/engine/engine_consumer_claude.cc
+++ b/components/ai_chat/core/browser/engine/engine_consumer_claude.cc
@@ -125,18 +125,11 @@ EngineConsumerClaudeRemote::EngineConsumerClaudeRemote(
const mojom::Model& model,
scoped_refptr url_loader_factory,
AIChatCredentialManager* credential_manager) {
- // Allow specific model name to be overriden by feature flag
- // TODO(petemill): verify premium status, or ensure server will verify even
- // when given a model name override via cli flag param.
- std::string model_name = ai_chat::features::kAIModelName.Get();
- if (model_name.empty()) {
- model_name = model.name;
- }
- DCHECK(!model_name.empty());
+ DCHECK(!model.name.empty());
base::flat_set stop_sequences(kStopSequences.begin(),
kStopSequences.end());
api_ = std::make_unique(
- model_name, stop_sequences, url_loader_factory, credential_manager);
+ model.name, stop_sequences, url_loader_factory, credential_manager);
max_page_content_length_ = model.max_page_content_length;
}
diff --git a/components/ai_chat/core/browser/engine/engine_consumer_llama.cc b/components/ai_chat/core/browser/engine/engine_consumer_llama.cc
index 6c2c225e7dac..fe15838cd3b2 100644
--- a/components/ai_chat/core/browser/engine/engine_consumer_llama.cc
+++ b/components/ai_chat/core/browser/engine/engine_consumer_llama.cc
@@ -243,18 +243,11 @@ EngineConsumerLlamaRemote::EngineConsumerLlamaRemote(
const mojom::Model& model,
scoped_refptr url_loader_factory,
AIChatCredentialManager* credential_manager) {
- // Allow specific model name to be overriden by feature flag
- // TODO(petemill): verify premium status, or ensure server will verify even
- // when given a model name override via cli flag param.
- std::string model_name = ai_chat::features::kAIModelName.Get();
- if (model_name.empty()) {
- model_name = model.name;
- }
- DCHECK(!model_name.empty());
+ DCHECK(!model.name.empty());
base::flat_set stop_sequences(kStopSequences.begin(),
kStopSequences.end());
api_ = std::make_unique(
- model_name, stop_sequences, url_loader_factory, credential_manager);
+ model.name, stop_sequences, url_loader_factory, credential_manager);
max_page_content_length_ = model.max_page_content_length;
}
diff --git a/components/ai_chat/core/browser/models.cc b/components/ai_chat/core/browser/models.cc
index 0de03d5460fd..b71b43f503cc 100644
--- a/components/ai_chat/core/browser/models.cc
+++ b/components/ai_chat/core/browser/models.cc
@@ -33,24 +33,24 @@ namespace ai_chat {
// - Long conversation warning threshold: 100k * 0.80 = 80k tokens
const base::flat_map kAllModels = {
- {"chat-default",
- {"chat-default", "llama-2-13b-chat", "llama2 13b", "Meta",
- mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, false,
- 9000, 9700}},
+ {"chat-basic",
+ {"chat-basic", "llama-2-13b-chat", "llama2 13b", "Meta",
+ mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT,
+ mojom::ModelAccess::BASIC, 9000, 9700}},
{"chat-leo-expanded",
{"chat-leo-expanded", "llama-2-70b-chat", "llama2 70b", "Meta",
- mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, true,
- 9000, 9700}},
+ mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT,
+ mojom::ModelAccess::PREMIUM, 9000, 9700}},
{"chat-claude-instant",
{"chat-claude-instant", "claude-instant-v1", "Claude Instant", "Anthropic",
- mojom::ModelEngineType::CLAUDE_REMOTE, mojom::ModelCategory::CHAT, true,
- 200000, 320000}},
+ mojom::ModelEngineType::CLAUDE_REMOTE, mojom::ModelCategory::CHAT,
+ mojom::ModelAccess::BASIC_AND_PREMIUM, 200000, 320000}},
};
const std::vector kAllModelKeysDisplayOrder = {
- "chat-default",
- "chat-leo-expanded",
"chat-claude-instant",
+ "chat-basic",
+ "chat-leo-expanded",
};
} // namespace ai_chat
diff --git a/components/ai_chat/core/browser/models.h b/components/ai_chat/core/browser/models.h
index 9492e6dfb6d7..6749cf29e3d9 100644
--- a/components/ai_chat/core/browser/models.h
+++ b/components/ai_chat/core/browser/models.h
@@ -9,13 +9,11 @@
#include
#include "base/containers/flat_map.h"
+#include "brave/components/ai_chat/core/common/features.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h"
namespace ai_chat {
-inline constexpr char kModelsDefaultKey[] = "chat-default";
-inline constexpr char kModelsPremiumDefaultKey[] = "chat-claude-instant";
-
// All models that the user can choose for chat conversations.
extern const base::flat_map kAllModels;
// UI display order for models
diff --git a/components/ai_chat/core/common/features.cc b/components/ai_chat/core/common/features.cc
index 251762a41635..267cf8c61ccc 100644
--- a/components/ai_chat/core/common/features.cc
+++ b/components/ai_chat/core/common/features.cc
@@ -21,8 +21,10 @@ BASE_FEATURE(kAIChat,
base::FEATURE_DISABLED_BY_DEFAULT
#endif
);
-const base::FeatureParam kAIModelName{&kAIChat, "ai_model_name",
- ""};
+const base::FeatureParam kAIModelsDefaultKey{&kAIChat, "default_model",
+ "chat-claude-instant"};
+const base::FeatureParam kAIModelsPremiumDefaultKey{&kAIChat, "default_premium_model",
+ "chat-claude-instant"};
const base::FeatureParam kAIChatSSE{&kAIChat, "ai_chat_sse", true};
const base::FeatureParam kAITemperature{&kAIChat, "temperature", 0.2};
diff --git a/components/ai_chat/core/common/features.h b/components/ai_chat/core/common/features.h
index 29bf90b9b8d3..bd44817904c3 100644
--- a/components/ai_chat/core/common/features.h
+++ b/components/ai_chat/core/common/features.h
@@ -14,7 +14,8 @@
namespace ai_chat::features {
BASE_DECLARE_FEATURE(kAIChat);
-extern const base::FeatureParam kAIModelName;
+extern const base::FeatureParam kAIModelsDefaultKey;
+extern const base::FeatureParam kAIModelsPremiumDefaultKey;
extern const base::FeatureParam kAIChatSSE;
extern const base::FeatureParam kAITemperature;
diff --git a/components/ai_chat/core/common/mojom/ai_chat.mojom b/components/ai_chat/core/common/mojom/ai_chat.mojom
index 52b7e92cce5d..195e1fcca650 100644
--- a/components/ai_chat/core/common/mojom/ai_chat.mojom
+++ b/components/ai_chat/core/common/mojom/ai_chat.mojom
@@ -36,6 +36,15 @@ enum ModelCategory {
CHAT,
};
+enum ModelAccess {
+ // The model only has a single basic tier, accessible by any level
+ BASIC,
+ // The model has a basic tier and a more capable premium tier (a.k.a freemium)
+ BASIC_AND_PREMIUM,
+ // The model only has a premium tier
+ PREMIUM,
+};
+
enum PremiumStatus {
Inactive,
Active,
@@ -96,7 +105,8 @@ struct Model {
ModelEngineType engine_type;
// user-facing category
ModelCategory category;
- bool is_premium;
+ // Which access level grants permission to use the model
+ ModelAccess access;
// max limit to truncate page contents (measured in chars, not tokens)
uint32 max_page_content_length;
// max limit for overall conversation (measured in chars, not tokens)
diff --git a/components/ai_chat/core/common/pref_names.cc b/components/ai_chat/core/common/pref_names.cc
index 5892b6cfc74b..288a909a3c87 100644
--- a/components/ai_chat/core/common/pref_names.cc
+++ b/components/ai_chat/core/common/pref_names.cc
@@ -5,6 +5,7 @@
#include "brave/components/ai_chat/core/common/pref_names.h"
+#include "brave/components/ai_chat/core/common/features.h"
#include "components/prefs/pref_registry_simple.h"
#include "components/prefs/pref_service.h"
@@ -21,7 +22,7 @@ void RegisterProfilePrefs(PrefRegistrySimple* registry) {
registry->RegisterTimePref(kLastAcceptedDisclaimer, {});
registry->RegisterBooleanPref(kBraveChatAutocompleteProviderEnabled, true);
registry->RegisterBooleanPref(kUserDismissedPremiumPrompt, false);
- registry->RegisterStringPref(kDefaultModelKey, "chat-default");
+ registry->RegisterStringPref(kDefaultModelKey, features::kAIModelsDefaultKey.Get());
#if BUILDFLAG(IS_ANDROID)
registry->RegisterBooleanPref(kBraveChatSubscriptionActiveAndroid, false);
registry->RegisterStringPref(kBraveChatPurchaseTokenAndroid, "");
@@ -36,6 +37,7 @@ void RegisterProfilePrefsForMigration(PrefRegistrySimple* registry) {
void MigrateProfilePrefs(PrefService* profile_prefs) {
profile_prefs->ClearPref(kObseleteBraveChatAutoGenerateQuestions);
+ // TODO(petemill): migrate model key from "chat-default" to "chat-basic"
}
void RegisterLocalStatePrefs(PrefRegistrySimple* registry) {
diff --git a/components/ai_chat/resources/page/components/feature_button_menu/index.tsx b/components/ai_chat/resources/page/components/feature_button_menu/index.tsx
index 71adf4c8c649..8941cb40d19b 100644
--- a/components/ai_chat/resources/page/components/feature_button_menu/index.tsx
+++ b/components/ai_chat/resources/page/components/feature_button_menu/index.tsx
@@ -7,8 +7,9 @@ import * as React from 'react'
import ButtonMenu from '@brave/leo/react/buttonMenu'
import Button from '@brave/leo/react/button'
import Icon from '@brave/leo/react/icon'
+import Label from '@brave/leo/react/label'
import { getLocale } from '$web-common/locale'
-import getPageHandlerInstance from '../../api/page_handler'
+import getPageHandlerInstance, * as mojom from '../../api/page_handler'
import DataContext from '../../state/context'
import styles from './style.module.scss'
import classnames from '$web-common/classnames'
@@ -55,7 +56,7 @@ export default function FeatureMenu() {
{getLocale(`braveLeoModelSubtitle-${model.key}`)}
- {model.isPremium && (
+ {model.access === mojom.ModelAccess.PREMIUM && (
)}
+ {model.access === mojom.ModelAccess.BASIC_AND_PREMIUM && (
+
+ )}
))}
diff --git a/components/ai_chat/resources/page/components/main/index.tsx b/components/ai_chat/resources/page/components/main/index.tsx
index 1fb60230ebff..143cd1821193 100644
--- a/components/ai_chat/resources/page/components/main/index.tsx
+++ b/components/ai_chat/resources/page/components/main/index.tsx
@@ -45,7 +45,7 @@ function Main() {
hasAcceptedAgreement &&
!context.isPremiumStatusFetching && // Avoid flash of content
!context.isPremiumUser &&
- context.currentModel?.isPremium
+ context.currentModel?.access === mojom.ModelAccess.PREMIUM
const shouldShowPremiumSuggestionStandalone =
hasAcceptedAgreement &&
diff --git a/components/ai_chat/resources/page/state/data-context-provider.tsx b/components/ai_chat/resources/page/state/data-context-provider.tsx
index d96211634fda..330a2b791148 100644
--- a/components/ai_chat/resources/page/state/data-context-provider.tsx
+++ b/components/ai_chat/resources/page/state/data-context-provider.tsx
@@ -67,7 +67,7 @@ function DataContextProvider (props: DataContextProviderProps) {
const isPremiumUser = premiumStatus !== undefined && premiumStatus !== mojom.PremiumStatus.Inactive
const apiHasError = (currentError !== mojom.APIError.None)
- const shouldDisableUserInput = !!(apiHasError || isGenerating || (!isPremiumUser && currentModel?.isPremium))
+ const shouldDisableUserInput = !!(apiHasError || isGenerating || (!isPremiumUser && currentModel?.access === mojom.ModelAccess.PREMIUM))
const getConversationHistory = () => {
getPageHandlerInstance()
@@ -127,7 +127,7 @@ function DataContextProvider (props: DataContextProviderProps) {
const switchToDefaultModel = () => {
// Select the first non-premium model
- const nonPremium = allModels.find(m => !m.isPremium)
+ const nonPremium = allModels.find(m => [mojom.ModelAccess.BASIC, mojom.ModelAccess.BASIC_AND_PREMIUM].includes(m.access))
if (!nonPremium) {
console.error('Could not find a non-premium model!')
return
diff --git a/components/ai_chat/resources/page/stories/components_panel.tsx b/components/ai_chat/resources/page/stories/components_panel.tsx
index ab86ec144452..b44695eed853 100644
--- a/components/ai_chat/resources/page/stories/components_panel.tsx
+++ b/components/ai_chat/resources/page/stories/components_panel.tsx
@@ -49,7 +49,7 @@ const MODELS: mojom.Model[] = [
displayMaker: 'Company',
engineType: mojom.ModelEngineType.LLAMA_REMOTE,
category: mojom.ModelCategory.CHAT,
- isPremium: false,
+ access: mojom.ModelAccess.BASIC,
maxPageContentLength: 10000,
longConversationWarningCharacterLimit: 9700
},
@@ -60,7 +60,18 @@ const MODELS: mojom.Model[] = [
displayMaker: 'Company',
engineType: mojom.ModelEngineType.LLAMA_REMOTE,
category: mojom.ModelCategory.CHAT,
- isPremium: true,
+ access: mojom.ModelAccess.PREMIUM,
+ maxPageContentLength: 10000,
+ longConversationWarningCharacterLimit: 9700
+ },
+ {
+ key: '3',
+ name: 'model-three-freemium',
+ displayName: 'Model Three',
+ displayMaker: 'Company',
+ engineType: mojom.ModelEngineType.LLAMA_REMOTE,
+ category: mojom.ModelCategory.CHAT,
+ access: mojom.ModelAccess.BASIC_AND_PREMIUM,
maxPageContentLength: 10000,
longConversationWarningCharacterLimit: 9700
}
diff --git a/components/resources/ai_chat_ui_strings.grdp b/components/resources/ai_chat_ui_strings.grdp
index 37f17ed391bd..3ca67f27588d 100644
--- a/components/resources/ai_chat_ui_strings.grdp
+++ b/components/resources/ai_chat_ui_strings.grdp
@@ -51,7 +51,7 @@
Retry
-
+
Hi, I'm Leo. I'm a fully hosted AI assistant by Brave. I'm powered by Llama 13B, a model created by Meta to be performant and applicable to many use cases.
@@ -202,7 +202,7 @@
This conversation is too long and cannot continue. There may be other models available with which Leo is capable of maintaining accuracy for longer conversations.
-
+
General purpose chat