Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flamingo Implementation #23063

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions src/transformers/models/flamingo/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from typing import TYPE_CHECKING

from transformers.utils import (
OptionalDependencyNotAvailable,
_LazyModule,
is_torch_available,
)


_import_structure = {
"configuration_flamingo": [
"FlamingoConfig",
],
}

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
_import_structure["modeling_flamingo"] = [
"FlamingoModel",
"FlamingoPreTrainedModel",
"FlamingoForConditionalGeneration",
]

if TYPE_CHECKING:
from .configuration_flamingo import (
FlamingoConfig,
)

# from .processing_flamingo import FlamingoProcessor

try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
else:
from .modeling_flamingo import (
FlamingoForConditionalGeneration,
FlamingoModel,
FlamingoPreTrainedModel,
)

else:
import sys

sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
21 changes: 21 additions & 0 deletions src/transformers/models/flamingo/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
{
"model_type": "flamingo",
"cross_attn_every_n_layers": 4,
"tie_word_embeddings": false,
"use_media_placement_augmentation": true,
"only_attend_previous": true,
"text_config": {
"_name_or_path": "luodian/llama-7b-hf",
"model_type": "llama"
},
"vision_config": {
"_name_or_path": "openai/clip-vit-large-patch14",
"model_type": "clip_vision_model",
"hidden_size": 1024,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"image_size": 224,
"patch_size": 14
}
}
89 changes: 89 additions & 0 deletions src/transformers/models/flamingo/configuration_flamingo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import copy

from ...configuration_utils import PretrainedConfig
from ...utils import logging
from ...models.auto import CONFIG_MAPPING
from ..clip import CLIPVisionConfig

logger = logging.get_logger(__name__)


class FlamingoConfig(PretrainedConfig):
r"""
[`FlamingoConfig`] is the configuration class to store the configuration of a [`FlamingoForConditionalGeneration`]. It is
used to instantiate a Flamingo model according to the specified arguments, defining the vision model and language model configs. Instantiating a configuration with the defaults will yield a similar configuration to
that of the Flamingo architecture.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
vision_config (`dict`, *optional*):
Dictionary of configuration options used to initialize [`PretrainedConfig`].
text_config (`dict`, *optional*):
Dictionary of configuration options used to initialize any [`PretrainedConfig`].
cross_attn_every_n_layers (`int`, *optional*, defaults to 4):
The number of cross-attention layers adding after each transformer layer.

kwargs (*optional*):
Dictionary of keyword arguments.

Example:

```python
>>> from transformers import (
... PretrainedConfig,
... PretrainedConfig,
... FlamingoConfig,
... FlamingoForConditionalGeneration,
... )

>>> # Initializing a FlamingoConfig with Salesforce/Flamingo-opt-2.7b style configuration
>>> configuration = FlamingoConfig()

>>> # Initializing a FlamingoForConditionalGeneration (with random weights) from the Salesforce/Flamingo-opt-2.7b style configuration
>>> model = FlamingoForConditionalGeneration(configuration)
```"""
model_type = "flamingo"
is_composition = True

def __init__(
self,
vision_config=None,
text_config=None,
cross_attn_every_n_layers: int = 4,
use_media_placement_augmentation: bool = True,
only_attend_previous: bool = True,
**kwargs,
):
super().__init__(**kwargs)

if vision_config is None:
vision_config = {}
logger.info("vision_config is None. initializing the vision config with default values.")

if text_config is None:
text_config = {}
logger.info("text_config is None. Initializing the text config with default values.")

self.vision_config = CLIPVisionConfig(**vision_config)
self.text_config = CONFIG_MAPPING[text_config.pop("model_type")](**text_config)
self.cross_attn_every_n_layers = cross_attn_every_n_layers
self.use_media_placement_augmentation = use_media_placement_augmentation
self.only_attend_previous = only_attend_previous

def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].

Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["vision_config"] = self.vision_config.to_dict()
output["text_config"] = self.text_config.to_dict()
output["model_type"] = self.__class__.model_type
output["cross_attn_every_n_layers"] = self.cross_attn_every_n_layers
output["use_media_placement_augmentation"] = self.use_media_placement_augmentation
output["only_attend_previous"] = self.only_attend_previous
return output
126 changes: 126 additions & 0 deletions src/transformers/models/flamingo/converting_flamingo_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import re
import argparse
import os

import torch
import torch.nn as nn

