Skip to content

Commit

Permalink
Add strong response types for vertex evaluators
Browse files Browse the repository at this point in the history
  • Loading branch information
tagboola committed May 2, 2024
1 parent 225dad6 commit 983658f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 12 deletions.
41 changes: 35 additions & 6 deletions js/plugins/vertexai/src/evaluation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import { BaseDataPoint } from '@genkit-ai/ai/evaluator';
import { Action } from '@genkit-ai/core';
import { GoogleAuth } from 'google-auth-library';
import { JSONClient } from 'google-auth-library/build/src/auth/googleauth';
import z from 'zod';
import { EvaluatorFactory } from './evaluator_factory';

/**
Expand Down Expand Up @@ -57,10 +58,6 @@ export function vertexEvaluators(
const metricType = isConfig(metric) ? metric.type : metric;
const metricSpec = isConfig(metric) ? metric.metricSpec : {};

console.log(
`Creating evaluator for metric ${metricType} with metricSpec ${metricSpec}`
);

switch (metricType) {
case VertexAIEvaluationMetricType.BLEU: {
return createBleuEvaluator(factory, metricSpec);
Expand All @@ -84,6 +81,12 @@ function isConfig(
return (config as VertexAIEvaluationMetricConfig).type !== undefined;
}

const BleuResponseSchema = z.object({
bleuResults: z.object({
bleuMetricValues: z.array(z.object({ score: z.number() })),
}),
});

// TODO: Add support for batch inputs
function createBleuEvaluator(
factory: EvaluatorFactory,
Expand All @@ -95,6 +98,7 @@ function createBleuEvaluator(
displayName: 'BLEU',
definition:
'Computes the BLEU score by comparing the output against the ground truth',
responseSchema: BleuResponseSchema,
},
(datapoint) => {
if (!datapoint.reference) {
Expand Down Expand Up @@ -124,6 +128,12 @@ function createBleuEvaluator(
);
}

const RougeResponseSchema = z.object({
rougeResults: z.object({
rougeMetricValues: z.array(z.object({ score: z.number() })),
}),
});

// TODO: Add support for batch inputs
function createRougeEvaluator(
factory: EvaluatorFactory,
Expand All @@ -135,6 +145,7 @@ function createRougeEvaluator(
displayName: 'ROUGE',
definition:
'Computes the ROUGE score by comparing the output against the ground truth',
responseSchema: RougeResponseSchema,
},
(datapoint) => {
if (!datapoint.reference) {
Expand Down Expand Up @@ -162,6 +173,14 @@ function createRougeEvaluator(
);
}

const SafetyResponseSchema = z.object({
safetyResult: z.object({
score: z.number(),
explanation: z.string(),
confidence: z.number(),
}),
});

function createSafetyEvaluator(
factory: EvaluatorFactory,
metricSpec: any
Expand All @@ -171,6 +190,7 @@ function createSafetyEvaluator(
metric: VertexAIEvaluationMetricType.SAFETY,
displayName: 'Safety',
definition: 'Assesses the level of safety of an output',
responseSchema: SafetyResponseSchema,
},
(datapoint) => {
return {
Expand All @@ -182,7 +202,7 @@ function createSafetyEvaluator(
},
};
},
(response: any, datapoint: BaseDataPoint) => {
(response, datapoint: BaseDataPoint) => {
return {
testCaseId: datapoint.testCaseId,
evaluation: {
Expand All @@ -196,6 +216,14 @@ function createSafetyEvaluator(
);
}

const GroundednessResponseSchema = z.object({
groundednessResult: z.object({
score: z.number(),
explanation: z.string(),
confidence: z.number(),
}),
});

function createGroundednessEvaluator(
factory: EvaluatorFactory,
metricSpec: any
Expand All @@ -206,6 +234,7 @@ function createGroundednessEvaluator(
displayName: 'Groundedness',
definition:
'Assesses the ability to provide or reference information included only in the context',
responseSchema: GroundednessResponseSchema,
},
(datapoint) => {
return {
Expand All @@ -218,7 +247,7 @@ function createGroundednessEvaluator(
},
};
},
(response: any, datapoint: BaseDataPoint) => {
(response, datapoint: BaseDataPoint) => {
return {
testCaseId: datapoint.testCaseId,
evaluation: {
Expand Down
31 changes: 25 additions & 6 deletions js/plugins/vertexai/src/evaluator_factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import { Action } from '@genkit-ai/core';
import { runInNewSpan } from '@genkit-ai/core/tracing';
import { GoogleAuth } from 'google-auth-library';
import { JSONClient } from 'google-auth-library/build/src/auth/googleauth';
import z from 'zod';
import { VertexAIEvaluationMetricType } from './evaluation';

export class EvaluatorFactory {
Expand All @@ -28,14 +29,18 @@ export class EvaluatorFactory {
private readonly projectId: string
) {}

create(
create<ResponseType extends z.ZodTypeAny>(
config: {
metric: VertexAIEvaluationMetricType;
displayName: string;
definition: string;
responseSchema: ResponseType;
},
toRequest: (datapoint: BaseDataPoint) => any,
responseHandler: (response: any, datapoint: BaseDataPoint) => any
responseHandler: (
response: z.infer<ResponseType>,
datapoint: BaseDataPoint
) => any
): Action {
return defineEvaluator(
{
Expand All @@ -44,14 +49,21 @@ export class EvaluatorFactory {
definition: config.definition,
},
async (datapoint: BaseDataPoint) => {
const response = await this.evaluateInstances(toRequest(datapoint));
const responseSchema = config.responseSchema;
const response = await this.evaluateInstances(
toRequest(datapoint),
responseSchema
);

return responseHandler(response, datapoint);
}
);
}

async evaluateInstances(partialRequest: any) {
async evaluateInstances<ResponseType extends z.ZodTypeAny>(
partialRequest: any,
responseSchema: ResponseType
): Promise<z.infer<ResponseType>> {
const locationName = `projects/${this.projectId}/locations/${this.location}`;
return await runInNewSpan(
{
Expand All @@ -64,15 +76,22 @@ export class EvaluatorFactory {
location: locationName,
...partialRequest,
};

metadata.input = request;
const client = await this.auth.getClient();
const url = `https://${this.location}-aiplatform.googleapis.com/v1beta1/${locationName}:evaluateInstances`;
const response = await client.request({
url: `https://${this.location}-aiplatform.googleapis.com/v1beta1/${locationName}:evaluateInstances`,
url,
method: 'POST',
body: JSON.stringify(request),
});
metadata.output = response.data;
return response.data as any;

try {
return responseSchema.parse(response.data);
} catch (e) {
throw new Error(`Error parsing ${url} API response: ${e}`);
}
}
);
}
Expand Down

0 comments on commit 983658f

Please sign in to comment.