Skip to content

Commit

Permalink
Fix data_utils.py when name ends with .tar.gz
Browse files Browse the repository at this point in the history
1. was giving an `UnboundLocal` error.
2. If `fpath` and `untar_fpath` are set to the same value the `untar` doesn't run.
So ensure that `fname` always has `.tar.gz`, and `untar_fpath` never does.

The old behavior (before the previous change) was to always add `.tar.gz`, giving `.tar.gz.tar.gz` files. That seems wrong too.

Fixes tensorflow_docs/site/en/tutorials/load_data/text.ipynb

PiperOrigin-RevId: 380934634
  • Loading branch information
MarkDaoust authored and tensorflower-gardener committed Jun 23, 2021
1 parent 3f90361 commit 2f1149d
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
10 changes: 8 additions & 2 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import hashlib
import multiprocessing.dummy
import os
import pathlib
import queue
import random
import shutil
Expand Down Expand Up @@ -227,9 +228,14 @@ def get_file(fname=None,
raise ValueError("Invalid origin '{}'".format(origin))

if untar:
if fname.endswith('.tar.gz'):
fname = pathlib.Path(fname)
# The 2 `.with_suffix()` are because of `.tar.gz` as pathlib
# considers it as 2 suffixes.
fname = fname.with_suffix('').with_suffix('')
fname = str(fname)
untar_fpath = os.path.join(datadir, fname)
if not untar_fpath.endswith('.tar.gz'):
fpath = untar_fpath + '.tar.gz'
fpath = untar_fpath + '.tar.gz'
else:
fpath = os.path.join(datadir, fname)

Expand Down
24 changes: 23 additions & 1 deletion keras/utils/data_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from keras.utils import data_utils


class TestGetFileAndValidateIt(tf.test.TestCase):
class TestGetFile(tf.test.TestCase):

def test_get_file_and_validate_it(self):
"""Tests get_file from a url, plus extraction and validation.
Expand Down Expand Up @@ -104,6 +104,28 @@ def test_get_file_and_validate_it(self):
with self.assertRaisesRegexp(ValueError, 'Please specify the "origin".*'):
_ = keras.utils.data_utils.get_file()

def test_get_file_with_tgz_extension(self):
"""Tests get_file from a url, plus extraction and validation."""
dest_dir = self.get_temp_dir()
orig_dir = self.get_temp_dir()

text_file_path = os.path.join(orig_dir, 'test.txt')
tar_file_path = os.path.join(orig_dir, 'test.tar.gz')

with open(text_file_path, 'w') as text_file:
text_file.write('Float like a butterfly, sting like a bee.')

with tarfile.open(tar_file_path, 'w:gz') as tar_file:
tar_file.add(text_file_path)

origin = urllib.parse.urljoin(
'file://', urllib.request.pathname2url(os.path.abspath(tar_file_path)))

path = keras.utils.data_utils.get_file(
'test.txt.tar.gz', origin, untar=True, cache_subdir=dest_dir)
self.assertEndsWith(path, '.txt')
self.assertTrue(os.path.exists(path))


class TestSequence(keras.utils.data_utils.Sequence):

Expand Down

0 comments on commit 2f1149d

Please sign in to comment.