Skip to content

Commit

Permalink
Merge pull request #73 from lightonai/fix_loading
Browse files Browse the repository at this point in the history
Add safetensor OR bin loading logic + add loading tests
  • Loading branch information
raphaelsty authored Nov 29, 2024
2 parents 0af6c4c + 6399410 commit d12825a
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 19 deletions.
62 changes: 43 additions & 19 deletions pylate/models/Dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os

import torch
from safetensors import safe_open
from safetensors.torch import load_model as load_safetensors_model
from sentence_transformers.models import Dense as DenseSentenceTransformer
from sentence_transformers.util import import_from_string
Expand Down Expand Up @@ -110,26 +111,49 @@ def from_stanford_weights(
"""
# Check if the model is locally available
if not (os.path.exists(os.path.join(model_name_or_path))):
# Else download the model/use the cached version
model_name_or_path = cached_file(
model_name_or_path,
filename="pytorch_model.bin",
cache_dir=cache_folder,
revision=revision,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
)
# If the model a local folder, load the PyTorch model
# Else download the model/use the cached version. We first try to use the safetensors version and fall back to bin if not existing. All the recent stanford-nlp models are safetensors but we keep bin for compatibility.
try:
model_name_or_path = cached_file(
model_name_or_path,
filename="model.safetensors",
cache_dir=cache_folder,
revision=revision,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
)
except EnvironmentError:
print("No safetensor model found, falling back to bin.")
model_name_or_path = cached_file(
model_name_or_path,
filename="pytorch_model.bin",
cache_dir=cache_folder,
revision=revision,
local_files_only=local_files_only,
token=token,
use_auth_token=use_auth_token,
)
# If the model a local folder, load the safetensor
# Again, we first try to load the safetensors version and fall back to bin if not existing.
else:
if os.path.exists(os.path.join(model_name_or_path, "model.safetensors")):
model_name_or_path = os.path.join(
model_name_or_path, "model.safetensors"
)
else:
print("No safetensor model found, falling back to bin.")
model_name_or_path = os.path.join(
model_name_or_path, "pytorch_model.bin"
)
if model_name_or_path.endswith("safetensors"):
with safe_open(model_name_or_path, framework="pt", device="cpu") as f:
state_dict = {"linear.weight": f.get_tensor("linear.weight")}
else:
model_name_or_path = os.path.join(model_name_or_path, "pytorch_model.bin")

# Load the state dict using torch.load instead of safe_open
state_dict = {
"linear.weight": torch.load(model_name_or_path, map_location="cpu")[
"linear.weight"
]
}
state_dict = {
"linear.weight": torch.load(model_name_or_path, map_location="cpu")[
"linear.weight"
]
}

# Determine input and output dimensions
in_features = state_dict["linear.weight"].shape[1]
Expand Down
90 changes: 90 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import math

import torch

from pylate import models, rank


def test_model_creation(**kwargs) -> None:
"""Test the creation of different models."""
query = ["fruits are healthy."]
documents = [["fruits are healthy.", "fruits are good for health."]]
torch.manual_seed(42)
# Creation from a base encoder
model = models.ColBERT(model_name_or_path="bert-base-uncased")
# We don't test the embeddings of newly initied models for now as we need to make it deterministic
# queries_embeddings = model.encode(sentences=query, is_query=True)
# documents_embeddings = model.encode(sentences=documents, is_query=False)
# reranked_documents = rank.rerank(
# documents_ids=[["1", "2"]],
# queries_embeddings=queries_embeddings,
# documents_embeddings=documents_embeddings,
# )
# assert math.isclose(
# reranked_documents[0][0]["score"], 25.92, rel_tol=0.01, abs_tol=0.01
# )
# assert math.isclose(reranked_documents[0][1]["score"], 23.7, rel_tol=0.01, abs_tol=0.01)

# Creation from a base sentence-transformer
model = models.ColBERT(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
# We don't test the embeddings of newly initied models for now as we need to make it deterministic
# queries_embeddings = model.encode(sentences=query, is_query=True)
# documents_embeddings = model.encode(sentences=documents, is_query=False)
# reranked_documents = rank.rerank(
# documents_ids=[["1", "2"]],
# queries_embeddings=queries_embeddings,
# documents_embeddings=documents_embeddings,
# )
# assert math.isclose(
# reranked_documents[0][0]["score"], 18.77, rel_tol=0.01, abs_tol=0.01
# )
# assert math.isclose(
# reranked_documents[0][1]["score"], 18.63, rel_tol=0.01, abs_tol=0.01

# Creation from stanford-nlp (safetensor)
model = models.ColBERT(model_name_or_path="answerdotai/answerai-colbert-small-v1")
queries_embeddings = model.encode(sentences=query, is_query=True)
documents_embeddings = model.encode(sentences=documents, is_query=False)
reranked_documents = rank.rerank(
documents_ids=[["1", "2"]],
queries_embeddings=queries_embeddings,
documents_embeddings=documents_embeddings,
)
assert math.isclose(
reranked_documents[0][0]["score"], 31.71, rel_tol=0.01, abs_tol=0.01
)
assert math.isclose(
reranked_documents[0][1]["score"], 31.64, rel_tol=0.01, abs_tol=0.01
)

# Creation from stanford-nlp (bin)
model = models.ColBERT(model_name_or_path="Crystalcareai/Colbertv2")
queries_embeddings = model.encode(sentences=query, is_query=True)
documents_embeddings = model.encode(sentences=documents, is_query=False)
reranked_documents = rank.rerank(
documents_ids=[["1", "2"]],
queries_embeddings=queries_embeddings,
documents_embeddings=documents_embeddings,
)
assert math.isclose(
reranked_documents[0][0]["score"], 31.15, rel_tol=0.01, abs_tol=0.01
)
assert math.isclose(
reranked_documents[0][1]["score"], 30.61, rel_tol=0.01, abs_tol=0.01
)

# Creation from PyLate
model = models.ColBERT(model_name_or_path="lightonai/colbertv2.0")
queries_embeddings = model.encode(sentences=query, is_query=True)
documents_embeddings = model.encode(sentences=documents, is_query=False)
reranked_documents = rank.rerank(
documents_ids=[["1", "2"]],
queries_embeddings=queries_embeddings,
documents_embeddings=documents_embeddings,
)
assert math.isclose(
reranked_documents[0][0]["score"], 30.01, rel_tol=0.01, abs_tol=0.01
)
assert math.isclose(
reranked_documents[0][1]["score"], 26.98, rel_tol=0.01, abs_tol=0.01
)

0 comments on commit d12825a

Please sign in to comment.