Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
Wovchena committed Jul 12, 2024
1 parent 5c615bf commit 3afc16d
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tests/python_tests/test_generate_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path
import torch
import functools
import math
from ov_genai_test_utils import (
get_models_list,
read_model,
Expand Down Expand Up @@ -680,7 +681,7 @@ def get_continuous_batching(path):


@pytest.mark.parametrize("generation_config", test_configs)
@pytest.mark.parametrize("prompt", prompts)
@pytest.mark.parametrize("prompt", batched_prompts)
@pytest.mark.precommit
def test_continuous_batching_vs_stateful(prompt, generation_config):
model_id, path, tokenizer, model, stateful = read_model((
Expand All @@ -691,5 +692,9 @@ def test_continuous_batching_vs_stateful(prompt, generation_config):
config.max_new_tokens = 100
cb = get_continuous_batching(path)
generated = cb.generate(prompt, **generation_config)
ref = stateful.generate(prompt, **generation_config)
assert generated == ref
reference = stateful.generate(prompt, **generation_config)
assert generated.texts == reference.texts
if 1 != generation_config.get("num_beams", 1):
# Stateful puts zeroes to generated.scores. Don't compare them.
for gen, ref in zip(generated.scores, reference.scores):
assert math.isclose(gen, ref, abs_tol=0.0001)

0 comments on commit 3afc16d

Please sign in to comment.