Skip to content

Commit

Permalink
[AI Chat]: Conversation starter pack, for unassociated content (#26379)
Browse files Browse the repository at this point in the history
  • Loading branch information
fallaciousreasoning authored Nov 13, 2024
1 parent c73c5e6 commit 879786a
Show file tree
Hide file tree
Showing 9 changed files with 263 additions and 93 deletions.
99 changes: 82 additions & 17 deletions components/ai_chat/core/browser/conversation_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
#include "brave/components/ai_chat/core/browser/conversation_handler.h"

#include <algorithm>
#include <iterator>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

#include "base/files/file_path.h"
#include "base/memory/weak_ptr.h"
#include "base/rand_util.h"
#include "base/strings/utf_string_conversions.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/thread_pool.h"
Expand All @@ -33,6 +36,10 @@
#include "services/network/public/cpp/shared_url_loader_factory.h"
#include "ui/base/l10n/l10n_util.h"

#define STARTER_PROMPT(TYPE) \
l10n_util::GetStringUTF8(IDS_AI_CHAT_STATIC_STARTER_TITLE_##TYPE), \
l10n_util::GetStringUTF8(IDS_AI_CHAT_STATIC_STARTER_PROMPT_##TYPE)

namespace ai_chat {

namespace {
Expand All @@ -43,6 +50,8 @@ using ai_chat::mojom::ConversationTurn;
using AssociatedContentDelegate =
ConversationHandler::AssociatedContentDelegate;

constexpr size_t kDefaultSuggestionsCount = 4;

} // namespace

AssociatedContentDelegate::AssociatedContentDelegate()
Expand Down Expand Up @@ -131,6 +140,16 @@ void AssociatedContentDelegate::OnTextEmbedderInitialized(bool initialized) {
pending_top_similarity_requests_.clear();
}

ConversationHandler::Suggestion::Suggestion(std::string title)
: title(std::move(title)) {}
ConversationHandler::Suggestion::Suggestion(std::string title,
std::string prompt)
: title(std::move(title)), prompt(std::move(prompt)) {}
ConversationHandler::Suggestion::Suggestion(Suggestion&&) = default;
ConversationHandler::Suggestion& ConversationHandler::Suggestion::operator=(
Suggestion&&) = default;
ConversationHandler::Suggestion::~Suggestion() = default;

ConversationHandler::ConversationHandler(
const mojom::Conversation* conversation,
AIChatService* ai_chat_service,
Expand All @@ -154,6 +173,7 @@ ConversationHandler::ConversationHandler(
models_observer_.Observe(model_service_.get());
// TODO(petemill): differ based on premium status, if different
ChangeModel(model_service->GetDefaultModelKey());
MaybeSeedOrClearSuggestions();
}

ConversationHandler::~ConversationHandler() {
Expand Down Expand Up @@ -352,9 +372,12 @@ void ConversationHandler::GetState(GetStateCallback callback) {

BuildAssociatedContentInfo();

std::vector<std::string> suggestions;
std::ranges::transform(suggestions_, std::back_inserter(suggestions),
[](const auto& s) { return s.title; });
mojom::ConversationStatePtr state = mojom::ConversationState::New(
metadata_->uuid, is_request_in_progress_, std::move(models_copy),
model_key, suggestions_, suggestion_generation_status_,
model_key, std::move(suggestions), suggestion_generation_status_,
associated_content_info_->Clone(), should_send_page_contents_,
current_error_);

Expand Down Expand Up @@ -505,13 +528,7 @@ void ConversationHandler::SubmitHumanConversationEntry(
DCHECK(latest_turn->character_type == mojom::CharacterType::HUMAN);
is_request_in_progress_ = true;
OnAPIRequestInProgressChanged();
// If it's a suggested question, remove it
auto found_question_iter =
base::ranges::find(suggestions_, latest_turn->text);
if (found_question_iter != suggestions_.end()) {
suggestions_.erase(found_question_iter);
OnSuggestedQuestionsChanged();
}

// Directly modify Entry's text to remove engine-breaking substrings
if (!has_edits) { // Edits are already sanitized.
engine_->SanitizeInput(latest_turn->text);
Expand All @@ -525,7 +542,15 @@ void ConversationHandler::SubmitHumanConversationEntry(
// callers of SubmitHumanConversationEntry mojo API currently don't have
// action_type specified.
std::string question_part = latest_turn->text;
if (latest_turn->action_type == mojom::ActionType::UNSPECIFIED) {
// If it's a suggested question, remove it
auto found_question_iter =
base::ranges::find(suggestions_, latest_turn->text, &Suggestion::title);
if (found_question_iter != suggestions_.end()) {
question_part =
found_question_iter->prompt.value_or(found_question_iter->title);
suggestions_.erase(found_question_iter);
OnSuggestedQuestionsChanged();
} else if (latest_turn->action_type == mojom::ActionType::UNSPECIFIED) {
if (latest_turn->text ==
l10n_util::GetStringUTF8(IDS_CHAT_UI_SUMMARIZE_PAGE)) {
latest_turn->action_type = mojom::ActionType::SUMMARIZE_PAGE;
Expand Down Expand Up @@ -671,9 +696,17 @@ void ConversationHandler::SubmitSummarizationRequest() {
SubmitHumanConversationEntry(std::move(turn));
}

void ConversationHandler::GetSuggestedQuestions(
GetSuggestedQuestionsCallback callback) {
std::move(callback).Run(suggestions_, suggestion_generation_status_);
std::vector<std::string> ConversationHandler::GetSuggestedQuestionsForTest() {
std::vector<std::string> suggestions;
base::ranges::transform(suggestions_, std::back_inserter(suggestions),
[](const auto& s) { return s.title; });
return suggestions;
}

void ConversationHandler::SetSuggestedQuestionForTest(std::string title,
std::string prompt) {
suggestions_.clear();
suggestions_.emplace_back(title, prompt);
}

void ConversationHandler::GenerateQuestions() {
Expand Down Expand Up @@ -1012,14 +1045,38 @@ void ConversationHandler::MaybeSeedOrClearSuggestions() {
const bool is_page_associated =
IsContentAssociationPossible() && should_send_page_contents_;

if (!is_page_associated && !suggestions_.empty()) {
if (!is_page_associated) {
suggestions_.clear();
suggestion_generation_status_ = mojom::SuggestionGenerationStatus::None;
suggestions_.emplace_back(STARTER_PROMPT(MEMO));
suggestions_.emplace_back(STARTER_PROMPT(INTERVIEW));
suggestions_.emplace_back(STARTER_PROMPT(STUDY_PLAN));
suggestions_.emplace_back(STARTER_PROMPT(PROJECT_TIMELINE));
suggestions_.emplace_back(STARTER_PROMPT(MARKETING_STRATEGY));
suggestions_.emplace_back(STARTER_PROMPT(PRESENTATION_OUTLINE));
suggestions_.emplace_back(STARTER_PROMPT(BRAINSTORM));
suggestions_.emplace_back(STARTER_PROMPT(PROFESSIONAL_EMAIL));
suggestions_.emplace_back(STARTER_PROMPT(BUSINESS_PROPOSAL));

// We don't have an external list of all the available suggestions, so we
// generate all of them and remove random ones until we have the required
// number and then shuffle the result.
while (suggestions_.size() > kDefaultSuggestionsCount) {
auto remove_at = base::RandInt(0, suggestions_.size() - 1);
suggestions_.erase(suggestions_.begin() + remove_at);
}
base::RandomShuffle(suggestions_.begin(), suggestions_.end());
OnSuggestedQuestionsChanged();
return;
}

if (is_page_associated && suggestions_.empty() &&
// This means we have the default suggestions
if (suggestion_generation_status_ ==
mojom::SuggestionGenerationStatus::None) {
suggestions_.clear();
}

if (suggestions_.empty() &&
suggestion_generation_status_ !=
mojom::SuggestionGenerationStatus::IsGenerating &&
suggestion_generation_status_ !=
Expand Down Expand Up @@ -1207,18 +1264,20 @@ void ConversationHandler::OnEngineCompletionComplete(
void ConversationHandler::OnSuggestedQuestionsResponse(
EngineConsumer::SuggestedQuestionResult result) {
if (result.has_value()) {
suggestions_.insert(suggestions_.end(), result->begin(), result->end());
std::ranges::transform(result.value(), std::back_inserter(suggestions_),
[](const auto& s) { return Suggestion(s); });
suggestion_generation_status_ =
mojom::SuggestionGenerationStatus::HasGenerated;
DVLOG(2) << "Got questions:" << base::JoinString(result.value(), "\n");
} else {
// TODO(nullhook): Set a specialized error state generated questions
suggestion_generation_status_ =
mojom::SuggestionGenerationStatus::CanGenerate;
DVLOG(2) << "Got no questions";
}

// Notify observers
OnSuggestedQuestionsChanged();
DVLOG(2) << "Got questions:" << base::JoinString(suggestions_, "\n");
}

void ConversationHandler::OnModelListUpdated() {
Expand Down Expand Up @@ -1373,8 +1432,12 @@ void ConversationHandler::OnAssociatedContentFaviconImageDataChanged() {
}

void ConversationHandler::OnSuggestedQuestionsChanged() {
std::vector<std::string> suggestions;
std::ranges::transform(suggestions_, std::back_inserter(suggestions),
[](const auto& s) { return s.title; });

for (auto& client : conversation_ui_handlers_) {
client->OnSuggestedQuestionsChanged(suggestions_,
client->OnSuggestedQuestionsChanged(suggestions,
suggestion_generation_status_);
}
}
Expand All @@ -1386,3 +1449,5 @@ void ConversationHandler::OnAPIRequestInProgressChanged() {
}

} // namespace ai_chat

#undef STARTER_PROMPT
19 changes: 17 additions & 2 deletions components/ai_chat/core/browser/conversation_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#define BRAVE_COMPONENTS_AI_CHAT_CORE_BROWSER_CONVERSATION_HANDLER_H_

#include <memory>
#include <optional>
#include <string>
#include <tuple>
#include <utility>
Expand Down Expand Up @@ -203,7 +204,8 @@ class ConversationHandler : public mojom::ConversationHandler,
void ModifyConversation(uint32_t turn_index,
const std::string& new_text) override;
void SubmitSummarizationRequest() override;
void GetSuggestedQuestions(GetSuggestedQuestionsCallback callback) override;
std::vector<std::string> GetSuggestedQuestionsForTest();
void SetSuggestedQuestionForTest(std::string title, std::string prompt);
void GenerateQuestions() override;
void GetAssociatedContentInfo(
GetAssociatedContentInfoCallback callback) override;
Expand Down Expand Up @@ -284,6 +286,19 @@ class ConversationHandler : public mojom::ConversationHandler,
FRIEND_TEST_ALL_PREFIXES(PageContentRefineTest, TextEmbedder);
FRIEND_TEST_ALL_PREFIXES(PageContentRefineTest, TextEmbedderInitialized);

struct Suggestion {
std::string title;
std::optional<std::string> prompt;

explicit Suggestion(std::string title);
Suggestion(std::string title, std::string prompt);
Suggestion(const Suggestion&) = delete;
Suggestion& operator=(const Suggestion&) = delete;
Suggestion(Suggestion&&);
Suggestion& operator=(Suggestion&&);
~Suggestion();
};

void InitEngine();
void BuildAssociatedContentInfo();
bool IsContentAssociationPossible();
Expand Down Expand Up @@ -350,7 +365,7 @@ class ConversationHandler : public mojom::ConversationHandler,
std::vector<mojom::ConversationTurnPtr> chat_history_;
mojom::ConversationTurnPtr pending_conversation_entry_;
// Any previously-generated suggested questions
std::vector<std::string> suggestions_;
std::vector<Suggestion> suggestions_;
std::string selected_language_;
// Is a conversation engine request in progress (does not include
// non-conversation engine requests.
Expand Down
49 changes: 37 additions & 12 deletions components/ai_chat/core/browser/conversation_handler_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "base/functional/overloaded.h"
#include "base/memory/scoped_refptr.h"
#include "base/ranges/algorithm.h"
#include "base/run_loop.h"
#include "base/scoped_observation.h"
#include "base/task/sequenced_task_runner.h"
#include "base/task/thread_pool.h"
Expand Down Expand Up @@ -355,7 +356,7 @@ TEST_F(ConversationHandlerUnitTest, GetState) {
testing::ElementsAre(l10n_util::GetStringUTF8(
IDS_CHAT_UI_SUMMARIZE_PAGE)));
} else {
EXPECT_TRUE(state->suggested_questions.empty());
EXPECT_EQ(4u, state->suggested_questions.size());
}
EXPECT_EQ(state->suggestion_status,
should_send_content
Expand Down Expand Up @@ -440,11 +441,7 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText) {
// once conversation history is committed.
EXPECT_FALSE(site_info->is_content_association_possible);
}));
conversation_handler_->GetSuggestedQuestions(
base::BindLambdaForTesting([&](const std::vector<std::string>& questions,
mojom::SuggestionGenerationStatus status) {
EXPECT_TRUE(questions.empty());
}));
EXPECT_TRUE(conversation_handler_->GetSuggestedQuestionsForTest().empty());

EXPECT_TRUE(conversation_handler_->HasAnyHistory());
const auto& history = conversation_handler_->GetConversationHistory();
Expand Down Expand Up @@ -527,12 +524,9 @@ TEST_F(ConversationHandlerUnitTest, SubmitSelectedText_WithAssociatedContent) {

// Should not be any LLM-generated suggested questions yet because they
// weren't asked for
conversation_handler_->GetSuggestedQuestions(
base::BindLambdaForTesting([&](const std::vector<std::string>& questions,
mojom::SuggestionGenerationStatus status) {
EXPECT_EQ(1u, questions.size());
EXPECT_EQ(questions[0], "Summarize this page");
}));
const auto questions = conversation_handler_->GetSuggestedQuestionsForTest();
EXPECT_EQ(1u, questions.size());
EXPECT_EQ(questions[0], "Summarize this page");

const auto& history2 = conversation_handler_->GetConversationHistory();
std::vector<mojom::ConversationTurnPtr> expected_history2;
Expand Down Expand Up @@ -1317,6 +1311,37 @@ TEST_F(ConversationHandlerUnitTest_NoAssociatedContent, GenerateQuestions) {
testing::Mock::VerifyAndClearExpectations(engine);
}

TEST_F(ConversationHandlerUnitTest_NoAssociatedContent,
GeneratesQuestionsByDefault) {
EXPECT_EQ(4u, conversation_handler_->GetSuggestedQuestionsForTest().size());
}

TEST_F(ConversationHandlerUnitTest_NoAssociatedContent,
SelectingDefaultQuestionSendsPrompt) {
conversation_handler_->SetSuggestedQuestionForTest("the thing",
"do the thing!");
auto suggestions = conversation_handler_->GetSuggestedQuestionsForTest();
EXPECT_EQ(1u, suggestions.size());

// Mock engine response
MockEngineConsumer* engine = static_cast<MockEngineConsumer*>(
conversation_handler_->GetEngineForTesting());

base::RunLoop loop;
// The prompt should be submitted to the engine, not the title.
EXPECT_CALL(*engine,
GenerateAssistantResponse(false, StrEq(""), _, "do the thing!",
StrEq(""), _, _))
.WillOnce(testing::InvokeWithoutArgs(&loop, &base::RunLoop::Quit));

conversation_handler_->SubmitHumanConversationEntry("the thing");
loop.Run();
testing::Mock::VerifyAndClearExpectations(engine);

// Suggestion should be removed
EXPECT_EQ(0u, conversation_handler_->GetSuggestedQuestionsForTest().size());
}

TEST_F(ConversationHandlerUnitTest, SelectedLanguage) {
MockEngineConsumer* engine = static_cast<MockEngineConsumer*>(
conversation_handler_->GetEngineForTesting());
Expand Down
4 changes: 0 additions & 4 deletions components/ai_chat/core/common/mojom/ai_chat.mojom
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,6 @@ interface ConversationHandler {
// Get all visible history entries, including in-progress responses
GetConversationHistory() => (array<ConversationTurn> conversation_history);

// List of all suggested questions
GetSuggestedQuestions() => (
array<string> questions, SuggestionGenerationStatus suggestion_status);

// The browser should generate some questions and fire an event when they
// are ready.
GenerateQuestions();
Expand Down
Loading

0 comments on commit 879786a

Please sign in to comment.