diff --git a/datarobot_batch_scoring/reader.py b/datarobot_batch_scoring/reader.py index 7e4bd1b4..9b27f7c1 100644 --- a/datarobot_batch_scoring/reader.py +++ b/datarobot_batch_scoring/reader.py @@ -383,7 +383,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.p.terminate() -def sniff_dialect(sample, encoding, sep, skip_dialect, ui): +def sniff_dialect(sample, sep, skip_dialect, ui): t1 = time() try: if skip_dialect: @@ -396,11 +396,11 @@ def sniff_dialect(sample, encoding, sep, skip_dialect, ui): dialect = csv.get_dialect('dataset_dialect') else: sniffer = csv.Sniffer() - dialect = sniffer.sniff(sample.decode(encoding), delimiters=sep) + dialect = sniffer.sniff(sample, delimiters=sep) ui.debug('investigate_encoding_and_dialect - seconds to detect ' 'csv dialect: {}'.format(time() - t1)) except csv.Error: - decoded_one = sample.decode(encoding) + decoded_one = sample t2 = time() detector = Detector() delimiter, resampled = detector.detect(decoded_one) @@ -432,6 +432,19 @@ def sniff_dialect(sample, encoding, sep, skip_dialect, ui): return dialect +def get_opener_and_mode(is_gz, text=False): + mode = 'r' if text else 'rb' + if is_gz: + return (gzip.open, mode) + elif six.PY2: + if text: + from io import open as io_open + return (io_open, 'r') + return (open, 'rU') + else: + return (open, mode) + + def investigate_encoding_and_dialect(dataset, sep, ui, fast=False, encoding=None, skip_dialect=False, output_delimiter=None): @@ -445,10 +458,7 @@ def investigate_encoding_and_dialect(dataset, sep, ui, fast=False, sample_size = DETECT_SAMPLE_SIZE_SLOW is_gz = dataset.endswith('.gz') - opener, mode = ( - (gzip.open, 'rb') if is_gz - else (open, ('rU' if six.PY2 else 'rb')) - ) + opener, mode = get_opener_and_mode(is_gz, text=True) with opener(dataset, mode) as dfile: sample = dfile.read(sample_size) @@ -462,8 +472,12 @@ def investigate_encoding_and_dialect(dataset, sep, ui, fast=False, encoding = encoding.lower() sample[:1000].decode(encoding) # Fail here if the encoding is invalid + opener, mode = get_opener_and_mode(is_gz, text=True) + with opener(dataset, mode, encoding=encoding) as dfile: + sample = dfile.read(sample_size) + try: - dialect = sniff_dialect(sample, encoding, sep, skip_dialect, ui) + dialect = sniff_dialect(sample, sep, skip_dialect, ui) except csv.Error as ex: ui.fatal(ex) if len(sample) < 10: