Skip to content

Commit

Permalink
FIX: Address final comments for transformers integration (huggingface#13
Browse files Browse the repository at this point in the history
)

* fix modeling final nits and add proper test file

* for now leave empty tests

* add integration test

* push new test
  • Loading branch information
younesbelkada authored Mar 14, 2024
1 parent c841cc7 commit 966ec9c
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 35 deletions.
19 changes: 6 additions & 13 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,23 +106,13 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
emb = torch.repeat_interleave(freqs, 2, dim=-1)
self.register_buffer("_cos_cached", emb.cos().to(torch.get_default_dtype()), persistent=False)
self.register_buffer("_sin_cached", emb.sin().to(torch.get_default_dtype()), persistent=False)

@torch.no_grad()
def forward(self, x, position_ids, seq_len=None):
if seq_len is not None:
logger.warning_once("The `seq_len` argument is deprecated and unused. It will be removed in v4.39.")

def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()

# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
Expand Down Expand Up @@ -1027,9 +1017,11 @@ def _update_causal_mask(self, attention_mask, input_tensor):
return causal_mask


# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with Llama->Cohere
class CohereForCausalLM(CoherePreTrainedModel):
_tied_weights_keys = ["model.embed_tokens.weight", "lm_head.weight"]
_tied_weights_keys = ["lm_head.weight"]

# Ignore copy
def __init__(self, config):
super().__init__(config)
self.model = CohereModel(config)
Expand Down Expand Up @@ -1058,6 +1050,7 @@ def set_decoder(self, decoder):
def get_decoder(self):
return self.model

