From 38991f82330103f8123f6fafcb4835d2f3160b27 Mon Sep 17 00:00:00 2001 From: Andrew Plummer Date: Tue, 10 Dec 2024 16:11:52 -0500 Subject: [PATCH] Added utils for GPT classification using templates and standardizing response. (#252) --- services/api/.env | 5 +- services/api/__mocks__/openai.js | 21 +++ services/api/src/gpt/classify-fruits.md | 21 +++ .../gpt-response-array-unformatted.json | 23 +++ .../__fixtures__/gpt-response-formatted.json | 23 +++ .../utils/__fixtures__/gpt-response-long.json | 23 +++ .../gpt-response-unformatted.json | 23 +++ services/api/src/utils/__tests__/gpt.js | 139 ++++++++++++++++++ services/api/src/utils/gpt.js | 94 ++++++++++++ 9 files changed, 371 insertions(+), 1 deletion(-) create mode 100644 services/api/__mocks__/openai.js create mode 100644 services/api/src/gpt/classify-fruits.md create mode 100644 services/api/src/utils/__fixtures__/gpt-response-array-unformatted.json create mode 100644 services/api/src/utils/__fixtures__/gpt-response-formatted.json create mode 100644 services/api/src/utils/__fixtures__/gpt-response-long.json create mode 100644 services/api/src/utils/__fixtures__/gpt-response-unformatted.json create mode 100644 services/api/src/utils/__tests__/gpt.js create mode 100644 services/api/src/utils/gpt.js diff --git a/services/api/.env b/services/api/.env index c689f8f0..4bd39930 100644 --- a/services/api/.env +++ b/services/api/.env @@ -66,4 +66,7 @@ TWILIO_AUTH_TOKEN=AC21717619d8cf45502cc2f7cd7aee139d GOOGLE_CLIENT_ID= # Sign in with Apple -APPLE_SERVICE_ID= \ No newline at end of file +APPLE_SERVICE_ID= + +# OpenAI +OPENAI_API_KEY= \ No newline at end of file diff --git a/services/api/__mocks__/openai.js b/services/api/__mocks__/openai.js new file mode 100644 index 00000000..65d9265f --- /dev/null +++ b/services/api/__mocks__/openai.js @@ -0,0 +1,21 @@ +let mock; + +function OpenAI() { + return { + chat: { + completions: { + create() { + return mock; + }, + }, + }, + }; +} + +function setResponse(data) { + mock = data; +} + +OpenAI.setResponse = setResponse; + +module.exports = OpenAI; diff --git a/services/api/src/gpt/classify-fruits.md b/services/api/src/gpt/classify-fruits.md new file mode 100644 index 00000000..c599cd91 --- /dev/null +++ b/services/api/src/gpt/classify-fruits.md @@ -0,0 +1,21 @@ +--- SYSTEM --- + +You are a helpful assistant. + +Here is a list of fruits: + +{{fruits}} + +--- USER --- + +The following text describes someone eating a meal. Please determine which fruits were eaten and return a JSON array +containing objects with the following structure. Only output JSON, do not include any explanations. + +- "name" - The name of the fruit. +- "color" - The typical color of the fruit. +- "calories" - A rough estimate of the number of calories per serving. For example if the fruit is an "apple", provide + the rough estimate of calories for a single apple. + +Text: + +{{text}} diff --git a/services/api/src/utils/__fixtures__/gpt-response-array-unformatted.json b/services/api/src/utils/__fixtures__/gpt-response-array-unformatted.json new file mode 100644 index 00000000..d04a8d0b --- /dev/null +++ b/services/api/src/utils/__fixtures__/gpt-response-array-unformatted.json @@ -0,0 +1,23 @@ +{ + "id": "chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1", + "object": "chat.completion", + "created": 1719313006, + "model": "gpt-4o-2024-05-13", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "[{\n\"name\": \"banana\",\n \"color\": \"yellow\",\n \"calories\": 105\n}]" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 134, + "completion_tokens": 79, + "total_tokens": 213 + }, + "system_fingerprint": "fp_3e7d703517" +} diff --git a/services/api/src/utils/__fixtures__/gpt-response-formatted.json b/services/api/src/utils/__fixtures__/gpt-response-formatted.json new file mode 100644 index 00000000..b984d51d --- /dev/null +++ b/services/api/src/utils/__fixtures__/gpt-response-formatted.json @@ -0,0 +1,23 @@ +{ + "id": "chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1", + "object": "chat.completion", + "created": 1719313006, + "model": "gpt-4o-2024-05-13", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "```json\n{\n \"name\": \"banana\",\n \"color\": \"yellow\",\n \"calories\": 105\n}\n```" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 134, + "completion_tokens": 79, + "total_tokens": 213 + }, + "system_fingerprint": "fp_3e7d703517" +} diff --git a/services/api/src/utils/__fixtures__/gpt-response-long.json b/services/api/src/utils/__fixtures__/gpt-response-long.json new file mode 100644 index 00000000..23d13bc3 --- /dev/null +++ b/services/api/src/utils/__fixtures__/gpt-response-long.json @@ -0,0 +1,23 @@ +{ + "id": "chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1", + "object": "chat.completion", + "created": 1719313006, + "model": "gpt-4o-2024-05-13", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "To determine the fruit eaten, we analyze the text provided. Here, the person ate a banana. We'll now create a JSON object including the name, typical color, and a rough estimate of calories per serving for a banana.\n\nHere is the JSON object:\n\n```json\n{\n \"name\": \"banana\",\n \"color\": \"yellow\",\n \"calories\": 105\n}\n```" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 134, + "completion_tokens": 79, + "total_tokens": 213 + }, + "system_fingerprint": "fp_3e7d703517" +} diff --git a/services/api/src/utils/__fixtures__/gpt-response-unformatted.json b/services/api/src/utils/__fixtures__/gpt-response-unformatted.json new file mode 100644 index 00000000..9d1c747e --- /dev/null +++ b/services/api/src/utils/__fixtures__/gpt-response-unformatted.json @@ -0,0 +1,23 @@ +{ + "id": "chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1", + "object": "chat.completion", + "created": 1719313006, + "model": "gpt-4o-2024-05-13", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "{\n \"name\": \"banana\",\n \"color\": \"yellow\",\n \"calories\": 105\n}" + }, + "logprobs": null, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 134, + "completion_tokens": 79, + "total_tokens": 213 + }, + "system_fingerprint": "fp_3e7d703517" +} diff --git a/services/api/src/utils/__tests__/gpt.js b/services/api/src/utils/__tests__/gpt.js new file mode 100644 index 00000000..813ae499 --- /dev/null +++ b/services/api/src/utils/__tests__/gpt.js @@ -0,0 +1,139 @@ +const { prompt } = require('../gpt'); +const { setResponse } = require('openai'); + +const responseLong = require('../__fixtures__/gpt-response-long.json'); +const responseFormatted = require('../__fixtures__/gpt-response-formatted.json'); +const responseUnformatted = require('../__fixtures__/gpt-response-unformatted.json'); +const responseArrayUnformatted = require('../__fixtures__/gpt-response-array-unformatted.json'); + +describe('prompt', () => { + it('should succeed for a long response', async () => { + setResponse(responseLong); + const result = await prompt({ + file: 'classify-fruits', + text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', + }); + expect(result).toEqual({ + name: 'banana', + color: 'yellow', + calories: 105, + }); + }); + + it('should succeed for a formatted response', async () => { + setResponse(responseFormatted); + const result = await prompt({ + file: 'classify-fruits', + text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', + }); + expect(result).toEqual({ + name: 'banana', + color: 'yellow', + calories: 105, + }); + }); + + it('should succeed for an unformatted response', async () => { + setResponse(responseUnformatted); + const result = await prompt({ + file: 'classify-fruits', + text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', + }); + expect(result).toEqual({ + name: 'banana', + color: 'yellow', + calories: 105, + }); + }); + + it('should succeed for an array response', async () => { + setResponse(responseArrayUnformatted); + const result = await prompt({ + file: 'classify-fruits', + text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', + }); + expect(result).toEqual([ + { + name: 'banana', + color: 'yellow', + calories: 105, + }, + ]); + }); + + it('should be able to return all messages', async () => { + setResponse(responseArrayUnformatted); + const result = await prompt({ + file: 'classify-fruits', + text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', + output: 'messages', + }); + expect(result).toEqual([ + { + role: 'system', + content: 'You are a helpful assistant.\n\nHere is a list of fruits:', + }, + { + role: 'user', + content: + 'The following text describes someone eating a meal. Please determine which fruits were eaten and return a JSON array\n' + + 'containing objects with the following structure. Only output JSON, do not include any explanations.\n' + + '\n' + + '- "name" - The name of the fruit.\n' + + '- "color" - The typical color of the fruit.\n' + + '- "calories" - A rough estimate of the number of calories per serving. For example if the fruit is an "apple", provide\n' + + ' the rough estimate of calories for a single apple.\n' + + '\n' + + 'Text:\n' + + '\n' + + 'I had a burger and some french fries for dinner. For dessert I had a banana.', + }, + { + role: 'assistant', + content: '[{\n"name": "banana",\n "color": "yellow",\n "calories": 105\n}]', + }, + ]); + }); + + it('should be able to output just the text', async () => { + setResponse(responseArrayUnformatted); + const result = await prompt({ + file: 'classify-fruits', + text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', + output: 'text', + }); + expect(result).toBe('[{\n"name": "banana",\n "color": "yellow",\n "calories": 105\n}]'); + }); + + it('should be able to output the raw response', async () => { + setResponse(responseArrayUnformatted); + const result = await prompt({ + file: 'classify-fruits', + text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', + output: 'raw', + }); + expect(result).toEqual({ + id: 'chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1', + object: 'chat.completion', + created: 1719313006, + model: 'gpt-4o-2024-05-13', + choices: [ + { + index: 0, + message: { + role: 'assistant', + content: '[{\n"name": "banana",\n "color": "yellow",\n "calories": 105\n}]', + }, + logprobs: null, + finish_reason: 'stop', + }, + ], + usage: { + prompt_tokens: 134, + completion_tokens: 79, + total_tokens: 213, + }, + system_fingerprint: 'fp_3e7d703517', + }); + }); +}); diff --git a/services/api/src/utils/gpt.js b/services/api/src/utils/gpt.js new file mode 100644 index 00000000..525d5fc2 --- /dev/null +++ b/services/api/src/utils/gpt.js @@ -0,0 +1,94 @@ +const fs = require('fs/promises'); +const path = require('path'); +const OpenAI = require('openai'); +const Mustache = require('mustache'); +const config = require('@bedrockio/config'); +const { memoize } = require('lodash'); + +const openai = new OpenAI({ + apiKey: config.get('OPENAI_API_KEY'), +}); + +const TEMPLATE_DIR = path.join(__dirname, '../gpt'); +const MESSAGES_REG = /(?:^|\n)-{3,}\s*(\w+)\s*-{3,}(.*?)(?=\n-{3,}|$)/gs; +const JSON_REG = /([{[].+[}\]])/s; + +const MODEL = 'gpt-4o'; + +async function prompt(options) { + const messages = await getMessages(options); + return await runCompletion(messages, options); +} + +async function getMessages(options) { + const { file, output, ...rest } = options; + const template = await loadTemplate(file); + + const raw = Mustache.render(template, transformParams(rest)); + + const messages = []; + for (let match of raw.matchAll(MESSAGES_REG)) { + const [, role, content] = match; + messages.push({ + role: role.toLowerCase(), + content: content.trim(), + }); + } + + return messages; +} + +function transformParams(params) { + const result = {}; + for (let [key, value] of Object.entries(params)) { + if (Array.isArray(value)) { + value = value + .map((el) => { + return `- ${el}`; + }) + .join('\n'); + } else if (typeof value === 'object') { + value = JSON.stringify(value, null, 2); + } + result[key] = value; + } + return result; +} + +async function runCompletion(messages, options) { + const { output = 'json' } = options; + + const response = await openai.chat.completions.create({ + model: MODEL, + messages, + }); + + let content = response.choices[0].message.content; + + if (output === 'raw') { + return response; + } else if (output === 'text') { + return content; + } else if (output === 'messages') { + const { message } = response.choices[0]; + return [...messages, message]; + } else if (output === 'json') { + try { + const match = content.match(JSON_REG); + return JSON.parse(match[1]); + } catch (error) { + throw new Error('Unable to derive JSON object in response.'); + } + } +} + +const loadTemplate = memoize(async (file) => { + if (!file.endsWith('.md')) { + file += '.md'; + } + return await fs.readFile(path.join(TEMPLATE_DIR, file), 'utf8'); +}); + +module.exports = { + prompt, +};