diff --git a/requirements.txt b/requirements.txt index 3c85381..f2f8498 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,5 @@ tflite-runtime-nightly onnxruntime>=1.10.0,<2 wyoming==1.5.3 +scikit-learn>=1,<2 +openwakeword>=0.6.0 diff --git a/wyoming_openwakeword/__main__.py b/wyoming_openwakeword/__main__.py index ef80601..0ac9da3 100644 --- a/wyoming_openwakeword/__main__.py +++ b/wyoming_openwakeword/__main__.py @@ -31,6 +31,11 @@ async def main() -> None: default=[], help="Path to directory with custom wake word models", ) + parser.add_argument( + "--custom-verifiers-dir", + default=None, + help="Path to directory with custom verifiers named after the model files with .pkl ending", + ) parser.add_argument( "--preload-model", action="append", @@ -43,6 +48,12 @@ async def main() -> None: default=0.5, help="Wake word model threshold (0.0-1.0, default: 0.5)", ) + parser.add_argument( + "--custom-verifier-threshold", + type=float, + default=0.3, + help="Threshold of the OpenWakeWord model when custom verifiers are active (0.0-1.0, default: 0.3)" + ) parser.add_argument( "--trigger-level", type=int, @@ -95,6 +106,7 @@ async def main() -> None: state = State( models_dir=Path(args.models_dir), custom_model_dirs=[Path(d) for d in args.custom_model_dir], + custom_verifiers_dir=Path(args.custom_verifiers_dir) if args.custom_verifiers_dir else None, debug_probability=args.debug_probability, output_dir=args.output_dir, ) @@ -105,6 +117,7 @@ async def main() -> None: args.preload_model, threshold=args.threshold, trigger_level=args.trigger_level, + custom_verifier_threshold=args.custom_verifier_threshold, vad_threshold=args.vad_threshold, ) diff --git a/wyoming_openwakeword/const.py b/wyoming_openwakeword/const.py index b0f47ed..b9a626e 100644 --- a/wyoming_openwakeword/const.py +++ b/wyoming_openwakeword/const.py @@ -45,6 +45,7 @@ class WakeWordData: activations: int = 0 threshold: float = 0.5 trigger_level: int = 1 + custom_verifier_threshold: float = 0.3 is_processing: bool = False def reset(self) -> None: diff --git a/wyoming_openwakeword/handler.py b/wyoming_openwakeword/handler.py index d37c253..81086ab 100644 --- a/wyoming_openwakeword/handler.py +++ b/wyoming_openwakeword/handler.py @@ -75,6 +75,7 @@ async def handle_event(self, event: Event) -> bool: detect.names, threshold=self.cli_args.threshold, trigger_level=self.cli_args.trigger_level, + custom_verifier_threshold=self.cli_args.custom_verifier_threshold, vad_threshold=self.cli_args.vad_threshold, ) @@ -214,7 +215,7 @@ def _get_info(self) -> Info: # ----------------------------------------------------------------------------- -def ensure_loaded(state: State, names: List[str], threshold: float, trigger_level: int, vad_threshold: float): +def ensure_loaded(state: State, names: List[str], threshold: float, trigger_level: int, custom_verifier_threshold: float, vad_threshold: float): """Ensure wake words are loaded by name.""" with state.clients_lock, state.ww_threads_lock: for model_name in names: @@ -243,6 +244,19 @@ def ensure_loaded(state: State, names: List[str], threshold: float, trigger_leve if model_path is None: raise ValueError(f"Wake word model not found: {model_name}") + custom_verifier_paths = _get_custom_verifier_files(state) + custom_verifier_path: Optional[Path] = None + + for maybe_custom_verifier_path in custom_verifier_paths: + if norm_model_name == _normalize_key(maybe_custom_verifier_path.stem): + custom_verifier_path = maybe_custom_verifier_path + break + + if match := _WAKE_WORD_WITH_VERSION.match(model_name): + if _normalize_key(maybe_custom_verifier_path.stem) == _normalize_key(match.group(1)): + custom_verifier_path = maybe_custom_verifier_path + break + # Start thread for model model_key = model_path.stem state.wake_words[model_key] = WakeWordState() @@ -253,6 +267,7 @@ def ensure_loaded(state: State, names: List[str], threshold: float, trigger_leve state, model_key, model_path, + custom_verifier_path, asyncio.get_running_loop(), vad_threshold, ), @@ -263,6 +278,7 @@ def ensure_loaded(state: State, names: List[str], threshold: float, trigger_leve client_data.wake_words[model_key] = WakeWordData( threshold=threshold, trigger_level=trigger_level, + custom_verifier_threshold=custom_verifier_threshold ) _LOGGER.debug("Started thread for %s", model_key) @@ -285,6 +301,13 @@ def _get_wake_word_files(state: State) -> List[Path]: return model_paths +def _get_custom_verifier_files(state: State) -> List[Path]: + """Get paths to all available custom verifier files.""" + if state.custom_verifiers_dir: + return state.custom_verifiers_dir.glob("*.pkl") + return [] + + def _normalize_key(model_key: str) -> str: """Normalize model key for comparison.""" return model_key.lower().replace("_", " ").strip() diff --git a/wyoming_openwakeword/openwakeword.py b/wyoming_openwakeword/openwakeword.py index f6110f2..e4e1a03 100644 --- a/wyoming_openwakeword/openwakeword.py +++ b/wyoming_openwakeword/openwakeword.py @@ -2,6 +2,7 @@ import logging from datetime import datetime from typing import Dict, List, Optional, TextIO +import pickle import numpy as np @@ -242,6 +243,7 @@ def ww_proc( state: State, ww_model_key: str, ww_model_path: str, + custom_verifier_path: Optional[str], loop: asyncio.AbstractEventLoop, vad_threshold: float, ): @@ -256,6 +258,11 @@ def ww_proc( ww_output_index = ww_model.get_output_details()[0]["index"] # ww = [batch x window x features (96)] => [batch x probability] + custom_verifier = None + if custom_verifier_path: + _LOGGER.info("Loading custom verifier %s", custom_verifier_path) + custom_verifier = pickle.load(open(custom_verifier_path, 'rb')) + client: Optional[ClientData] = None ww_state = state.wake_words[ww_model_key] @@ -317,6 +324,10 @@ def ww_proc( probabilities = ww_model.get_tensor(ww_output_index) probability = probabilities[0] + custom_verifier_probability = 1.0 + if custom_verifier: + custom_verifier_probability = custom_verifier.predict_proba(embeddings_tensor)[0][-1] + coros = [] with state.clients_lock: client = state.clients.get(client_id) @@ -329,11 +340,12 @@ def ww_proc( voice_detected = (vad_threshold <= 0.0 or vad_max_score >= vad_threshold) if state.debug_probability: _LOGGER.debug( - "client=%s, wake_word=%s, probability=%s, vad_probability=%s", + "client=%s, wake_word=%s, probability=%s, vad_probability=%s, custom_verifier_probability=%s", client_id, ww_model_key, probability.item(), vad_max_score, + custom_verifier_probability, ) prob_file: Optional[TextIO] = None @@ -352,8 +364,13 @@ def ww_proc( ) client_data = client.wake_words[ww_model_key] + probability = probability.item() + + if custom_verifier: + if probability >= client_data.custom_verifier_threshold: + probability = custom_verifier_probability - if probability.item() >= client_data.threshold and voice_detected: + if probability >= client_data.threshold and voice_detected: # Increase activation client_data.activations += 1 diff --git a/wyoming_openwakeword/state.py b/wyoming_openwakeword/state.py index 392cd27..87a1c28 100644 --- a/wyoming_openwakeword/state.py +++ b/wyoming_openwakeword/state.py @@ -20,6 +20,9 @@ class State: custom_model_dirs: List[Path] = field(default_factory=list) """Directories with custom wake word models.""" + custom_verifiers_dir: Path = None + """Directory with custom verifiers""" + ww_threads: Dict[str, Thread] = field(default_factory=dict) ww_threads_lock: Lock = field(default_factory=Lock)