diff --git a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift index 338801e..0abc80a 100644 --- a/Examples/WhisperAX/WhisperAX/Views/ContentView.swift +++ b/Examples/WhisperAX/WhisperAX/Views/ContentView.swift @@ -681,26 +681,6 @@ struct ContentView: View { } } - func requestMicrophoneIfNeeded() async -> Bool { - let microphoneStatus = AVCaptureDevice.authorizationStatus(for: .audio) - - switch microphoneStatus { - case .notDetermined: - return await withCheckedContinuation { continuation in - AVCaptureDevice.requestAccess(for: .audio) { granted in - continuation.resume(returning: granted) - } - } - case .restricted, .denied: - print("Microphone access denied") - return false - case .authorized: - return true - @unknown default: - fatalError("Unknown authorization status") - } - } - func loadModel(_ model: String, redownload: Bool = false) { print("Selected Model: \(UserDefaults.standard.string(forKey: "selectedModel") ?? "nil")") @@ -872,7 +852,7 @@ struct ContentView: View { func startRecording(_ loop: Bool) { if let audioProcessor = whisperKit?.audioProcessor { Task(priority: .userInitiated) { - guard await requestMicrophoneIfNeeded() else { + guard await AudioProcessor.requestRecordPermission() else { print("Microphone access was not granted.") return } diff --git a/Examples/WhisperAX/WhisperAXWatchApp/WhisperAXExampleView.swift b/Examples/WhisperAX/WhisperAXWatchApp/WhisperAXExampleView.swift index f58b8b9..5e20678 100644 --- a/Examples/WhisperAX/WhisperAXWatchApp/WhisperAXExampleView.swift +++ b/Examples/WhisperAX/WhisperAXWatchApp/WhisperAXExampleView.swift @@ -665,26 +665,6 @@ struct WhisperAXWatchView: View { } } } - -// func requestMicrophoneIfNeeded() async -> Bool { -// let microphoneStatus = AVCaptureDevice.authorizationStatus(for: .audio) -// -// switch microphoneStatus { -// case .notDetermined: -// return await withCheckedContinuation { continuation in -// AVCaptureDevice.requestAccess(for: .audio) { granted in -// continuation.resume(returning: granted) -// } -// } -// case .restricted, .denied: -// print("Microphone access denied") -// return false -// case .authorized: -// return true -// @unknown default: -// fatalError("Unknown authorization status") -// } -// } } #Preview { diff --git a/README.md b/README.md index b5c8383..ce3a088 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,11 @@ You can then run them via the CLI with: swift run transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --audio-path "path/to/your/audio.{wav,mp3,m4a,flac}" ``` -Which should print a transcription of the audio file. +Which should print a transcription of the audio file. If you would like to stream the audio directly from a microphone, use: + +```bash +swift run transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --stream +``` ## Contributing & Roadmap diff --git a/Sources/WhisperKit/Core/AudioProcessor.swift b/Sources/WhisperKit/Core/AudioProcessor.swift index c2cafc2..fcc6d6a 100644 --- a/Sources/WhisperKit/Core/AudioProcessor.swift +++ b/Sources/WhisperKit/Core/AudioProcessor.swift @@ -302,6 +302,10 @@ public class AudioProcessor: NSObject, AudioProcessing { return convertedArray } + public static func requestRecordPermission() async -> Bool { + await AVAudioApplication.requestRecordPermission() + } + deinit { stopRecording() } diff --git a/Sources/WhisperKit/Core/AudioStreamTranscriber.swift b/Sources/WhisperKit/Core/AudioStreamTranscriber.swift new file mode 100644 index 0000000..ed460f6 --- /dev/null +++ b/Sources/WhisperKit/Core/AudioStreamTranscriber.swift @@ -0,0 +1,225 @@ +// For licensing see accompanying LICENSE.md file. +// Copyright © 2024 Argmax, Inc. All rights reserved. + +import Foundation + +extension AudioStreamTranscriber { + public struct State { + public var isRecording: Bool = false + public var currentFallbacks: Int = 0 + public var lastBufferSize: Int = 0 + public var lastConfirmedSegmentEndSeconds: Float = 0 + public var bufferEnergy: [Float] = [] + public var currentText: String = "" + public var confirmedSegments: [TranscriptionSegment] = [] + public var unconfirmedSegments: [TranscriptionSegment] = [] + public var unconfirmedText: [String] = [] + } +} + +public typealias AudioStreamTranscriberCallback = (AudioStreamTranscriber.State, AudioStreamTranscriber.State) -> Void + +/// Responsible for streaming audio from the microphone, processing it, and transcribing it in real-time. +public actor AudioStreamTranscriber { + private var state: AudioStreamTranscriber.State = .init() { + didSet { + stateChangeCallback?(oldValue, state) + } + } + private let stateChangeCallback: AudioStreamTranscriberCallback? + + private let requiredSegmentsForConfirmation: Int + private let useVAD: Bool + private let silenceThreshold: Float + private let compressionCheckWindow: Int + private let audioProcessor: any AudioProcessing + private let transcriber: any Transcriber + private let decodingOptions: DecodingOptions + + public init( + audioProcessor: any AudioProcessing, + transcriber: any Transcriber, + decodingOptions: DecodingOptions, + requiredSegmentsForConfirmation: Int = 2, + silenceThreshold: Float = 0.3, + compressionCheckWindow: Int = 20, + useVAD: Bool = true, + stateChangeCallback: AudioStreamTranscriberCallback? + ) { + self.audioProcessor = audioProcessor + self.transcriber = transcriber + self.decodingOptions = decodingOptions + self.requiredSegmentsForConfirmation = requiredSegmentsForConfirmation + self.silenceThreshold = silenceThreshold + self.compressionCheckWindow = compressionCheckWindow + self.useVAD = useVAD + self.stateChangeCallback = stateChangeCallback + } + + public func startStreamTranscription() async throws { + guard !state.isRecording else { return } + guard await AudioProcessor.requestRecordPermission() else { + Logging.error("Microphone access was not granted.") + return + } + state.isRecording = true + try audioProcessor.startRecordingLive { [weak self] _ in + Task { [weak self] in + await self?.onAudioBufferCallback() + } + } + await realtimeLoop() + Logging.info("Realtime transcription has started") + } + + public func stopStreamTranscription() { + state.isRecording = false + audioProcessor.stopRecording() + Logging.info("Realtime transcription has ended") + } + + private func realtimeLoop() async { + while state.isRecording { + do { + try await transcribeCurrentBuffer() + } catch { + Logging.error("Error: \(error.localizedDescription)") + break + } + } + } + + private func onAudioBufferCallback() { + state.bufferEnergy = audioProcessor.relativeEnergy + } + + private func onProgressCallback(_ progress: TranscriptionProgress) { + let fallbacks = Int(progress.timings.totalDecodingFallbacks) + if progress.text.count < state.currentText.count { + if fallbacks == state.currentFallbacks { + state.unconfirmedText.append(state.currentText) + } else { + Logging.info("Fallback occured: \(fallbacks)") + } + } + state.currentText = progress.text + state.currentFallbacks = fallbacks + } + + private func transcribeCurrentBuffer() async throws { + // Retrieve the current audio buffer from the audio processor + let currentBuffer = audioProcessor.audioSamples + + // Calculate the size and duration of the next buffer segment + let nextBufferSize = currentBuffer.count - state.lastBufferSize + let nextBufferSeconds = Float(nextBufferSize) / Float(WhisperKit.sampleRate) + + // Only run the transcribe if the next buffer has at least 1 second of audio + guard nextBufferSeconds > 1 else { + if state.currentText == "" { + state.currentText = "Waiting for speech..." + } + return try await Task.sleep(nanoseconds: 100_000_000) // sleep for 100ms for next buffer + } + + if useVAD { + // Retrieve the current relative energy values from the audio processor + let currentRelativeEnergy = audioProcessor.relativeEnergy + + // Calculate the number of energy values to consider based on the duration of the next buffer + // Each energy value corresponds to 1 buffer length (100ms of audio), hence we divide by 0.1 + let energyValuesToConsider = Int(nextBufferSeconds / 0.1) + + // Extract the relevant portion of energy values from the currentRelativeEnergy array + let nextBufferEnergies = currentRelativeEnergy.suffix(energyValuesToConsider) + + // Determine the number of energy values to check for voice presence + // Considering up to the last 1 second of audio, which translates to 10 energy values + let numberOfValuesToCheck = max(10, nextBufferEnergies.count - 10) + + // Check if any of the energy values in the considered range exceed the silence threshold + // This indicates the presence of voice in the buffer + let voiceDetected = nextBufferEnergies.prefix(numberOfValuesToCheck).contains { $0 > Float(silenceThreshold) } + + // Only run the transcribe if the next buffer has voice + if !voiceDetected { + Logging.debug("No voice detected, skipping transcribe") + if state.currentText == "" { + state.currentText = "Waiting for speech..." + } + // Sleep for 100ms and check the next buffer + return try await Task.sleep(nanoseconds: 100_000_000) + } + } + + // Run transcribe + state.lastBufferSize = currentBuffer.count + + let transcription = try await transcribeAudioSamples(Array(currentBuffer)) + + state.currentText = "" + state.unconfirmedText = [] + guard let segments = transcription?.segments else { + return + } + + // Logic for moving segments to confirmedSegments + if segments.count > requiredSegmentsForConfirmation { + // Calculate the number of segments to confirm + let numberOfSegmentsToConfirm = segments.count - requiredSegmentsForConfirmation + + // Confirm the required number of segments + let confirmedSegmentsArray = Array(segments.prefix(numberOfSegmentsToConfirm)) + let remainingSegments = Array(segments.suffix(requiredSegmentsForConfirmation)) + + // Update lastConfirmedSegmentEnd based on the last confirmed segment + if let lastConfirmedSegment = confirmedSegmentsArray.last, lastConfirmedSegment.end > state.lastConfirmedSegmentEndSeconds { + state.lastConfirmedSegmentEndSeconds = lastConfirmedSegment.end + + // Add confirmed segments to the confirmedSegments array + if !state.confirmedSegments.contains(confirmedSegmentsArray) { + state.confirmedSegments.append(contentsOf: confirmedSegmentsArray) + } + } + + // Update transcriptions to reflect the remaining segments + state.unconfirmedSegments = remainingSegments + } else { + // Handle the case where segments are fewer or equal to required + state.unconfirmedSegments = segments + } + } + + private func transcribeAudioSamples(_ samples: [Float]) async throws -> TranscriptionResult? { + var options = decodingOptions + options.clipTimestamps = [state.lastConfirmedSegmentEndSeconds] + let checkWindow = compressionCheckWindow + return try await transcriber.transcribe(audioArray: samples, decodeOptions: options) { [weak self] progress in + Task { [weak self] in + await self?.onProgressCallback(progress) + } + return AudioStreamTranscriber.shouldStopEarly(progress: progress, options: options, compressionCheckWindow: checkWindow) + } + } + + private static func shouldStopEarly( + progress: TranscriptionProgress, + options: DecodingOptions, + compressionCheckWindow: Int + ) -> Bool? { + let currentTokens = progress.tokens + if currentTokens.count > compressionCheckWindow { + let checkTokens: [Int] = currentTokens.suffix(compressionCheckWindow) + let compressionRatio = compressionRatio(of: checkTokens) + if compressionRatio > options.compressionRatioThreshold ?? 0.0 { + return false + } + } + if let avgLogprob = progress.avgLogprob, let logProbThreshold = options.logProbThreshold { + if avgLogprob < logProbThreshold { + return false + } + } + return nil + } +} diff --git a/Sources/WhisperKit/Core/WhisperKit.swift b/Sources/WhisperKit/Core/WhisperKit.swift index aadba12..171a161 100644 --- a/Sources/WhisperKit/Core/WhisperKit.swift +++ b/Sources/WhisperKit/Core/WhisperKit.swift @@ -9,8 +9,13 @@ import Hub import TensorUtils import Tokenizers +public protocol Transcriber { + func transcribe(audioPath: String, decodeOptions: DecodingOptions?, callback: TranscriptionCallback) async throws -> TranscriptionResult? + func transcribe(audioArray: [Float], decodeOptions: DecodingOptions?, callback: TranscriptionCallback) async throws -> TranscriptionResult? +} + @available(macOS 14, iOS 17, watchOS 10, visionOS 1, *) -public class WhisperKit { +public class WhisperKit: Transcriber { // Models public var modelVariant: ModelVariant = .tiny public var modelState: ModelState = .unloaded diff --git a/Sources/WhisperKitCLI/transcribe.swift b/Sources/WhisperKitCLI/transcribe.swift index 08450a3..95022f3 100644 --- a/Sources/WhisperKitCLI/transcribe.swift +++ b/Sources/WhisperKitCLI/transcribe.swift @@ -10,6 +10,7 @@ import WhisperKit @available(macOS 14, iOS 17, watchOS 10, visionOS 1, *) @main struct WhisperKitCLI: AsyncParsableCommand { + @Option(help: "Path to audio file") var audioPath: String = "Tests/WhisperKitTests/Resources/jfk.wav" @@ -73,6 +74,9 @@ struct WhisperKitCLI: AsyncParsableCommand { @Option(help: "Directory to save the report") var reportPath: String = "." + @Flag(help: "Process audio directly from the microphone") + var stream: Bool = false + func transcribe(audioPath: String, modelPath: String) async throws { let resolvedModelPath = resolveAbsolutePath(modelPath) guard FileManager.default.fileExists(atPath: resolvedModelPath) else { @@ -89,12 +93,14 @@ struct WhisperKitCLI: AsyncParsableCommand { textDecoderCompute: textDecoderComputeUnits.asMLComputeUnits ) + print("Initializing models...") let whisperKit = try await WhisperKit( modelFolder: modelPath, computeOptions: computeOptions, verbose: verbose, logLevel: .debug ) + print("Models initialized") let options = DecodingOptions( verbose: verbose, @@ -153,14 +159,76 @@ struct WhisperKitCLI: AsyncParsableCommand { } } - func run() async throws { - let audioURL = URL(fileURLWithPath: audioPath) + func transcribeStream(modelPath: String) async throws { + let computeOptions = ModelComputeOptions( + audioEncoderCompute: audioEncoderComputeUnits.asMLComputeUnits, + textDecoderCompute: textDecoderComputeUnits.asMLComputeUnits + ) - if verbose { - print("Transcribing audio at \(audioURL)") + print("Initializing models...") + let whisperKit = try await WhisperKit( + modelFolder: modelPath, + computeOptions: computeOptions, + verbose: verbose, + logLevel: .debug + ) + print("Models initialized") + + let decodingOptions = DecodingOptions( + verbose: verbose, + task: .transcribe, + language: language, + temperature: temperature, + temperatureIncrementOnFallback: temperatureIncrementOnFallback, + temperatureFallbackCount: 3, // limit fallbacks for realtime + sampleLength: 224, // reduced sample length for realtime + topK: bestOf, + usePrefillPrompt: usePrefillPrompt, + usePrefillCache: usePrefillCache, + skipSpecialTokens: skipSpecialTokens, + withoutTimestamps: withoutTimestamps, + clipTimestamps: [], + suppressBlank: false, + supressTokens: supressTokens, + compressionRatioThreshold: compressionRatioThreshold ?? 2.4, + logProbThreshold: logprobThreshold ?? -1.0, + noSpeechThreshold: noSpeechThreshold ?? 0.6 + ) + + let audioStreamTranscriber = AudioStreamTranscriber( + audioProcessor: whisperKit.audioProcessor, + transcriber: whisperKit, + decodingOptions: decodingOptions + ) { oldState, newState in + guard oldState.currentText != newState.currentText || + oldState.unconfirmedSegments != newState.unconfirmedSegments || + oldState.confirmedSegments != newState.confirmedSegments else { + return + } + // TODO: Print only net new text without any repeats + print("---") + for segment in newState.confirmedSegments { + print("Confirmed segment: \(segment.text)") + } + for segment in newState.unconfirmedSegments { + print("Unconfirmed segment: \(segment.text)") + } + print("Current text: \(newState.currentText)") } + print("Transcribing audio stream, press Ctrl+C to stop.") + try await audioStreamTranscriber.startStreamTranscription() + } - try await transcribe(audioPath: audioPath, modelPath: modelPath) + mutating func run() async throws { + if stream { + try await transcribeStream(modelPath: modelPath) + } else { + let audioURL = URL(fileURLWithPath: audioPath) + if verbose { + print("Transcribing audio at \(audioURL)") + } + try await transcribe(audioPath: audioPath, modelPath: modelPath) + } } }