Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: incrementally move to Aila not using barrel files #289

Merged
merged 7 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,18 @@
"autodocs",
"autoprefixer",
"autosize",
"backmerge",
"backticked",
"beitzah",
"bethclark",
"bugsnag",
"categorisation",
"Categorised",
"centered",
"cloudinary",
"clsx",
"codegen",
"COLOR",
"compat",
"contrib",
"cuid",
Expand All @@ -40,6 +43,7 @@
"Docgen",
"dockerized",
"dopplerhq",
"dotenv",
"EASS",
"EHRC",
"estree",
Expand Down Expand Up @@ -106,6 +110,7 @@
"ponyfill",
"popover",
"portabletext",
"postcss",
"posthog",
"postpack",
"posttest",
Expand All @@ -117,6 +122,7 @@
"psql",
"pusherapp",
"ratelimit",
"refs",
"Regen",
"remeda",
"Rerank",
Expand All @@ -128,6 +134,7 @@
"sslmode",
"SUBJ",
"superjson",
"svgs",
"tailwindcss",
"tanstack",
"testid",
Expand All @@ -144,6 +151,7 @@
"uidotdev",
"unjudged",
"unsets",
"unshallow",
"unsummarised",
"untruncate",
"untruncated",
Expand All @@ -155,6 +163,7 @@
"valign",
"vars",
"vectorstores",
"vercel",
"WCAG",
"webvtt",
"zadd",
Expand Down
2 changes: 1 addition & 1 deletion apps/nextjs/src/app/api/chat/user.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { auth, clerkClient } from "@clerk/nextjs/server";
import { AilaAuthenticationError } from "@oakai/aila";
import { AilaAuthenticationError } from "@oakai/aila/src/core/AilaError";
import { demoUsers, inngest } from "@oakai/core";
import { posthogAiBetaServerClient } from "@oakai/core/src/analytics/posthogAiBetaServerClient";
import { UserBannedError } from "@oakai/core/src/models/userBannedError";
Expand Down
14 changes: 8 additions & 6 deletions apps/nextjs/src/app/api/chat/webActionsPlugin.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import type { AilaPlugin } from "@oakai/aila/src/core/plugins";
import { AilaThreatDetectionError } from "@oakai/aila/src/features/threatDetection/types";
import { AilaThreatDetectionError } from "@oakai/aila/src/features/threatDetection";
import { handleHeliconeError } from "@oakai/aila/src/utils/moderation/moderationErrorHandling";
import { inngest } from "@oakai/core/src/inngest";
import { SafetyViolations as defaultSafetyViolations } from "@oakai/core/src/models/safetyViolations";
import {
SafetyViolations as defaultSafetyViolations,
inngest,
} from "@oakai/core";
import { UserBannedError } from "@oakai/core/src/models/userBannedError";
import type { PrismaClientWithAccelerate } from "@oakai/db";
import { aiLogger } from "@oakai/logger";
Expand Down Expand Up @@ -33,11 +35,11 @@ export const createWebActionsPlugin: PluginCreator = (
prisma,
SafetyViolations,
);
enqueue(heliconeErrorMessage);
await enqueue(heliconeErrorMessage);
}

