You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello,
Thank you so much for this amazing work!
I am trying to calculate the FLOPS for whisper base, however I receive an error saying that either decoder_input_ids or decoder_inputs_embeds are required. Could you please guide me on how to resolve this?
model_name = "openai/whisper-base" # Replace with the specific model you are using
model = WhisperModel.from_pretrained(model_name)
tokenizer = WhisperTokenizer.from_pretrained(model_name)
batch_size = 1
sequence_length = 3000
feature_dim = 80
encoder_input_tensor = torch.randn(batch_size, feature_dim, sequence_length)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
encoder_input_tensor = encoder_input_tensor.to(device)
flops = FlopCountAnalysis(model, encoder_input_tensor)
print(f"Total FLOPs: {flops.total()}")
model_name = "openai/whisper-base"
model = WhisperModel.from_pretrained(model_name)
tokenizer = WhisperTokenizer.from_pretrained(model_name)
batch_size = 1
sequence_length = 3000 # Typical length for an audio sequence
feature_dim = 80 # Feature dimension for Whisper model
encoder_input_tensor = torch.randn(batch_size, feature_dim, sequence_length)
decoder_input_ids = torch.tensor([[tokenizer.pad_token_id]], dtype=torch.long).repeat(batch_size, 1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
encoder_input_tensor = encoder_input_tensor.to(device)
flops = FlopCountAnalysis(model, encoder_input_tensor)
print(f"Total FLOPs: {flops.total()}")```
ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds
The text was updated successfully, but these errors were encountered:
Hello,
Thank you so much for this amazing work!
I am trying to calculate the FLOPS for whisper base, however I receive an error saying that either decoder_input_ids or decoder_inputs_embeds are required. Could you please guide me on how to resolve this?
The text was updated successfully, but these errors were encountered: