diff --git a/.gitignore b/.gitignore
index c013e51..e4ecd84 100644
--- a/.gitignore
+++ b/.gitignore
@@ -158,4 +158,4 @@ evaluation_datasets/
 *.tsv
 
 /test-model/
-
+colbert-training/
\ No newline at end of file
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..e82405f
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,17 @@
+test:
+	pytest pylate
+	pytest tests
+
+ruff:
+	ruff format pylate
+
+lint:
+	ruff check pylate
+
+livedoc:
+	python docs/parse
+	mkdocs build --clean
+	mkdocs serve --dirtyreload
+
+deploydoc:
+	mkdocs gh-deploy --force
\ No newline at end of file
diff --git a/README.md b/README.md
index 631cb90..d3b369a 100644
--- a/README.md
+++ b/README.md
@@ -1,37 +1,72 @@
-# giga-cherche
+<div align="center">
+  <h1>PyLate</h1>
+  <p>Efficient training and retrieval with ColBERT</p>
+</div>
 
-giga-cherche is a library based on [sentence-transformers](https://github.com/UKPLab/sentence-transformers) to train and use ColBERT models.
+<p align="center"><img width=500 src="docs/img/logo.png"/></p>
 
-# Installation
+<div align="center">
+  <!-- Documentation -->
+  <a href="https://github.com/lightonai/pylate"><img src="https://img.shields.io/badge/Documentation-purple.svg?style=flat-square" alt="documentation"></a>
+  <!-- License -->
+  <a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-blue.svg?style=flat-square" alt="license"></a>
+</div>
 
-giga-cherche can be installed by running the setup.py file with the needed extras from the following list:
-- ```index``` if you want to use the proposed indexes
-- ```eval``` if you need to run BEIR evaluations
-- ```dev``` if you want to contribute to the repository
-  
-For example, to run the BEIR evaluations using giga-cherche indexes:
-```python setup.py install --extras eval, index```
 
-# Modeling
-The modeling of giga-cherche is based on sentence-transformers which allow to build a ColBERT model from any encoder available by appending a projection layer applied to the output of the encoders to reduce the embeddings dimension. 
+
+PyLate is a library built on top of Sentence Transformers, designed to simplify and optimize training, inference, and retrieval using ColBERT models. With PyLate, you can efficiently train ColBERT models on Triplet loss or Knowledge Distillation and deploy them for document retrieval tasks with ease.
+
+## Installation
+
+We can install pylate using:
+
+```bash
+pip install pylate
 ```
-from pylate import models
-model_name = "bert-base-uncased"
-model = models.ColBERT(model_name_or_path=model_name)
+
+Install with evaluation dependencies:
+
+```bash
+pip install "pylate[eval]"
+```
+
+## Documentation 
+
+The complete documentation is available [here](https://lightonai.github.io/pylate/), which includes in-depth guides, examples, and API references.
+
+## Datasets
+
+PyLate supports Hugging Face [Datasets](https://huggingface.co/docs/datasets/en/index), enabling seamless triplet / knowledge distillation based training. Below is an example of creating a custom dataset for training:
+
+```python
+from datasets import Dataset
+
+dataset = [
+    {
+        "query": "example query 1",
+        "positive": "example positive document 1",
+        "negative": "example negative document 1",
+    },
+    {
+        "query": "example query 2",
+        "positive": "example positive document 2",
+        "negative": "example negative document 2",
+    },
+    {
+        "query": "example query 3",
+        "positive": "example positive document 3",
+        "negative": "example negative document 3",
+    },
+]
+
+dataset = Dataset.from_list(mapping=dataset)
+
+train_dataset, test_dataset = dataset.train_test_split(test_size=0.3)
 ```
-The following parameters can be passed to the constructor to set different properties of the model:
-- ```embedding_size```, the output size of the projection layer and so the dimension of the embeddings
-- ```query_prefix```, the string version of the query marker to be prepended when encoding queries
-- ```document_prefix```, the string version of the document marker to be prepended when encoding documents
-- ```query_length```, the length of the query to truncate / pad to with mask tokens
-- ```document_length```, the length of the document to truncate
-- ```attend_to_expansion_tokens```, whether queries tokens should attend to MASK expansion tokens (original ColBERT did not)
-- ```skiplist_words```, a list of words to ignore in documents during scoring (default to punctuation)
 
 ## Training
 
-Given that giga-cherche ColBERT models are sentence-transformers models, we can benefit from all the bells and whistles from the latest update, including multi-gpu and BF16 training.
-For now, you can train ColBERT models using triplets dataset (datasets containing a positive and a negative for each query). The syntax is the same as sentence-transformers, using the specific elements adapted to ColBERT from giga-cherche:
+Here’s a simple example of training a ColBERT model on the MSMARCO dataset using PyLate. This script demonstrates training with triplet loss and evaluating the model on a test set.
 
 ```python
 from datasets import load_dataset
@@ -39,36 +74,50 @@ from sentence_transformers import (
     SentenceTransformerTrainer,
     SentenceTransformerTrainingArguments,
 )
+from sentence_transformers.training_args import BatchSamplers
 
-from pylate import losses, models, datasets, evaluation
+from pylate import evaluation, losses, models, utils
 
-model_name = "bert-base-uncased"
-batch_size = 32
-num_train_epochs = 1
-output_dir = "colbert_base"
-
-model = models.ColBERT(model_name_or_path=model_name)
+# Define the model
+model = models.ColBERT(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
 
+# Load dataset
 dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")
-splits = dataset.train_test_split(test_size=0.1)
-train_dataset = splits["train"]
-eval_dataset = splits["test"]
 
-train_loss = losses.ColBERT(model=model)
+# Split the dataset to create a test set
+train_dataset, eval_dataset = dataset.train_test_split(test_size=0.01)
+
+# Shuffle and select a subset of the dataset for demonstration purposes
+MAX_TRAIN_SIZE, MAX_EVAL_SIZE = 100, 100
+train_dataset = train_dataset.shuffle(seed=21).select(range(MAX_TRAIN_SIZE))
+eval_dataset = eval_dataset.shuffle(seed=21).select(range(MAX_EVAL_SIZE))
 
+# Define the loss function
+train_loss = losses.Contrastive(model=model)
+
+args = SentenceTransformerTrainingArguments(
+    output_dir="colbert-training",
+    num_train_epochs=1,
+    per_device_train_batch_size=32,
+    per_device_eval_batch_size=32,
+    fp16=False,  # Some GPUs support FP16 which is faster than FP32
+    bf16=False,  # Some GPUs support BF16 which is a faster FP16
+    batch_sampler=BatchSamplers.NO_DUPLICATES,
+    # Tracking parameters:
+    eval_strategy="steps",
+    eval_steps=0.1,
+    save_strategy="steps",
+    save_steps=5000,
+    save_total_limit=2,
+    learning_rate=3e-6,
+)
+
+# Evaluation procedure
 dev_evaluator = evaluation.ColBERTTripletEvaluator(
     anchors=eval_dataset["query"],
     positives=eval_dataset["positive"],
     negatives=eval_dataset["negative"],
 )
-args = SentenceTransformerTrainingArguments(
-    output_dir=output_dir,
-    num_train_epochs=num_train_epochs,
-    per_device_train_batch_size=batch_size,
-    per_device_eval_batch_size=batch_size,
-    bf16=True,
-    learning_rate=3e-6,
-)
 
 trainer = SentenceTransformerTrainer(
     model=model,
@@ -77,174 +126,125 @@ trainer = SentenceTransformerTrainer(
     eval_dataset=eval_dataset,
     loss=train_loss,
     evaluator=dev_evaluator,
-    data_collator=utils.ColBERTCollator(model.tokenize),
+    data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
 )
 
 trainer.train()
-```
 
-## Tokenization
-
-```
-import ast 
-
-def add_queries_and_documents(Examples dict) -> dict:
-    """Add queries and documents text to the examples."""
-    scores = ast.literal_eval(node_or_string=example["scores"])
-    processed_example = {"scores": scores, "query": queries[example["query_id"]]}
-
-    n_scores = len(scores)
-    for i in range(n_scores):
-        processed_example[f"document_{i}"] = documents[example[f"document_id_{i}"]]
-    
-    return processed_example
+model.save_pretrained("custom-colbert-model")
 ```
 
-##  Inference
-Once trained, the model can then be loaded to perform inference (you can also load the models directly from Hugging Face, for example using the provided ColBERTv2 model [NohTow/colbertv2_sentence_transformer](https://huggingface.co/NohTow/colbertv2_sentence_transformer)):
+After training, the model can be loaded like this:
 
 ```python
-model = ColBERT(
-    "NohTow/colbertv2_sentence_transformer",
-)
-```
-
-You can then call the ```encode``` function to get the embeddings corresponding to your queries:
-
-```python
-queries_embeddings = model.encode(
-        ["Who is the president of the USA?", "When was the last president of the USA elected?"],
-    )
-```
-
-When encoding documents, simply set the ```is_query``` parameter to false:
+from pylate import models
 
-```python
-documents_embeddings = model.encode(
-        ["Joseph Robinette Biden Jr. is an American politician who is the 46th and current president of the United States since 2021. A member of the Democratic Party, he previously served as the 47th vice president from 2009 to 2017 under President Barack Obama and represented Delaware in the United States Senate from 1973 to 2009.", "Donald John Trump (born June 14, 1946) is an American politician, media personality, and businessman who served as the 45th president of the United States from 2017 to 2021."],
-        is_query=False,
-    )
+model = models.ColBERT(model_name_or_path="custom-colbert-model")
 ```
 
-By default, this will return a list of numpy arrays containing the different embeddings of each sequence in the batch. You can pass the argument ```convert_to_tensor=True``` to get a list of tensors.
+##  Retrieve
 
-We also provide the option to pool the document embeddings using hierarchical clustering. Our recent study showed that we can pool the document embeddings by a factor of 2 to halve the memory consumption of the embeddings without degrading performance. This is done by feeding ```pool_factor=2```to the encode function. Bigger pooling values can be used to obtain different size/performance trade-offs.
-Note that query embeddings cannot be pooled.
-
-You can then compute the ColBERT max-sim scores like this:
+PyLate allows easy retrieval of top documents for a given query set using the trained ColBERT model and Voyager index.
 
 ```python
-from pylate import scores
-similarity_scores = scores.colbert_scores(query_embeddings, document_embeddings)
-```
+from pylate import indexes, models, retrieve
 
-## Indexing
+model = models.ColBERT(
+    model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
+)
 
-We provide a ColBERT index based on the [Weaviate vectordb](https://weaviate.io/). To speed-up the processing, the latest async client is used and the document candidates are generated using an HNSW index, which replace the IVF index from the original ColBERT. 
+index = indexes.Voyager(
+    index_folder="pylate-index",
+    index_name="index",
+    override=True,
+)
 
-Before being able to create and use an index, you need to need to launch the Weaviate server using Docker (```docker compose up```).
+retriever = retrieve.ColBERT(index=index)
+```
 
-To populate an index, simply create it and then add the computed embeddings with their corresponding ids:
+Once the model and index are set up, we can add documents to the index:
 
 ```python
-from pylate import indexes
+documents_ids = ["1", "2", "3"]
 
-index = indexes.Weaviate(name="test_index")
+documents = [
+    "document 1 text", "document 2 text", "document 3 text"
+]
 
+# Encode the documents
 documents_embeddings = model.encode(
-    ["Document text 1", "Document text 2"],
-    is_query=False,
+    documents,
+    batch_size=32,
+    is_query=False, # Encoding documents
+    show_progress_bar=True,
 )
 
+# Add the documents ids and embeddings to the Voyager index
 index.add_documents(
-    doc_ids=["1", "2"],
-    doc_embeddings=documents_embeddings,
+    documents_ids=documents_ids,
+    documents_embeddings=documents_embeddings,
 )
 ```
 
-We can also remove documents from the index using their ids:
-
-```python
-index.remove_documents(["1"])
-```
-
-To retrieve documents from the index, you can use the following code snippet:
+Then we can retrieve the top-k documents for a given query set:
 
 ```python
-from pylate import retrieve
-
-retriever = retrieve.ColBERT(Weaviate)
-
 queries_embeddings = model.encode(
-    ["A query related to the documents", "Another query"],
+    ["query for document 3", "query for document 1"],
+    batch_size=32,
+    is_query=True, # Encoding queries
+    show_progress_bar=True,
 )
 
-retrieved_chunks = retriever.retrieve(queries_embeddings, k=10)
-```
-
-You can also simply rerank a list of ids produced by an upstream retrieval module (such as BM25):
-
-```python
-from pylate import rerank
-
-reranker = rerank.ColBERT(Weaviate)
-
-reranked_chunks = reranker.rerank(
-    queries_embeddings, batch_doc_ids=[["7912", "4983"], ["8726", "7891"]]
+scores = retriever.retrieve(
+    queries_embeddings=queries_embeddings, 
+    k=10,
 )
-```
 
-## Evaluation
+print(scores)
+```
 
-We can eavaluate the performance of the model using the BEIR evaluation framework. The following code snippet shows how to evaluate the model on the SciFact dataset:
+Sample Output:
 
 ```python
-from pylate import evaluation, indexes, models, retrieve, utils
-
-model = models.ColBERT(
-    model_name_or_path="NohTow/colbertv2_sentence_transformer",
-)
-index = indexes.Weaviate(recreate=True, max_doc_length=model.document_length)
+[
+    [
+        {"id": "3", "score": 11.266985893249512},
+        {"id": "1", "score": 10.303335189819336},
+        {"id": "2", "score": 9.502392768859863},
+    ],
+    [
+        {"id": "1", "score": 10.88800048828125},
+        {"id": "3", "score": 9.950843811035156},
+        {"id": "2", "score": 9.602447509765625},
+    ],
+]
+```
 
-retriever = retrieve.ColBERT(index=index)
+## Contributing
 
-# Input dataset for evaluation
-documents, queries, qrels = evaluation.load_beir(
-    dataset_name="scifact",
-    split="test",
-)
+We welcome contributions! To get started:
 
+1. Install the development dependencies:
 
-for batch in utils.iter_batch(documents, batch_size=500):
-    documents_embeddings = model.encode(
-        sentences=[document["text"] for document in batch],
-        convert_to_numpy=True,
-        is_query=False,
-    )
+```bash
+pip install "pylate[dev]"
+```
 
-    index.add_documents(
-        doc_ids=[document["id"] for document in batch],
-        doc_embeddings=documents_embeddings,
-    )
+2. Run tests:
 
+```bash
+make test
+```
 
-scores = []
-for batch in utils.iter_batch(queries, batch_size=5):
-    queries_embeddings = model.encode(
-        sentences=[query["text"] for query in batch],
-        convert_to_numpy=True,
-        is_query=True,
-    )
+3. Format code with Ruff:
 
-    scores.extend(retriever.retrieve(queries=queries_embeddings, k=10))
+```bash
+make ruff
+```
 
+4. Build the documentation:
 
-print(
-    evaluation.evaluate(
-        scores=scores,
-        qrels=qrels,
-        queries=queries,
-        metrics=["map", "ndcg@10", "ndcg@100", "recall@10", "recall@100"],
-    )
-)
+```bash
+make livedoc
 ```
\ No newline at end of file
diff --git a/docs/.pages b/docs/.pages
new file mode 100644
index 0000000..d2d970d
--- /dev/null
+++ b/docs/.pages
@@ -0,0 +1,4 @@
+nav:
+    - documentation
+    - benchmarks
+    - api
\ No newline at end of file
diff --git a/docs/CNAME b/docs/CNAME
new file mode 100644
index 0000000..4173ebd
--- /dev/null
+++ b/docs/CNAME
@@ -0,0 +1 @@
+lighton.github.io/pylate/
\ No newline at end of file
diff --git a/docs/api/.pages b/docs/api/.pages
new file mode 100644
index 0000000..c2ca59a
--- /dev/null
+++ b/docs/api/.pages
@@ -0,0 +1,4 @@
+title: API reference
+arrange:
+  - overview.md
+  - ...
diff --git a/docs/api/evaluation/.pages b/docs/api/evaluation/.pages
new file mode 100644
index 0000000..86f0c6e
--- /dev/null
+++ b/docs/api/evaluation/.pages
@@ -0,0 +1 @@
+title: evaluation
\ No newline at end of file
diff --git a/docs/api/evaluation/ColBERTDistillationEvaluator.md b/docs/api/evaluation/ColBERTDistillationEvaluator.md
new file mode 100644
index 0000000..e68a94e
--- /dev/null
+++ b/docs/api/evaluation/ColBERTDistillationEvaluator.md
@@ -0,0 +1,110 @@
+# ColBERTDistillationEvaluator
+
+ColBERT Distillation Evaluator. This class is used to monitor the distillation process of a ColBERT model.
+
+
+
+## Parameters
+
+- **queries** (*list[str]*)
+
+    Set of queries.
+
+- **documents** (*list[list[str]]*)
+
+    Set of documents. Each query has a list of documents. Each document is a list of strings. Number of documents should be the same for each query.
+
+- **scores** (*list[list[float]]*)
+
+    The scores associated with the documents. Each query / documents pairs has a list of scores.
+
+- **name** (*str*) – defaults to ``
+
+    The name of the evaluator.
+
+- **batch_size** (*int*) – defaults to `16`
+
+    The batch size.
+
+- **show_progress_bar** (*bool*) – defaults to `False`
+
+    Whether to show the progress bar.
+
+- **write_csv** (*bool*) – defaults to `True`
+
+    Whether to write the results to a CSV file.
+
+- **truncate_dim** (*int | None*) – defaults to `None`
+
+    The dimension to truncate the embeddings.
+
+- **normalize_scores** (*bool*) – defaults to `True`
+
+
+## Attributes
+
+- **description**
+
+    Returns a human-readable description of the evaluator: BinaryClassificationEvaluator -> Binary Classification  1. Remove "Evaluator" from the class name 2. Add a space before every capital letter
+
+
+## Examples
+
+```python
+>>> from pylate import models, evaluation
+
+>>> model = models.ColBERT(
+...     model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu"
+... )
+
+>>> queries = [
+...     "query A",
+...     "query B",
+... ]
+
+>>> documents = [
+...     ["document A", "document B", "document C"],
+...     ["document C C", "document B B", "document A A"],
+... ]
+
+>>> scores = [
+...     [0.9, 0.1, 0.05],
+...     [0.05, 0.9, 0.1],
+... ]
+
+>>> distillation_evaluator = evaluation.ColBERTDistillationEvaluator(
+...     queries=queries,
+...     documents=documents,
+...     scores=scores,
+...     write_csv=True,
+... )
+
+>>> results = distillation_evaluator(model=model, output_path=".")
+
+>>> assert "kl_divergence" in results
+>>> assert isinstance(results["kl_divergence"], float)
+
+>>> import pandas as pd
+>>> df = pd.read_csv(distillation_evaluator.csv_file)
+>>> assert df.columns.tolist() == distillation_evaluator.csv_headers
+```
+
+## Methods
+
+???- note "__call__"
+
+    This is called during training to evaluate the model. It returns a score for the evaluation with a higher score indicating a better result.
+
+    Args:     model: the model to evaluate     output_path: path where predictions and metrics are written         to     epoch: the epoch where the evaluation takes place. This is         used for the file prefixes. If this is -1, then we         assume evaluation on test data.     steps: the steps in the current epoch at time of the         evaluation. This is used for the file prefixes. If this         is -1, then we assume evaluation at the end of the         epoch.  Returns:     Either a score for the evaluation with a higher score     indicating a better result, or a dictionary with scores. If     the latter is chosen, then `evaluator.primary_metric` must     be defined
+
+    **Parameters**
+
+    - **model**     (*'SentenceTransformer'*)    
+    - **output_path**     (*str*)     – defaults to `None`    
+    - **epoch**     (*int*)     – defaults to `-1`    
+    - **steps**     (*int*)     – defaults to `-1`    
+    
+???- note "prefix_name_to_metrics"
+
+???- note "store_metrics_in_model_card_data"
+
diff --git a/docs/api/evaluation/ColBERTTripletEvaluator.md b/docs/api/evaluation/ColBERTTripletEvaluator.md
new file mode 100644
index 0000000..312a731
--- /dev/null
+++ b/docs/api/evaluation/ColBERTTripletEvaluator.md
@@ -0,0 +1,112 @@
+# ColBERTTripletEvaluator
+
+Evaluate a model based on a set of triples. The evaluation will compare the score between the anchor and the positive sample with the score between the anchor and the negative sample. The accuracy is computed as the number of times the score between the anchor and the positive sample is higher than the score between the anchor and the negative sample.
+
+
+
+## Parameters
+
+- **anchors** (*list[str]*)
+
+    Sentences to check similarity to. (e.g. a query)
+
+- **positives** (*list[str]*)
+
+    List of positive sentences
+
+- **negatives** (*list[str]*)
+
+    List of negative sentences
+
+- **name** (*str*) – defaults to ``
+
+    Name for the output.
+
+- **batch_size** (*int*) – defaults to `32`
+
+    Batch size used to compute embeddings.
+
+- **show_progress_bar** (*bool*) – defaults to `False`
+
+    If true, prints a progress bar.
+
+- **write_csv** (*bool*) – defaults to `True`
+
+    Wether or not to write results to a CSV file.
+
+- **truncate_dim** (*int | None*) – defaults to `None`
+
+    The dimension to truncate sentence embeddings to. If None, do not truncate.
+
+
+## Attributes
+
+- **description**
+
+    Returns a human-readable description of the evaluator: BinaryClassificationEvaluator -> Binary Classification  1. Remove "Evaluator" from the class name 2. Add a space before every capital letter
+
+
+## Examples
+
+```python
+>>> from pylate import evaluation, models
+
+>>> model = models.ColBERT(
+...     model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
+...     device="cpu",
+... )
+
+>>> anchors = [
+...     "fruits are healthy.",
+...     "fruits are healthy.",
+... ]
+
+>>> positives = [
+...     "fruits are good for health.",
+...     "Fruits are growing in the trees.",
+... ]
+
+>>> negatives = [
+...     "Fruits are growing in the trees.",
+...     "fruits are good for health.",
+... ]
+
+>>> triplet_evaluation = evaluation.ColBERTTripletEvaluator(
+...     anchors=anchors,
+...     positives=positives,
+...     negatives=negatives,
+...     write_csv=True,
+... )
+
+>>> results = triplet_evaluation(model=model, output_path=".")
+
+>>> results
+{'accuracy': 0.5}
+
+>>> triplet_evaluation.csv_headers
+['epoch', 'steps', 'accuracy']
+
+>>> import pandas as pd
+>>> df = pd.read_csv(triplet_evaluation.csv_file)
+>>> assert df.columns.tolist() == triplet_evaluation.csv_headers
+```
+
+## Methods
+
+???- note "__call__"
+
+    Evaluate the model on the triplet dataset. Measure the scoring between the anchor and the positive with every other positive and negative samples using HITS@K.
+
+    **Parameters**
+
+    - **model**     (*pylate.models.colbert.ColBERT*)    
+    - **output_path**     (*str*)     – defaults to `None`    
+    - **epoch**     (*int*)     – defaults to `-1`    
+    - **steps**     (*int*)     – defaults to `-1`    
+    
+???- note "from_input_examples"
+
+???- note "prefix_name_to_metrics"
+
+???- note "store_metrics_in_model_card_data"
+
diff --git a/docs/api/evaluation/evaluate.md b/docs/api/evaluation/evaluate.md
new file mode 100644
index 0000000..6ed3fc6
--- /dev/null
+++ b/docs/api/evaluation/evaluate.md
@@ -0,0 +1,25 @@
+# evaluate
+
+Evaluate candidates matchs.
+
+
+
+## Parameters
+
+- **scores** (*list[list[dict]]*)
+
+- **qrels** (*dict*)
+
+    Qrels.
+
+- **queries** (*list[str]*)
+
+    index of queries of qrels.
+
+- **metrics** (*list | None*) – defaults to `None`
+
+    Metrics to compute.
+
+
+
+
diff --git a/docs/api/evaluation/get-beir-triples.md b/docs/api/evaluation/get-beir-triples.md
new file mode 100644
index 0000000..61481a1
--- /dev/null
+++ b/docs/api/evaluation/get-beir-triples.md
@@ -0,0 +1,40 @@
+# get_beir_triples
+
+Build BEIR triples.
+
+
+
+## Parameters
+
+- **documents** (*list*)
+
+    Documents.
+
+- **queries** (*list[str]*)
+
+    Queries.
+
+- **qrels** (*dict*)
+
+
+
+## Examples
+
+```python
+>>> from pylate import evaluation
+
+>>> documents, queries, qrels = evaluation.load_beir(
+...     "scifact",
+...     split="test",
+... )
+
+>>> triples = evaluation.get_beir_triples(
+...     documents=documents,
+...     queries=queries,
+...     qrels=qrels
+... )
+
+>>> len(triples)
+339
+```
+
diff --git a/docs/api/evaluation/load-beir.md b/docs/api/evaluation/load-beir.md
new file mode 100644
index 0000000..56c9777
--- /dev/null
+++ b/docs/api/evaluation/load-beir.md
@@ -0,0 +1,38 @@
+# load_beir
+
+Load BEIR dataset.
+
+
+
+## Parameters
+
+- **dataset_name** (*str*)
+
+    Name of the beir dataset.
+
+- **split** (*str*) – defaults to `test`
+
+    Split to load.
+
+
+
+## Examples
+
+```python
+>>> from pylate import evaluation
+
+>>> documents, queries, qrels = evaluation.load_beir(
+...     "scifact",
+...     split="test",
+... )
+
+>>> len(documents)
+5183
+
+>>> len(queries)
+300
+
+>>> len(qrels)
+300
+```
+
diff --git a/docs/api/indexes/.pages b/docs/api/indexes/.pages
new file mode 100644
index 0000000..e163c68
--- /dev/null
+++ b/docs/api/indexes/.pages
@@ -0,0 +1 @@
+title: indexes
\ No newline at end of file
diff --git a/docs/api/indexes/Voyager.md b/docs/api/indexes/Voyager.md
new file mode 100644
index 0000000..aaa36fc
--- /dev/null
+++ b/docs/api/indexes/Voyager.md
@@ -0,0 +1,121 @@
+# Voyager
+
+Voyager index. The Voyager index is a fast and efficient index for approximate nearest neighbor search.
+
+
+
+## Parameters
+
+- **index_folder** (*str*) – defaults to `indexes`
+
+- **index_name** (*str*) – defaults to `colbert`
+
+- **override** (*bool*) – defaults to `False`
+
+    Whether to override the collection if it already exists.
+
+- **embedding_size** (*int*) – defaults to `128`
+
+    The number of dimensions of the embeddings.
+
+- **M** (*int*) – defaults to `64`
+
+    The number of subquantizers.
+
+- **ef_construction** (*int*) – defaults to `200`
+
+    The number of candidates to evaluate during the construction of the index.
+
+- **ef_search** (*int*) – defaults to `200`
+
+    The number of candidates to evaluate during the search.
+
+
+
+## Examples
+
+```python
+>>> from pylate import indexes, models
+
+>>> index = indexes.Voyager(
+...     index_folder="test_indexes",
+...     index_name="colbert",
+...     override=True,
+...     embedding_size=128,
+... )
+
+>>> model = models.ColBERT(
+...     model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
+... )
+
+>>> documents_embeddings = model.encode(
+...     ["fruits are healthy.", "fruits are good for health.", "fruits are bad for health."],
+...     is_query=False,
+... )
+
+>>> index = index.add_documents(
+...     documents_ids=["1", "2", "3"],
+...     documents_embeddings=documents_embeddings
+... )
+
+>>> queries_embeddings = model.encode(
+...     ["fruits are healthy.", "fruits are good for health and fun."],
+...     is_query=True,
+... )
+
+>>> matchs = index(queries_embeddings, k=30)
+
+>>> assert matchs["distances"].shape[0] == 2
+>>> assert isinstance(matchs, dict)
+>>> assert "documents_ids" in matchs
+>>> assert "distances" in matchs
+
+>>> queries_embeddings = model.encode(
+...     "fruits are healthy.",
+...     is_query=True,
+... )
+
+>>> matchs = index(queries_embeddings, k=30)
+
+>>> assert matchs["distances"].shape[0] == 1
+>>> assert isinstance(matchs, dict)
+>>> assert "documents_ids" in matchs
+>>> assert "distances" in matchs
+```
+
+## Methods
+
+???- note "__call__"
+
+    Query the index for the nearest neighbors of the queries embeddings.
+
+    **Parameters**
+
+    - **queries_embeddings**     (*numpy.ndarray | torch.Tensor*)    
+    - **k**     (*int*)     – defaults to `10`    
+    
+???- note "add_documents"
+
+    Add documents to the index.
+
+    **Parameters**
+
+    - **documents_ids**     (*str | list[str]*)    
+    - **documents_embeddings**     (*list[numpy.ndarray | torch.Tensor]*)    
+    
+???- note "get_documents_embeddings"
+
+    Retrieve document embeddings for re-ranking from Voyager.
+
+    **Parameters**
+
+    - **document_ids**     (*list[list[str]]*)    
+    
+???- note "remove_documents"
+
+    Remove documents from the index.
+
+    **Parameters**
+
+    - **documents_ids**     (*list[str]*)    
+    
diff --git a/docs/api/losses/.pages b/docs/api/losses/.pages
new file mode 100644
index 0000000..a5e19eb
--- /dev/null
+++ b/docs/api/losses/.pages
@@ -0,0 +1 @@
+title: losses
\ No newline at end of file
diff --git a/docs/api/losses/Contrastive.md b/docs/api/losses/Contrastive.md
new file mode 100644
index 0000000..17e658e
--- /dev/null
+++ b/docs/api/losses/Contrastive.md
@@ -0,0 +1,507 @@
+# Contrastive
+
+Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.
+
+
+
+## Parameters
+
+- **model** (*[models.ColBERT](../../models/ColBERT)*)
+
+    ColBERT model.
+
+- **score_metric** – defaults to `<function colbert_scores at 0x125931cf0>`
+
+    ColBERT scoring function. Defaults to colbert_scores.
+
+- **size_average** (*bool*) – defaults to `True`
+
+    Average by the size of the mini-batch.
+
+
+
+## Examples
+
+```python
+>>> from pylate import models, losses
+
+>>> model = models.ColBERT(
+...     model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu"
+... )
+
+>>> loss = losses.Contrastive(model=model)
+
+>>> anchor = model.tokenize([
+...     "fruits are healthy.",
+... ], is_query=True)
+
+>>> positive = model.tokenize([
+...     "fruits are good for health.",
+... ], is_query=False)
+
+>>> negative = model.tokenize([
+...     "fruits are bad for health.",
+... ], is_query=False)
+
+>>> sentence_features = [anchor, positive, negative]
+
+>>> loss = loss(sentence_features=sentence_features)
+>>> assert isinstance(loss.item(), float)
+```
+
+## Methods
+
+???- note "__call__"
+
+    Call self as a function.
+
+    **Parameters**
+
+    - **args**    
+    - **kwargs**    
+    
+???- note "add_module"
+
+    Add a child module to the current module.
+
+    The module can be accessed as an attribute using the given name.  Args:     name (str): name of the child module. The child module can be         accessed from this module using the given name     module (Module): child module to be added to the module.
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **module**     (*Optional[ForwardRef('Module')]*)    
+    
+???- note "apply"
+
+    Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
+
+    Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`).  Args:     fn (:class:`Module` -> None): function to be applied to each submodule  Returns:     Module: self  Example::      >>> @torch.no_grad()     >>> def init_weights(m):     >>>     print(m)     >>>     if type(m) == nn.Linear:     >>>         m.weight.fill_(1.0)     >>>         print(m.weight)     >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))     >>> net.apply(init_weights)     Linear(in_features=2, out_features=2, bias=True)     Parameter containing:     tensor([[1., 1.],             [1., 1.]], requires_grad=True)     Linear(in_features=2, out_features=2, bias=True)     Parameter containing:     tensor([[1., 1.],             [1., 1.]], requires_grad=True)     Sequential(       (0): Linear(in_features=2, out_features=2, bias=True)       (1): Linear(in_features=2, out_features=2, bias=True)     )
+
+    **Parameters**
+
+    - **fn**     (*Callable[[ForwardRef('Module')], NoneType]*)    
+    
+???- note "bfloat16"
+
+    Casts all floating point parameters and buffers to ``bfloat16`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "buffers"
+
+    Return an iterator over module buffers.
+
+    Args:     recurse (bool): if True, then yields buffers of this module         and all submodules. Otherwise, yields only buffers that         are direct members of this module.  Yields:     torch.Tensor: module buffer  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for buf in model.buffers():     >>>     print(type(buf), buf.size())     <class 'torch.Tensor'> (20L,)     <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
+
+    **Parameters**
+
+    - **recurse**     (*bool*)     – defaults to `True`    
+    
+???- note "children"
+
+    Return an iterator over immediate children modules.
+
+    Yields:     Module: a child module
+
+    
+???- note "compile"
+
+    Compile this Module's forward using :func:`torch.compile`.
+
+    This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`.  See :func:`torch.compile` for details on the arguments for this function.
+
+    **Parameters**
+
+    - **args**    
+    - **kwargs**    
+    
+???- note "cpu"
+
+    Move all model parameters and buffers to the CPU.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "cuda"
+
+    Move all model parameters and buffers to the GPU.
+
+    This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.  .. note::     This method modifies the module in-place.  Args:     device (int, optional): if specified, all parameters will be         copied to that device  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, torch.device, NoneType]*)     – defaults to `None`    
+    
+???- note "double"
+
+    Casts all floating point parameters and buffers to ``double`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "eval"
+
+    Set the module in evaluation mode.
+
+    This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc.  This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.  See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it.  Returns:     Module: self
+
+    
+???- note "extra_repr"
+
+    Set the extra representation of the module.
+
+    To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
+
+    
+???- note "float"
+
+    Casts all floating point parameters and buffers to ``float`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "forward"
+
+    Compute the Constrastive loss.
+
+    **Parameters**
+
+    - **sentence_features**     (*Iterable[dict[str, torch.Tensor]]*)    
+    - **labels**     (*torch.Tensor | None*)     – defaults to `None`    
+    
+???- note "get_buffer"
+
+    Return the buffer given by ``target`` if it exists, otherwise throw an error.
+
+    See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``.  Args:     target: The fully-qualified string name of the buffer         to look for. (See ``get_submodule`` for how to specify a         fully-qualified string.)  Returns:     torch.Tensor: The buffer referenced by ``target``  Raises:     AttributeError: If the target string references an invalid         path or resolves to something that is not a         buffer
+
+    **Parameters**
+
+    - **target**     (*str*)    
+    
+???- note "get_extra_state"
+
+    Return any extra state to include in the module's state_dict.
+
+    Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`.  Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.  Returns:     object: Any extra state to store in the module's state_dict
+
+    
+???- note "get_parameter"
+
+    Return the parameter given by ``target`` if it exists, otherwise throw an error.
+
+    See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``.  Args:     target: The fully-qualified string name of the Parameter         to look for. (See ``get_submodule`` for how to specify a         fully-qualified string.)  Returns:     torch.nn.Parameter: The Parameter referenced by ``target``  Raises:     AttributeError: If the target string references an invalid         path or resolves to something that is not an         ``nn.Parameter``
+
+    **Parameters**
+
+    - **target**     (*str*)    
+    
+???- note "get_submodule"
+
+    Return the submodule given by ``target`` if it exists, otherwise throw an error.
+
+    For example, let's say you have an ``nn.Module`` ``A`` that looks like this:  .. code-block:: text      A(         (net_b): Module(             (net_c): Module(                 (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))             )             (linear): Linear(in_features=100, out_features=200, bias=True)         )     )  (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.)  To check whether or not we have the ``linear`` submodule, we would call ``get_submodule("net_b.linear")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule("net_b.net_c.conv")``.  The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used.  Args:     target: The fully-qualified string name of the submodule         to look for. (See above example for how to specify a         fully-qualified string.)  Returns:     torch.nn.Module: The submodule referenced by ``target``  Raises:     AttributeError: If the target string references an invalid         path or resolves to something that is not an         ``nn.Module``
+
+    **Parameters**
+
+    - **target**     (*str*)    
+    
+???- note "half"
+
+    Casts all floating point parameters and buffers to ``half`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "ipu"
+
+    Move all model parameters and buffers to the IPU.
+
+    This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized.  .. note::     This method modifies the module in-place.  Arguments:     device (int, optional): if specified, all parameters will be         copied to that device  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, torch.device, NoneType]*)     – defaults to `None`    
+    
+???- note "load_state_dict"
+
+    Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.
+
+    If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function.  .. warning::     If :attr:`assign` is ``True`` the optimizer must be created after     the call to :attr:`load_state_dict` unless     :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.  Args:     state_dict (dict): a dict containing parameters and         persistent buffers.     strict (bool, optional): whether to strictly enforce that the keys         in :attr:`state_dict` match the keys returned by this module's         :meth:`~torch.nn.Module.state_dict` function. Default: ``True``     assign (bool, optional): When ``False``, the properties of the tensors         in the current module are preserved while when ``True``, the         properties of the Tensors in the state dict are preserved. The only         exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s         for which the value from the module is preserved.         Default: ``False``  Returns:     ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:         * **missing_keys** is a list of str containing the missing keys         * **unexpected_keys** is a list of str containing the unexpected keys  Note:     If a parameter or buffer is registered as ``None`` and its corresponding key     exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a     ``RuntimeError``.
+
+    **Parameters**
+
+    - **state_dict**     (*Mapping[str, Any]*)    
+    - **strict**     (*bool*)     – defaults to `True`    
+    - **assign**     (*bool*)     – defaults to `False`    
+    
+???- note "modules"
+
+    Return an iterator over all modules in the network.
+
+    Yields:     Module: a module in the network  Note:     Duplicate modules are returned only once. In the following     example, ``l`` will be returned only once.  Example::      >>> l = nn.Linear(2, 2)     >>> net = nn.Sequential(l, l)     >>> for idx, m in enumerate(net.modules()):     ...     print(idx, '->', m)      0 -> Sequential(       (0): Linear(in_features=2, out_features=2, bias=True)       (1): Linear(in_features=2, out_features=2, bias=True)     )     1 -> Linear(in_features=2, out_features=2, bias=True)
+
+    
+???- note "named_buffers"
+
+    Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
+
+    Args:     prefix (str): prefix to prepend to all buffer names.     recurse (bool, optional): if True, then yields buffers of this module         and all submodules. Otherwise, yields only buffers that         are direct members of this module. Defaults to True.     remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.  Yields:     (str, torch.Tensor): Tuple containing the name and buffer  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for name, buf in self.named_buffers():     >>>     if name in ['running_var']:     >>>         print(buf.size())
+
+    **Parameters**
+
+    - **prefix**     (*str*)     – defaults to ``    
+    - **recurse**     (*bool*)     – defaults to `True`    
+    - **remove_duplicate**     (*bool*)     – defaults to `True`    
+    
+???- note "named_children"
+
+    Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
+
+    Yields:     (str, Module): Tuple containing a name and child module  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for name, module in model.named_children():     >>>     if name in ['conv4', 'conv5']:     >>>         print(module)
+
+    
+???- note "named_modules"
+
+    Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
+
+    Args:     memo: a memo to store the set of modules already added to the result     prefix: a prefix that will be added to the name of the module     remove_duplicate: whether to remove the duplicated module instances in the result         or not  Yields:     (str, Module): Tuple of name and module  Note:     Duplicate modules are returned only once. In the following     example, ``l`` will be returned only once.  Example::      >>> l = nn.Linear(2, 2)     >>> net = nn.Sequential(l, l)     >>> for idx, m in enumerate(net.named_modules()):     ...     print(idx, '->', m)      0 -> ('', Sequential(       (0): Linear(in_features=2, out_features=2, bias=True)       (1): Linear(in_features=2, out_features=2, bias=True)     ))     1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
+
+    **Parameters**
+
+    - **memo**     (*Optional[Set[ForwardRef('Module')]]*)     – defaults to `None`    
+    - **prefix**     (*str*)     – defaults to ``    
+    - **remove_duplicate**     (*bool*)     – defaults to `True`    
+    
+???- note "named_parameters"
+
+    Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
+
+    Args:     prefix (str): prefix to prepend to all parameter names.     recurse (bool): if True, then yields parameters of this module         and all submodules. Otherwise, yields only parameters that         are direct members of this module.     remove_duplicate (bool, optional): whether to remove the duplicated         parameters in the result. Defaults to True.  Yields:     (str, Parameter): Tuple containing the name and parameter  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for name, param in self.named_parameters():     >>>     if name in ['bias']:     >>>         print(param.size())
+
+    **Parameters**
+
+    - **prefix**     (*str*)     – defaults to ``    
+    - **recurse**     (*bool*)     – defaults to `True`    
+    - **remove_duplicate**     (*bool*)     – defaults to `True`    
+    
+???- note "parameters"
+
+    Return an iterator over module parameters.
+
+    This is typically passed to an optimizer.  Args:     recurse (bool): if True, then yields parameters of this module         and all submodules. Otherwise, yields only parameters that         are direct members of this module.  Yields:     Parameter: module parameter  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for param in model.parameters():     >>>     print(type(param), param.size())     <class 'torch.Tensor'> (20L,)     <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
+
+    **Parameters**
+
+    - **recurse**     (*bool*)     – defaults to `True`    
+    
+???- note "register_backward_hook"
+
+    Register a backward hook on the module.
+
+    This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and the behavior of this function will change in future versions.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]*)    
+    
+???- note "register_buffer"
+
+    Add a buffer to the module.
+
+    This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`.  Buffers can be accessed as attributes using given names.  Args:     name (str): name of the buffer. The buffer can be accessed         from this module using the given name     tensor (Tensor or None): buffer to be registered. If ``None``, then operations         that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,         the buffer is **not** included in the module's :attr:`state_dict`.     persistent (bool): whether the buffer is part of this module's         :attr:`state_dict`.  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> self.register_buffer('running_mean', torch.zeros(num_features))
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **tensor**     (*Optional[torch.Tensor]*)    
+    - **persistent**     (*bool*)     – defaults to `True`    
+    
+???- note "register_forward_hook"
+
+    Register a forward hook on the module.
+
+    The hook will be called every time after :func:`forward` has computed an output.  If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature::      hook(module, args, output) -> None or modified output  If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature::      hook(module, args, kwargs, output) -> None or modified output  Args:     hook (Callable): The user defined hook to be registered.     prepend (bool): If ``True``, the provided ``hook`` will be fired         before all existing ``forward`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``forward`` hooks on         this :class:`torch.nn.modules.Module`. Note that global         ``forward`` hooks registered with         :func:`register_module_forward_hook` will fire before all hooks         registered by this method.         Default: ``False``     with_kwargs (bool): If ``True``, the ``hook`` will be passed the         kwargs given to the forward function.         Default: ``False``     always_call (bool): If ``True`` the ``hook`` will be run regardless of         whether an exception is raised while calling the Module.         Default: ``False``  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Union[Callable[[~T, Tuple[Any, ...], Any], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    - **with_kwargs**     (*bool*)     – defaults to `False`    
+    - **always_call**     (*bool*)     – defaults to `False`    
+    
+???- note "register_forward_pre_hook"
+
+    Register a forward pre-hook on the module.
+
+    The hook will be called every time before :func:`forward` is invoked.  If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature::      hook(module, args) -> None or modified input  If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature::      hook(module, args, kwargs) -> None or a tuple of modified input and kwargs  Args:     hook (Callable): The user defined hook to be registered.     prepend (bool): If true, the provided ``hook`` will be fired before         all existing ``forward_pre`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``forward_pre`` hooks         on this :class:`torch.nn.modules.Module`. Note that global         ``forward_pre`` hooks registered with         :func:`register_module_forward_pre_hook` will fire before all         hooks registered by this method.         Default: ``False``     with_kwargs (bool): If true, the ``hook`` will be passed the kwargs         given to the forward function.         Default: ``False``  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Union[Callable[[~T, Tuple[Any, ...]], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    - **with_kwargs**     (*bool*)     – defaults to `False`    
+    
+???- note "register_full_backward_hook"
+
+    Register a backward hook on the module.
+
+    The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature::      hook(module, grad_input, grad_output) -> tuple(Tensor) or None  The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments.  For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.  .. warning ::     Modifying inputs or outputs inplace is not allowed when using backward hooks and     will raise an error.  Args:     hook (Callable): The user-defined hook to be registered.     prepend (bool): If true, the provided ``hook`` will be fired before         all existing ``backward`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``backward`` hooks on         this :class:`torch.nn.modules.Module`. Note that global         ``backward`` hooks registered with         :func:`register_module_full_backward_hook` will fire before         all hooks registered by this method.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    
+???- note "register_full_backward_pre_hook"
+
+    Register a backward pre-hook on the module.
+
+    The hook will be called every time the gradients for the module are computed. The hook should have the following signature::      hook(module, grad_output) -> tuple[Tensor] or None  The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments.  For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.  .. warning ::     Modifying inputs inplace is not allowed when using backward hooks and     will raise an error.  Args:     hook (Callable): The user-defined hook to be registered.     prepend (bool): If true, the provided ``hook`` will be fired before         all existing ``backward_pre`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``backward_pre`` hooks         on this :class:`torch.nn.modules.Module`. Note that global         ``backward_pre`` hooks registered with         :func:`register_module_full_backward_pre_hook` will fire before         all hooks registered by this method.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    
+???- note "register_load_state_dict_post_hook"
+
+    Register a post hook to be run after module's ``load_state_dict`` is called.
+
+    It should have the following signature::     hook(module, incompatible_keys) -> None  The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.  The given incompatible_keys can be modified inplace if needed.  Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**    
+    
+???- note "register_module"
+
+    Alias for :func:`add_module`.
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **module**     (*Optional[ForwardRef('Module')]*)    
+    
+???- note "register_parameter"
+
+    Add a parameter to the module.
+
+    The parameter can be accessed as an attribute using given name.  Args:     name (str): name of the parameter. The parameter can be accessed         from this module using the given name     param (Parameter or None): parameter to be added to the module. If         ``None``, then operations that run on parameters, such as :attr:`cuda`,         are ignored. If ``None``, the parameter is **not** included in the         module's :attr:`state_dict`.
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **param**     (*Optional[torch.nn.parameter.Parameter]*)    
+    
+???- note "register_state_dict_pre_hook"
+
+    Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.
+
+    These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made.
+
+    **Parameters**
+
+    - **hook**    
+    
+???- note "requires_grad_"
+
+    Change if autograd should record operations on parameters in this module.
+
+    This method sets the parameters' :attr:`requires_grad` attributes in-place.  This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).  See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it.  Args:     requires_grad (bool): whether autograd should record operations on                           parameters in this module. Default: ``True``.  Returns:     Module: self
+
+    **Parameters**
+
+    - **requires_grad**     (*bool*)     – defaults to `True`    
+    
+???- note "set_extra_state"
+
+    Set extra state contained in the loaded `state_dict`.
+
+    This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`.  Args:     state (dict): Extra state from the `state_dict`
+
+    **Parameters**
+
+    - **state**     (*Any*)    
+    
+???- note "share_memory"
+
+    See :meth:`torch.Tensor.share_memory_`.
+
+    
+???- note "state_dict"
+
+    Return a dictionary containing references to the whole state of the module.
+
+    Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included.  .. note::     The returned object is a shallow copy. It contains references     to the module's parameters and buffers.  .. warning::     Currently ``state_dict()`` also accepts positional arguments for     ``destination``, ``prefix`` and ``keep_vars`` in order. However,     this is being deprecated and keyword arguments will be enforced in     future releases.  .. warning::     Please avoid the use of argument ``destination`` as it is not     designed for end-users.  Args:     destination (dict, optional): If provided, the state of module will         be updated into the dict and the same object is returned.         Otherwise, an ``OrderedDict`` will be created and returned.         Default: ``None``.     prefix (str, optional): a prefix added to parameter and buffer         names to compose the keys in state_dict. Default: ``''``.     keep_vars (bool, optional): by default the :class:`~torch.Tensor` s         returned in the state dict are detached from autograd. If it's         set to ``True``, detaching will not be performed.         Default: ``False``.  Returns:     dict:         a dictionary containing a whole state of the module  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> module.state_dict().keys()     ['bias', 'weight']
+
+    **Parameters**
+
+    - **args**    
+    - **destination**     – defaults to `None`    
+    - **prefix**     – defaults to ``    
+    - **keep_vars**     – defaults to `False`    
+    
+???- note "to"
+
+    Move and/or cast the parameters and buffers.
+
+    This can be called as  .. function:: to(device=None, dtype=None, non_blocking=False)    :noindex:  .. function:: to(dtype, non_blocking=False)    :noindex:  .. function:: to(tensor, non_blocking=False)    :noindex:  .. function:: to(memory_format=torch.channels_last)    :noindex:  Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype`\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.  See below for examples.  .. note::     This method modifies the module in-place.  Args:     device (:class:`torch.device`): the desired device of the parameters         and buffers in this module     dtype (:class:`torch.dtype`): the desired floating point or complex dtype of         the parameters and buffers in this module     tensor (torch.Tensor): Tensor whose dtype and device are the desired         dtype and device for all parameters and buffers in this module     memory_format (:class:`torch.memory_format`): the desired memory         format for 4D parameters and buffers in this module (keyword         only argument)  Returns:     Module: self  Examples::      >>> # xdoctest: +IGNORE_WANT("non-deterministic")     >>> linear = nn.Linear(2, 2)     >>> linear.weight     Parameter containing:     tensor([[ 0.1913, -0.3420],             [-0.5113, -0.2325]])     >>> linear.to(torch.double)     Linear(in_features=2, out_features=2, bias=True)     >>> linear.weight     Parameter containing:     tensor([[ 0.1913, -0.3420],             [-0.5113, -0.2325]], dtype=torch.float64)     >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)     >>> gpu1 = torch.device("cuda:1")     >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)     Linear(in_features=2, out_features=2, bias=True)     >>> linear.weight     Parameter containing:     tensor([[ 0.1914, -0.3420],             [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')     >>> cpu = torch.device("cpu")     >>> linear.to(cpu)     Linear(in_features=2, out_features=2, bias=True)     >>> linear.weight     Parameter containing:     tensor([[ 0.1914, -0.3420],             [-0.5112, -0.2324]], dtype=torch.float16)      >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)     >>> linear.weight     Parameter containing:     tensor([[ 0.3741+0.j,  0.2382+0.j],             [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)     >>> linear(torch.ones(3, 2, dtype=torch.cdouble))     tensor([[0.6122+0.j, 0.1150+0.j],             [0.6122+0.j, 0.1150+0.j],             [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
+
+    **Parameters**
+
+    - **args**    
+    - **kwargs**    
+    
+???- note "to_empty"
+
+    Move the parameters and buffers to the specified device without copying storage.
+
+    Args:     device (:class:`torch.device`): The desired device of the parameters         and buffers in this module.     recurse (bool): Whether parameters and buffers of submodules should         be recursively moved to the specified device.  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, str, torch.device, NoneType]*)    
+    - **recurse**     (*bool*)     – defaults to `True`    
+    
+???- note "train"
+
+    Set the module in training mode.
+
+    This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc.  Args:     mode (bool): whether to set training mode (``True``) or evaluation                  mode (``False``). Default: ``True``.  Returns:     Module: self
+
+    **Parameters**
+
+    - **mode**     (*bool*)     – defaults to `True`    
+    
+???- note "type"
+
+    Casts all parameters and buffers to :attr:`dst_type`.
+
+    .. note::     This method modifies the module in-place.  Args:     dst_type (type or string): the desired type  Returns:     Module: self
+
+    **Parameters**
+
+    - **dst_type**     (*Union[torch.dtype, str]*)    
+    
+???- note "xpu"
+
+    Move all model parameters and buffers to the XPU.
+
+    This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.  .. note::     This method modifies the module in-place.  Arguments:     device (int, optional): if specified, all parameters will be         copied to that device  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, torch.device, NoneType]*)     – defaults to `None`    
+    
+???- note "zero_grad"
+
+    Reset gradients of all model parameters.
+
+    See similar function under :class:`torch.optim.Optimizer` for more context.  Args:     set_to_none (bool): instead of setting to zero, set the grads to None.         See :meth:`torch.optim.Optimizer.zero_grad` for details.
+
+    **Parameters**
+
+    - **set_to_none**     (*bool*)     – defaults to `True`    
+    
diff --git a/docs/api/losses/Distillation.md b/docs/api/losses/Distillation.md
new file mode 100644
index 0000000..9f0e5ba
--- /dev/null
+++ b/docs/api/losses/Distillation.md
@@ -0,0 +1,511 @@
+# Distillation
+
+Distillation loss for ColBERT model. The loss is computed with respect to the format of SentenceTransformer library.
+
+
+
+## Parameters
+
+- **model** (*[models.ColBERT](../../models/ColBERT)*)
+
+    SentenceTransformer model.
+
+- **score_metric** (*Callable*) – defaults to `<function colbert_kd_scores at 0x178a65120>`
+
+    Function that returns a score between two sequences of embeddings.
+
+- **size_average** (*bool*) – defaults to `True`
+
+    Average by the size of the mini-batch or perform sum.
+
+- **normalize_scores** (*bool*) – defaults to `True`
+
+
+
+## Examples
+
+```python
+>>> from pylate import models, losses
+
+>>> model = models.ColBERT(
+...     model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu"
+... )
+
+>>> distillation = losses.Distillation(model=model)
+
+>>> query = model.tokenize([
+...     "fruits are healthy.",
+... ], is_query=True)
+
+>>> documents = model.tokenize([
+...     "fruits are good for health.",
+...     "fruits are bad for health."
+... ], is_query=False)
+
+>>> sentence_features = [query, documents]
+
+>>> labels = torch.tensor([
+...     [0.7, 0.3],
+... ], dtype=torch.float32)
+
+>>> loss = distillation(sentence_features=sentence_features, labels=labels)
+
+>>> assert isinstance(loss.item(), float)
+```
+
+## Methods
+
+???- note "__call__"
+
+    Call self as a function.
+
+    **Parameters**
+
+    - **args**    
+    - **kwargs**    
+    
+???- note "add_module"
+
+    Add a child module to the current module.
+
+    The module can be accessed as an attribute using the given name.  Args:     name (str): name of the child module. The child module can be         accessed from this module using the given name     module (Module): child module to be added to the module.
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **module**     (*Optional[ForwardRef('Module')]*)    
+    
+???- note "apply"
+
+    Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
+
+    Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`).  Args:     fn (:class:`Module` -> None): function to be applied to each submodule  Returns:     Module: self  Example::      >>> @torch.no_grad()     >>> def init_weights(m):     >>>     print(m)     >>>     if type(m) == nn.Linear:     >>>         m.weight.fill_(1.0)     >>>         print(m.weight)     >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))     >>> net.apply(init_weights)     Linear(in_features=2, out_features=2, bias=True)     Parameter containing:     tensor([[1., 1.],             [1., 1.]], requires_grad=True)     Linear(in_features=2, out_features=2, bias=True)     Parameter containing:     tensor([[1., 1.],             [1., 1.]], requires_grad=True)     Sequential(       (0): Linear(in_features=2, out_features=2, bias=True)       (1): Linear(in_features=2, out_features=2, bias=True)     )
+
+    **Parameters**
+
+    - **fn**     (*Callable[[ForwardRef('Module')], NoneType]*)    
+    
+???- note "bfloat16"
+
+    Casts all floating point parameters and buffers to ``bfloat16`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "buffers"
+
+    Return an iterator over module buffers.
+
+    Args:     recurse (bool): if True, then yields buffers of this module         and all submodules. Otherwise, yields only buffers that         are direct members of this module.  Yields:     torch.Tensor: module buffer  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for buf in model.buffers():     >>>     print(type(buf), buf.size())     <class 'torch.Tensor'> (20L,)     <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
+
+    **Parameters**
+
+    - **recurse**     (*bool*)     – defaults to `True`    
+    
+???- note "children"
+
+    Return an iterator over immediate children modules.
+
+    Yields:     Module: a child module
+
+    
+???- note "compile"
+
+    Compile this Module's forward using :func:`torch.compile`.
+
+    This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`.  See :func:`torch.compile` for details on the arguments for this function.
+
+    **Parameters**
+
+    - **args**    
+    - **kwargs**    
+    
+???- note "cpu"
+
+    Move all model parameters and buffers to the CPU.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "cuda"
+
+    Move all model parameters and buffers to the GPU.
+
+    This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.  .. note::     This method modifies the module in-place.  Args:     device (int, optional): if specified, all parameters will be         copied to that device  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, torch.device, NoneType]*)     – defaults to `None`    
+    
+???- note "double"
+
+    Casts all floating point parameters and buffers to ``double`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "eval"
+
+    Set the module in evaluation mode.
+
+    This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc.  This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.  See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it.  Returns:     Module: self
+
+    
+???- note "extra_repr"
+
+    Set the extra representation of the module.
+
+    To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
+
+    
+???- note "float"
+
+    Casts all floating point parameters and buffers to ``float`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "forward"
+
+    Computes the distillation loss with respect to SentenceTransformer.
+
+    **Parameters**
+
+    - **sentence_features**     (*Iterable[dict[str, torch.Tensor]]*)    
+    - **labels**     (*torch.Tensor*)    
+    
+???- note "get_buffer"
+
+    Return the buffer given by ``target`` if it exists, otherwise throw an error.
+
+    See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``.  Args:     target: The fully-qualified string name of the buffer         to look for. (See ``get_submodule`` for how to specify a         fully-qualified string.)  Returns:     torch.Tensor: The buffer referenced by ``target``  Raises:     AttributeError: If the target string references an invalid         path or resolves to something that is not a         buffer
+
+    **Parameters**
+
+    - **target**     (*str*)    
+    
+???- note "get_extra_state"
+
+    Return any extra state to include in the module's state_dict.
+
+    Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`.  Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.  Returns:     object: Any extra state to store in the module's state_dict
+
+    
+???- note "get_parameter"
+
+    Return the parameter given by ``target`` if it exists, otherwise throw an error.
+
+    See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``.  Args:     target: The fully-qualified string name of the Parameter         to look for. (See ``get_submodule`` for how to specify a         fully-qualified string.)  Returns:     torch.nn.Parameter: The Parameter referenced by ``target``  Raises:     AttributeError: If the target string references an invalid         path or resolves to something that is not an         ``nn.Parameter``
+
+    **Parameters**
+
+    - **target**     (*str*)    
+    
+???- note "get_submodule"
+
+    Return the submodule given by ``target`` if it exists, otherwise throw an error.
+
+    For example, let's say you have an ``nn.Module`` ``A`` that looks like this:  .. code-block:: text      A(         (net_b): Module(             (net_c): Module(                 (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))             )             (linear): Linear(in_features=100, out_features=200, bias=True)         )     )  (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.)  To check whether or not we have the ``linear`` submodule, we would call ``get_submodule("net_b.linear")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule("net_b.net_c.conv")``.  The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used.  Args:     target: The fully-qualified string name of the submodule         to look for. (See above example for how to specify a         fully-qualified string.)  Returns:     torch.nn.Module: The submodule referenced by ``target``  Raises:     AttributeError: If the target string references an invalid         path or resolves to something that is not an         ``nn.Module``
+
+    **Parameters**
+
+    - **target**     (*str*)    
+    
+???- note "half"
+
+    Casts all floating point parameters and buffers to ``half`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "ipu"
+
+    Move all model parameters and buffers to the IPU.
+
+    This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized.  .. note::     This method modifies the module in-place.  Arguments:     device (int, optional): if specified, all parameters will be         copied to that device  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, torch.device, NoneType]*)     – defaults to `None`    
+    
+???- note "load_state_dict"
+
+    Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.
+
+    If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function.  .. warning::     If :attr:`assign` is ``True`` the optimizer must be created after     the call to :attr:`load_state_dict` unless     :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.  Args:     state_dict (dict): a dict containing parameters and         persistent buffers.     strict (bool, optional): whether to strictly enforce that the keys         in :attr:`state_dict` match the keys returned by this module's         :meth:`~torch.nn.Module.state_dict` function. Default: ``True``     assign (bool, optional): When ``False``, the properties of the tensors         in the current module are preserved while when ``True``, the         properties of the Tensors in the state dict are preserved. The only         exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s         for which the value from the module is preserved.         Default: ``False``  Returns:     ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:         * **missing_keys** is a list of str containing the missing keys         * **unexpected_keys** is a list of str containing the unexpected keys  Note:     If a parameter or buffer is registered as ``None`` and its corresponding key     exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a     ``RuntimeError``.
+
+    **Parameters**
+
+    - **state_dict**     (*Mapping[str, Any]*)    
+    - **strict**     (*bool*)     – defaults to `True`    
+    - **assign**     (*bool*)     – defaults to `False`    
+    
+???- note "modules"
+
+    Return an iterator over all modules in the network.
+
+    Yields:     Module: a module in the network  Note:     Duplicate modules are returned only once. In the following     example, ``l`` will be returned only once.  Example::      >>> l = nn.Linear(2, 2)     >>> net = nn.Sequential(l, l)     >>> for idx, m in enumerate(net.modules()):     ...     print(idx, '->', m)      0 -> Sequential(       (0): Linear(in_features=2, out_features=2, bias=True)       (1): Linear(in_features=2, out_features=2, bias=True)     )     1 -> Linear(in_features=2, out_features=2, bias=True)
+
+    
+???- note "named_buffers"
+
+    Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
+
+    Args:     prefix (str): prefix to prepend to all buffer names.     recurse (bool, optional): if True, then yields buffers of this module         and all submodules. Otherwise, yields only buffers that         are direct members of this module. Defaults to True.     remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.  Yields:     (str, torch.Tensor): Tuple containing the name and buffer  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for name, buf in self.named_buffers():     >>>     if name in ['running_var']:     >>>         print(buf.size())
+
+    **Parameters**
+
+    - **prefix**     (*str*)     – defaults to ``    
+    - **recurse**     (*bool*)     – defaults to `True`    
+    - **remove_duplicate**     (*bool*)     – defaults to `True`    
+    
+???- note "named_children"
+
+    Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
+
+    Yields:     (str, Module): Tuple containing a name and child module  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for name, module in model.named_children():     >>>     if name in ['conv4', 'conv5']:     >>>         print(module)
+
+    
+???- note "named_modules"
+
+    Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
+
+    Args:     memo: a memo to store the set of modules already added to the result     prefix: a prefix that will be added to the name of the module     remove_duplicate: whether to remove the duplicated module instances in the result         or not  Yields:     (str, Module): Tuple of name and module  Note:     Duplicate modules are returned only once. In the following     example, ``l`` will be returned only once.  Example::      >>> l = nn.Linear(2, 2)     >>> net = nn.Sequential(l, l)     >>> for idx, m in enumerate(net.named_modules()):     ...     print(idx, '->', m)      0 -> ('', Sequential(       (0): Linear(in_features=2, out_features=2, bias=True)       (1): Linear(in_features=2, out_features=2, bias=True)     ))     1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
+
+    **Parameters**
+
+    - **memo**     (*Optional[Set[ForwardRef('Module')]]*)     – defaults to `None`    
+    - **prefix**     (*str*)     – defaults to ``    
+    - **remove_duplicate**     (*bool*)     – defaults to `True`    
+    
+???- note "named_parameters"
+
+    Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
+
+    Args:     prefix (str): prefix to prepend to all parameter names.     recurse (bool): if True, then yields parameters of this module         and all submodules. Otherwise, yields only parameters that         are direct members of this module.     remove_duplicate (bool, optional): whether to remove the duplicated         parameters in the result. Defaults to True.  Yields:     (str, Parameter): Tuple containing the name and parameter  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for name, param in self.named_parameters():     >>>     if name in ['bias']:     >>>         print(param.size())
+
+    **Parameters**
+
+    - **prefix**     (*str*)     – defaults to ``    
+    - **recurse**     (*bool*)     – defaults to `True`    
+    - **remove_duplicate**     (*bool*)     – defaults to `True`    
+    
+???- note "parameters"
+
+    Return an iterator over module parameters.
+
+    This is typically passed to an optimizer.  Args:     recurse (bool): if True, then yields parameters of this module         and all submodules. Otherwise, yields only parameters that         are direct members of this module.  Yields:     Parameter: module parameter  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for param in model.parameters():     >>>     print(type(param), param.size())     <class 'torch.Tensor'> (20L,)     <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
+
+    **Parameters**
+
+    - **recurse**     (*bool*)     – defaults to `True`    
+    
+???- note "register_backward_hook"
+
+    Register a backward hook on the module.
+
+    This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and the behavior of this function will change in future versions.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]*)    
+    
+???- note "register_buffer"
+
+    Add a buffer to the module.
+
+    This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`.  Buffers can be accessed as attributes using given names.  Args:     name (str): name of the buffer. The buffer can be accessed         from this module using the given name     tensor (Tensor or None): buffer to be registered. If ``None``, then operations         that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,         the buffer is **not** included in the module's :attr:`state_dict`.     persistent (bool): whether the buffer is part of this module's         :attr:`state_dict`.  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> self.register_buffer('running_mean', torch.zeros(num_features))
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **tensor**     (*Optional[torch.Tensor]*)    
+    - **persistent**     (*bool*)     – defaults to `True`    
+    
+???- note "register_forward_hook"
+
+    Register a forward hook on the module.
+
+    The hook will be called every time after :func:`forward` has computed an output.  If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature::      hook(module, args, output) -> None or modified output  If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature::      hook(module, args, kwargs, output) -> None or modified output  Args:     hook (Callable): The user defined hook to be registered.     prepend (bool): If ``True``, the provided ``hook`` will be fired         before all existing ``forward`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``forward`` hooks on         this :class:`torch.nn.modules.Module`. Note that global         ``forward`` hooks registered with         :func:`register_module_forward_hook` will fire before all hooks         registered by this method.         Default: ``False``     with_kwargs (bool): If ``True``, the ``hook`` will be passed the         kwargs given to the forward function.         Default: ``False``     always_call (bool): If ``True`` the ``hook`` will be run regardless of         whether an exception is raised while calling the Module.         Default: ``False``  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Union[Callable[[~T, Tuple[Any, ...], Any], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    - **with_kwargs**     (*bool*)     – defaults to `False`    
+    - **always_call**     (*bool*)     – defaults to `False`    
+    
+???- note "register_forward_pre_hook"
+
+    Register a forward pre-hook on the module.
+
+    The hook will be called every time before :func:`forward` is invoked.  If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature::      hook(module, args) -> None or modified input  If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature::      hook(module, args, kwargs) -> None or a tuple of modified input and kwargs  Args:     hook (Callable): The user defined hook to be registered.     prepend (bool): If true, the provided ``hook`` will be fired before         all existing ``forward_pre`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``forward_pre`` hooks         on this :class:`torch.nn.modules.Module`. Note that global         ``forward_pre`` hooks registered with         :func:`register_module_forward_pre_hook` will fire before all         hooks registered by this method.         Default: ``False``     with_kwargs (bool): If true, the ``hook`` will be passed the kwargs         given to the forward function.         Default: ``False``  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Union[Callable[[~T, Tuple[Any, ...]], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    - **with_kwargs**     (*bool*)     – defaults to `False`    
+    
+???- note "register_full_backward_hook"
+
+    Register a backward hook on the module.
+
+    The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature::      hook(module, grad_input, grad_output) -> tuple(Tensor) or None  The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments.  For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.  .. warning ::     Modifying inputs or outputs inplace is not allowed when using backward hooks and     will raise an error.  Args:     hook (Callable): The user-defined hook to be registered.     prepend (bool): If true, the provided ``hook`` will be fired before         all existing ``backward`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``backward`` hooks on         this :class:`torch.nn.modules.Module`. Note that global         ``backward`` hooks registered with         :func:`register_module_full_backward_hook` will fire before         all hooks registered by this method.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    
+???- note "register_full_backward_pre_hook"
+
+    Register a backward pre-hook on the module.
+
+    The hook will be called every time the gradients for the module are computed. The hook should have the following signature::      hook(module, grad_output) -> tuple[Tensor] or None  The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments.  For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.  .. warning ::     Modifying inputs inplace is not allowed when using backward hooks and     will raise an error.  Args:     hook (Callable): The user-defined hook to be registered.     prepend (bool): If true, the provided ``hook`` will be fired before         all existing ``backward_pre`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``backward_pre`` hooks         on this :class:`torch.nn.modules.Module`. Note that global         ``backward_pre`` hooks registered with         :func:`register_module_full_backward_pre_hook` will fire before         all hooks registered by this method.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    
+???- note "register_load_state_dict_post_hook"
+
+    Register a post hook to be run after module's ``load_state_dict`` is called.
+
+    It should have the following signature::     hook(module, incompatible_keys) -> None  The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.  The given incompatible_keys can be modified inplace if needed.  Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**    
+    
+???- note "register_module"
+
+    Alias for :func:`add_module`.
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **module**     (*Optional[ForwardRef('Module')]*)    
+    
+???- note "register_parameter"
+
+    Add a parameter to the module.
+
+    The parameter can be accessed as an attribute using given name.  Args:     name (str): name of the parameter. The parameter can be accessed         from this module using the given name     param (Parameter or None): parameter to be added to the module. If         ``None``, then operations that run on parameters, such as :attr:`cuda`,         are ignored. If ``None``, the parameter is **not** included in the         module's :attr:`state_dict`.
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **param**     (*Optional[torch.nn.parameter.Parameter]*)    
+    
+???- note "register_state_dict_pre_hook"
+
+    Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.
+
+    These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made.
+
+    **Parameters**
+
+    - **hook**    
+    
+???- note "requires_grad_"
+
+    Change if autograd should record operations on parameters in this module.
+
+    This method sets the parameters' :attr:`requires_grad` attributes in-place.  This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).  See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it.  Args:     requires_grad (bool): whether autograd should record operations on                           parameters in this module. Default: ``True``.  Returns:     Module: self
+
+    **Parameters**
+
+    - **requires_grad**     (*bool*)     – defaults to `True`    
+    
+???- note "set_extra_state"
+
+    Set extra state contained in the loaded `state_dict`.
+
+    This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`.  Args:     state (dict): Extra state from the `state_dict`
+
+    **Parameters**
+
+    - **state**     (*Any*)    
+    
+???- note "share_memory"
+
+    See :meth:`torch.Tensor.share_memory_`.
+
+    
+???- note "state_dict"
+
+    Return a dictionary containing references to the whole state of the module.
+
+    Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included.  .. note::     The returned object is a shallow copy. It contains references     to the module's parameters and buffers.  .. warning::     Currently ``state_dict()`` also accepts positional arguments for     ``destination``, ``prefix`` and ``keep_vars`` in order. However,     this is being deprecated and keyword arguments will be enforced in     future releases.  .. warning::     Please avoid the use of argument ``destination`` as it is not     designed for end-users.  Args:     destination (dict, optional): If provided, the state of module will         be updated into the dict and the same object is returned.         Otherwise, an ``OrderedDict`` will be created and returned.         Default: ``None``.     prefix (str, optional): a prefix added to parameter and buffer         names to compose the keys in state_dict. Default: ``''``.     keep_vars (bool, optional): by default the :class:`~torch.Tensor` s         returned in the state dict are detached from autograd. If it's         set to ``True``, detaching will not be performed.         Default: ``False``.  Returns:     dict:         a dictionary containing a whole state of the module  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> module.state_dict().keys()     ['bias', 'weight']
+
+    **Parameters**
+
+    - **args**    
+    - **destination**     – defaults to `None`    
+    - **prefix**     – defaults to ``    
+    - **keep_vars**     – defaults to `False`    
+    
+???- note "to"
+
+    Move and/or cast the parameters and buffers.
+
+    This can be called as  .. function:: to(device=None, dtype=None, non_blocking=False)    :noindex:  .. function:: to(dtype, non_blocking=False)    :noindex:  .. function:: to(tensor, non_blocking=False)    :noindex:  .. function:: to(memory_format=torch.channels_last)    :noindex:  Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype`\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.  See below for examples.  .. note::     This method modifies the module in-place.  Args:     device (:class:`torch.device`): the desired device of the parameters         and buffers in this module     dtype (:class:`torch.dtype`): the desired floating point or complex dtype of         the parameters and buffers in this module     tensor (torch.Tensor): Tensor whose dtype and device are the desired         dtype and device for all parameters and buffers in this module     memory_format (:class:`torch.memory_format`): the desired memory         format for 4D parameters and buffers in this module (keyword         only argument)  Returns:     Module: self  Examples::      >>> # xdoctest: +IGNORE_WANT("non-deterministic")     >>> linear = nn.Linear(2, 2)     >>> linear.weight     Parameter containing:     tensor([[ 0.1913, -0.3420],             [-0.5113, -0.2325]])     >>> linear.to(torch.double)     Linear(in_features=2, out_features=2, bias=True)     >>> linear.weight     Parameter containing:     tensor([[ 0.1913, -0.3420],             [-0.5113, -0.2325]], dtype=torch.float64)     >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)     >>> gpu1 = torch.device("cuda:1")     >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)     Linear(in_features=2, out_features=2, bias=True)     >>> linear.weight     Parameter containing:     tensor([[ 0.1914, -0.3420],             [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')     >>> cpu = torch.device("cpu")     >>> linear.to(cpu)     Linear(in_features=2, out_features=2, bias=True)     >>> linear.weight     Parameter containing:     tensor([[ 0.1914, -0.3420],             [-0.5112, -0.2324]], dtype=torch.float16)      >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)     >>> linear.weight     Parameter containing:     tensor([[ 0.3741+0.j,  0.2382+0.j],             [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)     >>> linear(torch.ones(3, 2, dtype=torch.cdouble))     tensor([[0.6122+0.j, 0.1150+0.j],             [0.6122+0.j, 0.1150+0.j],             [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
+
+    **Parameters**
+
+    - **args**    
+    - **kwargs**    
+    
+???- note "to_empty"
+
+    Move the parameters and buffers to the specified device without copying storage.
+
+    Args:     device (:class:`torch.device`): The desired device of the parameters         and buffers in this module.     recurse (bool): Whether parameters and buffers of submodules should         be recursively moved to the specified device.  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, str, torch.device, NoneType]*)    
+    - **recurse**     (*bool*)     – defaults to `True`    
+    
+???- note "train"
+
+    Set the module in training mode.
+
+    This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc.  Args:     mode (bool): whether to set training mode (``True``) or evaluation                  mode (``False``). Default: ``True``.  Returns:     Module: self
+
+    **Parameters**
+
+    - **mode**     (*bool*)     – defaults to `True`    
+    
+???- note "type"
+
+    Casts all parameters and buffers to :attr:`dst_type`.
+
+    .. note::     This method modifies the module in-place.  Args:     dst_type (type or string): the desired type  Returns:     Module: self
+
+    **Parameters**
+
+    - **dst_type**     (*Union[torch.dtype, str]*)    
+    
+???- note "xpu"
+
+    Move all model parameters and buffers to the XPU.
+
+    This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.  .. note::     This method modifies the module in-place.  Arguments:     device (int, optional): if specified, all parameters will be         copied to that device  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, torch.device, NoneType]*)     – defaults to `None`    
+    
+???- note "zero_grad"
+
+    Reset gradients of all model parameters.
+
+    See similar function under :class:`torch.optim.Optimizer` for more context.  Args:     set_to_none (bool): instead of setting to zero, set the grads to None.         See :meth:`torch.optim.Optimizer.zero_grad` for details.
+
+    **Parameters**
+
+    - **set_to_none**     (*bool*)     – defaults to `True`    
+    
diff --git a/docs/api/models/.pages b/docs/api/models/.pages
new file mode 100644
index 0000000..ad7497b
--- /dev/null
+++ b/docs/api/models/.pages
@@ -0,0 +1 @@
+title: models
\ No newline at end of file
diff --git a/docs/api/models/ColBERT.md b/docs/api/models/ColBERT.md
new file mode 100644
index 0000000..963c715
--- /dev/null
+++ b/docs/api/models/ColBERT.md
@@ -0,0 +1,956 @@
+# ColBERT
+
+Loads or creates a ColBERT model that can be used to map sentences / text to multi-vectors embeddings.
+
+
+
+## Parameters
+
+- **model_name_or_path** (*str | None*) – defaults to `None`
+
+    If it is a filepath on disc, it loads the model from that path. If it is not a path, it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model from the Hugging Face Hub with that name.
+
+- **modules** (*Optional[Iterable[torch.nn.modules.module.Module]]*) – defaults to `None`
+
+    A list of torch Modules that should be called sequentially, can be used to create custom SentenceTransformer models from scratch.
+
+- **device** (*str | None*) – defaults to `None`
+
+    Device (like "cuda", "cpu", "mps", "npu") that should be used for computation. If None, checks if a GPU can be used.
+
+- **prompts** (*dict[str, str] | None*) – defaults to `None`
+
+    A dictionary with prompts for the model. The key is the prompt name, the value is the prompt text. The prompt text will be prepended before any text to encode. For example: `{"query": "query: ", "passage": "passage: "}` or `{"clustering": "Identify the main category based on the titles in "}`.
+
+- **default_prompt_name** (*str | None*) – defaults to `None`
+
+    The name of the prompt that should be used by default. If not set, no prompt will be applied.
+
+- **similarity_fn_name** (*Union[str, sentence_transformers.similarity_functions.SimilarityFunction, NoneType]*) – defaults to `None`
+
+    The name of the similarity function to use. Valid options are "cosine", "dot", "euclidean", and "manhattan". If not set, it is automatically set to "cosine" if `similarity` or `similarity_pairwise` are called while `model.similarity_fn_name` is still `None`.
+
+- **cache_folder** (*str | None*) – defaults to `None`
+
+    Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable.
+
+- **trust_remote_code** (*bool*) – defaults to `False`
+
+    Whether or not to allow for custom models defined on the Hub in their own modeling files. This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.
+
+- **revision** (*str | None*) – defaults to `None`
+
+    The specific model version to use. It can be a branch name, a tag name, or a commit id, for a stored model on Hugging Face.
+
+- **local_files_only** (*bool*) – defaults to `False`
+
+    Whether or not to only look at local files (i.e., do not try to download the model).
+
+- **token** (*bool | str | None*) – defaults to `None`
+
+    Hugging Face authentication token to download private models.
+
+- **use_auth_token** (*bool | str | None*) – defaults to `None`
+
+    Deprecated argument. Please use `token` instead.
+
+- **truncate_dim** (*int | None*) – defaults to `None`
+
+    The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is only applicable during inference when :meth:`SentenceTransformer.encode` is called.
+
+- **embedding_size** (*int | None*) – defaults to `None`
+
+    The output size of the projection layer. Default to 128.
+
+- **bias** (*bool*) – defaults to `False`
+
+- **query_prefix** (*str | None*) – defaults to `[Q] `
+
+    Prefix to add to the queries.
+
+- **document_prefix** (*str | None*) – defaults to `[D] `
+
+    Prefix to add to the documents.
+
+- **add_special_tokens** (*bool*) – defaults to `True`
+
+    Add the prefix to the inputs.
+
+- **truncation** (*bool*) – defaults to `True`
+
+    Truncate the inputs to the encoder max lengths or use sliding window encoding.
+
+- **query_length** (*int | None*) – defaults to `None`
+
+    The length of the query to truncate/pad to with mask tokens. If set, will override the config value. Default to 32.
+
+- **document_length** (*int | None*) – defaults to `None`
+
+    The max length of the document to truncate. If set, will override the config value. Default to 180.
+
+- **attend_to_expansion_tokens** (*bool*) – defaults to `False`
+
+    Whether to attend to the expansion tokens in the attention layers model. If False, the original tokens will not only attend to the expansion tokens, only the expansion tokens will attend to the original tokens. Default is False (as in the original ColBERT codebase).
+
+- **skiplist_words** (*list[str] | None*) – defaults to `None`
+
+- **model_kwargs** (*dict | None*) – defaults to `None`
+
+    Additional model configuration parameters to be passed to the Huggingface Transformers model. Particularly useful options are:  - ``torch_dtype``: Override the default `torch.dtype` and load the model under a specific `dtype`. The     different options are:          1. ``torch.float16``, ``torch.bfloat16`` or ``torch.float``: load in a specified ``dtype``,         ignoring the model's ``config.torch_dtype`` if one exists. If not specified - the model will get         loaded in ``torch.float`` (fp32).          2. ``"auto"`` - A ``torch_dtype`` entry in the ``config.json`` file of the model will be attempted         to be used. If this entry isn't found then next check the ``dtype`` of the first weight in the         checkpoint that's of a floating point type and use that as ``dtype``. This will load the model using         the ``dtype`` it was saved in at the end of the training. It can't be used as an indicator of how the         model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32. - ``attn_implementation``: The attention implementation to use in the model (if relevant). Can be any of     `"eager"` (manual implementation of the attention), `"sdpa"` (using `F.scaled_dot_product_attention     <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html>`_),     or `"flash_attention_2"` (using `Dao-AILab/flash-attention     <https://github.com/Dao-AILab/flash-attention>`_). By default, if available, SDPA will be used for     torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.  See the `PreTrainedModel.from_pretrained <https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained>`_ documentation for more details.
+
+- **tokenizer_kwargs** (*dict | None*) – defaults to `None`
+
+    Additional tokenizer configuration parameters to be passed to the Huggingface Transformers tokenizer. See the `AutoTokenizer.from_pretrained <https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained>`_ documentation for more details.
+
+- **config_kwargs** (*dict | None*) – defaults to `None`
+
+    Additional model configuration parameters to be passed to the Huggingface Transformers config. See the `AutoConfig.from_pretrained <https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained>`_ documentation for more details.
+
+- **model_card_data** (*Optional[sentence_transformers.model_card.SentenceTransformerModelCardData]*) – defaults to `None`
+
+    A model card data object that contains information about the model. This is used to generate a model card when saving the model. If not set, a default model card data object is created.
+
+
+## Attributes
+
+- **device**
+
+    Get torch.device from module, assuming that the whole module has one device. In case there are no PyTorch parameters, fall back to CPU.
+
+- **max_seq_length**
+
+    Returns the maximal input sequence length for the model. Longer inputs will be truncated.  Returns:     int: The maximal input sequence length.  Example:     ::          from sentence_transformers import SentenceTransformer          model = SentenceTransformer("all-mpnet-base-v2")         print(model.max_seq_length)         # => 384
+
+- **similarity**
+
+    Compute the similarity between two collections of embeddings. The output will be a matrix with the similarity scores between all embeddings from the first parameter and all embeddings from the second parameter. This differs from `similarity_pairwise` which computes the similarity between each pair of embeddings.  Args:     embeddings1 (Union[Tensor, ndarray]): [num_embeddings_1, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.     embeddings2 (Union[Tensor, ndarray]): [num_embeddings_2, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.  Returns:     Tensor: A [num_embeddings_1, num_embeddings_2]-shaped torch tensor with similarity scores.  Example:     ::          >>> model = SentenceTransformer("all-mpnet-base-v2")         >>> sentences = [         ...     "The weather is so nice!",         ...     "It's so sunny outside.",         ...     "He's driving to the movie theater.",         ...     "She's going to the cinema.",         ... ]         >>> embeddings = model.encode(sentences, normalize_embeddings=True)         >>> model.similarity(embeddings, embeddings)         tensor([[1.0000, 0.7235, 0.0290, 0.1309],                 [0.7235, 1.0000, 0.0613, 0.1129],                 [0.0290, 0.0613, 1.0000, 0.5027],                 [0.1309, 0.1129, 0.5027, 1.0000]])         >>> model.similarity_fn_name         "cosine"         >>> model.similarity_fn_name = "euclidean"         >>> model.similarity(embeddings, embeddings)         tensor([[-0.0000, -0.7437, -1.3935, -1.3184],                 [-0.7437, -0.0000, -1.3702, -1.3320],                 [-1.3935, -1.3702, -0.0000, -0.9973],                 [-1.3184, -1.3320, -0.9973, -0.0000]])
+
+- **similarity_fn_name**
+
+    Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`.  Returns:     Optional[str]: The name of the similarity function. Can be None if not set, in which case any uses of     :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise` default to "cosine".  Example:     >>> model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")     >>> model.similarity_fn_name     'dot'
+
+- **similarity_pairwise**
+
+    Compute the similarity between two collections of embeddings. The output will be a vector with the similarity scores between each pair of embeddings.  Args:     embeddings1 (Union[Tensor, ndarray]): [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.     embeddings2 (Union[Tensor, ndarray]): [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.  Returns:     Tensor: A [num_embeddings]-shaped torch tensor with pairwise similarity scores.  Example:     ::          >>> model = SentenceTransformer("all-mpnet-base-v2")         >>> sentences = [         ...     "The weather is so nice!",         ...     "It's so sunny outside.",         ...     "He's driving to the movie theater.",         ...     "She's going to the cinema.",         ... ]         >>> embeddings = model.encode(sentences, normalize_embeddings=True)         >>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])         tensor([0.7235, 0.5027])         >>> model.similarity_fn_name         "cosine"         >>> model.similarity_fn_name = "euclidean"         >>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])         tensor([-0.7437, -0.9973])
+
+- **tokenizer**
+
+    Property to get the tokenizer that is used by this model
+
+
+## Examples
+
+```python
+>>> from pylate import models
+
+>>> model = models.ColBERT(
+...     model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
+...     device="cpu",
+... )
+
+>>> embeddings = model.encode("Hello, how are you?")
+>>> assert isinstance(embeddings, np.ndarray)
+
+>>> embeddings = model.encode([
+...     "Hello, how are you?",
+...     "How is the weather today?"
+... ])
+
+>>> assert len(embeddings) == 2
+>>> assert isinstance(embeddings[0], np.ndarray)
+>>> assert isinstance(embeddings[1], np.ndarray)
+
+>>> embeddings = model.encode([
+...     [
+...         "Hello, how are you?",
+...         "How is the weather today?"
+...     ],
+...     [
+...         "Hello, how are you?",
+...         "How is the weather today?"
+...     ],
+... ])
+
+>>> assert len(embeddings) == 2
+
+>>> model.save_pretrained("test-model")
+
+>>> model = models.ColBERT("test-model")
+
+>>> embeddings = model.encode([
+...     "Hello, how are you?",
+...     "How is the weather today?"
+... ])
+
+>>> assert len(embeddings) == 2
+>>> assert isinstance(embeddings[0], np.ndarray)
+>>> assert isinstance(embeddings[1], np.ndarray)
+```
+
+## Methods
+
+???- note "__call__"
+
+    Call self as a function.
+
+    **Parameters**
+
+    - **args**    
+    - **kwargs**    
+    
+???- note "add_module"
+
+    Add a child module to the current module.
+
+    The module can be accessed as an attribute using the given name.  Args:     name (str): name of the child module. The child module can be         accessed from this module using the given name     module (Module): child module to be added to the module.
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **module**     (*Optional[ForwardRef('Module')]*)    
+    
+???- note "append"
+
+    Append a given module to the end.
+
+    Args:     module (nn.Module): module to append
+
+    **Parameters**
+
+    - **module**     (*torch.nn.modules.module.Module*)    
+    
+???- note "apply"
+
+    Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
+
+    Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`).  Args:     fn (:class:`Module` -> None): function to be applied to each submodule  Returns:     Module: self  Example::      >>> @torch.no_grad()     >>> def init_weights(m):     >>>     print(m)     >>>     if type(m) == nn.Linear:     >>>         m.weight.fill_(1.0)     >>>         print(m.weight)     >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))     >>> net.apply(init_weights)     Linear(in_features=2, out_features=2, bias=True)     Parameter containing:     tensor([[1., 1.],             [1., 1.]], requires_grad=True)     Linear(in_features=2, out_features=2, bias=True)     Parameter containing:     tensor([[1., 1.],             [1., 1.]], requires_grad=True)     Sequential(       (0): Linear(in_features=2, out_features=2, bias=True)       (1): Linear(in_features=2, out_features=2, bias=True)     )
+
+    **Parameters**
+
+    - **fn**     (*Callable[[ForwardRef('Module')], NoneType]*)    
+    
+???- note "bfloat16"
+
+    Casts all floating point parameters and buffers to ``bfloat16`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "buffers"
+
+    Return an iterator over module buffers.
+
+    Args:     recurse (bool): if True, then yields buffers of this module         and all submodules. Otherwise, yields only buffers that         are direct members of this module.  Yields:     torch.Tensor: module buffer  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for buf in model.buffers():     >>>     print(type(buf), buf.size())     <class 'torch.Tensor'> (20L,)     <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
+
+    **Parameters**
+
+    - **recurse**     (*bool*)     – defaults to `True`    
+    
+???- note "children"
+
+    Return an iterator over immediate children modules.
+
+    Yields:     Module: a child module
+
+    
+???- note "compile"
+
+    Compile this Module's forward using :func:`torch.compile`.
+
+    This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`.  See :func:`torch.compile` for details on the arguments for this function.
+
+    **Parameters**
+
+    - **args**    
+    - **kwargs**    
+    
+???- note "cpu"
+
+    Move all model parameters and buffers to the CPU.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "cuda"
+
+    Move all model parameters and buffers to the GPU.
+
+    This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.  .. note::     This method modifies the module in-place.  Args:     device (int, optional): if specified, all parameters will be         copied to that device  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, torch.device, NoneType]*)     – defaults to `None`    
+        Device (like "cuda", "cpu", "mps", "npu") that should be used for computation. If None, checks if a GPU can be used.
+    
+???- note "double"
+
+    Casts all floating point parameters and buffers to ``double`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "encode"
+
+    Computes sentence embeddings.
+
+    **Parameters**
+
+    - **sentences**     (*str | list[str]*)    
+    - **prompt_name**     (*str | None*)     – defaults to `None`    
+    - **prompt**     (*str | None*)     – defaults to `None`    
+    - **batch_size**     (*int*)     – defaults to `32`    
+    - **show_progress_bar**     (*bool*)     – defaults to `None`    
+    - **precision**     (*Literal['float32', 'int8', 'uint8', 'binary', 'ubinary']*)     – defaults to `float32`    
+    - **convert_to_numpy**     (*bool*)     – defaults to `True`    
+    - **convert_to_tensor**     (*bool*)     – defaults to `False`    
+    - **padding**     (*bool*)     – defaults to `False`    
+    - **device**     (*str*)     – defaults to `None`    
+        Device (like "cuda", "cpu", "mps", "npu") that should be used for computation. If None, checks if a GPU can be used.
+    - **normalize_embeddings**     (*bool*)     – defaults to `True`    
+    - **is_query**     (*bool*)     – defaults to `True`    
+    - **pool_factor**     (*int*)     – defaults to `1`    
+    - **protected_tokens**     (*int*)     – defaults to `1`    
+    
+???- note "encode_multi_process"
+
+    Encodes a list of sentences using multiple processes and GPUs via :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>`. The sentences are chunked into smaller packages and sent to individual processes, which encode them on different GPUs or CPUs. This method is only suitable for encoding large sets of sentences.
+
+    **Parameters**
+
+    - **sentences**     (*list[str]*)    
+    - **pool**     (*dict[str, object]*)    
+    - **prompt_name**     (*str | None*)     – defaults to `None`    
+    - **prompt**     (*str | None*)     – defaults to `None`    
+    - **batch_size**     (*int*)     – defaults to `32`    
+    - **chunk_size**     (*int*)     – defaults to `None`    
+    - **precision**     (*Literal['float32', 'int8', 'uint8', 'binary', 'ubinary']*)     – defaults to `float32`    
+    - **normalize_embeddings**     (*bool*)     – defaults to `True`    
+    - **padding**     (*bool*)     – defaults to `False`    
+    - **is_query**     (*bool*)     – defaults to `True`    
+    - **pool_factor**     (*int*)     – defaults to `1`    
+    - **protected_tokens**     (*int*)     – defaults to `1`    
+    
+???- note "eval"
+
+    Set the module in evaluation mode.
+
+    This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc.  This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.  See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it.  Returns:     Module: self
+
+    
+???- note "evaluate"
+
+    Evaluate the model based on an evaluator
+
+    Args:     evaluator (SentenceEvaluator): The evaluator used to evaluate the model.     output_path (str, optional): The path where the evaluator can write the results. Defaults to None.  Returns:     The evaluation results.
+
+    **Parameters**
+
+    - **evaluator**     (*sentence_transformers.evaluation.SentenceEvaluator.SentenceEvaluator*)    
+    - **output_path**     (*str*)     – defaults to `None`    
+    
+???- note "extend"
+
+???- note "extra_repr"
+
+    Set the extra representation of the module.
+
+    To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
+
+    
+???- note "fit"
+
+    Deprecated training method from before Sentence Transformers v3.0, it is recommended to use :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` instead. This method uses :class:`~sentence_transformers.trainer.SentenceTransformerTrainer` behind the scenes, but does not provide as much flexibility as the Trainer itself.
+
+    This training approach uses a list of DataLoaders and Loss functions to train the model. Each DataLoader is sampled in turn for one batch. We sample only as many batches from each DataLoader as there are in the smallest one to make sure of equal training with each dataset, i.e. round robin sampling.  This method should produce equivalent results in v3.0+ as before v3.0, but if you encounter any issues with your existing training scripts, then you may wish to use :meth:`SentenceTransformer.old_fit <sentence_transformers.SentenceTransformer.old_fit>` instead. That uses the old training method from before v3.0.  Args:     train_objectives: Tuples of (DataLoader, LossFunction). Pass         more than one for multi-task learning     evaluator: An evaluator (sentence_transformers.evaluation)         evaluates the model performance during training on held-         out dev data. It is used to determine the best model         that is saved to disc.     epochs: Number of epochs for training     steps_per_epoch: Number of training steps per epoch. If set         to None (default), one epoch is equal the DataLoader         size from train_objectives.     scheduler: Learning rate scheduler. Available schedulers:         constantlr, warmupconstant, warmuplinear, warmupcosine,         warmupcosinewithhardrestarts     warmup_steps: Behavior depends on the scheduler. For         WarmupLinear (default), the learning rate is increased         from o up to the maximal learning rate. After these many         training steps, the learning rate is decreased linearly         back to zero.     optimizer_class: Optimizer     optimizer_params: Optimizer parameters     weight_decay: Weight decay for model parameters     evaluation_steps: If > 0, evaluate the model using evaluator         after each number of training steps     output_path: Storage path for the model and evaluation files     save_best_model: If true, the best model (according to         evaluator) is stored at output_path     max_grad_norm: Used for gradient normalization.     use_amp: Use Automatic Mixed Precision (AMP). Only for         Pytorch >= 1.6.0     callback: Callback function that is invoked after each         evaluation. It must accept the following three         parameters in this order: `score`, `epoch`, `steps`     show_progress_bar: If True, output a tqdm progress bar     checkpoint_path: Folder to save checkpoints during training     checkpoint_save_steps: Will save a checkpoint after so many         steps     checkpoint_save_total_limit: Total number of checkpoints to         store
+
+    **Parameters**
+
+    - **train_objectives**     (*Iterable[Tuple[torch.utils.data.dataloader.DataLoader, torch.nn.modules.module.Module]]*)    
+    - **evaluator**     (*sentence_transformers.evaluation.SentenceEvaluator.SentenceEvaluator*)     – defaults to `None`    
+    - **epochs**     (*int*)     – defaults to `1`    
+    - **steps_per_epoch**     – defaults to `None`    
+    - **scheduler**     (*str*)     – defaults to `WarmupLinear`    
+    - **warmup_steps**     (*int*)     – defaults to `10000`    
+    - **optimizer_class**     (*Type[torch.optim.optimizer.Optimizer]*)     – defaults to `<class 'torch.optim.adamw.AdamW'>`    
+    - **optimizer_params**     (*Dict[str, object]*)     – defaults to `{'lr': 2e-05}`    
+    - **weight_decay**     (*float*)     – defaults to `0.01`    
+    - **evaluation_steps**     (*int*)     – defaults to `0`    
+    - **output_path**     (*str*)     – defaults to `None`    
+    - **save_best_model**     (*bool*)     – defaults to `True`    
+    - **max_grad_norm**     (*float*)     – defaults to `1`    
+    - **use_amp**     (*bool*)     – defaults to `False`    
+    - **callback**     (*Callable[[float, int, int], NoneType]*)     – defaults to `None`    
+    - **show_progress_bar**     (*bool*)     – defaults to `True`    
+    - **checkpoint_path**     (*str*)     – defaults to `None`    
+    - **checkpoint_save_steps**     (*int*)     – defaults to `500`    
+    - **checkpoint_save_total_limit**     (*int*)     – defaults to `0`    
+    
+???- note "float"
+
+    Casts all floating point parameters and buffers to ``float`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "forward"
+
+    Define the computation performed at every call.
+
+    Should be overridden by all subclasses.  .. note::     Although the recipe for forward pass needs to be defined within     this function, one should call the :class:`Module` instance afterwards     instead of this since the former takes care of running the     registered hooks while the latter silently ignores them.
+
+    **Parameters**
+
+    - **input**     (*Any*)    
+    
+???- note "get_buffer"
+
+    Return the buffer given by ``target`` if it exists, otherwise throw an error.
+
+    See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``.  Args:     target: The fully-qualified string name of the buffer         to look for. (See ``get_submodule`` for how to specify a         fully-qualified string.)  Returns:     torch.Tensor: The buffer referenced by ``target``  Raises:     AttributeError: If the target string references an invalid         path or resolves to something that is not a         buffer
+
+    **Parameters**
+
+    - **target**     (*str*)    
+    
+???- note "get_extra_state"
+
+    Return any extra state to include in the module's state_dict.
+
+    Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`.  Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.  Returns:     object: Any extra state to store in the module's state_dict
+
+    
+???- note "get_max_seq_length"
+
+    Returns the maximal sequence length that the model accepts. Longer inputs will be truncated.
+
+    Returns:     Optional[int]: The maximal sequence length that the model accepts, or None if it is not defined.
+
+    
+???- note "get_parameter"
+
+    Return the parameter given by ``target`` if it exists, otherwise throw an error.
+
+    See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``.  Args:     target: The fully-qualified string name of the Parameter         to look for. (See ``get_submodule`` for how to specify a         fully-qualified string.)  Returns:     torch.nn.Parameter: The Parameter referenced by ``target``  Raises:     AttributeError: If the target string references an invalid         path or resolves to something that is not an         ``nn.Parameter``
+
+    **Parameters**
+
+    - **target**     (*str*)    
+    
+???- note "get_sentence_embedding_dimension"
+
+    Returns the number of dimensions in the output of :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>`.
+
+    Returns:     Optional[int]: The number of dimensions in the output of `encode`. If it's not known, it's `None`.
+
+    
+???- note "get_sentence_features"
+
+???- note "get_submodule"
+
+    Return the submodule given by ``target`` if it exists, otherwise throw an error.
+
+    For example, let's say you have an ``nn.Module`` ``A`` that looks like this:  .. code-block:: text      A(         (net_b): Module(             (net_c): Module(                 (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))             )             (linear): Linear(in_features=100, out_features=200, bias=True)         )     )  (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.)  To check whether or not we have the ``linear`` submodule, we would call ``get_submodule("net_b.linear")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule("net_b.net_c.conv")``.  The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used.  Args:     target: The fully-qualified string name of the submodule         to look for. (See above example for how to specify a         fully-qualified string.)  Returns:     torch.nn.Module: The submodule referenced by ``target``  Raises:     AttributeError: If the target string references an invalid         path or resolves to something that is not an         ``nn.Module``
+
+    **Parameters**
+
+    - **target**     (*str*)    
+    
+???- note "gradient_checkpointing_enable"
+
+???- note "half"
+
+    Casts all floating point parameters and buffers to ``half`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "insert"
+
+???- note "insert_prefix_token"
+
+    Inserts a prefix token at the beginning of each sequence in the input tensor.
+
+    **Parameters**
+
+    - **input_ids**     (*torch.Tensor*)    
+    - **prefix_id**     (*int*)    
+    
+???- note "ipu"
+
+    Move all model parameters and buffers to the IPU.
+
+    This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized.  .. note::     This method modifies the module in-place.  Arguments:     device (int, optional): if specified, all parameters will be         copied to that device  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, torch.device, NoneType]*)     – defaults to `None`    
+        Device (like "cuda", "cpu", "mps", "npu") that should be used for computation. If None, checks if a GPU can be used.
+    
+???- note "load"
+
+???- note "load_state_dict"
+
+    Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.
+
+    If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function.  .. warning::     If :attr:`assign` is ``True`` the optimizer must be created after     the call to :attr:`load_state_dict` unless     :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.  Args:     state_dict (dict): a dict containing parameters and         persistent buffers.     strict (bool, optional): whether to strictly enforce that the keys         in :attr:`state_dict` match the keys returned by this module's         :meth:`~torch.nn.Module.state_dict` function. Default: ``True``     assign (bool, optional): When ``False``, the properties of the tensors         in the current module are preserved while when ``True``, the         properties of the Tensors in the state dict are preserved. The only         exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s         for which the value from the module is preserved.         Default: ``False``  Returns:     ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:         * **missing_keys** is a list of str containing the missing keys         * **unexpected_keys** is a list of str containing the unexpected keys  Note:     If a parameter or buffer is registered as ``None`` and its corresponding key     exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a     ``RuntimeError``.
+
+    **Parameters**
+
+    - **state_dict**     (*Mapping[str, Any]*)    
+    - **strict**     (*bool*)     – defaults to `True`    
+    - **assign**     (*bool*)     – defaults to `False`    
+    
+???- note "modules"
+
+    Return an iterator over all modules in the network.
+
+    Yields:     Module: a module in the network  Note:     Duplicate modules are returned only once. In the following     example, ``l`` will be returned only once.  Example::      >>> l = nn.Linear(2, 2)     >>> net = nn.Sequential(l, l)     >>> for idx, m in enumerate(net.modules()):     ...     print(idx, '->', m)      0 -> Sequential(       (0): Linear(in_features=2, out_features=2, bias=True)       (1): Linear(in_features=2, out_features=2, bias=True)     )     1 -> Linear(in_features=2, out_features=2, bias=True)
+
+    
+???- note "named_buffers"
+
+    Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
+
+    Args:     prefix (str): prefix to prepend to all buffer names.     recurse (bool, optional): if True, then yields buffers of this module         and all submodules. Otherwise, yields only buffers that         are direct members of this module. Defaults to True.     remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.  Yields:     (str, torch.Tensor): Tuple containing the name and buffer  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for name, buf in self.named_buffers():     >>>     if name in ['running_var']:     >>>         print(buf.size())
+
+    **Parameters**
+
+    - **prefix**     (*str*)     – defaults to ``    
+    - **recurse**     (*bool*)     – defaults to `True`    
+    - **remove_duplicate**     (*bool*)     – defaults to `True`    
+    
+???- note "named_children"
+
+    Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
+
+    Yields:     (str, Module): Tuple containing a name and child module  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for name, module in model.named_children():     >>>     if name in ['conv4', 'conv5']:     >>>         print(module)
+
+    
+???- note "named_modules"
+
+    Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
+
+    Args:     memo: a memo to store the set of modules already added to the result     prefix: a prefix that will be added to the name of the module     remove_duplicate: whether to remove the duplicated module instances in the result         or not  Yields:     (str, Module): Tuple of name and module  Note:     Duplicate modules are returned only once. In the following     example, ``l`` will be returned only once.  Example::      >>> l = nn.Linear(2, 2)     >>> net = nn.Sequential(l, l)     >>> for idx, m in enumerate(net.named_modules()):     ...     print(idx, '->', m)      0 -> ('', Sequential(       (0): Linear(in_features=2, out_features=2, bias=True)       (1): Linear(in_features=2, out_features=2, bias=True)     ))     1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
+
+    **Parameters**
+
+    - **memo**     (*Optional[Set[ForwardRef('Module')]]*)     – defaults to `None`    
+    - **prefix**     (*str*)     – defaults to ``    
+    - **remove_duplicate**     (*bool*)     – defaults to `True`    
+    
+???- note "named_parameters"
+
+    Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
+
+    Args:     prefix (str): prefix to prepend to all parameter names.     recurse (bool): if True, then yields parameters of this module         and all submodules. Otherwise, yields only parameters that         are direct members of this module.     remove_duplicate (bool, optional): whether to remove the duplicated         parameters in the result. Defaults to True.  Yields:     (str, Parameter): Tuple containing the name and parameter  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for name, param in self.named_parameters():     >>>     if name in ['bias']:     >>>         print(param.size())
+
+    **Parameters**
+
+    - **prefix**     (*str*)     – defaults to ``    
+    - **recurse**     (*bool*)     – defaults to `True`    
+    - **remove_duplicate**     (*bool*)     – defaults to `True`    
+    
+???- note "old_fit"
+
+    Deprecated training method from before Sentence Transformers v3.0, it is recommended to use :class:`sentence_transformers.trainer.SentenceTransformerTrainer` instead. This method should only be used if you encounter issues with your existing training scripts after upgrading to v3.0+.
+
+    This training approach uses a list of DataLoaders and Loss functions to train the model. Each DataLoader is sampled in turn for one batch. We sample only as many batches from each DataLoader as there are in the smallest one to make sure of equal training with each dataset, i.e. round robin sampling.  Args:     train_objectives: Tuples of (DataLoader, LossFunction). Pass         more than one for multi-task learning     evaluator: An evaluator (sentence_transformers.evaluation)         evaluates the model performance during training on held-         out dev data. It is used to determine the best model         that is saved to disc.     epochs: Number of epochs for training     steps_per_epoch: Number of training steps per epoch. If set         to None (default), one epoch is equal the DataLoader         size from train_objectives.     scheduler: Learning rate scheduler. Available schedulers:         constantlr, warmupconstant, warmuplinear, warmupcosine,         warmupcosinewithhardrestarts     warmup_steps: Behavior depends on the scheduler. For         WarmupLinear (default), the learning rate is increased         from o up to the maximal learning rate. After these many         training steps, the learning rate is decreased linearly         back to zero.     optimizer_class: Optimizer     optimizer_params: Optimizer parameters     weight_decay: Weight decay for model parameters     evaluation_steps: If > 0, evaluate the model using evaluator         after each number of training steps     output_path: Storage path for the model and evaluation files     save_best_model: If true, the best model (according to         evaluator) is stored at output_path     max_grad_norm: Used for gradient normalization.     use_amp: Use Automatic Mixed Precision (AMP). Only for         Pytorch >= 1.6.0     callback: Callback function that is invoked after each         evaluation. It must accept the following three         parameters in this order: `score`, `epoch`, `steps`     show_progress_bar: If True, output a tqdm progress bar     checkpoint_path: Folder to save checkpoints during training     checkpoint_save_steps: Will save a checkpoint after so many         steps     checkpoint_save_total_limit: Total number of checkpoints to         store
+
+    **Parameters**
+
+    - **train_objectives**     (*Iterable[Tuple[torch.utils.data.dataloader.DataLoader, torch.nn.modules.module.Module]]*)    
+    - **evaluator**     (*sentence_transformers.evaluation.SentenceEvaluator.SentenceEvaluator*)     – defaults to `None`    
+    - **epochs**     (*int*)     – defaults to `1`    
+    - **steps_per_epoch**     – defaults to `None`    
+    - **scheduler**     (*str*)     – defaults to `WarmupLinear`    
+    - **warmup_steps**     (*int*)     – defaults to `10000`    
+    - **optimizer_class**     (*Type[torch.optim.optimizer.Optimizer]*)     – defaults to `<class 'torch.optim.adamw.AdamW'>`    
+    - **optimizer_params**     (*Dict[str, object]*)     – defaults to `{'lr': 2e-05}`    
+    - **weight_decay**     (*float*)     – defaults to `0.01`    
+    - **evaluation_steps**     (*int*)     – defaults to `0`    
+    - **output_path**     (*str*)     – defaults to `None`    
+    - **save_best_model**     (*bool*)     – defaults to `True`    
+    - **max_grad_norm**     (*float*)     – defaults to `1`    
+    - **use_amp**     (*bool*)     – defaults to `False`    
+    - **callback**     (*Callable[[float, int, int], NoneType]*)     – defaults to `None`    
+    - **show_progress_bar**     (*bool*)     – defaults to `True`    
+    - **checkpoint_path**     (*str*)     – defaults to `None`    
+    - **checkpoint_save_steps**     (*int*)     – defaults to `500`    
+    - **checkpoint_save_total_limit**     (*int*)     – defaults to `0`    
+    
+???- note "parameters"
+
+    Return an iterator over module parameters.
+
+    This is typically passed to an optimizer.  Args:     recurse (bool): if True, then yields parameters of this module         and all submodules. Otherwise, yields only parameters that         are direct members of this module.  Yields:     Parameter: module parameter  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for param in model.parameters():     >>>     print(type(param), param.size())     <class 'torch.Tensor'> (20L,)     <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
+
+    **Parameters**
+
+    - **recurse**     (*bool*)     – defaults to `True`    
+    
+???- note "pool_embeddings_hierarchical"
+
+    Pools the embeddings hierarchically by clustering and averaging them.
+
+    **Parameters**
+
+    - **documents_embeddings**     (*list[torch.Tensor]*)    
+    - **pool_factor**     (*int*)     – defaults to `1`    
+    - **protected_tokens**     (*int*)     – defaults to `1`    
+    
+    **Returns**
+
+    *list*:     A list of pooled embeddings for each document.
+    
+???- note "pop"
+
+???- note "push_to_hub"
+
+    Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.
+
+    Args:     repo_id (str): Repository name for your model in the Hub, including the user or organization.     token (str, optional): An authentication token (See https://huggingface.co/settings/token)     private (bool, optional): Set to true, for hosting a private model     safe_serialization (bool, optional): If true, save the model using safetensors. If false, save the model the traditional PyTorch way     commit_message (str, optional): Message to commit while pushing.     local_model_path (str, optional): Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded     exist_ok (bool, optional): If true, saving to an existing repository is OK. If false, saving only to a new repository is possible     replace_model_card (bool, optional): If true, replace an existing model card in the hub with the automatically created model card     train_datasets (List[str], optional): Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.  Returns:     str: The url of the commit of your model in the repository on the Hugging Face Hub.
+
+    **Parameters**
+
+    - **repo_id**     (*str*)    
+    - **token**     (*Optional[str]*)     – defaults to `None`    
+        Hugging Face authentication token to download private models.
+    - **private**     (*Optional[bool]*)     – defaults to `None`    
+    - **safe_serialization**     (*bool*)     – defaults to `True`    
+    - **commit_message**     (*str*)     – defaults to `Add new SentenceTransformer model.`    
+    - **local_model_path**     (*Optional[str]*)     – defaults to `None`    
+    - **exist_ok**     (*bool*)     – defaults to `False`    
+    - **replace_model_card**     (*bool*)     – defaults to `False`    
+    - **train_datasets**     (*Optional[List[str]]*)     – defaults to `None`    
+    
+???- note "register_backward_hook"
+
+    Register a backward hook on the module.
+
+    This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and the behavior of this function will change in future versions.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]*)    
+    
+???- note "register_buffer"
+
+    Add a buffer to the module.
+
+    This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`.  Buffers can be accessed as attributes using given names.  Args:     name (str): name of the buffer. The buffer can be accessed         from this module using the given name     tensor (Tensor or None): buffer to be registered. If ``None``, then operations         that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,         the buffer is **not** included in the module's :attr:`state_dict`.     persistent (bool): whether the buffer is part of this module's         :attr:`state_dict`.  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> self.register_buffer('running_mean', torch.zeros(num_features))
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **tensor**     (*Optional[torch.Tensor]*)    
+    - **persistent**     (*bool*)     – defaults to `True`    
+    
+???- note "register_forward_hook"
+
+    Register a forward hook on the module.
+
+    The hook will be called every time after :func:`forward` has computed an output.  If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature::      hook(module, args, output) -> None or modified output  If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature::      hook(module, args, kwargs, output) -> None or modified output  Args:     hook (Callable): The user defined hook to be registered.     prepend (bool): If ``True``, the provided ``hook`` will be fired         before all existing ``forward`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``forward`` hooks on         this :class:`torch.nn.modules.Module`. Note that global         ``forward`` hooks registered with         :func:`register_module_forward_hook` will fire before all hooks         registered by this method.         Default: ``False``     with_kwargs (bool): If ``True``, the ``hook`` will be passed the         kwargs given to the forward function.         Default: ``False``     always_call (bool): If ``True`` the ``hook`` will be run regardless of         whether an exception is raised while calling the Module.         Default: ``False``  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Union[Callable[[~T, Tuple[Any, ...], Any], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    - **with_kwargs**     (*bool*)     – defaults to `False`    
+    - **always_call**     (*bool*)     – defaults to `False`    
+    
+???- note "register_forward_pre_hook"
+
+    Register a forward pre-hook on the module.
+
+    The hook will be called every time before :func:`forward` is invoked.  If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature::      hook(module, args) -> None or modified input  If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature::      hook(module, args, kwargs) -> None or a tuple of modified input and kwargs  Args:     hook (Callable): The user defined hook to be registered.     prepend (bool): If true, the provided ``hook`` will be fired before         all existing ``forward_pre`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``forward_pre`` hooks         on this :class:`torch.nn.modules.Module`. Note that global         ``forward_pre`` hooks registered with         :func:`register_module_forward_pre_hook` will fire before all         hooks registered by this method.         Default: ``False``     with_kwargs (bool): If true, the ``hook`` will be passed the kwargs         given to the forward function.         Default: ``False``  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Union[Callable[[~T, Tuple[Any, ...]], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    - **with_kwargs**     (*bool*)     – defaults to `False`    
+    
+???- note "register_full_backward_hook"
+
+    Register a backward hook on the module.
+
+    The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature::      hook(module, grad_input, grad_output) -> tuple(Tensor) or None  The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments.  For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.  .. warning ::     Modifying inputs or outputs inplace is not allowed when using backward hooks and     will raise an error.  Args:     hook (Callable): The user-defined hook to be registered.     prepend (bool): If true, the provided ``hook`` will be fired before         all existing ``backward`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``backward`` hooks on         this :class:`torch.nn.modules.Module`. Note that global         ``backward`` hooks registered with         :func:`register_module_full_backward_hook` will fire before         all hooks registered by this method.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    
+???- note "register_full_backward_pre_hook"
+
+    Register a backward pre-hook on the module.
+
+    The hook will be called every time the gradients for the module are computed. The hook should have the following signature::      hook(module, grad_output) -> tuple[Tensor] or None  The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments.  For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.  .. warning ::     Modifying inputs inplace is not allowed when using backward hooks and     will raise an error.  Args:     hook (Callable): The user-defined hook to be registered.     prepend (bool): If true, the provided ``hook`` will be fired before         all existing ``backward_pre`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``backward_pre`` hooks         on this :class:`torch.nn.modules.Module`. Note that global         ``backward_pre`` hooks registered with         :func:`register_module_full_backward_pre_hook` will fire before         all hooks registered by this method.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    
+???- note "register_load_state_dict_post_hook"
+
+    Register a post hook to be run after module's ``load_state_dict`` is called.
+
+    It should have the following signature::     hook(module, incompatible_keys) -> None  The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.  The given incompatible_keys can be modified inplace if needed.  Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**    
+    
+???- note "register_module"
+
+    Alias for :func:`add_module`.
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **module**     (*Optional[ForwardRef('Module')]*)    
+    
+???- note "register_parameter"
+
+    Add a parameter to the module.
+
+    The parameter can be accessed as an attribute using given name.  Args:     name (str): name of the parameter. The parameter can be accessed         from this module using the given name     param (Parameter or None): parameter to be added to the module. If         ``None``, then operations that run on parameters, such as :attr:`cuda`,         are ignored. If ``None``, the parameter is **not** included in the         module's :attr:`state_dict`.
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **param**     (*Optional[torch.nn.parameter.Parameter]*)    
+    
+???- note "register_state_dict_pre_hook"
+
+    Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.
+
+    These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made.
+
+    **Parameters**
+
+    - **hook**    
+    
+???- note "requires_grad_"
+
+    Change if autograd should record operations on parameters in this module.
+
+    This method sets the parameters' :attr:`requires_grad` attributes in-place.  This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).  See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it.  Args:     requires_grad (bool): whether autograd should record operations on                           parameters in this module. Default: ``True``.  Returns:     Module: self
+
+    **Parameters**
+
+    - **requires_grad**     (*bool*)     – defaults to `True`    
+    
+???- note "save"
+
+    Saves a model and its configuration files to a directory, so that it can be loaded with ``SentenceTransformer(path)`` again.
+
+    Args:     path (str): Path on disc where the model will be saved.     model_name (str, optional): Optional model name.     create_model_card (bool, optional): If True, create a README.md with basic information about this model.     train_datasets (list[str], optional): Optional list with the names of the datasets used to train the model.     safe_serialization (bool, optional): If True, save the model using safetensors. If False, save the model         the traditional (but unsafe) PyTorch way.
+
+    **Parameters**
+
+    - **path**     (*str*)    
+    - **model_name**     (*str | None*)     – defaults to `None`    
+    - **create_model_card**     (*bool*)     – defaults to `True`    
+    - **train_datasets**     (*list[str] | None*)     – defaults to `None`    
+    - **safe_serialization**     (*bool*)     – defaults to `True`    
+    
+???- note "save_pretrained"
+
+    Saves a model and its configuration files to a directory, so that it can be loaded with ``SentenceTransformer(path)`` again.
+
+    Args:     path (str): Path on disc where the model will be saved.     model_name (str, optional): Optional model name.     create_model_card (bool, optional): If True, create a README.md with basic information about this model.     train_datasets (List[str], optional): Optional list with the names of the datasets used to train the model.     safe_serialization (bool, optional): If True, save the model using safetensors. If False, save the model         the traditional (but unsafe) PyTorch way.
+
+    **Parameters**
+
+    - **path**     (*str*)    
+    - **model_name**     (*Optional[str]*)     – defaults to `None`    
+    - **create_model_card**     (*bool*)     – defaults to `True`    
+    - **train_datasets**     (*Optional[List[str]]*)     – defaults to `None`    
+    - **safe_serialization**     (*bool*)     – defaults to `True`    
+    
+???- note "save_to_hub"
+
+    DEPRECATED, use `push_to_hub` instead.
+
+    Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.  Args:     repo_id (str): Repository name for your model in the Hub, including the user or organization.     token (str, optional): An authentication token (See https://huggingface.co/settings/token)     private (bool, optional): Set to true, for hosting a private model     safe_serialization (bool, optional): If true, save the model using safetensors. If false, save the model the traditional PyTorch way     commit_message (str, optional): Message to commit while pushing.     local_model_path (str, optional): Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded     exist_ok (bool, optional): If true, saving to an existing repository is OK. If false, saving only to a new repository is possible     replace_model_card (bool, optional): If true, replace an existing model card in the hub with the automatically created model card     train_datasets (List[str], optional): Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.  Returns:     str: The url of the commit of your model in the repository on the Hugging Face Hub.
+
+    **Parameters**
+
+    - **repo_id**     (*str*)    
+    - **organization**     (*Optional[str]*)     – defaults to `None`    
+    - **token**     (*Optional[str]*)     – defaults to `None`    
+        Hugging Face authentication token to download private models.
+    - **private**     (*Optional[bool]*)     – defaults to `None`    
+    - **safe_serialization**     (*bool*)     – defaults to `True`    
+    - **commit_message**     (*str*)     – defaults to `Add new SentenceTransformer model.`    
+    - **local_model_path**     (*Optional[str]*)     – defaults to `None`    
+    - **exist_ok**     (*bool*)     – defaults to `False`    
+    - **replace_model_card**     (*bool*)     – defaults to `False`    
+    - **train_datasets**     (*Optional[List[str]]*)     – defaults to `None`    
+    
+???- note "set_extra_state"
+
+    Set extra state contained in the loaded `state_dict`.
+
+    This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`.  Args:     state (dict): Extra state from the `state_dict`
+
+    **Parameters**
+
+    - **state**     (*Any*)    
+    
+???- note "set_pooling_include_prompt"
+
+    Sets the `include_prompt` attribute in the pooling layer in the model, if there is one.
+
+    This is useful for INSTRUCTOR models, as the prompt should be excluded from the pooling strategy for these models.  Args:     include_prompt (bool): Whether to include the prompt in the pooling layer.  Returns:     None
+
+    **Parameters**
+
+    - **include_prompt**     (*bool*)    
+    
+???- note "share_memory"
+
+    See :meth:`torch.Tensor.share_memory_`.
+
+    
+???- note "skiplist_mask"
+
+    Create a mask for the set of input_ids that are in the skiplist.
+
+    **Parameters**
+
+    - **input_ids**     (*torch.Tensor*)    
+    - **skiplist**     (*list[int]*)    
+    
+???- note "smart_batching_collate"
+
+    Transforms a batch from a SmartBatchingDataset to a batch of tensors for the model Here, batch is a list of InputExample instances: [InputExample(...), ...]
+
+    Args:     batch: a batch from a SmartBatchingDataset  Returns:     a batch of tensors for the model
+
+    **Parameters**
+
+    - **batch**     (*List[ForwardRef('InputExample')]*)    
+    
+???- note "start_multi_process_pool"
+
+    Starts a multi-process pool to process the encoding with several independent processes. This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised to start only one process per GPU. This method works together with encode_multi_process and stop_multi_process_pool.
+
+    **Parameters**
+
+    - **target_devices**     (*list[str]*)     – defaults to `None`    
+    
+    **Returns**
+
+    *dict*:     A dictionary with the target processes, an input queue, and an output queue.
+    
+???- note "state_dict"
+
+    Return a dictionary containing references to the whole state of the module.
+
+    Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included.  .. note::     The returned object is a shallow copy. It contains references     to the module's parameters and buffers.  .. warning::     Currently ``state_dict()`` also accepts positional arguments for     ``destination``, ``prefix`` and ``keep_vars`` in order. However,     this is being deprecated and keyword arguments will be enforced in     future releases.  .. warning::     Please avoid the use of argument ``destination`` as it is not     designed for end-users.  Args:     destination (dict, optional): If provided, the state of module will         be updated into the dict and the same object is returned.         Otherwise, an ``OrderedDict`` will be created and returned.         Default: ``None``.     prefix (str, optional): a prefix added to parameter and buffer         names to compose the keys in state_dict. Default: ``''``.     keep_vars (bool, optional): by default the :class:`~torch.Tensor` s         returned in the state dict are detached from autograd. If it's         set to ``True``, detaching will not be performed.         Default: ``False``.  Returns:     dict:         a dictionary containing a whole state of the module  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> module.state_dict().keys()     ['bias', 'weight']
+
+    **Parameters**
+
+    - **args**    
+    - **destination**     – defaults to `None`    
+    - **prefix**     – defaults to ``    
+    - **keep_vars**     – defaults to `False`    
+    
+???- note "stop_multi_process_pool"
+
+    Stops all processes started with start_multi_process_pool.
+
+    Args:     pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list.  Returns:     None
+
+    - **pool**     (*Dict[Literal['input', 'output', 'processes'], Any]*)    
+    
+???- note "to"
+
+    Move and/or cast the parameters and buffers.
+
+    This can be called as  .. function:: to(device=None, dtype=None, non_blocking=False)    :noindex:  .. function:: to(dtype, non_blocking=False)    :noindex:  .. function:: to(tensor, non_blocking=False)    :noindex:  .. function:: to(memory_format=torch.channels_last)    :noindex:  Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype`\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.  See below for examples.  .. note::     This method modifies the module in-place.  Args:     device (:class:`torch.device`): the desired device of the parameters         and buffers in this module     dtype (:class:`torch.dtype`): the desired floating point or complex dtype of         the parameters and buffers in this module     tensor (torch.Tensor): Tensor whose dtype and device are the desired         dtype and device for all parameters and buffers in this module     memory_format (:class:`torch.memory_format`): the desired memory         format for 4D parameters and buffers in this module (keyword         only argument)  Returns:     Module: self  Examples::      >>> # xdoctest: +IGNORE_WANT("non-deterministic")     >>> linear = nn.Linear(2, 2)     >>> linear.weight     Parameter containing:     tensor([[ 0.1913, -0.3420],             [-0.5113, -0.2325]])     >>> linear.to(torch.double)     Linear(in_features=2, out_features=2, bias=True)     >>> linear.weight     Parameter containing:     tensor([[ 0.1913, -0.3420],             [-0.5113, -0.2325]], dtype=torch.float64)     >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)     >>> gpu1 = torch.device("cuda:1")     >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)     Linear(in_features=2, out_features=2, bias=True)     >>> linear.weight     Parameter containing:     tensor([[ 0.1914, -0.3420],             [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')     >>> cpu = torch.device("cpu")     >>> linear.to(cpu)     Linear(in_features=2, out_features=2, bias=True)     >>> linear.weight     Parameter containing:     tensor([[ 0.1914, -0.3420],             [-0.5112, -0.2324]], dtype=torch.float16)      >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)     >>> linear.weight     Parameter containing:     tensor([[ 0.3741+0.j,  0.2382+0.j],             [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)     >>> linear(torch.ones(3, 2, dtype=torch.cdouble))     tensor([[0.6122+0.j, 0.1150+0.j],             [0.6122+0.j, 0.1150+0.j],             [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
+
+    **Parameters**
+
+    - **args**    
+    - **kwargs**    
+    
+???- note "to_empty"
+
+    Move the parameters and buffers to the specified device without copying storage.
+
+    Args:     device (:class:`torch.device`): The desired device of the parameters         and buffers in this module.     recurse (bool): Whether parameters and buffers of submodules should         be recursively moved to the specified device.  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, str, torch.device, NoneType]*)    
+        Device (like "cuda", "cpu", "mps", "npu") that should be used for computation. If None, checks if a GPU can be used.
+    - **recurse**     (*bool*)     – defaults to `True`    
+    
+???- note "tokenize"
+
+    Tokenizes the input texts.
+
+    Args:     texts (Union[list[str], list[dict], list[tuple[str, str]]]): A list of texts to be tokenized.     is_query (bool): Flag to indicate if the texts are queries. Defaults to True.     pad_document (bool): Flag to indicate if documents should be padded to max length. Defaults to False.  Returns:     dict[str, torch.Tensor]: A dictionary of tensors with the tokenized texts, including "input_ids",         "attention_mask", and optionally "token_type_ids".
+
+    **Parameters**
+
+    - **texts**     (*list[str] | list[dict] | list[tuple[str, str]]*)    
+    - **is_query**     (*bool*)     – defaults to `True`    
+    - **pad_document**     (*bool*)     – defaults to `False`    
+    
+???- note "train"
+
+    Set the module in training mode.
+
+    This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc.  Args:     mode (bool): whether to set training mode (``True``) or evaluation                  mode (``False``). Default: ``True``.  Returns:     Module: self
+
+    **Parameters**
+
+    - **mode**     (*bool*)     – defaults to `True`    
+    
+???- note "truncate_sentence_embeddings"
+
+    In this context, :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>` outputs sentence embeddings truncated at dimension ``truncate_dim``.
+
+    This may be useful when you are using the same model for different applications where different dimensions are needed.  Args:     truncate_dim (int, optional): The dimension to truncate sentence embeddings to. ``None`` does no truncation.  Example:     ::          from sentence_transformers import SentenceTransformer          model = SentenceTransformer("all-mpnet-base-v2")          with model.truncate_sentence_embeddings(truncate_dim=16):             embeddings_truncated = model.encode(["hello there", "hiya"])         assert embeddings_truncated.shape[-1] == 16
+
+    **Parameters**
+
+    - **truncate_dim**     (*Optional[int]*)    
+        The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is only applicable during inference when :meth:`SentenceTransformer.encode` is called.
+    
+???- note "type"
+
+    Casts all parameters and buffers to :attr:`dst_type`.
+
+    .. note::     This method modifies the module in-place.  Args:     dst_type (type or string): the desired type  Returns:     Module: self
+
+    **Parameters**
+
+    - **dst_type**     (*Union[torch.dtype, str]*)    
+    
+???- note "xpu"
+
+    Move all model parameters and buffers to the XPU.
+
+    This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.  .. note::     This method modifies the module in-place.  Arguments:     device (int, optional): if specified, all parameters will be         copied to that device  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, torch.device, NoneType]*)     – defaults to `None`    
+        Device (like "cuda", "cpu", "mps", "npu") that should be used for computation. If None, checks if a GPU can be used.
+    
+???- note "zero_grad"
+
+    Reset gradients of all model parameters.
+
+    See similar function under :class:`torch.optim.Optimizer` for more context.  Args:     set_to_none (bool): instead of setting to zero, set the grads to None.         See :meth:`torch.optim.Optimizer.zero_grad` for details.
+
+    **Parameters**
+
+    - **set_to_none**     (*bool*)     – defaults to `True`    
+    
diff --git a/docs/api/models/Dense.md b/docs/api/models/Dense.md
new file mode 100644
index 0000000..b9007dd
--- /dev/null
+++ b/docs/api/models/Dense.md
@@ -0,0 +1,525 @@
+# Dense
+
+Performs linear projection on the token embeddings to a lower dimension.
+
+
+
+## Parameters
+
+- **in_features** (*int*)
+
+    Size of the embeddings in output of the tansformer.
+
+- **out_features** (*int*)
+
+    Size of the output embeddings after linear projection
+
+- **bias** (*bool*) – defaults to `True`
+
+    Add a bias vector
+
+- **activation_function** – defaults to `Identity()`
+
+- **init_weight** (*torch.Tensor*) – defaults to `None`
+
+    Initial value for the matrix of the linear layer
+
+- **init_bias** (*torch.Tensor*) – defaults to `None`
+
+    Initial value for the bias of the linear layer.
+
+
+
+## Examples
+
+```python
+>>> from pylate import models
+
+>>> model = models.Dense(
+...     in_features=768,
+...     out_features=128,
+... )
+
+>>> features = {
+...     "token_embeddings": torch.randn(2, 768),
+... }
+
+>>> projected_features = model(features)
+
+>>> assert projected_features["token_embeddings"].shape == (2, 128)
+>>> assert isinstance(model, DenseSentenceTransformer)
+```
+
+## Methods
+
+???- note "__call__"
+
+    Call self as a function.
+
+    **Parameters**
+
+    - **args**    
+    - **kwargs**    
+    
+???- note "add_module"
+
+    Add a child module to the current module.
+
+    The module can be accessed as an attribute using the given name.  Args:     name (str): name of the child module. The child module can be         accessed from this module using the given name     module (Module): child module to be added to the module.
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **module**     (*Optional[ForwardRef('Module')]*)    
+    
+???- note "apply"
+
+    Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
+
+    Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`).  Args:     fn (:class:`Module` -> None): function to be applied to each submodule  Returns:     Module: self  Example::      >>> @torch.no_grad()     >>> def init_weights(m):     >>>     print(m)     >>>     if type(m) == nn.Linear:     >>>         m.weight.fill_(1.0)     >>>         print(m.weight)     >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))     >>> net.apply(init_weights)     Linear(in_features=2, out_features=2, bias=True)     Parameter containing:     tensor([[1., 1.],             [1., 1.]], requires_grad=True)     Linear(in_features=2, out_features=2, bias=True)     Parameter containing:     tensor([[1., 1.],             [1., 1.]], requires_grad=True)     Sequential(       (0): Linear(in_features=2, out_features=2, bias=True)       (1): Linear(in_features=2, out_features=2, bias=True)     )
+
+    **Parameters**
+
+    - **fn**     (*Callable[[ForwardRef('Module')], NoneType]*)    
+    
+???- note "bfloat16"
+
+    Casts all floating point parameters and buffers to ``bfloat16`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "buffers"
+
+    Return an iterator over module buffers.
+
+    Args:     recurse (bool): if True, then yields buffers of this module         and all submodules. Otherwise, yields only buffers that         are direct members of this module.  Yields:     torch.Tensor: module buffer  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for buf in model.buffers():     >>>     print(type(buf), buf.size())     <class 'torch.Tensor'> (20L,)     <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
+
+    **Parameters**
+
+    - **recurse**     (*bool*)     – defaults to `True`    
+    
+???- note "children"
+
+    Return an iterator over immediate children modules.
+
+    Yields:     Module: a child module
+
+    
+???- note "compile"
+
+    Compile this Module's forward using :func:`torch.compile`.
+
+    This Module's `__call__` method is compiled and all arguments are passed as-is to :func:`torch.compile`.  See :func:`torch.compile` for details on the arguments for this function.
+
+    **Parameters**
+
+    - **args**    
+    - **kwargs**    
+    
+???- note "cpu"
+
+    Move all model parameters and buffers to the CPU.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "cuda"
+
+    Move all model parameters and buffers to the GPU.
+
+    This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.  .. note::     This method modifies the module in-place.  Args:     device (int, optional): if specified, all parameters will be         copied to that device  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, torch.device, NoneType]*)     – defaults to `None`    
+    
+???- note "double"
+
+    Casts all floating point parameters and buffers to ``double`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "eval"
+
+    Set the module in evaluation mode.
+
+    This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc.  This is equivalent with :meth:`self.train(False) <torch.nn.Module.train>`.  See :ref:`locally-disable-grad-doc` for a comparison between `.eval()` and several similar mechanisms that may be confused with it.  Returns:     Module: self
+
+    
+???- note "extra_repr"
+
+    Set the extra representation of the module.
+
+    To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
+
+    
+???- note "float"
+
+    Casts all floating point parameters and buffers to ``float`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "forward"
+
+    Performs linear projection on the token embeddings.
+
+    **Parameters**
+
+    - **features**     (*dict[str, torch.Tensor]*)    
+    
+???- note "from_sentence_transformers"
+
+    Converts a SentenceTransformer Dense model to a Dense model. Our Dense model does not have the activation function.
+
+    - **dense**     (*sentence_transformers.models.Dense.Dense*)    
+    
+???- note "get_buffer"
+
+    Return the buffer given by ``target`` if it exists, otherwise throw an error.
+
+    See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``.  Args:     target: The fully-qualified string name of the buffer         to look for. (See ``get_submodule`` for how to specify a         fully-qualified string.)  Returns:     torch.Tensor: The buffer referenced by ``target``  Raises:     AttributeError: If the target string references an invalid         path or resolves to something that is not a         buffer
+
+    **Parameters**
+
+    - **target**     (*str*)    
+    
+???- note "get_config_dict"
+
+???- note "get_extra_state"
+
+    Return any extra state to include in the module's state_dict.
+
+    Implement this and a corresponding :func:`set_extra_state` for your module if you need to store extra state. This function is called when building the module's `state_dict()`.  Note that extra state should be picklable to ensure working serialization of the state_dict. We only provide provide backwards compatibility guarantees for serializing Tensors; other objects may break backwards compatibility if their serialized pickled form changes.  Returns:     object: Any extra state to store in the module's state_dict
+
+    
+???- note "get_parameter"
+
+    Return the parameter given by ``target`` if it exists, otherwise throw an error.
+
+    See the docstring for ``get_submodule`` for a more detailed explanation of this method's functionality as well as how to correctly specify ``target``.  Args:     target: The fully-qualified string name of the Parameter         to look for. (See ``get_submodule`` for how to specify a         fully-qualified string.)  Returns:     torch.nn.Parameter: The Parameter referenced by ``target``  Raises:     AttributeError: If the target string references an invalid         path or resolves to something that is not an         ``nn.Parameter``
+
+    **Parameters**
+
+    - **target**     (*str*)    
+    
+???- note "get_sentence_embedding_dimension"
+
+???- note "get_submodule"
+
+    Return the submodule given by ``target`` if it exists, otherwise throw an error.
+
+    For example, let's say you have an ``nn.Module`` ``A`` that looks like this:  .. code-block:: text      A(         (net_b): Module(             (net_c): Module(                 (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))             )             (linear): Linear(in_features=100, out_features=200, bias=True)         )     )  (The diagram shows an ``nn.Module`` ``A``. ``A`` has a nested submodule ``net_b``, which itself has two submodules ``net_c`` and ``linear``. ``net_c`` then has a submodule ``conv``.)  To check whether or not we have the ``linear`` submodule, we would call ``get_submodule("net_b.linear")``. To check whether we have the ``conv`` submodule, we would call ``get_submodule("net_b.net_c.conv")``.  The runtime of ``get_submodule`` is bounded by the degree of module nesting in ``target``. A query against ``named_modules`` achieves the same result, but it is O(N) in the number of transitive modules. So, for a simple check to see if some submodule exists, ``get_submodule`` should always be used.  Args:     target: The fully-qualified string name of the submodule         to look for. (See above example for how to specify a         fully-qualified string.)  Returns:     torch.nn.Module: The submodule referenced by ``target``  Raises:     AttributeError: If the target string references an invalid         path or resolves to something that is not an         ``nn.Module``
+
+    **Parameters**
+
+    - **target**     (*str*)    
+    
+???- note "half"
+
+    Casts all floating point parameters and buffers to ``half`` datatype.
+
+    .. note::     This method modifies the module in-place.  Returns:     Module: self
+
+    
+???- note "ipu"
+
+    Move all model parameters and buffers to the IPU.
+
+    This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on IPU while being optimized.  .. note::     This method modifies the module in-place.  Arguments:     device (int, optional): if specified, all parameters will be         copied to that device  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, torch.device, NoneType]*)     – defaults to `None`    
+    
+???- note "load"
+
+    Load a Dense layer.
+
+    - **input_path**    
+    
+???- note "load_state_dict"
+
+    Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.
+
+    If :attr:`strict` is ``True``, then the keys of :attr:`state_dict` must exactly match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function.  .. warning::     If :attr:`assign` is ``True`` the optimizer must be created after     the call to :attr:`load_state_dict` unless     :func:`~torch.__future__.get_swap_module_params_on_conversion` is ``True``.  Args:     state_dict (dict): a dict containing parameters and         persistent buffers.     strict (bool, optional): whether to strictly enforce that the keys         in :attr:`state_dict` match the keys returned by this module's         :meth:`~torch.nn.Module.state_dict` function. Default: ``True``     assign (bool, optional): When ``False``, the properties of the tensors         in the current module are preserved while when ``True``, the         properties of the Tensors in the state dict are preserved. The only         exception is the ``requires_grad`` field of :class:`~torch.nn.Parameter`s         for which the value from the module is preserved.         Default: ``False``  Returns:     ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:         * **missing_keys** is a list of str containing the missing keys         * **unexpected_keys** is a list of str containing the unexpected keys  Note:     If a parameter or buffer is registered as ``None`` and its corresponding key     exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a     ``RuntimeError``.
+
+    **Parameters**
+
+    - **state_dict**     (*Mapping[str, Any]*)    
+    - **strict**     (*bool*)     – defaults to `True`    
+    - **assign**     (*bool*)     – defaults to `False`    
+    
+???- note "modules"
+
+    Return an iterator over all modules in the network.
+
+    Yields:     Module: a module in the network  Note:     Duplicate modules are returned only once. In the following     example, ``l`` will be returned only once.  Example::      >>> l = nn.Linear(2, 2)     >>> net = nn.Sequential(l, l)     >>> for idx, m in enumerate(net.modules()):     ...     print(idx, '->', m)      0 -> Sequential(       (0): Linear(in_features=2, out_features=2, bias=True)       (1): Linear(in_features=2, out_features=2, bias=True)     )     1 -> Linear(in_features=2, out_features=2, bias=True)
+
+    
+???- note "named_buffers"
+
+    Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
+
+    Args:     prefix (str): prefix to prepend to all buffer names.     recurse (bool, optional): if True, then yields buffers of this module         and all submodules. Otherwise, yields only buffers that         are direct members of this module. Defaults to True.     remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.  Yields:     (str, torch.Tensor): Tuple containing the name and buffer  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for name, buf in self.named_buffers():     >>>     if name in ['running_var']:     >>>         print(buf.size())
+
+    **Parameters**
+
+    - **prefix**     (*str*)     – defaults to ``    
+    - **recurse**     (*bool*)     – defaults to `True`    
+    - **remove_duplicate**     (*bool*)     – defaults to `True`    
+    
+???- note "named_children"
+
+    Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
+
+    Yields:     (str, Module): Tuple containing a name and child module  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for name, module in model.named_children():     >>>     if name in ['conv4', 'conv5']:     >>>         print(module)
+
+    
+???- note "named_modules"
+
+    Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
+
+    Args:     memo: a memo to store the set of modules already added to the result     prefix: a prefix that will be added to the name of the module     remove_duplicate: whether to remove the duplicated module instances in the result         or not  Yields:     (str, Module): Tuple of name and module  Note:     Duplicate modules are returned only once. In the following     example, ``l`` will be returned only once.  Example::      >>> l = nn.Linear(2, 2)     >>> net = nn.Sequential(l, l)     >>> for idx, m in enumerate(net.named_modules()):     ...     print(idx, '->', m)      0 -> ('', Sequential(       (0): Linear(in_features=2, out_features=2, bias=True)       (1): Linear(in_features=2, out_features=2, bias=True)     ))     1 -> ('0', Linear(in_features=2, out_features=2, bias=True))
+
+    **Parameters**
+
+    - **memo**     (*Optional[Set[ForwardRef('Module')]]*)     – defaults to `None`    
+    - **prefix**     (*str*)     – defaults to ``    
+    - **remove_duplicate**     (*bool*)     – defaults to `True`    
+    
+???- note "named_parameters"
+
+    Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
+
+    Args:     prefix (str): prefix to prepend to all parameter names.     recurse (bool): if True, then yields parameters of this module         and all submodules. Otherwise, yields only parameters that         are direct members of this module.     remove_duplicate (bool, optional): whether to remove the duplicated         parameters in the result. Defaults to True.  Yields:     (str, Parameter): Tuple containing the name and parameter  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for name, param in self.named_parameters():     >>>     if name in ['bias']:     >>>         print(param.size())
+
+    **Parameters**
+
+    - **prefix**     (*str*)     – defaults to ``    
+    - **recurse**     (*bool*)     – defaults to `True`    
+    - **remove_duplicate**     (*bool*)     – defaults to `True`    
+    
+???- note "parameters"
+
+    Return an iterator over module parameters.
+
+    This is typically passed to an optimizer.  Args:     recurse (bool): if True, then yields parameters of this module         and all submodules. Otherwise, yields only parameters that         are direct members of this module.  Yields:     Parameter: module parameter  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> for param in model.parameters():     >>>     print(type(param), param.size())     <class 'torch.Tensor'> (20L,)     <class 'torch.Tensor'> (20L, 1L, 5L, 5L)
+
+    **Parameters**
+
+    - **recurse**     (*bool*)     – defaults to `True`    
+    
+???- note "register_backward_hook"
+
+    Register a backward hook on the module.
+
+    This function is deprecated in favor of :meth:`~torch.nn.Module.register_full_backward_hook` and the behavior of this function will change in future versions.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]*)    
+    
+???- note "register_buffer"
+
+    Add a buffer to the module.
+
+    This is typically used to register a buffer that should not to be considered a model parameter. For example, BatchNorm's ``running_mean`` is not a parameter, but is part of the module's state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by setting :attr:`persistent` to ``False``. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module's :attr:`state_dict`.  Buffers can be accessed as attributes using given names.  Args:     name (str): name of the buffer. The buffer can be accessed         from this module using the given name     tensor (Tensor or None): buffer to be registered. If ``None``, then operations         that run on buffers, such as :attr:`cuda`, are ignored. If ``None``,         the buffer is **not** included in the module's :attr:`state_dict`.     persistent (bool): whether the buffer is part of this module's         :attr:`state_dict`.  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> self.register_buffer('running_mean', torch.zeros(num_features))
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **tensor**     (*Optional[torch.Tensor]*)    
+    - **persistent**     (*bool*)     – defaults to `True`    
+    
+???- note "register_forward_hook"
+
+    Register a forward hook on the module.
+
+    The hook will be called every time after :func:`forward` has computed an output.  If ``with_kwargs`` is ``False`` or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the output. It can modify the input inplace but it will not have effect on forward since this is called after :func:`forward` is called. The hook should have the following signature::      hook(module, args, output) -> None or modified output  If ``with_kwargs`` is ``True``, the forward hook will be passed the ``kwargs`` given to the forward function and be expected to return the output possibly modified. The hook should have the following signature::      hook(module, args, kwargs, output) -> None or modified output  Args:     hook (Callable): The user defined hook to be registered.     prepend (bool): If ``True``, the provided ``hook`` will be fired         before all existing ``forward`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``forward`` hooks on         this :class:`torch.nn.modules.Module`. Note that global         ``forward`` hooks registered with         :func:`register_module_forward_hook` will fire before all hooks         registered by this method.         Default: ``False``     with_kwargs (bool): If ``True``, the ``hook`` will be passed the         kwargs given to the forward function.         Default: ``False``     always_call (bool): If ``True`` the ``hook`` will be run regardless of         whether an exception is raised while calling the Module.         Default: ``False``  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Union[Callable[[~T, Tuple[Any, ...], Any], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    - **with_kwargs**     (*bool*)     – defaults to `False`    
+    - **always_call**     (*bool*)     – defaults to `False`    
+    
+???- note "register_forward_pre_hook"
+
+    Register a forward pre-hook on the module.
+
+    The hook will be called every time before :func:`forward` is invoked.  If ``with_kwargs`` is false or not specified, the input contains only the positional arguments given to the module. Keyword arguments won't be passed to the hooks and only to the ``forward``. The hook can modify the input. User can either return a tuple or a single modified value in the hook. We will wrap the value into a tuple if a single value is returned (unless that value is already a tuple). The hook should have the following signature::      hook(module, args) -> None or modified input  If ``with_kwargs`` is true, the forward pre-hook will be passed the kwargs given to the forward function. And if the hook modifies the input, both the args and kwargs should be returned. The hook should have the following signature::      hook(module, args, kwargs) -> None or a tuple of modified input and kwargs  Args:     hook (Callable): The user defined hook to be registered.     prepend (bool): If true, the provided ``hook`` will be fired before         all existing ``forward_pre`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``forward_pre`` hooks         on this :class:`torch.nn.modules.Module`. Note that global         ``forward_pre`` hooks registered with         :func:`register_module_forward_pre_hook` will fire before all         hooks registered by this method.         Default: ``False``     with_kwargs (bool): If true, the ``hook`` will be passed the kwargs         given to the forward function.         Default: ``False``  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Union[Callable[[~T, Tuple[Any, ...]], Optional[Any]], Callable[[~T, Tuple[Any, ...], Dict[str, Any]], Optional[Tuple[Any, Dict[str, Any]]]]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    - **with_kwargs**     (*bool*)     – defaults to `False`    
+    
+???- note "register_full_backward_hook"
+
+    Register a backward hook on the module.
+
+    The hook will be called every time the gradients with respect to a module are computed, i.e. the hook will execute if and only if the gradients with respect to module outputs are computed. The hook should have the following signature::      hook(module, grad_input, grad_output) -> tuple(Tensor) or None  The :attr:`grad_input` and :attr:`grad_output` are tuples that contain the gradients with respect to the inputs and outputs respectively. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the input that will be used in place of :attr:`grad_input` in subsequent computations. :attr:`grad_input` will only correspond to the inputs given as positional arguments and all kwarg arguments are ignored. Entries in :attr:`grad_input` and :attr:`grad_output` will be ``None`` for all non-Tensor arguments.  For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.  .. warning ::     Modifying inputs or outputs inplace is not allowed when using backward hooks and     will raise an error.  Args:     hook (Callable): The user-defined hook to be registered.     prepend (bool): If true, the provided ``hook`` will be fired before         all existing ``backward`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``backward`` hooks on         this :class:`torch.nn.modules.Module`. Note that global         ``backward`` hooks registered with         :func:`register_module_full_backward_hook` will fire before         all hooks registered by this method.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor], Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    
+???- note "register_full_backward_pre_hook"
+
+    Register a backward pre-hook on the module.
+
+    The hook will be called every time the gradients for the module are computed. The hook should have the following signature::      hook(module, grad_output) -> tuple[Tensor] or None  The :attr:`grad_output` is a tuple. The hook should not modify its arguments, but it can optionally return a new gradient with respect to the output that will be used in place of :attr:`grad_output` in subsequent computations. Entries in :attr:`grad_output` will be ``None`` for all non-Tensor arguments.  For technical reasons, when this hook is applied to a Module, its forward function will receive a view of each Tensor passed to the Module. Similarly the caller will receive a view of each Tensor returned by the Module's forward function.  .. warning ::     Modifying inputs inplace is not allowed when using backward hooks and     will raise an error.  Args:     hook (Callable): The user-defined hook to be registered.     prepend (bool): If true, the provided ``hook`` will be fired before         all existing ``backward_pre`` hooks on this         :class:`torch.nn.modules.Module`. Otherwise, the provided         ``hook`` will be fired after all existing ``backward_pre`` hooks         on this :class:`torch.nn.modules.Module`. Note that global         ``backward_pre`` hooks registered with         :func:`register_module_full_backward_pre_hook` will fire before         all hooks registered by this method.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**     (*Callable[[ForwardRef('Module'), Union[Tuple[torch.Tensor, ...], torch.Tensor]], Union[NoneType, Tuple[torch.Tensor, ...], torch.Tensor]]*)    
+    - **prepend**     (*bool*)     – defaults to `False`    
+    
+???- note "register_load_state_dict_post_hook"
+
+    Register a post hook to be run after module's ``load_state_dict`` is called.
+
+    It should have the following signature::     hook(module, incompatible_keys) -> None  The ``module`` argument is the current module that this hook is registered on, and the ``incompatible_keys`` argument is a ``NamedTuple`` consisting of attributes ``missing_keys`` and ``unexpected_keys``. ``missing_keys`` is a ``list`` of ``str`` containing the missing keys and ``unexpected_keys`` is a ``list`` of ``str`` containing the unexpected keys.  The given incompatible_keys can be modified inplace if needed.  Note that the checks performed when calling :func:`load_state_dict` with ``strict=True`` are affected by modifications the hook makes to ``missing_keys`` or ``unexpected_keys``, as expected. Additions to either set of keys will result in an error being thrown when ``strict=True``, and clearing out both missing and unexpected keys will avoid an error.  Returns:     :class:`torch.utils.hooks.RemovableHandle`:         a handle that can be used to remove the added hook by calling         ``handle.remove()``
+
+    **Parameters**
+
+    - **hook**    
+    
+???- note "register_module"
+
+    Alias for :func:`add_module`.
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **module**     (*Optional[ForwardRef('Module')]*)    
+    
+???- note "register_parameter"
+
+    Add a parameter to the module.
+
+    The parameter can be accessed as an attribute using given name.  Args:     name (str): name of the parameter. The parameter can be accessed         from this module using the given name     param (Parameter or None): parameter to be added to the module. If         ``None``, then operations that run on parameters, such as :attr:`cuda`,         are ignored. If ``None``, the parameter is **not** included in the         module's :attr:`state_dict`.
+
+    **Parameters**
+
+    - **name**     (*str*)    
+    - **param**     (*Optional[torch.nn.parameter.Parameter]*)    
+    
+???- note "register_state_dict_pre_hook"
+
+    Register a pre-hook for the :meth:`~torch.nn.Module.state_dict` method.
+
+    These hooks will be called with arguments: ``self``, ``prefix``, and ``keep_vars`` before calling ``state_dict`` on ``self``. The registered hooks can be used to perform pre-processing before the ``state_dict`` call is made.
+
+    **Parameters**
+
+    - **hook**    
+    
+???- note "requires_grad_"
+
+    Change if autograd should record operations on parameters in this module.
+
+    This method sets the parameters' :attr:`requires_grad` attributes in-place.  This method is helpful for freezing part of the module for finetuning or training parts of a model individually (e.g., GAN training).  See :ref:`locally-disable-grad-doc` for a comparison between `.requires_grad_()` and several similar mechanisms that may be confused with it.  Args:     requires_grad (bool): whether autograd should record operations on                           parameters in this module. Default: ``True``.  Returns:     Module: self
+
+    **Parameters**
+
+    - **requires_grad**     (*bool*)     – defaults to `True`    
+    
+???- note "save"
+
+???- note "set_extra_state"
+
+    Set extra state contained in the loaded `state_dict`.
+
+    This function is called from :func:`load_state_dict` to handle any extra state found within the `state_dict`. Implement this function and a corresponding :func:`get_extra_state` for your module if you need to store extra state within its `state_dict`.  Args:     state (dict): Extra state from the `state_dict`
+
+    **Parameters**
+
+    - **state**     (*Any*)    
+    
+???- note "share_memory"
+
+    See :meth:`torch.Tensor.share_memory_`.
+
+    
+???- note "state_dict"
+
+    Return a dictionary containing references to the whole state of the module.
+
+    Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to ``None`` are not included.  .. note::     The returned object is a shallow copy. It contains references     to the module's parameters and buffers.  .. warning::     Currently ``state_dict()`` also accepts positional arguments for     ``destination``, ``prefix`` and ``keep_vars`` in order. However,     this is being deprecated and keyword arguments will be enforced in     future releases.  .. warning::     Please avoid the use of argument ``destination`` as it is not     designed for end-users.  Args:     destination (dict, optional): If provided, the state of module will         be updated into the dict and the same object is returned.         Otherwise, an ``OrderedDict`` will be created and returned.         Default: ``None``.     prefix (str, optional): a prefix added to parameter and buffer         names to compose the keys in state_dict. Default: ``''``.     keep_vars (bool, optional): by default the :class:`~torch.Tensor` s         returned in the state dict are detached from autograd. If it's         set to ``True``, detaching will not be performed.         Default: ``False``.  Returns:     dict:         a dictionary containing a whole state of the module  Example::      >>> # xdoctest: +SKIP("undefined vars")     >>> module.state_dict().keys()     ['bias', 'weight']
+
+    **Parameters**
+
+    - **args**    
+    - **destination**     – defaults to `None`    
+    - **prefix**     – defaults to ``    
+    - **keep_vars**     – defaults to `False`    
+    
+???- note "to"
+
+    Move and/or cast the parameters and buffers.
+
+    This can be called as  .. function:: to(device=None, dtype=None, non_blocking=False)    :noindex:  .. function:: to(dtype, non_blocking=False)    :noindex:  .. function:: to(tensor, non_blocking=False)    :noindex:  .. function:: to(memory_format=torch.channels_last)    :noindex:  Its signature is similar to :meth:`torch.Tensor.to`, but only accepts floating point or complex :attr:`dtype`\ s. In addition, this method will only cast the floating point or complex parameters and buffers to :attr:`dtype` (if given). The integral parameters and buffers will be moved :attr:`device`, if that is given, but with dtypes unchanged. When :attr:`non_blocking` is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.  See below for examples.  .. note::     This method modifies the module in-place.  Args:     device (:class:`torch.device`): the desired device of the parameters         and buffers in this module     dtype (:class:`torch.dtype`): the desired floating point or complex dtype of         the parameters and buffers in this module     tensor (torch.Tensor): Tensor whose dtype and device are the desired         dtype and device for all parameters and buffers in this module     memory_format (:class:`torch.memory_format`): the desired memory         format for 4D parameters and buffers in this module (keyword         only argument)  Returns:     Module: self  Examples::      >>> # xdoctest: +IGNORE_WANT("non-deterministic")     >>> linear = nn.Linear(2, 2)     >>> linear.weight     Parameter containing:     tensor([[ 0.1913, -0.3420],             [-0.5113, -0.2325]])     >>> linear.to(torch.double)     Linear(in_features=2, out_features=2, bias=True)     >>> linear.weight     Parameter containing:     tensor([[ 0.1913, -0.3420],             [-0.5113, -0.2325]], dtype=torch.float64)     >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA1)     >>> gpu1 = torch.device("cuda:1")     >>> linear.to(gpu1, dtype=torch.half, non_blocking=True)     Linear(in_features=2, out_features=2, bias=True)     >>> linear.weight     Parameter containing:     tensor([[ 0.1914, -0.3420],             [-0.5112, -0.2324]], dtype=torch.float16, device='cuda:1')     >>> cpu = torch.device("cpu")     >>> linear.to(cpu)     Linear(in_features=2, out_features=2, bias=True)     >>> linear.weight     Parameter containing:     tensor([[ 0.1914, -0.3420],             [-0.5112, -0.2324]], dtype=torch.float16)      >>> linear = nn.Linear(2, 2, bias=None).to(torch.cdouble)     >>> linear.weight     Parameter containing:     tensor([[ 0.3741+0.j,  0.2382+0.j],             [ 0.5593+0.j, -0.4443+0.j]], dtype=torch.complex128)     >>> linear(torch.ones(3, 2, dtype=torch.cdouble))     tensor([[0.6122+0.j, 0.1150+0.j],             [0.6122+0.j, 0.1150+0.j],             [0.6122+0.j, 0.1150+0.j]], dtype=torch.complex128)
+
+    **Parameters**
+
+    - **args**    
+    - **kwargs**    
+    
+???- note "to_empty"
+
+    Move the parameters and buffers to the specified device without copying storage.
+
+    Args:     device (:class:`torch.device`): The desired device of the parameters         and buffers in this module.     recurse (bool): Whether parameters and buffers of submodules should         be recursively moved to the specified device.  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, str, torch.device, NoneType]*)    
+    - **recurse**     (*bool*)     – defaults to `True`    
+    
+???- note "train"
+
+    Set the module in training mode.
+
+    This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. :class:`Dropout`, :class:`BatchNorm`, etc.  Args:     mode (bool): whether to set training mode (``True``) or evaluation                  mode (``False``). Default: ``True``.  Returns:     Module: self
+
+    **Parameters**
+
+    - **mode**     (*bool*)     – defaults to `True`    
+    
+???- note "type"
+
+    Casts all parameters and buffers to :attr:`dst_type`.
+
+    .. note::     This method modifies the module in-place.  Args:     dst_type (type or string): the desired type  Returns:     Module: self
+
+    **Parameters**
+
+    - **dst_type**     (*Union[torch.dtype, str]*)    
+    
+???- note "xpu"
+
+    Move all model parameters and buffers to the XPU.
+
+    This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on XPU while being optimized.  .. note::     This method modifies the module in-place.  Arguments:     device (int, optional): if specified, all parameters will be         copied to that device  Returns:     Module: self
+
+    **Parameters**
+
+    - **device**     (*Union[int, torch.device, NoneType]*)     – defaults to `None`    
+    
+???- note "zero_grad"
+
+    Reset gradients of all model parameters.
+
+    See similar function under :class:`torch.optim.Optimizer` for more context.  Args:     set_to_none (bool): instead of setting to zero, set the grads to None.         See :meth:`torch.optim.Optimizer.zero_grad` for details.
+
+    **Parameters**
+
+    - **set_to_none**     (*bool*)     – defaults to `True`    
+    
diff --git a/docs/api/overview.md b/docs/api/overview.md
new file mode 100644
index 0000000..a87e005
--- /dev/null
+++ b/docs/api/overview.md
@@ -0,0 +1,57 @@
+# Overview
+
+## evaluation
+
+
+**Classes**
+
+- [ColBERTDistillationEvaluator](../evaluation/ColBERTDistillationEvaluator)
+- [ColBERTTripletEvaluator](../evaluation/ColBERTTripletEvaluator)
+
+**Functions**
+
+- [evaluate](../evaluation/evaluate)
+- [get_beir_triples](../evaluation/get-beir-triples)
+- [load_beir](../evaluation/load-beir)
+
+## indexes
+
+- [Voyager](../indexes/Voyager)
+
+## losses
+
+- [Contrastive](../losses/Contrastive)
+- [Distillation](../losses/Distillation)
+
+## models
+
+- [ColBERT](../models/ColBERT)
+- [Dense](../models/Dense)
+
+## rank
+
+- [rerank](../rank/rerank)
+
+## retrieve
+
+- [ColBERT](../retrieve/ColBERT)
+
+## scores
+
+- [colbert_kd_scores](../scores/colbert-kd-scores)
+- [colbert_scores](../scores/colbert-scores)
+- [colbert_scores_pairwise](../scores/colbert-scores-pairwise)
+
+## utils
+
+
+**Classes**
+
+- [ColBERTCollator](../utils/ColBERTCollator)
+- [KDProcessing](../utils/KDProcessing)
+
+**Functions**
+
+- [convert_to_tensor](../utils/convert-to-tensor)
+- [iter_batch](../utils/iter-batch)
+
diff --git a/docs/api/rank/.pages b/docs/api/rank/.pages
new file mode 100644
index 0000000..1b9333c
--- /dev/null
+++ b/docs/api/rank/.pages
@@ -0,0 +1 @@
+title: rank
\ No newline at end of file
diff --git a/docs/api/rank/rerank.md b/docs/api/rank/rerank.md
new file mode 100644
index 0000000..7d52a8f
--- /dev/null
+++ b/docs/api/rank/rerank.md
@@ -0,0 +1,76 @@
+# rerank
+
+Rerank the documents based on the queries embeddings.
+
+
+
+## Parameters
+
+- **documents_ids** (*list[list[int | str]]*)
+
+    The documents ids.
+
+- **queries_embeddings** (*list[list[float | int] | numpy.ndarray | torch.Tensor]*)
+
+    The queries embeddings which is a dictionary of queries and their embeddings.
+
+- **documents_embeddings** (*list[list[float | int] | numpy.ndarray | torch.Tensor]*)
+
+    The documents embeddings which is a dictionary of documents ids and their embeddings.
+
+- **device** (*str*) – defaults to `None`
+
+
+
+## Examples
+
+```python
+>>> from pylate import models, rank
+
+>>> model = models.ColBERT(
+...     model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu"
+... )
+
+>>> queries = [
+...     "query A",
+...     "query B",
+... ]
+
+>>> documents = [
+...     ["document A", "document B"],
+...     ["document 1", "document C", "document B"],
+... ]
+
+>>> documents_ids = [
+...    [1, 2],
+...    [1, 3, 2],
+... ]
+
+>>> queries_embeddings = model.encode(
+...     queries,
+...     is_query=True,
+...     batch_size=1,
+... )
+
+>>> documents_embeddings = model.encode(
+...     documents,
+...     is_query=False,
+...     batch_size=1,
+... )
+
+>>> reranked_documents = rank.rerank(
+...     documents_ids=documents_ids,
+...     queries_embeddings=queries_embeddings,
+...     documents_embeddings=documents_embeddings,
+... )
+
+>>> assert isinstance(reranked_documents, list)
+>>> assert len(reranked_documents) == 2
+>>> assert len(reranked_documents[0]) == 2
+>>> assert len(reranked_documents[1]) == 3
+>>> assert isinstance(reranked_documents[0], list)
+>>> assert isinstance(reranked_documents[0][0], dict)
+>>> assert "id" in reranked_documents[0][0]
+>>> assert "score" in reranked_documents[0][0]
+```
+
diff --git a/docs/api/retrieve/.pages b/docs/api/retrieve/.pages
new file mode 100644
index 0000000..ac400f7
--- /dev/null
+++ b/docs/api/retrieve/.pages
@@ -0,0 +1 @@
+title: retrieve
\ No newline at end of file
diff --git a/docs/api/retrieve/ColBERT.md b/docs/api/retrieve/ColBERT.md
new file mode 100644
index 0000000..d36bc9d
--- /dev/null
+++ b/docs/api/retrieve/ColBERT.md
@@ -0,0 +1,93 @@
+# ColBERT
+
+ColBERT retriever.
+
+
+
+## Parameters
+
+- **index** (*[indexes.Voyager](../../indexes/Voyager)*)
+
+
+
+## Examples
+
+```python
+>>> from pylate import indexes, models, retrieve
+
+>>> model = models.ColBERT(
+...     model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
+...     device="cpu",
+... )
+
+>>> documents_ids = ["1", "2"]
+
+>>> documents = [
+...     "fruits are healthy.",
+...     "fruits are good for health.",
+... ]
+
+>>> documents_embeddings = model.encode(
+...     sentences=documents,
+...     batch_size=1,
+...     is_query=False,
+... )
+
+>>> index = indexes.Voyager(
+...     index_folder="test_indexes",
+...     index_name="colbert",
+...     override=True,
+...     embedding_size=128,
+... )
+
+>>> index = index.add_documents(
+...     documents_ids=documents_ids,
+...     documents_embeddings=documents_embeddings,
+... )
+
+>>> retriever = retrieve.ColBERT(index=index)
+
+>>> queries_embeddings = model.encode(
+...     ["fruits are healthy.", "fruits are good for health."],
+...     batch_size=1,
+...     is_query=True,
+... )
+
+>>> results = retriever.retrieve(
+...     queries_embeddings=queries_embeddings,
+...     k=2,
+...     device="cpu",
+... )
+
+>>> assert isinstance(results, list)
+>>> assert len(results) == 2
+
+>>> queries_embeddings = model.encode(
+...     "fruits are healthy.",
+...     batch_size=1,
+...     is_query=True,
+... )
+
+>>> results = retriever.retrieve(
+...     queries_embeddings=queries_embeddings,
+...     k=2,
+...     device="cpu",
+... )
+
+>>> assert isinstance(results, list)
+>>> assert len(results) == 1
+```
+
+## Methods
+
+???- note "retrieve"
+
+    Retrieve documents for a list of queries.
+
+    **Parameters**
+
+    - **queries_embeddings**     (*list[list | numpy.ndarray | torch.Tensor]*)    
+    - **k**     (*int*)     – defaults to `10`    
+    - **k_index**     (*int | None*)     – defaults to `None`    
+    - **device**     (*str | None*)     – defaults to `None`    
+    
diff --git a/docs/api/scores/.pages b/docs/api/scores/.pages
new file mode 100644
index 0000000..15a75d3
--- /dev/null
+++ b/docs/api/scores/.pages
@@ -0,0 +1 @@
+title: scores
\ No newline at end of file
diff --git a/docs/api/scores/colbert-kd-scores.md b/docs/api/scores/colbert-kd-scores.md
new file mode 100644
index 0000000..032a9a4
--- /dev/null
+++ b/docs/api/scores/colbert-kd-scores.md
@@ -0,0 +1,47 @@
+# colbert_kd_scores
+
+Computes the ColBERT scores between queries and documents embeddings. This scoring function is dedicated to the knowledge distillation pipeline.
+
+
+
+## Parameters
+
+- **queries_embeddings** (*list | numpy.ndarray | torch.Tensor*)
+
+- **documents_embeddings** (*list | numpy.ndarray | torch.Tensor*)
+
+- **mask** (*torch.Tensor*) – defaults to `None`
+
+
+
+## Examples
+
+```python
+>>> import torch
+
+>>> queries_embeddings = torch.tensor([
+...     [[1.], [0.], [0.], [0.]],
+...     [[0.], [2.], [0.], [0.]],
+...     [[0.], [0.], [3.], [0.]],
+... ])
+
+>>> documents_embeddings = torch.tensor([
+...     [[[10.], [0.], [1.]], [[20.], [0.], [1.]], [[30.], [0.], [1.]]],
+...     [[[0.], [100.], [1.]], [[0.], [200.], [1.]], [[0.], [300.], [1.]]],
+...     [[[1.], [0.], [1000.]], [[1.], [0.], [2000.]], [[10.], [0.], [3000.]]],
+... ])
+>>> mask = torch.tensor([
+...     [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]],
+...     [[1., 1., 1.], [1., 1., 1.], [1., 1., 1.]],
+...     [[1., 1., 1.], [1., 1., 1.], [1., 1., 0.]],
+... ])
+>>> colbert_kd_scores(
+...     queries_embeddings=queries_embeddings,
+...     documents_embeddings=documents_embeddings,
+...     mask=mask
+... )
+tensor([[  10.,   20.,   30.],
+        [ 200.,  400.,  600.],
+        [3000., 6000., 30.]])
+```
+
diff --git a/docs/api/scores/colbert-scores-pairwise.md b/docs/api/scores/colbert-scores-pairwise.md
new file mode 100644
index 0000000..af1e567
--- /dev/null
+++ b/docs/api/scores/colbert-scores-pairwise.md
@@ -0,0 +1,44 @@
+# colbert_scores_pairwise
+
+Computes the ColBERT score for each query-document pair. The score is computed as the sum of maximum similarities between the query and the document for corresponding pairs.
+
+
+
+## Parameters
+
+- **queries_embeddings** (*torch.Tensor*)
+
+    The first tensor. The queries embeddings. Shape: (batch_size, num tokens queries, embedding_size)
+
+- **documents_embeddings** (*torch.Tensor*)
+
+    The second tensor. The documents embeddings. Shape: (batch_size, num tokens documents, embedding_size)
+
+
+
+## Examples
+
+```python
+>>> import torch
+
+>>> queries_embeddings = torch.tensor([
+...     [[1.], [0.], [0.], [0.]],
+...     [[0.], [2.], [0.], [0.]],
+...     [[0.], [0.], [3.], [0.]],
+... ])
+
+>>> documents_embeddings = torch.tensor([
+...     [[10.], [0.], [1.]],
+...     [[0.], [100.], [1.]],
+...     [[1.], [0.], [1000.]],
+... ])
+
+>>> scores = colbert_scores_pairwise(
+...     queries_embeddings=queries_embeddings,
+...     documents_embeddings=documents_embeddings
+... )
+
+>>> scores
+tensor([  10.,  200., 3000.])
+```
+
diff --git a/docs/api/scores/colbert-scores.md b/docs/api/scores/colbert-scores.md
new file mode 100644
index 0000000..9ac37fd
--- /dev/null
+++ b/docs/api/scores/colbert-scores.md
@@ -0,0 +1,48 @@
+# colbert_scores
+
+Computes the ColBERT scores between queries and documents embeddings. The score is computed as the sum of maximum similarities between the query and the document.
+
+
+
+## Parameters
+
+- **queries_embeddings** (*list | numpy.ndarray | torch.Tensor*)
+
+    The first tensor. The queries embeddings. Shape: (batch_size, num tokens queries, embedding_size)
+
+- **documents_embeddings** (*list | numpy.ndarray | torch.Tensor*)
+
+    The second tensor. The documents embeddings. Shape: (batch_size, num tokens documents, embedding_size)
+
+- **mask** (*torch.Tensor*) – defaults to `None`
+
+
+
+## Examples
+
+```python
+>>> import torch
+
+>>> queries_embeddings = torch.tensor([
+...     [[1.], [0.], [0.], [0.]],
+...     [[0.], [2.], [0.], [0.]],
+...     [[0.], [0.], [3.], [0.]],
+... ])
+
+>>> documents_embeddings = torch.tensor([
+...     [[10.], [0.], [1.]],
+...     [[0.], [100.], [1.]],
+...     [[1.], [0.], [1000.]],
+... ])
+
+>>> scores = colbert_scores(
+...     queries_embeddings=queries_embeddings,
+...     documents_embeddings=documents_embeddings
+... )
+
+>>> scores
+tensor([[  10.,  100., 1000.],
+        [  20.,  200., 2000.],
+        [  30.,  300., 3000.]])
+```
+
diff --git a/docs/api/utils/.pages b/docs/api/utils/.pages
new file mode 100644
index 0000000..7b72313
--- /dev/null
+++ b/docs/api/utils/.pages
@@ -0,0 +1 @@
+title: utils
\ No newline at end of file
diff --git a/docs/api/utils/ColBERTCollator.md b/docs/api/utils/ColBERTCollator.md
new file mode 100644
index 0000000..e771922
--- /dev/null
+++ b/docs/api/utils/ColBERTCollator.md
@@ -0,0 +1,70 @@
+# ColBERTCollator
+
+Collator for ColBERT model.
+
+
+
+## Parameters
+
+- **tokenize_fn** (*Callable*)
+
+    The function to tokenize the input text.
+
+- **valid_label_columns** (*list[str] | None*) – defaults to `None`
+
+    The name of the columns that contain the labels: scores or labels.
+
+
+
+## Examples
+
+```python
+>>> from pylate import models, utils
+
+>>> model = models.ColBERT(
+...     model_name_or_path="sentence-transformers/all-MiniLM-L6-v2", device="cpu"
+... )
+
+>>> collator = utils.ColBERTCollator(
+...     tokenize_fn=model.tokenize,
+... )
+
+>>> features = [
+...     {
+...         "query": "fruits are healthy.",
+...         "positive": "fruits are good for health.",
+...         "negative": "fruits are bad for health.",
+...         "label": [0.7, 0.3]
+...     }
+... ]
+
+>>> features = collator(features=features)
+
+>>> fields = [
+...     "query_input_ids",
+...     "positive_input_ids",
+...     "negative_input_ids",
+...     "query_attention_mask",
+...     "positive_attention_mask",
+...     "negative_attention_mask",
+...     "query_token_type_ids",
+...     "positive_token_type_ids",
+...     "negative_token_type_ids",
+... ]
+
+>>> for field in fields:
+...     assert field in features
+...     assert isinstance(features[field], torch.Tensor)
+...     assert features[field].ndim == 2
+```
+
+## Methods
+
+???- note "__call__"
+
+    Collate a list of features into a batch.
+
+    **Parameters**
+
+    - **features**     (*list[dict]*)    
+    
diff --git a/docs/api/utils/KDProcessing.md b/docs/api/utils/KDProcessing.md
new file mode 100644
index 0000000..386a561
--- /dev/null
+++ b/docs/api/utils/KDProcessing.md
@@ -0,0 +1,74 @@
+# KDProcessing
+
+Dataset processing class for knowledge distillation training.
+
+
+
+## Parameters
+
+- **queries** (*datasets.arrow_dataset.Dataset*)
+
+    Queries dataset.
+
+- **documents** (*datasets.arrow_dataset.Dataset*)
+
+    Documents dataset.
+
+- **n_ways** (*int*) – defaults to `32`
+
+
+
+## Examples
+
+```python
+>>> from datasets import load_dataset
+>>> from pylate import utils
+
+>>> train = load_dataset(
+...    path="lightonai/lighton-ms-marco-mini",
+...    name="train",
+...    split="train",
+... )
+
+>>> queries = load_dataset(
+...    path="lightonai/lighton-ms-marco-mini",
+...    name="queries",
+...    split="train",
+... )
+
+>>> documents = load_dataset(
+...    path="lightonai/lighton-ms-marco-mini",
+...    name="documents",
+...    split="train",
+... )
+
+>>> train.set_transform(
+...    utils.KDProcessing(
+...        queries=queries, documents=documents
+...    ).transform,
+... )
+
+>>> for sample in train:
+...     assert "documents" in sample and isinstance(sample["documents"], list)
+...     assert "query" in sample and isinstance(sample["query"], str)
+...     assert "scores" in sample and isinstance(sample["scores"], list)
+```
+
+## Methods
+
+???- note "map"
+
+    Process a single example.
+
+    **Parameters**
+
+    - **example**     (*dict*)    
+    
+???- note "transform"
+
+    Update the input dataset with the queries and documents.
+
+    **Parameters**
+
+    - **examples**     (*dict*)    
+    
diff --git a/docs/api/utils/convert-to-tensor.md b/docs/api/utils/convert-to-tensor.md
new file mode 100644
index 0000000..b5b121a
--- /dev/null
+++ b/docs/api/utils/convert-to-tensor.md
@@ -0,0 +1,52 @@
+# convert_to_tensor
+
+Converts a list or numpy array to a torch tensor.
+
+
+
+## Parameters
+
+- **x** (*torch.Tensor | numpy.ndarray | list[torch.Tensor | numpy.ndarray | list | float]*)
+
+    The input data. It can be a torch tensor, a numpy array, or a list of torch tensors, numpy arrays, or lists.
+
+
+
+## Examples
+
+```python
+>>> import numpy as np
+>>> import torch
+
+>>> x = torch.tensor([[1., 1., 1.], [2., 2., 2.]])
+>>> convert_to_tensor(x)
+tensor([[1., 1., 1.],
+        [2., 2., 2.]])
+
+>>> x = np.array([[1., 1., 1.], [2., 2., 2.]], dtype=np.float32)
+>>> convert_to_tensor(x)
+tensor([[1., 1., 1.],
+        [2., 2., 2.]])
+
+>>> x = []
+>>> convert_to_tensor(x)
+tensor([])
+
+>>> x = [np.array([1., 1., 1.])]
+>>> convert_to_tensor(x)
+tensor([[1., 1., 1.]])
+
+>>> x = [[1., 1., 1.]]
+>>> convert_to_tensor(x)
+tensor([[1., 1., 1.]])
+
+>>> x = [torch.tensor([1., 1., 1.]), torch.tensor([2., 2., 2.])]
+>>> convert_to_tensor(x)
+tensor([[1., 1., 1.],
+        [2., 2., 2.]])
+
+>>> x = np.array([], dtype=np.float32)
+>>> convert_to_tensor(x)
+tensor([])
+```
+
diff --git a/docs/api/utils/iter-batch.md b/docs/api/utils/iter-batch.md
new file mode 100644
index 0000000..ea93419
--- /dev/null
+++ b/docs/api/utils/iter-batch.md
@@ -0,0 +1,39 @@
+# iter_batch
+
+Iterate over a list of elements by batch.
+
+
+
+## Parameters
+
+- **X** (*list[str]*)
+
+- **batch_size** (*int*)
+
+- **tqdm_bar** (*bool*) – defaults to `True`
+
+- **desc** (*str*) – defaults to ``
+
+
+
+## Examples
+
+```python
+>>> from pylate import utils
+
+>>> X = [
+...  "element 0",
+...  "element 1",
+...  "element 2",
+...  "element 3",
+...  "element 4",
+... ]
+
+>>> n_samples = 0
+>>> for batch in utils.iter_batch(X, batch_size=2):
+...     n_samples += len(batch)
+
+>>> n_samples
+5
+```
+
diff --git a/docs/benchmarks/.pages b/docs/benchmarks/.pages
new file mode 100644
index 0000000..05a507a
--- /dev/null
+++ b/docs/benchmarks/.pages
@@ -0,0 +1,3 @@
+title: Benchmarks
+nav:
+    - Benchmarks: benchmarks.md
diff --git a/docs/benchmarks/benchmarks.md b/docs/benchmarks/benchmarks.md
new file mode 100644
index 0000000..c66d2d3
--- /dev/null
+++ b/docs/benchmarks/benchmarks.md
@@ -0,0 +1,11 @@
+# ColBERT Benchmarks
+
+=== "Table"
+
+    | Model                                   | Dataset   |   Language |   NDCG@10  |  NDCG@100  |   RECALL@10    |  RECALL@100 |
+    |:----------------------------------------|:----------|-----------:|-----------:|-----------:|---------------:|------------:|
+    | sentence-transformers/all-mpnet-base-v2 | dataset_x |   English  |   0.677864 | 0.645041   |    0.453154    |    876.714  |
+    | sentence-transformers/all-mpnet-base-v2 | dataset_y |   English  |   0.880581 | 0.858687   |   13.5424      |  10153.7    |
+    | sentence-transformers/all-mpnet-base-v2 | dataset_z |   English  |   0.878303 | 0.863555   |    0.873312    |    552.609  |
+    | sentence-transformers/all-mpnet-base-v2 | dataset_a |   English  |   0.999443 | 0.404494   |    1.33633     |   6617.5    |
+
diff --git a/docs/css/version-select.css b/docs/css/version-select.css
new file mode 100644
index 0000000..9b6074d
--- /dev/null
+++ b/docs/css/version-select.css
@@ -0,0 +1,5 @@
+@media only screen and (max-width:76.1875em) {
+    #version-selector {
+        padding: .6rem .8rem;
+    }
+}
\ No newline at end of file
diff --git a/docs/documentation/.pages b/docs/documentation/.pages
new file mode 100644
index 0000000..878d647
--- /dev/null
+++ b/docs/documentation/.pages
@@ -0,0 +1,6 @@
+title: Documentation
+nav:
+    - Training: training.md
+    - Datasets: datasets.md
+    - Retrieval: retrieval.md
+    - Evaluation: evaluation.md
diff --git a/docs/documentation/datasets.md b/docs/documentation/datasets.md
new file mode 100644
index 0000000..f1c908b
--- /dev/null
+++ b/docs/documentation/datasets.md
@@ -0,0 +1,248 @@
+PyLate models are designed to be compatible with Hugging Face datasets, facilitating seamless integration for tasks such as knowledge distillation and contrastive model training. Below are examples illustrating how to load and prepare datasets for these specific training objectives.
+
+## Knowledge distillation dataset
+
+For fine-tuning a model using knowledge distillation loss, three distinct dataset files are required: train, queries, and documents. Each file contains unique and complementary information necessary for the distillation process.
+
+A sample dataset for knowledge distillation is available here: [MS Marco Mini Dataset](https://huggingface.co/datasets/lightonai/lighton-ms-marco-mini).
+
+```python
+from datasets import load_dataset
+
+train = load_dataset(
+    "lightonai/lighton-ms-marco-mini",
+    "train",
+    split="train",
+)
+
+queries = load_dataset(
+    "lightonai/lighton-ms-marco-mini",
+    "queries",
+    split="train",
+)
+
+documents = load_dataset(
+    "lightonai/lighton-ms-marco-mini",
+    "documents",
+    split="train",
+)
+```
+
+Where:
+
+- `train`: Contains three columns: `['query_id', 'document_ids', 'scores']`
+    - `query_id` refers to the query identifier.
+    - `document_ids` is a list of document IDs relevant to the query.
+    - `scores` corresponds to the relevance scores between the query and each document.
+
+Example entry:
+
+```python
+{
+    "query_id": 54528,
+    "document_ids": [
+        6862419,
+        335116,
+        339186,
+        7509316,
+        7361291,
+        7416534,
+        5789936,
+        5645247,
+    ],
+    "scores": [
+        0.4546215673141326,
+        0.6575686537173476,
+        0.26825184192900203,
+        0.5256195579370395,
+        0.879939718687207,
+        0.7894968184862693,
+        0.6450100468854655,
+        0.5823844608171467,
+    ],
+}
+```
+
+Note: Ensure that the length of `document_ids` matches the length of `scores`.
+
+- `queries`: Contains two columns: `['query_id', 'text']`
+
+Example entry:
+
+```python
+{"query_id": 749480, "text": "what is function of magnesium in human body"}
+```
+
+- `documents`: contains two columns: `['document_ids', 'text']`
+
+Example entry:
+
+```python
+{
+    "document_id": 136062,
+    "text": "Document text",
+}
+```
+
+### Knowledge distillation dataset from list
+
+You can also create custom datasets from list in Python. This example demonstrates how to build and split the `train`, `queries`, and `documents` datasets
+
+```python
+from datasets import Dataset
+
+dataset = [
+    {
+        "query_id": 54528,
+        "document_ids": [
+            6862419,
+            335116,
+            339186,
+            7509316,
+            7361291,
+            7416534,
+            5789936,
+            5645247,
+        ],
+        "scores": [
+            0.4546215673141326,
+            0.6575686537173476,
+            0.26825184192900203,
+            0.5256195579370395,
+            0.879939718687207,
+            0.7894968184862693,
+            0.6450100468854655,
+            0.5823844608171467,
+        ],
+    },
+    {
+        "query_id": 749480,
+        "document_ids": [
+            6862419,
+            335116,
+            339186,
+            7509316,
+            7361291,
+            7416534,
+            5789936,
+            5645247,
+        ],
+        "scores": [
+            0.2546215673141326,
+            0.7575686537173476,
+            0.96825184192900203,
+            0.0256195579370395,
+            0.779939718687207,
+            0.2894968184862693,
+            0.1450100468854655,
+            0.7823844608171467,
+        ],
+    },
+]
+
+
+dataset = Dataset.from_list(mapping=dataset)
+
+train_dataset, test_dataset = dataset.train_test_split(test_size=0.3)
+```
+
+Queries and Documents Dataset Example
+
+```python
+from datasets import Dataset
+
+documents = [
+    {"document_id": 6862419, "text": "Document text"},
+    {"document_id": 335116, "text": "Document text"},
+    {"document_id": 339186, "text": "Document text"},
+    {"document_id": 7509316, "text": "Document text"},
+    {"document_id": 7361291, "text": "Document text"},
+    {"document_id": 7416534, "text": "Document text"},
+    {"document_id": 5789936, "text": "Document text"},
+    {"document_id": 5645247, "text": "Document text"},
+]
+
+queries = [
+    {"query_id": 749480, "text": "what is function of magnesium in human body"},
+    {"query_id": 54528, "text": "what is the capital of France"},
+]
+
+documents = Dataset.from_list(mapping=documents)
+
+queries = Dataset.from_list(mapping=queries)
+```
+
+## Constrastive dataset
+
+Contrastive training involves datasets that contain a query, a positive document (relevant to the query), and a negative document (irrelevant to the query). The model is trained to maximize the similarity between the query and the positive document while minimizing the similarity with the negative document.
+
+### Loading a pre-built contrastive dataset
+
+You can directly download an existing contrastive dataset from Hugging Face's hub, such as the [msmarco-bm25 triplet dataset](https://huggingface.co/datasets/sentence-transformers/msmarco-bm25).
+
+```python
+from datasets import load_dataset
+
+dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")
+
+train_dataset, test_dataset = dataset.train_test_split(test_size=0.001)
+```
+
+Then we can shuffle the dataset:
+
+```python
+train_dataset = train_dataset.shuffle(seed=42)
+```
+
+And select a subset of the dataset if needed:
+
+```python
+train_dataset = train_dataset.select(range(10_000))
+```
+
+### Creating a contrastive dataset from list
+
+If you want to create a custom contrastive dataset, you can do so by manually specifying the query, positive, and negative samples.
+
+```python
+from datasets import Dataset
+
+dataset = [
+    {
+        "query": "example query 1",
+        "positive": "example positive document 1",
+        "negative": "example negative document 1",
+    },
+    {
+        "query": "example query 2",
+        "positive": "example positive document 2",
+        "negative": "example negative document 2",
+    },
+    {
+        "query": "example query 3",
+        "positive": "example positive document 3",
+        "negative": "example negative document 3",
+    },
+]
+
+dataset = Dataset.from_list(mapping=dataset)
+
+train_dataset, test_dataset = dataset.train_test_split(test_size=0.3)
+```
+
+### Loading a contrastive dataset from a local parquet file
+
+To load a local dataset stored in a Parquet file:
+
+```python
+from datasets import load_dataset
+
+dataset = load_dataset(
+    path="parquet", 
+    data_files="dataset.parquet", 
+    split="train"
+)
+
+train_dataset, test_dataset = dataset.train_test_split(test_size=0.001)
+```
+
diff --git a/docs/documentation/evaluation.md b/docs/documentation/evaluation.md
new file mode 100644
index 0000000..5390985
--- /dev/null
+++ b/docs/documentation/evaluation.md
@@ -0,0 +1,169 @@
+## Retrieval evaluation
+
+This guide demonstrates an end-to-end pipeline to evaluate the performance of the ColBERT model on retrieval tasks. The pipeline involves three key steps: indexing documents, retrieving top-k documents for a given set of queries, and evaluating the retrieval results using standard metrics.
+
+### Retrieval Evaluation Pipeline
+
+```python
+from pylate import evaluation, indexes, models, retrieve
+
+# Step 1: Initialize the ColBERT model
+model = models.ColBERT(
+    model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
+    device="cuda" # "cpu" or "cuda" or "mps"
+)
+
+# Step 2: Create a Voyager index
+index = indexes.Voyager(
+    index_folder="pylate-index",
+    index_name="index",
+    override=True,  # Overwrite any existing index
+)
+
+# Step 3: Load the documents, queries, and relevance judgments (qrels)
+documents, queries, qrels = evaluation.load_beir(
+    "scifact",  # Specify the dataset (e.g., "scifact")
+    split="test",  # Specify the split (e.g., "test")
+)
+
+# Step 4: Encode the documents
+documents_embeddings = model.encode(
+    [document["text"] for document in documents],
+    batch_size=32,
+    is_query=False,  # Indicate that these are documents
+    show_progress_bar=True,
+)
+
+# Step 5: Add document embeddings to the index
+index.add_documents(
+    documents_ids=[document["id"] for document in documents],
+    documents_embeddings=documents_embeddings,
+)
+
+# Step 6: Encode the queries
+queries_embeddings = model.encode(
+    queries,
+    batch_size=32,
+    is_query=True,  # Indicate that these are queries
+    show_progress_bar=True,
+)
+
+# Step 7: Retrieve top-k documents
+retriever = retrieve.ColBERT(index=index)
+scores = retriever.retrieve(
+    queries_embeddings=queries_embeddings,
+    k=100,  # Retrieve the top 100 matches for each query
+)
+
+# Step 8: Evaluate the retrieval results
+results = evaluation.evaluate(
+    scores=scores,
+    qrels=qrels,
+    queries=queries,
+    metrics=[f"ndcg@{k}" for k in [1, 3, 5, 10, 100]] # NDCG for different k values
+    + [f"hits@{k}" for k in [1, 3, 5, 10, 100]]       # Hits at different k values
+    + ["map"]                                         # Mean Average Precision (MAP)
+    + ["recall@10", "recall@100"]                     # Recall at k
+    + ["precision@10", "precision@100"],              # Precision at k
+)
+
+print(results)
+```
+
+The output is a dictionary containing various evaluation metrics. Here’s a sample output:
+
+```python
+{
+    "ndcg@1": 0.47333333333333333,
+    "ndcg@3": 0.543862513095773,
+    "ndcg@5": 0.5623210323686343,
+    "ndcg@10": 0.5891793972249917,
+    "ndcg@100": 0.5891793972249917,
+    "hits@1": 0.47333333333333333,
+    "hits@3": 0.64,
+    "hits@5": 0.7033333333333334,
+    "hits@10": 0.8,
+    "hits@100": 0.8,
+    "map": 0.5442202380952381,
+    "recall@10": 0.7160555555555556,
+    "recall@100": 0.7160555555555556,
+    "precision@10": 0.08,
+    "precision@100": 0.008000000000000002,
+}
+```
+
+Key Points:
+
+1. is_query flag: Always set is_query=True when encoding queries and is_query=False when encoding documents. This ensures the model applies the correct prefixes for queries and documents.
+2. Evaluation metrics: The pipeline supports a wide range of evaluation metrics, including NDCG, hits, MAP, recall, and precision, with different cutoff points.
+3. Relevance judgments (qrels): The qrels are used to calculate how well the retrieved documents match the ground truth.
+
+### Beir datasets
+
+The following table lists the datasets available in the BEIR benchmark along with their names, types, number of queries, corpus size, and relevance degree per query. Source: [BEIR Datasets](https://github.com/beir-cellar/beir?tab=readme-ov-file)
+
+=== "Table"
+
+    | Dataset       | BEIR-Name       | Type              | Queries | Corpus      |
+    |---------------|-----------------|-------------------|---------|-------------|
+    | MSMARCO       | msmarco          | train, dev, test  | 6,980   | 8,840,000   |
+    | TREC-COVID    | trec-covid       | test              | 50      | 171,000     |
+    | NFCorpus      | nfcorpus         | train, dev, test  | 323     | 3,600       |
+    | BioASQ        | bioasq           | train, test       | 500     | 14,910,000  |
+    | NQ            | nq               | train, test       | 3,452   | 2,680,000   |
+    | HotpotQA      | hotpotqa         | train, dev, test  | 7,405   | 5,230,000   |
+    | FiQA-2018     | fiqa             | train, dev, test  | 648     | 57,000      |
+    | Signal-1M(RT) | signal1m         | test              | 97      | 2,860,000   |
+    | TREC-NEWS     | trec-news        | test              | 57      | 595,000     |
+    | Robust04      | robust04         | test              | 249     | 528,000     |
+    | ArguAna       | arguana          | test              | 1,406   | 8,670       |
+    | Touche-2020   | webis-touche2020 | test              | 49      | 382,000     |
+    | CQADupstack   | cqadupstack      | test              | 13,145  | 457,000     |
+    | Quora         | quora            | dev, test         | 10,000  | 523,000     |
+    | DBPedia       | dbpedia-entity   | dev, test         | 400     | 4,630,000   |
+    | SCIDOCS       | scidocs          | test              | 1,000   | 25,000      |
+    | FEVER         | fever            | train, dev, test  | 6,666   | 5,420,000   |
+    | Climate-FEVER | climate-fever    | test              | 1,535   | 5,420,000   |
+    | SciFact       | scifact          | train, test       | 300     | 5,000       |
+
+
+### Metrics
+
+PyLate evaluation is based on [Ranx Python library](https://amenra.github.io/ranx/metrics/) to compute standard Information Retrieval metrics. The following metrics are supported:
+
+=== "Table"
+
+    | Metric                    | Alias         | @k  |
+    |----------------------------|---------------|-----|
+    | Hits                       | hits          | Yes |
+    | Hit Rate / Success         | hit_rate      | Yes |
+    | Precision                  | precision     | Yes |
+    | Recall                     | recall        | Yes |
+    | F1                         | f1            | Yes |
+    | R-Precision                | r_precision   | No  |
+    | Bpref                      | bpref         | No  |
+    | Rank-biased Precision      | rbp           | No  |
+    | Mean Reciprocal Rank       | mrr           | Yes |
+    | Mean Average Precision     | map           | Yes |
+    | DCG                        | dcg           | Yes |
+    | DCG Burges                 | dcg_burges    | Yes |
+    | NDCG                       | ndcg          | Yes |
+    | NDCG Burges                | ndcg_burges   | Yes |
+
+
+For any details about the metrics, please refer to [Ranx documentation](https://amenra.github.io/ranx/metrics/).
+
+Sample code to evaluate the retrieval results using specific metrics:
+
+```python
+results = evaluation.evaluate(
+    scores=scores,
+    qrels=qrels,
+    queries=queries,
+    metrics=[f"ndcg@{k}" for k in [1, 3, 5, 10, 100]] # NDCG for different k values
+    + [f"hits@{k}" for k in [1, 3, 5, 10, 100]]       # Hits at different k values
+    + ["map"]                                         # Mean Average Precision (MAP)
+    + ["recall@10", "recall@100"]                     # Recall at k
+    + ["precision@10", "precision@100"],              # Precision at k
+)
+```
\ No newline at end of file
diff --git a/docs/documentation/retrieval.md b/docs/documentation/retrieval.md
new file mode 100644
index 0000000..9452d60
--- /dev/null
+++ b/docs/documentation/retrieval.md
@@ -0,0 +1,103 @@
+## ColBERT Retrieval
+
+The ColBERT retrieval module provide a streamlined interface to index and retrieve documents using the ColBERT model. It leverages the Voyager index to efficiently handle document embeddings and enable fast retrieval.
+
+### Indexing documents
+
+First, initialize the ColBERT model and Voyager index, then encode and index your documents:
+
+1. Initialize the ColBERT model.
+2. Set up the Voyager index.
+3. Encode documents: Ensure `is_query=False` when encoding documents so the model knows it is processing documents rather than queries.
+4. Add documents to the index: Provide both document IDs and their corresponding embeddings to the index.
+
+Here’s an example code for indexing:
+
+```python
+from pylate import indexes, models, retrieve
+
+# Step 1: Initialize the ColBERT model
+model = models.ColBERT(
+    model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
+)
+
+# Step 2: Create a Voyager index
+index = indexes.Voyager(
+    index_folder="pylate-index",
+    index_name="index",
+    override=True,  # This overwrites the existing index if any
+)
+
+# Step 3: Encode the documents
+documents_ids = ["1", "2", "3"]
+documents = ["document 1 text", "document 2 text", "document 3 text"]
+
+documents_embeddings = model.encode(
+    documents,
+    batch_size=32,
+    is_query=False,  # Indicate that these are documents, not queries
+    show_progress_bar=True,
+)
+
+# Step 4: Add document embeddings to the index
+index.add_documents(
+    documents_ids=documents_ids,
+    documents_embeddings=documents_embeddings,
+)
+```
+
+### Retrieving top-k documents for queries
+
+Once the documents are indexed, you can retrieve the top-k most relevant documents for a given set of queries.
+
+1. Initialize the ColBERT retriever
+2. Encode the queries: Use the same ColBERT model. Be sure to set `is_query=True`, so the system can differentiate between queries and documents.
+3. Retrieve top-k documents: Pass the query embeddings to the retriever to get the top matches, including document IDs and relevance scores.
+
+Here’s the code for retrieving relevant documents:
+
+```python
+# Step 1: Initialize the ColBERT retriever
+retriever = retrieve.ColBERT(index=index)
+
+# Step 2: Encode the queries
+queries_embeddings = model.encode(
+    ["query for document 3", "query for document 1"],
+    batch_size=32,
+    is_query=True,  # Indicate that these are queries
+    show_progress_bar=True,
+)
+
+# Step 3: Retrieve top-k documents
+scores = retriever.retrieve(
+    queries_embeddings=queries_embeddings, 
+    k=10,  # Retrieve the top 10 matches for each query
+)
+
+print(scores)
+```
+
+Example output
+
+```python
+[
+    [   # Candidates for the first query
+        {"id": "3", "score": 11.266985893249512},
+        {"id": "1", "score": 10.303335189819336},
+        {"id": "2", "score": 9.502392768859863},
+    ],
+    [   # Candidates for the second query
+        {"id": "1", "score": 10.88800048828125},
+        {"id": "3", "score": 9.950843811035156},
+        {"id": "2", "score": 9.602447509765625},
+    ],
+]
+```
+
+## Remove documents from the index
+
+To remove documents from the index, use the `remove_documents` method. Provide the document IDs you want to remove from the index.
+
+```python
+index.remove_documents(["1", "2"])
+```
\ No newline at end of file
diff --git a/docs/documentation/training.md b/docs/documentation/training.md
new file mode 100644
index 0000000..493d91a
--- /dev/null
+++ b/docs/documentation/training.md
@@ -0,0 +1,229 @@
+# ColBERT Training
+
+PyLate supports multi-GPU training for ColBERT models, allowing efficient scaling across multiple GPUs. There are two primary ways to fine-tune ColBERT models using PyLate:
+
+1. Contrastive Loss (Simplest Method): The easiest way to fine-tune your model is by using Contrastive Loss, which only requires a dataset containing triplets—each consisting of a query, a positive document (relevant to the query), and a negative document (irrelevant to the query). This method trains the model to maximize the similarity between the query and the positive document, while minimizing it with the negative document.
+
+2. Knowledge Distillation: To fine-tune a ColBERT model using Knowledge Distillation, you need to provide a dataset with three components: queries, documents, and the relevance scores between them. This method compresses the knowledge of a larger model / more accurate model (cross-encoder) into a smaller one, using the relevance scores to guide the training process.
+
+## Knowledge Distillation Training
+
+Knowledge distillation training aïm at training a ColBERT models to reproduce the outputs of Cross-Encoder model. This is done by using a dataset containing queries, documents and the scores between them.
+
+Example Code for Knowledge Distillation Training:
+
+```python
+from datasets import load_dataset
+from sentence_transformers import (
+    SentenceTransformerTrainer,
+    SentenceTransformerTrainingArguments,
+)
+
+from pylate import losses, models, utils
+
+# Load the datasets required for knowledge distillation (train, queries, documents)
+train = load_dataset(
+    path="./datasets/msmarco_fr_full",
+    name="train",
+)
+
+queries = load_dataset(
+    path="./datasets/msmarco_fr_full",
+    name="queries",
+)
+
+documents = load_dataset(
+    path="./datasets/msmarco_fr_full",
+    name="documents",
+)
+
+# Apply transformations to align the data with the knowledge distillation process
+train.set_transform(
+    utils.KDProcessing(queries=queries, documents=documents).transform,
+)
+
+# Define the model, training parameters, and output directory
+model_name = "bert-base-uncased"
+batch_size = 16
+num_train_epochs = 1
+output_dir = "output/distillation_run-bert-base"
+
+# Initialize the ColBERT model for distillation
+model = models.ColBERT(model_name_or_path=model_name)
+
+# Configure the training arguments (e.g., epochs, batch size, learning rate)
+args = SentenceTransformerTrainingArguments(
+    output_dir=output_dir,
+    num_train_epochs=num_train_epochs,
+    per_device_train_batch_size=batch_size,
+    fp16=False,
+    bf16=False,
+    logging_steps=10,
+    run_name="distillation_run-bert-base",
+    learning_rate=1e-5,
+)
+
+# Use the Distillation loss function for training
+train_loss = losses.Distillation(model=model)
+
+# Initialize the trainer for the distillation process
+trainer = SentenceTransformerTrainer(
+    model=model,
+    args=args,
+    train_dataset=train,
+    loss=train_loss,
+    data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
+)
+
+# Start the training process
+trainer.train()
+```
+
+## Contrastive Training
+
+Contrastive training is used to improve the model's ability to differentiate between relevant and irrelevant documents by maximizing the similarity between a query and a relevant document while minimizing it with irrelevant documents.
+
+Example Code for Contrastive Training:
+
+```python
+from sentence_transformers import (
+    SentenceTransformerTrainer,
+    SentenceTransformerTrainingArguments,
+)
+
+from datasets import load_dataset
+from pylate import evaluation, losses, models, utils
+
+# Define model parameters for contrastive training
+model_name = "output/answerai-colbert-small-v1"  # Choose the pre-trained model you want to use
+batch_size = 32  # A larger batch size often improves results, but requires more GPU memory
+num_train_epochs = 1  # Adjust based on your requirements
+
+# Set the output directory for saving the trained model
+output_dir = "output/msmarco_bm25_contrastive_bert-base-uncased"
+
+# Initialize the ColBERT model, adding a linear layer if it's not already a ColBERT model
+model = models.ColBERT(model_name_or_path=model_name)
+
+# Load the contrastive dataset (query, positive, and negative pairs)
+dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")
+
+# Split the dataset into training and evaluation subsets
+splits = dataset.train_test_split(test_size=0.01)
+train_dataset = splits["train"]
+eval_dataset = splits["test"]
+
+# Define the contrastive loss function for training
+train_loss = losses.Contrastive(model=model)
+
+# Set up an evaluator for validation using the contrastive approach (query, positive, negative)
+dev_evaluator = evaluation.ColBERTTripletEvaluator(
+    anchors=eval_dataset["query"],
+    positives=eval_dataset["positive"],
+    negatives=eval_dataset["negative"],
+)
+
+# Configure the training arguments (e.g., batch size, evaluation strategy, logging steps)
+args = SentenceTransformerTrainingArguments(
+    output_dir=output_dir,
+    num_train_epochs=num_train_epochs,
+    per_device_train_batch_size=batch_size,
+    per_device_eval_batch_size=batch_size,
+    fp16=False,  # Disable FP16 if the GPU does not support it
+    bf16=True,   # Enable BF16 if supported by the GPU
+    eval_strategy="steps",
+    eval_steps=0.1,
+    save_strategy="steps",
+    save_steps=5000,
+    save_total_limit=2,
+    logging_steps=10,
+    report_to="none",  # Set to 'none' to avoid sending data to monitoring services like W&B
+    run_name="msmarco_bm25_contrastive_bert-base-uncased",
+    learning_rate=3e-6,  # Adjust learning rate based on the task
+)
+
+# Initialize the trainer for the contrastive training
+trainer = SentenceTransformerTrainer(
+    model=model,
+    args=args,
+    train_dataset=train_dataset,
+    eval_dataset=eval_dataset,
+    loss=train_loss,
+    evaluator=dev_evaluator,
+    data_collator=utils.ColBERTCollator(model.tokenize),
+)
+
+# Start the training process
+trainer.train()
+
+```
+
+## Sentence Transformers Training Arguments
+
+The table below lists the arguments for the `SentenceTransformerTrainingArguments` class. Feel free to refer to the [SentenceTransformers](https://sbert.net/docs/sentence_transformer/training_overview.html#) library documentation for more information
+
+=== "Table"
+| Parameter                         | Name                                 | Definition                                                                                                                                                                                                                                                                     | Training Performance |  Observing Performance |
+|------------------------------------|--------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------|------------------------------------------------------------|
+| `output_dir`                       | `str`                                | The output directory where the model predictions and checkpoints will be written.                                                                                                                                                                                               |                                                           |                                                            |
+| `overwrite_output_dir`             | `bool`, *optional*, defaults to `False`| If `True`, overwrite the content of the output directory. Use this to continue training if `output_dir` points to a checkpoint directory.                                                                                                                                      |                                                           |                                                            |
+| `do_train`                         | `bool`, *optional*, defaults to `False`| Whether to run training or not. Intended to be used by your training/evaluation scripts.                                                                                                                                                                                        |                                                           |                                                            |
+| `do_eval`                          | `bool`, *optional*                   | Whether to run evaluation on the validation set. Will be `True` if `eval_strategy` is not `"no"`. Intended to be used by your training/evaluation scripts.                                                                                                                      |                                                           |                                                            |
+| `do_predict`                       | `bool`, *optional*, defaults to `False`| Whether to run predictions on the test set or not. Intended to be used by your training/evaluation scripts.                                                                                                                                                                      |                                                           |                                                            |
+| `eval_strategy`                    | `str` or `~trainer_utils.IntervalStrategy`, *optional*, defaults to `"no"`| The evaluation strategy to adopt during training. Possible values are `"no"`, `"steps"`, or `"epoch"`.                                                                                                                                                                         |                                                           | ✅                                                         |
+| `prediction_loss_only`             | `bool`, *optional*, defaults to `False`| When performing evaluation and generating predictions, only returns the loss.                                                                                                                                                                                                   |                                                           |                                                            |
+| `per_device_train_batch_size`      | `int`, *optional*, defaults to 8      | The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for training.                                                                                                                                                                                                                   | ✅                                                         |                                                            |
+| `per_device_eval_batch_size`       | `int`, *optional*, defaults to 8      | The batch size per GPU/XPU/TPU/MPS/NPU core/CPU for evaluation.                                                                                                                                                                                                                 | ✅                                                         |                                                            |
+| `gradient_accumulation_steps`      | `int`, *optional*, defaults to 1      | Number of updates steps to accumulate gradients before performing a backward/update pass.                                                                                                                                                                                       | ✅                                                         |                                                            |
+| `eval_accumulation_steps`          | `int`, *optional*                    | Number of predictions steps to accumulate the output tensors before moving the results to CPU.                                                                                                                                                                                  | ✅                                                         |                                                            |
+| `eval_delay`                       | `float`, *optional*                  | Number of epochs or steps to wait before the first evaluation depending on `eval_strategy`.                                                                                                                                                                                     |                                                           |                                                            |
+| `torch_empty_cache_steps`          | `int`, *optional*                    | Number of steps to wait before calling `torch.<device>.empty_cache()` to avoid CUDA out-of-memory errors.                                                                                                                                                                       |                                                           |                                                            |
+| `learning_rate`                    | `float`, *optional*, defaults to 5e-5| The initial learning rate for `AdamW` optimizer.                                                                                                                                                                                                                                | ✅                                                         |                                                            |                                                                                                                                                    |                                                           |                                                            |
+| `num_train_epochs`                 | `float`, *optional*, defaults to 3.0 | Total number of training epochs to perform.                                                                                                                                                                                                                                     | ✅                                                         |                                                            |
+| `max_steps`                        | `int`, *optional*, defaults to -1     | If set to a positive number, the total number of training steps to perform. Overrides `num_train_epochs`.                                                                                                                                                                       | ✅                                                         |                                                            |
+| `lr_scheduler_type`                | `str` or `SchedulerType`, *optional*, defaults to `"linear"`| The scheduler type to use.                                                                                                                                                                                                                                                     | ✅                                                         |                                                            |
+| `lr_scheduler_kwargs`              | `dict`, *optional*, defaults to {}    | Extra arguments for the learning rate scheduler.                                                                                                                                                                                                                                |                                                           |                                                            |
+| `warmup_ratio`                     | `float`, *optional*, defaults to 0.0 | Ratio of total training steps used for linear warmup from 0 to `learning_rate`.                                                                                                                                                                                                 | ✅                                                         |                                                            |
+| `warmup_steps`                     | `int`, *optional*, defaults to 0      | Number of steps used for linear warmup from 0 to `learning_rate`. Overrides any effect of `warmup_ratio`.                                                                                                                                                                       |                                                           |                                                            |
+| `log_level`                        | `str`, *optional*, defaults to `passive`| Logger log level to use on the main process.                                                                                                                                                                                                                                    |                                                           | ✅                                                         |
+| `log_level_replica`                | `str`, *optional*, defaults to `"warning"`| Logger log level to use on replicas. Same choices as `log_level`.                                                                                                                                                                                                               |                                                           |                                                            |
+| `log_on_each_node`                 | `bool`, *optional*, defaults to `True`| Whether to log using `log_level` once per node or only on the main node.                                                                                                                                                                                                        |                                                           |                                                            |
+| `logging_dir`                      | `str`, *optional*                    | TensorBoard log directory.                                                                                                                                                                                                                                                     |                                                           |                                                            |
+| `logging_strategy`                 | `str` or `~trainer_utils.IntervalStrategy`, *optional*, defaults to `"steps"`| The logging strategy to adopt during training. Possible values are `"no"`, `"epoch"`, or `"steps"`.                                                                                                                                                                            |                                                           | ✅                                                         |
+| `logging_first_step`               | `bool`, *optional*, defaults to `False`| Whether to log the first `global_step` or not.                                                                                                                                                                                                                                  |                                                           |                                                            |
+| `logging_steps`                    | `int` or `float`, *optional*, defaults to 500| Number of update steps between two logs if `logging_strategy="steps"`.                                                                                                                                                                                                          |                                                           | ✅                                                         |
+| `logging_nan_inf_filter`           | `bool`, *optional*, defaults to `True`| Whether to filter `nan` and `inf` losses for logging.                                                                                                                                                                                                                           |                                                           |                                                            |
+| `save_strategy`                    | `str` or `~trainer_utils.IntervalStrategy`, *optional*, defaults to `"steps"`| The checkpoint save strategy to adopt during training.                                                                                                                                                                                                                          |                                                           | ✅                                                         |
+| `save_steps`                       | `int` or `float`, *optional*, defaults to 500| Number of update steps before two checkpoint saves if `save_strategy="steps"`.                                                                                                                                                                                                  |                                                           | ✅                                                         |
+| `save_total_limit`                 | `int`, *optional*                    | Limit for total number of checkpoints.                                                                                                                                                                                                                                          |                                                           | ✅                                                         |
+| `save_safetensors`                 | `bool`, *optional*, defaults to `True`| Use safetensors saving and loading for state dicts instead of default `torch.load` and `torch.save`.                                                                                                                                                                            |                                                           |                                                            |
+| `save_on_each_node`                | `bool`, *optional*, defaults to `False`| Whether to save models and checkpoints on each node or only on the main one during multi-node distributed training.                                                                                                                                                              |                                                           |                                                            |
+| `seed`                             | `int`, *optional*, defaults to 42     | Random seed set at the beginning of training for reproducibility.                                                                                                                                                                                                               |                                                           |                                                            |
+| `auto_find_batch_size`             | `bool`, *optional*, defaults to `False`| Whether to find a batch size that will fit into memory automatically.                                                                                                                                                                                                           | ✅                                                         |                                                            |
+| `fp16`                             | `bool`, *optional*, defaults to `False`| Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.                                                                                                                                                                                               | ✅                                                         |                                                            |
+| `bf16`                             | `bool`, *optional*, defaults to `False`| Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training.                                                                                                                                                                                               | ✅                                                         |                                                            |
+| `push_to_hub`                      | `bool`, *optional*, defaults to `False`| Whether to push the model to the Hub every time the model is saved.                                                                                                                                                                                                             |                                                           | ✅                                                         |
+| `hub_model_id`                     | `str`, *optional*                    | The name of the repository to keep in sync with the local `output_dir`.                                                                                                                                                                                                         |                                                           | ✅                                                         |
+| `hub_strategy`                     | `str` or `~trainer_utils.HubStrategy`, *optional*, defaults to `"every_save"`| Defines the scope of what is pushed to the Hub and when.                                                                                                                                                                                                                        |                                                           | ✅                                                         |
+| `hub_private_repo`                 | `bool`, *optional*, defaults to `False`| If `True`, the Hub repo will be set to private.                                                                                                                                                                                                                                 |                                                           | ✅                                                         |
+| `load_best_model_at_end`           | `bool`, *optional*, defaults to `False`| Whether or not to load the best model found during training at the end of training.                                                                                                                                                                                             |                                                           | ✅                                                         |
+| `report_to`                        | `str` or `List[str]`, *optional*, defaults to `"all"`| The list of integrations to report the results and logs to.                                                                                                                                                                                                                     |                                                           | ✅                                                         |
+
+
+
+## Sentence Transformer Trainer arguments
+
+=== "Table"
+
+    | Parameter   | Name                                                                                             | Definition                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   |
+    |-------------|--------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+    | model       | `~sentence_transformers.SentenceTransformer`, *optional*                                          | The model to train, evaluate, or use for predictions. If not provided, a `model_init` must be passed.                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                         |
+    | args        | `~sentence_transformers.training_args.SentenceTransformerTrainingArguments`, *optional*           | The arguments to tweak for training. Defaults to a basic instance of `SentenceTransformerTrainingArguments` with the `output_dir` set to a directory named *tmp_trainer* in the current directory if not provided.                                                                                                                                                                                                                                                                                                                                                                      |
+    | train_dataset | `datasets.Dataset`, `datasets.DatasetDict`, or `Dict[str, datasets.Dataset]`, *optional*        | The dataset to use for training. Must have a format accepted by your loss function. Refer to `Training Overview > Dataset Format`.                                                                                                                                                                                                                                                                                                                                                                                                                                                     |
+    | eval_dataset | `datasets.Dataset`, `datasets.DatasetDict`, or `Dict[str, datasets.Dataset]`, *optional*         | The dataset to use for evaluation. Must have a format accepted by your loss function. Refer to `Training Overview > Dataset Format`.                                                                                                                                                                                                                                                                                                                                                                                                                                                   |
+    | loss        | `torch.nn.Module`, `Dict[str, torch.nn.Module]`, Callable, or Dict[str, Callable], *optional*     | The loss function to use for training. It can be a loss class instance, a dictionary mapping dataset names to loss instances, a function returning a loss instance given a model, or a dictionary mapping dataset names to such functions. Defaults to `CoSENTLoss` if not provided.                                                                                                                                                                                                                                                                                                    |
+    | evaluator   | `~sentence_transformers.evaluation.SentenceEvaluator` or `List[~sentence_transformers.evaluation.SentenceEvaluator]`, *optional* | The evaluator instance for useful metrics during training. Can be used with or without an `eval_dataset`. A list of evaluators will be wrapped in a `SequentialEvaluator` to run sequentially. Generally, evaluator metrics are more useful than loss values from `eval_dataset`.                                                                                                                                                                                                                                                                                                       |
+    | callbacks   | `List[transformers.TrainerCallback]`, *optional*                                                  | A list of callbacks to customize the training loop. Adds to the list of default callbacks. To remove a default callback, use the `Trainer.remove_callback` method.                                                                                                                                                                                                                                                                                                                                                                                                                                                            |
+    | optimizers  | `Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)` | A tuple containing the optimizer and scheduler to use. Defaults to an instance of `torch.optim.AdamW` for the model and a scheduler given by `transformers.get_linear_schedule_with_warmup`, controlled by `args`.                                                                                                                                                                                                                                                                                                                                                                      |
+
+
diff --git a/docs/img/favicon.png b/docs/img/favicon.png
new file mode 100644
index 0000000..cd74773
Binary files /dev/null and b/docs/img/favicon.png differ
diff --git a/docs/img/logo.png b/docs/img/logo.png
new file mode 100644
index 0000000..9b19354
Binary files /dev/null and b/docs/img/logo.png differ
diff --git a/docs/index.md b/docs/index.md
new file mode 100644
index 0000000..df1faea
--- /dev/null
+++ b/docs/index.md
@@ -0,0 +1,250 @@
+<div align="center">
+  <h1>PyLate</h1>
+  <p>Efficient training and retrieval with ColBERT</p>
+</div>
+
+<p align="center"><img width=500 src="img/logo.png"/></p>
+
+<div align="center">
+  <!-- Documentation -->
+  <a href="https://github.com/lightonai/pylate"><img src="https://img.shields.io/badge/Documentation-purple.svg?style=flat-square" alt="documentation"></a>
+  <!-- License -->
+  <a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-blue.svg?style=flat-square" alt="license"></a>
+</div>
+
+
+
+PyLate is a library built on top of Sentence Transformers, designed to simplify and optimize training, inference, and retrieval using ColBERT models. With PyLate, you can efficiently train ColBERT models on Triplet loss or Knowledge Distillation and deploy them for document retrieval tasks with ease.
+
+## Installation
+
+We can install pylate using:
+
+```bash
+pip install pylate
+```
+
+Install with evaluation dependencies:
+
+```bash
+pip install "pylate[eval]"
+```
+
+## Documentation 
+
+The complete documentation is available [here](https://lightonai.github.io/pylate/), which includes in-depth guides, examples, and API references.
+
+## Datasets
+
+PyLate supports Hugging Face [Datasets](https://huggingface.co/docs/datasets/en/index), enabling seamless triplet / knowledge distillation based training. Below is an example of creating a custom dataset for training:
+
+```python
+from datasets import Dataset
+
+dataset = [
+    {
+        "query": "example query 1",
+        "positive": "example positive document 1",
+        "negative": "example negative document 1",
+    },
+    {
+        "query": "example query 2",
+        "positive": "example positive document 2",
+        "negative": "example negative document 2",
+    },
+    {
+        "query": "example query 3",
+        "positive": "example positive document 3",
+        "negative": "example negative document 3",
+    },
+]
+
+dataset = Dataset.from_list(mapping=dataset)
+
+train_dataset, test_dataset = dataset.train_test_split(test_size=0.3)
+```
+
+## Training
+
+Here’s a simple example of training a ColBERT model on the MSMARCO dataset using PyLate. This script demonstrates training with triplet loss and evaluating the model on a test set.
+
+```python
+from datasets import load_dataset
+from sentence_transformers import (
+    SentenceTransformerTrainer,
+    SentenceTransformerTrainingArguments,
+)
+from sentence_transformers.training_args import BatchSamplers
+
+from pylate import evaluation, losses, models, utils
+
+# Define the model
+model = models.ColBERT(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
+
+# Load dataset
+dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")
+
+# Split the dataset to create a test set
+train_dataset, eval_dataset = dataset.train_test_split(test_size=0.01)
+
+# Shuffle and select a subset of the dataset for demonstration purposes
+MAX_TRAIN_SIZE, MAX_EVAL_SIZE = 100, 100
+train_dataset = train_dataset.shuffle(seed=21).select(range(MAX_TRAIN_SIZE))
+eval_dataset = eval_dataset.shuffle(seed=21).select(range(MAX_EVAL_SIZE))
+
+# Define the loss function
+train_loss = losses.Contrastive(model=model)
+
+args = SentenceTransformerTrainingArguments(
+    output_dir="colbert-training",
+    num_train_epochs=1,
+    per_device_train_batch_size=32,
+    per_device_eval_batch_size=32,
+    fp16=False,  # Some GPUs support FP16 which is faster than FP32
+    bf16=False,  # Some GPUs support BF16 which is a faster FP16
+    batch_sampler=BatchSamplers.NO_DUPLICATES,
+    # Tracking parameters:
+    eval_strategy="steps",
+    eval_steps=0.1,
+    save_strategy="steps",
+    save_steps=5000,
+    save_total_limit=2,
+    learning_rate=3e-6,
+)
+
+# Evaluation procedure
+dev_evaluator = evaluation.ColBERTTripletEvaluator(
+    anchors=eval_dataset["query"],
+    positives=eval_dataset["positive"],
+    negatives=eval_dataset["negative"],
+)
+
+trainer = SentenceTransformerTrainer(
+    model=model,
+    args=args,
+    train_dataset=train_dataset,
+    eval_dataset=eval_dataset,
+    loss=train_loss,
+    evaluator=dev_evaluator,
+    data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
+)
+
+trainer.train()
+
+model.save_pretrained("custom-colbert-model")
+```
+
+After training, the model can be loaded like this:
+
+```python
+from pylate import models
+
+model = models.ColBERT(model_name_or_path="custom-colbert-model")
+```
+
+##  Retrieve
+
+PyLate allows easy retrieval of top documents for a given query set using the trained ColBERT model and Voyager index.
+
+```python
+from pylate import indexes, models, retrieve
+
+model = models.ColBERT(
+    model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
+)
+
+index = indexes.Voyager(
+    index_folder="pylate-index",
+    index_name="index",
+    override=True,
+)
+
+retriever = retrieve.ColBERT(index=index)
+```
+
+Once the model and index are set up, we can add documents to the index:
+
+```python
+documents_ids = ["1", "2", "3"]
+
+documents = [
+    "document 1 text", "document 2 text", "document 3 text"
+]
+
+# Encode the documents
+documents_embeddings = model.encode(
+    documents,
+    batch_size=32,
+    is_query=False, # Encoding documents
+    show_progress_bar=True,
+)
+
+# Add the documents ids and embeddings to the Voyager index
+index.add_documents(
+    documents_ids=documents_ids,
+    documents_embeddings=documents_embeddings,
+)
+```
+
+Then we can retrieve the top-k documents for a given query set:
+
+```python
+queries_embeddings = model.encode(
+    ["query for document 3", "query for document 1"],
+    batch_size=32,
+    is_query=True, # Encoding queries
+    show_progress_bar=True,
+)
+
+scores = retriever.retrieve(
+    queries_embeddings=queries_embeddings, 
+    k=10,
+)
+
+print(scores)
+```
+
+Sample Output:
+
+```python
+[
+    [
+        {"id": "3", "score": 11.266985893249512},
+        {"id": "1", "score": 10.303335189819336},
+        {"id": "2", "score": 9.502392768859863},
+    ],
+    [
+        {"id": "1", "score": 10.88800048828125},
+        {"id": "3", "score": 9.950843811035156},
+        {"id": "2", "score": 9.602447509765625},
+    ],
+]
+```
+
+## Contributing
+
+We welcome contributions! To get started:
+
+1. Install the development dependencies:
+
+```bash
+pip install "pylate[dev]"
+```
+
+2. Run tests:
+
+```bash
+make test
+```
+
+3. Format code with Ruff:
+
+```bash
+make ruff
+```
+
+4. Build the documentation:
+
+```bash
+make livedoc
+```
\ No newline at end of file
diff --git a/docs/javascripts/config.js b/docs/javascripts/config.js
new file mode 100644
index 0000000..a80ddbf
--- /dev/null
+++ b/docs/javascripts/config.js
@@ -0,0 +1,16 @@
+window.MathJax = {
+    tex: {
+        inlineMath: [["\\(", "\\)"]],
+        displayMath: [["\\[", "\\]"]],
+        processEscapes: true,
+        processEnvironments: true
+    },
+    options: {
+        ignoreHtmlClass: ".*|",
+        processHtmlClass: "arithmatex"
+    }
+};
+
+document$.subscribe(() => {
+    MathJax.typesetPromise()
+})
\ No newline at end of file
diff --git a/docs/javascripts/tablesort.js b/docs/javascripts/tablesort.js
new file mode 100644
index 0000000..ee04e90
--- /dev/null
+++ b/docs/javascripts/tablesort.js
@@ -0,0 +1,6 @@
+document$.subscribe(function () {
+    var tables = document.querySelectorAll("article table:not([class])")
+    tables.forEach(function (table) {
+        new Tablesort(table)
+    })
+})
\ No newline at end of file
diff --git a/docs/js/version-select.js b/docs/js/version-select.js
new file mode 100644
index 0000000..127ab63
--- /dev/null
+++ b/docs/js/version-select.js
@@ -0,0 +1,49 @@
+window.addEventListener("DOMContentLoaded", function () {
+    // This is a bit hacky. Figure out the base URL from a known CSS file the
+    // template refers to...
+    var ex = new RegExp("/?css/version-select.css$");
+    var sheet = document.querySelector('link[href$="version-select.css"]');
+
+    var ABS_BASE_URL = sheet.href.replace(ex, "");
+    var CURRENT_VERSION = ABS_BASE_URL.split("/").pop();
+
+    function makeSelect(options, selected) {
+        var select = document.createElement("select");
+        select.classList.add("form-control");
+
+        options.forEach(function (i) {
+            var option = new Option(i.text, i.value, undefined,
+                i.value === selected);
+            select.add(option);
+        });
+
+        return select;
+    }
+
+    var xhr = new XMLHttpRequest();
+    xhr.open("GET", ABS_BASE_URL + "/../versions.json");
+    xhr.onload = function () {
+        var versions = JSON.parse(this.responseText);
+
+        var realVersion = versions.find(function (i) {
+            return i.version === CURRENT_VERSION ||
+                i.aliases.includes(CURRENT_VERSION);
+        }).version;
+
+        var select = makeSelect(versions.map(function (i) {
+            return { text: i.title, value: i.version };
+        }), realVersion);
+        select.addEventListener("change", function (event) {
+            window.location.href = ABS_BASE_URL + "/../" + this.value;
+        });
+
+        var container = document.createElement("div");
+        container.id = "version-selector";
+        container.className = "md-nav__item";
+        container.appendChild(select);
+
+        var sidebar = document.querySelector(".md-nav--primary > .md-nav__list");
+        sidebar.parentNode.insertBefore(container, sidebar);
+    };
+    xhr.send();
+});
\ No newline at end of file
diff --git a/docs/parse/__main__.py b/docs/parse/__main__.py
new file mode 100644
index 0000000..6615a3a
--- /dev/null
+++ b/docs/parse/__main__.py
@@ -0,0 +1,474 @@
+"""This script is responsible for building the API reference. The API reference is located in
+docs/api. The script scans through all the modules, classes, and functions. It processes
+the __doc__ of each object and formats it so that MkDocs can process it in turn.
+"""
+
+import functools
+import importlib
+import inspect
+import os
+import pathlib
+import re
+import shutil
+
+from numpydoc.docscrape import ClassDoc, FunctionDoc
+
+package = "pylate"
+
+shutil.copy("README.md", "docs/index.md")
+
+
+with open("docs/index.md", mode="r") as file:
+    content = file.read()
+
+with open("docs/index.md", mode="w") as file:
+    file.write(content.replace("docs/img/logo.png", "img/logo.png"))
+
+
+def paragraph(text):
+    return f"{text}\n"
+
+
+def h1(text):
+    return paragraph(f"# {text}")
+
+
+def h2(text):
+    return paragraph(f"## {text}")
+
+
+def h3(text):
+    return paragraph(f"### {text}")
+
+
+def h4(text):
+    return paragraph(f"#### {text}")
+
+
+def link(caption, href):
+    return f"[{caption}]({href})"
+
+
+def code(text):
+    return f"`{text}`"
+
+
+def li(text):
+    return f"- {text}\n"
+
+
+def snake_to_kebab(text):
+    return text.replace("_", "-")
+
+
+def inherit_docstring(c, meth):
+    """Since Python 3.5, inspect.getdoc is supposed to return the docstring from a parent class
+    if a class has none. However this doesn't seem to work for Cython classes.
+    """
+
+    doc = None
+
+    for ancestor in inspect.getmro(c):
+        try:
+            ancestor_meth = getattr(ancestor, meth)
+        except AttributeError:
+            break
+        doc = inspect.getdoc(ancestor_meth)
+        if doc:
+            break
+
+    return doc
+
+
+def inherit_signature(c, method_name):
+    m = getattr(c, method_name)
+    sig = inspect.signature(m)
+
+    params = []
+
+    for param in sig.parameters.values():
+        if param.name == "self" or param.annotation is not param.empty:
+            params.append(param)
+            continue
+
+        for ancestor in inspect.getmro(c):
+            try:
+                ancestor_meth = inspect.signature(getattr(ancestor, m.__name__))
+            except AttributeError:
+                break
+            try:
+                ancestor_param = ancestor_meth.parameters[param.name]
+            except KeyError:
+                break
+            if ancestor_param.annotation is not param.empty:
+                param = param.replace(annotation=ancestor_param.annotation)
+                break
+
+        params.append(param)
+
+    return_annotation = sig.return_annotation
+    if return_annotation is inspect._empty:
+        for ancestor in inspect.getmro(c):
+            try:
+                ancestor_meth = inspect.signature(getattr(ancestor, m.__name__))
+            except AttributeError:
+                break
+            if ancestor_meth.return_annotation is not inspect._empty:
+                return_annotation = ancestor_meth.return_annotation
+                break
+
+    return sig.replace(parameters=params, return_annotation=return_annotation)
+
+
+def pascal_to_kebab(string):
+    string = re.sub("(.)([A-Z][a-z]+)", r"\1-\2", string)
+    string = re.sub("(.)([0-9]+)", r"\1-\2", string)
+    return re.sub("([a-z0-9])([A-Z])", r"\1-\2", string).lower()
+
+
+class Linkifier:
+    def __init__(self):
+        path_index = {}
+        name_index = {}
+
+        modules = {
+            module: importlib.import_module(f"{package}.{module}")
+            for module in importlib.import_module(f"{package}").__all__
+        }
+
+        def index_module(mod_name, mod, path):
+            path = os.path.join(path, mod_name)
+            dotted_path = path.replace("/", ".")
+
+            for func_name, func in inspect.getmembers(mod, inspect.isfunction):
+                for e in (
+                    f"{mod_name}.{func_name}",
+                    f"{dotted_path}.{func_name}",
+                    f"{func.__module__}.{func_name}",
+                ):
+                    path_index[e] = os.path.join(path, snake_to_kebab(func_name))
+                    name_index[e] = f"{dotted_path}.{func_name}"
+
+            for klass_name, klass in inspect.getmembers(mod, inspect.isclass):
+                for e in (
+                    f"{mod_name}.{klass_name}",
+                    f"{dotted_path}.{klass_name}",
+                    f"{klass.__module__}.{klass_name}",
+                ):
+                    path_index[e] = os.path.join(path, klass_name)
+                    name_index[e] = f"{dotted_path}.{klass_name}"
+
+            for submod_name, submod in inspect.getmembers(mod, inspect.ismodule):
+                if submod_name not in mod.__all__ or submod_name == "typing":
+                    continue
+                for e in (f"{mod_name}.{submod_name}", f"{dotted_path}.{submod_name}"):
+                    path_index[e] = os.path.join(path, snake_to_kebab(submod_name))
+
+                # Recurse
+                index_module(submod_name, submod, path=path)
+
+        for mod_name, mod in modules.items():
+            index_module(mod_name, mod, path="")
+
+        # Prepend {package} to each index entry
+        for k in list(path_index.keys()):
+            path_index[f"{package}.{k}"] = path_index[k]
+        for k in list(name_index.keys()):
+            name_index[f"{package}.{k}"] = name_index[k]
+
+        self.path_index = path_index
+        self.name_index = name_index
+
+    def linkify(self, text, use_fences, depth):
+        path = self.path_index.get(text)
+        name = self.name_index.get(text)
+        if path and name:
+            backwards = "../" * (depth + 1)
+            if use_fences:
+                return f"[`{name}`]({backwards}{path})"
+            return f"[{name}]({backwards}{path})"
+        return None
+
+    def linkify_fences(self, text, depth):
+        between_fences = re.compile("`[\w\.]+\.\w+`")
+        return between_fences.sub(
+            lambda x: self.linkify(x.group().strip("`"), True, depth) or x.group(), text
+        )
+
+    def linkify_dotted(self, text, depth):
+        dotted = re.compile("\w+\.[\.\w]+")
+        return dotted.sub(
+            lambda x: self.linkify(x.group(), False, depth) or x.group(), text
+        )
+
+
+def concat_lines(lines):
+    return inspect.cleandoc(" ".join("\n\n" if line == "" else line for line in lines))
+
+
+def print_docstring(obj, file, depth):
+    """Prints a classes's docstring to a file."""
+
+    doc = ClassDoc(obj) if inspect.isclass(obj) else FunctionDoc(obj)
+
+    printf = functools.partial(print, file=file)
+
+    printf(h1(obj.__name__))
+    printf(linkifier.linkify_fences(paragraph(concat_lines(doc["Summary"])), depth))
+    printf(
+        linkifier.linkify_fences(
+            paragraph(concat_lines(doc["Extended Summary"])), depth
+        )
+    )
+
+    # We infer the type annotations from the signatures, and therefore rely on the signature
+    # instead of the docstring for documenting parameters
+    try:
+        signature = inspect.signature(obj)
+    except ValueError:
+        signature = (
+            inspect.Signature()
+        )  # TODO: this is necessary for Cython classes, but it's not correct
+    params_desc = {param.name: " ".join(param.desc) for param in doc["Parameters"]}
+
+    # Parameters
+    if signature.parameters:
+        printf(h2("Parameters"))
+    for param in signature.parameters.values():
+        # Name
+        printf(f"- **{param.name}**", end="")
+        # Type annotation
+        if param.annotation is not param.empty:
+            anno = inspect.formatannotation(param.annotation)
+            anno = linkifier.linkify_dotted(anno, depth)
+            printf(f" (*{anno}*)", end="")
+        # Default value
+        if param.default is not param.empty:
+            printf(f" – defaults to `{param.default}`", end="")
+        printf("\n", file=file)
+        # Description
+        if param.name in params_desc:
+            desc = params_desc[param.name]
+            if desc:
+                printf(f"    {desc}\n")
+    printf("")
+
+    # Attributes
+    if doc["Attributes"]:
+        printf(h2("Attributes"))
+    for attr in doc["Attributes"]:
+        # Name
+        printf(f"- **{attr.name}**", end="")
+        # Type annotation
+        if attr.type:
+            printf(f" (*{attr.type}*)", end="")
+        printf("\n", file=file)
+        # Description
+        desc = " ".join(attr.desc)
+        if desc:
+            printf(f"    {desc}\n")
+    printf("")
+
+    # Examples
+    if doc["Examples"]:
+        printf(h2("Examples"))
+
+        in_code = False
+        after_space = False
+
+        for line in inspect.cleandoc("\n".join(doc["Examples"])).splitlines():
+            if (
+                in_code
+                and after_space
+                and line
+                and not line.startswith(">>>")
+                and not line.startswith("...")
+            ):
+                printf("```\n")
+                in_code = False
+                after_space = False
+
+            if not in_code and line.startswith(">>>"):
+                printf("```python")
+                in_code = True
+
+            after_space = False
+            if not line:
+                after_space = True
+
+            printf(line)
+
+        if in_code:
+            printf("```")
+    printf("")
+
+    # Methods
+    if inspect.isclass(obj) and doc["Methods"]:
+        printf(h2("Methods"))
+        printf_indent = lambda x, **kwargs: printf(f"    {x}", **kwargs)  # noqa: E731
+
+        for meth in doc["Methods"]:
+            printf(paragraph(f'???- note "{meth.name}"'))
+
+            # Parse method docstring
+            docstring = inherit_docstring(c=obj, meth=meth.name)
+            if not docstring:
+                continue
+            meth_doc = FunctionDoc(func=None, doc=docstring)
+
+            printf_indent(paragraph(" ".join(meth_doc["Summary"])))
+            if meth_doc["Extended Summary"]:
+                printf_indent(paragraph(" ".join(meth_doc["Extended Summary"])))
+
+            # We infer the type annotations from the signatures, and therefore rely on the signature
+            # instead of the docstring for documenting parameters
+            signature = inherit_signature(obj, meth.name)
+            params_desc = {
+                param.name: " ".join(param.desc) for param in doc["Parameters"]
+            }
+
+            # Parameters
+            if (
+                len(signature.parameters) > 1
+            ):  # signature is never empty, but self doesn't count
+                printf_indent("**Parameters**\n")
+            for param in signature.parameters.values():
+                if param.name == "self":
+                    continue
+                # Name
+                printf_indent(f"- **{param.name}**", end="")
+                # Type annotation
+                if param.annotation is not param.empty:
+                    printf_indent(
+                        f" (*{inspect.formatannotation(param.annotation)}*)", end=""
+                    )
+                # Default value
+                if param.default is not param.empty:
+                    printf_indent(f" – defaults to `{param.default}`", end="")
+                printf_indent("", file=file)
+                # Description
+                desc = params_desc.get(param.name)
+                if desc:
+                    printf_indent(f"    {desc}")
+            printf_indent("")
+
+            # Returns
+            if meth_doc["Returns"]:
+                printf_indent("**Returns**\n")
+                return_val = meth_doc["Returns"][0]
+                if signature.return_annotation is not inspect._empty:
+                    if inspect.isclass(signature.return_annotation):
+                        printf_indent(
+                            f"*{signature.return_annotation.__name__}*: ", end=""
+                        )
+                    else:
+                        printf_indent(f"*{signature.return_annotation}*: ", end="")
+                printf_indent(return_val.type)
+                printf_indent("")
+
+    # Notes
+    if doc["Notes"]:
+        printf(h2("Notes"))
+        printf(paragraph("\n".join(doc["Notes"])))
+
+    # References
+    if doc["References"]:
+        printf(h2("References"))
+        printf(paragraph("\n".join(doc["References"])))
+
+
+def print_module(mod, path, overview, is_submodule=False):
+    mod_name = mod.__name__.split(".")[-1]
+
+    # Create a directory for the module
+    mod_slug = snake_to_kebab(mod_name)
+    mod_path = path.joinpath(mod_slug)
+    mod_short_path = str(mod_path).replace("docs/api/", "")
+    os.makedirs(mod_path, exist_ok=True)
+    with open(mod_path.joinpath(".pages"), "w") as f:
+        f.write(f"title: {mod_name}")
+
+    # Add the module to the overview
+    if is_submodule:
+        print(h3(mod_name), file=overview)
+    else:
+        print(h2(mod_name), file=overview)
+    if mod.__doc__:
+        print(paragraph(mod.__doc__), file=overview)
+
+    # Extract all public classes and functions
+    ispublic = lambda x: x.__name__ in mod.__all__ and not x.__name__.startswith("_")  # noqa: E731
+    classes = inspect.getmembers(mod, lambda x: inspect.isclass(x) and ispublic(x))
+    funcs = inspect.getmembers(mod, lambda x: inspect.isfunction(x) and ispublic(x))
+
+    # Classes
+
+    if classes and funcs:
+        print("\n**Classes**\n", file=overview)
+
+    for _, c in classes:
+        print(f"{mod_name}.{c.__name__}")
+
+        # Add the class to the overview
+        slug = snake_to_kebab(c.__name__)
+        print(
+            li(link(c.__name__, f"../{mod_short_path}/{slug}")), end="", file=overview
+        )
+
+        # Write down the class' docstring
+        with open(mod_path.joinpath(slug).with_suffix(".md"), "w") as file:
+            print_docstring(obj=c, file=file, depth=mod_short_path.count("/") + 1)
+
+    # Functions
+
+    if classes and funcs:
+        print("\n**Functions**\n", file=overview)
+
+    for _, f in funcs:
+        print(f"{mod_name}.{f.__name__}")
+
+        # Add the function to the overview
+        slug = snake_to_kebab(f.__name__)
+        print(
+            li(link(f.__name__, f"../{mod_short_path}/{slug}")), end="", file=overview
+        )
+
+        # Write down the function' docstring
+        with open(mod_path.joinpath(slug).with_suffix(".md"), "w") as file:
+            print_docstring(obj=f, file=file, depth=mod_short_path.count(".") + 1)
+
+    # Sub-modules
+    for name, submod in inspect.getmembers(mod, inspect.ismodule):
+        # We only want to go through the public submodules, such as optim.schedulers
+        if (
+            name in ("tags", "typing", "inspect", "skmultiflow_utils")
+            or name not in mod.__all__
+            or name.startswith("_")
+        ):
+            continue
+        print_module(mod=submod, path=mod_path, overview=overview, is_submodule=True)
+
+    print("", file=overview)
+
+
+if __name__ == "__main__":
+    api_path = pathlib.Path("docs/api")
+
+    # Create a directory for the API reference
+    shutil.rmtree(api_path, ignore_errors=True)
+    os.makedirs(api_path, exist_ok=True)
+    with open(api_path.joinpath(".pages"), "w") as f:
+        f.write("title: API reference\narrange:\n  - overview.md\n  - ...\n")
+
+    overview = open(api_path.joinpath("overview.md"), "w")
+    print(h1("Overview"), file=overview)
+
+    linkifier = Linkifier()
+
+    for mod_name, mod in inspect.getmembers(
+        importlib.import_module(f"{package}"), inspect.ismodule
+    ):
+        if mod_name.startswith("_"):
+            continue
+        print(mod_name)
+        print_module(mod, path=api_path, overview=overview)
diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css
new file mode 100644
index 0000000..95ec558
--- /dev/null
+++ b/docs/stylesheets/extra.css
@@ -0,0 +1,13 @@
+.md-typeset h2 {
+    margin: 1.5em 0;
+    padding-bottom: .4rem;
+    border-bottom: .04rem solid var(--md-default-fg-color--lighter);
+}
+
+.md-footer {
+    margin-top: 2em;
+}
+
+.md-typeset pre>code {
+    border-radius: 0.5em;
+}
\ No newline at end of file
diff --git a/evaluation/beir.py b/evaluation/scifact.py
similarity index 100%
rename from evaluation/beir.py
rename to evaluation/scifact.py
diff --git a/mkdocs.yml b/mkdocs.yml
new file mode 100644
index 0000000..d6d5297
--- /dev/null
+++ b/mkdocs.yml
@@ -0,0 +1,93 @@
+# Project information
+site_name: pylate
+site_description: Neural Search 
+site_author: Raphael Sourty
+site_url: https://lightonai.github.io/pylate
+
+# Repository
+repo_name: lighton/pylate
+repo_url: https://github.com/lightonai/pylate
+edit_uri: ""
+
+# Copyright
+copyright: Copyright &copy; 2023
+
+# Configuration
+theme:
+  name: material
+  custom_dir: docs
+  language: en
+
+  palette:
+    - scheme: default
+      primary: blue
+      accent: blue
+      toggle:
+        icon: material/brightness-7
+        name: Switch to dark mode
+    - scheme: slate
+      primary: blue
+      accent: blue
+      toggle:
+        icon: material/brightness-4
+        name: Switch to light mode
+  
+  font:
+    text: Fira Sans
+    code: Fira Code
+  logo: img/favicon.png
+  favicon: img/favicon.ico
+  features:
+    - content.code.copy
+    - navigation.tabs
+    - navigation.instant
+    - navigation.indexes
+    - navigation.prune
+
+# Extras
+extra:
+    social:
+      - icon: fontawesome/brands/github-alt
+        link: https://github.com/lightonai/pylate
+
+# Extensions
+markdown_extensions:
+  - admonition
+  - footnotes
+  - tables
+  - toc:
+      permalink: true
+      toc_depth: "1-3"
+  - pymdownx.details
+  - pymdownx.arithmatex:
+      generic: true
+  - pymdownx.highlight:
+      pygments_lang_class: true
+  - pymdownx.inlinehilite
+  - pymdownx.tabbed:
+      alternate_style: true
+  - pymdownx.superfences:
+      custom_fences:
+        - name: vegalite
+          class: vegalite
+          format: !!python/name:mkdocs_charts_plugin.fences.fence_vegalite
+
+
+plugins:
+  - search
+  - awesome-pages
+  - mkdocs-jupyter
+
+extra_javascript:
+  - javascripts/config.js
+  - https://cdn.jsdelivr.net/npm/mathjax@3.2/es5/tex-mml-chtml.js
+  - https://cdn.jsdelivr.net/npm/vega@5
+  - https://cdn.jsdelivr.net/npm/vega-lite@5
+  - https://cdn.jsdelivr.net/npm/vega-embed@6
+  - https://unpkg.com/tablesort@5.3.0/dist/tablesort.min.js
+  - javascripts/tablesort.js
+
+extra_css:
+  - stylesheets/extra.css
+  - css/version-select.css
+
diff --git a/pylate/evaluation/beir.py b/pylate/evaluation/beir.py
index baafcc4..2392eab 100644
--- a/pylate/evaluation/beir.py
+++ b/pylate/evaluation/beir.py
@@ -166,6 +166,29 @@ def evaluate(
     metrics
         Metrics to compute.
 
+    Examples
+    --------
+    >>> from pylate import evaluation
+
+    >>> scores = [
+    ...     [{"id": "1", "score": 0.9}, {"id": "2", "score": 0.8}],
+    ...     [{"id": "3", "score": 0.7}, {"id": "4", "score": 0.6}],
+    ... ]
+
+    >>> qrels = {
+    ...     "query1": {"1": True, "2": True},
+    ...     "query2": {"3": True, "4": True},
+    ... }
+
+    >>> queries = ["query1", "query2"]
+
+    >>> results = evaluation.evaluate(
+    ...     scores=scores,
+    ...     qrels=qrels,
+    ...     queries=queries,
+    ...     metrics=["ndcg@10", "hits@1"],
+    ... )
+
     """
     from ranx import Qrels, Run, evaluate
 
diff --git a/pylate/scores/__init__.py b/pylate/scores/__init__.py
index f29e636..e80a640 100644
--- a/pylate/scores/__init__.py
+++ b/pylate/scores/__init__.py
@@ -1,3 +1,3 @@
-from .scores import colbert_scores, colbert_scores_pairwise, colbert_kd_scores
+from .scores import colbert_kd_scores, colbert_scores, colbert_scores_pairwise
 
-__all__ = ["colbert_scores", "colbert_scores_pairwise", "colbert_kd_scores"]
\ No newline at end of file
+__all__ = ["colbert_scores", "colbert_scores_pairwise", "colbert_kd_scores"]
diff --git a/pylate/utils/iter_batch.py b/pylate/utils/iter_batch.py
index f0fbfd0..37e737b 100644
--- a/pylate/utils/iter_batch.py
+++ b/pylate/utils/iter_batch.py
@@ -7,7 +7,7 @@ def iter_batch(
     """Iterate over a list of elements by batch.
 
     Examples
-    -------
+    --------
     >>> from pylate import utils
 
     >>> X = [
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..224a779
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,2 @@
+[metadata]
+description-file = README.md
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 86f6c05..af43791 100644
--- a/setup.py
+++ b/setup.py
@@ -11,10 +11,21 @@
     "accelerate >= 0.31.0",
     "voyager >= 2.0.9",
     "sqlitedict >= 2.1.0",
+    "pandas >= 2.2.1",
 ]
 
 
-dev = ["ruff >= 0.4.9", "pytest-cov >= 5.0.0", "pytest >= 8.2.1", "pandas >= 2.2.1"]
+dev = [
+    "ruff >= 0.4.9",
+    "pytest-cov >= 5.0.0",
+    "pytest >= 8.2.1",
+    "pandas >= 2.2.1",
+    "mkdocs-material == 9.5.32",
+    "mkdocs-awesome-pages-plugin == 2.9.3",
+    "mkdocs-jupyter == 0.24.8",
+    "mkdocs_charts_plugin == 0.0.10",
+    "numpydoc == 1.8.0",
+]
 
 eval = ["ranx >= 0.3.16", "beir >= 2.0.0"]