Skip to content

Commit

Permalink
feat: add dall-e adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
pixelass committed May 2, 2023
1 parent 84322a4 commit 9c2f563
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 25 deletions.
3 changes: 3 additions & 0 deletions .eslintrc.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
"jest": true,
"node": true
},
"globals": {
"BufferEncoding": "readonly"
},
"plugins": ["@typescript-eslint", "unicorn", "unused-imports", "import", "jest", "prettier"],
"ignorePatterns": ["*.d.ts"],
"rules": {
Expand Down
32 changes: 16 additions & 16 deletions examples/book.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import { Agent } from "../src/index.js";
import { GPTModelAdapter } from "../src/openai/index.js";
import type { GPT3Options } from "../src/openai/types.js";
import { DallEModelAdapter, GPTModelAdapter } from "../src/openai/index.js";
import type { DallEOptions, GPT3Options } from "../src/openai/types.js";
import { createFileWriter, FSAdapter } from "../src/store/index.js";
import type { StoreAdapter } from "../src/store/types.js";
import type { ModelMessage } from "../src/types.js";
import type { AgentOptions } from "../src/types.js";
import type { ModelAdapter } from "../src/types.js";
import { createInstruction, sprint } from "../src/utils.js";

import { openai } from "./config.js";

const dir = "out/book";
const store = new FSAdapter(dir);
const fileWriter = createFileWriter(dir);
const imageWriter = createFileWriter(dir, "base64");
const book: ModelMessage & { title: string } = {
title: "AGI can solve all problems",
title: "Mysteries of the pyramids",
};

interface AuthorData extends ModelMessage {
Expand Down Expand Up @@ -42,7 +45,7 @@ const author = new Agent(
historySize: 1,
systemInstruction: createInstruction(
"Scientific Author",
"Write a short story with inline images in markdown. Add a descriptive prompt for each images to be created",
"Write a brief story with inline images in markdown. Add a descriptive prompt for each images to be created",
{
images: [{ path: "string", prompt: "string" }],
files: [{ path: "string", content: "markdown" }],
Expand All @@ -55,24 +58,21 @@ const author = new Agent(
options
);

/*
// Future API
const illustrator = new Agent(
new DallEModelAdapter<DallEOptions>({
size: "1024x1024",
n: 1,
systemInstruction: createInstruction("Illustrator", "Create illustrations for the chapter.", {
files: [{ path: "string", content: "string" }],
}),
}),
new DallEModelAdapter<DallEOptions>(
{
size: "256x256",
n: 1,
},
openai
),
store,
[fileWriter]
{ tools: [imageWriter] }
);
*/

try {
const messageId = await store.set(book);
await sprint(messageId, [author]);
await sprint<ModelAdapter<ModelMessage>, StoreAdapter>(messageId, [author, illustrator]);
console.log("Done");
} catch (error) {
console.error("Error:", error);
Expand Down
41 changes: 40 additions & 1 deletion src/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { ChatCompletionRequestMessage, OpenAIApi } from "openai";
import type { ModelAdapter, ModelMessage } from "../types.js";
import { extractCode } from "../utils.js";

import type { GPTOptions } from "./types.js";
import type { DallEOptions, GPTOptions, ImageMessage } from "./types.js";

/**
* Represents a GPT model adapter that can assign tasks and move to the next task.
Expand Down Expand Up @@ -79,3 +79,42 @@ export class GPTModelAdapter<Options extends GPTOptions> implements ModelAdapter
}
}
}

export class DallEModelAdapter<Options extends DallEOptions> implements ModelAdapter<ModelMessage> {
#options: Options;
#openai: OpenAIApi;
/**
* Creates an instance of the DallEModelAdapter class.
*
* @param {Options} options - The DALL-E model options.
* @param {OpenAIApi} openai - A configured openai API instance.
*/
constructor(options: Options, openai: OpenAIApi) {
this.#options = options;
this.#openai = openai;
}

async assign(task: ImageMessage): Promise<ModelMessage> {
try {
const files = await Promise.all(
task.images.map(async image => {
const response = await this.#openai.createImage({
...this.#options,
prompt: image.prompt,
// eslint-disable-next-line camelcase
response_format: "b64_json",
});
const base64 = response.data.data[0].b64_json;
const content = base64.replace(/^data:image\/\w+;base64,/, "");
return { path: image.path, content };
})
);

return {
files,
};
} catch (error) {
throw new Error(`Error assigning task in DallEModelAdapter: ${error.message}`);
}
}
}
10 changes: 10 additions & 0 deletions src/openai/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import type { CreateImageRequest } from "openai/api.js";
import type { Except } from "type-fest";

import type { ReasonableTemperature } from "../types.js";
import type { ModelMessage } from "../types.js";

/**
* Represents options for the GPT model.
Expand Down Expand Up @@ -43,3 +47,9 @@ export interface GPT4Options extends GPTOptions {
model: "gpt-4";
historySize: 1 | 2 | 3 | 4;
}

export interface ImageMessage extends ModelMessage {
images: [{ path: string; prompt: string }];
}

export type DallEOptions = Except<CreateImageRequest, "prompt" | "response_format">;
5 changes: 3 additions & 2 deletions src/store/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,16 @@ export class FSAdapter implements StoreAdapter {
* Creates a file writer tool for writing output files.
*
* @param {string} dir - The directory where the output files should be written.
* @param {BufferEncoding} [encoding="utf-8"] - the encoding that should vbe used when writing files
* @returns {Tool} - The file writer tool instance.
*/
export function createFileWriter(dir: string): Tool {
export function createFileWriter(dir: string, encoding: BufferEncoding = "utf-8"): Tool {
return {
prop: "files",
async run(message: ModelMessage) {
await Promise.all(
message.files.map(async file =>
writeFile(path.join(dir, "output", file.path), file.content)
writeFile(path.join(dir, "output", file.path), file.content, encoding)
)
);
},
Expand Down
17 changes: 11 additions & 6 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,14 @@ export async function exists(pathLike: string) {
* @async
* @param {string} filePath - The path to the file to be written.
* @param {string} content - The content to be written to the file.
* @param {BufferEncoding} [encoding="utf-8"] - the encoding that should vbe used when writing files
* @returns {Promise<void>} - Resolves when the file is successfully written, otherwise throws an error.
*/
export async function writeFile(filePath: string, content: string): Promise<void> {
export async function writeFile(
filePath: string,
content: string,
encoding: BufferEncoding = "utf8"
): Promise<void> {
try {
const { dir } = path.parse(filePath);

Expand All @@ -66,7 +71,7 @@ export async function writeFile(filePath: string, content: string): Promise<void
}

// Write the content to the file
await fs.writeFile(filePath, content);
await fs.writeFile(filePath, content, { encoding });
} catch (error) {
throw new Error(`Error writing file at path '${filePath}': ${error.message}`);
}
Expand Down Expand Up @@ -143,10 +148,10 @@ export async function getResult(messageId: string, agent: Agent) {
* @param chain An array of agents to be executed in sequence.
* @returns The final message ID produced by the last agent in the chain.
*/
export async function sprint<Model extends ModelAdapter<ModelMessage>, Store extends StoreAdapter>(
featureId: string,
chain: Agent<Model, Store>[]
) {
export async function sprint<
Model extends ModelAdapter<ModelMessage> = ModelAdapter<ModelMessage>,
Store extends StoreAdapter = StoreAdapter
>(featureId: string, chain: Agent<Model, Store>[]) {
return chain.reduce(
async (messageId, agent) => getResult(await messageId, agent),
Promise.resolve(featureId)
Expand Down

0 comments on commit 9c2f563

Please sign in to comment.