Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
abcdabcd987 committed May 2, 2024
1 parent 3afb7d9 commit 74c4442
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 34 deletions.
49 changes: 16 additions & 33 deletions tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from punica.ops import sample_prob, sample_topk, sample_topp

MAX_SAMPLING_ROUNDS = 32
trials = 5000
trials = 10000
device = torch.device("cuda:0")
dtype_str_list = ["float32"]
# dtype_str_list = ["float32", "float16", "bfloat16"]
vocab_list = [100, 5000]
vocab_list = [20, 5000]


@pytest.mark.parametrize("dtype_str", dtype_str_list)
Expand All @@ -31,25 +31,20 @@ def test_sample_prob(dtype_str: str, batch_size: int, beam_size: int, vocab: int
)
samples = sample_prob(probs, uniform_samples)
assert (samples < vocab).all()
cnt[
torch.arange(batch_size)[:, None, None],
torch.arange(beam_size)[None, :, None],
samples[:, None],
] += 1
samples = samples[..., None].to(torch.int64)
cnt.scatter_add_(-1, samples, cnt.new_ones(samples.shape))

torch.testing.assert_close(cnt / trials, probs.float(), rtol=0, atol=0.01)
torch.testing.assert_close(cnt / trials, probs.float(), rtol=0, atol=0.02)


@pytest.mark.parametrize("dtype_str", dtype_str_list)
@pytest.mark.parametrize("batch_size", [1, 13])
@pytest.mark.parametrize("beam_size", [1, 5])
@pytest.mark.parametrize("vocab", vocab_list)
@pytest.mark.parametrize("topk", [5, 100])
def test_sample_topk(
dtype_str: str, batch_size: int, beam_size: int, vocab: int, topk: int
):
def test_sample_topk(dtype_str: str, batch_size: int, beam_size: int, vocab: int):
torch.manual_seed(0xABCDABCD987)
dtype = getattr(torch, dtype_str)
topk = 5

probs = torch.nn.functional.softmax(
torch.randn((batch_size, beam_size, vocab), device=device), dim=-1, dtype=dtype
Expand All @@ -62,33 +57,24 @@ def test_sample_topk(
)
samples = sample_topk(probs, uniform_samples, topk)
assert (samples < vocab).all()
cnt[
torch.arange(batch_size)[:, None, None],
torch.arange(beam_size)[None, :, None],
samples[:, None],
] += 1
samples = samples[..., None].to(torch.int64)
cnt.scatter_add_(-1, samples, cnt.new_ones(samples.shape))

mask_idx = torch.topk(probs, k=vocab - topk, largest=False, dim=-1).indices
probs[
torch.arange(batch_size)[:, None, None],
torch.arange(beam_size)[None, :, None],
mask_idx,
] = 0
probs.scatter_(-1, mask_idx, probs.new_zeros(mask_idx.shape))
probs = probs / probs.sum(dim=-1, keepdim=True)
assert cnt[probs == 0].sum() == 0
torch.testing.assert_close(cnt / trials, probs.float(), rtol=0, atol=0.01)
torch.testing.assert_close(cnt / trials, probs.float(), rtol=0, atol=0.02)


@pytest.mark.parametrize("dtype_str", dtype_str_list)
@pytest.mark.parametrize("batch_size", [1, 13])
@pytest.mark.parametrize("beam_size", [1, 5])
@pytest.mark.parametrize("vocab", vocab_list)
@pytest.mark.parametrize("topp", [0.5, 0.9])
def test_sample_topp(
dtype_str: str, batch_size: int, beam_size: int, vocab: int, topp: float
):
def test_sample_topp(dtype_str: str, batch_size: int, beam_size: int, vocab: int):
torch.manual_seed(0xABCDABCD987)
dtype = getattr(torch, dtype_str)
topp = 0.2

probs = torch.nn.functional.softmax(
torch.randn((batch_size, beam_size, vocab), device=device), dim=-1, dtype=dtype
Expand All @@ -101,11 +87,8 @@ def test_sample_topp(
)
samples = sample_topp(probs, uniform_samples, topp)
assert (samples < vocab).all()
cnt[
torch.arange(batch_size)[:, None, None],
torch.arange(beam_size)[None, :, None],
samples[:, None],
] += 1
samples = samples[..., None].to(torch.int64)
cnt.scatter_add_(-1, samples, cnt.new_ones(samples.shape))

sorted_probs, sorted_indices = torch.sort(probs, descending=False, dim=-1)
cum_probs = torch.cumsum(sorted_probs, dim=-1)
Expand All @@ -116,4 +99,4 @@ def test_sample_topp(
probs[indices_to_remove] = 0
probs = probs / probs.sum(dim=-1, keepdim=True)
assert cnt[probs == 0].sum() == 0
torch.testing.assert_close(cnt / trials, probs.float(), rtol=0, atol=0.01)
torch.testing.assert_close(cnt / trials, probs.float(), rtol=0, atol=0.02)
2 changes: 1 addition & 1 deletion third_party/flashinfer

0 comments on commit 74c4442

Please sign in to comment.