Skip to content

Commit

Permalink
addressed reviewers comments
Browse files Browse the repository at this point in the history
  • Loading branch information
levmckinney committed Aug 30, 2023
1 parent e6e396b commit 99076b6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
2 changes: 2 additions & 0 deletions tests/test_lenses.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def test_tuned_lens_generate_smoke(random_small_model: trf.PreTrainedModel):
)
assert tokens.shape[-1] <= 11
assert tokens.shape[-1] > 1
assert input_ids == tokens[:, :1]
assert input_ids == th.tensor([bos_token_id]), "Don't mutate input_ids!"

tokens = tuned_lens.generate(
model=random_small_model,
Expand Down
6 changes: 3 additions & 3 deletions tuned_lens/nn/lenses.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def generate(
"""
eos_token = model.generation_config.eos_token_id

tokens = input_ids.clone()
tokens = input_ids
if tokens.ndim == 1:
tokens = tokens.unsqueeze(0)
batch, prompt_len = tokens.shape
Expand All @@ -364,8 +364,8 @@ def generate(
new_logits = self.forward(new_hidden, layer)
if do_sample:
new_logits = new_logits / temp
new_logits = th.nn.functional.log_softmax(new_logits, dim=-1)
new_tokens = th.multinomial(new_logits.exp(), num_samples=1)
probs = new_logits.softmax(dim=-1)
new_tokens = th.multinomial(probs, num_samples=1)
else:
new_tokens = new_logits.argmax(dim=-1, keepdim=True)

Expand Down

0 comments on commit 99076b6

Please sign in to comment.