Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Medium-scale refactoring #408

Merged
merged 35 commits into from
Mar 29, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
fcd49f6
Medium-scale refactoring
mpenkov Jan 11, 2020
656d2b4
more refactoring
mpenkov Jan 11, 2020
af35d24
automate docstrings
mpenkov Jan 12, 2020
57c459f
link to extending.md from README.rst
mpenkov Jan 12, 2020
3961dbb
fixup
mpenkov Jan 12, 2020
463b060
improve my_urlsplit function name
mpenkov Jan 12, 2020
4f287df
improve docstring
mpenkov Jan 12, 2020
3dcb71a
remove unused variable
mpenkov Jan 12, 2020
4ee4490
fixup
mpenkov Jan 12, 2020
2c8a4f2
disable docstring tweaking on Py2
mpenkov Jan 12, 2020
f489689
more Py27 goodness
mpenkov Jan 12, 2020
b22e3b0
add section to extending.md
mpenkov Jan 12, 2020
b09e03f
Merge remote-tracking branch 'upstream/master' into uri
mpenkov Jan 30, 2020
1cc60ea
improving transport submodule registration
mpenkov Jan 30, 2020
4d3b1a7
integrating gcs into new design
mpenkov Jan 30, 2020
64f43f0
disable moto server by default
mpenkov Jan 30, 2020
6110269
import submodules via importlib for flexibility
mpenkov Mar 27, 2020
9070547
Merge remote-tracking branch 'upstream/master' into uri
mpenkov Mar 27, 2020
abf4fef
move tweak function to doctools
mpenkov Mar 27, 2020
110a557
split out separate transport.py submodule
mpenkov Mar 27, 2020
7d67db8
Merge remote-tracking branch 'upstream/master' into uri
mpenkov Mar 27, 2020
12605ab
fixup
mpenkov Mar 27, 2020
64b2fdd
get rid of Py2
mpenkov Mar 27, 2020
98ded35
get rid of Py2, for real this time
mpenkov Mar 27, 2020
903bfd0
get rid of unused imports
mpenkov Mar 27, 2020
f5dc67f
still more Py2 removal
mpenkov Mar 27, 2020
b309d58
remove unused imports
mpenkov Mar 27, 2020
a936bea
warn on missing docstrings
mpenkov Mar 27, 2020
0720cfc
docstring before and after newline
mpenkov Mar 27, 2020
caf5a42
add doc links to submodules
mpenkov Mar 27, 2020
df7aee7
remove useless comment in setup.py
mpenkov Mar 27, 2020
c1be8de
improve examples
mpenkov Mar 27, 2020
0f6d5e4
split out utils and constants submodules
mpenkov Mar 28, 2020
6d7a73a
split out concurrency submodule
mpenkov Mar 28, 2020
f7a4df0
update extending.md
mpenkov Mar 29, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 109 additions & 0 deletions smart_open/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) 2020 Radim Rehurek <[email protected]>
#
# This code is distributed under the terms and conditions
# from the MIT License (MIT).
#
import io
import os.path

import six


_COMPRESSOR_REGISTRY = {}


def get_supported_extensions():
return sorted(_COMPRESSOR_REGISTRY.keys())


def register_compressor(ext, callback):
"""Register a callback for transparently decompressing files with a specific extension.

Parameters
----------
ext: str
The extension.
menshikh-iv marked this conversation as resolved.
Show resolved Hide resolved
callback: callable
The callback. It must accept two position arguments, file_obj and mode.
mpenkov marked this conversation as resolved.
Show resolved Hide resolved

Examples
--------

Instruct smart_open to use the identity function whenever opening a file
with a .xz extension (see README.rst for the complete example showing I/O):

>>> def _handle_xz(file_obj, mode):
... import lzma
... return lzma.LZMAFile(filename=file_obj, mode=mode, format=lzma.FORMAT_XZ)
>>>
>>> register_compressor('.xz', _handle_xz)

"""
if not (ext and ext[0] == '.'):
raise ValueError('ext must be a string starting with ., not %r' % ext)
if ext in _COMPRESSOR_REGISTRY:
logger.warning('overriding existing compression handler for %r', ext)
_COMPRESSOR_REGISTRY[ext] = callback


def _handle_bz2(file_obj, mode):
if six.PY2:
from bz2file import BZ2File
else:
from bz2 import BZ2File
return BZ2File(file_obj, mode)


