Skip to content

Commit

Permalink
Refactor ai chat page
Browse files Browse the repository at this point in the history
Move the core logic to its own class that maintains a single chat session instead of creating a new session for every single message.
  • Loading branch information
Eyob Alemu committed Sep 8, 2024
1 parent d2b66bb commit 6a2413e
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 90 deletions.
75 changes: 75 additions & 0 deletions lib/apis/gemini.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import 'package:arxiv/models/chat_message.dart';
import 'package:arxiv/models/paper.dart';
import 'package:google_generative_ai/google_generative_ai.dart';
import 'package:flutter/services.dart' show rootBundle;

class Gemini {
late final GenerativeModel _model;
late final ChatSession _chatSession;

Gemini._internal(String apiKey, String systemPrompt) {
_model = GenerativeModel(
apiKey: apiKey,
model: 'gemini-1.5-flash',
systemInstruction: Content.system(systemPrompt),
generationConfig: GenerationConfig(
temperature: 1,
topK: 64,
topP: 0.95,
maxOutputTokens: 8192,
responseMimeType: 'text/plain',
),
);

_chatSession = _model.startChat();
}

static Future<Gemini> newModel(String apiKey, {Paper? paper}) async {
final systemPrompt = paper != null
? await _getModelSystemMessage(paper)
: await _getGeneralSystemMessage();
return Gemini._internal(apiKey, systemPrompt);
}

Future<ChatMessage> sendMessage(String message) async {
try {
var content = Content.text(message);
var response = await _chatSession.sendMessage(content);
return ChatMessage(Role.ai, response.text?.trim() ?? "");
} catch (e) {
return ChatMessage(Role.ai, e.toString());
}
}

static Future<String> _getModelSystemMessage(Paper paper) async {
var substitutes = {
'paperId': paper.id,
'paperTitle': paper.title,
'paperAuthors': paper.authors,
'paperPublishedDate': paper.publishedAt,
'paperSummary': paper.summary,
};

return await _fromTemplateFile(
'assets/system_message_templates/model.txt', substitutes);
}

static Future<String> _getGeneralSystemMessage() async {
return await _fromTemplateFile(
'assets/system_message_templates/general.txt', {});
}

/// Interpolates values to a text read from a file. The format for a placeholder is {{some_name}}.
static Future<String> _fromTemplateFile(
String fileName, Map<String, dynamic> substitutes) async {
var template = await rootBundle.loadString(fileName);
return template.splitMapJoin(RegExp('{{.*?}}'),
onMatch: (m) => substitutes[_getPlaceholderName(m.group(0))] ?? '');
}

static String _getPlaceholderName(String? placeholderTemplate) {
if (placeholderTemplate == null) return '';

return placeholderTemplate.substring(2, placeholderTemplate.length - 2);
}
}
104 changes: 15 additions & 89 deletions lib/pages/ai_chat_page.dart
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
// ignore_for_file: file_names

import 'package:arxiv/apis/gemini.dart';
import 'package:arxiv/components/api_settings.dart';
import 'package:arxiv/components/each_chat_message.dart';
import 'package:arxiv/components/prompt_suggestions.dart';
import 'package:arxiv/models/chat_message.dart';
import 'package:arxiv/models/paper.dart';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart' show rootBundle;
import 'package:google_generative_ai/google_generative_ai.dart';
import 'package:hive/hive.dart';
import 'package:ionicons/ionicons.dart';
import 'package:theme_provider/theme_provider.dart';