from ...models.clip import CLIPVisionModel
from ..auto import AutoModelForCausalLM, AutoTokenizer
from .modeling_flamingo import (
FlamingoPreTrainedModel,
FlamingoLMMixin,
extend_instance,
_infer_decoder_layers_attr_name,
FlamingoPerceiverResampler,
)
from .configuration_flamingo import FlamingoConfig


class FlamingoModel(FlamingoPreTrainedModel):
config_class = FlamingoConfig

def __init__(
self,
config: FlamingoConfig,
):
super().__init__(config)
lang_encoder = AutoModelForCausalLM.from_pretrained(config.text_config._name_or_path)
text_tokenizer = AutoTokenizer.from_pretrained(config.text_config._name_or_path)
vision_encoder = CLIPVisionModel.from_pretrained(config.vision_config._name_or_path)

text_tokenizer.add_special_tokens({"additional_special_tokens": ["<|endofchunk|>", "<image>"]})
if text_tokenizer.pad_token is None:
text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
self.text_tokenizer = text_tokenizer
self.eoc_token_id = text_tokenizer.encode("<|endofchunk|>")[-1]
self.media_token_id = text_tokenizer.encode("<image>")[-1]

extend_instance(lang_encoder, FlamingoLMMixin)
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_encoder)
lang_encoder.set_decoder_layers_attr_name(decoder_layers_attr_name)
lang_encoder.resize_token_embeddings(len(text_tokenizer))
self.lang_encoder = lang_encoder

self.cross_attn_every_n_layers = config.cross_attn_every_n_layers
self.use_media_placement_augmentation = config.use_media_placement_augmentation

vision_encoder.output_tokens = True
self.vision_encoder = vision_encoder

self.vis_dim = 1024
self.perceiver = FlamingoPerceiverResampler(dim=self.vis_dim)

self.lang_encoder.init_flamingo(
media_token_id=self.media_token_id,
vis_hidden_size=self.vis_dim,
cross_attn_every_n_layers=self.cross_attn_every_n_layers,
use_media_placement_augmentation=self.use_media_placement_augmentation,
)

def get_input_embeddings(self) -> nn.Module:
return self.lang_encoder.get_input_embeddings()

def set_input_embeddings(self, new_embeddings):
self.lang_encoder.set_input_embeddings(new_embeddings)

def get_output_embeddings(self) -> nn.Module:
return self.lang_encoder.get_output_embeddings()

def set_output_embeddings(self, new_embeddings):
self.lang_encoder.set_output_embeddings(new_embeddings)


def rename_flamingo_checkpoint(old_ckpt: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""Rename some keys in the public Flamingo checkpoint"""
perceiver_pattern1 = re.compile(r"perceiver\.layers\.[0-9]\.0")
perceiver_pattern2 = re.compile(r"perceiver\.layers\.[0-9]\.1")
new_ckpt = old_ckpt.copy()
for key, value in old_ckpt.items():
if re.match(perceiver_pattern1, key):
new_key = re.sub(r"([0-9])\.0", r"\1", key)
new_ckpt.pop(key)
new_ckpt[new_key] = value
elif re.match(perceiver_pattern2, key):
new_key = re.sub(r"([0-9])\.1", r"\1.feed_forward", key)
new_ckpt.pop(key)
new_ckpt[new_key] = value
elif key.startswith("lang_encoder.gated_cross_attn_layers."):
new_ckpt.pop(key)
elif key.startswith("lang_encoder.") and "ff_gate" not in key:
new_key = key.replace("ff", "feed_forward")
new_ckpt.pop(key)
new_ckpt[new_key] = value

return new_ckpt


@torch.no_grad()
def dump_hf_model(old_ckpt_path: str, new_folder_path: str) -> None:
os.makedirs(new_folder_path, exist_ok=True)
old_ckpt = torch.load(old_ckpt_path, map_location="cpu")
config = FlamingoConfig.from_json_file("transformers/src/transformers/models/flamingo/config.json")
model = FlamingoModel(config)
new_ckpt = rename_flamingo_checkpoint(old_ckpt)
model.load_state_dict(new_ckpt, strict=False)
model.save_pretrained(new_folder_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--old_ckpt_path",
"-old",
type=str,
required=True,
help="Path to the OpenFlamingo checkpoint",
)
parser.add_argument(
"--new_hf_path",
"-new",
type=str,
required=True,
help="Path to the HF folder",
)
args = parser.parse_args()
dump_hf_model(args.old_ckpt_path, args.new_hf_path)
Loading