Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RAG #12

Open
wants to merge 10 commits into
base: openai
Choose a base branch
from
17 changes: 17 additions & 0 deletions firestore.indexes.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,23 @@
"order": "DESCENDING"
}
]
},
{
"collectionGroup": "threadVectors",
"queryScope": "COLLECTION",
"fields": [
{
"fieldPath": "uid",
"order": "ASCENDING"
},
{
"fieldPath": "messages",
"vectorConfig": {
"dimension": 1536,
"flat": {}
}
}
]
}
],
"fieldOverrides": []
Expand Down
1 change: 1 addition & 0 deletions packages/shared/src/types/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export * from './firebase.js';
export * from './thread.js';
export * from './threadContent.js';
export * from './threadVector.js';
10 changes: 10 additions & 0 deletions packages/shared/src/types/threadVector.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import type { Timestamp, WithId } from './firebase.js';
import type { VectorValue } from '@google-cloud/firestore';

export type ThreadVectorData = {
updatedAt: Timestamp;
uid: string;
messages: VectorValue;
};

export type ThreadVector = WithId<ThreadVectorData>;
3 changes: 3 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions services/functions/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"dependencies": {
"@google-cloud/vertexai": "^1.1.0",
"@local/shared": "workspace:^",
"dedent": "^1.5.3",
"firebase-admin": "^12.1.0",
"firebase-functions": "^4.3.1",
"lodash-es": "^4.17.21",
Expand Down
5 changes: 5 additions & 0 deletions services/functions/src/firestore/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import { thread } from './thread/index.js';

export const firestore = {
thread,
};
5 changes: 5 additions & 0 deletions services/functions/src/firestore/thread/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import { onDocumentDeleted } from './onDocumentDeleted.js';

export const thread = {
onDocumentDeleted,
};
9 changes: 9 additions & 0 deletions services/functions/src/firestore/thread/onDocumentDeleted.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { deleteThreadVector, threadVectorRef } from '../../models/threadVector.js';
import { getDocumentData } from '../../utils/firebase/firestore.js';
import { onDocumentDeleted as _onDocumentDeleted } from '../../utils/firebase/functions.js';

export const onDocumentDeleted = _onDocumentDeleted({ document: 'threads/{id}' }, async (event) => {
const { id } = event.params;
const { exists } = await getDocumentData(threadVectorRef({ id }));
exists && (await deleteThreadVector({ id }));
});
4 changes: 4 additions & 0 deletions services/functions/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import './utils/firebase/app.js';
import { firestore as _firestore } from './firestore/index.js';
import { taskQueues as _taskQueues } from './taskQueues/index.js';

process.env.TZ = 'Asia/Tokyo';

export * from './geminiPro.js';
export * from './openai.js';
export const firestore = { ..._firestore };
export const taskQueues = { ..._taskQueues };
13 changes: 13 additions & 0 deletions services/functions/src/models/threadVector.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import { getConverter, getFirestore, serverTimestamp } from '../utils/firebase/firestore.js';
import type { ThreadVectorData } from '@local/shared';

export const threadVectorConverter = getConverter<ThreadVectorData>();

export const threadVectorsRef = () => getFirestore().collection('threadVectors').withConverter(threadVectorConverter);

export const threadVectorRef = ({ id }: { id: string }) => threadVectorsRef().doc(id);

export const setThreadVector = async ({ id, data }: { id: string; data: Partial<ThreadVectorData> }) =>
threadVectorRef({ id }).set({ updatedAt: serverTimestamp(), ...data }, { merge: true });

export const deleteThreadVector = async ({ id }: { id: string }) => threadVectorRef({ id }).delete();
69 changes: 62 additions & 7 deletions services/functions/src/openai.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import dedent from 'dedent';
import OpenAI from 'openai';
import { updateThreadContent, throttleUpdateThreadContent } from './models/threadContent.js';
import { updateThreadContent, throttleUpdateThreadContent, threadContentRef } from './models/threadContent.js';
import { threadVectorsRef } from './models/threadVector.js';
import { env } from './utils/env.js';
import { onCall, logger, HttpsError } from './utils/firebase/functions.js';
import { getCollectionData, getDocumentData } from './utils/firebase/firestore.js';
import { onCall, logger, HttpsError, taskQueues } from './utils/firebase/functions.js';
import { embedding } from './utils/openai.js';
import type { Message, ThreadContent } from '@local/shared';
import type { ChatCompletionMessageParam } from 'openai/resources/index.mjs';

export const openai = onCall<{ threadId: string; model: ThreadContent['model']; messages: Message[] }>(
{ secrets: ['OPENAI_API_KEY'] },
Expand All @@ -13,12 +18,59 @@ export const openai = onCall<{ threadId: string; model: ThreadContent['model'];

try {
const openai = new OpenAI({ apiKey: env('OPENAI_API_KEY'), organization: env('OPENAI_ORGANIZATION_ID') });
const query = await embedding({ input: JSON.stringify(messages.at(-1)), openai });
const similarVectors = await getCollectionData(
threadVectorsRef()
.where('uid', '==', auth.uid)
.findNearest('messages', query, { limit: 4, distanceMeasure: 'COSINE' }),
);
// NOTE: 不等式フィルタとベクトル検索の組み合わせがサポートされていないのでここで同じthreadIdを除外している
const similarThreadContents = await Promise.all(
similarVectors
.filter(({ id }) => id !== threadId)
.map(async ({ id }) => (await getDocumentData(threadContentRef({ id }))).data),
);
const instruction = dedent(`
# Instruction
1. Analyze the provided past thread data.
- The data is provided in JSON format.
- Each thread is stored with its thread ID as the key.
2. Search for content similar to the current question in the past thread data.
3. If a similar question is found, determine if it meets the following condition:
- The answer to the similar question can be considered effective for the current question.
4. If the condition is met, generate a response following these steps:
a. Generate the response in the same language as the input.
b. Format the response according to the specified response format.
c. Translate the fixed phrases in the response format to the same language as the input.
d. Include a list of all the thread IDs that meet the condition in the response format, each as a separate bullet point.
5. If no similar question is found or the condition is not met, generate a response to the current question as usual.
# Response Format
{Current answer}
The following responses might also be helpful for reference:
- <https://${process.env.GCLOUD_PROJECT}.web.app/?threadId={threadID1}>
- <https://${process.env.GCLOUD_PROJECT}.web.app/?threadId={threadID2}>
...
- <https://${process.env.GCLOUD_PROJECT}.web.app/?threadId={threadIDN}>
# Past Thread Data
${JSON.stringify(
similarThreadContents.reduce(
(acc, { id, messages }) => ({
...acc,
[id]: messages.map(({ role, contents }) => [role, contents[0].value]),
}),
{},
),
)}
`);
const stream = openai.beta.chat.completions.stream({
model,
messages: messages.map(({ role, contents }) => ({
role: role === 'human' ? 'user' : 'assistant',
content: contents[0].value,
})),
messages: [
{ role: 'system', content: instruction },
...messages.map(({ role, contents }) => ({
role: role === 'human' ? 'user' : 'assistant',
content: contents[0].value,
})),
] as ChatCompletionMessageParam[],
stream: true,
});
let content = '';
Expand All @@ -34,7 +86,10 @@ export const openai = onCall<{ threadId: string; model: ThreadContent['model'];
...messages,
{ role: 'ai', contents: [{ type: 'text', value: contentCompletion }] },
];
await updateThreadContent({ id: threadId, data: { messages: finalMessages } });
await Promise.all([
updateThreadContent({ id: threadId, data: { messages: finalMessages } }),
taskQueues.embeddingThreadContent.enqueue({ id: threadId, uid: auth.uid, messages: finalMessages }),
]);
return true;
} catch (error) {
logger.error('Failed to openai.', { error });
Expand Down
17 changes: 17 additions & 0 deletions services/functions/src/taskQueues/embeddingThreadContent.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import { FieldValue } from 'firebase-admin/firestore';
import { setThreadVector } from '../models/threadVector.js';
import { onTaskDispatched } from '../utils/firebase/functions.js';
import { embedding } from '../utils/openai.js';
import type { ThreadContent } from '@local/shared';

export const embeddingThreadContent = onTaskDispatched(
{ secrets: ['OPENAI_API_KEY'] },
async ({
data: { id, uid, messages },
}: {
data: { id: string; uid: string; messages: ThreadContent['messages'] };
}) => {
const vector = await embedding({ input: JSON.stringify(messages) });
await setThreadVector({ id, data: { uid, messages: FieldValue.vector(vector) } });
},
);
5 changes: 5 additions & 0 deletions services/functions/src/taskQueues/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import { embeddingThreadContent } from './embeddingThreadContent.js';

export const taskQueues = {
embeddingThreadContent,
};
12 changes: 11 additions & 1 deletion services/functions/src/utils/firebase/firestore.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import { FieldValue, getFirestore as _getFirestore } from 'firebase-admin/firestore';
import type { VectorQuery } from '@google-cloud/firestore';
import type { WithId } from '@local/shared';
import type {
DocumentData,
FirestoreDataConverter,
Firestore,
Timestamp,
WithFieldValue,
CollectionReference,
DocumentReference,
Query,
} from 'firebase-admin/firestore';

let firestore: Firestore;
Expand Down Expand Up @@ -35,4 +39,10 @@ const getConverter = <T extends DocumentData>(): FirestoreDataConverter<WithId<T
},
});

export { serverTimestamp, getConverter, getFirestore };
const getDocumentData = async <T>(ref: DocumentReference<T>) =>
ref.get().then((doc) => ({ data: { id: doc.id, ...doc.data() } as WithId<T>, exists: doc.exists }));

const getCollectionData = async <T>(query: CollectionReference<T> | Query<T> | VectorQuery<T>) =>
query.get().then(({ docs }) => docs.map((doc) => ({ id: doc.id, ...doc.data() }) as WithId<T>));

export { serverTimestamp, getConverter, getFirestore, getDocumentData, getCollectionData };
27 changes: 26 additions & 1 deletion services/functions/src/utils/firebase/functions.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
/* eslint-disable @typescript-eslint/no-explicit-any */
import { getFunctions } from 'firebase-admin/functions';
import { https, logger } from 'firebase-functions/v2';
import { onDocumentDeleted as _onDocumentDeleted } from 'firebase-functions/v2/firestore';
import { onCall as _onCall } from 'firebase-functions/v2/https';
import { HttpsError } from 'firebase-functions/v2/identity';
import { onTaskDispatched as _onTaskDispatched } from 'firebase-functions/v2/tasks';
import type { DocumentOptions, FirestoreEvent, QueryDocumentSnapshot } from 'firebase-functions/v2/firestore';
import type { CallableOptions, CallableRequest } from 'firebase-functions/v2/https';
import type { TaskQueueOptions, Request } from 'firebase-functions/v2/tasks';

export const defaultRegion = 'asia-northeast1';

Expand All @@ -12,4 +17,24 @@ const onCall = <T>(optsOrHandler: CallableOptions | OnCallHandler<T>, _handler?:
return _onCall<T>({ region: defaultRegion, memory: '1GiB', timeoutSeconds: 300, ...optsOrHandler }, handler);
};

export { https, logger, HttpsError, onCall };
type OnDocumentDeletedHandler = (event: FirestoreEvent<QueryDocumentSnapshot | undefined>) => Promise<void>;
const onDocumentDeleted = (opts: DocumentOptions, handler: OnDocumentDeletedHandler) => {
return _onDocumentDeleted({ region: defaultRegion, memory: '1GiB', timeoutSeconds: 300, ...opts }, handler);
};

type OnTaskDispatchedHandler = (request: Request) => Promise<void>;
const onTaskDispatched = (
optsOrHandler: TaskQueueOptions | OnTaskDispatchedHandler,
_handler?: OnTaskDispatchedHandler,
) => {
const handler = _handler ?? (optsOrHandler as OnTaskDispatchedHandler);
return _onTaskDispatched({ region: defaultRegion, memory: '1GiB', timeoutSeconds: 300, ...optsOrHandler }, handler);
};

const taskQueues = {
embeddingThreadContent: getFunctions().taskQueue(
`locations/${defaultRegion}/functions/taskQueues-embeddingThreadContent`,
),
};

export { https, logger, HttpsError, onCall, onDocumentDeleted, onTaskDispatched, taskQueues };
20 changes: 20 additions & 0 deletions services/functions/src/utils/openai.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import OpenAI from 'openai';
import { env } from './env.js';

export const embedding = async ({
input,
openai = new OpenAI({
apiKey: env('OPENAI_API_KEY'),
organization: env('OPENAI_ORGANIZATION_ID'),
}),
}: {
input: string;
openai?: OpenAI;
}) => {
const response = await openai.embeddings.create({
model: 'text-embedding-3-small',
encoding_format: 'float',
input,
});
return response.data[0].embedding;
};