Skip to content

Commit

Permalink
Added utils for GPT classification using templates and standardizing …
Browse files Browse the repository at this point in the history
…response. (#252)
  • Loading branch information
andrewplummer authored Dec 10, 2024
1 parent debb274 commit 38991f8
Show file tree
Hide file tree
Showing 9 changed files with 371 additions and 1 deletion.
5 changes: 4 additions & 1 deletion services/api/.env
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,7 @@ TWILIO_AUTH_TOKEN=AC21717619d8cf45502cc2f7cd7aee139d
GOOGLE_CLIENT_ID=

# Sign in with Apple
APPLE_SERVICE_ID=
APPLE_SERVICE_ID=

# OpenAI
OPENAI_API_KEY=
21 changes: 21 additions & 0 deletions services/api/__mocks__/openai.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
let mock;

function OpenAI() {
return {
chat: {
completions: {
create() {
return mock;
},
},
},
};
}

function setResponse(data) {
mock = data;
}

OpenAI.setResponse = setResponse;

module.exports = OpenAI;
21 changes: 21 additions & 0 deletions services/api/src/gpt/classify-fruits.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
--- SYSTEM ---

You are a helpful assistant.

Here is a list of fruits:

{{fruits}}

--- USER ---

The following text describes someone eating a meal. Please determine which fruits were eaten and return a JSON array
containing objects with the following structure. Only output JSON, do not include any explanations.

- "name" - The name of the fruit.
- "color" - The typical color of the fruit.
- "calories" - A rough estimate of the number of calories per serving. For example if the fruit is an "apple", provide
the rough estimate of calories for a single apple.

Text:

{{text}}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"id": "chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1",
"object": "chat.completion",
"created": 1719313006,
"model": "gpt-4o-2024-05-13",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "[{\n\"name\": \"banana\",\n \"color\": \"yellow\",\n \"calories\": 105\n}]"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 134,
"completion_tokens": 79,
"total_tokens": 213
},
"system_fingerprint": "fp_3e7d703517"
}
23 changes: 23 additions & 0 deletions services/api/src/utils/__fixtures__/gpt-response-formatted.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"id": "chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1",
"object": "chat.completion",
"created": 1719313006,
"model": "gpt-4o-2024-05-13",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "```json\n{\n \"name\": \"banana\",\n \"color\": \"yellow\",\n \"calories\": 105\n}\n```"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 134,
"completion_tokens": 79,
"total_tokens": 213
},
"system_fingerprint": "fp_3e7d703517"
}
23 changes: 23 additions & 0 deletions services/api/src/utils/__fixtures__/gpt-response-long.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"id": "chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1",
"object": "chat.completion",
"created": 1719313006,
"model": "gpt-4o-2024-05-13",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "To determine the fruit eaten, we analyze the text provided. Here, the person ate a banana. We'll now create a JSON object including the name, typical color, and a rough estimate of calories per serving for a banana.\n\nHere is the JSON object:\n\n```json\n{\n \"name\": \"banana\",\n \"color\": \"yellow\",\n \"calories\": 105\n}\n```"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 134,
"completion_tokens": 79,
"total_tokens": 213
},
"system_fingerprint": "fp_3e7d703517"
}
23 changes: 23 additions & 0 deletions services/api/src/utils/__fixtures__/gpt-response-unformatted.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"id": "chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1",
"object": "chat.completion",
"created": 1719313006,
"model": "gpt-4o-2024-05-13",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "{\n \"name\": \"banana\",\n \"color\": \"yellow\",\n \"calories\": 105\n}"
},
"logprobs": null,
"finish_reason": "stop"
}
],
"usage": {
"prompt_tokens": 134,
"completion_tokens": 79,
"total_tokens": 213
},
"system_fingerprint": "fp_3e7d703517"
}
139 changes: 139 additions & 0 deletions services/api/src/utils/__tests__/gpt.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
const { prompt } = require('../gpt');
const { setResponse } = require('openai');

const responseLong = require('../__fixtures__/gpt-response-long.json');
const responseFormatted = require('../__fixtures__/gpt-response-formatted.json');
const responseUnformatted = require('../__fixtures__/gpt-response-unformatted.json');
const responseArrayUnformatted = require('../__fixtures__/gpt-response-array-unformatted.json');

