Skip to content

Commit

Permalink
Implement primitive custom verifier support
Browse files Browse the repository at this point in the history
  • Loading branch information
jhbruhn committed Apr 2, 2024
1 parent 8e679a5 commit 1c84f07
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 3 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
tflite-runtime-nightly
onnxruntime>=1.10.0,<2
wyoming==1.5.3
scikit-learn>=1,<2
openwakeword>=0.6.0
13 changes: 13 additions & 0 deletions wyoming_openwakeword/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ async def main() -> None:
default=[],
help="Path to directory with custom wake word models",
)
parser.add_argument(
"--custom-verifiers-dir",
default=None,
help="Path to directory with custom verifiers named after the model files with .pkl ending",
)
parser.add_argument(
"--preload-model",
action="append",
Expand All @@ -43,6 +48,12 @@ async def main() -> None:
default=0.5,
help="Wake word model threshold (0.0-1.0, default: 0.5)",
)
parser.add_argument(
"--custom-verifier-threshold",
type=float,
default=0.3,
help="Threshold of the OpenWakeWord model when custom verifiers are active (0.0-1.0, default: 0.3)"
)
parser.add_argument(
"--trigger-level",
type=int,
Expand Down Expand Up @@ -95,6 +106,7 @@ async def main() -> None:
state = State(
models_dir=Path(args.models_dir),
custom_model_dirs=[Path(d) for d in args.custom_model_dir],
custom_verifiers_dir=Path(args.custom_verifiers_dir) if args.custom_verifiers_dir else None,
debug_probability=args.debug_probability,
output_dir=args.output_dir,
)
Expand All @@ -105,6 +117,7 @@ async def main() -> None:
args.preload_model,
threshold=args.threshold,
trigger_level=args.trigger_level,
custom_verifier_threshold=args.custom_verifier_threshold,
vad_threshold=args.vad_threshold,
)

Expand Down
1 change: 1 addition & 0 deletions wyoming_openwakeword/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class WakeWordData:
activations: int = 0
threshold: float = 0.5
trigger_level: int = 1
custom_verifier_threshold: float = 0.3
is_processing: bool = False

def reset(self) -> None:
Expand Down
25 changes: 24 additions & 1 deletion wyoming_openwakeword/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def handle_event(self, event: Event) -> bool:
detect.names,
threshold=self.cli_args.threshold,
trigger_level=self.cli_args.trigger_level,
custom_verifier_threshold=self.cli_args.custom_verifier_threshold,
vad_threshold=self.cli_args.vad_threshold,
)

Expand Down Expand Up @@ -214,7 +215,7 @@ def _get_info(self) -> Info:
# -----------------------------------------------------------------------------


