From 2efc746ba2bd8f2186d9a8c24620edfd58a5f43c Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Fri, 22 Nov 2024 00:59:32 +0900 Subject: [PATCH] Revert "Wrap with `Query` to fix doc issue" This reverts commit 4fd1319312c1a19825ea33f6e7ff7af0fcebb1c3. --- modules/whisper/data_classes.py | 227 ++++++++++++-------------------- 1 file changed, 86 insertions(+), 141 deletions(-) diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py index f5d3459..705e4b8 100644 --- a/modules/whisper/data_classes.py +++ b/modules/whisper/data_classes.py @@ -88,44 +88,32 @@ def from_list(cls, data_list: List) -> 'BaseParams': # More info : https://github.com/fastapi/fastapi/discussions/8634#discussioncomment-5153136 class VadParams(BaseParams): """Voice Activity Detection parameters""" - vad_filter: bool = Field( - Query(default=False, description="Enable voice activity detection to filter out non-speech parts") - ) + vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts") threshold: float = Field( - Query( - default=0.5, - ge=0.0, - le=1.0, - description="Speech threshold for Silero VAD. Probabilities above this value are considered speech" - ) + default=0.5, + ge=0.0, + le=1.0, + description="Speech threshold for Silero VAD. Probabilities above this value are considered speech" ) min_speech_duration_ms: int = Field( - Query( - default=250, - ge=0, - description="Final speech chunks shorter than this are discarded" - ) + default=250, + ge=0, + description="Final speech chunks shorter than this are discarded" ) max_speech_duration_s: float = Field( - Query( - default=float("inf"), - gt=0, - description="Maximum duration of speech chunks in seconds" - ) + default=float("inf"), + gt=0, + description="Maximum duration of speech chunks in seconds" ) min_silence_duration_ms: int = Field( - Query( - default=2000, - ge=0, - description="Minimum silence duration between speech chunks" - ) + default=2000, + ge=0, + description="Minimum silence duration between speech chunks" ) speech_pad_ms: int = Field( - Query( - default=400, - ge=0, - description="Padding added to each side of speech chunks" - ) + default=400, + ge=0, + description="Padding added to each side of speech chunks" ) @classmethod @@ -167,17 +155,11 @@ def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components class DiarizationParams(BaseParams): """Speaker diarization parameters""" - is_diarize: bool = Field( - Query(default=False, description="Enable speaker diarization") - ) - diarization_device: str = Field( - Query(default="cuda", description="Device to run Diarization model.") - ) + is_diarize: bool = Field(default=False, description="Enable speaker diarization") + diarization_device: str = Field(default="cuda", description="Device to run Diarization model.") hf_token: str = Field( - Query( - default="", - description="Hugging Face token for downloading diarization models" - ) + default="", + description="Hugging Face token for downloading diarization models" ) @classmethod @@ -205,30 +187,24 @@ def to_gradio_inputs(cls, class BGMSeparationParams(BaseParams): """Background music separation parameters""" - is_separate_bgm: bool = Field( - Query(default=False, description="Enable background music separation") - ) + is_separate_bgm: bool = Field(default=False, description="Enable background music separation") uvr_model_size: str = Field( - Query( - default="UVR-MDX-NET-Inst_HQ_4", - description="UVR model size" - ) - ) - uvr_device: str = Field( - Query(default="cuda", description="Device to run UVR model.") + default="UVR-MDX-NET-Inst_HQ_4", + description="UVR model size" ) + uvr_device: str = Field(default="cuda", description="Device to run UVR model.") segment_size: int = Field( - Query( - default=256, - gt=0, - description="Segment size for UVR model" - ) + default=256, + gt=0, + description="Segment size for UVR model" ) save_file: bool = Field( - Query(default=False, description="Whether to save separated audio files") + default=False, + description="Whether to save separated audio files" ) enable_offload: bool = Field( - Query(default=True, description="Offload UVR model after transcription") + default=True, + description="Offload UVR model after transcription" ) @classmethod @@ -274,115 +250,84 @@ def to_gradio_input(cls, class WhisperParams(BaseParams): """Whisper parameters""" - model_size: str = Field( - Query(default="large-v2", description="Whisper model size") - ) - lang: Optional[str] = Field( - Query(default=None, description="Source language of the file to transcribe") - ) - is_translate: bool = Field( - Query(default=False, description="Translate speech to English end-to-end") - ) - beam_size: int = Field( - Query(default=5, ge=1, description="Beam size for decoding") - ) + model_size: str = Field(default="large-v2", description="Whisper model size") + lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe") + is_translate: bool = Field(default=False, description="Translate speech to English end-to-end") + beam_size: int = Field(default=5, ge=1, description="Beam size for decoding") log_prob_threshold: float = Field( - Query(default=-1.0, description="Threshold for average log probability of sampled tokens") + default=-1.0, + description="Threshold for average log probability of sampled tokens" ) no_speech_threshold: float = Field( - Query( - default=0.6, - ge=0.0, - le=1.0, - description="Threshold for detecting silence" - ) - ) - compute_type: str = Field( - Query(default="float16", description="Computation type for transcription") - ) - best_of: int = Field( - Query(default=5, ge=1, description="Number of candidates when sampling") - ) - patience: float = Field( - Query(default=1.0, gt=0, description="Beam search patience factor") - ) + default=0.6, + ge=0.0, + le=1.0, + description="Threshold for detecting silence" + ) + compute_type: str = Field(default="float16", description="Computation type for transcription") + best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling") + patience: float = Field(default=1.0, gt=0, description="Beam search patience factor") condition_on_previous_text: bool = Field( - Query(default=True, description="Use previous output as prompt for next window") + default=True, + description="Use previous output as prompt for next window" ) prompt_reset_on_temperature: float = Field( - Query( - default=0.5, - ge=0.0, - le=1.0, - description="Temperature threshold for resetting prompt" - ) - ) - initial_prompt: Optional[str] = Field( - Query(default=None, description="Initial prompt for first window") + default=0.5, + ge=0.0, + le=1.0, + description="Temperature threshold for resetting prompt" ) + initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window") temperature: float = Field( - Query(default=0.0, ge=0.0, description="Temperature for sampling") + default=0.0, + ge=0.0, + description="Temperature for sampling" ) compression_ratio_threshold: float = Field( - Query(default=2.4, gt=0, description="Threshold for gzip compression ratio") - ) - length_penalty: float = Field( - Query(default=1.0, gt=0, description="Exponential length penalty") - ) - repetition_penalty: float = Field( - Query(default=1.0, gt=0, description="Penalty for repeated tokens") - ) - no_repeat_ngram_size: int = Field( - Query(default=0, ge=0, description="Size of n-grams to prevent repetition") - ) - prefix: Optional[str] = Field( - Query(default=None, description="Prefix text for first window") - ) + default=2.4, + gt=0, + description="Threshold for gzip compression ratio" + ) + length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty") + repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens") + no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition") + prefix: Optional[str] = Field(default=None, description="Prefix text for first window") suppress_blank: bool = Field( - Query(default=True, description="Suppress blank outputs at start of sampling") - ) - suppress_tokens: Optional[Union[List[int], str]] = Field( - Query(default=[-1], description="Token IDs to suppress") + default=True, + description="Suppress blank outputs at start of sampling" ) + suppress_tokens: Optional[Union[List[int], str]] = Field(default=[-1], description="Token IDs to suppress") max_initial_timestamp: float = Field( - Query(default=1.0, ge=0.0, description="Maximum initial timestamp") - ) - word_timestamps: bool = Field( - Query(default=False, description="Extract word-level timestamps") + default=1.0, + ge=0.0, + description="Maximum initial timestamp" ) + word_timestamps: bool = Field(default=False, description="Extract word-level timestamps") prepend_punctuations: Optional[str] = Field( - Query( - default="\"'“¿([{-", - description="Punctuations to merge with next word" - ) + default="\"'“¿([{-", + description="Punctuations to merge with next word" ) append_punctuations: Optional[str] = Field( - Query( - default="\"'.。,,!!??::”)]}、", - description="Punctuations to merge with previous word" - ) - ) - max_new_tokens: Optional[int] = Field( - Query(default=None, description="Maximum number of new tokens per chunk") - ) - chunk_length: Optional[int] = Field( - Query(default=30, description="Length of audio segments in seconds") + default="\"'.。,,!!??::”)]}、", + description="Punctuations to merge with previous word" ) + max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk") + chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds") hallucination_silence_threshold: Optional[float] = Field( - Query(default=None, description="Threshold for skipping silent periods in hallucination detection") - ) - hotwords: Optional[str] = Field( - Query(default=None, description="Hotwords/hint phrases for the model") + default=None, + description="Threshold for skipping silent periods in hallucination detection" ) + hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model") language_detection_threshold: Optional[float] = Field( - Query(default=None, description="Threshold for language detection probability") + default=None, + description="Threshold for language detection probability" ) language_detection_segments: int = Field( - Query(default=1, gt=0, description="Number of segments for language detection") - ) - batch_size: int = Field( - Query(default=24, gt=0, description="Batch size for processing") + default=1, + gt=0, + description="Number of segments for language detection" ) + batch_size: int = Field(default=24, gt=0, description="Batch size for processing") @field_validator('lang') def validate_lang(cls, v):