describe('prompt', () => {
it('should succeed for a long response', async () => {
setResponse(responseLong);
const result = await prompt({
file: 'classify-fruits',
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.',
});
expect(result).toEqual({
name: 'banana',
color: 'yellow',
calories: 105,
});
});

it('should succeed for a formatted response', async () => {
setResponse(responseFormatted);
const result = await prompt({
file: 'classify-fruits',
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.',
});
expect(result).toEqual({
name: 'banana',
color: 'yellow',
calories: 105,
});
});

it('should succeed for an unformatted response', async () => {
setResponse(responseUnformatted);
const result = await prompt({
file: 'classify-fruits',
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.',
});
expect(result).toEqual({
name: 'banana',
color: 'yellow',
calories: 105,
});
});

it('should succeed for an array response', async () => {
setResponse(responseArrayUnformatted);
const result = await prompt({
file: 'classify-fruits',
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.',
});
expect(result).toEqual([
{
name: 'banana',
color: 'yellow',
calories: 105,
},
]);
});

it('should be able to return all messages', async () => {
setResponse(responseArrayUnformatted);
const result = await prompt({
file: 'classify-fruits',
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.',
output: 'messages',
});
expect(result).toEqual([
{
role: 'system',
content: 'You are a helpful assistant.\n\nHere is a list of fruits:',
},
{
role: 'user',
content:
'The following text describes someone eating a meal. Please determine which fruits were eaten and return a JSON array\n' +
'containing objects with the following structure. Only output JSON, do not include any explanations.\n' +
'\n' +
'- "name" - The name of the fruit.\n' +
'- "color" - The typical color of the fruit.\n' +
'- "calories" - A rough estimate of the number of calories per serving. For example if the fruit is an "apple", provide\n' +
' the rough estimate of calories for a single apple.\n' +
'\n' +
'Text:\n' +
'\n' +
'I had a burger and some french fries for dinner. For dessert I had a banana.',
},
{
role: 'assistant',
content: '[{\n"name": "banana",\n "color": "yellow",\n "calories": 105\n}]',
},
]);
});

it('should be able to output just the text', async () => {
setResponse(responseArrayUnformatted);
const result = await prompt({
file: 'classify-fruits',
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.',
output: 'text',
});
expect(result).toBe('[{\n"name": "banana",\n "color": "yellow",\n "calories": 105\n}]');
});

it('should be able to output the raw response', async () => {
setResponse(responseArrayUnformatted);
const result = await prompt({
file: 'classify-fruits',
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.',
output: 'raw',
});
expect(result).toEqual({
id: 'chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1',
object: 'chat.completion',
created: 1719313006,
model: 'gpt-4o-2024-05-13',
choices: [
{
index: 0,
message: {
role: 'assistant',
content: '[{\n"name": "banana",\n "color": "yellow",\n "calories": 105\n}]',
},
logprobs: null,
finish_reason: 'stop',
},
],
usage: {
prompt_tokens: 134,
completion_tokens: 79,
total_tokens: 213,
},
system_fingerprint: 'fp_3e7d703517',
});
});
});
94 changes: 94 additions & 0 deletions services/api/src/utils/gpt.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
const fs = require('fs/promises');
const path = require('path');
const OpenAI = require('openai');
const Mustache = require('mustache');
const config = require('@bedrockio/config');
const { memoize } = require('lodash');

const openai = new OpenAI({
apiKey: config.get('OPENAI_API_KEY'),
});

const TEMPLATE_DIR = path.join(__dirname, '../gpt');
const MESSAGES_REG = /(?:^|\n)-{3,}\s*(\w+)\s*-{3,}(.*?)(?=\n-{3,}|$)/gs;
const JSON_REG = /([{[].+[}\]])/s;

const MODEL = 'gpt-4o';

async function prompt(options) {
const messages = await getMessages(options);
return await runCompletion(messages, options);
}

async function getMessages(options) {
const { file, output, ...rest } = options;
const template = await loadTemplate(file);

const raw = Mustache.render(template, transformParams(rest));

const messages = [];
for (let match of raw.matchAll(MESSAGES_REG)) {
const [, role, content] = match;
messages.push({
role: role.toLowerCase(),
content: content.trim(),
});
}

return messages;
}

function transformParams(params) {
const result = {};
for (let [key, value] of Object.entries(params)) {
if (Array.isArray(value)) {
value = value
.map((el) => {
return `- ${el}`;
})
.join('\n');
} else if (typeof value === 'object') {
value = JSON.stringify(value, null, 2);
}
result[key] = value;
}
return result;
}

async function runCompletion(messages, options) {
const { output = 'json' } = options;

const response = await openai.chat.completions.create({
model: MODEL,
messages,
});

let content = response.choices[0].message.content;

if (output === 'raw') {
return response;
} else if (output === 'text') {
return content;
} else if (output === 'messages') {
const { message } = response.choices[0];
return [...messages, message];
} else if (output === 'json') {
try {
const match = content.match(JSON_REG);
return JSON.parse(match[1]);
} catch (error) {
throw new Error('Unable to derive JSON object in response.');
}
}
}

const loadTemplate = memoize(async (file) => {
if (!file.endsWith('.md')) {
file += '.md';
}
return await fs.readFile(path.join(TEMPLATE_DIR, file), 'utf8');
});

module.exports = {
prompt,
};

0 comments on commit 38991f8

Please sign in to comment.