def ensure_loaded(state: State, names: List[str], threshold: float, trigger_level: int, vad_threshold: float):
def ensure_loaded(state: State, names: List[str], threshold: float, trigger_level: int, custom_verifier_threshold: float, vad_threshold: float):
"""Ensure wake words are loaded by name."""
with state.clients_lock, state.ww_threads_lock:
for model_name in names:
Expand Down Expand Up @@ -243,6 +244,19 @@ def ensure_loaded(state: State, names: List[str], threshold: float, trigger_leve
if model_path is None:
raise ValueError(f"Wake word model not found: {model_name}")

custom_verifier_paths = _get_custom_verifier_files(state)
custom_verifier_path: Optional[Path] = None

for maybe_custom_verifier_path in custom_verifier_paths:
if norm_model_name == _normalize_key(maybe_custom_verifier_path.stem):
custom_verifier_path = maybe_custom_verifier_path
break

if match := _WAKE_WORD_WITH_VERSION.match(model_name):
if _normalize_key(maybe_custom_verifier_path.stem) == _normalize_key(match.group(1)):
custom_verifier_path = maybe_custom_verifier_path
break

# Start thread for model
model_key = model_path.stem
state.wake_words[model_key] = WakeWordState()
Expand All @@ -253,6 +267,7 @@ def ensure_loaded(state: State, names: List[str], threshold: float, trigger_leve
state,
model_key,
model_path,
custom_verifier_path,
asyncio.get_running_loop(),
vad_threshold,
),
Expand All @@ -263,6 +278,7 @@ def ensure_loaded(state: State, names: List[str], threshold: float, trigger_leve
client_data.wake_words[model_key] = WakeWordData(
threshold=threshold,
trigger_level=trigger_level,
custom_verifier_threshold=custom_verifier_threshold
)

_LOGGER.debug("Started thread for %s", model_key)
Expand All @@ -285,6 +301,13 @@ def _get_wake_word_files(state: State) -> List[Path]:
return model_paths


def _get_custom_verifier_files(state: State) -> List[Path]:
"""Get paths to all available custom verifier files."""
if state.custom_verifiers_dir:
return state.custom_verifiers_dir.glob("*.pkl")
return []


def _normalize_key(model_key: str) -> str:
"""Normalize model key for comparison."""
return model_key.lower().replace("_", " ").strip()
Expand Down
21 changes: 19 additions & 2 deletions wyoming_openwakeword/openwakeword.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from datetime import datetime
from typing import Dict, List, Optional, TextIO
import pickle

import numpy as np

Expand Down Expand Up @@ -242,6 +243,7 @@ def ww_proc(
state: State,
ww_model_key: str,
ww_model_path: str,
custom_verifier_path: Optional[str],
loop: asyncio.AbstractEventLoop,
vad_threshold: float,
):
Expand All @@ -256,6 +258,11 @@ def ww_proc(
ww_output_index = ww_model.get_output_details()[0]["index"]
# ww = [batch x window x features (96)] => [batch x probability]

custom_verifier = None
if custom_verifier_path:
_LOGGER.info("Loading custom verifier %s", custom_verifier_path)
custom_verifier = pickle.load(open(custom_verifier_path, 'rb'))

client: Optional[ClientData] = None

ww_state = state.wake_words[ww_model_key]
Expand Down Expand Up @@ -317,6 +324,10 @@ def ww_proc(
probabilities = ww_model.get_tensor(ww_output_index)
probability = probabilities[0]

custom_verifier_probability = 1.0
if custom_verifier:
custom_verifier_probability = custom_verifier.predict_proba(embeddings_tensor)[0][-1]

coros = []
with state.clients_lock:
client = state.clients.get(client_id)
Expand All @@ -329,11 +340,12 @@ def ww_proc(
voice_detected = (vad_threshold <= 0.0 or vad_max_score >= vad_threshold)
if state.debug_probability:
_LOGGER.debug(
"client=%s, wake_word=%s, probability=%s, vad_probability=%s",
"client=%s, wake_word=%s, probability=%s, vad_probability=%s, custom_verifier_probability=%s",
client_id,
ww_model_key,
probability.item(),
vad_max_score,
custom_verifier_probability,
)

prob_file: Optional[TextIO] = None
Expand All @@ -352,8 +364,13 @@ def ww_proc(
)

client_data = client.wake_words[ww_model_key]
probability = probability.item()

if custom_verifier:
if probability >= client_data.custom_verifier_threshold:
probability = custom_verifier_probability

if probability.item() >= client_data.threshold and voice_detected:
if probability >= client_data.threshold and voice_detected:
# Increase activation
client_data.activations += 1

Expand Down
3 changes: 3 additions & 0 deletions wyoming_openwakeword/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ class State:
custom_model_dirs: List[Path] = field(default_factory=list)
"""Directories with custom wake word models."""

custom_verifiers_dir: Path = None
"""Directory with custom verifiers"""

ww_threads: Dict[str, Thread] = field(default_factory=dict)
ww_threads_lock: Lock = field(default_factory=Lock)

Expand Down

0 comments on commit 1c84f07

Please sign in to comment.