-
Notifications
You must be signed in to change notification settings - Fork 10
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments on the model definitions
Will follow up on the core implementation once we synchronize the branches and sync up with the current multimodal implementation upstream
|
||
self.scaling = self.head_dim**-0.5 | ||
|
||
self.k_proj = ColumnParallelLinear(self.d_model, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use QKVColumnLinear
for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or no because of the bias?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we should extend QKVParallelLinear for this case?
Could also do this in a follow up PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robertgshaw2-neuralmagic we could use the QKVParallelLinear
in two scenarios:
- encoder attention
- decoder self-attention
we cannot use it for decoder cross attention, as queries are coming from a different source then keys and values.
I decided not to introduce the new module, since our current implementation (that is very close to T5
) is still failing and we are not sure yet why. I'll start refactoring once we fix the current issues, as adding more moving parts may obfuscate the path to the solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robertgshaw2-neuralmagic but nothing prevents us from adding QKVParallelLinear
support for the T5 model (once again, only for encoder attention and decoder self-attention)!
stride=2, | ||
padding=1) | ||
|
||
self.embed_positions = nn.Embedding(self.max_source_positions, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could this be VocabParallelEmbedding
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, this part is definitely parallelizable, my omission, I did introduce the VocabParallelEmbedding
here before 🥴
super().__init__() | ||
self.d_model = config.d_model | ||
|
||
self.embed_tokens = nn.Embedding(config.vocab_size, self.d_model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can these be VocabParallelEmbedding
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps, let me check
# TODO: For now we are not implementing the sampling method | ||
return hidden_states | ||
|
||
def load_weights( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you match the style of llama in this function? They have been very consistent with this logic across functions
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py#L363
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good, seems like we could adhere to their convention
if kv_caches[0][0] is None: | ||
hidden_states = None | ||
else: | ||
hidden_states = self.decoder(input_ids=decoder_input_ids, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the decoder_input_ids
is something wrong? when we go into the else
branch, there is no decoder_input_ids
variable.
return hidden_states | ||
|
||
|
||
class WhisperDecoderBlock(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not just keep the same name with the WhisperDecoderLayer
as in transformers library?
self.prefix = prefix | ||
self.multi_modal_data = multi_modal_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.multi_modal_data = multi_modal_data |
stale |
No description provided.