if (error instanceof Error) {
enqueue({
await enqueue({
type: "error",
message: error.message,
value: `Sorry, an error occurred: ${error.message}`,
Expand Down Expand Up @@ -80,7 +82,7 @@ export const createWebActionsPlugin: PluginCreator = (
} catch (error) {
if (error instanceof UserBannedError) {
log.info("User is banned, queueing account lock message");
enqueue({
await enqueue({
type: "action",
action: "SHOW_ACCOUNT_LOCKED",
});
Expand Down
1 change: 1 addition & 0 deletions packages/aila/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"@pollyjs/persister-fs": "^6.0.6",
"@types/jest": "^29.5.12",
"eslint": "^8.56.0",
"eslint-plugin-turbo": "^2.2.3",
"jest": "^29.7.0",
"setup-polly-jest": "^0.11.0",
"ts-jest": "^29.1.4",
Expand Down
4 changes: 2 additions & 2 deletions packages/aila/src/core/Aila.liveWithOpenAI.test.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { AilaInitializationOptions } from ".";
import { Aila } from ".";
import { MockCategoriser } from "../features/categorisation/categorisers/MockCategoriser";
import { Aila } from "./Aila";
import { checkLastMessage, expectPatch, expectText } from "./Aila.testHelpers";
import type { AilaInitializationOptions } from "./types";

const runInCI = process.env.CI === "true";
const runManually = process.env.RUN_LLM_TESTS === "true";
Expand Down
10 changes: 6 additions & 4 deletions packages/aila/src/core/Aila.test.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import { Aila } from ".";
import type { Polly } from "@pollyjs/core";

import { setupPolly } from "../../tests/mocks/setupPolly";
import { MockCategoriser } from "../features/categorisation/categorisers/MockCategoriser";
import { Aila } from "./Aila";
import { AilaAuthenticationError } from "./AilaError";
import { MockLLMService } from "./llm/MockLLMService";

describe("Aila", () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let polly: any;
let polly: Polly;

beforeAll(() => {
polly = setupPolly();
Expand Down Expand Up @@ -348,7 +349,9 @@ describe("Aila", () => {
const mockCategoriser = new MockCategoriser({ mockedLessonPlan });

const mockLLMResponse = [
// eslint-disable-next-line @typescript-eslint/quotes, quotes
'{"type":"patch","reasoning":"Update title","value":{"op":"replace","path":"/title","value":"Updated Mocked Lesson Plan"}}␞\n',
// eslint-disable-next-line @typescript-eslint/quotes, quotes
'{"type":"patch","reasoning":"Update subject","value":{"op":"replace","path":"/subject","value":"Updated Mocked Subject"}}␞\n',
];
const mockLLMService = new MockLLMService(mockLLMResponse);
Expand Down Expand Up @@ -379,7 +382,6 @@ describe("Aila", () => {
// Use MockLLMService to generate a response
await ailaInstance.generateSync({ input: "Test input" });

console.log("Generated");
// Check if MockLLMService updates were applied
expect(ailaInstance.lesson.plan.title).toBe("Updated Mocked Lesson Plan");
expect(ailaInstance.lesson.plan.subject).toBe("Updated Mocked Subject");
Expand Down
2 changes: 1 addition & 1 deletion packages/aila/src/core/Aila.testHelpers.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { expect } from "@jest/globals";
import invariant from "tiny-invariant";

import type { Aila } from ".";
import type { MessagePart, TextDocument } from "../protocol/jsonPatchProtocol";
import type { Aila } from "./Aila";

export function checkAssistantResponse(content: string) {
// Check that the response is a string (not JSON)
Expand Down
2 changes: 1 addition & 1 deletion packages/aila/src/core/Aila.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import type { PrismaClientWithAccelerate } from "@oakai/db";
import { prisma as globalPrisma } from "@oakai/db";
import { prisma as globalPrisma } from "@oakai/db/client";
import { aiLogger } from "@oakai/logger";

import {
Expand Down
17 changes: 7 additions & 10 deletions packages/aila/src/core/chat/AilaChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,13 @@ import {
} from "@oakai/core/src/utils/subjects";
import invariant from "tiny-invariant";

import type { AilaChatService, AilaServices } from "../..";
import { AilaError } from "../..";
import { DEFAULT_MODEL, DEFAULT_TEMPERATURE } from "../../constants";
import type {
AilaGenerationStatus} from "../../features/generation";
import {
AilaGeneration
} from "../../features/generation";
import type { AilaChatService } from "../../core/AilaServices";
import type { AilaServices } from "../../core/AilaServices";
import { AilaGeneration } from "../../features/generation/AilaGeneration";
import type { AilaGenerationStatus } from "../../features/generation/types";
import { generateMessageId } from "../../helpers/chat/generateMessageId";
import type {
JsonPatchDocumentOptional} from "../../protocol/jsonPatchProtocol";
import type { JsonPatchDocumentOptional } from "../../protocol/jsonPatchProtocol";
import {
LLMMessageSchema,
parseMessageParts,
Expand All @@ -24,6 +20,7 @@ import type {
AilaPersistedChat,
AilaRagRelevantLesson,
} from "../../protocol/schema";
import { AilaError } from "../AilaError";
import type { LLMService } from "../llm/LLMService";
import { OpenAIService } from "../llm/OpenAIService";
import type { AilaPromptBuilder } from "../prompt/AilaPromptBuilder";
Expand Down Expand Up @@ -279,7 +276,7 @@ export class AilaChat implements AilaChatService {
invariant(responseText, "Response text not set");
await this._generation.complete({ status, responseText });
}
this._generation.persist(status);
await this._generation.persist(status);
}

private async persistChat() {
Expand Down
8 changes: 5 additions & 3 deletions packages/aila/src/core/chat/AilaStreamHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ export class AilaStreamHandler {
await this.readFromStream();
}
} catch (e) {
this.handleStreamError(e);
await this.handleStreamError(e);
log.info("Stream error", e, this._chat.iteration, this._chat.id);
} finally {
this._isStreaming = false;
Expand All @@ -53,7 +53,9 @@ export class AilaStreamHandler {
log.info("Chat completed", this._chat.iteration, this._chat.id);
} catch (e) {
this._chat.aila.errorReporter?.reportError(e);
throw new AilaChatError("Chat completion failed", { cause: e });
controller.error(
new AilaChatError("Chat completion failed", { cause: e }),
);
} finally {
this.closeController();
log.info("Stream closed", this._chat.iteration, this._chat.id);
Expand Down Expand Up @@ -107,7 +109,7 @@ export class AilaStreamHandler {
for (const plugin of this._chat.aila.plugins ?? []) {
await plugin.onStreamError?.(error, {
aila: this._chat.aila,
enqueue: this._chat.enqueue,
enqueue: (patch) => this._chat.enqueue(patch),
});
}

Expand Down
2 changes: 2 additions & 0 deletions packages/aila/src/core/chat/PatchEnqueuer.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ describe("PatchEnqueuer", () => {

await patchEnqueuer.enqueuePatch(path, value);

// eslint-disable-next-line @typescript-eslint/unbound-method
expect(controller.enqueue).toHaveBeenCalled();
const expectedPatch = `\n␞\n${JSON.stringify({
type: "patch",
reasoning: "generated",
value: { op: "add", path, value },
status: "complete",
})}\n␞\n`;
// eslint-disable-next-line @typescript-eslint/unbound-method
expect(controller.enqueue).toHaveBeenCalledWith(expectedPatch);
});

Expand Down
6 changes: 3 additions & 3 deletions packages/aila/src/core/prompt/AilaPromptBuilder.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { jsonrepair } from "jsonrepair";

import type { Message } from "..";
import type { AilaServices } from "../..";
import type { AilaServices } from "../../core/AilaServices";
import { tryWithErrorReporting } from "../../helpers/errorReporting";
import type { Message } from "../chat";

export abstract class AilaPromptBuilder {
protected _aila: AilaServices;
Expand Down Expand Up @@ -41,7 +41,7 @@ export abstract class AilaPromptBuilder {
}
return row;
},
`Failed to parse row`,
"Failed to parse row",
"info",
{
category: "json",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { TemplateProps } from "@oakai/core/src/prompts/lesson-assistant";
import { template } from "@oakai/core/src/prompts/lesson-assistant";
import { prisma as globalPrisma } from "@oakai/db";
import { prisma as globalPrisma } from "@oakai/db/client";
import { aiLogger } from "@oakai/logger";

import { DEFAULT_RAG_LESSON_PLANS } from "../../../constants";
Expand Down
2 changes: 1 addition & 1 deletion packages/aila/src/core/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { PrismaClientWithAccelerate } from "@oakai/db";
import type { PrismaClientWithAccelerate } from "@oakai/db/client";

import type { AilaAmericanismsFeature } from "../features/americanisms";
import type { AnalyticsAdapter } from "../features/analytics";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
/// <reference path="./american-british-english-translator.d.ts" />
import { textify } from "@oakai/core/src/models/lessonPlans";
import { textify } from "@oakai/core/src/utils/textify";
import translator from "american-british-english-translator";

import type { AilaAmericanismsFeature } from ".";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { CategoriseKeyStageAndSubjectResponse } from "@oakai/core/src/rag";
import { CategoriseKeyStageAndSubjectResponse } from "@oakai/core/src/rag/categorisation";
import { keyStages, subjects } from "@oakai/core/src/utils/subjects";
import { aiLogger } from "@oakai/logger";
import type { ChatCompletionMessageParam } from "openai/resources";
Expand Down
4 changes: 2 additions & 2 deletions packages/aila/src/features/generation/AilaGeneration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ import {
generateAilaPromptVersionVariantSlug,
} from "@oakai/core/src/prompts/lesson-assistant/variants";
import type { Prompt } from "@oakai/db";
import { prisma } from "@oakai/db";
import { prisma } from "@oakai/db/client";
import { aiLogger } from "@oakai/logger";
import { kv } from "@vercel/kv";
import { getEncoding } from "js-tiktoken";

import type { AilaServices } from "../../core";
import type { AilaServices } from "../../core/AilaServices";
import type { AilaChat } from "../../core/chat";
import type { AilaGenerationStatus } from "./types";

Expand Down
14 changes: 9 additions & 5 deletions packages/aila/src/features/generation/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import { AilaGeneration } from ".";
import type { AilaInitializationOptions } from "../..";
import { Aila } from "../..";
import { Aila } from "../../core/Aila";
import type { Message } from "../../core/chat";
import { AilaChat } from "../../core/chat";
import type { AilaInitializationOptions } from "../../core/types";
import { AilaGeneration } from "./AilaGeneration";

const ailaArgs: AilaInitializationOptions = {
plugins: [],
Expand All @@ -21,7 +21,9 @@ describe("calculateTokenUsage", () => {
];

const mockEncoding = {
encode: jest.fn().mockImplementation((text) => text.split(" ").length),
encode: jest
.fn()
.mockImplementation((text: string) => text.split(" ").length),
};
jest.mock("js-tiktoken", () => ({
getEncoding: () => mockEncoding,
Expand Down Expand Up @@ -57,7 +59,9 @@ describe("calculateTokenUsage", () => {
];

const mockEncoding = {
encode: jest.fn().mockImplementation((text) => text.split(" ").length),
encode: jest
.fn()
.mockImplementation((text: string) => text.split(" ").length),
};
jest.mock("js-tiktoken", () => ({
getEncoding: () => mockEncoding,
Expand Down
12 changes: 4 additions & 8 deletions packages/aila/src/features/moderation/AilaModeration.ts
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import { Moderations } from "@oakai/core";
import { Moderations } from "@oakai/core/src/models/moderations";
import {
getCategoryGroup,
getMockModerationResult,
getSafetyResult,
isToxic,
} from "@oakai/core/src/utils/ailaModeration/helpers";
import type { ModerationResult } from "@oakai/core/src/utils/ailaModeration/moderationSchema";
import type {
Moderation,
PrismaClientWithAccelerate} from "@oakai/db";
import {
prisma as globalPrisma,
} from "@oakai/db";
import type { Moderation, PrismaClientWithAccelerate } from "@oakai/db";
import { prisma as globalPrisma } from "@oakai/db";
import { aiLogger } from "@oakai/logger";
import invariant from "tiny-invariant";

import type { AilaServices } from "../../core";
import type { AilaServices } from "../../core/AilaServices";
import type { Message } from "../../core/chat";
import type { AilaPluginContext } from "../../core/plugins/types";
import { getLastAssistantMessage } from "../../helpers/chat/getLastAssistantMessage";
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { Moderations } from "@oakai/core";
import type { Moderation} from "@oakai/db";
import { prisma } from "@oakai/db";
import { Moderations } from "@oakai/core/src/models/moderations";
import { prisma } from "@oakai/db/client";
import type { Moderation } from "@prisma/client";

export async function getSessionModerations(
appSessionId: string,
Expand Down
Loading
Loading