-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added utils for GPT classification using templates and standardizing …
…response. (#252)
- Loading branch information
1 parent
debb274
commit 38991f8
Showing
9 changed files
with
371 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
let mock; | ||
|
||
function OpenAI() { | ||
return { | ||
chat: { | ||
completions: { | ||
create() { | ||
return mock; | ||
}, | ||
}, | ||
}, | ||
}; | ||
} | ||
|
||
function setResponse(data) { | ||
mock = data; | ||
} | ||
|
||
OpenAI.setResponse = setResponse; | ||
|
||
module.exports = OpenAI; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
--- SYSTEM --- | ||
|
||
You are a helpful assistant. | ||
|
||
Here is a list of fruits: | ||
|
||
{{fruits}} | ||
|
||
--- USER --- | ||
|
||
The following text describes someone eating a meal. Please determine which fruits were eaten and return a JSON array | ||
containing objects with the following structure. Only output JSON, do not include any explanations. | ||
|
||
- "name" - The name of the fruit. | ||
- "color" - The typical color of the fruit. | ||
- "calories" - A rough estimate of the number of calories per serving. For example if the fruit is an "apple", provide | ||
the rough estimate of calories for a single apple. | ||
|
||
Text: | ||
|
||
{{text}} |
23 changes: 23 additions & 0 deletions
23
services/api/src/utils/__fixtures__/gpt-response-array-unformatted.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
{ | ||
"id": "chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1", | ||
"object": "chat.completion", | ||
"created": 1719313006, | ||
"model": "gpt-4o-2024-05-13", | ||
"choices": [ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": "[{\n\"name\": \"banana\",\n \"color\": \"yellow\",\n \"calories\": 105\n}]" | ||
}, | ||
"logprobs": null, | ||
"finish_reason": "stop" | ||
} | ||
], | ||
"usage": { | ||
"prompt_tokens": 134, | ||
"completion_tokens": 79, | ||
"total_tokens": 213 | ||
}, | ||
"system_fingerprint": "fp_3e7d703517" | ||
} |
23 changes: 23 additions & 0 deletions
23
services/api/src/utils/__fixtures__/gpt-response-formatted.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
{ | ||
"id": "chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1", | ||
"object": "chat.completion", | ||
"created": 1719313006, | ||
"model": "gpt-4o-2024-05-13", | ||
"choices": [ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": "```json\n{\n \"name\": \"banana\",\n \"color\": \"yellow\",\n \"calories\": 105\n}\n```" | ||
}, | ||
"logprobs": null, | ||
"finish_reason": "stop" | ||
} | ||
], | ||
"usage": { | ||
"prompt_tokens": 134, | ||
"completion_tokens": 79, | ||
"total_tokens": 213 | ||
}, | ||
"system_fingerprint": "fp_3e7d703517" | ||
} |
23 changes: 23 additions & 0 deletions
23
services/api/src/utils/__fixtures__/gpt-response-long.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
{ | ||
"id": "chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1", | ||
"object": "chat.completion", | ||
"created": 1719313006, | ||
"model": "gpt-4o-2024-05-13", | ||
"choices": [ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": "To determine the fruit eaten, we analyze the text provided. Here, the person ate a banana. We'll now create a JSON object including the name, typical color, and a rough estimate of calories per serving for a banana.\n\nHere is the JSON object:\n\n```json\n{\n \"name\": \"banana\",\n \"color\": \"yellow\",\n \"calories\": 105\n}\n```" | ||
}, | ||
"logprobs": null, | ||
"finish_reason": "stop" | ||
} | ||
], | ||
"usage": { | ||
"prompt_tokens": 134, | ||
"completion_tokens": 79, | ||
"total_tokens": 213 | ||
}, | ||
"system_fingerprint": "fp_3e7d703517" | ||
} |
23 changes: 23 additions & 0 deletions
23
services/api/src/utils/__fixtures__/gpt-response-unformatted.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
{ | ||
"id": "chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1", | ||
"object": "chat.completion", | ||
"created": 1719313006, | ||
"model": "gpt-4o-2024-05-13", | ||
"choices": [ | ||
{ | ||
"index": 0, | ||
"message": { | ||
"role": "assistant", | ||
"content": "{\n \"name\": \"banana\",\n \"color\": \"yellow\",\n \"calories\": 105\n}" | ||
}, | ||
"logprobs": null, | ||
"finish_reason": "stop" | ||
} | ||
], | ||
"usage": { | ||
"prompt_tokens": 134, | ||
"completion_tokens": 79, | ||
"total_tokens": 213 | ||
}, | ||
"system_fingerprint": "fp_3e7d703517" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
const { prompt } = require('../gpt'); | ||
const { setResponse } = require('openai'); | ||
|
||
const responseLong = require('../__fixtures__/gpt-response-long.json'); | ||
const responseFormatted = require('../__fixtures__/gpt-response-formatted.json'); | ||
const responseUnformatted = require('../__fixtures__/gpt-response-unformatted.json'); | ||
const responseArrayUnformatted = require('../__fixtures__/gpt-response-array-unformatted.json'); | ||
|
||
describe('prompt', () => { | ||
it('should succeed for a long response', async () => { | ||
setResponse(responseLong); | ||
const result = await prompt({ | ||
file: 'classify-fruits', | ||
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', | ||
}); | ||
expect(result).toEqual({ | ||
name: 'banana', | ||
color: 'yellow', | ||
calories: 105, | ||
}); | ||
}); | ||
|
||
it('should succeed for a formatted response', async () => { | ||
setResponse(responseFormatted); | ||
const result = await prompt({ | ||
file: 'classify-fruits', | ||
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', | ||
}); | ||
expect(result).toEqual({ | ||
name: 'banana', | ||
color: 'yellow', | ||
calories: 105, | ||
}); | ||
}); | ||
|
||
it('should succeed for an unformatted response', async () => { | ||
setResponse(responseUnformatted); | ||
const result = await prompt({ | ||
file: 'classify-fruits', | ||
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', | ||
}); | ||
expect(result).toEqual({ | ||
name: 'banana', | ||
color: 'yellow', | ||
calories: 105, | ||
}); | ||
}); | ||
|
||
it('should succeed for an array response', async () => { | ||
setResponse(responseArrayUnformatted); | ||
const result = await prompt({ | ||
file: 'classify-fruits', | ||
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', | ||
}); | ||
expect(result).toEqual([ | ||
{ | ||
name: 'banana', | ||
color: 'yellow', | ||
calories: 105, | ||
}, | ||
]); | ||
}); | ||
|
||
it('should be able to return all messages', async () => { | ||
setResponse(responseArrayUnformatted); | ||
const result = await prompt({ | ||
file: 'classify-fruits', | ||
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', | ||
output: 'messages', | ||
}); | ||
expect(result).toEqual([ | ||
{ | ||
role: 'system', | ||
content: 'You are a helpful assistant.\n\nHere is a list of fruits:', | ||
}, | ||
{ | ||
role: 'user', | ||
content: | ||
'The following text describes someone eating a meal. Please determine which fruits were eaten and return a JSON array\n' + | ||
'containing objects with the following structure. Only output JSON, do not include any explanations.\n' + | ||
'\n' + | ||
'- "name" - The name of the fruit.\n' + | ||
'- "color" - The typical color of the fruit.\n' + | ||
'- "calories" - A rough estimate of the number of calories per serving. For example if the fruit is an "apple", provide\n' + | ||
' the rough estimate of calories for a single apple.\n' + | ||
'\n' + | ||
'Text:\n' + | ||
'\n' + | ||
'I had a burger and some french fries for dinner. For dessert I had a banana.', | ||
}, | ||
{ | ||
role: 'assistant', | ||
content: '[{\n"name": "banana",\n "color": "yellow",\n "calories": 105\n}]', | ||
}, | ||
]); | ||
}); | ||
|
||
it('should be able to output just the text', async () => { | ||
setResponse(responseArrayUnformatted); | ||
const result = await prompt({ | ||
file: 'classify-fruits', | ||
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', | ||
output: 'text', | ||
}); | ||
expect(result).toBe('[{\n"name": "banana",\n "color": "yellow",\n "calories": 105\n}]'); | ||
}); | ||
|
||
it('should be able to output the raw response', async () => { | ||
setResponse(responseArrayUnformatted); | ||
const result = await prompt({ | ||
file: 'classify-fruits', | ||
text: 'I had a burger and some french fries for dinner. For dessert I had a banana.', | ||
output: 'raw', | ||
}); | ||
expect(result).toEqual({ | ||
id: 'chatcmpl-9dy8si0kRlF27OZiDtA4Y38u4lfO1', | ||
object: 'chat.completion', | ||
created: 1719313006, | ||
model: 'gpt-4o-2024-05-13', | ||
choices: [ | ||
{ | ||
index: 0, | ||
message: { | ||
role: 'assistant', | ||
content: '[{\n"name": "banana",\n "color": "yellow",\n "calories": 105\n}]', | ||
}, | ||
logprobs: null, | ||
finish_reason: 'stop', | ||
}, | ||
], | ||
usage: { | ||
prompt_tokens: 134, | ||
completion_tokens: 79, | ||
total_tokens: 213, | ||
}, | ||
system_fingerprint: 'fp_3e7d703517', | ||
}); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
const fs = require('fs/promises'); | ||
const path = require('path'); | ||
const OpenAI = require('openai'); | ||
const Mustache = require('mustache'); | ||
const config = require('@bedrockio/config'); | ||
const { memoize } = require('lodash'); | ||
|
||
const openai = new OpenAI({ | ||
apiKey: config.get('OPENAI_API_KEY'), | ||
}); | ||
|
||
const TEMPLATE_DIR = path.join(__dirname, '../gpt'); | ||
const MESSAGES_REG = /(?:^|\n)-{3,}\s*(\w+)\s*-{3,}(.*?)(?=\n-{3,}|$)/gs; | ||
const JSON_REG = /([{[].+[}\]])/s; | ||
|
||
const MODEL = 'gpt-4o'; | ||
|
||
async function prompt(options) { | ||
const messages = await getMessages(options); | ||
return await runCompletion(messages, options); | ||
} | ||
|
||
async function getMessages(options) { | ||
const { file, output, ...rest } = options; | ||
const template = await loadTemplate(file); | ||
|
||
const raw = Mustache.render(template, transformParams(rest)); | ||
|
||
const messages = []; | ||
for (let match of raw.matchAll(MESSAGES_REG)) { | ||
const [, role, content] = match; | ||
messages.push({ | ||
role: role.toLowerCase(), | ||
content: content.trim(), | ||
}); | ||
} | ||
|
||
return messages; | ||
} | ||
|
||
function transformParams(params) { | ||
const result = {}; | ||
for (let [key, value] of Object.entries(params)) { | ||
if (Array.isArray(value)) { | ||
value = value | ||
.map((el) => { | ||
return `- ${el}`; | ||
}) | ||
.join('\n'); | ||
} else if (typeof value === 'object') { | ||
value = JSON.stringify(value, null, 2); | ||
} | ||
result[key] = value; | ||
} | ||
return result; | ||
} | ||
|
||
async function runCompletion(messages, options) { | ||
const { output = 'json' } = options; | ||
|
||
const response = await openai.chat.completions.create({ | ||
model: MODEL, | ||
messages, | ||
}); | ||
|
||
let content = response.choices[0].message.content; | ||
|
||
if (output === 'raw') { | ||
return response; | ||
} else if (output === 'text') { | ||
return content; | ||
} else if (output === 'messages') { | ||
const { message } = response.choices[0]; | ||
return [...messages, message]; | ||
} else if (output === 'json') { | ||
try { | ||
const match = content.match(JSON_REG); | ||
return JSON.parse(match[1]); | ||
} catch (error) { | ||
throw new Error('Unable to derive JSON object in response.'); | ||
} | ||
} | ||
} | ||
|
||
const loadTemplate = memoize(async (file) => { | ||
if (!file.endsWith('.md')) { | ||
file += '.md'; | ||
} | ||
return await fs.readFile(path.join(TEMPLATE_DIR, file), 'utf8'); | ||
}); | ||
|
||
module.exports = { | ||
prompt, | ||
}; |