Skip to content

Commit

Permalink
Merge pull request #36 from eban12/main
Browse files Browse the repository at this point in the history
Code refactor and bug fixes
  • Loading branch information
dagmawibabi authored Sep 8, 2024
2 parents e281245 + 6a2413e commit 673f2d5
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 194 deletions.
75 changes: 75 additions & 0 deletions lib/apis/gemini.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import 'package:arxiv/models/chat_message.dart';
import 'package:arxiv/models/paper.dart';
import 'package:google_generative_ai/google_generative_ai.dart';
import 'package:flutter/services.dart' show rootBundle;

class Gemini {
late final GenerativeModel _model;
late final ChatSession _chatSession;

Gemini._internal(String apiKey, String systemPrompt) {
_model = GenerativeModel(
apiKey: apiKey,
model: 'gemini-1.5-flash',
systemInstruction: Content.system(systemPrompt),
generationConfig: GenerationConfig(
temperature: 1,
topK: 64,
topP: 0.95,
maxOutputTokens: 8192,
responseMimeType: 'text/plain',
),
);

_chatSession = _model.startChat();
}

static Future<Gemini> newModel(String apiKey, {Paper? paper}) async {
final systemPrompt = paper != null
? await _getModelSystemMessage(paper)
: await _getGeneralSystemMessage();
return Gemini._internal(apiKey, systemPrompt);
}

Future<ChatMessage> sendMessage(String message) async {
try {
var content = Content.text(message);
var response = await _chatSession.sendMessage(content);
return ChatMessage(Role.ai, response.text?.trim() ?? "");
} catch (e) {
return ChatMessage(Role.ai, e.toString());
}
}

static Future<String> _getModelSystemMessage(Paper paper) async {
var substitutes = {
'paperId': paper.id,
'paperTitle': paper.title,
'paperAuthors': paper.authors,
'paperPublishedDate': paper.publishedAt,
'paperSummary': paper.summary,
};

return await _fromTemplateFile(
'assets/system_message_templates/model.txt', substitutes);
}

static Future<String> _getGeneralSystemMessage() async {
return await _fromTemplateFile(
'assets/system_message_templates/general.txt', {});
}

/// Interpolates values to a text read from a file. The format for a placeholder is {{some_name}}.
static Future<String> _fromTemplateFile(
String fileName, Map<String, dynamic> substitutes) async {
var template = await rootBundle.loadString(fileName);
return template.splitMapJoin(RegExp('{{.*?}}'),
onMatch: (m) => substitutes[_getPlaceholderName(m.group(0))] ?? '');
}

static String _getPlaceholderName(String? placeholderTemplate) {
if (placeholderTemplate == null) return '';

return placeholderTemplate.substring(2, placeholderTemplate.length - 2);
}
}
80 changes: 42 additions & 38 deletions lib/components/each_chat_message.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// ignore_for_file: file_names

import 'package:animated_text_kit/animated_text_kit.dart';
import 'package:arxiv/models/chat_message.dart';
import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:flutter_markdown/flutter_markdown.dart';
Expand All @@ -19,7 +20,7 @@ class EachChatMessage extends StatefulWidget {
required this.toolsOn,
});

final dynamic response;
final ChatMessage response;
final dynamic toolsOn;