class AIChatPage extends StatefulWidget {
const AIChatPage({super.key, required this.paperData});

final dynamic paperData;
final Paper? paperData;

@override
State<AIChatPage> createState() => _AIChatPageState();
Expand All @@ -24,14 +24,14 @@ class _AIChatPageState extends State<AIChatPage> {
TextEditingController userMessageController = TextEditingController();
ScrollController scrollController = ScrollController();
var apiKey = "";
var aiResponse = "";
var systemPrompt = "";
List<ChatMessage> chatList = [];
var apiKeySettingsOn = false;
var toolsOn = true;

final _systemLoadingTrigger = "SYMLOADINGANIMATION";

late final Gemini model;

var paperPromptSuggestions = [
"Who wrote this paper?",
"What is the title of this paper?",
Expand Down Expand Up @@ -76,31 +76,8 @@ class _AIChatPageState extends State<AIChatPage> {
chatList.add(ChatMessage(Role.user, message));
chatList.add(ChatMessage(Role.system, _systemLoadingTrigger));
scrollToTheBottom();
ChatMessage aiResponseObject;

try {
var model = GenerativeModel(
apiKey: apiKey,
model: 'gemini-1.5-flash',
systemInstruction: Content.system(systemPrompt),
generationConfig: GenerationConfig(
temperature: 1,
topK: 64,
topP: 0.95,
maxOutputTokens: 8192,
responseMimeType: 'text/plain',
),
);

var chat = model.startChat();
var content = Content.text(message);

var response = await chat.sendMessage(content);
aiResponse = response.text?.trim() ?? "";
aiResponseObject = ChatMessage(Role.ai, aiResponse);
} catch (e) {
aiResponseObject = ChatMessage(Role.ai, e.toString());
}
ChatMessage aiResponseObject = await model.sendMessage(message);

chatList.removeLast();
setState(() {});
Expand All @@ -116,14 +93,16 @@ class _AIChatPageState extends State<AIChatPage> {
setState(() {});
}

void configAPIKey() async {
void configModel() async {
Box apiBox = await Hive.openBox("apibox");
apiKey = await apiBox.get("apikey") ?? "";
await Hive.close();
if (apiKey == "") {
apiKeySettingsOn = true;
} else {

if (apiKey.isNotEmpty) {
model = await Gemini.newModel(apiKey, paper: widget.paperData);
apiKeySettingsOn = false;
} else {
apiKeySettingsOn = true;
}
setState(() {});
}
Expand All @@ -148,64 +127,11 @@ class _AIChatPageState extends State<AIChatPage> {
setState(() {});
}

void setupModelSystemMessage() async {
var paperId = widget.paperData["id"].toString().substring(
widget.paperData["id"].lastIndexOf("/") + 1,
widget.paperData["id"].length);
var paperTitle = widget.paperData["title"]
.toString()
.replaceAll(RegExp(r'\\n'), '')
.replaceAll(RegExp(r'\\ '), '');
var paperAuthors = widget.paperData["author"]
.toString()
.replaceAll("name:", "")
.replaceAll(RegExp("[\\[\\]\\{\\}]"), "");
var paperPublishedDate =
widget.paperData["published"].toString().substring(0, 10);
var paperSummary = widget.paperData["summary"]
.trim()
.replaceAll(RegExp(r'\\n'), ' ')
.replaceAll(RegExp(r'\\'), '');

var substitutes = {
'paperId': paperId,
'paperTitle': paperTitle,
'paperAuthors': paperAuthors,
'paperPublishedDate': paperPublishedDate,
'paperSummary': paperSummary
};

systemPrompt = await fromTemplateFile(
'assets/system_message_templates/model.txt', substitutes);
}

void setupGeneralSystemMessage() async {
systemPrompt = await fromTemplateFile(
'assets/system_message_templates/general.txt', {});
}

/// Interpolates values to a text read from a file. The format for a placeholder is {{some_name}}.
Future<String> fromTemplateFile(
String fileName, Map<String, dynamic> substitutes) async {
var template = await rootBundle.loadString(fileName);
return template.splitMapJoin(RegExp('{{.*?}}'),
onMatch: (m) => substitutes[getPlaceholderName(m.group(0))] ?? '');
}

String getPlaceholderName(String? placeholderTemplate) {
if (placeholderTemplate == null) return '';

return placeholderTemplate.substring(2, placeholderTemplate.length - 2);
}

@override
void initState() {
super.initState();
getToggleTools();
configAPIKey();
widget.paperData == ""
? setupGeneralSystemMessage()
: setupModelSystemMessage();
configModel();
}

@override
Expand Down Expand Up @@ -280,12 +206,12 @@ class _AIChatPageState extends State<AIChatPage> {
),
apiKeySettingsOn == true
? APISettings(
configAPIKey: configAPIKey,
configAPIKey: configModel,
)
: PromptSuggestions(
chatWithAI: chatWithAI,
userMessageController: userMessageController,
promptSuggestions: widget.paperData == ""
promptSuggestions: widget.paperData == null
? generalPromptSuggestions
: paperPromptSuggestions,
),
Expand Down
2 changes: 1 addition & 1 deletion lib/pages/home_page.dart
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class _HomePageState extends State<HomePage> {
context,
MaterialPageRoute(
builder: (context) => const AIChatPage(
paperData: "",
paperData: null,
),
),
);
Expand Down

0 comments on commit 6a2413e

Please sign in to comment.