Skip to content

Commit

Permalink
local librasr import + renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
Daniel Mann committed Aug 29, 2024
1 parent e5b928d commit 4d5e3c2
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions i6_models/parts/fsa.py → i6_models/parts/rasr_fsa.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from __future__ import annotations

__all__ = ["TorchFsaBuilder", "WeightedFsa"]
__all__ = ["RasrFsaBuilder", "WeightedFsa"]

from functools import reduce
from typing import Iterable, NamedTuple, Tuple, Union

import numpy as np
import torch
import librasr


class WeightedFsa(NamedTuple):
Expand Down Expand Up @@ -43,11 +42,13 @@ def to(self, device: Union[str, torch.device]) -> WeightedFsa:
return WeightedFsa._make(tensor.to(device) for tensor in self)


class TorchFsaBuilder:
class RasrFsaBuilder:
"""
Builder class that wraps around the librasr.AllophoneStateFsaBuilder,
bringing the FSAs into the correct format for the `i6_native_ops.fbw.fbw_loss`.
Use of this class requires a working installation of the python package `librasr`.
Hence, the package is locally imported in case other classes are accessed from
this module.
This class provides an explicit implementation of the `__getstate__` and `__setstate__`
functions, necessary for pickling as the C++-class `librasr.AllophoneStateFsaBuilder`
is not picklable.
Expand All @@ -57,6 +58,8 @@ class TorchFsaBuilder:
"""

def __init__(self, config_path: str, tdp_scale: float = 1.0):
import librasr

self.config_path = config_path
config = librasr.Configuration()
config.set_from_file(self.config_path)
Expand All @@ -69,6 +72,8 @@ def __getstate__(self):
return state

def __setstate__(self, state):
import librasr

self.__dict__.update(state)
config = librasr.Configuration()
config.set_from_file(self.config_path)
Expand Down

0 comments on commit 4d5e3c2

Please sign in to comment.