Skip to content

Commit

Permalink
Revert "Wrap with Query to fix doc issue"
Browse files Browse the repository at this point in the history
This reverts commit 4fd1319.
  • Loading branch information
jhj0517 committed Nov 21, 2024
1 parent 3941bbb commit 2efc746
Showing 1 changed file with 86 additions and 141 deletions.
227 changes: 86 additions & 141 deletions modules/whisper/data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,44 +88,32 @@ def from_list(cls, data_list: List) -> 'BaseParams':
# More info : https://github.com/fastapi/fastapi/discussions/8634#discussioncomment-5153136
class VadParams(BaseParams):
"""Voice Activity Detection parameters"""
vad_filter: bool = Field(
Query(default=False, description="Enable voice activity detection to filter out non-speech parts")
)
vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
threshold: float = Field(
Query(
default=0.5,
ge=0.0,
le=1.0,
description="Speech threshold for Silero VAD. Probabilities above this value are considered speech"
)
default=0.5,
ge=0.0,
le=1.0,
description="Speech threshold for Silero VAD. Probabilities above this value are considered speech"
)
min_speech_duration_ms: int = Field(
Query(
default=250,
ge=0,
description="Final speech chunks shorter than this are discarded"
)
default=250,
ge=0,
description="Final speech chunks shorter than this are discarded"
)
max_speech_duration_s: float = Field(
Query(
default=float("inf"),
gt=0,
description="Maximum duration of speech chunks in seconds"
)
default=float("inf"),
gt=0,
description="Maximum duration of speech chunks in seconds"
)
min_silence_duration_ms: int = Field(
Query(
default=2000,
ge=0,
description="Minimum silence duration between speech chunks"
)
default=2000,
ge=0,
description="Minimum silence duration between speech chunks"
)
speech_pad_ms: int = Field(
Query(
default=400,
ge=0,
description="Padding added to each side of speech chunks"
)
default=400,
ge=0,
description="Padding added to each side of speech chunks"
)

@classmethod
Expand Down Expand Up @@ -167,17 +155,11 @@ def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components

class DiarizationParams(BaseParams):
"""Speaker diarization parameters"""
is_diarize: bool = Field(
Query(default=False, description="Enable speaker diarization")
)
diarization_device: str = Field(
Query(default="cuda", description="Device to run Diarization model.")
)
is_diarize: bool = Field(default=False, description="Enable speaker diarization")
diarization_device: str = Field(default="cuda", description="Device to run Diarization model.")
hf_token: str = Field(
Query(
default="",
description="Hugging Face token for downloading diarization models"
)
default="",
description="Hugging Face token for downloading diarization models"
)

