From 291db8066f437f3fab9e9baf96a78f3b525e4e78 Mon Sep 17 00:00:00 2001
From: Antoine Chaffin <38869395+NohTow@users.noreply.github.com>
Date: Mon, 25 Nov 2024 16:16:44 +0100
Subject: [PATCH] Creation of the PylateModelCard (#67)
* Creation of the PylateModelCard
* Fixing ruff top import
* Removing the example making tests fail
* Changing Sentence Transformer model default to PyLate
* Moving files to a dedicated subfolder
* Removing all the awfully copy-pasted redundant code to extend ST properly
* Adding init for hf_hub
* Changing docstring for automatic parsing documentation
* Consistency in the model args
* Adding save to tests to test saving/model card creation
---
pylate/hf_hub/__init__.py | 3 +
pylate/hf_hub/model_card.py | 259 +++++++++++++++++++++
pylate/hf_hub/model_card_template.md | 329 +++++++++++++++++++++++++++
pylate/models/colbert.py | 8 +-
tests/test_contrastive.py | 2 +
tests/test_kd.py | 2 +
6 files changed, 598 insertions(+), 5 deletions(-)
create mode 100644 pylate/hf_hub/__init__.py
create mode 100644 pylate/hf_hub/model_card.py
create mode 100644 pylate/hf_hub/model_card_template.md
diff --git a/pylate/hf_hub/__init__.py b/pylate/hf_hub/__init__.py
new file mode 100644
index 0000000..dae0b6a
--- /dev/null
+++ b/pylate/hf_hub/__init__.py
@@ -0,0 +1,3 @@
+from .model_card import PylateModelCardData
+
+__all__ = ["PylateModelCardData"]
diff --git a/pylate/hf_hub/model_card.py b/pylate/hf_hub/model_card.py
new file mode 100644
index 0000000..80cb330
--- /dev/null
+++ b/pylate/hf_hub/model_card.py
@@ -0,0 +1,259 @@
+from __future__ import annotations
+
+import logging
+from collections import defaultdict
+from dataclasses import dataclass, field, fields
+from pathlib import Path
+from platform import python_version
+from typing import TYPE_CHECKING, Any, Literal
+
+import torch
+import transformers
+from huggingface_hub import ModelCard
+from sentence_transformers import SentenceTransformerModelCardData
+from sentence_transformers import __version__ as sentence_transformers_version
+from sentence_transformers.util import (
+ is_accelerate_available,
+ is_datasets_available,
+)
+from torch import nn
+from transformers.integrations import CodeCarbonCallback
+
+from ..__version__ import __version__ as pylate_version
+
+if is_datasets_available():
+ pass
+
+logger = logging.getLogger(__name__)
+
+if TYPE_CHECKING:
+ from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
+ from sentence_transformers.SentenceTransformer import SentenceTransformer
+ from sentence_transformers.trainer import SentenceTransformerTrainer
+
+
+IGNORED_FIELDS = ["model", "trainer", "eval_results_dict"]
+
+
+def get_versions() -> dict[str, Any]:
+ versions = {
+ "python": python_version(),
+ "sentence_transformers": sentence_transformers_version,
+ "transformers": transformers.__version__,
+ "torch": torch.__version__,
+ "pylate": pylate_version,
+ }
+ if is_accelerate_available():
+ from accelerate import __version__ as accelerate_version
+
+ versions["accelerate"] = accelerate_version
+ if is_datasets_available():
+ from datasets import __version__ as datasets_version
+
+ versions["datasets"] = datasets_version
+ from tokenizers import __version__ as tokenizers_version
+
+ versions["tokenizers"] = tokenizers_version
+ return versions
+
+
+@dataclass
+class PylateModelCardData(SentenceTransformerModelCardData):
+ """
+ A dataclass for storing data used in the model card.
+
+ Parameters
+ ----------
+ language
+ The model language, either a string or a list of strings, e.g., "en" or ["en", "de", "nl"].
+ license
+ The license of the model, e.g., "apache-2.0", "mit", or "cc-by-nc-sa-4.0".
+ model_name
+ The pretty name of the model, e.g., "SentenceTransformer based on microsoft/mpnet-base".
+ model_id
+ The model ID for pushing the model to the Hub, e.g., "tomaarsen/sbert-mpnet-base-allnli".
+ train_datasets
+ A list of dictionaries containing names and/or Hugging Face dataset IDs for training datasets,
+ e.g., [{"name": "SNLI", "id": "stanfordnlp/snli"}, {"name": "MultiNLI", "id": "nyu-mll/multi_nli"}, {"name": "STSB"}].
+ eval_datasets
+ A list of dictionaries containing names and/or Hugging Face dataset IDs for evaluation datasets,
+ e.g., [{"name": "SNLI", "id": "stanfordnlp/snli"}, {"id": "mteb/stsbenchmark-sts"}].
+ task_name
+ The human-readable task the model is trained on, e.g., "semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more".
+ tags
+ A list of tags for the model, e.g., ["sentence-transformers", "sentence-similarity", "feature-extraction"].
+ """
+
+ # Potentially provided by the user
+ language: str | list[str] | None = field(default_factory=list)
+ license: str | None = None
+ model_name: str | None = None
+ model_id: str | None = None
+ train_datasets: list[dict[str, str]] = field(default_factory=list)
+ eval_datasets: list[dict[str, str]] = field(default_factory=list)
+ task_name: str = "semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more"
+ tags: list[str] | None = field(
+ default_factory=lambda: [
+ "ColBERT",
+ "PyLate",
+ "sentence-transformers",
+ "sentence-similarity",
+ "feature-extraction",
+ ]
+ )
+ generate_widget_examples: Literal["deprecated"] = "deprecated"
+
+ # Automatically filled by `ModelCardCallback` and the Trainer directly
+ base_model: str | None = field(default=None, init=False)
+ base_model_revision: str | None = field(default=None, init=False)
+ non_default_hyperparameters: dict[str, Any] = field(
+ default_factory=dict, init=False
+ )
+ all_hyperparameters: dict[str, Any] = field(default_factory=dict, init=False)
+ eval_results_dict: dict[SentenceEvaluator, dict[str, Any]] | None = field(
+ default_factory=dict, init=False
+ )
+ training_logs: list[dict[str, float]] = field(default_factory=list, init=False)
+ widget: list[dict[str, str]] = field(default_factory=list, init=False)
+ predict_example: str | None = field(default=None, init=False)
+ label_example_list: list[dict[str, str]] = field(default_factory=list, init=False)
+ code_carbon_callback: CodeCarbonCallback | None = field(default=None, init=False)
+ citations: dict[str, str] = field(default_factory=dict, init=False)
+ best_model_step: int | None = field(default=None, init=False)
+ trainer: SentenceTransformerTrainer | None = field(
+ default=None, init=False, repr=False
+ )
+ datasets: list[str] = field(default_factory=list, init=False, repr=False)
+
+ # Utility fields
+ first_save: bool = field(default=True, init=False)
+ widget_step: int = field(default=-1, init=False)
+
+ # Computed once, always unchanged
+ pipeline_tag: str = field(default="sentence-similarity", init=False)
+ library_name: str = field(default="PyLate", init=False)
+ version: dict[str, str] = field(default_factory=get_versions, init=False)
+
+ # Passed via `register_model` only
+ model: SentenceTransformer | None = field(default=None, init=False, repr=False)
+
+ def set_losses(self, losses: list[nn.Module]) -> None:
+ citations = {
+ "Sentence Transformers": """
+@inproceedings{reimers-2019-sentence-bert,
+ title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
+ author = "Reimers, Nils and Gurevych, Iryna",
+ booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
+ month = "11",
+ year = "2019",
+ publisher = "Association for Computational Linguistics",
+ url = "https://arxiv.org/abs/1908.10084"
+}""",
+ "PyLate": """
+@misc{PyLate,
+title={PyLate: Flexible Training and Retrieval for Late Interaction Models},
+author={Chaffin, Antoine and Sourty, Raphaël},
+url={https://github.com/lightonai/pylate},
+year={2024}
+}""",
+ }
+ for loss in losses:
+ try:
+ citations[loss.__class__.__name__] = loss.citation
+ except Exception:
+ pass
+ inverted_citations = defaultdict(list)
+ for loss, citation in citations.items():
+ inverted_citations[citation].append(loss)
+
+ def join_list(losses: list[str]) -> str:
+ if len(losses) > 1:
+ return ", ".join(losses[:-1]) + " and " + losses[-1]
+ return losses[0]
+
+ self.citations = {
+ join_list(losses): citation
+ for citation, losses in inverted_citations.items()
+ }
+ self.add_tags(
+ [
+ f"loss:{loss}"
+ for loss in {loss.__class__.__name__: loss for loss in losses}
+ ]
+ )
+
+ def to_dict(self) -> dict[str, Any]:
+ # Try to set the base model
+ if self.first_save and not self.base_model:
+ try:
+ self.try_to_set_base_model()
+ except Exception:
+ pass
+
+ # Set the model name
+ if not self.model_name:
+ if self.base_model:
+ self.model_name = f"PyLate model based on {self.base_model}"
+ else:
+ self.model_name = "PyLate"
+
+ super_dict = {field.name: getattr(self, field.name) for field in fields(self)}
+
+ # Compute required formats from the (usually post-training) evaluation data
+ if self.eval_results_dict:
+ try:
+ super_dict.update(self.format_eval_metrics())
+ except Exception as exc:
+ logger.warning(f"Error while formatting evaluation metrics: {exc}")
+ raise exc
+
+ # Compute required formats for the during-training evaluation data
+ if self.training_logs:
+ try:
+ super_dict.update(self.format_training_logs())
+ except Exception as exc:
+ logger.warning(f"Error while formatting training logs: {exc}")
+
+ super_dict["hide_eval_lines"] = len(self.training_logs) > 100
+
+ # Try to add the code carbon callback data
+ if (
+ self.code_carbon_callback
+ and self.code_carbon_callback.tracker
+ and self.code_carbon_callback.tracker._start_time is not None
+ ):
+ super_dict.update(self.get_codecarbon_data())
+
+ # Add some additional metadata stored in the model itself
+ super_dict["document_length"] = self.model.document_length
+ super_dict["query_length"] = self.model.query_length
+ super_dict["output_dimensionality"] = (
+ self.model.get_sentence_embedding_dimension()
+ )
+ super_dict["model_string"] = str(self.model)
+ if self.model.similarity_fn_name:
+ super_dict["similarity_fn_name"] = {
+ "cosine": "Cosine Similarity",
+ "dot": "Dot Product",
+ "euclidean": "Euclidean Distance",
+ "manhattan": "Manhattan Distance",
+ }.get(
+ self.model.similarity_fn_name,
+ self.model.similarity_fn_name.replace("_", " ").title(),
+ )
+ else:
+ super_dict["similarity_fn_name"] = "Cosine Similarity"
+
+ self.first_save = False
+
+ for key in IGNORED_FIELDS:
+ super_dict.pop(key, None)
+ return super_dict
+
+
+def generate_model_card(model: SentenceTransformer) -> str:
+ template_path = Path(__file__).parent / "model_card_template.md"
+ model_card = ModelCard.from_template(
+ card_data=model.model_card_data, template_path=template_path, hf_emoji="🐕"
+ )
+ return model_card.content
diff --git a/pylate/hf_hub/model_card_template.md b/pylate/hf_hub/model_card_template.md
new file mode 100644
index 0000000..0eeefcb
--- /dev/null
+++ b/pylate/hf_hub/model_card_template.md
@@ -0,0 +1,329 @@
+---
+# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1
+# Doc / guide: https://huggingface.co/docs/hub/model-cards
+{{ card_data }}
+---
+
+# {{ model_name if model_name else "PyLate model" }}
+
+This is a [PyLate](https://github.com/lightonai/pylate) model{% if base_model %} finetuned from [{{ base_model }}](https://huggingface.co/{{ base_model }}){% else %} trained{% endif %}{% if train_datasets | selectattr("name") | list %} on the {% for dataset in (train_datasets | selectattr("name")) %}{% if dataset.id %}[{{ dataset.name if dataset.name else dataset.id }}](https://huggingface.co/datasets/{{ dataset.id }}){% else %}{{ dataset.name }}{% endif %}{% if not loop.last %}{% if loop.index == (train_datasets | selectattr("name") | list | length - 1) %} and {% else %}, {% endif %}{% endif %}{% endfor %} dataset{{"s" if train_datasets | selectattr("name") | list | length > 1 else ""}}{% endif %}. It maps sentences & paragraphs to sequences of {{ output_dimensionality }}-dimensional dense vectors and can be used for semantic textual similarity using the MaxSim operator.
+
+## Model Details
+
+### Model Description
+- **Model Type:** PyLate model
+{% if base_model -%}
+ {%- if base_model_revision -%}
+ - **Base model:** [{{ base_model }}](https://huggingface.co/{{ base_model }})
+ {%- else -%}
+ - **Base model:** [{{ base_model }}](https://huggingface.co/{{ base_model }})
+ {%- endif -%}
+{%- else -%}
+
+{%- endif %}
+- **Document Length:** {{ document_length }} tokens
+- **Query Length:** {{ query_length }} tokens
+- **Output Dimensionality:** {{ output_dimensionality }} tokens
+- **Similarity Function:** MaxSim
+{% if train_datasets | selectattr("name") | list -%}
+ - **Training Dataset{{"s" if train_datasets | selectattr("name") | list | length > 1 else ""}}:**
+ {%- for dataset in (train_datasets | selectattr("name")) %}
+ {%- if dataset.id %}
+ - [{{ dataset.name if dataset.name else dataset.id }}](https://huggingface.co/datasets/{{ dataset.id }})
+ {%- else %}
+ - {{ dataset.name }}
+ {%- endif %}
+ {%- endfor %}
+{%- else -%}
+
+{%- endif %}
+{% if language -%}
+ - **Language{{"s" if language is not string and language | length > 1 else ""}}:**
+ {%- if language is string %} {{ language }}
+ {%- else %} {% for lang in language -%}
+ {{ lang }}{{ ", " if not loop.last else "" }}
+ {%- endfor %}
+ {%- endif %}
+{%- else -%}
+
+{%- endif %}
+{% if license -%}
+ - **License:** {{ license }}
+{%- else -%}
+
+{%- endif %}
+
+### Model Sources
+
+- **Documentation:** [PyLate Documentation](https://lightonai.github.io/pylate/)
+- **Repository:** [PyLate on GitHub](https://github.com/lightonai/pylate)
+- **Hugging Face:** [PyLate models on Hugging Face](https://huggingface.co/models?library=PyLate)
+
+### Full Model Architecture
+
+```
+{{ model_string }}
+```
+
+## Usage
+First install the PyLate library:
+
+```bash
+pip install -U pylate
+```
+
+### Retrieval
+
+PyLate provides a streamlined interface to index and retrieve documents using ColBERT models. The index leverages the Voyager HNSW index to efficiently handle document embeddings and enable fast retrieval.
+
+#### Indexing documents
+
+First, load the ColBERT model and initialize the Voyager index, then encode and index your documents:
+
+```python
+from pylate import indexes, models, retrieve
+
+# Step 1: Load the ColBERT model
+model = models.ColBERT(
+ model_name_or_path={{ model_id | default('pylate_model_id', true) }},
+)
+
+# Step 2: Initialize the Voyager index
+index = indexes.Voyager(
+ index_folder="pylate-index",
+ index_name="index",
+ override=True, # This overwrites the existing index if any
+)
+
+# Step 3: Encode the documents
+documents_ids = ["1", "2", "3"]
+documents = ["document 1 text", "document 2 text", "document 3 text"]
+
+documents_embeddings = model.encode(
+ documents,
+ batch_size=32,
+ is_query=False, # Ensure that it is set to False to indicate that these are documents, not queries
+ show_progress_bar=True,
+)
+
+# Step 4: Add document embeddings to the index by providing embeddings and corresponding ids
+index.add_documents(
+ documents_ids=documents_ids,
+ documents_embeddings=documents_embeddings,
+)
+```
+
+Note that you do not have to recreate the index and encode the documents every time. Once you have created an index and added the documents, you can re-use the index later by loading it:
+
+```python
+# To load an index, simply instantiate it with the correct folder/name and without overriding it
+index = indexes.Voyager(
+ index_folder="pylate-index",
+ index_name="index",
+)
+```
+
+#### Retrieving top-k documents for queries
+
+Once the documents are indexed, you can retrieve the top-k most relevant documents for a given set of queries.
+To do so, initialize the ColBERT retriever with the index you want to search in, encode the queries and then retrieve the top-k documents to get the top matches ids and relevance scores:
+
+```python
+# Step 1: Initialize the ColBERT retriever
+retriever = retrieve.ColBERT(index=index)
+
+# Step 2: Encode the queries
+queries_embeddings = model.encode(
+ ["query for document 3", "query for document 1"],
+ batch_size=32,
+ is_query=True, # # Ensure that it is set to False to indicate that these are queries
+ show_progress_bar=True,
+)
+
+# Step 3: Retrieve top-k documents
+scores = retriever.retrieve(
+ queries_embeddings=queries_embeddings,
+ k=10, # Retrieve the top 10 matches for each query
+)
+```
+
+### Reranking
+If you only want to use the ColBERT model to perform reranking on top of your first-stage retrieval pipeline without building an index, you can simply use rank function and pass the queries and documents to rerank:
+
+```python
+from pylate import rank, models
+
+queries = [
+ "query A",
+ "query B",
+]
+
+documents = [
+ ["document A", "document B"],
+ ["document 1", "document C", "document B"],
+]
+
+documents_ids = [
+ [1, 2],
+ [1, 3, 2],
+]
+
+model = models.ColBERT(
+ model_name_or_path={{ model_id | default('pylate_model_id', true) }},
+)
+
+queries_embeddings = model.encode(
+ queries,
+ is_query=True,
+)
+
+documents_embeddings = model.encode(
+ documents,
+ is_query=False,
+)
+
+reranked_documents = rank.rerank(
+ documents_ids=documents_ids,
+ queries_embeddings=queries_embeddings,
+ documents_embeddings=documents_embeddings,
+)
+```
+
+
+
+
+
+
+{% if eval_metrics %}
+## Evaluation
+
+### Metrics
+{% for metrics in eval_metrics %}
+#### {{ metrics.description }}
+{% if metrics.dataset_name %}* Dataset: `{{ metrics.dataset_name }}`{% endif %}
+* Evaluated with {% if metrics.class_name.startswith("sentence_transformers.") %}[{{ metrics.class_name.split(".")[-1] }}
](https://sbert.net/docs/package_reference/sentence_transformer/evaluation.html#sentence_transformers.evaluation.{{ metrics.class_name.split(".")[-1] }}){% else %}{{ metrics.class_name }}
{% endif %}
+
+{{ metrics.table }}
+{%- endfor %}{% endif %}
+
+
+
+
+## Training Details
+{% for dataset_type, dataset_list in [("training", train_datasets), ("evaluation", eval_datasets)] %}{% if dataset_list %}
+### {{ dataset_type.title() }} Dataset{{"s" if dataset_list | length > 1 else ""}}
+{% for dataset in dataset_list %}
+#### {{ dataset['name'] or 'Unnamed Dataset' }}
+
+{% if dataset['name'] %}* Dataset: {% if 'id' in dataset %}[{{ dataset['name'] }}](https://huggingface.co/datasets/{{ dataset['id'] }}){% else %}{{ dataset['name'] }}{% endif %}
+{%- if 'revision' in dataset and 'id' in dataset %} at [{{ dataset['revision'][:7] }}](https://huggingface.co/datasets/{{ dataset['id'] }}/tree/{{ dataset['revision'] }}){% endif %}{% endif %}
+{% if dataset['size'] %}* Size: {{ "{:,}".format(dataset['size']) }} {{ dataset_type }} samples
+{% endif %}* Columns: {% if dataset['columns'] | length == 1 %}{{ dataset['columns'][0] }}{% elif dataset['columns'] | length == 2 %}{{ dataset['columns'][0] }} and {{ dataset['columns'][1] }}{% else %}{{ dataset['columns'][:-1] | join(', ') }}, and {{ dataset['columns'][-1] }}{% endif %}
+{% if dataset['stats_table'] %}* Approximate statistics based on the first {{ [dataset['size'], 1000] | min }} samples:
+{{ dataset['stats_table'] }}{% endif %}{% if dataset['examples_table'] %}* Samples:
+{{ dataset['examples_table'] }}{% endif %}* Loss: {% if dataset["loss"]["fullname"].startswith("sentence_transformers.") %}[{{ dataset["loss"]["fullname"].split(".")[-1] }}
](https://sbert.net/docs/package_reference/sentence_transformer/losses.html#{{ dataset["loss"]["fullname"].split(".")[-1].lower() }}){% else %}{{ dataset["loss"]["fullname"] }}
{% endif %}{% if "config_code" in dataset["loss"] %} with these parameters:
+{{ dataset["loss"]["config_code"] }}{% endif %}
+{% endfor %}{% endif %}{% endfor -%}
+
+{% if all_hyperparameters %}
+### Training Hyperparameters
+{% if non_default_hyperparameters -%}
+#### Non-Default Hyperparameters
+
+{% for name, value in non_default_hyperparameters.items() %}- `{{ name }}`: {{ value }}
+{% endfor %}{%- endif %}
+#### All Hyperparameters
+Click to expand
+
+{% for name, value in all_hyperparameters.items() %}- `{{ name }}`: {{ value }}
+{% endfor %}
+
+{% endif %}
+
+{%- if eval_lines %}
+### Training Logs
+{% if hide_eval_lines %}Click to expand
+
+{% endif -%}
+{{ eval_lines }}{% if explain_bold_in_eval %}
+* The bold row denotes the saved checkpoint.{% endif %}
+{%- if hide_eval_lines %}
+ {% endif %}
+{% endif %}
+
+{%- if co2_eq_emissions %}
+### Environmental Impact
+Carbon emissions were measured using [CodeCarbon](https://github.com/mlco2/codecarbon).
+- **Energy Consumed**: {{ "%.3f"|format(co2_eq_emissions["energy_consumed"]) }} kWh
+- **Carbon Emitted**: {{ "%.3f"|format(co2_eq_emissions["emissions"] / 1000) }} kg of CO2
+- **Hours Used**: {{ co2_eq_emissions["hours_used"] }} hours
+
+### Training Hardware
+- **On Cloud**: {{ "Yes" if co2_eq_emissions["on_cloud"] else "No" }}
+- **GPU Model**: {{ co2_eq_emissions["hardware_used"] or "No GPU used" }}
+- **CPU Model**: {{ co2_eq_emissions["cpu_model"] }}
+- **RAM Size**: {{ "%.2f"|format(co2_eq_emissions["ram_total_size"]) }} GB
+{% endif %}
+### Framework Versions
+- Python: {{ version["python"] }}
+- Sentence Transformers: {{ version["sentence_transformers"] }}
+- PyLate: {{ version["pylate"] }}
+- Transformers: {{ version["transformers"] }}
+- PyTorch: {{ version["torch"] }}
+- Accelerate: {{ version["accelerate"] }}
+- Datasets: {{ version["datasets"] }}
+- Tokenizers: {{ version["tokenizers"] }}
+
+
+## Citation
+
+### BibTeX
+{% for loss_name, citation in citations.items() %}
+#### {{ loss_name }}
+```bibtex
+{{ citation | trim }}
+```
+{% endfor %}
+
+
+
+
+
\ No newline at end of file
diff --git a/pylate/models/colbert.py b/pylate/models/colbert.py
index ca64cbb..c57cb3b 100644
--- a/pylate/models/colbert.py
+++ b/pylate/models/colbert.py
@@ -13,10 +13,6 @@
from numpy import ndarray
from scipy.cluster import hierarchy
from sentence_transformers import SentenceTransformer
-from sentence_transformers.model_card import (
- SentenceTransformerModelCardData,
- generate_model_card,
-)
from sentence_transformers.models import Dense as DenseSentenceTransformer
from sentence_transformers.models import Transformer
from sentence_transformers.quantization import quantize_embeddings
@@ -25,6 +21,7 @@
from torch import nn
from tqdm.autonotebook import trange
+from ..hf_hub.model_card import PylateModelCardData, generate_model_card
from ..utils import _start_multi_process_pool
from .Dense import Dense
@@ -217,7 +214,7 @@ def __init__(
model_kwargs: dict | None = None,
tokenizer_kwargs: dict | None = None,
config_kwargs: dict | None = None,
- model_card_data: Optional[SentenceTransformerModelCardData] = None,
+ model_card_data: PylateModelCardData | None = None,
) -> None:
self.query_prefix = query_prefix
self.document_prefix = document_prefix
@@ -225,6 +222,7 @@ def __init__(
self.document_length = document_length
self.attend_to_expansion_tokens = attend_to_expansion_tokens
self.skiplist_words = skiplist_words
+ model_card_data = model_card_data or PylateModelCardData()
super(ColBERT, self).__init__(
model_name_or_path=model_name_or_path,
diff --git a/tests/test_contrastive.py b/tests/test_contrastive.py
index 4edbbb7..52e1a8b 100644
--- a/tests/test_contrastive.py
+++ b/tests/test_contrastive.py
@@ -65,6 +65,8 @@ def test_contrastive_training() -> None:
trainer.train()
+ model.save_pretrained("tests/contrastive/final")
+
assert os.path.isdir("tests/contrastive")
metrics = dev_evaluation(
diff --git a/tests/test_kd.py b/tests/test_kd.py
index 139876a..0cc02b8 100644
--- a/tests/test_kd.py
+++ b/tests/test_kd.py
@@ -63,6 +63,8 @@ def test_kd_training() -> None:
trainer.train()
+ model.save_pretrained("tests/kd/final")
+
assert os.path.isdir("tests/kd")
if os.path.exists(path="tests/kd"):