Skip to content
This repository has been archived by the owner on Sep 9, 2020. It is now read-only.

Commit

Permalink
[PRED-2644] Fix decoding error in case of dialect detection
Browse files Browse the repository at this point in the history
We open some file in binary mode and read some N bytes to detect encoding.
Later we use this bytes to detect dialect but before it we decode() them
into string(unicode). Because we have const number of N, it's possible that
during bytes reading last character may be torn apart and then during decode()
we can't identify that character.
  • Loading branch information
falkerson committed Jul 15, 2019
1 parent 13da21e commit 78b621c
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 22 deletions.
58 changes: 36 additions & 22 deletions datarobot_batch_scoring/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
REPORT_INTERVAL,
ProgressQueueMsg)
from datarobot_batch_scoring.detect import Detector
from datarobot_batch_scoring.utils import get_rusage, SerializableDialect
from datarobot_batch_scoring.utils import get_rusage, gzip_with_encoding, SerializableDialect


if six.PY2:
Expand Down Expand Up @@ -383,7 +383,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
self.p.terminate()


def sniff_dialect(sample, encoding, sep, skip_dialect, ui):
def sniff_dialect(sample, sep, skip_dialect, ui):
t1 = time()
try:
if skip_dialect:
Expand All @@ -396,11 +396,11 @@ def sniff_dialect(sample, encoding, sep, skip_dialect, ui):
dialect = csv.get_dialect('dataset_dialect')
else:
sniffer = csv.Sniffer()
dialect = sniffer.sniff(sample.decode(encoding), delimiters=sep)
dialect = sniffer.sniff(sample, delimiters=sep)
ui.debug('investigate_encoding_and_dialect - seconds to detect '
'csv dialect: {}'.format(time() - t1))
except csv.Error:
decoded_one = sample.decode(encoding)
decoded_one = sample
t2 = time()
detector = Detector()
delimiter, resampled = detector.detect(decoded_one)
Expand Down Expand Up @@ -432,6 +432,23 @@ def sniff_dialect(sample, encoding, sep, skip_dialect, ui):
return dialect


def get_opener_and_mode(is_gz, text=False):
mode = 'r' if text else 'rb'
if is_gz:
if six.PY2:
return (gzip_with_encoding, mode)
else:
return (gzip.open, mode)
else:
if six.PY2:
if text:
from io import open as io_open
return (io_open, 'r')
return (open, 'rU')
else:
return (open, mode)


def investigate_encoding_and_dialect(dataset, sep, ui, fast=False,
encoding=None, skip_dialect=False,
output_delimiter=None):
Expand All @@ -443,12 +460,8 @@ def investigate_encoding_and_dialect(dataset, sep, ui, fast=False,
sample_size = DETECT_SAMPLE_SIZE_FAST
else:
sample_size = DETECT_SAMPLE_SIZE_SLOW

is_gz = dataset.endswith('.gz')
opener, mode = (
(gzip.open, 'rb') if is_gz
else (open, ('rU' if six.PY2 else 'rb'))
)
opener, mode = get_opener_and_mode(is_gz)
with opener(dataset, mode) as dfile:
sample = dfile.read(sample_size)

Expand All @@ -462,8 +475,12 @@ def investigate_encoding_and_dialect(dataset, sep, ui, fast=False,
encoding = encoding.lower()
sample[:1000].decode(encoding) # Fail here if the encoding is invalid

opener, mode = get_opener_and_mode(is_gz, text=True)
with opener(dataset, mode, encoding=encoding) as dfile:
sample = dfile.read(sample_size)

try:
dialect = sniff_dialect(sample, encoding, sep, skip_dialect, ui)
dialect = sniff_dialect(sample, sep, skip_dialect, ui)
except csv.Error as ex:
ui.fatal(ex)
if len(sample) < 10:
Expand Down Expand Up @@ -518,15 +535,11 @@ def auto_sampler(dataset, encoding, ui):

sample_size = AUTO_SAMPLE_SIZE
is_gz = dataset.endswith('.gz')
opener, mode = (gzip.open, 'rb') if is_gz else (open, 'rU')
with opener(dataset, mode) as dfile:
opener, mode = get_opener_and_mode(is_gz, text=True)
with opener(dataset, mode, encoding=encoding) as dfile:
sample = dfile.read(sample_size)
if six.PY3 and not is_gz:
sample = sample.encode(encoding or 'utf-8')

ingestable_sample = sample.decode(encoding)
size_bytes = sys.getsizeof(ingestable_sample.encode('utf-8'))

size_bytes = sys.getsizeof(sample.encode('utf-8'))
if size_bytes < (sample_size * 0.75):
# if dataset is tiny, don't bother auto sampling.
ui.info('auto_sampler: total time seconds - {}'.format(time() - t0))
Expand All @@ -536,15 +549,16 @@ def auto_sampler(dataset, encoding, ui):

if six.PY3:
buf = io.StringIO()
buf.write(ingestable_sample)
buf.write(sample)
else:
buf = StringIO.StringIO()
buf.write(sample)
buf.write(sample.encode('utf-8'))
buf.seek(0)

file_lines, csv_lines = 0, 0
dialect = csv.get_dialect('dataset_dialect')
fd = Recoder(buf, encoding)
reader = csv.reader(fd, dialect=dialect, delimiter=dialect.delimiter)
#fd = Recoder(buf, encoding)
reader = csv.reader(buf, dialect=dialect, delimiter=dialect.delimiter)
line_pos = []
for _ in buf:
file_lines += 1
Expand All @@ -559,7 +573,7 @@ def auto_sampler(dataset, encoding, ui):
# PRED-1240 there's no guarantee that we got _any_ fully formed lines.
# If so, the dataset is super wide, so we only send 10 rows at a time
return AUTO_SAMPLE_FALLBACK

import pdb;pdb.set_trace()
try:
for _ in reader:
csv_lines += 1
Expand Down
16 changes: 16 additions & 0 deletions datarobot_batch_scoring/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import codecs
import csv
import getpass
import gzip
import io
import logging
import os
import sys
from contextlib import contextmanager
from collections import namedtuple
from functools import partial
from gzip import GzipFile
Expand Down Expand Up @@ -518,3 +521,16 @@ def state(self, status):

def state_name(self, s=None):
return self.state_names[s or self.state]


@contextmanager
def gzip_with_encoding(data, mode, encoding=None):
""" Decorator to support encoding for gzip in PY2
"""
if encoding is not None:
reader = codecs.getreader(encoding)
with gzip.open(data, mode) as f:
yield reader(f)
else:
with gzip.open(data, mode) as f:
yield f

0 comments on commit 78b621c

Please sign in to comment.