Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade the version of transformers #7343

Merged
merged 3 commits into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 13 additions & 36 deletions monai/networks/nets/transchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,17 @@
from __future__ import annotations

import math
import os
import shutil
import tarfile
import tempfile
from collections.abc import Sequence

import torch
from torch import nn

from monai.config.type_definitions import PathLike
from monai.utils import optional_import

transformers = optional_import("transformers")
load_tf_weights_in_bert = optional_import("transformers", name="load_tf_weights_in_bert")[0]
cached_path = optional_import("transformers.file_utils", name="cached_path")[0]
cached_file = optional_import("transformers.utils", name="cached_file")[0]
BertEmbeddings = optional_import("transformers.models.bert.modeling_bert", name="BertEmbeddings")[0]
BertLayer = optional_import("transformers.models.bert.modeling_bert", name="BertLayer")[0]

Expand Down Expand Up @@ -63,44 +60,16 @@ def from_pretrained(
state_dict=None,
cache_dir=None,
from_tf=False,
path_or_repo_id="bert-base-uncased",
filename="pytorch_model.bin",
*inputs,
**kwargs,
):
archive_file = "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz"
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
tempdir = None
if os.path.isdir(resolved_archive_file) or from_tf:
serialization_dir = resolved_archive_file
else:
tempdir = tempfile.mkdtemp()
with tarfile.open(resolved_archive_file, "r:gz") as archive:

def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)

prefix = os.path.commonprefix([abs_directory, abs_target])

return prefix == abs_directory

def safe_extract(tar, path=".", members=None, *, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")

tar.extractall(path, members, numeric_owner=numeric_owner)

safe_extract(archive, tempdir)
serialization_dir = tempdir
weights_path = cached_file(path_or_repo_id, filename, cache_dir=cache_dir)
model = cls(num_language_layers, num_vision_layers, num_mixed_layers, bert_config, *inputs, **kwargs)
if state_dict is None and not from_tf:
weights_path = os.path.join(serialization_dir, "pytorch_model.bin")
state_dict = torch.load(weights_path, map_location="cpu" if not torch.cuda.is_available() else None)
if tempdir:
shutil.rmtree(tempdir)
if from_tf:
weights_path = os.path.join(serialization_dir, "model.ckpt")
return load_tf_weights_in_bert(model, weights_path)
old_keys = []
new_keys = []
Expand Down Expand Up @@ -304,6 +273,8 @@ def __init__(
chunk_size_feed_forward: int = 0,
is_decoder: bool = False,
add_cross_attention: bool = False,
path_or_repo_id: str | PathLike = "bert-base-uncased",
filename: str = "pytorch_model.bin",
) -> None:
"""
Args:
Expand All @@ -315,6 +286,10 @@ def __init__(
num_vision_layers: number of vision transformer layers.
num_mixed_layers: number of mixed transformer layers.
drop_out: fraction of the input units to drop.
path_or_repo_id: This can be either:
- a string, the *model id* of a model repo on huggingface.co.
- a path to a *directory* potentially containing the file.
filename: The name of the file to locate in `path_or_repo`.

The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`.

Expand Down Expand Up @@ -369,6 +344,8 @@ def __init__(
num_vision_layers=num_vision_layers,
num_mixed_layers=num_mixed_layers,
bert_config=bert_config,
path_or_repo_id=path_or_repo_id,
filename=filename,
)

self.patch_size = patch_size
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin"
pandas
requests
einops
transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157
transformers>=4.36.0
mlflow>=1.28.0
clearml>=1.10.0rc0
matplotlib!=3.5.0
Expand Down
3 changes: 1 addition & 2 deletions tests/test_transchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from monai.networks import eval_mode
from monai.networks.nets.transchex import Transchex
from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_quick
from tests.utils import skip_if_quick

TEST_CASE_TRANSCHEX = []
for drop_out in [0.4]:
Expand Down Expand Up @@ -46,7 +46,6 @@


@skip_if_quick
@SkipIfAtLeastPyTorchVersion((1, 10))
class TestTranschex(unittest.TestCase):
@parameterized.expand(TEST_CASE_TRANSCHEX)
def test_shape(self, input_param, expected_shape):
Expand Down
Loading