Skip to content

Commit

Permalink
ContrastiveDistillationDataset iter refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
DemirTonchev committed Dec 25, 2024
1 parent 2d5e29b commit 1c905b1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 19 deletions.
28 changes: 10 additions & 18 deletions src/setfit/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def __init__(
sentences: List[str],
cos_sim_matrix: torch.Tensor,
num_iterations: Optional[None] = None,
sampling_strategy: str = "oversampling",
max_pairs: int = -1,
) -> None:
self.cos_sim_matrix = cos_sim_matrix
Expand All @@ -194,23 +193,16 @@ def __init__(
[0] * len(sentences),
multilabel=False,
num_iterations=num_iterations,
sampling_strategy=sampling_strategy,
sampling_strategy=SamplingStrategy.UNIQUE, # use unique to create all pos pairs, implementation choice to use generate_positive_pair method (*)
max_pairs=max_pairs,
)
# Internally we store all pairs in pos_pairs, regardless of sampling strategy.
# After all, without labels, there isn't much of a strategy.
self.sentence_labels = list(enumerate(self.sentences))

self.len_neg_pairs = 0
if num_iterations is not None and num_iterations > 0:
self.len_pos_pairs = num_iterations * len(self.sentences)
else:
self.len_pos_pairs = len(self.pos_pairs)

def generate_pairs(self) -> None:
for (text_one, id_one), (text_two, id_two) in shuffle_combinations(self.sentence_labels):
self.pos_pairs.append(
{"sentence_1": text_one, "sentence_2": text_two, "label": self.cos_sim_matrix[id_one][id_two]}
)
if self.max_pos_or_neg != -1 and len(self.pos_pairs) > self.max_pos_or_neg:
break
self.sentence_labels = list(zip(self.sentences, range(len(self.sentences))))

# (*) Internally we use generate_positive_pair
def generate_positive_pair(self) -> Generator[Dict[str, Union[str, float]]]:
pair_generator = shuffle_combinations(self.sentence_labels)
while True:
for (text_one, id_one), (text_two, id_two) in pair_generator:
yield {"sentence_1": text_one, "sentence_2": text_two, "label": self.cos_sim_matrix[id_one][id_two]}
pair_generator = shuffle_combinations(self.sentence_labels)
2 changes: 1 addition & 1 deletion src/setfit/trainer_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_dataset(
data_sampler = ContrastiveDistillationDataset(
x, cos_sim_matrix, args.num_iterations, args.sampling_strategy, max_pairs=max_pairs
)
dataset = Dataset.from_list(list(data_sampler))
dataset = Dataset.from_generator(data_sampler.__iter__)
loss = args.loss(self.model.model_body)
return dataset, loss

Expand Down

0 comments on commit 1c905b1

Please sign in to comment.