diff --git a/src/setfit/sampler.py b/src/setfit/sampler.py index 1f7727c9..16bc9938 100644 --- a/src/setfit/sampler.py +++ b/src/setfit/sampler.py @@ -61,7 +61,7 @@ def __init__( self.sentences = sentences self.labels = labels self.sentence_labels = list(zip(self.sentences, self.labels)) - self.max_pairs = max_pairs + self.max_pos_or_neg = -1 if max_pairs == -1 else max_pairs // 2 if multilabel: self.generate_multilabel_pairs() @@ -90,8 +90,8 @@ def __init__( def generate_pairs(self) -> None: for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): is_positive = _label == label - is_positive_full = self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs - is_negative_full = self.max_pairs != -1 and len(self.neg_pairs) >= self.max_pairs + is_positive_full = self.max_pos_or_neg != -1 and len(self.pos_pairs) >= self.max_pos_or_neg + is_negative_full = self.max_pos_or_neg != -1 and len(self.neg_pairs) >= self.max_pos_or_neg if is_positive: if not is_positive_full: @@ -106,8 +106,8 @@ def generate_multilabel_pairs(self) -> None: for (_text, _label), (text, label) in shuffle_combinations(self.sentence_labels): # logical_and checks if labels are both set for each class is_positive = any(np.logical_and(_label, label)) - is_positive_full = self.max_pairs != -1 and len(self.pos_pairs) >= self.max_pairs - is_negative_full = self.max_pairs != -1 and len(self.neg_pairs) >= self.max_pairs + is_positive_full = self.max_pos_or_neg != -1 and len(self.pos_pairs) >= self.max_pos_or_neg + is_negative_full = self.max_pos_or_neg != -1 and len(self.neg_pairs) >= self.max_pos_or_neg if is_positive: if not is_positive_full: @@ -180,5 +180,5 @@ def generate_pairs(self) -> None: self.pos_pairs.append( {"sentence_1": text_one, "sentence_2": text_two, "label": self.cos_sim_matrix[id_one][id_two]} ) - if self.max_pairs != -1 and len(self.pos_pairs) > self.max_pairs: + if self.max_pos_or_neg != -1 and len(self.pos_pairs) > self.max_pos_or_neg: break