Skip to content

Commit

Permalink
rm boto3 dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
julien-c committed Apr 27, 2020
1 parent 4e817ff commit 97a3754
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 75 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/github-torch-hub.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install dependencies
run: |
pip install torch
pip install numpy tokenizers boto3 filelock requests tqdm regex sentencepiece sacremoses
pip install numpy tokenizers filelock requests tqdm regex sentencepiece sacremoses
- name: Torch hub list
run: |
Expand Down
2 changes: 1 addition & 1 deletion hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


dependencies = ["torch", "numpy", "tokenizers", "boto3", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]
dependencies = ["torch", "numpy", "tokenizers", "filelock", "requests", "tqdm", "regex", "sentencepiece", "sacremoses"]


@add_start_docstrings(AutoConfig.__doc__)
Expand Down
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@
"tokenizers == 0.7.0",
# dataclasses for Python versions that don't have it
"dataclasses;python_version<'3.7'",
# accessing files from S3 directly
"boto3",
# filesystem locks e.g. to prevent parallel downloads
"filelock",
# for downloading models over HTTPS
Expand Down
80 changes: 9 additions & 71 deletions src/transformers/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile

import boto3
import requests
from botocore.config import Config
from botocore.exceptions import ClientError
from filelock import FileLock
from tqdm.auto import tqdm

Expand Down Expand Up @@ -144,7 +141,7 @@ def docstring_decorator(fn):

def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https", "s3")
return parsed.scheme in ("http", "https")


def hf_bucket_url(identifier, postfix=None, cdn=False) -> str:
Expand Down Expand Up @@ -297,55 +294,6 @@ def cached_path(
return output_path


def split_s3_path(url):
"""Split a full s3 path into the bucket name and path."""
parsed = urlparse(url)
if not parsed.netloc or not parsed.path:
raise ValueError("bad s3 path {}".format(url))
bucket_name = parsed.netloc
s3_path = parsed.path
# Remove '/' at beginning of path.
if s3_path.startswith("/"):
s3_path = s3_path[1:]
return bucket_name, s3_path


def s3_request(func):
"""
Wrapper function for s3 requests in order to create more helpful error
messages.
"""

@wraps(func)
def wrapper(url, *args, **kwargs):
try:
return func(url, *args, **kwargs)
except ClientError as exc:
if int(exc.response["Error"]["Code"]) == 404:
raise EnvironmentError("file {} not found".format(url))
else:
raise

return wrapper


@s3_request
def s3_etag(url, proxies=None):
"""Check ETag on S3 object."""
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url)
s3_object = s3_resource.Object(bucket_name, s3_path)
return s3_object.e_tag


@s3_request
def s3_get(url, temp_file, proxies=None):
"""Pull a file directly from S3."""
s3_resource = boto3.resource("s3", config=Config(proxies=proxies))
bucket_name, s3_path = split_s3_path(url)
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)


def http_get(url, temp_file, proxies=None, resume_size=0, user_agent=None):
ua = "transformers/{}; python/{}".format(__version__, sys.version.split()[0])
if is_torch_available():
Expand Down Expand Up @@ -406,17 +354,13 @@ def get_from_cache(

etag = None
if not local_files_only:
# Get eTag to add to filename, if it exists.
if url.startswith("s3://"):
etag = s3_etag(url, proxies=proxies)
else:
try:
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
if response.status_code == 200:
etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout):
# etag is already None
pass
try:
response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
if response.status_code == 200:
etag = response.headers.get("ETag")
except (EnvironmentError, requests.exceptions.Timeout):
# etag is already None
pass

filename = url_to_filename(url, etag)

Expand Down Expand Up @@ -483,13 +427,7 @@ def _resumable_file_manager():
with temp_file_manager() as temp_file:
logger.info("%s not found in cache or force_download set to True, downloading to %s", url, temp_file.name)

# GET file object
if url.startswith("s3://"):
if resume_download:
logger.warn('Warning: resumable downloads are not implemented for "s3://" urls')
s3_get(url, temp_file, proxies=proxies)
else:
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)
http_get(url, temp_file, proxies=proxies, resume_size=resume_size, user_agent=user_agent)

logger.info("storing %s in cache at %s", url, cache_path)
os.replace(temp_file.name, cache_path)
Expand Down

0 comments on commit 97a3754

Please sign in to comment.