Skip to content

Commit

Permalink
♻️ refactor: refactor the server db model implement (lobehub#4878)
Browse files Browse the repository at this point in the history
* refactor db

* update tests

* fix test and build

* add tests

* add tests

* add test for files

* add test for chunks

* remove unused method

* add message tests
  • Loading branch information
arvinxx authored Dec 3, 2024
1 parent 1fa7a3c commit 3814853
Show file tree
Hide file tree
Showing 52 changed files with 1,159 additions and 520 deletions.
3 changes: 2 additions & 1 deletion src/app/(main)/repos/[id]/@menu/default.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { notFound } from 'next/navigation';
import { Flexbox } from 'react-layout-kit';

import { serverDB } from '@/database/server';
import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase';

import Head from './Head';
Expand All @@ -14,7 +15,7 @@ type Props = { params: Params };

const MenuPage = async ({ params }: Props) => {
const id = params.id;
const item = await KnowledgeBaseModel.findById(params.id);
const item = await KnowledgeBaseModel.findById(serverDB, params.id);

if (!item) return notFound();

Expand Down
3 changes: 2 additions & 1 deletion src/app/(main)/repos/[id]/page.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { redirect } from 'next/navigation';

import { serverDB } from '@/database/server';
import { KnowledgeBaseModel } from '@/database/server/models/knowledgeBase';
import FileManager from '@/features/FileManager';

Expand All @@ -10,7 +11,7 @@ interface Params {
type Props = { params: Params };

export default async ({ params }: Props) => {
const item = await KnowledgeBaseModel.findById(params.id);
const item = await KnowledgeBaseModel.findById(serverDB, params.id);

if (!item) return redirect('/repos');

Expand Down
4 changes: 3 additions & 1 deletion src/database/schemas/topic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import { boolean, jsonb, pgTable, text, unique } from 'drizzle-orm/pg-core';
import { createInsertSchema } from 'drizzle-zod';

import { idGenerator } from '@/database/utils/idGenerator';
import { ChatTopicMetadata } from '@/types/topic';

import { timestamps, timestamptz } from './_helpers';
import { sessions } from './session';
import { users } from './user';
Expand All @@ -21,7 +23,7 @@ export const topics = pgTable(
.notNull(),
clientId: text('client_id'),
historySummary: text('history_summary'),
metadata: jsonb('metadata'),
metadata: jsonb('metadata').$type<ChatTopicMetadata | undefined>(),
...timestamps,
},
(t) => ({
Expand Down
10 changes: 4 additions & 6 deletions src/database/server/core/dbForTest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import { serverDBEnv } from '@/config/db';

import * as schema from '../../schemas';

const migrationsFolder = join(__dirname, '../../migrations');

export const getTestDBInstance = async () => {
let connectionString = serverDBEnv.DATABASE_TEST_URL;

Expand All @@ -23,9 +25,7 @@ export const getTestDBInstance = async () => {

const db = nodeDrizzle(client, { schema });

await nodeMigrator.migrate(db, {
migrationsFolder: join(__dirname, '../../migrations'),
});
await nodeMigrator.migrate(db, { migrationsFolder });

return db;
}
Expand All @@ -37,9 +37,7 @@ export const getTestDBInstance = async () => {

const db = neonDrizzle(client, { schema });

await migrator.migrate(db, {
migrationsFolder: join(__dirname, '../migrations'),
});
await migrator.migrate(db, { migrationsFolder });

return db;
};
12 changes: 3 additions & 9 deletions src/database/server/models/__tests__/_test_template.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';

import { getTestDBInstance } from '@/database/server/core/dbForTest';

Expand All @@ -9,14 +9,8 @@ import { SessionGroupModel } from '../sessionGroup';

let serverDB = await getTestDBInstance();

vi.mock('@/database/server/core/db', async () => ({
get serverDB() {
return serverDB;
},
}));

const userId = 'session-group-model-test-user-id';
const sessionGroupModel = new SessionGroupModel(userId);
const sessionGroupModel = new SessionGroupModel(serverDB, userId);

beforeEach(async () => {
await serverDB.delete(users);
Expand Down Expand Up @@ -74,7 +68,7 @@ describe('SessionGroupModel', () => {
await sessionGroupModel.create({ name: 'Test Group 1' });
await sessionGroupModel.create({ name: 'Test Group 333' });

const anotherSessionGroupModel = new SessionGroupModel('user2');
const anotherSessionGroupModel = new SessionGroupModel(serverDB, 'user2');
await anotherSessionGroupModel.create({ name: 'Test Group 2' });

await sessionGroupModel.deleteAll();
Expand Down
10 changes: 2 additions & 8 deletions src/database/server/models/__tests__/agent.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';

import { getTestDBInstance } from '@/database/server/core/dbForTest';

Expand All @@ -18,14 +18,8 @@ import { AgentModel } from '../agent';

let serverDB = await getTestDBInstance();

vi.mock('@/database/server/core/db', async () => ({
get serverDB() {
return serverDB;
},
}));

const userId = 'agent-model-test-user-id';
const agentModel = new AgentModel(userId);
const agentModel = new AgentModel(serverDB, userId);

const knowledgeBase = { id: 'kb1', userId, name: 'knowledgeBase' };
const fileList = [
Expand Down
8 changes: 1 addition & 7 deletions src/database/server/models/__tests__/asyncTask.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,8 @@ import { ASYNC_TASK_TIMEOUT, AsyncTaskModel } from '../asyncTask';

let serverDB = await getTestDBInstance();

vi.mock('@/database/server/core/db', async () => ({
get serverDB() {
return serverDB;
},
}));

const userId = 'async-task-model-test-user-id';
const asyncTaskModel = new AsyncTaskModel(userId);
const asyncTaskModel = new AsyncTaskModel(serverDB, userId);

beforeEach(async () => {
await serverDB.delete(users);
Expand Down
171 changes: 155 additions & 16 deletions src/database/server/models/__tests__/chunk.test.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,18 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';

import { getTestDBInstance } from '@/database/server/core/dbForTest';
import { uuid } from '@/utils/uuid';

import {
chunks,
embeddings,
fileChunks,
files,
unstructuredChunks,
users,
} from '../../../schemas';
import { chunks, embeddings, fileChunks, files, unstructuredChunks, users } from '../../../schemas';
import { ChunkModel } from '../chunk';
import { codeEmbedding, designThinkingQuery, designThinkingQuery2 } from './fixtures/embedding';

let serverDB = await getTestDBInstance();

vi.mock('@/database/server/core/db', async () => ({
get serverDB() {
return serverDB;
},
}));

const userId = 'chunk-model-test-user-id';
const chunkModel = new ChunkModel(userId);
const chunkModel = new ChunkModel(serverDB, userId);
const sharedFileList = [
{
id: '1',
Expand Down Expand Up @@ -79,6 +67,27 @@ describe('ChunkModel', () => {
expect(createdChunks[0]).toMatchObject(params[0]);
expect(createdChunks[1]).toMatchObject(params[1]);
});

// 测试空参数场景
it('should handle empty params array', async () => {
const result = await chunkModel.bulkCreate([], '1');
expect(result).toHaveLength(0);
});

// 测试事务回滚
it('should rollback transaction on error', async () => {
const invalidParams = [
{ text: 'Chunk 1', userId },
{ index: 'abc', userId }, // 这会导致错误
] as any;

await expect(chunkModel.bulkCreate(invalidParams, '1')).rejects.toThrow();

const createdChunks = await serverDB.query.chunks.findMany({
where: eq(chunks.userId, userId),
});
expect(createdChunks).toHaveLength(0);
});
});

describe('delete', () => {
Expand Down Expand Up @@ -191,6 +200,41 @@ describe('ChunkModel', () => {
expect(result[1].id).toBe(chunk2.id);
expect(result[0].similarity).toBeGreaterThan(result[1].similarity);
});
// 补充无文件 ID 的搜索场景
it('should perform semantic search without fileIds', async () => {
const [chunk1, chunk2] = await serverDB
.insert(chunks)
.values([
{ text: 'Test Chunk 1', userId },
{ text: 'Test Chunk 2', userId },
])
.returning();

await serverDB.insert(embeddings).values([
{ chunkId: chunk1.id, embeddings: designThinkingQuery, userId },
{ chunkId: chunk2.id, embeddings: codeEmbedding, userId },
]);

const result = await chunkModel.semanticSearch({
embedding: designThinkingQuery2,
fileIds: undefined,
query: 'design thinking',
});

expect(result).toBeDefined();
expect(result).toHaveLength(2);
});

// 测试空结果场景
it('should return empty array when no matches found', async () => {
const result = await chunkModel.semanticSearch({
embedding: designThinkingQuery,
fileIds: ['non-existent-file'],
query: 'no matches',
});

expect(result).toHaveLength(0);
});
});

describe('bulkCreateUnstructuredChunks', () => {
Expand Down Expand Up @@ -391,5 +435,100 @@ content in Table html is below:
<table>...</table>
`);
});

it('should handle null text', () => {
const chunk = {
text: null,
type: 'Text',
metadata: {},
};

const result = chunkModel['mapChunkText'](chunk);
expect(result).toBeNull();
});

it('should handle missing metadata for Table type', () => {
const chunk = {
text: 'Table text',
type: 'Table',
metadata: {},
};

const result = chunkModel['mapChunkText'](chunk);
expect(result).toContain('Table text');
expect(result).toContain('content in Table html is below:');
expect(result).toContain('undefined'); // metadata.text_as_html is undefined
});
});

describe('findById', () => {
it('should find a chunk by id', async () => {
// Create a test chunk
const [chunk] = await serverDB
.insert(chunks)
.values({ text: 'Test Chunk', userId })
.returning();

const result = await chunkModel.findById(chunk.id);

expect(result).toBeDefined();
expect(result?.id).toBe(chunk.id);
expect(result?.text).toBe('Test Chunk');
});

it('should return null for non-existent id', async () => {
const result = await chunkModel.findById(uuid());
expect(result).toBeUndefined();
});
});

describe('semanticSearchForChat', () => {
// 测试空文件 ID 列表场景
it('should return empty array when fileIds is empty', async () => {
const result = await chunkModel.semanticSearchForChat({
embedding: designThinkingQuery,
fileIds: [],
query: 'test',
});

expect(result).toHaveLength(0);
});

// 测试结果限制
it('should limit results to 5 items', async () => {
const fileId = '1';
// Create 6 chunks
const chunkResult = await serverDB
.insert(chunks)
.values(
Array(6)
.fill(0)
.map((_, i) => ({ text: `Test Chunk ${i}`, userId })),
)
.returning();

await serverDB.insert(fileChunks).values(
chunkResult.map((chunk) => ({
fileId,
chunkId: chunk.id,
})),
);

await serverDB.insert(embeddings).values(
chunkResult.map((chunk) => ({
chunkId: chunk.id,
embeddings: designThinkingQuery,
userId,
})),
);

const result = await chunkModel.semanticSearchForChat({
embedding: designThinkingQuery2,
fileIds: [fileId],
query: 'test',
});

expect(result).toHaveLength(5);
});
});
});
Loading

0 comments on commit 3814853

Please sign in to comment.