Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

[WiP] Whisper Implementation #147

Closed
wants to merge 5 commits into from
Closed

Conversation

dbogunowicz
Copy link

No description provided.

Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments on the model definitions

Will follow up on the core implementation once we synchronize the branches and sync up with the current multimodal implementation upstream


self.scaling = self.head_dim**-0.5

self.k_proj = ColumnParallelLinear(self.d_model,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use QKVColumnLinear for this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or no because of the bias?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should extend QKVParallelLinear for this case?

Could also do this in a follow up PR

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-neuralmagic we could use the QKVParallelLinear in two scenarios:

  • encoder attention
  • decoder self-attention

we cannot use it for decoder cross attention, as queries are coming from a different source then keys and values.

I decided not to introduce the new module, since our current implementation (that is very close to T5) is still failing and we are not sure yet why. I'll start refactoring once we fix the current issues, as adding more moving parts may obfuscate the path to the solution.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robertgshaw2-neuralmagic but nothing prevents us from adding QKVParallelLinear support for the T5 model (once again, only for encoder attention and decoder self-attention)!

stride=2,
padding=1)

self.embed_positions = nn.Embedding(self.max_source_positions,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could this be VocabParallelEmbedding?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this part is definitely parallelizable, my omission, I did introduce the VocabParallelEmbedding here before 🥴

super().__init__()
self.d_model = config.d_model

self.embed_tokens = nn.Embedding(config.vocab_size, self.d_model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can these be VocabParallelEmbedding?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, let me check

# TODO: For now we are not implementing the sampling method
return hidden_states

def load_weights(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you match the style of llama in this function? They have been very consistent with this logic across functions

https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py#L363

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, seems like we could adhere to their convention

if kv_caches[0][0] is None:
hidden_states = None
else:
hidden_states = self.decoder(input_ids=decoder_input_ids,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the decoder_input_ids is something wrong? when we go into the else branch, there is no decoder_input_ids variable.

return hidden_states


class WhisperDecoderBlock(nn.Module):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just keep the same name with the WhisperDecoderLayer as in transformers library?

@dbogunowicz dbogunowicz changed the title Whisper Implementation [WiP] Whisper Implementation Apr 2, 2024
self.prefix = prefix
self.multi_modal_data = multi_modal_data
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.multi_modal_data = multi_modal_data

@andy-neuma
Copy link
Member

stale

@andy-neuma andy-neuma closed this Aug 12, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants