-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: memory layer + stepping execution (#15)
Co-authored-by: Mike Grabowski <[email protected]>
- Loading branch information
Showing
11 changed files
with
234 additions
and
50 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -173,3 +173,6 @@ dist | |
|
||
# Finder (MacOS) folder config | ||
.DS_Store | ||
|
||
last-run-id.txt | ||
context-*.json |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
/** | ||
* Example borrowed from CrewAI. | ||
*/ | ||
import { agent } from '@dead-simple-ai-agent/framework/agent' | ||
import { iterate } from '@dead-simple-ai-agent/framework/teamwork' | ||
import { tool } from '@dead-simple-ai-agent/framework/tool' | ||
import { workflow } from '@dead-simple-ai-agent/framework/workflow' | ||
import { promises as fs } from 'fs' | ||
import { tmpdir } from 'os' | ||
import { join } from 'path' | ||
import { z } from 'zod' | ||
|
||
import { lookupWikipedia } from './tools.js' | ||
|
||
async function requestUserInput(prompt: string): Promise<string> { | ||
return new Promise((resolve) => { | ||
console.log(prompt) | ||
process.stdin.once('data', (data) => { | ||
resolve(data.toString().trim()) | ||
}) | ||
}) | ||
} | ||
export const askPatient = tool({ | ||
description: 'Tool for asking patient a question', | ||
parameters: z.object({ | ||
query: z.string().describe('The question to ask the patient'), | ||
}), | ||
execute: ({ query }): Promise<string> => { | ||
return requestUserInput(query) | ||
}, | ||
}) | ||
|
||
const nurse = agent({ | ||
role: 'Nurse,doctor assistant', | ||
description: ` | ||
You are skille nurse / doctor assistant. | ||
You role is to cooperate with reporter to create a pre-visit note for a patient that is about to come for a visit. | ||
Ask user questions about the patient's health and symptoms. | ||
Ask one question at time up to 5 questions. | ||
`, | ||
tools: { | ||
ask_question: askPatient, | ||
}, | ||
}) | ||
|
||
const reporter = agent({ | ||
role: 'Reporter', | ||
description: ` | ||
You are skilled at preparing great looking markdown reports. | ||
Prepare a report for a patient that is about to come for a visit. | ||
Add info about the patient's health and symptoms. | ||
If something is not clear use Wikipedia to check. | ||
`, | ||
tools: { | ||
lookupWikipedia, | ||
}, | ||
}) | ||
|
||
const preVisitNoteWorkflow = workflow({ | ||
members: [nurse, reporter], | ||
description: ` | ||
Create a pre-visit note for a patient that is about to come for a visit. | ||
The note should include the patient's health and symptoms. | ||
Include: | ||
- symptoms, | ||
- health issues, | ||
- medications, | ||
- allergies, | ||
- surgeries | ||
Never ask fo: | ||
- personal data, | ||
- sensitive data, | ||
- any data that can be used to identify the patient. | ||
`, | ||
output: ` | ||
A markdown report for the patient's pre-visit note. | ||
`, | ||
}) | ||
|
||
const tmpDir = tmpdir() | ||
const dbPath = join(tmpDir, 'stepping_survey_workflow_db.json') | ||
|
||
if (await fs.exists(dbPath)) { | ||
try { | ||
const messages = JSON.parse(await fs.readFile(dbPath, 'utf-8')) | ||
preVisitNoteWorkflow.messages.push(...messages) | ||
|
||
console.log('🛟 Loaded workflow from', dbPath) | ||
} catch (error) { | ||
console.log(`🚨Error while loading workflow from ${dbPath}. Starting new workflow.`) | ||
} | ||
} | ||
|
||
const result = await iterate(preVisitNoteWorkflow) | ||
|
||
console.log(result) | ||
|
||
await fs.writeFile(dbPath, JSON.stringify(result.messages, null, 2), 'utf-8') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,5 +25,8 @@ | |
}, | ||
"trustedDependencies": [ | ||
"core-js" | ||
] | ||
], | ||
"dependencies": { | ||
"nanoid": "^5.0.9" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,63 +1,84 @@ | ||
import s from 'dedent' | ||
|
||
import { executeTaskWithAgent } from './executor.js' | ||
import { getNextTask } from './supervisor/nextTask.js' | ||
import { selectAgent } from './supervisor/selectAgent.js' | ||
import { Message } from './types.js' | ||
import { Workflow } from './workflow.js' | ||
|
||
async function execute(workflow: Workflow, messages: Message[]): Promise<string> { | ||
// eslint-disable-next-line no-constant-condition | ||
const task = await getNextTask(workflow.provider, messages) | ||
export async function iterate(workflow: Workflow): Promise<Workflow> { | ||
const { messages, provider, members } = workflow | ||
|
||
const task = await getNextTask(provider, messages) | ||
if (!task) { | ||
return messages.at(-1)!.content as string // end of the recursion | ||
return { | ||
...workflow, | ||
messages, | ||
status: 'finished', | ||
} | ||
} | ||
|
||
if (workflow.maxIterations && messages.length > workflow.maxIterations) { | ||
console.debug('Max iterations exceeded ', workflow.maxIterations) | ||
return messages.at(-1)!.content as string | ||
// tbd: implement `final answer` flow to generate output message | ||
if (messages.length > workflow.maxIterations) { | ||
return { | ||
...workflow, | ||
messages, | ||
status: 'interrupted', | ||
} | ||
} | ||
|
||
// tbd: get rid of console.logs, use telemetry instead | ||
console.log('🚀 Next task:', task) | ||
|
||
messages.push({ | ||
role: 'user', | ||
content: task, | ||
}) | ||
|
||
// tbd: this throws, handle it | ||
const selectedAgent = await selectAgent(workflow.provider, task, workflow.members) | ||
const selectedAgent = await selectAgent(provider, task, members) | ||
console.log('🚀 Selected agent:', selectedAgent.role) | ||
|
||
// tbd: this should just be a try/catch | ||
// tbd: do not return string, but more information or keep memory in agent | ||
const agentRequest: Message[] = [ | ||
...messages, | ||
{ | ||
role: 'user', | ||
content: task, | ||
}, | ||
] | ||
|
||
try { | ||
const result = await executeTaskWithAgent(selectedAgent, messages, workflow.members) | ||
messages.push({ | ||
role: 'assistant', | ||
content: result, | ||
}) | ||
const result = await executeTaskWithAgent(selectedAgent, agentRequest, members) | ||
return { | ||
...workflow, | ||
messages: [ | ||
...agentRequest, | ||
{ | ||
role: 'assistant', | ||
content: result, | ||
}, | ||
], | ||
status: 'running', | ||
} | ||
} catch (error) { | ||
console.log('🚀 Task error:', error) | ||
messages.push({ | ||
role: 'assistant', | ||
content: error instanceof Error ? error.message : 'Unknown error', | ||
}) | ||
return { | ||
...workflow, | ||
messages: [ | ||
...agentRequest, | ||
{ | ||
role: 'assistant', | ||
content: error instanceof Error ? error.message : 'Unknown error', | ||
}, | ||
], | ||
status: 'failed', | ||
} | ||
} | ||
|
||
return execute(workflow, messages) | ||
} | ||
|
||
export async function teamwork(workflow: Workflow): Promise<string> { | ||
const messages = [ | ||
{ | ||
role: 'assistant' as const, | ||
content: s` | ||
Here is description of the workflow and expected output by the user: | ||
<workflow>${workflow.description}</workflow> | ||
<output>${workflow.output}</output> | ||
`, | ||
}, | ||
] | ||
return execute(workflow, messages) | ||
const result = await iterate(workflow) | ||
|
||
if (result.status === 'running') { | ||
return teamwork(result) | ||
} | ||
|
||
if (result.status === 'finished') { | ||
return result.messages.at(-1)!.content as string | ||
} | ||
|
||
// tbd: recover from errors | ||
// tbd: request final answer if took too long | ||
throw new Error('Workflow failed. This is not implemented yet.') | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
{ | ||
"name": "@dead-simple-ai-agent/memory-lowdb", | ||
"version": "0.0.1", | ||
"description": "A dead simple AI agent framework", | ||
"author": "Piotr Karwatka <[email protected]>", | ||
"exports": { | ||
".": { | ||
"bun": "./src/index.ts" | ||
} | ||
}, | ||
"type": "module", | ||
"dependencies": { | ||
"dedent": "^1.5.3", | ||
"openai": "^4.76.0", | ||
"zod": "^3.23.8", | ||
"lowdb": "^7.0.1", | ||
"@dead-simple-ai-agent/framework": "^0.0.1" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import { Context } from '@dead-simple-ai-agent/framework/executor' | ||
import { Low } from 'lowdb' | ||
import { JSONFile } from 'lowdb/node' | ||
|
||
export async function save(context: Context) { | ||
const db = new Low<Context>(new JSONFile(`context-${context.id}.json`), context) | ||
await db.write() | ||
} | ||
|
||
export async function load(context: Context): Promise<Context> { | ||
const db = new Low<Context>(new JSONFile(`context-${context.id}.json`), {} as Context) | ||
await db.read() | ||
return { ...db.data, ...context } // because team members are not serializable | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
{ | ||
"extends": "../../tsconfig.json" | ||
} |