Skip to content

Commit

Permalink
use iter insead of list
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijianma committed Nov 17, 2023
1 parent 9619176 commit 942b6f5
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions data_juicer/format/mixture_formatter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from itertools import chain, repeat
from typing import List, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -76,7 +77,7 @@ def _get_weight(self, data_prefix):

for i in range(len(data_prefix)):
try:
value = float(data_prefix[i])
value = max(float(data_prefix[i]), 0.0)
weights.append(value)
except: # noqa: E722
value = data_prefix[i].strip()
Expand Down Expand Up @@ -108,15 +109,13 @@ def _random_sample(self, dataset, weight=1.0, sample_number=0, seed=None):
if sample_number == ds_samples:
return dataset

num_epochs = int(np.ceil(sample_number / ds_samples)) - 1
sample_index = range(sample_number)

if num_epochs > 0:
remain_samples = sample_number - num_epochs * ds_samples
sample_index = list(range(ds_samples)) * num_epochs + list(
range(remain_samples))
else:
remain_samples = sample_number
sample_index = list(range(remain_samples))
n_repeat = int(np.ceil(sample_number / ds_samples)) - 1
if n_repeat > 0:
remain_samples = sample_number - n_repeat * ds_samples
sample_index = chain(*repeat(range(ds_samples), n_repeat),
range(remain_samples))

return dataset.shuffle(seed=seed).select(sample_index)

Expand Down

0 comments on commit 942b6f5

Please sign in to comment.