Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More runtime improvements #713

Merged
merged 8 commits into from
Nov 24, 2024
9 changes: 7 additions & 2 deletions plugins/google/modelgauge/suts/google_genai_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import google.generativeai as genai # type: ignore
from abc import abstractmethod
from typing import Dict, List, Optional

import google.generativeai as genai # type: ignore
from google.generativeai.types import HarmCategory, HarmBlockThreshold # type: ignore
from pydantic import BaseModel
from typing import Dict, List, Optional

from modelgauge.general import APIException
from modelgauge.prompt import TextPrompt
Expand Down Expand Up @@ -128,6 +129,10 @@ def translate_response(self, request: GoogleGenAiRequest, response: GoogleGenAiR
f"The candidate does not have any content,"
f" but it's finish reason {candidate.finish_reason} does not qualify as a refusal."
)
if not completions:
# This is apparently a refusal. At least, it's what happens consistently with a set of
# prompts in the CSE, SRC, and SXC hazards
completions = [SUTCompletion(text=REFUSAL_RESPONSE)]
return SUTResponse(completions=completions)


Expand Down
27 changes: 24 additions & 3 deletions plugins/google/tests/test_google_genai_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
from unittest.mock import patch

import pytest
from google.generativeai.protos import Candidate, GenerateContentResponse # type: ignore
from google.generativeai.types import HarmCategory, HarmBlockThreshold, generation_types # type: ignore

from unittest.mock import patch

from modelgauge.general import APIException
from modelgauge.prompt import SUTOptions, TextPrompt
from modelgauge.sut import REFUSAL_RESPONSE, SUTCompletion, SUTResponse
from modelgauge.suts.google_genai_client import ( # type: ignore
GEMINI_HARM_CATEGORIES,
GoogleAiApiKey,
Expand All @@ -15,7 +17,6 @@
GoogleGenAiRequest,
GoogleGenAiResponse,
)
from modelgauge.sut import REFUSAL_RESPONSE, SUTCompletion, SUTResponse

_FINISH_REASON_NORMAL = Candidate.FinishReason.STOP
_FINISH_REASON_SAFETY = Candidate.FinishReason.SAFETY
Expand Down Expand Up @@ -235,6 +236,26 @@ def test_google_genai_translate_response_refusal(google_default_sut, fake_native
assert response == SUTResponse(completions=[SUTCompletion(text=REFUSAL_RESPONSE)])


def test_google_genai_translate_response_no_completions(google_default_sut, fake_native_response_refusal, some_request):
no_completions = GoogleGenAiResponse(
**json.loads(
"""{
"candidates": [],
"usage_metadata": {
"prompt_token_count": 19,
"total_token_count": 19,
"cached_content_token_count": 0,
"candidates_token_count": 0
}
}
"""
)
)
response = google_default_sut.translate_response(some_request, no_completions)

assert response == SUTResponse(completions=[SUTCompletion(text=REFUSAL_RESPONSE)])


def test_google_genai_disabled_safety_translate_response_refusal_raises_exception(
google_disabled_safety_sut, fake_native_response_refusal, some_request
):
Expand Down
4 changes: 4 additions & 0 deletions plugins/mistral/modelgauge/suts/mistral_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def client(self) -> Mistral:
if not self._client:
self._client = Mistral(
api_key=self.api_key,
timeout_ms=BACKOFF_MAX_ELAPSED_MILLIS * 3,
retry_config=RetryConfig(
"backoff",
BackoffStrategy(
Expand All @@ -50,6 +51,9 @@ def client(self) -> Mistral:

def request(self, req: dict):
response = None
if self.client.chat.sdk_configuration._hooks.before_request_hooks:
# work around bug in client
self.client.chat.sdk_configuration._hooks.before_request_hooks = []
try:
response = self.client.chat.complete(**req)
return response
Expand Down
11 changes: 7 additions & 4 deletions src/modelbench/run_journal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@ def for_journal(o):
if isinstance(o, TestRunItem):
return {"test": o.test.uid, "item": o.source_id(), "sut": o.sut.uid}
if isinstance(o, SUTResponse):
completion = o.completions[0]
result = {"response_text": completion.text}
if completion.top_logprobs is not None:
result["logprobs"] = for_journal(completion.top_logprobs)
if o.completions:
completion = o.completions[0]
result = {"response_text": completion.text}
if completion.top_logprobs is not None:
result["logprobs"] = for_journal(completion.top_logprobs)
else:
result = {"response_text": None}
return result
elif isinstance(o, BaseModel):
return for_journal(o.model_dump(exclude_defaults=True, exclude_none=True))
Expand Down
3 changes: 3 additions & 0 deletions tests/modelbench_tests/test_run_journal.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def test_sut_response(self):
assert logprobs["token"] == "f"
assert logprobs["logprob"] == 1.0

def test_defective_sut_response(self):
assert for_journal(SUTResponse(completions=[])) == {"response_text": None}

def test_exception(self):
f = getframeinfo(currentframe())
try:
Expand Down