Skip to content

Commit

Permalink
Add MLLama (huggingface#33703)
Browse files Browse the repository at this point in the history
* current changes

* nit

* Add cross_attenttion_mask to processor

* multi-image fixed

* Add cross_attenttion_mask to processor

* cross attn works in all cases

* WIP refactoring function for image processor

* WIP refactoring image processor functions

* Refactor preprocess to use global loops instead of list nested list comps

* Docstrings

* Add channels unification

* fix dtype issues

* Update docsrings and format

* Consistent max_image_tiles

* current script

* updates

* Add convert to rgb

* Add image processor tests

* updates!

* update

* god damn it I am dumb sometimes

* Precompute aspect ratios

* now this works, full match

* fix 😉

* nits

* style

* fix model and conversion

* nit

* nit

* kinda works

* hack for sdpa non-contiguous bias

* nits here and there

* latest c hanges

* merge?

* run forward

* Add aspect_ratio_mask

* vision attention mask

* update script and config variable names

* nit

* nits

* be able to load

* style

* nits

* there

* nits

* make forward run

* small update

* enable generation multi-turn

* nit

* nit

* Clean up a bit for errors and typos

* A bit more constant fixes

* 90B keys and shapes match

* Fix for 11B model

* Fixup, remove debug part

* Docs

* Make max_aspect_ratio_id to be minimal

* Update image processing code to match new implementation

* Adjust conversion for final checkpoint state

* Change dim in repeat_interleave (accordig to meta code)

* tmp fix for num_tiles

* Fix for conversion (gate<->up, q/k_proj rope permute)

* nits

* codestyle

* Vision encoder fixes

* pass cross attn mask further

* Refactor aspect ratio mask

* Disable text-only generation

* Fix cross attention layers order, remove q/k norm rotation for cross atention layers

* Refactor gated position embeddings

* fix bugs but needs test with new weights

* rope scaling should be llama3

* Fix rope scaling name

* Remove debug for linear layer

* fix copies

* Make mask prepare private func

* Remove linear patch embed

* Make precomputed embeddings as nn.Embedding module

* MllamaPrecomputedAspectRatioEmbedding with config init

* Remove unused self.output_dim

* nit, intermediate layers

* Rename ln and pos_embed

* vision_chunk_size -> image_size

* return_intermediate -> intermediate_layers_indices

* vision_input_dim -> hidden_size

* Fix copied from statements

* fix most tests

* Fix more copied from

* layer_id->layer_idx

* Comment

* Fix tests for processor

* Copied from for _prepare_4d_causal_attention_mask_with_cache_position

* Style fix

* Add MllamaForCausalLM

* WIP fixing tests

* Remove duplicated layers

* Remove dummy file

* Fix style

* Fix consistency

* Fix some TODOs

* fix language_model instantiation, add docstring

* Move docstring, remove todos for precomputed embeds (we cannot init them properly)

* Add initial docstrings

* Fix

* fix some tests

* lets skip these

* nits, remove print, style

* Add one more copied from

* Improve test message

* Make validate func private

* Fix dummy objects

* Refactor `data_format` a bit + add comment

* typos/nits

Co-authored-by: Pablo Montalvo <[email protected]>

* fix dummy objects and imports

* Add chat template config json

* remove num_kv_heads from vision attention

* fix

* move some commits and add more tests

* fix test

* Remove `update_key_name` from modeling utils

* remove num-kv-heads again

* some prelimiary docs

* Update chat template + tests

* nit, conversion script max_num_tiles from params

* Fix warning for text-only generation

* Update conversion script for instruct models

* Update chat template in converstion + test

* add tests for CausalLM model

* model_max_length, avoid null chat_template

* Refactor conversion script

* Fix forward

* Fix integration tests

* Refactor vision config + docs

* Fix default

* Refactor text config

* Doc fixes

* Remove unused args, fix docs example

* Squashed commit of the following:

commit b51ce5a2efffbecdefbf6fc92ee87372ec9d8830
Author: qubvel <[email protected]>
Date:   Wed Sep 18 13:39:15 2024 +0000

    Move model + add output hidden states and output attentions

* Fix num_channels

* Add mllama text and mllama vision models

* Fixing repo consistency

* Style fix

* Fixing repo consistency

* Fixing unused config params

* Fix failed tests after refactoring

* hidden_activation -> hidden_act  for text mlp

* Remove from_pretrained from sub-configs

* Apply suggestions from code review

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/mllama/convert_mllama_weights_to_hf.py

Co-authored-by: Arthur <[email protected]>

* Reuse lambda in conversion script

* Remove run.py

* Update docs/source/en/model_doc/mllama.md

Co-authored-by: Arthur <[email protected]>

* Update src/transformers/models/mllama/processing_mllama.py

Co-authored-by: Arthur <[email protected]>

* Remove unused LlamaTokenizerFast

* Fix logging

* Refactor gating

* Remove cycle for collecting intermediate states

* Refactor text-only check, add integration test for text-only

* Revert from pretrained to configs

* Fix example

* Add auto `bos_token` adding in processor

* Fix tips

* Update src/transformers/models/auto/tokenization_auto.py

Co-authored-by: Arthur <[email protected]>

* Enable supports_gradient_checkpointing model flag

* add eager/sdpa options

* don't skip attn tests and bring back GC skips (did i really remove those?)

* Fix signature, but get error with None gradient

* Fix output attention tests

* Disable GC back

* Change no split modules

* Fix dropout

* Style

* Add Mllama to sdpa list

* Add post init for vision model

* Refine config for MllamaForCausalLMModelTest and skipped tests for CausalLM model

* if skipped, say it, don't pass

* Clean vision tester config

* Doc for args

* Update tests/models/mllama/test_modeling_mllama.py

Co-authored-by: Arthur <[email protected]>

* Add cross_attention_mask to test

* typehint

* Remove todo

* Enable gradient checkpointing

* Docstring

* Style

* Fixing and skipping some tests for new cache

* Mark flaky test

* Skip `test_sdpa_can_compile_dynamic` test

* Fixing some offload tests

* Add direct GenerationMixin inheritance

* Remove unused code

* Add initializer_range to vision config

* update the test to make sure we show if split

* fix gc?

* Fix repo consistency

* Undo modeling utils debug changes

* Fix link

* mllama -> Mllama

* [mllama] -> [Mllama]

* Enable compile test for CausalLM model (text-only)

* Fix TextModel prefix

* Update doc

* Docs for forward, type hints, and vision model prefix

* make sure to reset

* fix init

* small script refactor and styling

* nit

* updates!

* some nits

* Interpolate embeddings for 560 size and update integration tests

* nit

* does not suppor static cache!

* update

* fix

* nit2

* this?

* Fix conversion

* Style

* 4x memory improvement with image cache AFAIK

* Token decorator for tests

* Skip failing tests

* update processor errors

* fix split issues

* style

* weird

* style

* fix failing tests

* update

* nit fixing the whisper tests

* fix path

* update

---------

Co-authored-by: raushan <[email protected]>
Co-authored-by: pavel <[email protected]>
Co-authored-by: qubvel <[email protected]>
Co-authored-by: Pablo Montalvo <[email protected]>
Co-authored-by: ydshieh <[email protected]>
Co-authored-by: Pedro Cuenca <[email protected]>
  • Loading branch information
7 people authored and dataKim1201 committed Oct 7, 2024
1 parent 9abf440 commit e66c200
Show file tree
Hide file tree
Showing 31 changed files with 6,183 additions and 98 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,8 @@
title: MatCha
- local: model_doc/mgp-str
title: MGP-STR
- local: model_doc/mllama
title: mllama
- local: model_doc/nougat
title: Nougat
- local: model_doc/omdet-turbo
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ Flax), PyTorch, and/or TensorFlow.
| [Mimi](model_doc/mimi) ||||
| [Mistral](model_doc/mistral) ||||
| [Mixtral](model_doc/mixtral) ||||
| [Mllama](model_doc/mllama) ||||
| [mLUKE](model_doc/mluke) ||||
| [MMS](model_doc/mms) ||||
| [MobileBERT](model_doc/mobilebert) ||||
Expand Down
124 changes: 124 additions & 0 deletions docs/source/en/model_doc/mllama.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Mllama

