Skip to content

Commit

Permalink
Merge pull request #68 from jkrukowski/cli-cleanup
Browse files Browse the repository at this point in the history
WhisperKit CLI cleanup
  • Loading branch information
ZachNagengast authored Mar 12, 2024
2 parents 9c4d8e0 + e556132 commit eca4a2e
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 138 deletions.
4 changes: 1 addition & 3 deletions Sources/WhisperKit/Core/AudioProcessor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -479,9 +479,7 @@ public extension AudioProcessor {
&inputDeviceID,
UInt32(MemoryLayout<AudioDeviceID>.size)
)

let format = inputNode.outputFormat(forBus: 0)


if error != noErr {
Logging.error("Error setting Audio Unit property: \(error)")
} else {
Expand Down
72 changes: 72 additions & 0 deletions Sources/WhisperKitCLI/CLIArguments.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

import ArgumentParser

struct CLIArguments: ParsableArguments {
@Option(help: "Path to audio file")
var audioPath: String = "Tests/WhisperKitTests/Resources/jfk.wav"

@Option(help: "Path of model files")
var modelPath: String = "Models/whisperkit-coreml/openai_whisper-tiny"

@Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine

@Option(help: "Compute units for text decoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
var textDecoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine

@Flag(help: "Verbose mode")
var verbose: Bool = false

@Option(help: "Language spoken in the audio")
var language: String?

@Option(help: "Temperature to use for sampling")
var temperature: Float = 0

@Option(help: "Temperature to increase on fallbacks during decoding")
var temperatureIncrementOnFallback: Float = 0.2

@Option(help: "Number of times to increase temperature when falling back during decoding")
var temperatureFallbackCount: Int = 5

@Option(help: "Number of candidates when sampling with non-zero temperature")
var bestOf: Int = 5

@Flag(help: "Force initial prompt tokens based on language, task, and timestamp options")
var usePrefillPrompt: Bool = false

@Flag(help: "Use decoder prefill data for faster initial decoding")
var usePrefillCache: Bool = false

@Flag(help: "Skip special tokens in the output")
var skipSpecialTokens: Bool = false

@Flag(help: "Force no timestamps when decoding")
var withoutTimestamps: Bool = false

@Flag(help: "Add timestamps for each word in the output")
var wordTimestamps: Bool = false

@Argument(help: "Supress given tokens in the output")
var supressTokens: [Int] = []

@Option(help: "Gzip compression ratio threshold for decoding failure")
var compressionRatioThreshold: Float?

@Option(help: "Average log probability threshold for decoding failure")
var logprobThreshold: Float?

@Option(help: "Probability threshold to consider a segment as silence")
var noSpeechThreshold: Float?

@Flag(help: "Output a report of the results")
var report: Bool = false

@Option(help: "Directory to save the report")
var reportPath: String = "."

@Flag(help: "Process audio directly from the microphone")
var stream: Bool = false
}
20 changes: 20 additions & 0 deletions Sources/WhisperKitCLI/CLIUtils.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

import ArgumentParser
import CoreML
import Foundation
import WhisperKit

enum ComputeUnits: String, ExpressibleByArgument, CaseIterable {
case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine, random
var asMLComputeUnits: MLComputeUnits {
switch self {
case .all: return .all
case .cpuAndGPU: return .cpuAndGPU
case .cpuOnly: return .cpuOnly
case .cpuAndNeuralEngine: return .cpuAndNeuralEngine
case .random: return Bool.random() ? .cpuAndGPU : .cpuAndNeuralEngine
}
}
}
199 changes: 64 additions & 135 deletions Sources/WhisperKitCLI/transcribe.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,82 +4,33 @@
import ArgumentParser
import CoreML
import Foundation

import WhisperKit

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
@main
struct WhisperKitCLI: AsyncParsableCommand {
@Option(help: "Path to audio file")
var audioPath: String = "Tests/WhisperKitTests/Resources/jfk.wav"

@Option(help: "Path of model files")
var modelPath: String = "Models/whisperkit-coreml/openai_whisper-tiny"

@Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine

@Option(help: "Compute units for text decoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
var textDecoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine

@Flag(help: "Verbose mode")
var verbose: Bool = false

@Option(help: "Task to perform (transcribe or translate)")
var task: String = "transcribe"

@Option(help: "Language spoken in the audio")
var language: String?

@Option(help: "Temperature to use for sampling")
var temperature: Float = 0

@Option(help: "Temperature to increase on fallbacks during decoding")
var temperatureIncrementOnFallback: Float = 0.2

@Option(help: "Number of times to increase temperature when falling back during decoding")
var temperatureFallbackCount: Int = 5

@Option(help: "Number of candidates when sampling with non-zero temperature")
var bestOf: Int = 5

@Flag(help: "Force initial prompt tokens based on language, task, and timestamp options")
var usePrefillPrompt: Bool = false

@Flag(help: "Use decoder prefill data for faster initial decoding")
var usePrefillCache: Bool = false

@Flag(help: "Skip special tokens in the output")
var skipSpecialTokens: Bool = false

@Flag(help: "Force no timestamps when decoding")
var withoutTimestamps: Bool = false

@Flag(help: "Add timestamps for each word in the output")
var wordTimestamps: Bool = false

@Argument(help: "Supress given tokens in the output")
var supressTokens: [Int] = []

@Option(help: "Gzip compression ratio threshold for decoding failure")
var compressionRatioThreshold: Float?

@Option(help: "Average log probability threshold for decoding failure")
var logprobThreshold: Float?

@Option(help: "Probability threshold to consider a segment as silence")
var noSpeechThreshold: Float?
static let configuration = CommandConfiguration(
commandName: "transcribe",
abstract: "WhisperKit Transcribe CLI",
discussion: "Swift native speech recognition with Whisper for Apple Silicon"
)

@Flag(help: "Output a report of the results")
var report: Bool = false
@OptionGroup
var cliArguments: CLIArguments

@Option(help: "Directory to save the report")
var reportPath: String = "."

@Flag(help: "Process audio directly from the microphone")
var stream: Bool = false
mutating func run() async throws {
if cliArguments.stream {
try await transcribeStream(modelPath: cliArguments.modelPath)
} else {
let audioURL = URL(fileURLWithPath: cliArguments.audioPath)
if cliArguments.verbose {
print("Transcribing audio at \(audioURL)")
}
try await transcribe(audioPath: cliArguments.audioPath, modelPath: cliArguments.modelPath)
}
}

func transcribe(audioPath: String, modelPath: String) async throws {
private func transcribe(audioPath: String, modelPath: String) async throws {
let resolvedModelPath = resolveAbsolutePath(modelPath)
guard FileManager.default.fileExists(atPath: resolvedModelPath) else {
fatalError("Model path does not exist \(resolvedModelPath)")
Expand All @@ -91,49 +42,52 @@ struct WhisperKitCLI: AsyncParsableCommand {
}

let computeOptions = ModelComputeOptions(
audioEncoderCompute: audioEncoderComputeUnits.asMLComputeUnits,
textDecoderCompute: textDecoderComputeUnits.asMLComputeUnits
audioEncoderCompute: cliArguments.audioEncoderComputeUnits.asMLComputeUnits,
textDecoderCompute: cliArguments.textDecoderComputeUnits.asMLComputeUnits
)

print("Initializing models...")
let whisperKit = try await WhisperKit(
modelFolder: modelPath,
computeOptions: computeOptions,
verbose: verbose,
verbose: cliArguments.verbose,
logLevel: .debug
)
print("Models initialized")

let options = DecodingOptions(
verbose: verbose,
verbose: cliArguments.verbose,
task: .transcribe,
language: language,
temperature: temperature,
temperatureIncrementOnFallback: temperatureIncrementOnFallback,
temperatureFallbackCount: temperatureFallbackCount,
topK: bestOf,
usePrefillPrompt: usePrefillPrompt,
usePrefillCache: usePrefillCache,
skipSpecialTokens: skipSpecialTokens,
withoutTimestamps: withoutTimestamps,
wordTimestamps: wordTimestamps,
supressTokens: supressTokens,
compressionRatioThreshold: compressionRatioThreshold,
logProbThreshold: logprobThreshold,
noSpeechThreshold: noSpeechThreshold
language: cliArguments.language,
temperature: cliArguments.temperature,
temperatureIncrementOnFallback: cliArguments.temperatureIncrementOnFallback,
temperatureFallbackCount: cliArguments.temperatureFallbackCount,
topK: cliArguments.bestOf,
usePrefillPrompt: cliArguments.usePrefillPrompt,
usePrefillCache: cliArguments.usePrefillCache,
skipSpecialTokens: cliArguments.skipSpecialTokens,
withoutTimestamps: cliArguments.withoutTimestamps,
wordTimestamps: cliArguments.wordTimestamps,
supressTokens: cliArguments.supressTokens,
compressionRatioThreshold: cliArguments.compressionRatioThreshold,
logProbThreshold: cliArguments.logprobThreshold,
noSpeechThreshold: cliArguments.noSpeechThreshold
)

let transcribeResult = try await whisperKit.transcribe(audioPath: resolvedAudioPath, decodeOptions: options)
let transcribeResult = try await whisperKit.transcribe(
audioPath: resolvedAudioPath,
decodeOptions: options
)

let transcription = transcribeResult?.text ?? "Transcription failed"

if report, let result = transcribeResult {
if cliArguments.report, let result = transcribeResult {
let audioFileName = URL(fileURLWithPath: audioPath).lastPathComponent.components(separatedBy: ".").first!

// Write SRT (SubRip Subtitle Format) for the transcription
let srtReportWriter = WriteSRT(outputDir: reportPath)
let srtReportWriter = WriteSRT(outputDir: cliArguments.reportPath)
let savedSrtReport = srtReportWriter.write(result: result, to: audioFileName)
if verbose {
if cliArguments.verbose {
switch savedSrtReport {
case let .success(reportPath):
print("\n\nSaved SRT Report: \n\n\(reportPath)\n")
Expand All @@ -143,9 +97,9 @@ struct WhisperKitCLI: AsyncParsableCommand {
}

// Write JSON for all metadata
let jsonReportWriter = WriteJSON(outputDir: reportPath)
let jsonReportWriter = WriteJSON(outputDir: cliArguments.reportPath)
let savedJsonReport = jsonReportWriter.write(result: result, to: audioFileName)
if verbose {
if cliArguments.verbose {
switch savedJsonReport {
case let .success(reportPath):
print("\n\nSaved JSON Report: \n\n\(reportPath)\n")
Expand All @@ -155,47 +109,47 @@ struct WhisperKitCLI: AsyncParsableCommand {
}
}

if verbose {
if cliArguments.verbose {
print("\n\nTranscription: \n\n\(transcription)\n")
} else {
print(transcription)
}
}

func transcribeStream(modelPath: String) async throws {
private func transcribeStream(modelPath: String) async throws {
let computeOptions = ModelComputeOptions(
audioEncoderCompute: audioEncoderComputeUnits.asMLComputeUnits,
textDecoderCompute: textDecoderComputeUnits.asMLComputeUnits
audioEncoderCompute: cliArguments.audioEncoderComputeUnits.asMLComputeUnits,
textDecoderCompute: cliArguments.textDecoderComputeUnits.asMLComputeUnits
)

print("Initializing models...")
let whisperKit = try await WhisperKit(
modelFolder: modelPath,
computeOptions: computeOptions,
verbose: verbose,
verbose: cliArguments.verbose,
logLevel: .debug
)
print("Models initialized")

let decodingOptions = DecodingOptions(
verbose: verbose,
verbose: cliArguments.verbose,
task: .transcribe,
language: language,
temperature: temperature,
temperatureIncrementOnFallback: temperatureIncrementOnFallback,
language: cliArguments.language,
temperature: cliArguments.temperature,
temperatureIncrementOnFallback: cliArguments.temperatureIncrementOnFallback,
temperatureFallbackCount: 3, // limit fallbacks for realtime
sampleLength: 224, // reduced sample length for realtime
topK: bestOf,
usePrefillPrompt: usePrefillPrompt,
usePrefillCache: usePrefillCache,
skipSpecialTokens: skipSpecialTokens,
withoutTimestamps: withoutTimestamps,
topK: cliArguments.bestOf,
usePrefillPrompt: cliArguments.usePrefillPrompt,
usePrefillCache: cliArguments.usePrefillCache,
skipSpecialTokens: cliArguments.skipSpecialTokens,
withoutTimestamps: cliArguments.withoutTimestamps,
clipTimestamps: [],
suppressBlank: false,
supressTokens: supressTokens,
compressionRatioThreshold: compressionRatioThreshold ?? 2.4,
logProbThreshold: logprobThreshold ?? -1.0,
noSpeechThreshold: noSpeechThreshold ?? 0.6
supressTokens: cliArguments.supressTokens,
compressionRatioThreshold: cliArguments.compressionRatioThreshold ?? 2.4,
logProbThreshold: cliArguments.logprobThreshold ?? -1.0,
noSpeechThreshold: cliArguments.noSpeechThreshold ?? 0.6
)

let audioStreamTranscriber = AudioStreamTranscriber(
Expand All @@ -222,29 +176,4 @@ struct WhisperKitCLI: AsyncParsableCommand {
print("Transcribing audio stream, press Ctrl+C to stop.")
try await audioStreamTranscriber.startStreamTranscription()
}

mutating func run() async throws {
if stream {
try await transcribeStream(modelPath: modelPath)
} else {
let audioURL = URL(fileURLWithPath: audioPath)
if verbose {
print("Transcribing audio at \(audioURL)")
}
try await transcribe(audioPath: audioPath, modelPath: modelPath)
}
}
}

enum ComputeUnits: String, ExpressibleByArgument, CaseIterable {
case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine, random
var asMLComputeUnits: MLComputeUnits {
switch self {
case .all: return .all
case .cpuAndGPU: return .cpuAndGPU
case .cpuOnly: return .cpuOnly
case .cpuAndNeuralEngine: return .cpuAndNeuralEngine
case .random: return Bool.random() ? .cpuAndGPU : .cpuAndNeuralEngine
}
}
}

0 comments on commit eca4a2e

Please sign in to comment.