Skip to content

Commit

Permalink
Example doc for token classification of Llama and Dependent/Copied Mo…
Browse files Browse the repository at this point in the history
…dels (#34139)

* Added Example Doc for token classification on all tokenClassificationModels copied from llama

* Refactor code to add code sample docstrings for Gemma and Gemma2 models (including modular Gemma)

* Refactor code to update model checkpoint names for Qwen2 models
  • Loading branch information
h3110Fr13nd authored Oct 22, 2024
1 parent 644d528 commit 049682a
Show file tree
Hide file tree
Showing 13 changed files with 90 additions and 3 deletions.
9 changes: 9 additions & 0 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
Expand All @@ -48,6 +49,9 @@
from .configuration_gemma import GemmaConfig


_CHECKPOINT_FOR_DOC = "google/gemma-7b"


class GemmaRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
Expand Down Expand Up @@ -1233,6 +1237,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(GEMMA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/gemma/modular_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@

SPIECE_UNDERLINE = "▁"

_CHECKPOINT_FOR_DOC = "google/gemma-7b"

logger = logging.get_logger(__name__)

Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal,
Expand All @@ -47,6 +48,9 @@
from .configuration_gemma2 import Gemma2Config


_CHECKPOINT_FOR_DOC = "google/gemma2-7b"


class Gemma2RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
Expand Down Expand Up @@ -1292,6 +1296,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(GEMMA2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/gemma2/modular_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
from ...modeling_flash_attention_utils import _flash_attention_forward


_CHECKPOINT_FOR_DOC = "google/gemma2-7b"

logger = logging.get_logger(__name__)


Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
Expand All @@ -52,6 +53,7 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "meta-llama/Llama-2-7b-hf"
_CONFIG_FOR_DOC = "LlamaConfig"


Expand Down Expand Up @@ -1446,6 +1448,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
Expand All @@ -55,6 +56,7 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
_CONFIG_FOR_DOC = "MistralConfig"


Expand Down Expand Up @@ -1242,6 +1244,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
Expand All @@ -65,6 +66,7 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "mistralai/Mixtral-8x7B-v0.1"
_CONFIG_FOR_DOC = "MixtralConfig"


Expand Down Expand Up @@ -1468,6 +1470,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/nemotron/modeling_nemotron.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_greater_or_equal_2_10,
Expand All @@ -50,6 +51,7 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "nvidia/nemotron-3-8b-base-4k-hf"
_CONFIG_FOR_DOC = "NemotronConfig"


Expand Down Expand Up @@ -1323,6 +1325,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(NEMOTRON_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
14 changes: 13 additions & 1 deletion src/transformers/models/persimmon/modeling_persimmon.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,19 @@
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
)
from .configuration_persimmon import PersimmonConfig


logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "adept/persimmon-8b-base"
_CONFIG_FOR_DOC = "PersimmonConfig"


Expand Down Expand Up @@ -1120,6 +1127,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(PERSIMMON_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/qwen2/modeling_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
Expand All @@ -58,7 +59,7 @@
logger = logging.get_logger(__name__)


_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B-beta"
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-7B"
_CONFIG_FOR_DOC = "Qwen2Config"


Expand Down Expand Up @@ -1348,6 +1349,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
8 changes: 7 additions & 1 deletion src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
Expand All @@ -56,7 +57,7 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "Qwen/Qwen1.5-MoE-A2.7B"
_CHECKPOINT_FOR_DOC = "Qwen/Qwen2-57B-A14B"
_CONFIG_FOR_DOC = "Qwen2MoeConfig"


Expand Down Expand Up @@ -1533,6 +1534,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/stablelm/modeling_stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
Expand All @@ -56,6 +57,7 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "stabilityai/stablelm-3b-4e1t"
_CONFIG_FOR_DOC = "StableLmConfig"


Expand Down Expand Up @@ -1396,6 +1398,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(STABLELM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/starcoder2/modeling_starcoder2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
Expand All @@ -56,6 +57,7 @@

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "bigcode/starcoder2-7b"
_CONFIG_FOR_DOC = "Starcoder2Config"


Expand Down Expand Up @@ -1316,6 +1318,11 @@ def set_input_embeddings(self, value):
self.model.embed_tokens = value

@add_start_docstrings_to_model_forward(STARCODER2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TokenClassifierOutput,
config_class=_CONFIG_FOR_DOC,
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down

0 comments on commit 049682a

Please sign in to comment.