From 5374f0cd0a7a8bc89662e3d6c6054f6a0d6ce8d0 Mon Sep 17 00:00:00 2001 From: sabonerune <102559104+sabonerune@users.noreply.github.com> Date: Mon, 4 Nov 2024 11:49:29 +0900 Subject: [PATCH] ENH: Add mutex for global instance --- pyopenjtalk/__init__.py | 131 +++++++++++++++++++++++----------------- 1 file changed, 76 insertions(+), 55 deletions(-) diff --git a/pyopenjtalk/__init__.py b/pyopenjtalk/__init__.py index abf503a..656c508 100644 --- a/pyopenjtalk/__init__.py +++ b/pyopenjtalk/__init__.py @@ -1,10 +1,15 @@ +from __future__ import annotations + import atexit import os import sys import tarfile import tempfile -from contextlib import ExitStack +from collections.abc import Callable, Generator +from contextlib import ExitStack, contextmanager from os.path import exists +from threading import Lock +from typing import TypeVar from urllib.request import urlopen if sys.version_info >= (3, 9): @@ -44,14 +49,6 @@ ) ).encode("utf-8") -# Global instance of OpenJTalk -_global_jtalk = None -# Global instance of HTSEngine -# mei_normal.voice is used as default -_global_htsengine = None -# Global instance of Marine -_global_marine = None - def _extract_dic(): from tqdm.auto import tqdm @@ -78,6 +75,49 @@ def _lazy_init(): _extract_dic() +_T = TypeVar("_T") + + +def _global_instance_manager( + instance_factory: Callable[[], _T] | None = None, instance: _T | None = None +) -> Callable[[], Generator[_T, None, None]]: + assert instance_factory is not None or instance is not None + _instance = instance + mutex = Lock() + + @contextmanager + def manager() -> Generator[_T, None, None]: + nonlocal _instance + with mutex: + if _instance is None: + _instance = instance_factory() + yield _instance + + return manager + + +def _jtalk_factory() -> OpenJTalk: + _lazy_init() + return OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR) + + +def _marine_factory(): + try: + from marine.predict import Predictor + except ImportError: + raise ImportError("Please install marine by `pip install pyopenjtalk[marine]`") + return Predictor() + + +# Global instance of OpenJTalk +_global_jtalk = _global_instance_manager(_jtalk_factory) +# Global instance of HTSEngine +# mei_normal.voice is used as default +_global_htsengine = _global_instance_manager(lambda: HTSEngine(DEFAULT_HTS_VOICE)) +# Global instance of Marine +_global_marine = _global_instance_manager(_marine_factory) + + def g2p(*args, **kwargs): """Grapheme-to-phoeneme (G2P) conversion @@ -93,11 +133,8 @@ def g2p(*args, **kwargs): Returns: str or list: G2P result in 1) str if join is True 2) list if join is False. """ - global _global_jtalk - if _global_jtalk is None: - _lazy_init() - _global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR) - return _global_jtalk.g2p(*args, **kwargs) + with _global_jtalk() as jtalk: + return jtalk.g2p(*args, **kwargs) def estimate_accent(njd_features): @@ -111,21 +148,13 @@ def estimate_accent(njd_features): Returns: list: features for NJDNode with estimation results by marine. """ - global _global_marine - if _global_marine is None: - try: - from marine.predict import Predictor - except BaseException: - raise ImportError( - "Please install marine by `pip install pyopenjtalk[marine]`" - ) - _global_marine = Predictor() - from marine.utils.openjtalk_util import convert_njd_feature_to_marine_feature + with _global_marine() as marine: + from marine.utils.openjtalk_util import convert_njd_feature_to_marine_feature - marine_feature = convert_njd_feature_to_marine_feature(njd_features) - marine_results = _global_marine.predict( - [marine_feature], require_open_jtalk_format=True - ) + marine_feature = convert_njd_feature_to_marine_feature(njd_features) + marine_results = marine.predict( + [marine_feature], require_open_jtalk_format=True + ) njd_features = merge_njd_marine_features(njd_features, marine_results) return njd_features @@ -164,13 +193,11 @@ def synthesize(labels, speed=1.0, half_tone=0.0): if isinstance(labels, tuple) and len(labels) == 2: labels = labels[1] - global _global_htsengine - if _global_htsengine is None: - _global_htsengine = HTSEngine(DEFAULT_HTS_VOICE) - sr = _global_htsengine.get_sampling_frequency() - _global_htsengine.set_speed(speed) - _global_htsengine.add_half_tone(half_tone) - return _global_htsengine.synthesize(labels), sr + with _global_htsengine() as htsengine: + sr = htsengine.get_sampling_frequency() + htsengine.set_speed(speed) + htsengine.add_half_tone(half_tone) + return htsengine.synthesize(labels), sr def tts(text, speed=1.0, half_tone=0.0, run_marine=False): @@ -202,11 +229,8 @@ def run_frontend(text): Returns: list: features for NJDNode. """ - global _global_jtalk - if _global_jtalk is None: - _lazy_init() - _global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR) - return _global_jtalk.run_frontend(text) + with _global_jtalk() as jtalk: + return jtalk.run_frontend(text) def make_label(njd_features): @@ -218,11 +242,8 @@ def make_label(njd_features): Returns: list: full-context labels. """ - global _global_jtalk - if _global_jtalk is None: - _lazy_init() - _global_jtalk = OpenJTalk(dn_mecab=OPEN_JTALK_DICT_DIR) - return _global_jtalk.make_label(njd_features) + with _global_jtalk() as jtalk: + return jtalk.make_label(njd_features) def mecab_dict_index(path, out_path, dn_mecab=None): @@ -233,12 +254,11 @@ def mecab_dict_index(path, out_path, dn_mecab=None): out_path (str): path to output dictionary dn_mecab (optional. str): path to mecab dictionary """ - global _global_jtalk - if _global_jtalk is None: - _lazy_init() if not exists(path): raise FileNotFoundError("no such file or directory: %s" % path) if dn_mecab is None: + with _global_jtalk(): # call _lazy_init() + pass dn_mecab = OPEN_JTALK_DICT_DIR r = _mecab_dict_index(dn_mecab, path.encode("utf-8"), out_path.encode("utf-8")) @@ -257,10 +277,11 @@ def update_global_jtalk_with_user_dict(path): path (str): path to user dictionary """ global _global_jtalk - if _global_jtalk is None: - _lazy_init() - if not exists(path): - raise FileNotFoundError("no such file or directory: %s" % path) - _global_jtalk = OpenJTalk( - dn_mecab=OPEN_JTALK_DICT_DIR, userdic=path.encode("utf-8") - ) + with _global_jtalk(): + if not exists(path): + raise FileNotFoundError("no such file or directory: %s" % path) + _global_jtalk = _global_instance_manager( + instance=OpenJTalk( + dn_mecab=OPEN_JTALK_DICT_DIR, userdic=path.encode("utf-8") + ) + )