diff --git a/.travis.yml b/.travis.yml index 3485d4bb..b34d1781 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,6 +6,7 @@ env: - VERSION=20.5.5.74 - VERSION=20.4.9.110 - VERSION=20.3.20.6 + - VERSION=20.3.20.6 USE_NUMPY=1 - VERSION=19.16.17.80 - VERSION=19.15.3.6 - VERSION=19.9.2.4 # allow_suspicious_low_cardinality_types @@ -65,11 +66,23 @@ install: - pip install --upgrade pip setuptools - pip install flake8 flake8-print coveralls cython script: + - if [ -z ${USE_NUMPY+x} ]; then pip uninstall -y numpy pandas; fi - flake8 && coverage run --source=clickhouse_driver setup.py test after_success: coveralls jobs: + # Exclude numpy unsupported versions, + exclude: + - python: 3.4 + env: VERSION=20.3.20.6 USE_NUMPY=1 + - python: 3.9-dev + env: VERSION=20.3.20.6 USE_NUMPY=1 + - python: pypy2.7-5.10.0 + env: VERSION=20.3.20.6 USE_NUMPY=1 + - python: pypy3.5 + env: VERSION=20.3.20.6 USE_NUMPY=1 + include: - stage: valgrind name: Valgrind check @@ -100,6 +113,7 @@ jobs: env: - VERSION=20.3.7.46 + - USE_NUMPY=1 - PYTHONMALLOC=malloc - stage: wheels diff --git a/clickhouse_driver/__init__.py b/clickhouse_driver/__init__.py index 9cc3b61d..0e344f25 100644 --- a/clickhouse_driver/__init__.py +++ b/clickhouse_driver/__init__.py @@ -3,7 +3,7 @@ from .dbapi import connect -VERSION = (0, 1, 5) +VERSION = (0, 1, 6) __version__ = '.'.join(str(x) for x in VERSION) __all__ = ['Client', 'connect'] diff --git a/clickhouse_driver/client.py b/clickhouse_driver/client.py index 8df773e2..aff534bc 100644 --- a/clickhouse_driver/client.py +++ b/clickhouse_driver/client.py @@ -1,3 +1,4 @@ +import re import ssl from time import time import types @@ -33,12 +34,15 @@ class Client(object): * strings_encoding -- specifies string encoding. UTF-8 by default. + * use_numpy -- Use numpy for columns reading. + """ available_client_settings = ( 'insert_block_size', # TODO: rename to max_insert_block_size 'strings_as_bytes', - 'strings_encoding' + 'strings_encoding', + 'use_numpy' ) def __init__(self, *args, **kwargs): @@ -53,9 +57,28 @@ def __init__(self, *args, **kwargs): ), 'strings_encoding': self.settings.pop( 'strings_encoding', defines.STRINGS_ENCODING + ), + 'use_numpy': self.settings.pop( + 'use_numpy', False ) } + if self.client_settings['use_numpy']: + try: + from .numpy.result import ( + NumpyIterQueryResult, NumpyProgressQueryResult, + NumpyQueryResult + ) + self.query_result_cls = NumpyQueryResult + self.iter_query_result_cls = NumpyIterQueryResult + self.progress_query_result_cls = NumpyProgressQueryResult + except ImportError: + raise RuntimeError('Extras for NumPy must be installed') + else: + self.query_result_cls = QueryResult + self.iter_query_result_cls = IterQueryResult + self.progress_query_result_cls = ProgressQueryResult + self.connection = Connection(*args, **kwargs) self.connection.context.settings = self.settings self.connection.context.client_settings = self.client_settings @@ -78,12 +101,12 @@ def receive_result(self, with_column_types=False, progress=False, gen = self.packet_generator() if progress: - return ProgressQueryResult( + return self.progress_query_result_cls( gen, with_column_types=with_column_types, columnar=columnar ) else: - result = QueryResult( + result = self.query_result_cls( gen, with_column_types=with_column_types, columnar=columnar ) return result.get_result() @@ -91,7 +114,11 @@ def receive_result(self, with_column_types=False, progress=False, def iter_receive_result(self, with_column_types=False): gen = self.packet_generator() - for rows in IterQueryResult(gen, with_column_types=with_column_types): + result = self.iter_query_result_cls( + gen, with_column_types=with_column_types + ) + + for rows in result: for row in rows: yield row @@ -318,6 +345,23 @@ def execute_iter( self.disconnect() raise + def query_dataframe(self, query, params=None, external_tables=None, + query_id=None, settings=None): + try: + import pandas as pd + except ImportError: + raise RuntimeError('Extras for NumPy must be installed') + + data, columns = self.execute( + query, columnar=True, with_column_types=True, params=params, + external_tables=external_tables, query_id=query_id, + settings=settings + ) + + return pd.DataFrame( + {re.sub(r'\W', '_', col[0]): d for d, col in zip(data, columns)} + ) + def process_ordinary_query_with_progress( self, query, params=None, with_column_types=False, external_tables=None, query_id=None, @@ -487,6 +531,9 @@ def from_url(cls, url): elif name == 'secure': kwargs[name] = asbool(value) + elif name == 'use_numpy': + kwargs[name] = asbool(value) + elif name == 'client_name': kwargs[name] = value diff --git a/clickhouse_driver/columns/floatcolumn.py b/clickhouse_driver/columns/floatcolumn.py index 2cce4b26..37da71ed 100644 --- a/clickhouse_driver/columns/floatcolumn.py +++ b/clickhouse_driver/columns/floatcolumn.py @@ -8,12 +8,12 @@ class FloatColumn(FormatColumn): py_types = (float, ) + compat.integer_types -class Float32(FloatColumn): +class Float32Column(FloatColumn): ch_type = 'Float32' format = 'f' def __init__(self, types_check=False, **kwargs): - super(Float32, self).__init__(types_check=types_check, **kwargs) + super(Float32Column, self).__init__(types_check=types_check, **kwargs) if types_check: # Chop only bytes that fit current type. @@ -30,6 +30,6 @@ def before_write_items(items, nulls_map=None): self.before_write_items = before_write_items -class Float64(FloatColumn): +class Float64Column(FloatColumn): ch_type = 'Float64' format = 'd' diff --git a/clickhouse_driver/columns/numpy/__init__.py b/clickhouse_driver/columns/numpy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/clickhouse_driver/columns/numpy/base.py b/clickhouse_driver/columns/numpy/base.py new file mode 100644 index 00000000..bc49a578 --- /dev/null +++ b/clickhouse_driver/columns/numpy/base.py @@ -0,0 +1,14 @@ +import numpy as np + +from ..base import Column + + +class NumpyColumn(Column): + dtype = None + + def read_items(self, n_items, buf): + data = buf.read(n_items * self.dtype.itemsize) + return np.frombuffer(data, self.dtype, n_items) + + def write_items(self, items, buf): + raise RuntimeError('Write is not implemented') diff --git a/clickhouse_driver/columns/numpy/datecolumn.py b/clickhouse_driver/columns/numpy/datecolumn.py new file mode 100644 index 00000000..61bd9393 --- /dev/null +++ b/clickhouse_driver/columns/numpy/datecolumn.py @@ -0,0 +1,12 @@ +import numpy as np + +from .base import NumpyColumn + + +class NumpyDateColumn(NumpyColumn): + dtype = np.dtype(np.uint16) + ch_type = 'Date' + + def read_items(self, n_items, buf): + data = super(NumpyDateColumn, self).read_items(n_items, buf) + return data.astype('datetime64[D]') diff --git a/clickhouse_driver/columns/numpy/datetimecolumn.py b/clickhouse_driver/columns/numpy/datetimecolumn.py new file mode 100644 index 00000000..0dea1a77 --- /dev/null +++ b/clickhouse_driver/columns/numpy/datetimecolumn.py @@ -0,0 +1,102 @@ +try: + import numpy as np +except ImportError: + numpy = None + +try: + import pandas as pd +except ImportError: + pandas = None + +from pytz import timezone as get_timezone +from tzlocal import get_localzone + +from .base import NumpyColumn + + +class NumpyDateTimeColumn(NumpyColumn): + dtype = np.dtype(np.uint32) + + def __init__(self, timezone=None, offset_naive=True, local_timezone=None, + **kwargs): + self.timezone = timezone + self.offset_naive = offset_naive + self.local_timezone = local_timezone + super(NumpyDateTimeColumn, self).__init__(**kwargs) + + def apply_timezones(self, dt): + ts = pd.to_datetime(dt, utc=True) + timezone = self.timezone if self.timezone else self.local_timezone + + ts = ts.tz_convert(timezone) + if self.offset_naive: + ts = ts.tz_localize(None) + + return ts.to_numpy() + + def read_items(self, n_items, buf): + data = super(NumpyDateTimeColumn, self).read_items(n_items, buf) + dt = data.astype('datetime64[s]') + return self.apply_timezones(dt) + + +class NumpyDateTime64Column(NumpyDateTimeColumn): + dtype = np.dtype(np.uint64) + + max_scale = 6 + + def __init__(self, scale=0, **kwargs): + self.scale = scale + super(NumpyDateTime64Column, self).__init__(**kwargs) + + def read_items(self, n_items, buf): + scale = 10 ** self.scale + frac_scale = 10 ** (self.max_scale - self.scale) + + data = super(NumpyDateTimeColumn, self).read_items(n_items, buf) + seconds = (data // scale).astype('datetime64[s]') + microseconds = ((data % scale) * frac_scale).astype('timedelta64[us]') + + dt = seconds + microseconds + return self.apply_timezones(dt) + + +def create_numpy_datetime_column(spec, column_options): + if spec.startswith('DateTime64'): + cls = NumpyDateTime64Column + spec = spec[11:-1] + params = spec.split(',', 1) + column_options['scale'] = int(params[0]) + if len(params) > 1: + spec = params[1].strip() + ')' + else: + cls = NumpyDateTimeColumn + spec = spec[9:] + + context = column_options['context'] + + tz_name = timezone = None + offset_naive = True + local_timezone = None + + # As Numpy do not use local timezone for converting timestamp to + # datetime we need always detect local timezone for manual converting. + try: + local_timezone = get_localzone().zone + except Exception: + pass + + # Use column's timezone if it's specified. + if spec and spec[-1] == ')': + tz_name = spec[1:-2] + offset_naive = False + else: + if not context.settings.get('use_client_time_zone', False): + if local_timezone != context.server_info.timezone: + tz_name = context.server_info.timezone + + if tz_name: + timezone = get_timezone(tz_name) + + return cls(timezone=timezone, offset_naive=offset_naive, + local_timezone=local_timezone, **column_options) diff --git a/clickhouse_driver/columns/numpy/floatcolumn.py b/clickhouse_driver/columns/numpy/floatcolumn.py new file mode 100644 index 00000000..bbf2db64 --- /dev/null +++ b/clickhouse_driver/columns/numpy/floatcolumn.py @@ -0,0 +1,13 @@ +import numpy as np + +from .base import NumpyColumn + + +class NumpyFloat32Column(NumpyColumn): + dtype = np.dtype(np.float32) + ch_type = 'Float32' + + +class NumpyFloat64Column(NumpyColumn): + dtype = np.dtype(np.float64) + ch_type = 'Float64' diff --git a/clickhouse_driver/columns/numpy/intcolumn.py b/clickhouse_driver/columns/numpy/intcolumn.py new file mode 100644 index 00000000..3da12aa4 --- /dev/null +++ b/clickhouse_driver/columns/numpy/intcolumn.py @@ -0,0 +1,43 @@ +import numpy as np + +from .base import NumpyColumn + + +class NumpyInt8Column(NumpyColumn): + dtype = np.dtype(np.int8) + ch_type = 'Int8' + + +class NumpyUInt8Column(NumpyColumn): + dtype = np.dtype(np.uint8) + ch_type = 'UInt8' + + +class NumpyInt16Column(NumpyColumn): + dtype = np.dtype(np.int16) + ch_type = 'Int16' + + +class NumpyUInt16Column(NumpyColumn): + dtype = np.dtype(np.uint16) + ch_type = 'UInt16' + + +class NumpyInt32Column(NumpyColumn): + dtype = np.dtype(np.int32) + ch_type = 'Int32' + + +class NumpyUInt32Column(NumpyColumn): + dtype = np.dtype(np.uint32) + ch_type = 'UInt32' + + +class NumpyInt64Column(NumpyColumn): + dtype = np.dtype(np.int64) + ch_type = 'Int64' + + +class NumpyUInt64Column(NumpyColumn): + dtype = np.dtype(np.uint64) + ch_type = 'UInt64' diff --git a/clickhouse_driver/columns/numpy/lowcardinalitycolumn.py b/clickhouse_driver/columns/numpy/lowcardinalitycolumn.py new file mode 100644 index 00000000..1ab00e06 --- /dev/null +++ b/clickhouse_driver/columns/numpy/lowcardinalitycolumn.py @@ -0,0 +1,56 @@ +import pandas as pd + +from ..lowcardinalitycolumn import LowCardinalityColumn +from ...reader import read_binary_uint64 +from .intcolumn import ( + NumpyUInt8Column, NumpyUInt16Column, NumpyUInt32Column, NumpyUInt64Column +) + + +class NumpyLowCardinalityColumn(LowCardinalityColumn): + int_types = { + 0: NumpyUInt8Column, + 1: NumpyUInt16Column, + 2: NumpyUInt32Column, + 3: NumpyUInt64Column + } + + def __init__(self, nested_column, **kwargs): + super(NumpyLowCardinalityColumn, self).__init__(nested_column, + **kwargs) + + def _read_data(self, n_items, buf, nulls_map=None): + if not n_items: + return tuple() + + serialization_type = read_binary_uint64(buf) + + # Lowest byte contains info about key type. + key_type = serialization_type & 0xf + keys_column = self.int_types[key_type]() + + nullable = self.nested_column.nullable + # Prevent null map reading. Reset nested column nullable flag. + self.nested_column.nullable = False + + index_size = read_binary_uint64(buf) + index = self.nested_column.read_data(index_size, buf) + + read_binary_uint64(buf) # number of keys + keys = keys_column.read_data(n_items, buf) + + if nullable: + # Shift all codes by one ("No value" code is -1 for pandas + # categorical) and drop corresponding first index + # this is analog of original operation: + # index = (None, ) + index[1:] + keys = keys - 1 + index = index[1:] + result = pd.Categorical.from_codes(keys, index) + return result + + +def create_numpy_low_cardinality_column(spec, column_by_spec_getter): + inner = spec[15:-1] + nested = column_by_spec_getter(inner) + return NumpyLowCardinalityColumn(nested) diff --git a/clickhouse_driver/columns/numpy/service.py b/clickhouse_driver/columns/numpy/service.py new file mode 100644 index 00000000..c5ecd313 --- /dev/null +++ b/clickhouse_driver/columns/numpy/service.py @@ -0,0 +1,80 @@ +from ... import errors +from ..arraycolumn import create_array_column +from .datecolumn import NumpyDateColumn +from .datetimecolumn import create_numpy_datetime_column +from ..decimalcolumn import create_decimal_column +from ..enumcolumn import create_enum_column +from .floatcolumn import NumpyFloat32Column, NumpyFloat64Column +from .intcolumn import ( + NumpyInt8Column, NumpyInt16Column, NumpyInt32Column, NumpyInt64Column, + NumpyUInt8Column, NumpyUInt16Column, NumpyUInt32Column, NumpyUInt64Column +) +from .lowcardinalitycolumn import create_numpy_low_cardinality_column +from ..nothingcolumn import NothingColumn +from ..nullcolumn import NullColumn +# from .nullablecolumn import create_nullable_column +from ..simpleaggregatefunctioncolumn import ( + create_simple_aggregate_function_column +) +from .stringcolumn import create_string_column +from ..tuplecolumn import create_tuple_column +from ..uuidcolumn import UUIDColumn +from ..intervalcolumn import ( + IntervalYearColumn, IntervalMonthColumn, IntervalWeekColumn, + IntervalDayColumn, IntervalHourColumn, IntervalMinuteColumn, + IntervalSecondColumn +) +from ..ipcolumn import IPv4Column, IPv6Column + +column_by_type = {c.ch_type: c for c in [ + NumpyDateColumn, + NumpyFloat32Column, NumpyFloat64Column, + NumpyInt8Column, NumpyInt16Column, NumpyInt32Column, NumpyInt64Column, + NumpyUInt8Column, NumpyUInt16Column, NumpyUInt32Column, NumpyUInt64Column, + NothingColumn, NullColumn, UUIDColumn, + IntervalYearColumn, IntervalMonthColumn, IntervalWeekColumn, + IntervalDayColumn, IntervalHourColumn, IntervalMinuteColumn, + IntervalSecondColumn, IPv4Column, IPv6Column +]} + + +def get_numpy_column_by_spec(spec, column_options): + def create_column_with_options(x): + return get_numpy_column_by_spec(x, column_options) + + if spec == 'String' or spec.startswith('FixedString'): + return create_string_column(spec, column_options) + + elif spec.startswith('Enum'): + return create_enum_column(spec, column_options) + + elif spec.startswith('DateTime'): + return create_numpy_datetime_column(spec, column_options) + + elif spec.startswith('Decimal'): + return create_decimal_column(spec, column_options) + + elif spec.startswith('Array'): + return create_array_column(spec, create_column_with_options) + + elif spec.startswith('Tuple'): + return create_tuple_column(spec, create_column_with_options) + + # elif spec.startswith('Nullable'): + # return create_nullable_column(spec, create_column_with_options) + + elif spec.startswith('LowCardinality'): + return create_numpy_low_cardinality_column(spec, + create_column_with_options) + + elif spec.startswith('SimpleAggregateFunction'): + return create_simple_aggregate_function_column( + spec, create_column_with_options) + + else: + try: + cls = column_by_type[spec] + return cls(**column_options) + + except KeyError as e: + raise errors.UnknownTypeError('Unknown type {}'.format(e.args[0])) diff --git a/clickhouse_driver/columns/numpy/stringcolumn.py b/clickhouse_driver/columns/numpy/stringcolumn.py new file mode 100644 index 00000000..a4cbebdf --- /dev/null +++ b/clickhouse_driver/columns/numpy/stringcolumn.py @@ -0,0 +1,62 @@ +import numpy as np + +from ... import defines +from .base import NumpyColumn + + +class NumpyStringColumn(NumpyColumn): + dtype = np.dtype('object') + + default_encoding = defines.STRINGS_ENCODING + + def __init__(self, encoding=default_encoding, **kwargs): + self.encoding = encoding + super(NumpyStringColumn, self).__init__(**kwargs) + + def read_items(self, n_items, buf): + return np.array( + buf.read_strings(n_items, encoding=self.encoding), dtype=self.dtype + ) + + +class NumpyByteStringColumn(NumpyColumn): + def read_items(self, n_items, buf): + return np.array(buf.read_strings(n_items), dtype=self.dtype) + + +class NumpyFixedString(NumpyStringColumn): + def __init__(self, length, **kwargs): + self.length = length + super(NumpyFixedString, self).__init__(**kwargs) + + def read_items(self, n_items, buf): + return np.array(buf.read_fixed_strings( + n_items, self.length, encoding=self.encoding + ), dtype=self.dtype) + + +class NumpyByteFixedString(NumpyByteStringColumn): + def __init__(self, length, **kwargs): + self.length = length + super(NumpyByteFixedString, self).__init__(**kwargs) + + def read_items(self, n_items, buf): + return np.array( + buf.read_fixed_strings(n_items, self.length), dtype=self.dtype + ) + + +def create_string_column(spec, column_options): + client_settings = column_options['context'].client_settings + strings_as_bytes = client_settings['strings_as_bytes'] + encoding = client_settings.get( + 'strings_encoding', NumpyStringColumn.default_encoding + ) + + if spec == 'String': + cls = NumpyByteStringColumn if strings_as_bytes else NumpyStringColumn + return cls(encoding=encoding, **column_options) + else: + length = int(spec[12:-1]) + cls = NumpyByteFixedString if strings_as_bytes else NumpyFixedString + return cls(length, encoding=encoding, **column_options) diff --git a/clickhouse_driver/columns/service.py b/clickhouse_driver/columns/service.py index fd34b83b..d97d78df 100644 --- a/clickhouse_driver/columns/service.py +++ b/clickhouse_driver/columns/service.py @@ -5,7 +5,7 @@ from .decimalcolumn import create_decimal_column from . import exceptions as column_exceptions from .enumcolumn import create_enum_column -from .floatcolumn import Float32, Float64 +from .floatcolumn import Float32Column, Float64Column from .intcolumn import ( Int8Column, Int16Column, Int32Column, Int64Column, UInt8Column, UInt16Column, UInt32Column, UInt64Column @@ -29,7 +29,7 @@ column_by_type = {c.ch_type: c for c in [ - DateColumn, Float32, Float64, + DateColumn, Float32Column, Float64Column, Int8Column, Int16Column, Int32Column, Int64Column, UInt8Column, UInt16Column, UInt32Column, UInt64Column, NothingColumn, NullColumn, UUIDColumn, @@ -40,6 +40,13 @@ def get_column_by_spec(spec, column_options): + context = column_options['context'] + use_numpy = context.client_settings['use_numpy'] if context else False + + if use_numpy: + from .numpy.service import get_numpy_column_by_spec + return get_numpy_column_by_spec(spec, column_options) + def create_column_with_options(x): return get_column_by_spec(x, column_options) diff --git a/clickhouse_driver/connection.py b/clickhouse_driver/connection.py index d915c902..f19e3210 100644 --- a/clickhouse_driver/connection.py +++ b/clickhouse_driver/connection.py @@ -55,6 +55,21 @@ def __init__(self, name, version_major, version_minor, version_patch, def version_tuple(self): return self.version_major, self.version_minor, self.version_patch + def __repr__(self): + version = '%s.%s.%s' % ( + self.version_major, self.version_minor, self.version_patch + ) + items = [ + ('name', self.name), + ('version', version), + ('revision', self.revision), + ('timezone', self.timezone), + ('display_name', self.display_name) + ] + + params = ', '.join('{}={}'.format(key, value) for key, value in items) + return '' % (params) + class Connection(object): """ diff --git a/clickhouse_driver/context.py b/clickhouse_driver/context.py index 4d1861f7..2c4a7bfa 100644 --- a/clickhouse_driver/context.py +++ b/clickhouse_driver/context.py @@ -29,3 +29,8 @@ def client_settings(self): @client_settings.setter def client_settings(self, value): self._client_settings = value.copy() + + def __repr__(self): + return '' % ( + self._server_info, self._client_settings, self._settings + ) diff --git a/clickhouse_driver/numpy/__init__.py b/clickhouse_driver/numpy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/clickhouse_driver/numpy/block.py b/clickhouse_driver/numpy/block.py new file mode 100644 index 00000000..ef5d5f3d --- /dev/null +++ b/clickhouse_driver/numpy/block.py @@ -0,0 +1,8 @@ +import numpy as np + +from ..block import ColumnOrientedBlock + + +class NumpyColumnOrientedBlock(ColumnOrientedBlock): + def transposed(self): + return np.transpose(self.data) diff --git a/clickhouse_driver/numpy/result.py b/clickhouse_driver/numpy/result.py new file mode 100644 index 00000000..69535d9e --- /dev/null +++ b/clickhouse_driver/numpy/result.py @@ -0,0 +1,129 @@ +from itertools import chain + +import numpy as np +import pandas as pd +from pandas.api.types import union_categoricals + +from ..progress import Progress +from ..result import QueryResult + + +class NumpyQueryResult(QueryResult): + """ + Stores query result from multiple blocks as numpy arrays. + """ + + def store(self, packet): + block = getattr(packet, 'block', None) + if block is None: + return + + # Header block contains no rows. Pick columns from it. + if block.num_rows: + if self.columnar: + self.data.append(block.get_columns()) + else: + self.data.extend(block.get_rows()) + + elif not self.columns_with_types: + self.columns_with_types = block.columns_with_types + + def get_result(self): + """ + :return: stored query result. + """ + + for packet in self.packet_generator: + self.store(packet) + + if self.columnar: + data = [] + # Transpose to a list of columns, each column is list of chunks + for column_chunks in zip(*self.data): + # Concatenate chunks for each column + if isinstance(column_chunks[0], np.ndarray): + column = np.concatenate(column_chunks) + elif isinstance(column_chunks[0], pd.Categorical): + column = union_categoricals(column_chunks) + else: + column = tuple(chain.from_iterable(column_chunks)) + data.append(column) + else: + data = self.data + + if self.with_column_types: + return data, self.columns_with_types + else: + return data + + +class NumpyProgressQueryResult(NumpyQueryResult): + """ + Stores query result and progress information from multiple blocks. + Provides iteration over query progress. + """ + + def __init__(self, *args, **kwargs): + self.progress_totals = Progress() + + super(NumpyProgressQueryResult, self).__init__(*args, **kwargs) + + def __iter__(self): + return self + + def next(self): + while True: + packet = next(self.packet_generator) + progress_packet = getattr(packet, 'progress', None) + if progress_packet: + self.progress_totals.increment(progress_packet) + return ( + self.progress_totals.rows, self.progress_totals.total_rows + ) + else: + self.store(packet) + + # For Python 3. + __next__ = next + + def get_result(self): + # Read all progress packets. + for _ in self: + pass + + return super(NumpyProgressQueryResult, self).get_result() + + +class NumpyIterQueryResult(object): + """ + Provides iteration over returned data by chunks (streaming by chunks). + """ + + def __init__( + self, packet_generator, + with_column_types=False): + self.packet_generator = packet_generator + self.with_column_types = with_column_types + + self.first_block = True + super(NumpyIterQueryResult, self).__init__() + + def __iter__(self): + return self + + def next(self): + packet = next(self.packet_generator) + block = getattr(packet, 'block', None) + if block is None: + return [] + + if self.first_block and self.with_column_types: + self.first_block = False + rv = [block.columns_with_types] + rv.extend(block.get_rows()) + return rv + else: + return block.get_rows() + + # For Python 3. + __next__ = next diff --git a/clickhouse_driver/result.py b/clickhouse_driver/result.py index 1238015e..fc126569 100644 --- a/clickhouse_driver/result.py +++ b/clickhouse_driver/result.py @@ -66,14 +66,9 @@ class ProgressQueryResult(QueryResult): Provides iteration over query progress. """ - def __init__( - self, packet_generator, - with_column_types=False, columnar=False): + def __init__(self, *args, **kwargs): self.progress_totals = Progress() - - super(ProgressQueryResult, self).__init__( - packet_generator, with_column_types, columnar - ) + super(ProgressQueryResult, self).__init__(*args, **kwargs) def __iter__(self): return self diff --git a/clickhouse_driver/streams/native.py b/clickhouse_driver/streams/native.py index 24c767d3..3dff3217 100644 --- a/clickhouse_driver/streams/native.py +++ b/clickhouse_driver/streams/native.py @@ -75,7 +75,13 @@ def read(self): self.fin) data.append(column) - block = ColumnOrientedBlock( + if self.context.client_settings['use_numpy']: + from ..numpy.block import NumpyColumnOrientedBlock + block_cls = NumpyColumnOrientedBlock + else: + block_cls = ColumnOrientedBlock + + block = block_cls( columns_with_types=list(zip(names, types)), data=data, info=info, diff --git a/docs/features.rst b/docs/features.rst index f3c5c878..08ff4e9c 100644 --- a/docs/features.rst +++ b/docs/features.rst @@ -349,3 +349,63 @@ managers: >>> with conn.cursor() as cursor: >>> cursor.execute('SHOW TABLES') >>> print(cursor.fetchall()) + + +Reading into NumPy arrays +------------------------- + +*New in version 0.1.6.* + +Starting from version 0.1.6 package can return columns as NumPy arrays. +Additional packages are required for :ref:`installation-numpy-support`. + + .. code-block:: python + + >>> client = Client('localhost', settings={'use_numpy': True}): + >>> client.execute( + ... 'SELECT * FROM system.numbers LIMIT 10000', + ... columnar=True + ... ) + [array([ 0, 1, 2, ..., 9997, 9998, 9999], dtype=uint64)] + +Inserting using NumPy arrays currently is not supported. You can insert data +without ``use_numpy`` option. + +Supported types: + + * Float32/64 + * [U]Int8/16/32/64 + * Date/DateTime('timezone')/DateTime64('timezone') + * String/FixedString(N) + * LowCardinality(T) + +NumPy arrays are not used when reading nullable columns and columns of +unsupported types. + +Direct loading into NumPy arrays increases performance and lowers memory +requirements on large amounts of rows. + +Direct loading into pandas dataframe is also supported by using +`query_dataframe`: + + .. code-block:: python + + >>> client = Client('localhost', settings={'use_numpy': True}): + >>> client.query_dataframe(' FROM table') + ... 'SELECT number AS x, (number + 100) AS y ' + ... 'FROM system.numbers LIMIT 10000' + ... ) + x y + 0 0 100 + 1 1 101 + 2 2 102 + 3 3 103 + 4 4 104 + ... ... ... + 9995 9995 10095 + 9996 9996 10096 + 9997 9997 10097 + 9998 9998 10098 + 9999 9999 10099 + + [10000 rows x 2 columns] diff --git a/docs/installation.rst b/docs/installation.rst index 0fd36ffb..6a707320 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -77,6 +77,20 @@ Install LZ4 and ZSTD requirements: pip install clickhouse-driver[lz4,zstd] +.. _installation-numpy-support: + +NumPy support +------------- + +You can install additional packages (NumPy and Pandas) if you need NumPy support: + + .. code-block:: bash + + pip install clickhouse-driver[numpy] + +NumPy supported versions are limited by ``numpy`` package python support. + + Installation from github ------------------------ diff --git a/setup.py b/setup.py index a0e32a10..cb401599 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,8 @@ else: USE_CYTHON = True +USE_NUMPY = bool(os.getenv('USE_NUMPY', False)) + here = os.path.abspath(os.path.dirname(__file__)) @@ -55,6 +57,17 @@ def read_version(): extensions, compiler_directives={'language_level': '3'} ) +tests_require = [ + 'nose', + 'mock', + 'freezegun', + 'lz4<=3.0.1', + 'zstd', + 'clickhouse-cityhash>=1.0.2.1' +] + +if USE_NUMPY: + tests_require.extend(['numpy', 'pandas']) setup( name='clickhouse-driver', @@ -122,15 +135,9 @@ def read_version(): ext_modules=extensions, extras_require={ 'lz4': ['lz4<=3.0.1', 'clickhouse-cityhash>=1.0.2.1'], - 'zstd': ['zstd', 'clickhouse-cityhash>=1.0.2.1'] + 'zstd': ['zstd', 'clickhouse-cityhash>=1.0.2.1'], + 'numpy': ['numpy>=1.12.0', 'pandas>=0.24.0'] }, test_suite='nose.collector', - tests_require=[ - 'nose', - 'mock', - 'freezegun', - 'lz4<=3.0.1', - 'zstd', - 'clickhouse-cityhash>=1.0.2.1' - ], + tests_require=tests_require ) diff --git a/tests/columns/test_unknown.py b/tests/columns/test_unknown.py index 5889ff42..ffc98589 100644 --- a/tests/columns/test_unknown.py +++ b/tests/columns/test_unknown.py @@ -10,4 +10,4 @@ def test_get_unknown_column(self): with self.assertRaises(errors.UnknownTypeError) as e: get_column_by_spec('Unicorn', {'context': {}}) - self.assertIn('Unicorn', str(e.exception)) + self.assertIn('Unicorn', str(e.exception)) diff --git a/tests/numpy/__init__.py b/tests/numpy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/numpy/columns/__init__.py b/tests/numpy/columns/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/numpy/columns/test_datetime.py b/tests/numpy/columns/test_datetime.py new file mode 100644 index 00000000..eca60fac --- /dev/null +++ b/tests/numpy/columns/test_datetime.py @@ -0,0 +1,505 @@ +from contextlib import contextmanager +from datetime import date, datetime +import os +from time import tzset + +from mock import patch + +try: + import numpy as np +except ImportError: + np = None + +try: + import pandas as pd +except ImportError: + pd = None + +from pytz import timezone, utc, UnknownTimeZoneError +import tzlocal + +from tests.numpy.testcase import NumpyBaseTestCase +from tests.util import require_server_version + + +class BaseDateTimeTestCase(NumpyBaseTestCase): + def setUp(self): + super(BaseDateTimeTestCase, self).setUp() + # TODO: remove common client when inserts will be implemented + self.common_client = self._create_client() + + # Bust tzlocal cache. + try: + tzlocal.unix._cache_tz = None + except AttributeError: + pass + + try: + tzlocal.win32._cache_tz = None + except AttributeError: + pass + + def tearDown(self): + self.common_client.disconnect() + super(BaseDateTimeTestCase, self).tearDown() + + +class DateTimeTestCase(BaseDateTimeTestCase): + def test_datetime_type(self): + query = 'SELECT now()' + + rv = self.client.execute(query, columnar=True) + self.assertIsInstance(rv[0][0], np.datetime64) + + @require_server_version(20, 1, 2) + def test_datetime64_type(self): + query = 'SELECT now64()' + + rv = self.client.execute(query, columnar=True) + self.assertIsInstance(rv[0][0], np.datetime64) + + def test_simple(self): + with self.create_table('a Date, b DateTime'): + data = [(date(2012, 10, 25), datetime(2012, 10, 25, 14, 7, 19))] + self.common_client.execute( + 'INSERT INTO test (a, b) VALUES', data + ) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual(inserted, '2012-10-25\t2012-10-25 14:07:19\n') + + inserted = self.client.execute(query, columnar=True) + self.assertArraysEqual( + inserted[0], np.array(['2012-10-25'], dtype='datetime64[D]') + ) + self.assertArraysEqual( + inserted[1], + np.array(['2012-10-25T14:07:19'], dtype='datetime64[ns]') + ) + + def test_handle_errors_from_tzlocal(self): + with patch('tzlocal.get_localzone') as mocked_get_localzone: + mocked_get_localzone.side_effect = UnknownTimeZoneError() + self.client.execute('SELECT now()') + + @require_server_version(20, 1, 2) + def test_datetime64_frac_trunc(self): + with self.create_table('a DateTime64'): + data = [(datetime(2012, 10, 25, 14, 7, 19, 125600), )] + self.common_client.execute( + 'INSERT INTO test (a) VALUES', data + ) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual(inserted, '2012-10-25 14:07:19.125\n') + + inserted = self.client.execute(query, columnar=True) + self.assertArraysEqual( + inserted[0], + np.array(['2012-10-25T14:07:19.125'], dtype='datetime64[ns]') + ) + + @require_server_version(20, 1, 2) + def test_datetime64_explicit_frac(self): + with self.create_table('a DateTime64(1)'): + data = [(datetime(2012, 10, 25, 14, 7, 19, 125600),)] + self.common_client.execute( + 'INSERT INTO test (a) VALUES', data + ) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual(inserted, '2012-10-25 14:07:19.1\n') + + inserted = self.client.execute(query, columnar=True) + self.assertArraysEqual( + inserted[0], + np.array(['2012-10-25T14:07:19.1'], dtype='datetime64[ns]') + ) + + +class DateTimeTimezonesTestCase(BaseDateTimeTestCase): + dt_type = 'DateTime' + + def make_tz_numpy_array(self, dt, tz_name): + return pd.to_datetime(np.array([dt] * 2, dtype='datetime64[ns]')) \ + .tz_localize(tz_name).to_numpy() + + @contextmanager + def patch_env_tz(self, tz_name): + # Although in many cases, changing the TZ environment variable may + # affect the output of functions like localtime() without calling + # tzset(), this behavior should not be relied on. + # https://docs.python.org/3/library/time.html#time.tzset + with patch.dict(os.environ, {'TZ': tz_name}): + tzset() + yield + + tzset() + + # Asia/Kamchatka = UTC+12 + # Asia/Novosibirsk = UTC+7 + # Europe/Moscow = UTC+3 + + # 1500010800 second since epoch in Europe/Moscow. + # 1500000000 second since epoch in UTC. + dt = datetime(2017, 7, 14, 5, 40) + dt_str = '2017-07-14T05:40:00' + dt_tz = timezone('Asia/Kamchatka').localize(dt) + + col_tz_name = 'Asia/Novosibirsk' + col_tz = timezone(col_tz_name) + + # INSERTs and SELECTs must be the same as clickhouse-client's + # if column has no timezone. + + def table_columns(self, with_tz=False): + if not with_tz: + return 'a {}'.format(self.dt_type) + + return "a {}('{}')".format(self.dt_type, self.col_tz_name) + + def test_use_server_timezone(self): + # Insert datetime with timezone UTC + # into column with no timezone + # using server's timezone (Europe/Moscow) + + # Determine server timezone and calculate expected timestamp. + server_tz_name = self.common_client.execute('SELECT timezone()')[0][0] + offset = timezone(server_tz_name).utcoffset(self.dt).total_seconds() + timestamp = 1500010800 - int(offset) + + with self.patch_env_tz('Asia/Novosibirsk'): + with self.create_table(self.table_columns()): + self.common_client.execute( + 'INSERT INTO test (a) VALUES', [(self.dt, )] + ) + + self.emit_cli( + "INSERT INTO test (a) VALUES ('2017-07-14 05:40:00')" + ) + + query = 'SELECT toInt32(a) FROM test' + inserted = self.emit_cli(query) + self.assertEqual(inserted, '{ts}\n{ts}\n'.format(ts=timestamp)) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual( + inserted, + '2017-07-14 05:40:00\n2017-07-14 05:40:00\n' + ) + + inserted = self.client.execute(query, columnar=True) + self.assertArraysEqual( + inserted[0], + np.array([self.dt_str] * 2, dtype='datetime64[ns]') + ) + + def test_use_client_timezone(self): + # Insert datetime with timezone UTC + # into column with no timezone + # using client's timezone Asia/Novosibirsk + + settings = {'use_client_time_zone': True} + + with self.patch_env_tz('Asia/Novosibirsk'): + with self.create_table(self.table_columns()): + self.common_client.execute( + 'INSERT INTO test (a) VALUES', [(self.dt, )], + settings=settings + ) + + self.emit_cli( + "INSERT INTO test (a) VALUES ('2017-07-14 05:40:00')", + use_client_time_zone=1 + ) + + query = 'SELECT toInt32(a) FROM test' + inserted = self.emit_cli(query, use_client_time_zone=1) + # 1499985600 = 1500000000 - 4 * 3600 + self.assertEqual(inserted, '1499985600\n1499985600\n') + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query, use_client_time_zone=1) + self.assertEqual( + inserted, + '2017-07-14 05:40:00\n2017-07-14 05:40:00\n' + ) + + inserted = self.client.execute(query, columnar=True, + settings=settings) + self.assertArraysEqual( + inserted[0], + np.array([self.dt_str] * 2, dtype='datetime64[ns]') + ) + + # def test_insert_integers(self): + # settings = {'use_client_time_zone': True} + # + # with self.patch_env_tz('Europe/Moscow'): + # with self.create_table(self.table_columns()): + # self.client.execute( + # 'INSERT INTO test (a) VALUES', [(1530211034, )], + # settings=settings + # ) + # + # query = 'SELECT toUInt32(a), a FROM test' + # inserted = self.emit_cli(query, use_client_time_zone=1) + # self.assertEqual(inserted, + # '1530211034\t2018-06-28 21:37:14\n') + # + # def test_insert_integer_bounds(self): + # with self.create_table('a DateTime'): + # self.client.execute( + # 'INSERT INTO test (a) VALUES', + # [(0, ), (1, ), (1500000000, ), (2**32-1, )] + # ) + # + # query = 'SELECT toUInt32(a) FROM test ORDER BY a' + # inserted = self.emit_cli(query) + # self.assertEqual(inserted, '0\n1\n1500000000\n4294967295\n') + + @require_server_version(1, 1, 54337) + def test_datetime_with_timezone_use_server_timezone(self): + # Insert datetime with timezone Asia/Kamchatka + # into column with no timezone + # using server's timezone (Europe/Moscow) + + server_tz_name = self.client.execute('SELECT timezone()')[0][0] + offset = timezone(server_tz_name).utcoffset(self.dt) + + with self.patch_env_tz('Asia/Novosibirsk'): + with self.create_table(self.table_columns()): + self.common_client.execute( + 'INSERT INTO test (a) VALUES', [(self.dt_tz, )] + ) + + self.emit_cli( + "INSERT INTO test (a) VALUES " + "(toDateTime('2017-07-14 05:40:00', 'Asia/Kamchatka'))", + ) + + query = 'SELECT toInt32(a) FROM test' + inserted = self.emit_cli(query) + # 1499967600 = 1500000000 - 12 * 3600 + self.assertEqual(inserted, '1499967600\n1499967600\n') + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + + dt = (self.dt_tz.astimezone(utc) + offset).replace(tzinfo=None) + self.assertEqual(inserted, '{dt}\n{dt}\n'.format(dt=dt)) + + inserted = self.client.execute(query, columnar=True) + self.assertArraysEqual( + inserted[0], + np.array([dt.isoformat()] * 2, dtype='datetime64[ns]') + ) + + @require_server_version(1, 1, 54337) + def test_datetime_with_timezone_use_client_timezone(self): + # Insert datetime with timezone Asia/Kamchatka + # into column with no timezone + # using client's timezone Asia/Novosibirsk + + settings = {'use_client_time_zone': True} + + with self.patch_env_tz('Asia/Novosibirsk'): + with self.create_table(self.table_columns()): + self.common_client.execute( + 'INSERT INTO test (a) VALUES', [(self.dt_tz, )], + settings=settings + ) + + self.emit_cli( + "INSERT INTO test (a) VALUES " + "(toDateTime('2017-07-14 05:40:00', 'Asia/Kamchatka'))", + use_client_time_zone=1 + ) + + query = 'SELECT toInt32(a) FROM test' + inserted = self.emit_cli(query, use_client_time_zone=1) + # 1499967600 = 1500000000 - 12 * 3600 + self.assertEqual(inserted, '1499967600\n1499967600\n') + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query, use_client_time_zone=1) + # 2017-07-14 00:40:00 = 2017-07-14 05:40:00 - 05:00:00 + # (Kamchatka - Novosibirsk) + self.assertEqual( + inserted, + '2017-07-14 00:40:00\n2017-07-14 00:40:00\n' + ) + + inserted = self.client.execute(query, columnar=True, + settings=settings) + dt = datetime(2017, 7, 14, 0, 40) + self.assertArraysEqual( + inserted[0], + np.array([dt.isoformat()] * 2, dtype='datetime64[ns]') + ) + + @require_server_version(1, 1, 54337) + def test_column_use_server_timezone(self): + # Insert datetime with no timezone + # into column with timezone Asia/Novosibirsk + # using server's timezone (Europe/Moscow) + + with self.patch_env_tz('Europe/Moscow'): + with self.create_table(self.table_columns(with_tz=True)): + self.common_client.execute( + 'INSERT INTO test (a) VALUES', [(self.dt, )] + ) + + self.emit_cli( + "INSERT INTO test (a) VALUES ('2017-07-14 05:40:00')" + ) + + query = 'SELECT toInt32(a) FROM test' + inserted = self.emit_cli(query) + # 1499985600 = 1500000000 - 4 * 3600 + self.assertEqual(inserted, '1499985600\n1499985600\n') + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual( + inserted, + '2017-07-14 05:40:00\n2017-07-14 05:40:00\n' + ) + + inserted = self.client.execute(query, columnar=True) + self.assertArraysEqual( + inserted[0], + self.make_tz_numpy_array(self.dt, self.col_tz_name) + ) + + @require_server_version(1, 1, 54337) + def test_column_use_client_timezone(self): + # Insert datetime with no timezone + # into column with timezone Asia/Novosibirsk + # using client's timezone Europe/Moscow + + settings = {'use_client_time_zone': True} + + with self.patch_env_tz('Europe/Moscow'): + with self.create_table(self.table_columns(with_tz=True)): + self.common_client.execute( + 'INSERT INTO test (a) VALUES', [(self.dt, )], + settings=settings + ) + self.emit_cli( + "INSERT INTO test (a) VALUES ('2017-07-14 05:40:00')", + use_client_time_zone=1 + ) + + query = 'SELECT toInt32(a) FROM test' + inserted = self.emit_cli(query, use_client_time_zone=1) + # 1499985600 = 1500000000 - 4 * 3600 + self.assertEqual(inserted, '1499985600\n1499985600\n') + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query, use_client_time_zone=1) + self.assertEqual( + inserted, + '2017-07-14 05:40:00\n2017-07-14 05:40:00\n' + ) + + inserted = self.client.execute(query, columnar=True, + settings=settings) + self.assertArraysEqual( + inserted[0], + self.make_tz_numpy_array(self.dt, self.col_tz_name) + ) + + @require_server_version(1, 1, 54337) + def test_datetime_with_timezone_column_use_server_timezone(self): + # Insert datetime with timezone Asia/Kamchatka + # into column with timezone Asia/Novosibirsk + # using server's timezone (Europe/Moscow) + + with self.patch_env_tz('Europe/Moscow'): + with self.create_table(self.table_columns(with_tz=True)): + self.common_client.execute( + 'INSERT INTO test (a) VALUES', [(self.dt_tz, )] + ) + + self.emit_cli( + "INSERT INTO test (a) VALUES " + "(toDateTime('2017-07-14 05:40:00', 'Asia/Kamchatka'))", + ) + + query = 'SELECT toInt32(a) FROM test' + inserted = self.emit_cli(query) + # 1499967600 = 1500000000 - 12 * 3600 + self.assertEqual(inserted, '1499967600\n1499967600\n') + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + # 2017-07-14 00:40:00 = 2017-07-14 05:40:00 - 05:00:00 + # (Kamchatka - Novosibirsk) + self.assertEqual( + inserted, + '2017-07-14 00:40:00\n2017-07-14 00:40:00\n' + ) + + inserted = self.client.execute(query, columnar=True) + dt = datetime(2017, 7, 14, 0, 40) + self.assertArraysEqual( + inserted[0], self.make_tz_numpy_array(dt, self.col_tz_name) + ) + + @require_server_version(1, 1, 54337) + def test_datetime_with_timezone_column_use_client_timezone(self): + # Insert datetime with timezone Asia/Kamchatka + # into column with timezone Asia/Novosibirsk + # using client's timezone (Europe/Moscow) + + settings = {'use_client_time_zone': True} + + with self.patch_env_tz('Europe/Moscow'): + with self.create_table(self.table_columns(with_tz=True)): + self.common_client.execute( + 'INSERT INTO test (a) VALUES', [(self.dt_tz, )], + settings=settings + ) + + self.emit_cli( + "INSERT INTO test (a) VALUES " + "(toDateTime('2017-07-14 05:40:00', 'Asia/Kamchatka'))", + use_client_time_zone=1 + ) + + query = 'SELECT toInt32(a) FROM test' + inserted = self.emit_cli(query, use_client_time_zone=1) + # 1499967600 = 1500000000 - 12 * 3600 + self.assertEqual(inserted, '1499967600\n1499967600\n') + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query, use_client_time_zone=1) + # 2017-07-14 00:40:00 = 2017-07-14 05:40:00 - 05:00:00 + # (Kamchatka - Novosibirsk) + self.assertEqual( + inserted, + '2017-07-14 00:40:00\n2017-07-14 00:40:00\n' + ) + + inserted = self.client.execute(query, columnar=True, + settings=settings) + dt = datetime(2017, 7, 14, 0, 40) + self.assertArraysEqual( + inserted[0], self.make_tz_numpy_array(dt, self.col_tz_name) + ) + + +class DateTime64TimezonesTestCase(DateTimeTimezonesTestCase): + dt_type = 'DateTime64' + required_server_version = (20, 1, 2) + + def table_columns(self, with_tz=False): + if not with_tz: + return 'a {}(0)'.format(self.dt_type) + + return "a {}(0, '{}')".format(self.dt_type, self.col_tz_name) diff --git a/tests/numpy/columns/test_float.py b/tests/numpy/columns/test_float.py new file mode 100644 index 00000000..6c1d334a --- /dev/null +++ b/tests/numpy/columns/test_float.py @@ -0,0 +1,29 @@ +try: + import numpy as np +except ImportError: + np = None + +from tests.numpy.testcase import NumpyBaseTestCase + + +class FloatTestCase(NumpyBaseTestCase): + n = 10 + + def check_column(self, rv, col_type): + self.assertArraysEqual(rv[0], np.array(range(self.n))) + self.assertIsInstance(rv[0][0], (col_type, )) + + def get_query(self, ch_type): + query = 'SELECT CAST(number AS {}) FROM numbers({})'.format( + ch_type, self.n + ) + + return self.client.execute(query, columnar=True) + + def test_float32(self): + rv = self.get_query('Float32') + self.check_column(rv, np.float32) + + def test_float64(self): + rv = self.get_query('Float64') + self.check_column(rv, np.float64) diff --git a/tests/numpy/columns/test_int.py b/tests/numpy/columns/test_int.py new file mode 100644 index 00000000..01408ba7 --- /dev/null +++ b/tests/numpy/columns/test_int.py @@ -0,0 +1,53 @@ +try: + import numpy as np +except ImportError: + np = None + +from tests.numpy.testcase import NumpyBaseTestCase + + +class IntTestCase(NumpyBaseTestCase): + n = 10 + + def check_column(self, rv, col_type): + self.assertArraysEqual(rv[0], np.array(range(self.n))) + self.assertIsInstance(rv[0][0], (col_type, )) + + def get_query(self, ch_type): + query = 'SELECT CAST(number AS {}) FROM numbers({})'.format( + ch_type, self.n + ) + + return self.client.execute(query, columnar=True) + + def test_int8(self): + rv = self.get_query('Int8') + self.check_column(rv, np.int8) + + def test_int16(self): + rv = self.get_query('Int16') + self.check_column(rv, np.int16) + + def test_int32(self): + rv = self.get_query('Int32') + self.check_column(rv, np.int32) + + def test_int64(self): + rv = self.get_query('Int64') + self.check_column(rv, np.int64) + + def test_uint8(self): + rv = self.get_query('UInt8') + self.check_column(rv, np.uint8) + + def test_uint16(self): + rv = self.get_query('UInt16') + self.check_column(rv, np.uint16) + + def test_uint32(self): + rv = self.get_query('UInt32') + self.check_column(rv, np.uint32) + + def test_uint64(self): + rv = self.get_query('UInt64') + self.check_column(rv, np.uint64) diff --git a/tests/numpy/columns/test_low_cardinality.py b/tests/numpy/columns/test_low_cardinality.py new file mode 100644 index 00000000..ddb54611 --- /dev/null +++ b/tests/numpy/columns/test_low_cardinality.py @@ -0,0 +1,186 @@ +try: + import numpy as np +except ImportError: + np = None + +from tests.numpy.testcase import NumpyBaseTestCase + +from datetime import date, timedelta +# from decimal import Decimal + + +class LowCardinalityTestCase(NumpyBaseTestCase): + required_server_version = (19, 3, 3) + stable_support_version = (19, 9, 2) + + def setUp(self): + super(LowCardinalityTestCase, self).setUp() + # TODO: remove common client when inserts will be implemented + self.common_client = self._create_client() + + def tearDown(self): + self.common_client.disconnect() + super(LowCardinalityTestCase, self).tearDown() + + def cli_client_kwargs(self): + if self.server_version >= self.stable_support_version: + return {'allow_suspicious_low_cardinality_types': 1} + + def test_uint8(self): + with self.create_table('a LowCardinality(UInt8)'): + data = [(x, ) for x in range(255)] + self.common_client.execute('INSERT INTO test (a) VALUES', data) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual( + inserted, + '\n'.join(str(x[0]) for x in data) + '\n' + ) + + inserted = self.client.execute(query, columnar=True) + self.assertArraysEqual(inserted[0], np.array(range(255))) + + def test_int8(self): + with self.create_table('a LowCardinality(Int8)'): + data = [(x - 127, ) for x in range(255)] + self.common_client.execute('INSERT INTO test (a) VALUES', data) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual( + inserted, + '\n'.join(str(x[0]) for x in data) + '\n' + + ) + + inserted = self.client.execute(query, columnar=True) + self.assertArraysEqual( + inserted[0], np.array([x - 127 for x in range(255)]) + ) + + # def test_nullable_int8(self): + # with self.create_table('a LowCardinality(Nullable(Int8))'): + # data = [(None, ), (-1, ), (0, ), (1, ), (None, )] + # self.client.execute('INSERT INTO test (a) VALUES', data) + # + # query = 'SELECT * FROM test' + # inserted = self.emit_cli(query) + # self.assertEqual(inserted, '\\N\n-1\n0\n1\n\\N\n') + # + # inserted = self.client.execute(query) + # self.assertEqual(inserted, data) + + def test_date(self): + with self.create_table('a LowCardinality(Date)'): + start = date(1970, 1, 1) + data = [(start + timedelta(x), ) for x in range(300)] + self.common_client.execute('INSERT INTO test (a) VALUES', data) + + query = 'SELECT * FROM test' + inserted = self.client.execute(query, columnar=True) + self.assertArraysEqual( + inserted[0], np.array(list(range(300)), dtype='datetime64[D]') + ) + + def test_float(self): + with self.create_table('a LowCardinality(Float)'): + data = [(float(x),) for x in range(300)] + self.common_client.execute('INSERT INTO test (a) VALUES', data) + + query = 'SELECT * FROM test' + inserted = self.client.execute(query, columnar=True) + self.assertArraysEqual( + inserted[0], np.array([float(x) for x in range(300)]) + ) + + # def test_decimal(self): + # with self.create_table('a LowCardinality(Float)'): + # data = [(Decimal(x),) for x in range(300)] + # self.client.execute('INSERT INTO test (a) VALUES', data) + # + # query = 'SELECT * FROM test' + # inserted = self.client.execute(query) + # self.assertEqual(inserted, data) + # + # def test_array(self): + # with self.create_table('a Array(LowCardinality(Int16))'): + # data = [([100, 500], )] + # self.client.execute('INSERT INTO test (a) VALUES', data) + # + # query = 'SELECT * FROM test' + # inserted = self.emit_cli(query) + # self.assertEqual(inserted, '[100,500]\n') + # + # inserted = self.client.execute(query) + # self.assertEqual(inserted, data) + # + # def test_empty_array(self): + # with self.create_table('a Array(LowCardinality(Int16))'): + # data = [([], )] + # self.client.execute('INSERT INTO test (a) VALUES', data) + # + # query = 'SELECT * FROM test' + # inserted = self.emit_cli(query) + # self.assertEqual(inserted, '[]\n') + # + # inserted = self.client.execute(query) + # self.assertEqual(inserted, data) + # + def test_string(self): + with self.create_table('a LowCardinality(String)'): + data = [ + ('test', ), ('low', ), ('cardinality', ), + ('test', ), ('test', ), ('', ) + ] + self.common_client.execute('INSERT INTO test (a) VALUES', data) + + query = 'SELECT * FROM test' + inserted = self.emit_cli(query) + self.assertEqual( + inserted, + 'test\nlow\ncardinality\ntest\ntest\n\n' + ) + + inserted = self.client.execute(query, columnar=True) + self.assertArraysEqual(inserted[0], list(list(zip(*data))[0])) + + # def test_fixed_string(self): + # with self.create_table('a LowCardinality(FixedString(12))'): + # data = [ + # ('test', ), ('low', ), ('cardinality', ), + # ('test', ), ('test', ), ('', ) + # ] + # self.client.execute('INSERT INTO test (a) VALUES', data) + # + # query = 'SELECT * FROM test' + # inserted = self.emit_cli(query) + # self.assertEqual( + # inserted, + # 'test\\0\\0\\0\\0\\0\\0\\0\\0\n' + # 'low\\0\\0\\0\\0\\0\\0\\0\\0\\0\n' + # 'cardinality\\0\n' + # 'test\\0\\0\\0\\0\\0\\0\\0\\0\n' + # 'test\\0\\0\\0\\0\\0\\0\\0\\0\n' + # '\\0\\0\\0\\0\\0\\0\\0\\0\\0\\0\\0\\0\n' + # ) + # + # inserted = self.client.execute(query) + # self.assertEqual(inserted, data) + # + # def test_nullable_string(self): + # with self.create_table('a LowCardinality(Nullable(String))'): + # data = [ + # ('test', ), ('', ), (None, ) + # ] + # self.client.execute('INSERT INTO test (a) VALUES', data) + # + # query = 'SELECT * FROM test' + # inserted = self.emit_cli(query) + # self.assertEqual( + # inserted, + # 'test\n\n\\N\n' + # ) + # + # inserted = self.client.execute(query) + # self.assertEqual(inserted, data) diff --git a/tests/numpy/columns/test_other.py b/tests/numpy/columns/test_other.py new file mode 100644 index 00000000..540f3253 --- /dev/null +++ b/tests/numpy/columns/test_other.py @@ -0,0 +1,44 @@ +from clickhouse_driver import errors + +try: + from clickhouse_driver.columns.numpy.service import \ + get_numpy_column_by_spec +except ImportError: + get_numpy_column_by_spec = None + +from clickhouse_driver.context import Context + +from tests.numpy.testcase import NumpyBaseTestCase + + +class OtherColumnsTestCase(NumpyBaseTestCase): + def get_column(self, spec): + ctx = Context() + ctx.client_settings = {'strings_as_bytes': False} + return get_numpy_column_by_spec(spec, {'context': ctx}) + + def test_enum(self): + col = self.get_column("Enum8('hello' = 1, 'world' = 2)") + self.assertIsNotNone(col) + + def test_decimal(self): + col = self.get_column('Decimal(8, 4)') + self.assertIsNotNone(col) + + def test_array(self): + col = self.get_column('Array(String)') + self.assertIsNotNone(col) + + def test_tuple(self): + col = self.get_column('Tuple(String)') + self.assertIsNotNone(col) + + def test_simple_aggregation_function(self): + col = self.get_column('SimpleAggregateFunction(any, Int32)') + self.assertIsNotNone(col) + + def test_get_unknown_column(self): + with self.assertRaises(errors.UnknownTypeError) as e: + self.get_column('Unicorn') + + self.assertIn('Unicorn', str(e.exception)) diff --git a/tests/numpy/columns/test_string.py b/tests/numpy/columns/test_string.py new file mode 100644 index 00000000..60d4d4e9 --- /dev/null +++ b/tests/numpy/columns/test_string.py @@ -0,0 +1,52 @@ +try: + import numpy as np +except ImportError: + np = None + +from tests.numpy.testcase import NumpyBaseTestCase + + +class StringTestCase(NumpyBaseTestCase): + def test_string(self): + query = "SELECT arrayJoin(splitByChar(',', 'a,b,c')) AS x" + rv = self.client.execute(query, columnar=True) + + self.assertArraysEqual(rv[0], np.array(['a', 'b', 'c'])) + self.assertIsInstance(rv[0][0], (object, )) + + +class ByteStringTestCase(NumpyBaseTestCase): + client_kwargs = {'settings': {'strings_as_bytes': True, 'use_numpy': True}} + + def test_string(self): + query = "SELECT arrayJoin(splitByChar(',', 'a,b,c')) AS x" + rv = self.client.execute(query, columnar=True) + + self.assertArraysEqual(rv[0], np.array([b'a', b'b', b'c'])) + self.assertIsInstance(rv[0][0], (object, )) + + +class FixedStringTestCase(NumpyBaseTestCase): + def test_string(self): + query = ( + "SELECT CAST(arrayJoin(splitByChar(',', 'a,b,c')) " + "AS FixedString(2)) AS x" + ) + rv = self.client.execute(query, columnar=True) + + self.assertArraysEqual(rv[0], np.array(['a', 'b', 'c'])) + self.assertIsInstance(rv[0][0], (object, )) + + +class ByteFixedStringTestCase(NumpyBaseTestCase): + client_kwargs = {'settings': {'strings_as_bytes': True, 'use_numpy': True}} + + def test_string(self): + query = ( + "SELECT CAST(arrayJoin(splitByChar(',', 'a,b,c')) " + "AS FixedString(3)) AS x" + ) + rv = self.client.execute(query, columnar=True) + + self.assertArraysEqual(rv[0], np.array([b'a', b'b', b'c'])) + self.assertIsInstance(rv[0][0], (object, )) diff --git a/tests/numpy/test_generic.py b/tests/numpy/test_generic.py new file mode 100644 index 00000000..3e0fb440 --- /dev/null +++ b/tests/numpy/test_generic.py @@ -0,0 +1,137 @@ +import types + +try: + import numpy as np + import pandas as pd +except ImportError: + np = None + pd = None + +from tests.testcase import BaseTestCase +from tests.numpy.testcase import NumpyBaseTestCase + + +class GenericTestCase(NumpyBaseTestCase): + n = 10 + + def test_columnar(self): + rv = self.client.execute( + 'SELECT number FROM numbers({})'.format(self.n), columnar=True + ) + + self.assertEqual(len(rv), 1) + self.assertIsInstance(rv[0], (np.ndarray, )) + + def test_rowwise(self): + rv = self.client.execute( + 'SELECT number FROM numbers({})'.format(self.n) + ) + + self.assertEqual(len(rv), self.n) + self.assertIsInstance(rv[0], (np.ndarray, )) + + def test_insert_not_supported(self): + data = [(300,)] + + with self.create_table('a Int32'): + with self.assertRaises(RuntimeError) as e: + self.client.execute( + 'INSERT INTO test (a) VALUES', data + ) + + self.assertEqual('Write is not implemented', str(e.exception)) + + def test_with_column_types(self): + rv = self.client.execute( + 'SELECT CAST(2 AS Int32) AS x', with_column_types=True + ) + + self.assertEqual(rv, ([(2, )], [('x', 'Int32')])) + + +class NumpyProgressTestCase(NumpyBaseTestCase): + def test_select_with_progress(self): + progress = self.client.execute_with_progress('SELECT 2') + self.assertEqual( + list(progress), + [(1, 0), (1, 0)] if self.server_version > (20,) else [(1, 0)] + ) + self.assertEqual(progress.get_result(), [(2,)]) + self.assertTrue(self.client.connection.connected) + + def test_select_with_progress_no_progress_obtaining(self): + progress = self.client.execute_with_progress('SELECT 2') + self.assertEqual(progress.get_result(), [(2,)]) + + +class NumpyIteratorTestCase(NumpyBaseTestCase): + def test_select_with_iter(self): + result = self.client.execute_iter( + 'SELECT number FROM system.numbers LIMIT 10' + ) + self.assertIsInstance(result, types.GeneratorType) + + self.assertEqual(list(result), list(zip(range(10)))) + self.assertEqual(list(result), []) + + def test_select_with_iter_with_column_types(self): + result = self.client.execute_iter( + 'SELECT CAST(number AS UInt32) as number ' + 'FROM system.numbers LIMIT 10', + with_column_types=True + ) + self.assertIsInstance(result, types.GeneratorType) + + self.assertEqual( + list(result), + [[('number', 'UInt32')]] + list(zip(range(10))) + ) + self.assertEqual(list(result), []) + + +class QueryDataFrameTestCase(NumpyBaseTestCase): + def test_simple(self): + df = self.client.query_dataframe( + 'SELECT CAST(number AS Int64) AS x FROM system.numbers LIMIT 100' + ) + + self.assertTrue(df.equals(pd.DataFrame({'x': range(100)}))) + + def test_replace_whitespace_in_column_names(self): + df = self.client.query_dataframe( + 'SELECT number AS "test me" FROM system.numbers LIMIT 100' + ) + + self.assertIn('test_me', df) + + +class NoNumPyTestCase(BaseTestCase): + def setUp(self): + super(NoNumPyTestCase, self).setUp() + + try: + import numpy # noqa: F401 + import pandas # noqa: F401 + except Exception: + pass + + else: + self.skipTest('NumPy extras are installed') + + def test_runtime_error_without_numpy(self): + with self.assertRaises(RuntimeError) as e: + with self.created_client(settings={'use_numpy': True}) as client: + client.execute('SELECT 1') + + self.assertEqual( + 'Extras for NumPy must be installed', str(e.exception) + ) + + def test_query_dataframe(self): + with self.assertRaises(RuntimeError) as e: + with self.created_client(settings={'use_numpy': True}) as client: + client.query_dataframe('SELECT 1 AS x') + + self.assertEqual( + 'Extras for NumPy must be installed', str(e.exception) + ) diff --git a/tests/numpy/testcase.py b/tests/numpy/testcase.py new file mode 100644 index 00000000..aecd257b --- /dev/null +++ b/tests/numpy/testcase.py @@ -0,0 +1,15 @@ +from tests.testcase import BaseTestCase + + +class NumpyBaseTestCase(BaseTestCase): + client_kwargs = {'settings': {'use_numpy': True}} + + def setUp(self): + try: + super(NumpyBaseTestCase, self).setUp() + except RuntimeError as e: + if 'NumPy' in str(e): + self.skipTest('Numpy package is not installed') + + def assertArraysEqual(self, first, second): + return self.assertTrue((first == second).all())