# Ignore copy
@add_start_docstrings_to_model_forward(COHERE_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def _generate_supported_model_class_names(
"hubert",
"layoutlm",
"llama",
"cohere",
"lxmert",
"m2m_100",
"marian",
Expand Down
146 changes: 124 additions & 22 deletions tests/models/cohere/test_modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,27 @@

from transformers import CohereConfig, is_torch_available
from transformers.testing_utils import (
require_bitsandbytes,
require_torch,
require_torch_multi_gpu,
require_torch_sdpa,
slow,
torch_device,
)

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ids_tensor
from ...test_modeling_common import ModelTesterMixin, ids_tensor
from ...test_pipeline_mixin import PipelineTesterMixin


if is_torch_available():
import torch

from transformers import CohereForCausalLM, CohereModel
from transformers import AutoTokenizer, CohereForCausalLM, CohereModel


# Copied from transformers.tests.models.llama.LlamaModelTester with Llama->Cohere
class CohereModelTester:
def __init__(
self,
Expand Down Expand Up @@ -109,6 +115,7 @@ def prepare_config_and_inputs(self):

return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels

# Ignore copy
def get_config(self):
return CohereConfig(
vocab_size=self.vocab_size,
Expand All @@ -124,6 +131,7 @@ def get_config(self):
is_decoder=False,
initializer_range=self.initializer_range,
pad_token_id=self.pad_token_id,
eos_token_id=self.pad_token_id,
)

def create_and_check_model(
Expand Down Expand Up @@ -262,7 +270,7 @@ def prepare_config_and_inputs_for_common(self):


@require_torch
class CohereModelTest(unittest.TestCase):
class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CohereModel, CohereForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (CohereForCausalLM,) if is_torch_available() else ()
pipeline_model_mapping = (
Expand All @@ -285,6 +293,14 @@ def setUp(self):
self.model_tester = CohereModelTester(self)
self.config_tester = ConfigTester(self, config_class=CohereConfig, hidden_size=37)

def test_config(self):
self.config_tester.run_common_tests()

@unittest.skip("TODO @gante fix this for Cohere")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass

def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
Expand All @@ -295,26 +311,112 @@ def test_model_various_embeddings(self):
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)

@unittest.skip("TODO @gante fix this")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_bitsandbytes
@require_torch_sdpa
@require_torch_multi_gpu
@slow
def test_eager_matches_sdpa_generate(self):
"""
Overwritting the common test as the test is flaky on tiny models
"""
max_new_tokens = 30

model_id = "CohereForAI/c4ai-command-r-v01-4bit"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model_sdpa = CohereForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map="auto"
)
self.assertTrue(model_sdpa.config._attn_implementation == "sdpa")

model_eager = CohereForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16, attn_implementation="eager", device_map="auto"
)

self.assertTrue(model_eager.config._attn_implementation == "eager")

for name, submodule in model_eager.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
raise ValueError("The eager model should not have SDPA attention layers")

has_sdpa = False
for name, submodule in model_sdpa.named_modules():
if "SdpaAttention" in submodule.__class__.__name__:
has_sdpa = True
break
if not has_sdpa:
raise ValueError("The SDPA model should have SDPA attention layers")

texts = [
"hi here's a longer context, getting longer and",
"Hello this is a very long sentence my friend, very long for real",
"Today I am in Paris and",
]

for padding_side in ["left", "right"]:
tokenizer.padding_side = padding_side
tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer(texts, return_tensors="pt", padding=True).to(torch_device)

res_eager = model_eager.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
res_sdpa = model_sdpa.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

with self.subTest(f"{padding_side}"):
torch.testing.assert_close(
res_eager,
res_sdpa,
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
)


@require_torch
@slow
class CohereIntegrationTest(unittest.TestCase):
@unittest.skip("Logits are not exactly the same, once we fix the instabalities somehow, will update!")
@slow
def test_model_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = CohereForCausalLM.from_pretrained("CohereForAI/c4ai-command-r-v01", device_map="auto")
out = model(torch.tensor(input_ids).unsqueeze(0))
# # Expected mean on dim = -1
EXPECTED_MEAN = torch.tensor([[0.5077, -2.5771, -1.1590, -2.6220, -1.7837, -2.4421, -1.3293, -2.2028]])
torch.testing.assert_close(out[0].mean(-1).cpu(), EXPECTED_MEAN, atol=1e-2, rtol=1e-2)
# slicing logits[0, 0, 0:30]
EXPECTED_SLICE = torch.tensor([ 1.8525, 5.0039, 2.7734, 3.6270, 0.9390, -0.4587, 3.4062, 0.9468, \
3.7324, 1.2344, 5.3047, 4.7266, 5.9414, 5.5195, 1.8047, 3.5215, \
1.5752, 3.7031, 6.2891, 3.4785, 2.0293, 4.2539, 2.8086, 4.7070, \
3.6953, 4.0391, 3.9766, 3.3066, 2.9395, 3.3105]) # fmt: skip
torch.testing.assert_close(out[0][0, 0, :30].cpu(), EXPECTED_SLICE, atol=1e-5, rtol=1e-5)
@require_torch_multi_gpu
def test_batched_4bit(self):
model_id = "CohereForAI/c4ai-command-r-v01-4bit"

EXPECTED_TEXT = [
'Hello today I am going to show you how to make a simple and easy card using the new stamp set called "Hello" from the Occasions catalog. This set is so versatile and can be used for many occasions. I used the new In',
"Hi there, here we are again with another great collection of free fonts. This time we have gathered 10 free fonts that you can download and use in your designs. These fonts are free for personal and commercial use. So",
]

model = CohereForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

tokenizer.pad_token = tokenizer.eos_token

text = ["Hello today I am going to show you how to", "Hi there, here we are"]
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device)

output = model.generate(**inputs, max_new_tokens=40, do_sample=False)
self.assertEqual(tokenizer.batch_decode(output, skip_special_tokens=True), EXPECTED_TEXT)

def test_batched_small_model_logits(self):
# Since the model is very large, we created a random cohere model so that we can do a simple
# logits check on it.
model_id = "hf-internal-testing/cohere-random"

EXPECTED_LOGITS = torch.Tensor(
[
[[0.0000, 0.1866, -0.1997], [0.0000, -0.0736, 0.1785], [0.0000, -0.1965, -0.0569]],
[[0.0000, -0.0302, 0.1488], [0.0000, -0.0402, 0.1351], [0.0000, -0.0341, 0.1116]],
]
).to(torch_device)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = CohereForCausalLM.from_pretrained(model_id, low_cpu_mem_usage=True, torch_dtype=torch.float16).to(
torch_device
)

tokenizer.pad_token = tokenizer.eos_token

text = ["Hello today I am going to show you how to", "Hi there, here we are"]
inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device)

with torch.no_grad():
output = model(**inputs)

logits = output.logits
self.assertTrue(torch.allclose(EXPECTED_LOGITS, logits[:, :3, :3], rtol=1e-3, atol=1e-3))

0 comments on commit 966ec9c

Please sign in to comment.