Skip to content

Commit

Permalink
AI Chat: introduce freemium model concept
Browse files Browse the repository at this point in the history
  • Loading branch information
petemill committed Jan 9, 2024
1 parent 29c4e03 commit c74b8b7
Show file tree
Hide file tree
Showing 17 changed files with 75 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "brave/browser/ui/sidebar/sidebar_service_factory.h"
#include "brave/components/ai_chat/core/browser/ai_chat_metrics.h"
#include "brave/components/ai_chat/core/browser/models.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-shared.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom.h"
#include "brave/components/ai_chat/core/common/pref_names.h"
#include "brave/components/sidebar/sidebar_item.h"
Expand Down Expand Up @@ -172,7 +173,8 @@ void BraveLeoAssistantHandler::HandleGetModels(const base::Value::List& args) {
dict.Set("display_maker", model->display_maker);
dict.Set("engine_type", static_cast<int>(model->engine_type));
dict.Set("category", static_cast<int>(model->category));
dict.Set("is_premium", model->is_premium);
dict.Set("is_premium",
model->access == ai_chat::mojom::ModelAccess::PREMIUM);
models_list.Append(std::move(dict));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ void BraveAddCommonStrings(content::WebUIDataSource* html_source,
{"braveLeoAssistantModelSelectionLabel",
IDS_SETTINGS_LEO_ASSISTANT_MODEL_SELECTION_LABEL},
{"braveLeoModelCategory-chat", IDS_CHAT_UI_MODEL_CATEGORY_CHAT},
{"braveLeoModelSubtitle-chat-default", IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE},
{"braveLeoModelSubtitle-chat-basic", IDS_CHAT_UI_CHAT_BASIC_SUBTITLE},
{"braveLeoModelSubtitle-chat-leo-expanded",
IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE},
{"braveLeoModelSubtitle-chat-claude-instant",
Expand Down
4 changes: 2 additions & 2 deletions components/ai_chat/core/browser/constants.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ base::span<const webui::LocalizedString> GetLocalizedStrings() {
{"errorNetworkLabel", IDS_CHAT_UI_ERROR_NETWORK},
{"errorRateLimit", IDS_CHAT_UI_ERROR_RATE_LIMIT},
{"retryButtonLabel", IDS_CHAT_UI_RETRY_BUTTON_LABEL},
{"introMessage-chat-default", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_DEFAULT},
{"introMessage-chat-basic", IDS_CHAT_UI_INTRO_MESSAGE_CHAT_BASIC},
{"introMessage-chat-leo-expanded",
IDS_CHAT_UI_INTRO_MESSAGE_CHAT_LEO_EXPANDED},
{"introMessage-chat-claude-instant",
Expand Down Expand Up @@ -74,7 +74,7 @@ base::span<const webui::LocalizedString> GetLocalizedStrings() {
{"optionOther", IDS_CHAT_UI_OPTION_OTHER},
{"feedbackError", IDS_CHAT_UI_FEEDBACK_SUBMIT_ERROR},
{"ratingError", IDS_CHAT_UI_RATING_ERROR},
{"braveLeoModelSubtitle-chat-default", IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE},
{"braveLeoModelSubtitle-chat-basic", IDS_CHAT_UI_CHAT_BASIC_SUBTITLE},
{"braveLeoModelSubtitle-chat-leo-expanded",
IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE},
{"braveLeoModelSubtitle-chat-claude-instant",
Expand Down
10 changes: 5 additions & 5 deletions components/ai_chat/core/browser/conversation_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,18 +87,18 @@ ConversationDriver::ConversationDriver(
return;
}
// Use default premium model for this instance
instance->ChangeModel(kModelsPremiumDefaultKey);
instance->ChangeModel(features::kAIModelsPremiumDefaultKey.Get());
// Make sure default model reflects premium status
const auto* current_default =
instance->pref_service_
->GetDefaultPrefValue(prefs::kDefaultModelKey)
->GetIfString();

if (current_default &&
*current_default != kModelsPremiumDefaultKey) {
*current_default != features::kAIModelsPremiumDefaultKey.Get()) {
instance->pref_service_->SetDefaultPrefValue(
prefs::kDefaultModelKey,
base::Value(kModelsPremiumDefaultKey));
base::Value(features::kAIModelsPremiumDefaultKey.Get()));
}
},
// Unretained is ok as credential manager is owned by this class,
Expand Down Expand Up @@ -189,7 +189,7 @@ void ConversationDriver::InitEngine() {
if (model_match == kAllModels.end()) {
NOTREACHED() << "Model was not part of static model list";
// Use default
model_match = kAllModels.find(kModelsDefaultKey);
model_match = kAllModels.find(features::kAIModelsDefaultKey.Get());
const auto is_found = model_match != kAllModels.end();
DCHECK(is_found);
if (!is_found) {
Expand Down Expand Up @@ -768,7 +768,7 @@ void ConversationDriver::OnPremiumStatusReceived(
if (last_premium_status_ != premium_status &&
premium_status == mojom::PremiumStatus::Active) {
// Change model if we haven't already
ChangeModel(kModelsPremiumDefaultKey);
ChangeModel(features::kAIModelsPremiumDefaultKey.Get());
}
last_premium_status_ = premium_status;
if (HasUserOptedIn()) {
Expand Down
11 changes: 2 additions & 9 deletions components/ai_chat/core/browser/engine/engine_consumer_claude.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,11 @@ EngineConsumerClaudeRemote::EngineConsumerClaudeRemote(
const mojom::Model& model,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
AIChatCredentialManager* credential_manager) {
// Allow specific model name to be overriden by feature flag
// TODO(petemill): verify premium status, or ensure server will verify even
// when given a model name override via cli flag param.
std::string model_name = ai_chat::features::kAIModelName.Get();
if (model_name.empty()) {
model_name = model.name;
}
DCHECK(!model_name.empty());
DCHECK(!model.name.empty());
base::flat_set<std::string_view> stop_sequences(kStopSequences.begin(),
kStopSequences.end());
api_ = std::make_unique<RemoteCompletionClient>(
model_name, stop_sequences, url_loader_factory, credential_manager);
model.name, stop_sequences, url_loader_factory, credential_manager);

max_page_content_length_ = model.max_page_content_length;
}
Expand Down
11 changes: 2 additions & 9 deletions components/ai_chat/core/browser/engine/engine_consumer_llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,18 +243,11 @@ EngineConsumerLlamaRemote::EngineConsumerLlamaRemote(
const mojom::Model& model,
scoped_refptr<network::SharedURLLoaderFactory> url_loader_factory,
AIChatCredentialManager* credential_manager) {
// Allow specific model name to be overriden by feature flag
// TODO(petemill): verify premium status, or ensure server will verify even
// when given a model name override via cli flag param.
std::string model_name = ai_chat::features::kAIModelName.Get();
if (model_name.empty()) {
model_name = model.name;
}
DCHECK(!model_name.empty());
DCHECK(!model.name.empty());
base::flat_set<std::string_view> stop_sequences(kStopSequences.begin(),
kStopSequences.end());
api_ = std::make_unique<RemoteCompletionClient>(
model_name, stop_sequences, url_loader_factory, credential_manager);
model.name, stop_sequences, url_loader_factory, credential_manager);

max_page_content_length_ = model.max_page_content_length;
}
Expand Down
20 changes: 10 additions & 10 deletions components/ai_chat/core/browser/models.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,24 @@ namespace ai_chat {
// - Long conversation warning threshold: 100k * 0.80 = 80k tokens

const base::flat_map<std::string_view, mojom::Model> kAllModels = {
{"chat-default",
{"chat-default", "llama-2-13b-chat", "llama2 13b", "Meta",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, false,
9000, 9700}},
{"chat-basic",
{"chat-basic", "llama-2-13b-chat", "llama2 13b", "Meta",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT,
mojom::ModelAccess::BASIC, 9000, 9700}},
{"chat-leo-expanded",
{"chat-leo-expanded", "llama-2-70b-chat", "llama2 70b", "Meta",
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT, true,
9000, 9700}},
mojom::ModelEngineType::LLAMA_REMOTE, mojom::ModelCategory::CHAT,
mojom::ModelAccess::PREMIUM, 9000, 9700}},
{"chat-claude-instant",
{"chat-claude-instant", "claude-instant-v1", "Claude Instant", "Anthropic",
mojom::ModelEngineType::CLAUDE_REMOTE, mojom::ModelCategory::CHAT, true,
200000, 320000}},
mojom::ModelEngineType::CLAUDE_REMOTE, mojom::ModelCategory::CHAT,
mojom::ModelAccess::BASIC_AND_PREMIUM, 200000, 320000}},
};

const std::vector<std::string_view> kAllModelKeysDisplayOrder = {
"chat-default",
"chat-leo-expanded",
"chat-claude-instant",
"chat-basic",
"chat-leo-expanded",
};

} // namespace ai_chat
4 changes: 1 addition & 3 deletions components/ai_chat/core/browser/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
#include <vector>

#include "base/containers/flat_map.h"
#include "brave/components/ai_chat/core/common/features.h"
#include "brave/components/ai_chat/core/common/mojom/ai_chat.mojom-forward.h"

namespace ai_chat {

inline constexpr char kModelsDefaultKey[] = "chat-default";
inline constexpr char kModelsPremiumDefaultKey[] = "chat-claude-instant";

// All models that the user can choose for chat conversations.
extern const base::flat_map<std::string_view, mojom::Model> kAllModels;
// UI display order for models
Expand Down
6 changes: 4 additions & 2 deletions components/ai_chat/core/common/features.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@ BASE_FEATURE(kAIChat,
base::FEATURE_DISABLED_BY_DEFAULT
#endif
);
const base::FeatureParam<std::string> kAIModelName{&kAIChat, "ai_model_name",
""};
const base::FeatureParam<std::string> kAIModelsDefaultKey{&kAIChat, "default_model",
"chat-claude-instant"};
const base::FeatureParam<std::string> kAIModelsPremiumDefaultKey{&kAIChat, "default_premium_model",
"chat-claude-instant"};
const base::FeatureParam<bool> kAIChatSSE{&kAIChat, "ai_chat_sse", true};
const base::FeatureParam<double> kAITemperature{&kAIChat, "temperature", 0.2};

Expand Down
3 changes: 2 additions & 1 deletion components/ai_chat/core/common/features.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
namespace ai_chat::features {

BASE_DECLARE_FEATURE(kAIChat);
extern const base::FeatureParam<std::string> kAIModelName;
extern const base::FeatureParam<std::string> kAIModelsDefaultKey;
extern const base::FeatureParam<std::string> kAIModelsPremiumDefaultKey;
extern const base::FeatureParam<bool> kAIChatSSE;
extern const base::FeatureParam<double> kAITemperature;

Expand Down
12 changes: 11 additions & 1 deletion components/ai_chat/core/common/mojom/ai_chat.mojom
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,15 @@ enum ModelCategory {
CHAT,
};

enum ModelAccess {
// The model only has a single basic tier, accessible by any level
BASIC,
// The model has a basic tier and a more capable premium tier (a.k.a freemium)
BASIC_AND_PREMIUM,
// The model only has a premium tier
PREMIUM,
};

enum PremiumStatus {
Inactive,
Active,
Expand Down Expand Up @@ -96,7 +105,8 @@ struct Model {
ModelEngineType engine_type;
// user-facing category
ModelCategory category;
bool is_premium;
// Which access level grants permission to use the model
ModelAccess access;
// max limit to truncate page contents (measured in chars, not tokens)
uint32 max_page_content_length;
// max limit for overall conversation (measured in chars, not tokens)
Expand Down
4 changes: 3 additions & 1 deletion components/ai_chat/core/common/pref_names.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "brave/components/ai_chat/core/common/pref_names.h"

#include "brave/components/ai_chat/core/common/features.h"
#include "components/prefs/pref_registry_simple.h"
#include "components/prefs/pref_service.h"

Expand All @@ -21,7 +22,7 @@ void RegisterProfilePrefs(PrefRegistrySimple* registry) {
registry->RegisterTimePref(kLastAcceptedDisclaimer, {});
registry->RegisterBooleanPref(kBraveChatAutocompleteProviderEnabled, true);
registry->RegisterBooleanPref(kUserDismissedPremiumPrompt, false);
registry->RegisterStringPref(kDefaultModelKey, "chat-default");
registry->RegisterStringPref(kDefaultModelKey, features::kAIModelsDefaultKey.Get());
#if BUILDFLAG(IS_ANDROID)
registry->RegisterBooleanPref(kBraveChatSubscriptionActiveAndroid, false);
registry->RegisterStringPref(kBraveChatPurchaseTokenAndroid, "");
Expand All @@ -36,6 +37,7 @@ void RegisterProfilePrefsForMigration(PrefRegistrySimple* registry) {

void MigrateProfilePrefs(PrefService* profile_prefs) {
profile_prefs->ClearPref(kObseleteBraveChatAutoGenerateQuestions);
// TODO(petemill): migrate model key from "chat-default" to "chat-basic"
}

void RegisterLocalStatePrefs(PrefRegistrySimple* registry) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@ import * as React from 'react'
import ButtonMenu from '@brave/leo/react/buttonMenu'
import Button from '@brave/leo/react/button'
import Icon from '@brave/leo/react/icon'
import Label from '@brave/leo/react/label'
import { getLocale } from '$web-common/locale'
import getPageHandlerInstance from '../../api/page_handler'
import getPageHandlerInstance, * as mojom from '../../api/page_handler'
import DataContext from '../../state/context'
import styles from './style.module.scss'
import classnames from '$web-common/classnames'
Expand Down Expand Up @@ -55,14 +56,22 @@ export default function FeatureMenu() {
{getLocale(`braveLeoModelSubtitle-${model.key}`)}
</p>
</div>
{model.isPremium && (
{model.access === mojom.ModelAccess.PREMIUM && (
<Icon
className={classnames({
[styles.lockOpen]: context.isPremiumUser
})}
name={context.isPremiumUser ? 'lock-open' : 'lock-plain'}
/>
)}
{model.access === mojom.ModelAccess.BASIC_AND_PREMIUM && (
<Label
mode={context.isPremiumUser ? 'loud' : 'default'}
color='blue'
>
{context.isPremiumUser ? 'Unlimited' : 'Basic'}
</Label>
)}
</div>
</leo-menu-item>
))}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function Main() {
hasAcceptedAgreement &&
!context.isPremiumStatusFetching && // Avoid flash of content
!context.isPremiumUser &&
context.currentModel?.isPremium
context.currentModel?.access === mojom.ModelAccess.PREMIUM

const shouldShowPremiumSuggestionStandalone =
hasAcceptedAgreement &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function DataContextProvider (props: DataContextProviderProps) {
const isPremiumUser = premiumStatus !== undefined && premiumStatus !== mojom.PremiumStatus.Inactive

const apiHasError = (currentError !== mojom.APIError.None)
const shouldDisableUserInput = !!(apiHasError || isGenerating || (!isPremiumUser && currentModel?.isPremium))
const shouldDisableUserInput = !!(apiHasError || isGenerating || (!isPremiumUser && currentModel?.access === mojom.ModelAccess.PREMIUM))

const getConversationHistory = () => {
getPageHandlerInstance()
Expand Down Expand Up @@ -127,7 +127,7 @@ function DataContextProvider (props: DataContextProviderProps) {

const switchToDefaultModel = () => {
// Select the first non-premium model
const nonPremium = allModels.find(m => !m.isPremium)
const nonPremium = allModels.find(m => [mojom.ModelAccess.BASIC, mojom.ModelAccess.BASIC_AND_PREMIUM].includes(m.access))
if (!nonPremium) {
console.error('Could not find a non-premium model!')
return
Expand Down
15 changes: 13 additions & 2 deletions components/ai_chat/resources/page/stories/components_panel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ const MODELS: mojom.Model[] = [
displayMaker: 'Company',
engineType: mojom.ModelEngineType.LLAMA_REMOTE,
category: mojom.ModelCategory.CHAT,
isPremium: false,
access: mojom.ModelAccess.BASIC,
maxPageContentLength: 10000,
longConversationWarningCharacterLimit: 9700
},
Expand All @@ -60,7 +60,18 @@ const MODELS: mojom.Model[] = [
displayMaker: 'Company',
engineType: mojom.ModelEngineType.LLAMA_REMOTE,
category: mojom.ModelCategory.CHAT,
isPremium: true,
access: mojom.ModelAccess.PREMIUM,
maxPageContentLength: 10000,
longConversationWarningCharacterLimit: 9700
},
{
key: '3',
name: 'model-three-freemium',
displayName: 'Model Three',
displayMaker: 'Company',
engineType: mojom.ModelEngineType.LLAMA_REMOTE,
category: mojom.ModelCategory.CHAT,
access: mojom.ModelAccess.BASIC_AND_PREMIUM,
maxPageContentLength: 10000,
longConversationWarningCharacterLimit: 9700
}
Expand Down
4 changes: 2 additions & 2 deletions components/resources/ai_chat_ui_strings.grdp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
<message name="IDS_CHAT_UI_RETRY_BUTTON_LABEL" desc="A button label to retry API again">
Retry
</message>
<message name="IDS_CHAT_UI_INTRO_MESSAGE_CHAT_DEFAULT" desc="AI Chat intro message for the default model">
<message name="IDS_CHAT_UI_INTRO_MESSAGE_CHAT_BASIC" desc="AI Chat intro message for the default model">
Hi, I'm Leo. I'm a fully hosted AI assistant by Brave. I'm powered by Llama 13B, a model created by Meta to be performant and applicable to many use cases.
</message>
<message name="IDS_CHAT_UI_INTRO_MESSAGE_CHAT_LEO_EXPANDED" desc="AI Chat intro message for the expanded model">
Expand Down Expand Up @@ -202,7 +202,7 @@
<message name="IDS_CHAT_UI_CONVERSATION_END_ERROR" desc="An error issued when the user cannot continue chatting">
This conversation is too long and cannot continue. There may be other models available with which Leo is capable of maintaining accuracy for longer conversations.
</message>
<message name="IDS_CHAT_UI_CHAT_DEFAULT_SUBTITLE" desc="A description for default chat model">
<message name="IDS_CHAT_UI_CHAT_BASIC_SUBTITLE" desc="A description for default chat model">
General purpose chat
</message>
<message name="IDS_CHAT_UI_CHAT_LEO_EXPANDED_SUBTITLE" desc="A description for llama-2-70b model">
Expand Down

0 comments on commit c74b8b7

Please sign in to comment.