@classmethod
Expand Down Expand Up @@ -205,30 +187,24 @@ def to_gradio_inputs(cls,

class BGMSeparationParams(BaseParams):
"""Background music separation parameters"""
is_separate_bgm: bool = Field(
Query(default=False, description="Enable background music separation")
)
is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
uvr_model_size: str = Field(
Query(
default="UVR-MDX-NET-Inst_HQ_4",
description="UVR model size"
)
)
uvr_device: str = Field(
Query(default="cuda", description="Device to run UVR model.")
default="UVR-MDX-NET-Inst_HQ_4",
description="UVR model size"
)
uvr_device: str = Field(default="cuda", description="Device to run UVR model.")
segment_size: int = Field(
Query(
default=256,
gt=0,
description="Segment size for UVR model"
)
default=256,
gt=0,
description="Segment size for UVR model"
)
save_file: bool = Field(
Query(default=False, description="Whether to save separated audio files")
default=False,
description="Whether to save separated audio files"
)
enable_offload: bool = Field(
Query(default=True, description="Offload UVR model after transcription")
default=True,
description="Offload UVR model after transcription"
)

@classmethod
Expand Down Expand Up @@ -274,115 +250,84 @@ def to_gradio_input(cls,

class WhisperParams(BaseParams):
"""Whisper parameters"""
model_size: str = Field(
Query(default="large-v2", description="Whisper model size")
)
lang: Optional[str] = Field(
Query(default=None, description="Source language of the file to transcribe")
)
is_translate: bool = Field(
Query(default=False, description="Translate speech to English end-to-end")
)
beam_size: int = Field(
Query(default=5, ge=1, description="Beam size for decoding")
)
model_size: str = Field(default="large-v2", description="Whisper model size")
lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
is_translate: bool = Field(default=False, description="Translate speech to English end-to-end")
beam_size: int = Field(default=5, ge=1, description="Beam size for decoding")
log_prob_threshold: float = Field(
Query(default=-1.0, description="Threshold for average log probability of sampled tokens")
default=-1.0,
description="Threshold for average log probability of sampled tokens"
)
no_speech_threshold: float = Field(
Query(
default=0.6,
ge=0.0,
le=1.0,
description="Threshold for detecting silence"
)
)
compute_type: str = Field(
Query(default="float16", description="Computation type for transcription")
)
best_of: int = Field(
Query(default=5, ge=1, description="Number of candidates when sampling")
)
patience: float = Field(
Query(default=1.0, gt=0, description="Beam search patience factor")
)
default=0.6,
ge=0.0,
le=1.0,
description="Threshold for detecting silence"
)
compute_type: str = Field(default="float16", description="Computation type for transcription")
best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling")
patience: float = Field(default=1.0, gt=0, description="Beam search patience factor")
condition_on_previous_text: bool = Field(
Query(default=True, description="Use previous output as prompt for next window")
default=True,
description="Use previous output as prompt for next window"
)
prompt_reset_on_temperature: float = Field(
Query(
default=0.5,
ge=0.0,
le=1.0,
description="Temperature threshold for resetting prompt"
)
)
initial_prompt: Optional[str] = Field(
Query(default=None, description="Initial prompt for first window")
default=0.5,
ge=0.0,
le=1.0,
description="Temperature threshold for resetting prompt"
)
initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window")
temperature: float = Field(
Query(default=0.0, ge=0.0, description="Temperature for sampling")
default=0.0,
ge=0.0,
description="Temperature for sampling"
)
compression_ratio_threshold: float = Field(
Query(default=2.4, gt=0, description="Threshold for gzip compression ratio")
)
length_penalty: float = Field(
Query(default=1.0, gt=0, description="Exponential length penalty")
)
repetition_penalty: float = Field(
Query(default=1.0, gt=0, description="Penalty for repeated tokens")
)
no_repeat_ngram_size: int = Field(
Query(default=0, ge=0, description="Size of n-grams to prevent repetition")
)
prefix: Optional[str] = Field(
Query(default=None, description="Prefix text for first window")
)
default=2.4,
gt=0,
description="Threshold for gzip compression ratio"
)
length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty")
repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens")
no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition")
prefix: Optional[str] = Field(default=None, description="Prefix text for first window")
suppress_blank: bool = Field(
Query(default=True, description="Suppress blank outputs at start of sampling")
)
suppress_tokens: Optional[Union[List[int], str]] = Field(
Query(default=[-1], description="Token IDs to suppress")
default=True,
description="Suppress blank outputs at start of sampling"
)
suppress_tokens: Optional[Union[List[int], str]] = Field(default=[-1], description="Token IDs to suppress")
max_initial_timestamp: float = Field(
Query(default=1.0, ge=0.0, description="Maximum initial timestamp")
)
word_timestamps: bool = Field(
Query(default=False, description="Extract word-level timestamps")
default=1.0,
ge=0.0,
description="Maximum initial timestamp"
)
word_timestamps: bool = Field(default=False, description="Extract word-level timestamps")
prepend_punctuations: Optional[str] = Field(
Query(
default="\"'“¿([{-",
description="Punctuations to merge with next word"
)
default="\"'“¿([{-",
description="Punctuations to merge with next word"
)
append_punctuations: Optional[str] = Field(
Query(
default="\"'.。,,!!??::”)]}、",
description="Punctuations to merge with previous word"
)
)
max_new_tokens: Optional[int] = Field(
Query(default=None, description="Maximum number of new tokens per chunk")
)
chunk_length: Optional[int] = Field(
Query(default=30, description="Length of audio segments in seconds")
default="\"'.。,,!!??::”)]}、",
description="Punctuations to merge with previous word"
)
max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk")
chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds")
hallucination_silence_threshold: Optional[float] = Field(
Query(default=None, description="Threshold for skipping silent periods in hallucination detection")
)
hotwords: Optional[str] = Field(
Query(default=None, description="Hotwords/hint phrases for the model")
default=None,
description="Threshold for skipping silent periods in hallucination detection"
)
hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model")
language_detection_threshold: Optional[float] = Field(
Query(default=None, description="Threshold for language detection probability")
default=None,
description="Threshold for language detection probability"
)
language_detection_segments: int = Field(
Query(default=1, gt=0, description="Number of segments for language detection")
)
batch_size: int = Field(
Query(default=24, gt=0, description="Batch size for processing")
default=1,
gt=0,
description="Number of segments for language detection"
)
batch_size: int = Field(default=24, gt=0, description="Batch size for processing")

@field_validator('lang')
def validate_lang(cls, v):
Expand Down

0 comments on commit 2efc746

Please sign in to comment.