@override
Expand All @@ -32,13 +33,18 @@ class _EachChatMessageState extends State<EachChatMessage> {
var speedFactor = 0.1;
var isSpeaking = false;

void readReponse() async {
var message =
widget.response["content"].toString().substring(0, 6) == "SYMMDX"
? widget.response["content"]
.toString()
.substring(6, widget.response["content"].length)
: widget.response["content"];
final _markdownPrefix = "SYMMDX";

bool isMarkdown(String content) {
return content.substring(0, 6) == _markdownPrefix;
}

void readResponse() async {
var message = isMarkdown(widget.response.content)
? widget.response.content
.toString()
.substring(6, widget.response.content.length)
: widget.response.content;
if (isSpeaking == false) {
await tts.setLanguage("en-US");
tts.setSpeechRate(speedRate);
Expand All @@ -51,22 +57,20 @@ class _EachChatMessageState extends State<EachChatMessage> {
}

void shareResponse() async {
var message =
widget.response["content"].toString().substring(0, 6) == "SYMMDX"
? widget.response["content"]
.toString()
.substring(6, widget.response["content"].length)
: widget.response["content"];
var message = isMarkdown(widget.response.content)
? widget.response.content
.toString()
.substring(6, widget.response.content.length)
: widget.response.content;
Share.share(message.toString().trim());
}

void copyResponse() async {
var message =
widget.response["content"].toString().substring(0, 6) == "SYMMDX"
? widget.response["content"]
.toString()
.substring(6, widget.response["content"].length)
: widget.response["content"];
var message = isMarkdown(widget.response.content)
? widget.response.content
.toString()
.substring(6, widget.response.content.length)
: widget.response.content;
await Clipboard.setData(
ClipboardData(
text: message,
Expand Down Expand Up @@ -94,7 +98,7 @@ class _EachChatMessageState extends State<EachChatMessage> {
@override
Widget build(BuildContext context) {
return Row(
mainAxisAlignment: widget.response["role"] == "USER"
mainAxisAlignment: widget.response.role == Role.user
? MainAxisAlignment.end
: MainAxisAlignment.start,
children: [
Expand All @@ -104,8 +108,8 @@ class _EachChatMessageState extends State<EachChatMessage> {
Row(
crossAxisAlignment: CrossAxisAlignment.start,
children: [
widget.response["role"] == "AI" ||
widget.response["role"] == "SYSTEM"
widget.response.role == Role.ai ||
widget.response.role == Role.system
? Padding(
padding: const EdgeInsets.only(top: 6.0, left: 10.0),
child: Icon(
Expand Down Expand Up @@ -146,17 +150,17 @@ class _EachChatMessageState extends State<EachChatMessage> {
Colors.grey[100],
borderRadius: BorderRadius.circular(10.0),
),
child: widget.response["role"] == "USER"
child: widget.response.role == Role.user
? Padding(
padding: const EdgeInsets.symmetric(
horizontal: 13.0,
vertical: 10.0,
),
child: Text(
widget.response["content"],
widget.response.content,
),
)
: widget.response["role"] == "SYSTEM"
: widget.response.role == Role.system
? Container(
padding: const EdgeInsets.symmetric(
horizontal: 20.0,
Expand All @@ -172,15 +176,15 @@ class _EachChatMessageState extends State<EachChatMessage> {
size: 30,
),
)
: widget.response["content"]
: widget.response.content
.toString()
.substring(0, 6) ==
"SYMMDX"
? Markdown(
data: widget.response["content"]
data: widget.response.content
.toString()
.substring(
6, widget.response["content"].length)
6, widget.response.content.length)
.trim(),
selectable: true,
shrinkWrap: true,
Expand All @@ -201,31 +205,31 @@ class _EachChatMessageState extends State<EachChatMessage> {
isRepeatingAnimation: false,
animatedTexts: [
TypewriterAnimatedText(
widget.response["content"]
widget.response.content
.toString()
.trim(),
textStyle: TextStyle(
color: widget.response["content"]
color: widget.response.content
.toString()
.trim()
.startsWith(
"GenerativeAIException") ||
widget.response["content"]
widget.response.content
.toString()
.trim()
.startsWith(
"ClientException") ||
widget.response["content"]
widget.response.content
.toString()
.trim()
.startsWith(
"HandshakeException") ||
widget.response["content"]
widget.response.content
.toString()
.trim()
.startsWith(
"API key not valid") ||
widget.response["content"]
widget.response.content
.toString()
.trim()
.startsWith(
Expand All @@ -244,7 +248,7 @@ class _EachChatMessageState extends State<EachChatMessage> {
),
),
),
widget.response["role"] == "USER"
widget.response.role == Role.user
? Padding(
padding: const EdgeInsets.only(top: 8.0, right: 10.0),
child: Icon(
Expand All @@ -260,15 +264,15 @@ class _EachChatMessageState extends State<EachChatMessage> {
],
),
// TOOLS
widget.response["role"] == "AI" && widget.toolsOn == true
widget.response.role == "AI" && widget.toolsOn == true
? Container(
padding: const EdgeInsets.only(left: 50.0, bottom: 14.0),
child: Row(
mainAxisAlignment: MainAxisAlignment.center,
children: [
GestureDetector(
onTap: () {
readReponse();
readResponse();
},
child: Container(
padding: const EdgeInsets.symmetric(
Expand Down
26 changes: 5 additions & 21 deletions lib/components/each_paper_card.dart
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,6 @@ import 'package:ionicons/ionicons.dart';
import 'package:share_plus/share_plus.dart';
import 'package:theme_provider/theme_provider.dart';

bool containsLatex(String title) {
final latexRegex = RegExp(r'[$\\{}]');
return latexRegex.hasMatch(title);
}

class EachPaperCard extends StatefulWidget {
const EachPaperCard({
super.key,
Expand Down Expand Up @@ -101,16 +96,7 @@ class _EachPaperCardState extends State<EachPaperCard> {

@override
Widget build(BuildContext context) {
String title = widget.eachPaper.title
.toString()
.replaceAll(RegExp(r'\\n'), '')
.replaceAll(RegExp(r'\\ '), '');

if (containsLatex(title) == true) {
title = title.replaceAll(RegExp(r'\$ '), r' \) ');
title = title.replaceAll(RegExp(r' \$'), r' \( ');
title = title.replaceAll(r'$', r' \) ');
}
String title = widget.eachPaper.title;

return Container(
margin: const EdgeInsets.only(
Expand Down Expand Up @@ -140,10 +126,8 @@ class _EachPaperCardState extends State<EachPaperCard> {
children: [
// ID and Published Date
IDAndDate(
id: widget.eachPaper.id.substring(
widget.eachPaper.id.lastIndexOf("/") + 1,
widget.eachPaper.id.length),
date: widget.eachPaper.publishedAt.substring(0, 10),
id: widget.eachPaper.id,
date: widget.eachPaper.publishedAt,
),

// TITLE
Expand All @@ -154,7 +138,7 @@ class _EachPaperCardState extends State<EachPaperCard> {
),
child: Container(
padding: const EdgeInsets.only(bottom: 5.0),
child: containsLatex(title)
child: Paper.containsLatex(title)
? TeXView(
child: TeXViewDocument(
title,
Expand Down Expand Up @@ -183,7 +167,7 @@ class _EachPaperCardState extends State<EachPaperCard> {
Padding(
padding: const EdgeInsets.only(bottom: 2.0),
child: Text(
"Published: ${widget.eachPaper.publishedAt.toString().substring(0, 10)}",
"Published: ${widget.eachPaper.publishedAt}",
style: const TextStyle(
fontSize: 12.0,
),
Expand Down
22 changes: 3 additions & 19 deletions lib/components/summary_bottom_sheet.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import 'package:hive/hive.dart';
import 'package:ionicons/ionicons.dart';
import 'package:theme_provider/theme_provider.dart';

import 'package:arxiv/components/each_paper_card.dart';

class SummaryBottomSheet extends StatefulWidget {
const SummaryBottomSheet({
super.key,
Expand Down Expand Up @@ -87,26 +85,12 @@ class _SummaryBottomSheetState extends State<SummaryBottomSheet> {
isSpeaking = false;
setState(() {});
});
summary = widget.paperData.summary
.trim()
.replaceAll(RegExp(r'\\n'), ' ')
.replaceAll(
RegExp(r'\\'),
'',
);
summary = widget.paperData.summary;
}

@override
Widget build(BuildContext context) {
String summary = widget.paperData.summary
.trim()
.replaceAll(RegExp(r'\\n'), ' ')
.replaceAll(RegExp(r'\\'), '');
if (containsLatex(summary)) {
summary = summary.replaceAll(RegExp(r'\$ '), r' \) ');
summary = summary.replaceAll(RegExp(r' \$'), r' \( ');
summary = summary.replaceAll(r'$', r' \) ');
}
String summary = widget.paperData.summary;
return Scaffold(
backgroundColor: Colors.transparent,
body: Container(
Expand Down Expand Up @@ -284,7 +268,7 @@ class _SummaryBottomSheetState extends State<SummaryBottomSheet> {
top: 10.0,
bottom: 100.0,
),
child: (containsLatex(summary)
child: (Paper.containsLatex(summary)
? TeXView(
child: TeXViewDocument(
summary,
Expand Down
8 changes: 8 additions & 0 deletions lib/models/chat_message.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
enum Role { ai, user, system }

class ChatMessage {
Role role;
String content;

ChatMessage(this.role, this.content);
}
Loading

0 comments on commit 673f2d5

Please sign in to comment.