diff --git a/datarobot_batch_scoring/reader.py b/datarobot_batch_scoring/reader.py index 7e4bd1b4..b79aebec 100644 --- a/datarobot_batch_scoring/reader.py +++ b/datarobot_batch_scoring/reader.py @@ -17,7 +17,7 @@ REPORT_INTERVAL, ProgressQueueMsg) from datarobot_batch_scoring.detect import Detector -from datarobot_batch_scoring.utils import get_rusage, SerializableDialect +from datarobot_batch_scoring.utils import get_rusage, gzip_with_encoding, SerializableDialect if six.PY2: @@ -383,7 +383,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.p.terminate() -def sniff_dialect(sample, encoding, sep, skip_dialect, ui): +def sniff_dialect(sample, sep, skip_dialect, ui): t1 = time() try: if skip_dialect: @@ -396,11 +396,11 @@ def sniff_dialect(sample, encoding, sep, skip_dialect, ui): dialect = csv.get_dialect('dataset_dialect') else: sniffer = csv.Sniffer() - dialect = sniffer.sniff(sample.decode(encoding), delimiters=sep) + dialect = sniffer.sniff(sample, delimiters=sep) ui.debug('investigate_encoding_and_dialect - seconds to detect ' 'csv dialect: {}'.format(time() - t1)) except csv.Error: - decoded_one = sample.decode(encoding) + decoded_one = sample t2 = time() detector = Detector() delimiter, resampled = detector.detect(decoded_one) @@ -432,6 +432,23 @@ def sniff_dialect(sample, encoding, sep, skip_dialect, ui): return dialect +def get_opener_and_mode(is_gz, text=False): + mode = 'rt' if text else 'rb' + if is_gz: + if six.PY2: + return (gzip_with_encoding, mode) + else: + return (gzip.open, mode) + else: + if six.PY2: + if text: + from io import open as io_open + return (io_open, 'r') + return (open, 'rU') + else: + return (open, mode) + + def investigate_encoding_and_dialect(dataset, sep, ui, fast=False, encoding=None, skip_dialect=False, output_delimiter=None): @@ -443,12 +460,8 @@ def investigate_encoding_and_dialect(dataset, sep, ui, fast=False, sample_size = DETECT_SAMPLE_SIZE_FAST else: sample_size = DETECT_SAMPLE_SIZE_SLOW - is_gz = dataset.endswith('.gz') - opener, mode = ( - (gzip.open, 'rb') if is_gz - else (open, ('rU' if six.PY2 else 'rb')) - ) + opener, mode = get_opener_and_mode(is_gz) with opener(dataset, mode) as dfile: sample = dfile.read(sample_size) @@ -462,8 +475,12 @@ def investigate_encoding_and_dialect(dataset, sep, ui, fast=False, encoding = encoding.lower() sample[:1000].decode(encoding) # Fail here if the encoding is invalid + opener, mode = get_opener_and_mode(is_gz, text=True) + with opener(dataset, mode, encoding=encoding) as dfile: + sample = dfile.read(sample_size) + try: - dialect = sniff_dialect(sample, encoding, sep, skip_dialect, ui) + dialect = sniff_dialect(sample, sep, skip_dialect, ui) except csv.Error as ex: ui.fatal(ex) if len(sample) < 10: @@ -518,15 +535,11 @@ def auto_sampler(dataset, encoding, ui): sample_size = AUTO_SAMPLE_SIZE is_gz = dataset.endswith('.gz') - opener, mode = (gzip.open, 'rb') if is_gz else (open, 'rU') - with opener(dataset, mode) as dfile: + opener, mode = get_opener_and_mode(is_gz, text=True) + with opener(dataset, mode, encoding=encoding) as dfile: sample = dfile.read(sample_size) - if six.PY3 and not is_gz: - sample = sample.encode(encoding or 'utf-8') - - ingestable_sample = sample.decode(encoding) - size_bytes = sys.getsizeof(ingestable_sample.encode('utf-8')) + size_bytes = sys.getsizeof(sample.encode('utf-8')) if size_bytes < (sample_size * 0.75): # if dataset is tiny, don't bother auto sampling. ui.info('auto_sampler: total time seconds - {}'.format(time() - t0)) @@ -536,15 +549,16 @@ def auto_sampler(dataset, encoding, ui): if six.PY3: buf = io.StringIO() - buf.write(ingestable_sample) + buf.write(sample) else: buf = StringIO.StringIO() - buf.write(sample) + buf.write(sample.encode('utf-8')) buf.seek(0) + file_lines, csv_lines = 0, 0 dialect = csv.get_dialect('dataset_dialect') - fd = Recoder(buf, encoding) - reader = csv.reader(fd, dialect=dialect, delimiter=dialect.delimiter) + #fd = Recoder(buf, encoding) + reader = csv.reader(buf, dialect=dialect, delimiter=dialect.delimiter) line_pos = [] for _ in buf: file_lines += 1 @@ -559,7 +573,6 @@ def auto_sampler(dataset, encoding, ui): # PRED-1240 there's no guarantee that we got _any_ fully formed lines. # If so, the dataset is super wide, so we only send 10 rows at a time return AUTO_SAMPLE_FALLBACK - try: for _ in reader: csv_lines += 1 diff --git a/datarobot_batch_scoring/utils.py b/datarobot_batch_scoring/utils.py index 8263bbef..8930296d 100644 --- a/datarobot_batch_scoring/utils.py +++ b/datarobot_batch_scoring/utils.py @@ -1,9 +1,12 @@ +import codecs import csv import getpass +import gzip import io import logging import os import sys +from contextlib import contextmanager from collections import namedtuple from functools import partial from gzip import GzipFile @@ -518,3 +521,16 @@ def state(self, status): def state_name(self, s=None): return self.state_names[s or self.state] + + +@contextmanager +def gzip_with_encoding(data, mode, encoding=None): + """ Decorator to support encoding for gzip in PY2 + """ + if encoding is not None: + reader = codecs.getreader(encoding) + with gzip.open(data, mode) as f: + yield reader(f) + else: + with gzip.open(data, mode) as f: + yield f diff --git a/tests/test_utils.py b/tests/test_utils.py index a909d009..b28bddbd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -210,9 +210,9 @@ def test_investigate_encoding_and_dialect_substitute_delimiter(): data = 'tests/fixtures/windows_encoded.csv' encoding = investigate_encoding_and_dialect(data, '|', ui, fast=False, - encoding='utf-8', + encoding='', skip_dialect=True) - assert encoding == 'utf-8' # Intentionally wrong + assert encoding == 'windows-1252' assert not sn.called dialect = csv.get_dialect('dataset_dialect') assert dialect.delimiter == '|'