From 3b1eef4051bfa627e9691bccb28501a4ad6887f9 Mon Sep 17 00:00:00 2001 From: Abhinav Rau Date: Sun, 1 Sep 2024 19:05:36 -0400 Subject: [PATCH] Added Question Answering quality metric for Sythetic Q&A generation --- src/common.js | 83 ++++++++--- src/excel/create_tables.js | 75 ++++++++++ src/excel/excel_common.js | 1 + src/excel/excel_search_runner.js | 6 +- src/excel/excel_summarization_runner.js | 6 +- src/excel/excel_synthetic_qa_runner.js | 132 +++++++++++++++--- src/excel/synthetic_qa_tables.js | 72 ++-------- src/task_runner.js | 11 +- src/vertex_ai.js | 13 +- .../multi_modal/test_multi_modal_request.json | 4 +- .../test_qa_quality_request.json | 42 ++++++ .../test_qa_quality_response.json | 51 +++++++ test/test_common.js | 44 ++++++ test/test_search_runner.js | 2 +- test/test_summarization_runner.js | 21 +-- test/test_synthetic_qa_runner.js | 55 +++++--- test/test_vertex_ai_multimodal.js | 26 +++- 17 files changed, 488 insertions(+), 156 deletions(-) create mode 100644 src/excel/create_tables.js create mode 100644 test/data/question_answering/test_qa_quality_request.json create mode 100644 test/data/question_answering/test_qa_quality_response.json diff --git a/src/common.js b/src/common.js index 11c189d..6e6fe4b 100644 --- a/src/common.js +++ b/src/common.js @@ -1,3 +1,14 @@ +export function findIndexByColumnsNameIn2DArray(array2D, searchValue) { + for (let i = 0; i < array2D.length; i++) { + if (array2D[i][0] === searchValue) { + // Found the value in the first index, return the entire sub-array + return i; + } + } + // Value not found, return null or handle it as needed + return 0; +} + // Vertex AI Search Table Format export const vertex_ai_search_configValues = [ @@ -43,9 +54,8 @@ export const vertex_ai_search_testTableHeader = [ export var summaryMatching_prompt = "You will get two answers to a question, you should determine if they are semantically similar or not. "; -export var summaryMatching_examples = - " examples - answer_1: I was created by X. answer_2: X created me. output:same " + - "answer_1:There are 52 days in a year. answer_2: A year is fairly long. output:different "; +export var summaryMatching_examples = `examples - answer_1: I was created by X. answer_2: X created me. output:same + answer_1:There are 52 days in a year. answer_2: A year is fairly long. output:different `; // Synthetic Q&A Table Format export const synth_q_and_a_configValues = [ @@ -55,27 +65,53 @@ export const synth_q_and_a_configValues = [ ["Gemini Model ID", "gemini-1.5-flash-001"], [ "System Instructions", - "You are an expert in reading call center policy and procedure documents." + - "Given the attached document, generate a question and answer that customers are likely to ask a call center agent." + - "The question should only be sourced from the provided the document.Do not use any other information other than the attached document. " + - "Explain your reasoning for the answer by quoting verbatim where in the document the answer is found. Return the results in JSON format.Example: " + - "{'question': 'Here is a question?', 'answer': 'Here is the answer', 'reasoning': 'Quote from document'}", + `Given the attached document, generate a question and an answer.The question should only be sourced from the provided the document. Do not use any other information other than the attached document. Explain your reasoning for the answer by quoting verbatim where in the document the answer is found. Return the results in JSON format.Example: {'question': 'Here is a question?', 'answer': 'Here is the answer', 'reasoning': 'Quote from document'}`, ], - ["Batch Size (1-10)", "4"], // BatchSize - ["Time between Batches in Seconds (1-10)", "2"], + [ + "Prompt", + `You are an expert in reading call center policy and procedure documents. Generate question and answer a customer would ask from a Bank using the attached document.`, + ], + ["Generate Q & A Quality", "TRUE"], + [ + "Q & A Quality Prompt", + `# Instruction +You are an expert evaluator. Your task is to evaluate the quality of the responses generated by AI models. +We will provide you with the user prompt and an AI-generated responses. +You should first read the user prompt carefully for analyzing the task, and then evaluate the quality of the responses based on and rules provided in the Evaluation section below. + +# Evaluation +## Metric Definition +You will be assessing question answering quality, which measures the overall quality of the answer to the question in user prompt. Pay special attention to length constraints, such as in X words or in Y sentences. The instruction for performing a question-answering task is provided in the user prompt. The response should not contain information that is not present in the context (if it is provided). + +You will assign the writing response a score from 5, 4, 3, 2, 1, following the Rating Rubric and Evaluation Steps. +Give step-by-step explanations for your scoring, and only choose scores from 5, 4, 3, 2, 1. + +## Criteria Definition +Instruction following: The response demonstrates a clear understanding of the question answering task instructions, satisfying all of the instruction's requirements. +Groundedness: The response contains information included only in the context if the context is present in user prompt. The response does not reference any outside information. +Completeness: The response completely answers the question with suffient detail. +Fluent: The response is well-organized and easy to read. + +## Rating Rubric +5: (Very good). The answer follows instructions, is grounded, complete, and fluent. +4: (Good). The answer follows instructions, is grounded, complete, but is not very fluent. +3: (Ok). The answer mostly follows instructions, is grounded, answers the question partially and is not very fluent. +2: (Bad). The answer does not follow the instructions very well, is incomplete or not fully grounded. +1: (Very bad). The answer does not follow the instructions, is wrong and not grounded. + +## Evaluation Steps +STEP 1: Assess the response in aspects of instruction following, groundedness,completeness, and fluency according to the crtieria. +STEP 2: Score based on the rubric. + +Return result in JSON format. example output: { 'rating': 2 , evaluation: 'reason'}`, + ], + ["Q & A Quality Model ID", "gemini-1.5-pro-001"], + ["Max Concurrent Requests (1-10)", "5"], + ["Request Interval in Seconds(1-10)", "1"], ]; export const synth_q_and_a_TableHeader = [ - [ - "ID", - "GCS File URI", - "Mime Type", - "Generated Question", - "Expected Answer", - "Reasoning", - "Status", - "Response Time", - ], + ["ID", "GCS File URI", "Mime Type", "Generated Question", "Expected Answer", "Q & A Quality"], ]; // Summarization Table Format @@ -108,6 +144,12 @@ export const summarization_TableHeader = [ ]; // Eval Maps +export const mapQuestionAnsweringScore = new Map(); +mapQuestionAnsweringScore.set("1", "1-Very Bad"); +mapQuestionAnsweringScore.set("2", "2-Bad"); +mapQuestionAnsweringScore.set("3", "3-OK"); +mapQuestionAnsweringScore.set("4", "4-Good"); +mapQuestionAnsweringScore.set("5", "5-Very Good"); export const mapSummaryQualityScore = new Map(); mapSummaryQualityScore.set(1, "1-Very Bad"); @@ -184,4 +226,3 @@ export class ResourceNotFoundError extends Error { this.statusCode = 404; // Optional: HTTP status code for API errors } } - diff --git a/src/excel/create_tables.js b/src/excel/create_tables.js new file mode 100644 index 0000000..8a014b0 --- /dev/null +++ b/src/excel/create_tables.js @@ -0,0 +1,75 @@ +import { appendError, showStatus } from "../ui.js"; + +export async function createConfigTable( + taskTitle, + configValuesArray, + tableRangeStart, + tableRangeEnd, +) { + await Excel.run(async (context) => { + try { + const currentWorksheet = context.workbook.worksheets.getActiveWorksheet(); + currentWorksheet.load("name"); + await context.sync(); + const worksheetName = currentWorksheet.name; + + var range = currentWorksheet.getRange("A1"); + range.values = [[taskTitle]]; + range.format.font.bold = true; + range.format.fill.color = "yellow"; + range.format.font.size = 16; + + var configTable = currentWorksheet.tables.add(tableRangeStart, true /*hasHeaders*/); + configTable.name = `${worksheetName}.ConfigTable`; + + configTable.getHeaderRowRange().values = [configValuesArray[0]]; + + configTable.rows.add(null, configValuesArray.slice(1)); + + currentWorksheet.getUsedRange().format.autofitColumns(); + currentWorksheet.getUsedRange().format.autofitRows(); + currentWorksheet.getRange(tableRangeEnd).format.wrapText = true; // wrap system instrcutions + currentWorksheet.getRange(tableRangeEnd).format.shrinkToFit = true; // shrinkToFit system instrcutions + + await context.sync(); + } catch (error) { + showStatus(`Exception when creating ${taskTitle} Config Table: ${error.message}`, true); + appendError(`Error creating ${taskTitle} Config Table:`, error); + + return; + } + }); +} + +export async function createDataTable( + taskTitle, + tableHeaderArray, + tableRangeStart, + tableRangeEnd, +) { + await Excel.run(async (context) => { + try { + const currentWorksheet = context.workbook.worksheets.getActiveWorksheet(); + currentWorksheet.load("name"); + await context.sync(); + const worksheetName = currentWorksheet.name; + + var velvetTable = currentWorksheet.tables.add(tableRangeStart, true /*hasHeaders*/); + velvetTable.name = `${worksheetName}.TestCasesTable`; + + velvetTable.getHeaderRowRange().values = [tableHeaderArray[0]]; + + velvetTable.resize(tableRangeEnd); + currentWorksheet.getUsedRange().format.autofitColumns(); + currentWorksheet.getUsedRange().format.autofitRows(); + currentWorksheet.getUsedRange().format.wrapText = true; + currentWorksheet.getUsedRange().format.shrinkToFit = true; + + await context.sync(); + } catch (error) { + showStatus(`Exception when creating ${taskTitle} DataTable: ${error.message}`, true); + appendError(`Error creating ${taskTitle} Data Table:`, error); + return; + } + }); +} diff --git a/src/excel/excel_common.js b/src/excel/excel_common.js index 67ececd..50315e5 100644 --- a/src/excel/excel_common.js +++ b/src/excel/excel_common.js @@ -10,3 +10,4 @@ export function getColumn(table, columnName) { showStatus(`Exception when getting column: ${JSON.stringify(error)}`, true); } } + diff --git a/src/excel/excel_search_runner.js b/src/excel/excel_search_runner.js index 4eadfb9..ff37995 100644 --- a/src/excel/excel_search_runner.js +++ b/src/excel/excel_search_runner.js @@ -46,8 +46,8 @@ export class ExcelSearchRunner extends TaskRunner { ignoreAdversarialQuery: valueColumn.values[12][0], ignoreNonSummarySeekingQuery: valueColumn.values[13][0], summaryMatchingAdditionalPrompt: valueColumn.values[14][0], - batchSize: valueColumn.values[15][0], - timeBetweenCallsInSec: valueColumn.values[16][0], + batchSize: parseInt(valueColumn.values[15][0]), + timeBetweenCallsInSec: parseInt(valueColumn.values[16][0]), accessToken: $("#access-token").val(), systemInstruction: "", responseMimeType: "text/plain", @@ -146,7 +146,7 @@ export class ExcelSearchRunner extends TaskRunner { async cancelAllTasks() { this.throttled_process_summary.abort(); - appendLog(`Cancel Requested for Search Tasks`); + appendLog(`Cancel Requested for Search Tasks`); } async processRow(response_json, context, config, rowNum) { diff --git a/src/excel/excel_summarization_runner.js b/src/excel/excel_summarization_runner.js index e2413a3..44a1530 100644 --- a/src/excel/excel_summarization_runner.js +++ b/src/excel/excel_summarization_runner.js @@ -41,8 +41,8 @@ export class SummarizationRunner extends TaskRunner { generateSummarizationVerbosity: valueColumn.values[7][0], generateGroundedness: valueColumn.values[8][0], generateFulfillment: valueColumn.values[9][0], - batchSize: valueColumn.values[10][0], - timeBetweenCallsInSec: valueColumn.values[11][0], + batchSize: parseInt(valueColumn.values[10][0]), + timeBetweenCallsInSec: parseInt(valueColumn.values[11][0]), accessToken: $("#access-token").val(), systemInstruction: "", @@ -118,7 +118,7 @@ export class SummarizationRunner extends TaskRunner { async getResultFromVertexAI(rowNum, config) { const toSummarize = this.toSummarizeColumn.values; const full_prompt = config.prompt + " Text to summarize: " + toSummarize[rowNum][0]; - return await callGeminiMultitModal(rowNum, full_prompt, null, null, config); + return await callGeminiMultitModal(rowNum, full_prompt, null, null, null, config.model, config); } async waitForTaskstoFinish() { diff --git a/src/excel/excel_synthetic_qa_runner.js b/src/excel/excel_synthetic_qa_runner.js index 3ef85b8..5e3006d 100644 --- a/src/excel/excel_synthetic_qa_runner.js +++ b/src/excel/excel_synthetic_qa_runner.js @@ -1,6 +1,7 @@ import { TaskRunner } from "../task_runner.js"; -import { appendError, showStatus } from "../ui.js"; +import { findIndexByColumnsNameIn2DArray, mapQuestionAnsweringScore } from "../common.js"; +import { appendError, appendLog, showStatus } from "../ui.js"; import { callGeminiMultitModal } from "../vertex_ai.js"; import { getColumn } from "./excel_common.js"; @@ -8,6 +9,9 @@ export class SyntheticQARunner extends TaskRunner { constructor() { super(); this.synthQATaskPromiseSet = new Set(); + this.generateQualityEval_throttled = this.throttle((a, b, c) => + this.generateQualityEval(a, b, c), + ); } async getSyntheticQAConfig() { @@ -20,15 +24,53 @@ export class SyntheticQARunner extends TaskRunner { const worksheetName = currentWorksheet.name; const configTable = currentWorksheet.tables.getItem(`${worksheetName}.ConfigTable`); const valueColumn = getColumn(configTable, "Value"); + const configColumn = getColumn(configTable, "Config"); await context.sync(); config = { - vertexAIProjectID: valueColumn.values[1][0], - vertexAILocation: valueColumn.values[2][0], - model: valueColumn.values[3][0], - systemInstruction: valueColumn.values[4][0], - batchSize: valueColumn.values[5][0], - timeBetweenCallsInSec: valueColumn.values[6][0], + vertexAIProjectID: + valueColumn.values[ + findIndexByColumnsNameIn2DArray(configColumn.values, "Vertex AI Project ID") + ][0], + vertexAILocation: + valueColumn.values[ + findIndexByColumnsNameIn2DArray(configColumn.values, "Vertex AI Location") + ][0], + model: + valueColumn.values[ + findIndexByColumnsNameIn2DArray(configColumn.values, "Gemini Model ID") + ][0], + systemInstruction: + valueColumn.values[ + findIndexByColumnsNameIn2DArray(configColumn.values, "System Instructions") + ][0], + prompt: + valueColumn.values[findIndexByColumnsNameIn2DArray(configColumn.values, "Prompt")][0], + qaQualityFlag: + valueColumn.values[ + findIndexByColumnsNameIn2DArray(configColumn.values, "Generate Q & A Quality") + ][0], + qAQualityPrompt: + valueColumn.values[ + findIndexByColumnsNameIn2DArray(configColumn.values, "Q & A Quality Prompt") + ][0], + qAQualityModel: + valueColumn.values[ + findIndexByColumnsNameIn2DArray(configColumn.values, "Q & A Quality Model ID") + ][0], + batchSize: parseInt( + valueColumn.values[ + findIndexByColumnsNameIn2DArray(configColumn.values, "Max Concurrent Requests (1-10)") + ][0], + ), + timeBetweenCallsInSec: parseInt( + valueColumn.values[ + findIndexByColumnsNameIn2DArray( + configColumn.values, + "Request Interval in Seconds(1-10)", + ) + ][0], + ), accessToken: $("#access-token").val(), responseMimeType: "application/json", }; @@ -59,9 +101,7 @@ export class SyntheticQARunner extends TaskRunner { this.mimeTypeColumn = getColumn(testCasesTable, "Mime Type"); this.generatedQuestionColumn = getColumn(testCasesTable, "Generated Question"); this.expectedAnswerColumn = getColumn(testCasesTable, "Expected Answer"); - this.reasoningAColumn = getColumn(testCasesTable, "Reasoning"); - this.statusColumn = getColumn(testCasesTable, "Status"); - this.responseTimeColumn = getColumn(testCasesTable, "Response Time"); + this.qualityColumn = getColumn(testCasesTable, "Q & A Quality"); testCasesTable.rows.load("count"); await context.sync(); @@ -99,12 +139,14 @@ export class SyntheticQARunner extends TaskRunner { async getResultFromVertexAI(rowNum, config) { let fileUri = this.fileUriColumn.values; let mimeType = this.mimeTypeColumn.values; - let prompt = "Generate 1 question and answer"; + return await callGeminiMultitModal( rowNum, - prompt, + config.prompt, + config.systemInstruction, fileUri[rowNum][0], mimeType[rowNum][0], + config.model, config, ); } @@ -118,6 +160,7 @@ export class SyntheticQARunner extends TaskRunner { } async processRow(response_json, context, config, rowNum) { + let numCallsMade = 0; try { const output = response_json.candidates[0].content.parts[0].text; // Set the generated question @@ -132,24 +175,67 @@ export class SyntheticQARunner extends TaskRunner { cell_expectedAnswer.clear(Excel.ClearApplyTo.formats); cell_expectedAnswer.values = [[response.answer]]; - // Set the reasoning - /* const cell_reasoning = this.reasoningColumn.getRange().getCell(rowNum, 0); - cell_reasoning.clear(Excel.ClearApplyTo.formats); - cell_reasoning.values = [[response.reasoning]]; */ - - // Set the reasoning - const cell_status = this.statusColumn.getRange().getCell(rowNum, 0); - cell_status.clear(Excel.ClearApplyTo.formats); - cell_status.values = [["Success"]]; + // call to get quality if flag is set + if (config.qaQualityFlag) { + this.synthQATaskPromiseSet.add( + this.generateQualityEval_throttled(config, response, rowNum), + ); + } } catch (err) { appendError(`testCaseID: ${rowNum} Error setting QA. Error: ${err.message} `, err); - const cell_status = this.statusColumn.getRange().getCell(rowNum, 0); + const cell_status = this.generatedQuestionColumn.getRange().getCell(rowNum, 0); cell_status.clear(Excel.ClearApplyTo.formats); cell_status.format.fill.color = "#FFCCCB"; cell_status.values = [["Failed. Error: " + err.message]]; } finally { //await context.sync(); } - return 0; + // execute the tasks + await Promise.allSettled(this.synthQATaskPromiseSet.values()); + + return numCallsMade++; + } + + async generateQualityEval(config, response, rowNum) { + try { + appendLog(`testCaseID::${rowNum} generateQualityEval Started..`); + let fileUri = this.fileUriColumn.values; + let mimeType = this.mimeTypeColumn.values; + + const evalPrompt = `${config.qAQualityPrompt} # User Inputs and AI-generated Response + ## User Inputs + ### Prompt + ${config.systemInstruction} + ${config.prompt} + + ## AI-generated Response + ${JSON.stringify(response)}`; + + const eval_response = await callGeminiMultitModal( + rowNum, + evalPrompt, + "", + fileUri[rowNum][0], + mimeType[rowNum][0], + config.qAQualityModel, + config, + ); + + const eval_output = eval_response.output.candidates[0].content.parts[0].text; + // since its json we get the rating tag + const eval_json = JSON.parse(eval_output); + + // Set the eval quality + const cell_evalQuality = this.qualityColumn.getRange().getCell(rowNum, 0); + cell_evalQuality.clear(Excel.ClearApplyTo.formats); + cell_evalQuality.values = [[mapQuestionAnsweringScore.get(eval_json.rating)]]; + appendLog(`testCaseID::${rowNum} generateQualityEval Finished: Raing: ${eval_json.rating}`); + } catch (err) { + appendError(`testCaseID: ${rowNum} Error setting Eval QA Error: ${err.message} `, err); + const cell_status = this.qualityColumn.getRange().getCell(rowNum, 0); + cell_status.clear(Excel.ClearApplyTo.formats); + cell_status.format.fill.color = "#FFCCCB"; + cell_status.values = [["Failed. Error: " + err.message]]; + } } } diff --git a/src/excel/synthetic_qa_tables.js b/src/excel/synthetic_qa_tables.js index b2edb69..09f1c36 100644 --- a/src/excel/synthetic_qa_tables.js +++ b/src/excel/synthetic_qa_tables.js @@ -1,66 +1,20 @@ import { synth_q_and_a_configValues, synth_q_and_a_TableHeader } from "../common.js"; -import { appendError, showStatus } from "../ui.js"; +import { createConfigTable, createDataTable } from "./create_tables.js"; export async function createSyntheticQAConfigTable() { - await Excel.run(async (context) => { - try { - const currentWorksheet = context.workbook.worksheets.getActiveWorksheet(); - currentWorksheet.load("name"); - await context.sync(); - const worksheetName = currentWorksheet.name; - - var range = currentWorksheet.getRange("A1"); - range.values = [["Generate Synthetic Questions and Answers"]]; - range.format.font.bold = true; - range.format.fill.color = "yellow"; - range.format.font.size = 16; - - var configTable = currentWorksheet.tables.add("A2:B2", true /*hasHeaders*/); - configTable.name = `${worksheetName}.ConfigTable`; - - configTable.getHeaderRowRange().values = [synth_q_and_a_configValues[0]]; - - configTable.rows.add(null, synth_q_and_a_configValues.slice(1)); - - currentWorksheet.getUsedRange().format.autofitColumns(); - currentWorksheet.getUsedRange().format.autofitRows(); - currentWorksheet.getRange("A6:B6").format.wrapText = true; // wrap system instrcutions - currentWorksheet.getRange("A6:B6").format.shrinkToFit = true; // shrinkToFit system instrcutions - - await context.sync(); - } catch (error) { - showStatus(`Exception when creating createSyntheticQA Config Table: ${error.message}`, true); - appendError("Error creating Config Table:", error); - - return; - } - }); + createConfigTable( + "Generate Synthetic Questions and Answers", + synth_q_and_a_configValues, + "A2:B2", + "A12:B12", + ); } export async function createSyntheticQADataTable() { - await Excel.run(async (context) => { - try { - const currentWorksheet = context.workbook.worksheets.getActiveWorksheet(); - currentWorksheet.load("name"); - await context.sync(); - const worksheetName = currentWorksheet.name; - - var velvetTable = currentWorksheet.tables.add("C9:J9", true /*hasHeaders*/); - velvetTable.name = `${worksheetName}.TestCasesTable`; - - velvetTable.getHeaderRowRange().values = [synth_q_and_a_TableHeader[0]]; - - velvetTable.resize("C9:J119"); - currentWorksheet.getUsedRange().format.autofitColumns(); - currentWorksheet.getUsedRange().format.autofitRows(); - currentWorksheet.getUsedRange().format.wrapText = true; - currentWorksheet.getUsedRange().format.shrinkToFit = true; - - await context.sync(); - } catch (error) { - showStatus(`Exception when creating createSyntheticQA DataTable: ${error.message}`, true); - appendError("Error creating Data Table:", error); - return; - } - }); + createDataTable( + "Generate Synthetic Questions and Answers", + synth_q_and_a_TableHeader, + "C13:H13", + "C13:H113", + ); } diff --git a/src/task_runner.js b/src/task_runner.js index 7e83ccc..0ec793f 100644 --- a/src/task_runner.js +++ b/src/task_runner.js @@ -15,7 +15,7 @@ export class TaskRunner { constructor() { this.cancelPressed = false; this.throttle = pThrottle({ - limit: 6, + limit: 10, interval: 1000, }); this.throttled_api_call = this.throttle((a, b) => this.getResultFromVertexAI(a, b)); @@ -48,6 +48,13 @@ export class TaskRunner { let numCallsMade = 0; this.cancelPressed = false; + if (config.batchSize !== null && config.timeBetweenCallsInSec != null) { + this.throttle = pThrottle({ + limit: config.batchSize > 10 ? 10 : config.batchSize, + interval: config.timeBetweenCallsInSec > 5 ? 5 * 1000 : config.timeBetweenCallsInSec * 1000, + }); + this.throttled_api_call = this.throttle((a, b) => this.getResultFromVertexAI(a, b)); + } // Start timer const startTime = new Date(); let promiseSet = new Set(); @@ -109,6 +116,8 @@ export class TaskRunner { // bad requests if some fo the config or auth tken is bad if (currentRow === 1) { await Promise.resolve(apiPromise); + await this.waitForTaskstoFinish(); + await context.sync(); if (numFails > 0) { break; } diff --git a/src/vertex_ai.js b/src/vertex_ai.js index 9a939ad..a1b31a7 100644 --- a/src/vertex_ai.js +++ b/src/vertex_ai.js @@ -85,12 +85,19 @@ export async function calculateSimilarityUsingPalm2(id, sentence1, sentence2, co return { id: id, status_code: status, output: `${output}` }; } -export async function callGeminiMultitModal(id, prompt, fileUri, mimeType, config) { +export async function callGeminiMultitModal( + id, + prompt, + systemInstruction, + fileUri, + mimeType, + model_id, + config, +) { const token = config.accessToken; const projectId = config.vertexAIProjectID; const location = config.vertexAILocation; - const model_id = config.model; - const system_instruction = config.systemInstruction === null ? "" : config.systemInstruction; + const system_instruction = systemInstruction === null ? "" : systemInstruction; var data = { contents: [ diff --git a/test/data/multi_modal/test_multi_modal_request.json b/test/data/multi_modal/test_multi_modal_request.json index 50375fc..10ab4c8 100644 --- a/test/data/multi_modal/test_multi_modal_request.json +++ b/test/data/multi_modal/test_multi_modal_request.json @@ -4,7 +4,7 @@ "role": "user", "parts": [ { - "text": "Generate 1 question and answer" + "text": "You are an expert in reading call center policy and procedure documents. Generate question and answer a customer would ask from a Bank using the attached document." }, { "fileData": { @@ -18,7 +18,7 @@ "systemInstruction": { "parts": [ { - "text": "You are an expert in reading call center policy and procedure documents.Given the attached document, generate a question and answer that customers are likely to ask a call center agent.The question should only be sourced from the provided the document.Do not use any other information other than the attached document. Explain your reasoning for the answer by quoting verbatim where in the document the answer is found. Return the results in JSON format.Example: {'question': 'Here is a question?', 'answer': 'Here is the answer', 'reasoning': 'Quote from document'}" + "text": "Given the attached document, generate a question and an answer.The question should only be sourced from the provided the document. Do not use any other information other than the attached document. Explain your reasoning for the answer by quoting verbatim where in the document the answer is found. Return the results in JSON format.Example: {'question': 'Here is a question?', 'answer': 'Here is the answer', 'reasoning': 'Quote from document'}" } ] }, diff --git a/test/data/question_answering/test_qa_quality_request.json b/test/data/question_answering/test_qa_quality_request.json new file mode 100644 index 0000000..bf62b3f --- /dev/null +++ b/test/data/question_answering/test_qa_quality_request.json @@ -0,0 +1,42 @@ +{ + "contents": [ + { + "role": "user", + "parts": [ + { + "text": "# Instruction\nYou are an expert evaluator. Your task is to evaluate the quality of the responses generated by AI models.\nWe will provide you with the user prompt and an AI-generated responses.\nYou should first read the user prompt carefully for analyzing the task, and then evaluate the quality of the responses based on and rules provided in the Evaluation section below.\n\n# Evaluation\n## Metric Definition\nYou will be assessing question answering quality, which measures the overall quality of the answer to the question in user prompt. Pay special attention to length constraints, such as in X words or in Y sentences. The instruction for performing a question-answering task is provided in the user prompt. The response should not contain information that is not present in the context (if it is provided).\n\nYou will assign the writing response a score from 5, 4, 3, 2, 1, following the Rating Rubric and Evaluation Steps.\nGive step-by-step explanations for your scoring, and only choose scores from 5, 4, 3, 2, 1.\n\n## Criteria Definition\nInstruction following: The response demonstrates a clear understanding of the question answering task instructions, satisfying all of the instruction's requirements.\nGroundedness: The response contains information included only in the context if the context is present in user prompt. The response does not reference any outside information.\nCompleteness: The response completely answers the question with suffient detail.\nFluent: The response is well-organized and easy to read.\n\n## Rating Rubric\n5: (Very good). The answer follows instructions, is grounded, complete, and fluent.\n4: (Good). The answer follows instructions, is grounded, complete, but is not very fluent.\n3: (Ok). The answer mostly follows instructions, is grounded, answers the question partially and is not very fluent.\n2: (Bad). The answer does not follow the instructions very well, is incomplete or not fully grounded.\n1: (Very bad). The answer does not follow the instructions, is wrong and not grounded.\n\n## Evaluation Steps\nSTEP 1: Assess the response in aspects of instruction following, groundedness, completeness, and fluency according to the criteria.\nSTEP 2: Score based on the rubric.\nReturn result in JSON format. example output: { 'rating': 2 , evaluation: 'reason'}\n\n# User Inputs and AI-generated Response\n## User Inputs\n### Prompt\nGiven the attached document, generate a question and an answer.The question should only be sourced from the provided the document. Do not use any other information other than the attached document. Explain your reasoning for the answer by quoting verbatim where in the document the answer is found. Return the results in JSON format.Example: {'question': 'Here is a question?', 'answer': 'Here is the answer', 'reasoning': 'Quote from document'}\n\nYou are an expert in reading call center policy and procedure documents. Generate question and answer a customer would ask from a Bank using the attached document. ## AI-generated Response\n\nMarkdown\n\n{\n \"question\": \"What are the circumstances under which Gemini Bank will consider a refund for overdraft fees?\",\n \"answer\": \"Gemini Bank may consider an overdraft fee refund for situations beyond the customer's control, such as documented bank errors, authorization issues with legitimate transactions, or unexpected delays in receiving a deposit critical to avoid the overdraft. \",\n \"reasoning\": \"The document states: \"Extenuating Circumstances: We will consider fee refunds for extenuating circumstances beyond the customer's control, such as documented bank errors, authorization issues with legitimate transactions, or unexpected delays in receiving a deposit critical to avoid the overdraft.\"\"\n}" + }, + { + "fileData": { + "mimeType": "application/pdf", + "fileUri": "gs://argolis-arau-gemini-bank/Procedure - Overdraft Fee Refunds.pdf" + } + } + ] + } + ], + "generationConfig": { + "maxOutputTokens": 8192, + "temperature": 1, + "topP": 0.95, + "response_mime_type": "application/json" + }, + "safetySettings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_MEDIUM_AND_ABOVE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_MEDIUM_AND_ABOVE" + } + ] +} diff --git a/test/data/question_answering/test_qa_quality_response.json b/test/data/question_answering/test_qa_quality_response.json new file mode 100644 index 0000000..b16fa85 --- /dev/null +++ b/test/data/question_answering/test_qa_quality_response.json @@ -0,0 +1,51 @@ +{ + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "{\n\"rating\": \"5\",\n\"evaluation\": \"The AI response correctly identifies a question and answer from the provided document. The answer is accurate and the reasoning is sound and quotes the document verbatim.\"\n}\n" + } + ] + }, + "finishReason": "STOP", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.034179688, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.111328125 + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.103515625, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.068359375 + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.071777344, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.046142578 + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.03173828, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.09033203 + } + ], + "avgLogprobs": -0.1950966093275282 + } + ], + "usageMetadata": { + "promptTokenCount": 1286, + "candidatesTokenCount": 45, + "totalTokenCount": 1331 + } +} diff --git a/test/test_common.js b/test/test_common.js index 5cbfcc8..7ae6e2b 100644 --- a/test/test_common.js +++ b/test/test_common.js @@ -34,3 +34,47 @@ export function mockVertexAISearchRequestResponse( }); return { requestJson, url, expectedResponse }; } + +export function mockGeminiRequestResponse( + testCaseNum, + expected_status_code, + expectedRequestFile, + expectedResponseFile, + model_id, + config, +) { + const requestData = fs.readFileSync(expectedRequestFile); + const requestJson = JSON.parse(requestData); + + // Read response json from file into variable + const responseData = fs.readFileSync(expectedResponseFile); + const responseJson = JSON.parse(responseData); + + // expected response with row number + const expectedResponse = { + id: testCaseNum, + status_code: expected_status_code, + output: responseJson, + }; + + const url = `https://${config.vertexAILocation}-aiplatform.googleapis.com/v1/projects/${config.vertexAIProjectID}/locations/${config.vertexAILocation}/publishers/google/models/${model_id}:generateContent`; + + // mock the call with our response we want to return + var response = fetchMock.post(url, { + status: expected_status_code, + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(responseJson), + }); + return { requestJson, url, expectedResponse }; +} + +export function getRequestResponseJsonFromFile(requestJsonFilePath, responseJsonFilePath) { + const request = fs.readFileSync(requestJsonFilePath); + const response = fs.readFileSync(responseJsonFilePath); + return { + response_json: JSON.parse(response), + request_json: JSON.parse(request), + }; +} diff --git a/test/test_search_runner.js b/test/test_search_runner.js index 94d137d..57f1555 100644 --- a/test/test_search_runner.js +++ b/test/test_search_runner.js @@ -330,7 +330,7 @@ describe("When Search Run Tests is clicked ", () => { ); // Execute the tests - config.timeBetweenCallsInSec = 0; // set timeout to zero so test doesn't timeout + config.batchSize = 10; // set batchSize high so test doesn't timeout await excelSearchRunner.executeSearchTests(config); // Verify mocks are called diff --git a/test/test_summarization_runner.js b/test/test_summarization_runner.js index 97b24c5..28c9d54 100644 --- a/test/test_summarization_runner.js +++ b/test/test_summarization_runner.js @@ -1,6 +1,5 @@ import expect from "expect"; import fetchMock from "fetch-mock"; -import fs from "fs"; import { default as $, default as JQuery } from "jquery"; import pkg from "office-addin-mock"; import sinon from "sinon"; @@ -10,17 +9,15 @@ import { createSummarizationEvalConfigTable, createSummarizationEvalDataTable, } from "../src/excel/summarization_tables.js"; -import { showStatus } from "../src/ui.js"; import { callGeminiMultitModal } from "../src/vertex_ai.js"; import { summarization_configValues, summarization_TableHeader } from "../src/common.js"; - +import { getRequestResponseJsonFromFile } from "./test_common.js"; // mock the UI components global.$ = $; global.JQuery = JQuery; - global.callGeminiMultitModal = callGeminiMultitModal; const { OfficeMockObject } = pkg; @@ -68,7 +65,6 @@ export var testCaseRows = summarization_TableHeader.concat([ ]); describe("When Summarization Eval is clicked ", () => { - var mockTestData; var $stub; beforeEach(() => { @@ -80,9 +76,6 @@ describe("When Summarization Eval is clicked ", () => { tabulator: sinon.stub(), prop: sinon.stub().returns(true), }); - - - fetchMock.reset(); @@ -343,7 +336,7 @@ describe("When Summarization Eval is clicked ", () => { body: JSON.stringify(fulfillment_response_json), }); - config.timeBetweenCallsInSec = 0; // No need for delay in tests + config.batchSize = 10; // set batchSize high so test doesn't timeout // Execute the tests await summarizationRunner.createSummarizationData(config); @@ -419,16 +412,6 @@ describe("When Summarization Eval is clicked ", () => { expect(fulfillment_quality_cell[0][0]).toEqual(testCaseRows[1][fulfillment_quality_col_index]); }); }); - -function getRequestResponseJsonFromFile(requestJsonFilePath, responseJsonFilePath) { - const request = fs.readFileSync(requestJsonFilePath); - const response = fs.readFileSync(responseJsonFilePath); - return { - response_json: JSON.parse(response), - request_json: JSON.parse(request), - }; -} - function getCellAndColumnIndexByName(column_name, mockTestData) { var col_index = testCaseRows[0].indexOf(column_name); diff --git a/test/test_synthetic_qa_runner.js b/test/test_synthetic_qa_runner.js index 1d140c3..6264bcb 100644 --- a/test/test_synthetic_qa_runner.js +++ b/test/test_synthetic_qa_runner.js @@ -1,6 +1,5 @@ import expect from "expect"; import fetchMock from "fetch-mock"; -import fs from "fs"; import { default as $, default as JQuery } from "jquery"; import pkg from "office-addin-mock"; import sinon from "sinon"; @@ -13,6 +12,7 @@ import { } from "../src/excel/synthetic_qa_tables.js"; import { showStatus } from "../src/ui.js"; import { callGeminiMultitModal } from "../src/vertex_ai.js"; +import { mockGeminiRequestResponse } from "./test_common.js"; // mock the UI components global.showStatus = showStatus; @@ -29,9 +29,7 @@ export var testCaseRows = synth_q_and_a_TableHeader.concat([ "application/pdf", //Mime Type "If I close my new savings account within 30 days of opening it, will I be charged a fee?", // Generated Question "Yes, you will be charged a $25 fee unless the closure is due to a Gemini Bank error in account opening, customer dissatisfaction with a product or service disclosed during the opening process, or insufficient funds.", //Expected Answer - "“Account Closure: A customer who closes their new savings account within 30 days of opening will be subject to a $25 fee unless the closure is due to:\nΟ Gemini Bank error in account opening.\nΟ Customer dissatisfaction with a product or service disclosed during the opening process.”", //Reasoning - "Success", // Status - "10ms", // Response Time + "5-Very Good", ], ]); @@ -236,22 +234,40 @@ describe("When Generate Synthetic Q&A is clicked ", () => { const config = await syntheticQuestionAnswerRunner.getSyntheticQAConfig(); expect(config).not.toBe(null); + // set up mock for file for first query // read the request from json file - const requestData = fs.readFileSync("./test/data/multi_modal/test_multi_modal_request.json"); - const requestJson = JSON.parse(requestData); - // Read response json from file into variable - const responseData = fs.readFileSync("./test/data/multi_modal/test_multi_modal_response.json"); - const responseJson = JSON.parse(responseData); + fetchMock.config.overwriteRoutes = false; - const url = `https://${config.vertexAILocation}-aiplatform.googleapis.com/v1/projects/${config.vertexAIProjectID}/locations/${config.vertexAILocation}/publishers/google/models/${config.model}:generateContent`; - fetchMock.postOnce(url, { - status: 200, - headers: { "Content-Type": `application/json` }, - body: JSON.stringify(responseJson), - }); + const { + requestJson: requestJson, + url: url, + expectedResponse: responseJson, + } = mockGeminiRequestResponse( + 1, + 200, + "./test/data/multi_modal/test_multi_modal_request.json", + "./test/data/multi_modal/test_multi_modal_response.json", + config.model, + config, + ); - config.timeBetweenCallsInSec = 0; // No need for delay in tests + // set up mock for file for first query + // read the request from json file + const { + requestJson: quality_request_json, + url: quality_url, + expectedResponse: quality_expectedResponse, + } = mockGeminiRequestResponse( + 1, + 200, + "./test/data/question_answering/test_qa_quality_request.json", + "./test/data/question_answering/test_qa_quality_response.json", + config.qAQualityModel, + config, + ); + + config.batchSize = 10; // set batchSize high so test doesn't timeout // Execute the tests await syntheticQuestionAnswerRunner.createSyntheticQAData(config); @@ -279,6 +295,13 @@ describe("When Generate Synthetic Q&A is clicked ", () => { expect(actual_generated_answer[0][0]).toEqual( testCaseRows[1][actual_generated_answer_col_index], ); + + // Match the Question Answer Quality + const { cell: question_answer_quality, col_index: question_answer_quality_col_index } = + getCellAndColumnIndexByName("Q & A Quality", mockTestData); + expect(question_answer_quality[0][0]).toEqual( + testCaseRows[1][question_answer_quality_col_index], + ); }); }); diff --git a/test/test_vertex_ai_multimodal.js b/test/test_vertex_ai_multimodal.js index 30ce6c3..6e9809d 100644 --- a/test/test_vertex_ai_multimodal.js +++ b/test/test_vertex_ai_multimodal.js @@ -60,7 +60,15 @@ describe("When callGeminiMultiModal is called", () => { body: JSON.stringify(responseJson), }); - const result = await callGeminiMultitModal(1, prompt, fileUri, mimeType, config); + const result = await callGeminiMultitModal( + 1, + prompt, + config.systemInstruction, + fileUri, + mimeType, + config.model, + config, + ); // make sure our mock is called expect(fetchMock.called()).toBe(true); @@ -75,7 +83,6 @@ describe("When callGeminiMultiModal is called", () => { expect(result.output.candidates[0].content.parts[0].text).toEqual( responseJson.candidates[0].content.parts[0].text, ); - }); it("should fail when you get an authentication error from Vertex AI", async () => { const prompt = "What is the sentiment of this text?"; @@ -104,7 +111,7 @@ describe("When callGeminiMultiModal is called", () => { }); try { - const result = await callGeminiMultitModal(1, prompt, fileUri, mimeType, config); + const result = await callGeminiMultitModal(1, prompt, "", fileUri, mimeType, config.model, config); assert.fail(); } catch (err) { expect(fetchMock.called()).toBe(true); @@ -115,6 +122,8 @@ describe("When callGeminiMultiModal is called", () => { it("should fail when fetch throws exception", async () => { const prompt = "What is the sentiment of this text?"; + const systemInstruction = + "You are an expert in reading call center policy and procedure documents.Given the attached document, generate a question and answer that customers are likely to ask a call center agent.The question should only be sourced from the provided the document.Do not use any other information other than the attached document. Explain your reasoning for the answer by quoting verbatim where in the document the answer is found. Return the results in JSON format.Example: {'question': 'Here is a question?', 'answer': 'Here is the answer', 'reasoning': 'Quote from document'}"; const fileUri = "https://example.com/file.txt"; const mimeType = "text/plain"; const model_id = "gemini-1.5-flash-001"; @@ -123,7 +132,6 @@ describe("When callGeminiMultiModal is called", () => { vertexAIProjectID: "YOUR_PROJECT_ID", vertexAILocation: "YOUR_LOCATION", model: "gemini-1.5-flash-001", - systemInstruction: null, responseMimeType: "application/json", }; @@ -133,7 +141,15 @@ describe("When callGeminiMultiModal is called", () => { }); try { - const result = await callGeminiMultitModal(1, prompt, fileUri, mimeType, config); + const result = await callGeminiMultitModal( + 1, + prompt, + systemInstruction, + fileUri, + mimeType, + config.model, + config, + ); assert.fail(); } catch (err) { expect(fetchMock.called()).toBe(true);