Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CutSet multiplexing #565

Merged
merged 3 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions lhotse/cut.py
Original file line number Diff line number Diff line change
Expand Up @@ -4628,6 +4628,28 @@ def __add__(self, other: "CutSet") -> "CutSet":
)
return CutSet(cuts=merged)

@classmethod
def mux(
cls,
*cut_sets: "CutSet",
weights: Optional[List[Union[int, float]]] = None,
seed: int = 0,
) -> "CutSet":
"""
Merges multiple CutSets into a new CutSet by lazily multiplexing them during iteration time.
If one of the CutSets is exhausted before the others, we will keep iterating until all CutSets
are exhausted.

:param cut_sets: cut sets to be multiplexed.
They can be either lazy or eager, but the resulting manifest will always be lazy.
:param weights: an optional weight for each CutSet, affects the probability of it being sampled.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pzelasko why not set the weights proportional to cutset sizes by default? This way you would deplete all cutsets at the same time on average. If we keep it uniform, we are risking that a small cutset A is depleted fast and for the rest of the epoch, there will be only larger cutset B.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving it up to the user in case the cut sets are very large and opened lazily (I can also imagine len not being available for some types of lazy manifests in the future)

The weights are uniform by default.
:param seed: the random seed, ensures deterministic order across multiple iterations.
"""
from lhotse.serialization import LazyIteratorMultiplexer

return cls(cuts=LazyIteratorMultiplexer(*cut_sets, weights=weights, seed=seed))


def make_windowed_cuts_from_features(
feature_set: FeatureSet,
Expand Down
72 changes: 71 additions & 1 deletion lhotse/serialization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import itertools
import json
import random
import warnings
from pathlib import Path
from typing import Any, Dict, Generator, Iterable, Optional, Type, Union
from typing import Any, Dict, Generator, Iterable, List, Optional, Type, Union

import yaml

Expand Down Expand Up @@ -527,6 +528,75 @@ def __add__(self, other) -> "LazyIteratorChain":
return LazyIteratorChain(self, other)


class LazyIteratorMultiplexer:
"""
A wrapper over multiple iterators that enables to combine lazy manifests in Lhotse.
During iteration, unlike :class:`.LazyIteratorChain`, :class:`.LazyIteratorMultiplexer`
at each step randomly selects the iterable used to yield an item.

Since the iterables might be of different length, we provide a ``weights`` parameter
to let the user decide which iterables should be sampled more frequently than others.
When an iterable is exhausted, we will keep sampling from the other iterables, until
we exhaust them all.
"""

def __init__(
self,
*iterators: Iterable,
weights: Optional[List[Union[int, float]]] = None,
seed: int = 0,
) -> None:
self.iterators = list(iterators)
self.seed = seed

assert (
len(self.iterators) > 1
), "There have to be at least two iterables to multiplex."

if weights is None:
self.weights = [1] * len(self.iterators)
else:
self.weights = weights

assert len(self.iterators) == len(self.weights)

def __iter__(self):
rng = random.Random(self.seed)
iters = [iter(it) for it in self.iterators]
exhausted = [False for _ in range(len(iters))]
while not all(exhausted):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, can we support to specify that as soon as a specified cutset is exhausted, it breaks the while loop.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, see #585

active_indexes, active_weights = zip(
*[
(i, w)
for i, (is_exhausted, w) in enumerate(zip(exhausted, self.weights))
if not is_exhausted
]
)
idx = rng.choices(active_indexes, weights=active_weights, k=1)[0]
selected = iters[idx]
try:
item = next(selected)
yield item
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to return also the idx?

We may need to select a different network to process the returned item depending on the returned idx.

For instance, when combining gigaspeech and librispeech in transducer training, the encoder network is shared and there are two separate decoder+joiner networks for each dataset. If the returned item is from gigaspeech, we would run the decoder+joiner for gigaspeech; if the returned item is from librispeech, we would run the decoder+joiner for librispeech.

If this function returns only item without idx, it is difficult to tell which dataset the returned item is sampled from. Therefore, it's also difficult to select which decoder+joiner to run.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm that would be problematic, it would break the API that CutSet and CutSampler is depending on. But there are two other ways that would be much easier:

  1. identify where a cut comes from based on cut/supervision/recording ID
  2. extend the manifests by adding assigning a custom field that identifies which domain the cut belongs to (cut.origin = "libri") before saving them to disk

except StopIteration:
exhausted[idx] = True
continue

def values(self):
yield from self

def keys(self):
return (item.id for item in self)

def items(self):
return ((item.id, item) for item in self)

def __len__(self) -> int:
return sum(len(it) for it in self.iterators)

def __add__(self, other) -> "LazyIteratorChain":
return LazyIteratorChain(self, other)


def deserialize_item(data: dict) -> Any:
# Figures out what type of manifest is being decoded with some heuristics
# and returns a Lhotse manifest object rather than a raw dict.
Expand Down
74 changes: 74 additions & 0 deletions test/test_multipexing_iterables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import pickle
import random

from lhotse import CutSet
from lhotse.serialization import LazyIteratorMultiplexer
from lhotse.testing.dummies import DummyManifest


def test_multiplexer():
mux = LazyIteratorMultiplexer(range(10), range(900, 903), seed=0) # len 10 # len 3

assert sorted(list(mux)) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 900, 901, 902]
assert sorted(list(mux)) != list(mux)


def test_multiplexer_deterministic():
# seed given
mux = LazyIteratorMultiplexer(
range(1000), range(900000, 901000), seed=0 # len 10 # len 3
)
assert list(mux) == list(mux)


def test_multiplexer_weights():
mux_uniform = LazyIteratorMultiplexer(
range(10), range(900, 903), seed=0 # len 10 # len 3
)
mux_weighted = LazyIteratorMultiplexer(
range(10), # len 10
range(900, 903), # len 3
seed=0,
weights=[10, 3],
)

assert sorted(list(mux_weighted)) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 900, 901, 902]
assert sorted(list(mux_weighted)) != list(mux_weighted)
assert list(mux_weighted) != list(mux_uniform)


def test_cut_set_mux():
cuts1 = DummyManifest(CutSet, begin_id=0, end_id=10)
cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1005)

cuts_mux = CutSet.mux(cuts1, cuts2, seed=0)

def cid(i: int) -> str:
return f"dummy-cut-{i:04d}"

assert sorted([c.id for c in cuts_mux]) == [
cid(i) for i in (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1000, 1001, 1002, 1003, 1004)
]
assert sorted([c.id for c in cuts_mux]) != [c.id for c in cuts_mux]


def test_multiplexer_pickling():
mux = LazyIteratorMultiplexer(
list(range(100)), list(range(10)), weights=[2, 3], seed=0
)

data = pickle.dumps(mux)
mux_rec = pickle.loads(data)

assert list(mux) == list(mux_rec)


def test_multiplexer_with_cuts_pickling():
cuts1 = DummyManifest(CutSet, begin_id=0, end_id=10)
cuts2 = DummyManifest(CutSet, begin_id=1000, end_id=1005)
mux = LazyIteratorMultiplexer(cuts1, cuts2, seed=0)

data = pickle.dumps(mux)
mux_rec = pickle.loads(data)

assert list(mux) == list(mux_rec)