-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Creation of the PylateModelCard (#67)
* Creation of the PylateModelCard * Fixing ruff top import * Removing the example making tests fail * Changing Sentence Transformer model default to PyLate * Moving files to a dedicated subfolder * Removing all the awfully copy-pasted redundant code to extend ST properly * Adding init for hf_hub * Changing docstring for automatic parsing documentation * Consistency in the model args * Adding save to tests to test saving/model card creation
- Loading branch information
Showing
6 changed files
with
598 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .model_card import PylateModelCardData | ||
|
||
__all__ = ["PylateModelCardData"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from collections import defaultdict | ||
from dataclasses import dataclass, field, fields | ||
from pathlib import Path | ||
from platform import python_version | ||
from typing import TYPE_CHECKING, Any, Literal | ||
|
||
import torch | ||
import transformers | ||
from huggingface_hub import ModelCard | ||
from sentence_transformers import SentenceTransformerModelCardData | ||
from sentence_transformers import __version__ as sentence_transformers_version | ||
from sentence_transformers.util import ( | ||
is_accelerate_available, | ||
is_datasets_available, | ||
) | ||
from torch import nn | ||
from transformers.integrations import CodeCarbonCallback | ||
|
||
from ..__version__ import __version__ as pylate_version | ||
|
||
if is_datasets_available(): | ||
pass | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
if TYPE_CHECKING: | ||
from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator | ||
from sentence_transformers.SentenceTransformer import SentenceTransformer | ||
from sentence_transformers.trainer import SentenceTransformerTrainer | ||
|
||
|
||
IGNORED_FIELDS = ["model", "trainer", "eval_results_dict"] | ||
|
||
|
||
def get_versions() -> dict[str, Any]: | ||
versions = { | ||
"python": python_version(), | ||
"sentence_transformers": sentence_transformers_version, | ||
"transformers": transformers.__version__, | ||
"torch": torch.__version__, | ||
"pylate": pylate_version, | ||
} | ||
if is_accelerate_available(): | ||
from accelerate import __version__ as accelerate_version | ||
|
||
versions["accelerate"] = accelerate_version | ||
if is_datasets_available(): | ||
from datasets import __version__ as datasets_version | ||
|
||
versions["datasets"] = datasets_version | ||
from tokenizers import __version__ as tokenizers_version | ||
|
||
versions["tokenizers"] = tokenizers_version | ||
return versions | ||
|
||
|
||
@dataclass | ||
class PylateModelCardData(SentenceTransformerModelCardData): | ||
""" | ||
A dataclass for storing data used in the model card. | ||
Parameters | ||
---------- | ||
language | ||
The model language, either a string or a list of strings, e.g., "en" or ["en", "de", "nl"]. | ||
license | ||
The license of the model, e.g., "apache-2.0", "mit", or "cc-by-nc-sa-4.0". | ||
model_name | ||
The pretty name of the model, e.g., "SentenceTransformer based on microsoft/mpnet-base". | ||
model_id | ||
The model ID for pushing the model to the Hub, e.g., "tomaarsen/sbert-mpnet-base-allnli". | ||
train_datasets | ||
A list of dictionaries containing names and/or Hugging Face dataset IDs for training datasets, | ||
e.g., [{"name": "SNLI", "id": "stanfordnlp/snli"}, {"name": "MultiNLI", "id": "nyu-mll/multi_nli"}, {"name": "STSB"}]. | ||
eval_datasets | ||
A list of dictionaries containing names and/or Hugging Face dataset IDs for evaluation datasets, | ||
e.g., [{"name": "SNLI", "id": "stanfordnlp/snli"}, {"id": "mteb/stsbenchmark-sts"}]. | ||
task_name | ||
The human-readable task the model is trained on, e.g., "semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more". | ||
tags | ||
A list of tags for the model, e.g., ["sentence-transformers", "sentence-similarity", "feature-extraction"]. | ||
""" | ||
|
||
# Potentially provided by the user | ||
language: str | list[str] | None = field(default_factory=list) | ||
license: str | None = None | ||
model_name: str | None = None | ||
model_id: str | None = None | ||
train_datasets: list[dict[str, str]] = field(default_factory=list) | ||
eval_datasets: list[dict[str, str]] = field(default_factory=list) | ||
task_name: str = "semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more" | ||
tags: list[str] | None = field( | ||
default_factory=lambda: [ | ||
"ColBERT", | ||
"PyLate", | ||
"sentence-transformers", | ||
"sentence-similarity", | ||
"feature-extraction", | ||
] | ||
) | ||
generate_widget_examples: Literal["deprecated"] = "deprecated" | ||
|
||
# Automatically filled by `ModelCardCallback` and the Trainer directly | ||
base_model: str | None = field(default=None, init=False) | ||
base_model_revision: str | None = field(default=None, init=False) | ||
non_default_hyperparameters: dict[str, Any] = field( | ||
default_factory=dict, init=False | ||
) | ||
all_hyperparameters: dict[str, Any] = field(default_factory=dict, init=False) | ||
eval_results_dict: dict[SentenceEvaluator, dict[str, Any]] | None = field( | ||
default_factory=dict, init=False | ||
) | ||
training_logs: list[dict[str, float]] = field(default_factory=list, init=False) | ||
widget: list[dict[str, str]] = field(default_factory=list, init=False) | ||
predict_example: str | None = field(default=None, init=False) | ||
label_example_list: list[dict[str, str]] = field(default_factory=list, init=False) | ||
code_carbon_callback: CodeCarbonCallback | None = field(default=None, init=False) | ||
citations: dict[str, str] = field(default_factory=dict, init=False) | ||
best_model_step: int | None = field(default=None, init=False) | ||
trainer: SentenceTransformerTrainer | None = field( | ||
default=None, init=False, repr=False | ||
) | ||
datasets: list[str] = field(default_factory=list, init=False, repr=False) | ||
|
||
# Utility fields | ||
first_save: bool = field(default=True, init=False) | ||
widget_step: int = field(default=-1, init=False) | ||
|
||
# Computed once, always unchanged | ||
pipeline_tag: str = field(default="sentence-similarity", init=False) | ||
library_name: str = field(default="PyLate", init=False) | ||
version: dict[str, str] = field(default_factory=get_versions, init=False) | ||
|
||
# Passed via `register_model` only | ||
model: SentenceTransformer | None = field(default=None, init=False, repr=False) | ||
|
||
def set_losses(self, losses: list[nn.Module]) -> None: | ||
citations = { | ||
"Sentence Transformers": """ | ||
@inproceedings{reimers-2019-sentence-bert, | ||
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks", | ||
author = "Reimers, Nils and Gurevych, Iryna", | ||
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing", | ||
month = "11", | ||
year = "2019", | ||
publisher = "Association for Computational Linguistics", | ||
url = "https://arxiv.org/abs/1908.10084" | ||
}""", | ||
"PyLate": """ | ||
@misc{PyLate, | ||
title={PyLate: Flexible Training and Retrieval for Late Interaction Models}, | ||
author={Chaffin, Antoine and Sourty, Raphaël}, | ||
url={https://github.com/lightonai/pylate}, | ||
year={2024} | ||
}""", | ||
} | ||
for loss in losses: | ||
try: | ||
citations[loss.__class__.__name__] = loss.citation | ||
except Exception: | ||
pass | ||
inverted_citations = defaultdict(list) | ||
for loss, citation in citations.items(): | ||
inverted_citations[citation].append(loss) | ||
|
||
def join_list(losses: list[str]) -> str: | ||
if len(losses) > 1: | ||
return ", ".join(losses[:-1]) + " and " + losses[-1] | ||
return losses[0] | ||
|
||
self.citations = { | ||
join_list(losses): citation | ||
for citation, losses in inverted_citations.items() | ||
} | ||
self.add_tags( | ||
[ | ||
f"loss:{loss}" | ||
for loss in {loss.__class__.__name__: loss for loss in losses} | ||
] | ||
) | ||
|
||
def to_dict(self) -> dict[str, Any]: | ||
# Try to set the base model | ||
if self.first_save and not self.base_model: | ||
try: | ||
self.try_to_set_base_model() | ||
except Exception: | ||
pass | ||
|
||
# Set the model name | ||
if not self.model_name: | ||
if self.base_model: | ||
self.model_name = f"PyLate model based on {self.base_model}" | ||
else: | ||
self.model_name = "PyLate" | ||
|
||
super_dict = {field.name: getattr(self, field.name) for field in fields(self)} | ||
|
||
# Compute required formats from the (usually post-training) evaluation data | ||
if self.eval_results_dict: | ||
try: | ||
super_dict.update(self.format_eval_metrics()) | ||
except Exception as exc: | ||
logger.warning(f"Error while formatting evaluation metrics: {exc}") | ||
raise exc | ||
|
||
# Compute required formats for the during-training evaluation data | ||
if self.training_logs: | ||
try: | ||
super_dict.update(self.format_training_logs()) | ||
except Exception as exc: | ||
logger.warning(f"Error while formatting training logs: {exc}") | ||
|
||
super_dict["hide_eval_lines"] = len(self.training_logs) > 100 | ||
|
||
# Try to add the code carbon callback data | ||
if ( | ||
self.code_carbon_callback | ||
and self.code_carbon_callback.tracker | ||
and self.code_carbon_callback.tracker._start_time is not None | ||
): | ||
super_dict.update(self.get_codecarbon_data()) | ||
|
||
# Add some additional metadata stored in the model itself | ||
super_dict["document_length"] = self.model.document_length | ||
super_dict["query_length"] = self.model.query_length | ||
super_dict["output_dimensionality"] = ( | ||
self.model.get_sentence_embedding_dimension() | ||
) | ||
super_dict["model_string"] = str(self.model) | ||
if self.model.similarity_fn_name: | ||
super_dict["similarity_fn_name"] = { | ||
"cosine": "Cosine Similarity", | ||
"dot": "Dot Product", | ||
"euclidean": "Euclidean Distance", | ||
"manhattan": "Manhattan Distance", | ||
}.get( | ||
self.model.similarity_fn_name, | ||
self.model.similarity_fn_name.replace("_", " ").title(), | ||
) | ||
else: | ||
super_dict["similarity_fn_name"] = "Cosine Similarity" | ||
|
||
self.first_save = False | ||
|
||
for key in IGNORED_FIELDS: | ||
super_dict.pop(key, None) | ||
return super_dict | ||
|
||
|
||
def generate_model_card(model: SentenceTransformer) -> str: | ||
template_path = Path(__file__).parent / "model_card_template.md" | ||
model_card = ModelCard.from_template( | ||
card_data=model.model_card_data, template_path=template_path, hf_emoji="🐕" | ||
) | ||
return model_card.content |
Oops, something went wrong.