Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add verify_ssl option to gluon.utils.download #11546

Merged
merged 5 commits into from
Jul 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,7 @@ try {
bat """xcopy C:\\mxnet\\data data /E /I /Y
xcopy C:\\mxnet\\model model /E /I /Y
call activate py2
pip install mock
set PYTHONPATH=${env.WORKSPACE}\\pkg_vc14_cpu\\python
del /S /Q ${env.WORKSPACE}\\pkg_vc14_cpu\\python\\*.pyc
C:\\mxnet\\test_cpu.bat"""
Expand Down Expand Up @@ -893,6 +894,7 @@ try {
bat """xcopy C:\\mxnet\\data data /E /I /Y
xcopy C:\\mxnet\\model model /E /I /Y
call activate py2
pip install mock
set PYTHONPATH=${env.WORKSPACE}\\pkg_vc14_gpu\\python
del /S /Q ${env.WORKSPACE}\\pkg_vc14_gpu\\python\\*.pyc
C:\\mxnet\\test_gpu.bat"""
Expand Down
14 changes: 12 additions & 2 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def check_sha1(filename, sha1_hash):
return sha1.hexdigest() == sha1_hash


def download(url, path=None, overwrite=False, sha1_hash=None, retries=5):
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
"""Download an given URL

Parameters
Expand All @@ -192,6 +192,8 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5):
but doesn't match.
retries : integer, default 5
The number of times to attempt the download in case of failure or non 200 return codes
verify_ssl : bool, default True
Verify SSL certificates.

Returns
-------
Expand All @@ -200,6 +202,9 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5):
"""
if path is None:
fname = url.split('/')[-1]
# Empty filenames are invalid
assert fname, 'Can\'t construct file-name from this URL. ' \
'Please set the `path` option manually.'
else:
path = os.path.expanduser(path)
if os.path.isdir(path):
Expand All @@ -208,6 +213,11 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5):
fname = path
assert retries >= 0, "Number of retries should be at least 0"

if not verify_ssl:
warnings.warn(
'Unverified HTTPS request is being made (verify_ssl=False). '
'Adding certificate verification is strongly advised.')

if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
if not os.path.exists(dirname):
Expand All @@ -217,7 +227,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5):
# pylint: disable=W0703
try:
print('Downloading %s from %s...'%(fname, url))
r = requests.get(url, stream=True)
r = requests.get(url, stream=True, verify=verify_ssl)
if r.status_code != 200:
raise RuntimeError("Failed downloading url %s"%url)
with open(fname, 'wb') as f:
Expand Down
44 changes: 40 additions & 4 deletions tests/python/unittest/test_gluon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,56 @@
# specific language governing permissions and limitations
# under the License.

import io
import os
import tempfile
import warnings

try:
from unittest import mock
except ImportError:
import mock
import mxnet as mx
from nose.tools import *
import requests
from nose.tools import raises


class MockResponse(requests.Response):
def __init__(self, status_code, content):
super(MockResponse, self).__init__()
assert isinstance(status_code, int)
self.status_code = status_code
self.raw = io.BytesIO(content.encode('utf-8'))


@raises(Exception)
@mock.patch(
'requests.get', mock.Mock(side_effect=requests.exceptions.ConnectionError))
def test_download_retries():
mx.gluon.utils.download("http://doesnotexist.notfound")


@mock.patch(
'requests.get',
mock.Mock(side_effect=
lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT' * 100)))
def test_download_successful():
tmp = tempfile.mkdtemp()
tmpfile = os.path.join(tmp, 'README.md')
mx.gluon.utils.download("https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
path=tmpfile)
assert os.path.getsize(tmpfile) > 100
mx.gluon.utils.download(
"https://raw.githubusercontent.com/apache/incubator-mxnet/master/README.md",
path=tmpfile)
assert os.path.getsize(tmpfile) > 100


@mock.patch(
'requests.get',
mock.Mock(
side_effect=lambda *args, **kwargs: MockResponse(200, 'MOCK CONTENT')))
def test_download_ssl_verify():
with warnings.catch_warnings(record=True) as warnings_:
mx.gluon.utils.download(
"https://mxnet.incubator.apache.org/index.html", verify_ssl=False)
assert any(
str(w.message).startswith('Unverified HTTPS request')
for w in warnings_)