## Overview

The Llama 3.2-Vision collection of multimodal large language models (LLMs) is a collection of pretrained and instruction-tuned image reasoning generative models in 11B and 90B sizes (text \+ images in / text out). The Llama 3.2-Vision instruction-tuned models are optimized for visual recognition, image reasoning, captioning, and answering general questions about an image.

**Model Architecture:** Llama 3.2-Vision is built on top of Llama 3.1 text-only model, which is an auto-regressive language model that uses an optimized transformer architecture. The tuned versions use supervised fine-tuning (SFT) and reinforcement learning with human feedback (RLHF) to align with human preferences for helpfulness and safety. To support image recognition tasks, the Llama 3.2-Vision model uses a separately trained vision adapter that integrates with the pre-trained Llama 3.1 language model. The adapter consists of a series of cross-attention layers that feed image encoder representations into the core LLM.

## Usage Tips

- For image+text and text inputs use `MllamaForConditionalGeneration`.
- For text-only inputs use `MllamaForCausalLM` for generation to avoid loading vision tower.
- Each sample can contain multiple images, and the number of images can vary between samples. The processor will pad the inputs to the maximum number of images across samples and to a maximum number of tiles within each image.
- The text passed to the processor should have the `"<|image|>"` tokens where the images should be inserted.
- The processor has its own `apply_chat_template` method to convert chat messages to text that can then be passed as text to the processor.

