-
Notifications
You must be signed in to change notification settings - Fork 217
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CutSet multiplexing #565
CutSet multiplexing #565
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
import itertools | ||
import json | ||
import random | ||
import warnings | ||
from pathlib import Path | ||
from typing import Any, Dict, Generator, Iterable, Optional, Type, Union | ||
from typing import Any, Dict, Generator, Iterable, List, Optional, Type, Union | ||
|
||
import yaml | ||
|
||
|
@@ -527,6 +528,75 @@ def __add__(self, other) -> "LazyIteratorChain": | |
return LazyIteratorChain(self, other) | ||
|
||
|
||
class LazyIteratorMultiplexer: | ||
""" | ||
A wrapper over multiple iterators that enables to combine lazy manifests in Lhotse. | ||
During iteration, unlike :class:`.LazyIteratorChain`, :class:`.LazyIteratorMultiplexer` | ||
at each step randomly selects the iterable used to yield an item. | ||
|
||
Since the iterables might be of different length, we provide a ``weights`` parameter | ||
to let the user decide which iterables should be sampled more frequently than others. | ||
When an iterable is exhausted, we will keep sampling from the other iterables, until | ||
we exhaust them all. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
*iterators: Iterable, | ||
weights: Optional[List[Union[int, float]]] = None, | ||
seed: int = 0, | ||
) -> None: | ||
self.iterators = list(iterators) | ||
self.seed = seed | ||
|
||
assert ( | ||
len(self.iterators) > 1 | ||
), "There have to be at least two iterables to multiplex." | ||
|
||
if weights is None: | ||
self.weights = [1] * len(self.iterators) | ||
else: | ||
self.weights = weights | ||
|
||
assert len(self.iterators) == len(self.weights) | ||
|
||
def __iter__(self): | ||
rng = random.Random(self.seed) | ||
iters = [iter(it) for it in self.iterators] | ||
exhausted = [False for _ in range(len(iters))] | ||
while not all(exhausted): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, can we support to specify that as soon as a specified cutset is exhausted, it breaks the while loop. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, see #585 |
||
active_indexes, active_weights = zip( | ||
*[ | ||
(i, w) | ||
for i, (is_exhausted, w) in enumerate(zip(exhausted, self.weights)) | ||
if not is_exhausted | ||
] | ||
) | ||
idx = rng.choices(active_indexes, weights=active_weights, k=1)[0] | ||
selected = iters[idx] | ||
try: | ||
item = next(selected) | ||
yield item | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to return also the We may need to select a different network to process the returned For instance, when combining gigaspeech and librispeech in transducer training, the encoder network is shared and there are two separate decoder+joiner networks for each dataset. If the returned item is from gigaspeech, we would run the decoder+joiner for gigaspeech; if the returned item is from librispeech, we would run the decoder+joiner for librispeech. If this function returns only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm that would be problematic, it would break the API that CutSet and CutSampler is depending on. But there are two other ways that would be much easier:
|
||
except StopIteration: | ||
exhausted[idx] = True | ||
continue | ||
|
||
def values(self): | ||
yield from self | ||
|
||
def keys(self): | ||
return (item.id for item in self) | ||
|
||
def items(self): | ||
return ((item.id, item) for item in self) | ||
|
||
def __len__(self) -> int: | ||
return sum(len(it) for it in self.iterators) | ||
|
||
def __add__(self, other) -> "LazyIteratorChain": | ||
return LazyIteratorChain(self, other) | ||
|
||
|
||
def deserialize_item(data: dict) -> Any: | ||
# Figures out what type of manifest is being decoded with some heuristics | ||
# and returns a Lhotse manifest object rather than a raw dict. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import pickle | ||
import random | ||
|
||
from lhotse import CutSet | ||
from lhotse.serialization import LazyIteratorMultiplexer | ||
from lhotse.testing.dummies import DummyManifest | ||
|
||
|
||
def test_multiplexer(): | ||
mux = LazyIteratorMultiplexer(range(10), range(900, 903), seed=0) # len 10 # len 3 | ||
|
||
assert sorted(list(mux)) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 900, 901, 902] | ||
assert sorted(list(mux)) != list(mux) | ||
|
||
|
||
def test_multiplexer_deterministic(): | ||
# seed given | ||
mux = LazyIteratorMultiplexer( | ||
range(1000), range(900000, 901000), seed=0 # len 10 # len 3 | ||
) | ||
assert list(mux) == list(mux) | ||
|
||
|
||
def test_multiplexer_weights(): | ||
mux_uniform = LazyIteratorMultiplexer( | ||
range(10), range(900, 903), seed=0 # len 10 # len 3 | ||
) | ||
mux_weighted = LazyIteratorMultiplexer( | ||
range(10), # len 10 | ||
range(900, 903), # len 3 | ||
seed=0, | ||
weights=[10, 3], | ||
) | ||
|
||
assert sorted(list(mux_weighted)) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 900, 901, 902] | ||
assert sorted(list(mux_weighted)) != list(mux_weighted) | ||
assert list(mux_weighted) != list(mux_uniform) | ||
|
||
|
||
def test_cut_set_mux(): | ||
cuts1 = DummyManifest(CutSet, begin_id=0, end_id=10) | ||
cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1005) | ||
|
||
cuts_mux = CutSet.mux(cuts1, cuts2, seed=0) | ||
|
||
def cid(i: int) -> str: | ||
return f"dummy-cut-{i:04d}" | ||
|
||
assert sorted([c.id for c in cuts_mux]) == [ | ||
cid(i) for i in (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1000, 1001, 1002, 1003, 1004) | ||
] | ||
assert sorted([c.id for c in cuts_mux]) != [c.id for c in cuts_mux] | ||
|
||
|
||
def test_multiplexer_pickling(): | ||
mux = LazyIteratorMultiplexer( | ||
list(range(100)), list(range(10)), weights=[2, 3], seed=0 | ||
) | ||
|
||
data = pickle.dumps(mux) | ||
mux_rec = pickle.loads(data) | ||
|
||
assert list(mux) == list(mux_rec) | ||
|
||
|
||
def test_multiplexer_with_cuts_pickling(): | ||
cuts1 = DummyManifest(CutSet, begin_id=0, end_id=10) | ||
cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1005) | ||
mux = LazyIteratorMultiplexer(cuts1, cuts2, seed=0) | ||
|
||
data = pickle.dumps(mux) | ||
mux_rec = pickle.loads(data) | ||
|
||
assert list(mux) == list(mux_rec) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pzelasko why not set the weights proportional to cutset sizes by default? This way you would deplete all cutsets at the same time on average. If we keep it uniform, we are risking that a small cutset A is depleted fast and for the rest of the epoch, there will be only larger cutset B.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving it up to the user in case the cut sets are very large and opened lazily (I can also imagine len not being available for some types of lazy manifests in the future)