Skip to content

Commit

Permalink
changed typing to be BaseSchema and BaseType respectively. If things …
Browse files Browse the repository at this point in the history
…break it will be here
  • Loading branch information
gclomax committed Nov 20, 2024
1 parent 1343ca9 commit 48c2331
Show file tree
Hide file tree
Showing 10 changed files with 58 additions and 34 deletions.
12 changes: 8 additions & 4 deletions packages/aila/src/core/quiz/AilaQuizReranker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,17 @@ import type { z } from "zod";

import type { QuizPath, QuizQuestion } from "../../protocol/schema";
import type { LooseLessonPlan } from "../../protocol/schema";
import { selectHighestRated, type BaseType } from "./ChoiceModels";
import {
selectHighestRated,
type BaseType,
type BaseSchema,
} from "./ChoiceModels";
import { evaluateQuiz } from "./OpenAIRanker";
import { processArray } from "./apiCallingUtils";
import { withRandomDelay } from "./apiCallingUtils";
import type { AilaQuizReranker } from "./interfaces";

export abstract class BasedOnRagAilaQuizReranker<T extends BaseType>
export abstract class BasedOnRagAilaQuizReranker<T extends typeof BaseSchema>
implements AilaQuizReranker<T>
{
abstract rerankQuiz(quizzes: QuizQuestion[][]): Promise<number[]>;
Expand All @@ -24,10 +28,10 @@ export abstract class BasedOnRagAilaQuizReranker<T extends BaseType>
// This takes a quiz array and evaluates it using the rating schema and quiz type and returns an array of evaluation schema objects.
// TODO: GCLOMAX - move evaluate quiz out to use dependancy injection - can then pass the different types of reranker types.
// TODO: GCLOMAX - Cache this. This is where a lot of the expensive openai calling takes place.
public async evaluateQuizArray<T extends BaseType>(
public async evaluateQuizArray(
quizArray: QuizQuestion[][],
lessonPlan: LooseLessonPlan,
ratingSchema: T,
ratingSchema: typeof BaseSchema,
quizType: QuizPath,
): Promise<T[]> {
// Decorates to delay the evaluation of each quiz. There is probably a better library for this.
Expand Down
8 changes: 4 additions & 4 deletions packages/aila/src/core/quiz/AilaQuizRerankerFactory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@ import type { AilaQuizService } from "../AilaServices";
import { AilaQuiz } from "./AilaQuiz";
import { BasedOnRagAilaQuizReranker } from "./AilaQuizReranker";
import type { BaseType } from "./ChoiceModels";
import type { BaseSchema } from "./ChoiceModels";
import { TestSchemaReranker } from "./SchemaReranker";
import type {
AilaQuizFactory,
AilaQuizReranker,
AilaQuizRerankerFactory,
quizRecommenderType,
quizRerankerType,
} from "./interfaces";
import type { QuizRerankerType } from "./schema";

export class AilaQuizRerankerFactoryImpl implements AilaQuizRerankerFactory {
public createAilaQuizReranker(
quizType: quizRerankerType,
): AilaQuizReranker<BaseType> {
quizType: QuizRerankerType,
): AilaQuizReranker<typeof BaseSchema> {
return new TestSchemaReranker();
}
}
Expand Down
17 changes: 15 additions & 2 deletions packages/aila/src/core/quiz/AilaQuizVariants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,22 @@ export abstract class BaseQuizGenerator implements AilaQuizGeneratorService {
protected rerankService: CohereReranker;

constructor() {
if (
!process.env.I_DOT_AI_ELASTIC_CLOUD_ID ||
!process.env.I_DOT_AI_ELASTIC_KEY
) {
throw new Error(
"Environment variables for Elastic Cloud ID and API Key must be set",
);
}
this.client = new Client({
cloud: { id: process.env.I_DOT_AI_ELASTIC_CLOUD_ID as string },
auth: { apiKey: process.env.I_DOT_AI_ELASTIC_KEY as string },
cloud: {
id: process.env.I_DOT_AI_ELASTIC_CLOUD_ID,
},

auth: {
apiKey: process.env.I_DOT_AI_ELASTIC_KEY,
},
});

this.cohere = new CohereClient({
Expand Down
14 changes: 9 additions & 5 deletions packages/aila/src/core/quiz/BaseFullQuizService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import type {
QuizQuestion,
} from "../../protocol/schema";
import type { AilaQuizGeneratorService } from "../AilaServices";
import type { BaseType } from "./ChoiceModels";
import type { BaseSchema, BaseType } from "./ChoiceModels";
import { MLQuizGenerator } from "./MLQuizGenerator";
import { TestSchemaReranker } from "./SchemaReranker";
import { SimpleQuizSelector } from "./SimpleQuizSelector";
Expand All @@ -17,7 +17,7 @@ import type { quizPatchType } from "./interfaces";

export abstract class BaseFullQuizService implements FullQuizService {
public abstract quizSelector: QuizSelector<BaseType>;
public abstract quizReranker: AilaQuizReranker<BaseType>;
public abstract quizReranker: AilaQuizReranker<typeof BaseSchema>;
public abstract quizGenerators: AilaQuizGeneratorService[];

public async createBestQuiz(
Expand Down Expand Up @@ -49,13 +49,17 @@ export abstract class BaseFullQuizService implements FullQuizService {
quizType,
);

const bestQuiz = this.quizSelector.selectBestQuiz(quizzes, quizRankings);
const parsedRankings = quizRankings.map((ranking) => ranking.parse({}));
const bestQuiz = this.quizSelector.selectBestQuiz(quizzes, parsedRankings);
return bestQuiz;
}
}

export class SimpleFullQuizService extends BaseFullQuizService {
public quizSelector: QuizSelector<BaseType> = new SimpleQuizSelector();
public quizReranker: AilaQuizReranker<BaseType> = new TestSchemaReranker();
public quizSelector: QuizSelector<BaseType> =
new SimpleQuizSelector<BaseType>();

public quizReranker: AilaQuizReranker<typeof BaseSchema> =
new TestSchemaReranker();
public quizGenerators: AilaQuizGeneratorService[] = [new MLQuizGenerator()];
}
2 changes: 1 addition & 1 deletion packages/aila/src/core/quiz/BaseQuizSelector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import type { z } from "zod";
import type { QuizQuestion } from "../../protocol/schema";
import type { RatingFunction } from "./ChoiceModels";
import type { MaxRatingFunctionApplier } from "./ChoiceModels";
import type { BaseType } from "./ChoiceModels";
import type { BaseSchema, BaseType } from "./ChoiceModels";
import type { QuizSelector } from "./interfaces";

export abstract class BaseQuizSelector<T extends BaseType>
Expand Down
2 changes: 1 addition & 1 deletion packages/aila/src/core/quiz/ChoiceModels.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import { starterQuizSuitabilitySchema } from "./RerankerStructuredOutputSchema";

// TODO: GCLOMAX - make blended comparison functions - i.e taking the same question from each quiz and finding the best one.

const BaseSchema = z.object({
export const BaseSchema = z.object({
// Add any common fields here
rating: z.number(),
});
Expand Down
2 changes: 1 addition & 1 deletion packages/aila/src/core/quiz/OpenAIRanker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { createOpenAIClient } from "@oakai/core/src/llm/openai";
import { replayCanvasIntegration } from "@sentry/nextjs";
import type { OpenAI } from "openai";
import { zodResponseFormat } from "openai/helpers/zod";
import { ChatCompletionMessageParam } from "openai/resources/index.mjs";
// import { ChatCompletionMessageParam } from "openai/resources/index.mjs";
import { z } from "zod";

import { DEFAULT_MODEL } from "../../constants";
Expand Down
17 changes: 8 additions & 9 deletions packages/aila/src/core/quiz/SchemaReranker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,29 @@ import type {
QuizQuestion,
} from "../../protocol/schema";
import { BasedOnRagAilaQuizReranker } from "./AilaQuizReranker";
import {
testRatingSchema,
type TestRating,
} from "./RerankerStructuredOutputSchema";
import { testRatingSchema } from "./RerankerStructuredOutputSchema";

export class TestSchemaReranker extends BasedOnRagAilaQuizReranker<TestRating> {
export class TestSchemaReranker extends BasedOnRagAilaQuizReranker<
typeof testRatingSchema
> {
public rerankQuiz(quizzes: QuizQuestion[][]): Promise<number[]> {
return Promise.resolve([]);
}
public inputSchema = testRatingSchema;
public evaluateStarterQuiz(
quizzes: QuizQuestion[][],
lessonPlan: LooseLessonPlan,
ratingSchema: TestRating,
ratingSchema: typeof testRatingSchema,
quizType: QuizPath,
): Promise<TestRating[]> {
): Promise<(typeof testRatingSchema)[]> {
return this.evaluateQuizArray(quizzes, lessonPlan, ratingSchema, quizType);
}
public evaluateExitQuiz(
quizzes: QuizQuestion[][],
lessonPlan: LooseLessonPlan,
ratingSchema: TestRating,
ratingSchema: typeof testRatingSchema,
quizType: QuizPath,
): Promise<TestRating[]> {
): Promise<(typeof testRatingSchema)[]> {
return this.evaluateQuizArray(quizzes, lessonPlan, ratingSchema, quizType);
}
}
Expand Down
4 changes: 3 additions & 1 deletion packages/aila/src/core/quiz/SimpleQuizSelector.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import { BaseQuizSelector } from "./BaseQuizSelector";
import { selectHighestRated, type RatingFunction } from "./ChoiceModels";
import type { MaxRatingFunctionApplier } from "./ChoiceModels";
import type { BaseType } from "./ChoiceModels";
import { BaseSchema } from "./ChoiceModels";

// TODO: GCLOMAX - Why on earth is this not working?????????????
export class SimpleQuizSelector<
T extends BaseType,
> extends BaseQuizSelector<T> {
public ratingFunction: RatingFunction<T> = (item) => item.rating;
public ratingFunction: RatingFunction<T> = (item: T) => item.rating;
public maxRatingFunctionApplier: MaxRatingFunctionApplier<T> =
selectHighestRated;
// constructor(
Expand Down
14 changes: 8 additions & 6 deletions packages/aila/src/core/quiz/interfaces.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import type {
AilaQuizGeneratorService,
AilaQuizService,
} from "../AilaServices";
import type { BaseSchema } from "./ChoiceModels";
import type {
BaseType,
MaxRatingFunctionApplier,
Expand Down Expand Up @@ -61,22 +62,23 @@ export interface AilaQuizVariantService {
): Promise<JsonPatchDocument>;
}

export interface AilaQuizReranker<T extends BaseType> {
export interface AilaQuizReranker<T extends typeof BaseSchema> {
rerankQuiz(quizzes: QuizQuestion[][]): Promise<number[]>;
evaluateQuizArray<T extends BaseType>(
evaluateQuizArray(
quizzes: QuizQuestion[][],
lessonPlan: LooseLessonPlan,
ratingSchema: T,
ratingSchema: typeof BaseSchema,
quizType: QuizPath,
): Promise<T[]>;
ratingSchema?: T;
quizType?: QuizPath;
ratingFunction?: RatingFunction<T>;
ratingFunction?: RatingFunction<BaseType>;
}

// TODO: GCLOMAX - make generic by extending BaseType and BaseSchema as <T,U>
export interface FullQuizService {
quizSelector: QuizSelector<BaseType>;
quizReranker: AilaQuizReranker<BaseType>;
quizReranker: AilaQuizReranker<typeof BaseSchema>;
quizGenerators: AilaQuizGeneratorService[];
}

Expand Down Expand Up @@ -145,7 +147,7 @@ export interface AilaQuizFactory {
export interface AilaQuizRerankerFactory {
createAilaQuizReranker(
quizType: QuizRerankerType,
): AilaQuizReranker<BaseType>;
): AilaQuizReranker<typeof BaseSchema>;
}

export interface QuizSelectorFactory {
Expand Down

0 comments on commit 48c2331

Please sign in to comment.