Skip to content

Commit

Permalink
Prevent sampling 2x more than requested when max_steps is set (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomaarsen authored Sep 18, 2024
1 parent e1e8e3e commit 223afb6
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/setfit/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.sentences = sentences
self.labels = labels
self.sentence_labels = list(zip(self.sentences, self.labels))
self.max_pairs = max_pairs
self.max_pos_or_neg = -1 if max_pairs == -1 else max_pairs // 2

if multilabel:
self.generate_multilabel_pairs()
Expand Down Expand Up @@ -90,8 +90,8 @@ def __init__(
def generate_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
is_positive = _label == label
is_positive_full = self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs
is_negative_full = self.max_pairs != -1 and len(self.neg_pairs) >= self.max_pairs
is_positive_full = self.max_pos_or_neg != -1 and len(self.pos_pairs) >= self.max_pos_or_neg
is_negative_full = self.max_pos_or_neg != -1 and len(self.neg_pairs) >= self.max_pos_or_neg

if is_positive:
if not is_positive_full:
Expand All @@ -106,8 +106,8 @@ def generate_multilabel_pairs(self) -> None:
for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels):
# logical_and checks if labels are both set for each class
is_positive = any(np.logical_and(_label, label))
is_positive_full = self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs
is_negative_full = self.max_pairs != -1 and len(self.neg_pairs) >= self.max_pairs
is_positive_full = self.max_pos_or_neg != -1 and len(self.pos_pairs) >= self.max_pos_or_neg
is_negative_full = self.max_pos_or_neg != -1 and len(self.neg_pairs) >= self.max_pos_or_neg

if is_positive:
if not is_positive_full:
Expand Down Expand Up @@ -180,5 +180,5 @@ def generate_pairs(self) -> None:
self.pos_pairs.append(
{"sentence_1": text_one, "sentence_2": text_two, "label": self.cos_sim_matrix[id_one][id_two]}
)
if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs:
if self.max_pos_or_neg != -1 and len(self.pos_pairs) > self.max_pos_or_neg:
break

0 comments on commit 223afb6

Please sign in to comment.