def _handle_gzip(file_obj, mode):
import gzip
return gzip.GzipFile(fileobj=file_obj, mode=mode)


def compression_wrapper(file_obj, filename, mode):
"""
This function will wrap the file_obj with an appropriate
[de]compression mechanism based on the extension of the filename.

file_obj must either be a filehandle object, or a class which behaves
like one.

If the filename extension isn't recognized, will simply return the original
file_obj.
"""
_, ext = os.path.splitext(filename)

if _need_to_buffer(file_obj, mode, ext):
warnings.warn('streaming gzip support unavailable, see %s' % _ISSUE_189_URL)
file_obj = io.BytesIO(file_obj.read())
if ext in _COMPRESSOR_REGISTRY and mode.endswith('+'):
raise ValueError('transparent (de)compression unsupported for mode %r' % mode)

try:
callback = _COMPRESSOR_REGISTRY[ext]
except KeyError:
return file_obj
else:
return callback(file_obj, mode)


def _need_to_buffer(file_obj, mode, ext):
"""Returns True if we need to buffer the whole file in memory in order to proceed."""
try:
is_seekable = file_obj.seekable()
except AttributeError:
#
# Under Py2, built-in file objects returned by open do not have
# .seekable, but have a .seek method instead.
#
is_seekable = hasattr(file_obj, 'seek')
is_compressed = ext in _COMPRESSOR_REGISTRY
return six.PY2 and mode.startswith('r') and is_compressed and not is_seekable


#
# NB. avoid using lambda here to make stack traces more readable.
#
register_compressor('.bz2', _handle_bz2)
register_compressor('.gz', _handle_gzip)
18 changes: 18 additions & 0 deletions smart_open/hdfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,26 @@
import logging
import subprocess

import smart_open.uri

from six.moves.urllib import parse as urlparse

logger = logging.getLogger(__name__)

HDFS_SCHEME = 'hdfs'


def parse_uri(uri_as_string):
split_uri = urlparse.urlsplit(uri_as_string)
assert split_uri.scheme == HDFS_SCHEME

uri_path = split_uri.netloc + split_uri.path
uri_path = "/" + uri_path.lstrip("/")
if not uri_path:
raise RuntimeError("invalid HDFS URI: %s" % str(parsed_uri))

return smart_open.uri.Uri(scheme=HDFS_SCHEME, uri_path=uri_path)


def open(uri, mode):
if mode == 'rb':
Expand Down
12 changes: 12 additions & 0 deletions smart_open/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
import io
import logging

from six.moves.urllib import parse as urlparse
import requests

from smart_open import bytebuffer, s3
import smart_open.uri

DEFAULT_BUFFER_SIZE = 128 * 1024
SUPPORTED_SCHEMES = ('http', 'https')

logger = logging.getLogger(__name__)

Expand All @@ -28,6 +31,15 @@
"""


def parse_uri(uri_as_string):
split_uri = urlparse.urlsplit(uri_as_string)
assert split_uri.scheme in SUPPORTED_SCHEMES

uri_path = split_uri.netloc + split_uri.path
uri_path = "/" + uri_path.lstrip("/")
return smart_open.uri.Uri(scheme=split_uri.scheme, uri_path=uri_path)


def open(uri, mode, kerberos=False, user=None, password=None, headers=None):
"""Implement streamed reader from a web site.

Expand Down
142 changes: 141 additions & 1 deletion smart_open/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
import logging
import warnings

import boto
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
import boto3
import botocore.client
import six

from six.moves.urllib import parse as urlparse
from botocore.exceptions import IncompleteReadError

import smart_open.bytebuffer
import smart_open.uri

from botocore.exceptions import IncompleteReadError

logger = logging.getLogger(__name__)

Expand All @@ -46,6 +50,8 @@
BINARY_NEWLINE = b'\n'

SUPPORTED_SCHEMES = ("s3", "s3n", 's3u', "s3a")
DEFAULT_PORT = 443
DEFAULT_HOST = 's3.amazonaws.com'

DEFAULT_BUFFER_SIZE = 128 * 1024

Expand All @@ -55,6 +61,140 @@
WHENCE_CHOICES = [START, CURRENT, END]