## Usage Example

#### Instruct model
```python
import requests
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor

model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct"
model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
processor = AutoProcessor.from_pretrained(model_id)

messages = [
[
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": "What does the image show?"}
]
}
],
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)

url = "https://llava-vl.github.io/static/images/view.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=text, images=image, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=25)
print(processor.decode(output[0]))
```

#### Base model
```python
import requests
import torch
from PIL import Image
from transformers import MllamaForConditionalGeneration, AutoProcessor

model_id = "meta-llama/Llama-3.2-11B-Vision"
model = MllamaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
processor = AutoProcessor.from_pretrained(model_id)

prompt = "<|image|>If I had to write a haiku for this one"
url = "https://llava-vl.github.io/static/images/view.jpg"
raw_image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(text=prompt, images=raw_image, return_tensors="pt").to(model.device)
output = model.generate(**inputs, do_sample=False, max_new_tokens=25)
print(processor.decode(output[0], skip_special_tokens=True))
```


## MllamaConfig

[[autodoc]] MllamaConfig

## MllamaProcessor

[[autodoc]] MllamaProcessor


## MllamaImageProcessor

[[autodoc]] MllamaImageProcessor

## MllamaForConditionalGeneration

[[autodoc]] MllamaForConditionalGeneration
- forward

## MllamaForCausalLM

[[autodoc]] MllamaForCausalLM
- forward

## MllamaTextModel

[[autodoc]] MllamaTextModel
- forward

## MllamaForCausalLM

[[autodoc]] MllamaForCausalLM
- forward

## MllamaVisionModel

[[autodoc]] MllamaVisionModel
- forward
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100#transformers.M2M100Model)
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mllama](https://huggingface.co/docs/transformers/model_doc/mllama#transformers.MllamaForConditionalGeneration)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
Expand Down
28 changes: 28 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,10 @@
"models.mimi": ["MimiConfig"],
"models.mistral": ["MistralConfig"],
"models.mixtral": ["MixtralConfig"],
"models.mllama": [
"MllamaConfig",
"MllamaProcessor",
],
"models.mluke": [],
"models.mobilebert": [
"MobileBertConfig",
Expand Down Expand Up @@ -1199,6 +1203,7 @@
)
_import_structure["models.mask2former"].append("Mask2FormerImageProcessor")
_import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"])
_import_structure["models.mllama"].extend(["MllamaImageProcessor"])
_import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"])
_import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"])
_import_structure["models.mobilevit"].extend(["MobileViTFeatureExtractor", "MobileViTImageProcessor"])
Expand Down Expand Up @@ -2704,6 +2709,16 @@
"MixtralPreTrainedModel",
]
)
_import_structure["models.mllama"].extend(
[
"MllamaForCausalLM",
"MllamaForConditionalGeneration",
"MllamaPreTrainedModel",
"MllamaProcessor",
"MllamaTextModel",
"MllamaVisionModel",
]
)
_import_structure["models.mobilebert"].extend(
[
"MobileBertForMaskedLM",
Expand Down Expand Up @@ -5377,6 +5392,10 @@
)
from .models.mistral import MistralConfig
from .models.mixtral import MixtralConfig
from .models.mllama import (
MllamaConfig,
MllamaProcessor,
)
from .models.mobilebert import (
MobileBertConfig,
MobileBertTokenizer,
Expand Down Expand Up @@ -6037,6 +6056,7 @@
MaskFormerFeatureExtractor,
MaskFormerImageProcessor,
)
from .models.mllama import MllamaImageProcessor
from .models.mobilenet_v1 import (
MobileNetV1FeatureExtractor,
MobileNetV1ImageProcessor,
Expand Down Expand Up @@ -7270,6 +7290,14 @@
MixtralModel,
MixtralPreTrainedModel,
)
from .models.mllama import (
MllamaForCausalLM,
MllamaForConditionalGeneration,
MllamaPreTrainedModel,
MllamaProcessor,
MllamaTextModel,
MllamaVisionModel,
)
from .models.mobilebert import (
MobileBertForMaskedLM,
MobileBertForMultipleChoice,
Expand Down
81 changes: 53 additions & 28 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
if self.key_cache[layer_idx] != []:
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
if self.value_cache[layer_idx] != []:
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))

