Skip to content

Commit

Permalink
merged main
Browse files Browse the repository at this point in the history
  • Loading branch information
czaloom committed Aug 19, 2024
2 parents a3b997f + 40f7b3b commit 9085608
Show file tree
Hide file tree
Showing 18 changed files with 5,971 additions and 1,619 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ repos:
args: [--line-length=79]

- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.350
rev: v1.1.376
hooks:
- id: pyright
additional_dependencies: [
Expand All @@ -52,7 +52,7 @@ repos:
"psycopg2-binary",
"pgvector",
"openai",
"mistralai<=0.4.2",
"mistralai>=1.0",
"absl-py",
"nltk",
"rouge_score",
Expand Down
2 changes: 1 addition & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"structlog",
"pgvector",
"openai",
"mistralai <= 0.4.2",
"mistralai >= 1.0",
"absl-py",
"nltk",
"rouge_score",
Expand Down
62 changes: 24 additions & 38 deletions api/tests/functional-tests/backend/core/test_llm_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from unittest.mock import MagicMock

import pytest
from mistralai.exceptions import MistralException
from mistralai.models.chat_completion import (
from mistralai.models import (
AssistantMessage,
ChatCompletionChoice,
ChatCompletionResponse,
ChatCompletionResponseChoice,
ChatMessage,
FinishReason,
UsageInfo,
)
from mistralai.models.common import UsageInfo
from mistralai.models.sdkerror import SDKError as MistralSDKError
from openai import OpenAIError
from openai.types.chat import ChatCompletionMessage
from openai.types.chat.chat_completion import ChatCompletion, Choice
Expand Down Expand Up @@ -1023,11 +1022,6 @@ def _create_mock_chat_completion_none_content(
# Check that the WrappedOpenAIClient does not alter the messages.
assert fake_message == client._process_messages(fake_message)

# OpenAI only allows the roles of system, user and assistant.
invalid_message = [{"role": "invalid", "content": "Some content."}]
with pytest.raises(ValueError):
client._process_messages(invalid_message)

# The OpenAI Client should be able to connect if the API key is set as the environment variable.
os.environ["OPENAI_API_KEY"] = "dummy_key"
client = WrappedOpenAIClient(model_name="model_name")
Expand Down Expand Up @@ -1080,15 +1074,15 @@ def _create_mock_chat_completion_with_bad_length(
model="gpt-3.5-turbo",
object="chat.completion",
choices=[
ChatCompletionResponseChoice(
finish_reason=FinishReason("length"),
ChatCompletionChoice(
finish_reason="length",
index=0,
message=ChatMessage(
role="role",
content="some content",
name=None,
message=AssistantMessage(
role="assistant",
content="some response",
name=None, # type: ignore - mistralai issue
tool_calls=None,
tool_call_id=None,
tool_call_id=None, # type: ignore - mistralai issue
),
)
],
Expand All @@ -1106,15 +1100,15 @@ def _create_mock_chat_completion(
model="gpt-3.5-turbo",
object="chat.completion",
choices=[
ChatCompletionResponseChoice(
finish_reason=FinishReason("stop"),
ChatCompletionChoice(
finish_reason="stop",
index=0,
message=ChatMessage(
role="role",
message=AssistantMessage(
role="assistant",
content="some response",
name=None,
name=None, # type: ignore - mistralai issue
tool_calls=None,
tool_call_id=None,
tool_call_id=None, # type: ignore - mistralai issue
),
)
],
Expand All @@ -1128,20 +1122,12 @@ def _create_mock_chat_completion(
client = WrappedMistralAIClient(
api_key="invalid_key", model_name="model_name"
)
fake_message = [{"role": "role", "content": "content"}]
with pytest.raises(MistralException):
fake_message = [{"role": "assistant", "content": "content"}]
with pytest.raises(MistralSDKError):
client.connect()
client(fake_message)

assert [
ChatMessage(
role="role",
content="content",
name=None,
tool_calls=None,
tool_call_id=None,
)
] == client._process_messages(fake_message)
assert fake_message == client._process_messages(fake_message)

# The Mistral Client should be able to connect if the API key is set as the environment variable.
os.environ["MISTRAL_API_KEY"] = "dummy_key"
Expand All @@ -1151,18 +1137,18 @@ def _create_mock_chat_completion(
client.client = MagicMock()

# The metric computation should fail if the request fails.
client.client.chat = _create_bad_request
client.client.chat.complete = _create_bad_request
with pytest.raises(ValueError) as e:
client(fake_message)

# The metric computation should fail when the finish reason is bad length.
client.client.chat = _create_mock_chat_completion_with_bad_length
client.client.chat.complete = _create_mock_chat_completion_with_bad_length
with pytest.raises(ValueError) as e:
client(fake_message)
assert "reached max token limit" in str(e)

# The metric computation should run successfully when the finish reason is stop.
client.client.chat = _create_mock_chat_completion
client.client.chat.complete = _create_mock_chat_completion
assert client(fake_message) == "some response"


Expand Down
147 changes: 147 additions & 0 deletions api/tests/functional-tests/backend/metrics/test_detection.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import numpy as np
import pytest
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from valor_api import crud, enums, schemas
from valor_api.backend import core
from valor_api.backend.metrics.detection import (
RankedPair,
_compute_detailed_curves,
_compute_detection_metrics,
_compute_detection_metrics_with_detailed_precision_recall_curve,
_convert_annotations_to_common_type,
compute_detection_metrics,
)
from valor_api.backend.models import (
Expand Down Expand Up @@ -2276,3 +2279,147 @@ def test_detection_exceptions(db: Session):

# show that no errors raised
compute_detection_metrics(db=db, evaluation_id=evaluation_id)


def test__convert_annotations_to_common_type(db: Session):

dataset_name = "dataset"
model_name = "model"

xmin, xmax, ymin, ymax = 11, 45, 37, 102
h, w = 150, 200
mask = np.zeros((h, w), dtype=bool)
mask[ymin:ymax, xmin:xmax] = True

pts = [
(float(xmin), float(ymin)),
(float(xmin), float(ymax)),
(float(xmax), float(ymax)),
(float(xmax), float(ymin)),
(float(xmin), float(ymin)),
]
poly = schemas.Polygon(value=[pts])
raster = schemas.Raster.from_numpy(mask)
box = schemas.Box.from_extrema(xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax)
datum = schemas.Datum(uid="123")

gt_box = schemas.GroundTruth(
datum=datum,
dataset_name=dataset_name,
annotations=[
schemas.Annotation(
bounding_box=box,
labels=[schemas.Label(key="box", value="value")],
is_instance=True,
)
],
)
gt_polygon = schemas.GroundTruth(
datum=datum,
dataset_name=dataset_name,
annotations=[
schemas.Annotation(
polygon=poly,
labels=[schemas.Label(key="polygon", value="value")],
is_instance=True,
)
],
)
gt_raster = schemas.GroundTruth(
datum=datum,
dataset_name=dataset_name,
annotations=[
schemas.Annotation(
raster=raster,
labels=[schemas.Label(key="raster", value="value")],
is_instance=True,
)
],
)

pd_box = schemas.Prediction(
datum=datum,
dataset_name=dataset_name,
model_name=model_name,
annotations=[
schemas.Annotation(
bounding_box=box,
labels=[schemas.Label(key="box", value="value", score=0.88)],
is_instance=True,
)
],
)
pd_polygon = schemas.Prediction(
datum=datum,
dataset_name=dataset_name,
model_name=model_name,
annotations=[
schemas.Annotation(
polygon=poly,
labels=[
schemas.Label(key="polygon", value="value", score=0.89)
],
is_instance=True,
)
],
)
pd_raster = schemas.Prediction(
datum=datum,
dataset_name=dataset_name,
model_name=model_name,
annotations=[
schemas.Annotation(
raster=raster,
labels=[schemas.Label(key="raster", value="value", score=0.9)],
is_instance=True,
)
],
)

gts = [
(enums.AnnotationType.BOX, gt_box),
(enums.AnnotationType.POLYGON, gt_polygon),
(enums.AnnotationType.RASTER, gt_raster),
]
pds = [
(enums.AnnotationType.BOX, pd_box),
(enums.AnnotationType.POLYGON, pd_polygon),
(enums.AnnotationType.RASTER, pd_raster),
]

for gt_type, gt in gts:
for pd_type, pd in pds:
crud.create_dataset(
db=db, dataset=schemas.Dataset(name=dataset_name)
)
crud.create_groundtruths(db=db, groundtruths=[gt])
crud.finalize(db=db, dataset_name="dataset")
crud.create_model(db=db, model=schemas.Model(name=model_name))
crud.create_predictions(db=db, predictions=[pd])

dataset = core.fetch_dataset(db=db, name=dataset_name)
model = core.fetch_model(db=db, name=model_name)

for target_type in [
enums.AnnotationType.RASTER,
enums.AnnotationType.POLYGON,
enums.AnnotationType.BOX,
]:
if min(gt_type, pd_type) >= target_type:
_convert_annotations_to_common_type(
db=db,
datasets=[dataset],
model=model,
target_type=target_type,
)
else:
with pytest.raises(ValueError):
_convert_annotations_to_common_type(
db=db,
datasets=[dataset],
model=model,
target_type=target_type,
)

crud.delete(db=db, dataset_name=dataset_name)
crud.delete(db=db, model_name=model_name)
8 changes: 2 additions & 6 deletions api/tests/functional-tests/crud/test_create_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,15 +1064,11 @@ def test_gt_seg_as_mask_or_polys(
assert len(segs.annotations) == 2

assert segs.annotations[0].raster and segs.annotations[1].raster
decoded_mask0 = np.array(
_bytes_to_pil(b64decode(segs.annotations[0].raster.mask))
)
decoded_mask0 = segs.annotations[0].raster.array
assert decoded_mask0.shape == mask.shape
np.testing.assert_equal(decoded_mask0, mask)

decoded_mask1 = np.array(
_bytes_to_pil(b64decode(segs.annotations[1].raster.mask))
)
decoded_mask1 = segs.annotations[1].raster.array
assert decoded_mask1.shape == mask.shape
np.testing.assert_equal(decoded_mask1, mask)

Expand Down
Loading

0 comments on commit 9085608

Please sign in to comment.