Skip to content

Commit

Permalink
implemented basic tuned lens generate function
Browse files Browse the repository at this point in the history
  • Loading branch information
levmckinney committed Aug 21, 2023
1 parent 9bf1f35 commit e6e396b
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/test_lenses.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,28 @@ def test_from_model_and_pretrained_propogates_kwargs(
TunedLens.from_unembed_and_pretrained(
lens_resource_id="does not use", unembed=unembed, resource_id="bar"
)


def test_tuned_lens_generate_smoke(random_small_model: trf.PreTrainedModel):
tuned_lens = TunedLens.from_model(random_small_model)
bos_token_id = random_small_model.config.bos_token_id
input_ids = th.tensor([bos_token_id])
tokens = tuned_lens.generate(
model=random_small_model,
layer=2,
do_sample=True,
input_ids=input_ids,
max_new_tokens=10,
)
assert tokens.shape[-1] <= 11
assert tokens.shape[-1] > 1

tokens = tuned_lens.generate(
model=random_small_model,
layer=2,
input_ids=input_ids,
do_sample=False,
max_new_tokens=10,
)
assert tokens.shape[-1] <= 11
assert tokens.shape[-1] > 1
63 changes: 63 additions & 0 deletions tuned_lens/nn/lenses.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,66 @@ def forward(self, h: th.Tensor, idx: int) -> th.Tensor:
def __len__(self) -> int:
"""Return the number of layer translators in the lens."""
return len(self.layer_translators)

@th.inference_mode()
def generate(
self,
model: PreTrainedModel,
layer: int,
input_ids: th.Tensor,
do_sample: bool = True,
temp: float = 1.0,
max_new_tokens: int = 100,
) -> th.Tensor:
"""Generate from the tuned lens at the given layer.
Args:
model: The base model the generate from. Usually the model this lens trained
on.
layer: The layer to generate from.
input_ids: (batch x prompt_len) The input ids to generate from.
do_sample: Whether to use sampling or greedy decoding.
temp: The temperature to use for sampling.
max_new_tokens: The maximum number of tokens to generate.
Returns:
The prompt concatenated with the newly generated tokens.
"""
eos_token = model.generation_config.eos_token_id

tokens = input_ids.clone()
if tokens.ndim == 1:
tokens = tokens.unsqueeze(0)
batch, prompt_len = tokens.shape
del prompt_len
past_key_values = None
done = th.zeros(batch, dtype=th.bool)

for _ in range(max_new_tokens):
output = model(
input_ids=tokens,
output_hidden_states=True,
use_cache=True,
past_key_values=past_key_values,
)
past_key_values = output.past_key_values
hidden = output.hidden_states[layer]
new_hidden = hidden[:, -1, :]
new_logits = self.forward(new_hidden, layer)
if do_sample:
new_logits = new_logits / temp
new_logits = th.nn.functional.log_softmax(new_logits, dim=-1)
new_tokens = th.multinomial(new_logits.exp(), num_samples=1)
else:
new_tokens = new_logits.argmax(dim=-1, keepdim=True)

# Once a sequence has generated an EOS token, it should not generate any
# other tokens.
done = done | (new_tokens == eos_token)
new_tokens = new_tokens.masked_fill(done, eos_token)
tokens = th.cat([tokens, new_tokens], dim=-1)
# Halt generation if all sequences have generated an EOS token.
if done.all():
break

return tokens

0 comments on commit e6e396b

Please sign in to comment.