@property
def seen_tokens(self):
Expand Down Expand Up @@ -358,10 +360,14 @@ class DynamicCache(Cache):
```
"""

def __init__(self) -> None:
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
super().__init__()
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
if num_hidden_layers is None:
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []
else:
self.key_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
self.value_cache: List[torch.Tensor] = [[] for _ in range(num_hidden_layers)]
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
Expand Down Expand Up @@ -420,6 +426,11 @@ def update(
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
# content on layer cache can be a tensor and checking not tensor causes errors
# so we explicitly check for the empty list
elif self.key_cache[layer_idx] == []:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
Expand All @@ -429,7 +440,7 @@ def update(
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# TODO: deprecate this function in favor of `cache_position`
if len(self.key_cache) <= layer_idx:
if len(self.key_cache) <= layer_idx or (len(self.key_cache) > layer_idx and self.key_cache[layer_idx] == []):
return 0
return self.key_cache[layer_idx].shape[-2]

Expand All @@ -446,10 +457,12 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
return legacy_cache

@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls()
cache = cls(num_hidden_layers)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
Expand All @@ -468,30 +481,34 @@ def crop(self, max_length: int):

self._seen_tokens = max_length
for idx in range(len(self.key_cache)):
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
if self.key_cache[idx] != []:
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]

def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicCache()
current_split = DynamicCache(num_hidden_layers)
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
return out

@classmethod
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls()
cache = cls(num_hidden_layers)
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
cache.update(layer_keys, layer_values, idx)
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
if key_cache != []:
layer_keys = torch.cat(key_cache, dim=0)
layer_values = torch.cat(value_cache, dim=0)
cache.update(layer_keys, layer_values, idx)
return cache

def batch_repeat_interleave(self, repeats: int):
Expand Down Expand Up @@ -1391,10 +1408,13 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:

@classmethod
def from_legacy_cache(
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
) -> "EncoderDecoderCache":
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
cache = cls(
self_attention_cache=DynamicCache(num_hidden_layers),
cross_attention_cache=DynamicCache(num_hidden_layers),
)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx][:2]
Expand All @@ -1407,7 +1427,10 @@ def from_legacy_cache(

def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
if len(self.self_attention_cache.key_cache) <= layer_idx:
# check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
if self.self_attention_cache.key_cache == []:
return 0
if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()

Expand Down Expand Up @@ -1448,24 +1471,26 @@ def crop(self, maximum_length: int):
self.check_dynamic_cache(self.crop.__name__)
self.self_attention_cache.crop(maximum_length)

def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
def batch_split(
self, full_batch_size: int, split_size: int, num_hidden_layers: int
) -> "List[EncoderDecoderCache]":
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
self.check_dynamic_cache(self.batch_split.__name__)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers)
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size, num_hidden_layers)

out = []
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
out.append(EncoderDecoderCache(self_attn, cross_attn))
return out

@classmethod
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
def from_batch_splits(cls, splits: List["EncoderDecoderCache"], num_hidden_layers: int) -> "EncoderDecoderCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
self_attention_cache = DynamicCache()
cross_attention_cache = DynamicCache()
self_attention_cache = DynamicCache(num_hidden_layers)
cross_attention_cache = DynamicCache(num_hidden_layers)
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
Expand Down
Loading

0 comments on commit e66c200

Please sign in to comment.