diff --git a/lhotse/cut.py b/lhotse/cut.py index 9858af841..e4108c417 100644 --- a/lhotse/cut.py +++ b/lhotse/cut.py @@ -4628,6 +4628,28 @@ def __add__(self, other: "CutSet") -> "CutSet": ) return CutSet(cuts=merged) + @classmethod + def mux( + cls, + *cut_sets: "CutSet", + weights: Optional[List[Union[int, float]]] = None, + seed: int = 0, + ) -> "CutSet": + """ + Merges multiple CutSets into a new CutSet by lazily multiplexing them during iteration time. + If one of the CutSets is exhausted before the others, we will keep iterating until all CutSets + are exhausted. + + :param cut_sets: cut sets to be multiplexed. + They can be either lazy or eager, but the resulting manifest will always be lazy. + :param weights: an optional weight for each CutSet, affects the probability of it being sampled. + The weights are uniform by default. + :param seed: the random seed, ensures deterministic order across multiple iterations. + """ + from lhotse.serialization import LazyIteratorMultiplexer + + return cls(cuts=LazyIteratorMultiplexer(*cut_sets, weights=weights, seed=seed)) + def make_windowed_cuts_from_features( feature_set: FeatureSet, diff --git a/lhotse/serialization.py b/lhotse/serialization.py index 7cbeab209..5bca1aa6f 100644 --- a/lhotse/serialization.py +++ b/lhotse/serialization.py @@ -1,8 +1,9 @@ import itertools import json +import random import warnings from pathlib import Path -from typing import Any, Dict, Generator, Iterable, Optional, Type, Union +from typing import Any, Dict, Generator, Iterable, List, Optional, Type, Union import yaml @@ -527,6 +528,75 @@ def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) +class LazyIteratorMultiplexer: + """ + A wrapper over multiple iterators that enables to combine lazy manifests in Lhotse. + During iteration, unlike :class:`.LazyIteratorChain`, :class:`.LazyIteratorMultiplexer` + at each step randomly selects the iterable used to yield an item. + + Since the iterables might be of different length, we provide a ``weights`` parameter + to let the user decide which iterables should be sampled more frequently than others. + When an iterable is exhausted, we will keep sampling from the other iterables, until + we exhaust them all. + """ + + def __init__( + self, + *iterators: Iterable, + weights: Optional[List[Union[int, float]]] = None, + seed: int = 0, + ) -> None: + self.iterators = list(iterators) + self.seed = seed + + assert ( + len(self.iterators) > 1 + ), "There have to be at least two iterables to multiplex." + + if weights is None: + self.weights = [1] * len(self.iterators) + else: + self.weights = weights + + assert len(self.iterators) == len(self.weights) + + def __iter__(self): + rng = random.Random(self.seed) + iters = [iter(it) for it in self.iterators] + exhausted = [False for _ in range(len(iters))] + while not all(exhausted): + active_indexes, active_weights = zip( + *[ + (i, w) + for i, (is_exhausted, w) in enumerate(zip(exhausted, self.weights)) + if not is_exhausted + ] + ) + idx = rng.choices(active_indexes, weights=active_weights, k=1)[0] + selected = iters[idx] + try: + item = next(selected) + yield item + except StopIteration: + exhausted[idx] = True + continue + + def values(self): + yield from self + + def keys(self): + return (item.id for item in self) + + def items(self): + return ((item.id, item) for item in self) + + def __len__(self) -> int: + return sum(len(it) for it in self.iterators) + + def __add__(self, other) -> "LazyIteratorChain": + return LazyIteratorChain(self, other) + + def deserialize_item(data: dict) -> Any: # Figures out what type of manifest is being decoded with some heuristics # and returns a Lhotse manifest object rather than a raw dict. diff --git a/test/test_multipexing_iterables.py b/test/test_multipexing_iterables.py new file mode 100644 index 000000000..6c00c89cd --- /dev/null +++ b/test/test_multipexing_iterables.py @@ -0,0 +1,74 @@ +import pickle +import random + +from lhotse import CutSet +from lhotse.serialization import LazyIteratorMultiplexer +from lhotse.testing.dummies import DummyManifest + + +def test_multiplexer(): + mux = LazyIteratorMultiplexer(range(10), range(900, 903), seed=0) # len 10 # len 3 + + assert sorted(list(mux)) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 900, 901, 902] + assert sorted(list(mux)) != list(mux) + + +def test_multiplexer_deterministic(): + # seed given + mux = LazyIteratorMultiplexer( + range(1000), range(900000, 901000), seed=0 # len 10 # len 3 + ) + assert list(mux) == list(mux) + + +def test_multiplexer_weights(): + mux_uniform = LazyIteratorMultiplexer( + range(10), range(900, 903), seed=0 # len 10 # len 3 + ) + mux_weighted = LazyIteratorMultiplexer( + range(10), # len 10 + range(900, 903), # len 3 + seed=0, + weights=[10, 3], + ) + + assert sorted(list(mux_weighted)) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 900, 901, 902] + assert sorted(list(mux_weighted)) != list(mux_weighted) + assert list(mux_weighted) != list(mux_uniform) + + +def test_cut_set_mux(): + cuts1 = DummyManifest(CutSet, begin_id=0, end_id=10) + cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1005) + + cuts_mux = CutSet.mux(cuts1, cuts2, seed=0) + + def cid(i: int) -> str: + return f"dummy-cut-{i:04d}" + + assert sorted([c.id for c in cuts_mux]) == [ + cid(i) for i in (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1000, 1001, 1002, 1003, 1004) + ] + assert sorted([c.id for c in cuts_mux]) != [c.id for c in cuts_mux] + + +def test_multiplexer_pickling(): + mux = LazyIteratorMultiplexer( + list(range(100)), list(range(10)), weights=[2, 3], seed=0 + ) + + data = pickle.dumps(mux) + mux_rec = pickle.loads(data) + + assert list(mux) == list(mux_rec) + + +def test_multiplexer_with_cuts_pickling(): + cuts1 = DummyManifest(CutSet, begin_id=0, end_id=10) + cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1005) + mux = LazyIteratorMultiplexer(cuts1, cuts2, seed=0) + + data = pickle.dumps(mux) + mux_rec = pickle.loads(data) + + assert list(mux) == list(mux_rec)