def _my_urlsplit(url):
mpenkov marked this conversation as resolved.
Show resolved Hide resolved
"""This is a hack to prevent the regular urlsplit from splitting around question marks.

A question mark (?) in a URL typically indicates the start of a
querystring, and the standard library's urlparse function handles the
querystring separately. Unfortunately, question marks can also appear
_inside_ the actual URL for some schemas like S3.

Replaces question marks with newlines prior to splitting. This is safe because:

1. The standard library's urlsplit completely ignores newlines
2. Raw newlines will never occur in innocuous URLs. They are always URL-encoded.

See Also
--------
https://github.com/python/cpython/blob/3.7/Lib/urllib/parse.py
https://github.com/RaRe-Technologies/smart_open/issues/285
"""
sr = urlparse.urlsplit(url.replace('?', '\n'), allow_fragments=False)
return urlparse.SplitResult(sr.scheme, sr.netloc, sr.path.replace('\n', '?'), '', '')


def parse_uri(uri_as_string):
#
# Restrictions on bucket names and labels:
#
# - Bucket names must be at least 3 and no more than 63 characters long.
# - Bucket names must be a series of one or more labels.
# - Adjacent labels are separated by a single period (.).
# - Bucket names can contain lowercase letters, numbers, and hyphens.
# - Each label must start and end with a lowercase letter or a number.
#
# We use the above as a guide only, and do not perform any validation. We
# let boto3 take care of that for us.
#
split_uri = _my_urlsplit(uri_as_string)
assert split_uri.scheme in SUPPORTED_SCHEMES

port = DEFAULT_PORT
menshikh-iv marked this conversation as resolved.
Show resolved Hide resolved
host = boto.config.get('s3', 'host', DEFAULT_HOST)
ordinary_calling_format = False
#
# These defaults tell boto3 to look for credentials elsewhere
#
access_id, access_secret = None, None

#
# Common URI template [secret:key@][host[:port]@]bucket/object
#
# The urlparse function doesn't handle the above schema, so we have to do
# it ourselves.
#
uri = split_uri.netloc + split_uri.path

if '@' in uri and ':' in uri.split('@')[0]:
auth, uri = uri.split('@', 1)
access_id, access_secret = auth.split(':')

head, key_id = uri.split('/', 1)
if '@' in head and ':' in head:
ordinary_calling_format = True
host_port, bucket_id = head.split('@')
host, port = host_port.split(':', 1)
port = int(port)
elif '@' in head:
ordinary_calling_format = True
host, bucket_id = head.split('@')
else:
bucket_id = head

return smart_open.uri.Uri(
scheme=split_uri.scheme,
bucket_id=bucket_id,
key_id=key_id,
port=port,
host=host,
ordinary_calling_format=ordinary_calling_format,
access_id=access_id,
access_secret=access_secret,
)


def consolidate_params(uri, transport_params):
"""Consolidates the parsed Uri with the additional parameters.

This is necessary because the user can pass some of the parameters can in
two different ways:

1) Via the URI itself
2) Via the transport parameters

These are not mutually exclusive, but we have to pick one over the other
in a sensible way in order to proceed.

"""
transport_params = dict(transport_params)

session = transport_params.get('session')
if session is not None and (uri.access_id or uri.access_secret):
logger.warning(
'ignoring credentials parsed from URL because they conflict with '
'transport_params.session. Set transport_params.session to None '
'to suppress this warning.'
)
uri = uri._replace(access_id=None, access_secret=None)
elif (uri.access_id and uri.access_secret):
transport_params['session'] = boto3.Session(
aws_access_key_id=uri.access_id,
aws_secret_access_key=uri.access_secret,
)
uri = uri._replace(access_id=None, access_secret=None)

if uri.host != DEFAULT_HOST:
endpoint_url = 'https://%s:%d' % (uri.host, uri.port)
_override_endpoint_url(transport_params, endpoint_url)

return uri, transport_params


def _override_endpoint_url(transport_params, url):
try:
resource_kwargs = transport_params['resource_kwargs']
except KeyError:
resource_kwargs = transport_params['resource_kwargs'] = {}

if resource_kwargs.get('endpoint_url'):
logger.warning(
'ignoring endpoint_url parsed from URL because it conflicts '
'with transport_params.resource_kwargs.endpoint_url. '
)
else:
resource_kwargs.update(endpoint_url=url)


def clamp(value, minval, maxval):
return max(min(value, maxval), minval)

Expand Down
Loading