Skip to content

Commit

Permalink
Refine mine_hard_negatives arguments (#2977)
Browse files Browse the repository at this point in the history
* refine mine_hard_negatives arguments to support column names + min_score to accept negative samples

* check if the tensor is empty before applying the reduction functions

* rename `query_column_name` to `anchor_column_name` for `mine_hard_negatives` function

* append pre-commit refinements

---------

Co-authored-by: Abu Bakr <[email protected]>
Co-authored-by: Abu Bakr Soliman <[email protected]>
  • Loading branch information
3 people authored Oct 10, 2024
1 parent 7855327 commit f85b502
Showing 1 changed file with 47 additions and 15 deletions.
62 changes: 47 additions & 15 deletions sentence_transformers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,11 +520,14 @@ def semantic_search(
def mine_hard_negatives(
dataset: Dataset,
model: SentenceTransformer,
anchor_column_name: str | None = None,
positive_column_name: str | None = None,
corpus: list[str] | None = None,
cross_encoder: CrossEncoder | None = None,
range_min: int = 0,
range_max: int | None = None,
max_score: float | None = None,
min_score: float | None = None,
margin: float | None = None,
num_negatives: int = 3,
sampling_strategy: Literal["random", "top"] = "top",
Expand Down Expand Up @@ -554,6 +557,7 @@ def mine_hard_negatives(
to sample negatives from. A lower value makes processing faster, but may result in less candidate negatives that
satisfy the margin or max_score conditions.
- **max_score**: Maximum score to consider as a negative: useful to skip candidates that are too similar to the anchor.
- **min_score**: Minimum score to consider as a negative: useful to skip candidates that are too dissimilar to the anchor.
- **margin**: Margin for hard negative mining: useful to skip candidates negatives whose similarity to the anchor is
within a certain margin of the positive pair. A value of 0 can be used to enforce that the negative is always
further away from the anchor than the positive.
Expand Down Expand Up @@ -621,13 +625,16 @@ def mine_hard_negatives(
Args:
dataset (Dataset): A dataset containing (anchor, positive) pairs.
model (SentenceTransformer): A SentenceTransformer model to use for embedding the sentences.
anchor_column_name (str, optional): The column name in `dataset` that contains the anchor/query. Defaults to None, in which case the first column in `dataset` will be used.
positive_column_name (str, optional): The column name in `dataset` that contains the positive candidates. Defaults to None, in which case the second column in `dataset` will be used.
corpus (List[str], optional): A list containing documents as strings that will be used as candidate negatives
in addition to the second column in `dataset`. Defaults to None, in which case the second column in
`dataset` will exclusively be used as the negative candidate corpus.
cross_encoder (CrossEncoder, optional): A CrossEncoder model to use for rescoring the candidates. Defaults to None.
range_min (int): Minimum rank of the closest matches to consider as negatives. Defaults to 0.
range_max (int, optional): Maximum rank of the closest matches to consider as negatives. Defaults to None.
max_score (float, optional): Maximum score to consider as a negative. Defaults to None.
min_score (float, optional): Minimum score to consider as a negative. Defaults to None.
margin (float, optional): Margin for hard negative mining. Defaults to None.
num_negatives (int): Number of negatives to sample. Defaults to 3.
sampling_strategy (Literal["random", "top"]): Sampling strategy for negatives: "top" or "random". Defaults to "top".
Expand All @@ -648,11 +655,20 @@ def mine_hard_negatives(

# If a dataset has duplicate queries, assume that all duplicates are positive pairs.
columns = dataset.column_names
if len(columns) != 2:

if not anchor_column_name or anchor_column_name not in columns:
anchor_column_name = columns[0]

if not positive_column_name or positive_column_name not in columns:
positive_column_name = columns[1]

if not anchor_column_name and not positive_column_name and len(columns) != 2:
raise ValueError("Dataset must contain exactly two columns.")

# To avoid re-embedding the same query multiple times, we keep a counter of the number of positives per query
positives_per_query = list(dataset.to_pandas().groupby(columns[0]).count().to_dict()[columns[1]].values())
positives_per_query = list(
dataset.to_pandas().groupby(anchor_column_name).count().to_dict()[positive_column_name].values()
)
max_positives = max(positives_per_query)

if range_max is None:
Expand All @@ -671,8 +687,8 @@ def mine_hard_negatives(
print(f"Setting range_max to {range_max} based on the provided parameters.")

log_counters = {}
queries = dataset[columns[0]]
positives = dataset[columns[1]]
queries = dataset[anchor_column_name]
positives = dataset[positive_column_name]
separate_corpus = corpus is not None
if not separate_corpus:
corpus = positives
Expand Down Expand Up @@ -835,6 +851,18 @@ def mine_hard_negatives(
"ratio": num_skipped / num_candidates,
}

# Remove based on min_score
if min_score is not None:
removed_indices = scores < min_score
scores[removed_indices] = -float("inf")

num_skipped = removed_indices.sum().item()
if num_skipped:
log_counters["min_score"] = {
"skipped": num_skipped,
"ratio": num_skipped / num_candidates,
}

# Grab the top negative candidates and remove the first range_min candidates
negative_scores, local_indices = torch.topk(scores, k=range_max, dim=1)
indices = indices[batch_idx, local_indices]
Expand Down Expand Up @@ -888,14 +916,14 @@ def mine_hard_negatives(
positive_indices = pos_indices[indices_to_keep]

triplets_data = {
columns[0]: [],
columns[1]: [],
anchor_column_name: [],
positive_column_name: [],
"negative": [],
}

for anchor_idx, negative_idx, positive_idx in zip(anchor_indices, indices, positive_indices):
triplets_data[columns[0]].append(queries[anchor_idx])
triplets_data[columns[1]].append(corpus[positive_idx])
triplets_data[anchor_column_name].append(queries[anchor_idx])
triplets_data[positive_column_name].append(corpus[positive_idx])
triplets_data["negative"].append(corpus[negative_idx])
difference_scores = positive_scores.repeat(num_negatives, 1).T[indices_to_keep] - negative_scores

Expand All @@ -906,8 +934,8 @@ def mine_hard_negatives(
indices = indices[indices_to_keep]

triplets_data = {
columns[0]: [all_queries[idx] for idx, keep in enumerate(indices_to_keep) if keep],
columns[1]: [positives[idx] for idx, keep in enumerate(indices_to_keep) if keep],
anchor_column_name: [all_queries[idx] for idx, keep in enumerate(indices_to_keep) if keep],
positive_column_name: [positives[idx] for idx, keep in enumerate(indices_to_keep) if keep],
**{
f"negative_{i}": [corpus[neg_idx] for neg_idx in neg_indices]
for i, neg_indices in enumerate(indices.T, start=1)
Expand Down Expand Up @@ -937,11 +965,11 @@ def mine_hard_negatives(
("mean", torch.mean),
("median", torch.median),
("std", torch.std),
("min", torch.min),
("25%", lambda scores: torch.quantile(scores.float(), q=0.25)),
("50%", lambda scores: torch.quantile(scores.float(), q=0.5)),
("75%", lambda scores: torch.quantile(scores.float(), q=0.75)),
("max", torch.max),
("min", lambda scores: torch.min(scores) if scores.numel() > 0 else float("inf")),
("25%", lambda scores: torch.quantile(scores.float(), q=0.25) if scores.numel() > 0 else float("inf")),
("50%", lambda scores: torch.quantile(scores.float(), q=0.5) if scores.numel() > 0 else float("inf")),
("75%", lambda scores: torch.quantile(scores.float(), q=0.75) if scores.numel() > 0 else float("inf")),
("max", lambda scores: torch.max(scores) if scores.numel() > 0 else float("-inf")),
]:
print(
row_format.format(
Expand All @@ -960,6 +988,10 @@ def mine_hard_negatives(
print(
f"Skipped {log_counters['max_score']['skipped']} potential negatives ({log_counters['max_score']['ratio']:.2%}) due to the maximum score of {max_score}."
)
if "min_score" in log_counters:
print(
f"Skipped {log_counters['min_score']['skipped']} potential negatives ({log_counters['min_score']['ratio']:.2%}) due to the minimum score of {min_score}."
)

missing_negatives = (num_negatives * len(dataset)) - len(negative_scores)
if missing_negatives > 0:
Expand Down

0 comments on commit f85b502

Please sign in to comment.