Skip to content

Commit

Permalink
OPT - fix docstring and improve tests slighly (huggingface#17228)
Browse files Browse the repository at this point in the history
* correct some stuff

* fix doc tests

* make style
  • Loading branch information
patrickvonplaten authored and elusenji committed Jun 12, 2022
1 parent 773eee9 commit f0424a4
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 58 deletions.
48 changes: 14 additions & 34 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = ""
_CHECKPOINT_FOR_DOC = "facebook/opt-350m"
_CONFIG_FOR_DOC = "OPTConfig"
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"

# Base model docstring
_EXPECTED_OUTPUT_SHAPE = [1, 8, 768]
_EXPECTED_OUTPUT_SHAPE = [1, 8, 1024]


OPT_PRETRAINED_MODEL_ARCHIVE_LIST = [
Expand Down Expand Up @@ -424,25 +424,6 @@ def _set_gradient_checkpointing(self, module, value=False):
module.gradient_checkpointing = value


OPT_GENERATION_EXAMPLE = r"""
Generation example:
```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM
>>> model = OPTForCausalLM.from_pretrained("ArthurZ/opt-350m")
>>> tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
>>> TEXTS_TO_GENERATE = "Hey, are you consciours? Can you talk to me?" "Hi there, my name is Barack"
>>> inputs = tokenizer([TEXTS_TO_GENERATE], max_length=1024, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs["input_ids"], num_beams=2, min_length=0, max_length=20)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'I'm not conscious.<\s>'
```
"""

OPT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Expand Down Expand Up @@ -933,19 +914,18 @@ def forward(
Example:
```python
>>> from transformers import OPTTokenizer, OPTForCausalLM
# this needs fixing
>>> tokenizer = OPTTokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
>>> model = OPTForCausalLM.from_pretrained("ArthurZ/opt-350m")
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> logits = outputs.logits
>>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
>>> list(logits.shape) == expected_shape
True
>>> from transformers import GPT2Tokenizer, OPTForCausalLM
>>> model = OPTForCausalLM.from_pretrained("facebook/opt-350m")
>>> tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
Expand Down
76 changes: 52 additions & 24 deletions tests/models/opt/test_modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import timeout_decorator # noqa

from transformers import OPTConfig, is_torch_available, pipeline
from transformers import OPTConfig, is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from transformers.utils import cached_property

Expand Down Expand Up @@ -330,33 +330,61 @@ def test_logits(self):
assert torch.allclose(logits, logits_meta, atol=1e-4)


@require_tokenizers
@slow
class OPTGenerationTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.all_model_path = ["facebook/opt-125m", "facebook/opt-350m"]

def test_generation(self):
prompts = [
@property
def prompts(self):
return [
"Today is a beautiful day and I want to",
"In the city of",
"Paris is the capital of France and",
"Computers and mobile phones have taken",
]
NEXT_TOKENS = [3392, 764, 5, 81]
GEN_OUTPUT = []

tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
for model in self.all_model_path:
model = OPTForCausalLM.from_pretrained(self.path_model)
model = model.eval()
model.config.eos_token_id = tokenizer.eos_token_id

gen = pipeline("text-generation", model=model, tokenizer=tokenizer, return_tensors=True)

for prompt in prompts:
len_input_sentence = len(tokenizer.tokenize(prompt))
predicted_next_token = gen(prompt)[0]["generated_token_ids"][len_input_sentence]
GEN_OUTPUT.append(predicted_next_token)
self.assertListEqual(GEN_OUTPUT, NEXT_TOKENS)

def test_generation_pre_attn_layer_norm(self):
model_id = "facebook/opt-125m"

EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want to thank",
"In the city of Rome Canaver Canaver Canaver Canaver",
"Paris is the capital of France and Parisdylib",
"Computers and mobile phones have taken precedence over",
]

predicted_outputs = []
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
model = OPTForCausalLM.from_pretrained(model_id)

for prompt in self.prompts:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

generated_ids = model.generate(input_ids, max_length=10)

generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
predicted_outputs += generated_string

self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)

def test_generation_post_attn_layer_norm(self):
model_id = "facebook/opt-350m"

EXPECTED_OUTPUTS = [
"Today is a beautiful day and I want to share",
"In the city of San Francisco, the city",
"Paris is the capital of France and the capital",
"Computers and mobile phones have taken over the",
]

predicted_outputs = []
tokenizer = GPT2Tokenizer.from_pretrained(model_id)
model = OPTForCausalLM.from_pretrained(model_id)

for prompt in self.prompts:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

generated_ids = model.generate(input_ids, max_length=10)

generated_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
predicted_outputs += generated_string

self.assertListEqual(predicted_outputs, EXPECTED_OUTPUTS)
1 change: 1 addition & 0 deletions utils/documentation_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ src/transformers/models/marian/modeling_marian.py
src/transformers/models/mbart/modeling_mbart.py
src/transformers/models/mobilebert/modeling_mobilebert.py
src/transformers/models/mobilebert/modeling_tf_mobilebert.py
src/transformers/models/opt/modeling_opt.py
src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/plbart/modeling_plbart.py
src/transformers/models/poolformer/modeling_poolformer.py
Expand Down

0 comments on commit f0424a4

Please sign in to comment.