Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle null metadata docs in bq retriever #695

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 42 additions & 15 deletions js/plugins/vertexai/src/vector-search/bigquery.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
*/

import { Document, DocumentDataSchema } from '@genkit-ai/ai/retriever';
import { BigQuery } from '@google-cloud/bigquery';
import { logger } from '@genkit-ai/core/logging';
import { BigQuery, QueryRowsResponse } from '@google-cloud/bigquery';
import { ZodError } from 'zod';
import { DocumentIndexer, DocumentRetriever, Neighbor } from './types';

/**
* Creates a BigQuery Document Retriever.
*
Expand All @@ -36,34 +39,58 @@ export const getBigQueryDocumentRetriever = (
const bigQueryRetriever: DocumentRetriever = async (
neighbors: Neighbor[]
): Promise<Document[]> => {
const ids = neighbors
const ids: string[] = neighbors
.map((neighbor) => neighbor.datapoint?.datapointId)
.filter(Boolean);
.filter(Boolean) as string[];

const query = `
SELECT * FROM \`${datasetId}.${tableId}\`
WHERE id IN UNNEST(@ids)
`;

const options = {
query,
params: { ids },
};
const [rows] = await bq.query(options);
const docs: Document[] = rows
.map((row) => {
const docData = {

let rows: QueryRowsResponse[0];

try {
[rows] = await bq.query(options);
} catch (queryError) {
logger.error('Failed to execute BigQuery query:', queryError);
return [];
}

const documents: Document[] = [];

for (const row of rows) {
try {
const docData: { content: any; metadata?: any } = {
content: JSON.parse(row.content),
metadata: JSON.parse(row.metadata),
};
const parsedDocData = DocumentDataSchema.safeParse(docData);
if (parsedDocData.success) {
return new Document(parsedDocData.data);

if (row.metadata) {
docData.metadata = JSON.parse(row.metadata);
}

const parsedDocData = DocumentDataSchema.parse(docData);
documents.push(new Document(parsedDocData));
} catch (error) {
const id = row.id;
const errorPrefix = `Failed to parse document data for document with ID ${id}:`;

if (error instanceof ZodError || error instanceof Error) {
logger.warn(`${errorPrefix} ${error.message}`);
} else {
logger.warn(errorPrefix);
}
return null;
})
.filter((doc): doc is Document => !!doc);
}
}

return docs;
return documents;
};

return bigQueryRetriever;
};

Expand Down
168 changes: 168 additions & 0 deletions js/plugins/vertexai/tests/vector-search/bigquery_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { Document } from '@genkit-ai/ai/retriever';
import { BigQuery } from '@google-cloud/bigquery';
import assert from 'node:assert';
import { describe, it } from 'node:test';
import { getBigQueryDocumentRetriever } from '../../src';

class MockBigQuery {
query: Function;

constructor({
mockRows,
shouldThrowError = false,
}: {
mockRows: any[];
shouldThrowError?: boolean;
}) {
this.query = async (_options: {
query: string;
params: { ids: string[] };
}) => {
if (shouldThrowError) {
throw new Error('Query failed');
}
return [mockRows];
};
}
}

describe('getBigQueryDocumentRetriever', () => {
it('returns a function that retrieves documents from BigQuery', async () => {
const doc1 = Document.fromText('content1');
const doc2 = Document.fromText('content2');

const mockRows = [
{
id: '1',
content: JSON.stringify(doc1.content),
metadata: null,
},
{
id: '2',
content: JSON.stringify(doc2.content),
metadata: null,
},
];

const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery;
const documentRetriever = getBigQueryDocumentRetriever(
mockBigQuery,
'test-table',
'test-dataset'
);

const documents = await documentRetriever([
{ datapoint: { datapointId: '1' } },
{ datapoint: { datapointId: '2' } },
]);

assert.deepStrictEqual(documents, [doc1, doc2]);
});

it('returns an empty array when no documents match', async () => {
const mockRows: any[] = [];

const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery;
const documentRetriever = getBigQueryDocumentRetriever(
mockBigQuery,
'test-table',
'test-dataset'
);

const documents = await documentRetriever([
{ datapoint: { datapointId: '3' } },
]);

assert.deepStrictEqual(documents, []);
});

it('handles BigQuery query errors', async () => {
const mockBigQuery = new MockBigQuery({
mockRows: [],
shouldThrowError: true,
}) as unknown as BigQuery;
const documentRetriever = getBigQueryDocumentRetriever(
mockBigQuery,
'test-table',
'test-dataset'
);
// no need to assert the error, just make sure it doesn't throw
await documentRetriever([{ datapoint: { datapointId: '1' } }]);
});

it('filters out invalid documents', async () => {
const validDoc = Document.fromText('valid content');
const mockRows = [
{
id: '1',
content: JSON.stringify(validDoc.content),
metadata: null,
},
{
id: '2',
content: 'invalid JSON',
metadata: null,
},
];

const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery;
const documentRetriever = getBigQueryDocumentRetriever(
mockBigQuery,
'test-table',
'test-dataset'
);

const documents = await documentRetriever([
{ datapoint: { datapointId: '1' } },
{ datapoint: { datapointId: '2' } },
]);

assert.deepStrictEqual(documents, [validDoc]);
});

it('handles missing content in documents', async () => {
const validDoc = Document.fromText('valid content');
const mockRows = [
{
id: '1',
content: JSON.stringify(validDoc.content),
metadata: null,
},
{
id: '2',
content: null,
metadata: null,
},
];

const mockBigQuery = new MockBigQuery({ mockRows }) as unknown as BigQuery;
const documentRetriever = getBigQueryDocumentRetriever(
mockBigQuery,
'test-table',
'test-dataset'
);

const documents = await documentRetriever([
{ datapoint: { datapointId: '1' } },
{ datapoint: { datapointId: '2' } },
]);

assert.deepStrictEqual(documents, [validDoc]);
});
});
2 changes: 1 addition & 1 deletion js/testapps/vertexai-vector-search-bigquery/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ configureGenkit({
googleAuth: {
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
},
vectorSearchIndexOptions: [
vectorSearchOptions: [
{
publicDomainName: VECTOR_SEARCH_PUBLIC_DOMAIN_NAME,
indexEndpointId: VECTOR_SEARCH_INDEX_ENDPOINT_ID,
Expand Down
2 changes: 1 addition & 1 deletion js/testapps/vertexai-vector-search-custom/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ configureGenkit({
googleAuth: {
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
},
vectorSearchIndexOptions: [
vectorSearchOptions: [
{
publicDomainName: VECTOR_SEARCH_PUBLIC_DOMAIN_NAME,
indexEndpointId: VECTOR_SEARCH_INDEX_ENDPOINT_ID,
Expand Down
2 changes: 1 addition & 1 deletion js/testapps/vertexai-vector-search-firestore/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ configureGenkit({
googleAuth: {
scopes: ['https://www.googleapis.com/auth/cloud-platform'],
},
vectorSearchIndexOptions: [
vectorSearchOptions: [
{
publicDomainName: VECTOR_SEARCH_PUBLIC_DOMAIN_NAME,
indexEndpointId: VECTOR_SEARCH_INDEX_ENDPOINT_ID,
Expand Down
Loading