From d673f1b68bf11050aaf3bc998a199e4594591986 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 16 Mar 2023 09:31:31 -0700 Subject: [PATCH 01/19] implemented query class --- google/cloud/bigtable/read_rows_query.py | 210 +++++++++++++++++++++-- 1 file changed, 198 insertions(+), 12 deletions(-) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index 64583b2d7..3f4ef1ebb 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -13,36 +13,162 @@ # limitations under the License. # from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from .row_response import row_key +from dataclasses import dataclass +from google.cloud.bigtable.row_filters import RowFilter if TYPE_CHECKING: - from google.cloud.bigtable.row_filters import RowFilter from google.cloud.bigtable import RowKeySamples +@dataclass +class _RangePoint: + # model class for a point in a row range + key: row_key + is_inclusive: bool + + class ReadRowsQuery: """ Class to encapsulate details of a read row request """ def __init__( - self, row_keys: list[str | bytes] | str | bytes | None = None, limit=None + self, + row_keys: list[str | bytes] | str | bytes | None = None, + limit: int | None = None, + row_filter: RowFilter | dict[str, Any] | None = None, ): - pass + """ + Create a new ReadRowsQuery - def set_limit(self, limit: int) -> ReadRowsQuery: - raise NotImplementedError + Args: + - row_keys: a list of row keys to include in the query + - limit: the maximum number of rows to return. None or 0 means no limit + default: None (no limit) + - row_filter: a RowFilter to apply to the query + """ + self.row_keys: set[bytes] = set() + self.row_ranges: list[tuple[_RangePoint | None, _RangePoint | None]] = [] + if row_keys: + self.add_rows(row_keys) + self.limit: int | None = limit + self.filter: RowFilter | dict[str, Any] = row_filter - def set_filter(self, filter: "RowFilter") -> ReadRowsQuery: - raise NotImplementedError + def set_limit(self, new_limit: int | None): + """ + Set the maximum number of rows to return by this query. - def add_rows(self, row_id_list: list[str]) -> ReadRowsQuery: - raise NotImplementedError + None or 0 means no limit + + Args: + - new_limit: the new limit to apply to this query + Returns: + - a reference to this query for chaining + Raises: + - ValueError if new_limit is < 0 + """ + if new_limit is not None and new_limit < 0: + raise ValueError("limit must be >= 0") + self._limit = new_limit + return self + + def set_filter( + self, row_filter: RowFilter | dict[str, Any] | None + ) -> ReadRowsQuery: + """ + Set a RowFilter to apply to this query + + Args: + - row_filter: a RowFilter to apply to this query + Can be a RowFilter object or a dict representation + Returns: + - a reference to this query for chaining + """ + if not ( + isinstance(row_filter, dict) + or isinstance(row_filter, RowFilter) + or row_filter is None + ): + raise ValueError( + "row_filter must be a RowFilter or corresponding dict representation" + ) + self._filter = row_filter + return self + + def add_rows(self, row_keys: list[str | bytes] | str | bytes) -> ReadRowsQuery: + """ + Add a list of row keys to this query + + Args: + - row_keys: a list of row keys to add to this query + Returns: + - a reference to this query for chaining + Raises: + - ValueError if an input is not a string or bytes + """ + if not isinstance(row_keys, list): + row_keys = [row_keys] + update_set = set() + for k in row_keys: + if isinstance(k, str): + k = k.encode() + elif not isinstance(k, bytes): + raise ValueError("row_keys must be strings or bytes") + update_set.add(k) + self.row_keys.update(update_set) + return self def add_range( - self, start_key: str | bytes | None = None, end_key: str | bytes | None = None + self, + start_key: str | bytes | None = None, + end_key: str | bytes | None = None, + start_is_inclusive: bool | None = None, + end_is_inclusive: bool | None = None, ) -> ReadRowsQuery: - raise NotImplementedError + """ + Add a range of row keys to this query. + + Args: + - start_key: the start of the range + if None, start_key is interpreted as the empty string, inclusive + - end_key: the end of the range + if None, end_key is interpreted as the infinite row key, exclusive + - start_is_inclusive: if True, the start key is included in the range + defaults to True if None. Must not be included if start_key is None + - end_is_inclusive: if True, the end key is included in the range + defaults to False if None. Must not be included if end_key is None + """ + # check for invalid combinations of arguments + if start_is_inclusive is None: + start_is_inclusive = True + elif start_key is None: + raise ValueError( + "start_is_inclusive must not be included if start_key is None" + ) + if end_is_inclusive is None: + end_is_inclusive = False + elif end_key is None: + raise ValueError("end_is_inclusive must not be included if end_key is None") + # ensure that start_key and end_key are bytes + if isinstance(start_key, str): + start_key = start_key.encode() + elif start_key is not None and not isinstance(start_key, bytes): + raise ValueError("start_key must be a string or bytes") + if isinstance(end_key, str): + end_key = end_key.encode() + elif end_key is not None and not isinstance(end_key, bytes): + raise ValueError("end_key must be a string or bytes") + + start_pt = ( + _RangePoint(start_key, start_is_inclusive) + if start_key is not None + else None + ) + end_pt = _RangePoint(end_key, end_is_inclusive) if end_key is not None else None + self.row_ranges.append((start_pt, end_pt)) + return self def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery]: """ @@ -54,3 +180,63 @@ def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery query (if possible) """ raise NotImplementedError + + def to_dict(self) -> dict[str, Any]: + """ + Convert this query into a dictionary that can be used to construct a + ReadRowsRequest protobuf + """ + ranges = [] + for start, end in self.row_ranges: + new_range = {} + if start is not None: + key = "start_key_closed" if start.is_inclusive else "start_key_open" + new_range[key] = start.key + if end is not None: + key = "end_key_closed" if end.is_inclusive else "end_key_open" + new_range[key] = end.key + ranges.append(new_range) + row_keys = list(self.row_keys) + row_keys.sort() + row_set = {"row_keys": row_keys, "row_ranges": ranges} + final_dict: dict[str, Any] = { + "rows": row_set, + } + dict_filter = ( + self.filter.to_dict() if isinstance(self.filter, RowFilter) else self.filter + ) + if dict_filter: + final_dict["filter"] = dict_filter + if self.limit is not None: + final_dict["rows_limit"] = self.limit + return final_dict + + # Support limit and filter as properties + + @property + def limit(self) -> int | None: + """ + Getter implementation for limit property + """ + return self._limit + + @limit.setter + def limit(self, new_limit: int | None): + """ + Setter implementation for limit property + """ + self.set_limit(new_limit) + + @property + def filter(self): + """ + Getter implemntation for filter property + """ + return self._filter + + @filter.setter + def filter(self, row_filter: RowFilter | dict[str, Any] | None): + """ + Setter implementation for filter property + """ + self.set_filter(row_filter) From 5fe8c778948c8db5b2f134049900ef1b1013b0dc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 16 Mar 2023 09:31:56 -0700 Subject: [PATCH 02/19] added query tests --- tests/unit/test_read_rows_query.py | 284 +++++++++++++++++++++++++++++ 1 file changed, 284 insertions(+) create mode 100644 tests/unit/test_read_rows_query.py diff --git a/tests/unit/test_read_rows_query.py b/tests/unit/test_read_rows_query.py new file mode 100644 index 000000000..eb924edaa --- /dev/null +++ b/tests/unit/test_read_rows_query.py @@ -0,0 +1,284 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +TEST_ROWS = [ + "row_key_1", + b"row_key_2", +] + + +class TestReadRowsQuery(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.read_rows_query import ReadRowsQuery + + return ReadRowsQuery + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor_defaults(self): + query = self._make_one() + self.assertEqual(query.row_keys, set()) + self.assertEqual(query.row_ranges, []) + self.assertEqual(query.filter, None) + self.assertEqual(query.limit, None) + + def test_ctor_explicit(self): + from google.cloud.bigtable.row_filters import RowFilterChain + + filter_ = RowFilterChain() + query = self._make_one(["row_key_1", "row_key_2"], limit=10, row_filter=filter_) + self.assertEqual(len(query.row_keys), 2) + self.assertIn("row_key_1".encode(), query.row_keys) + self.assertIn("row_key_2".encode(), query.row_keys) + self.assertEqual(query.row_ranges, []) + self.assertEqual(query.filter, filter_) + self.assertEqual(query.limit, 10) + + def test_ctor_invalid_limit(self): + with self.assertRaises(ValueError): + self._make_one(limit=-1) + + def test_set_filter(self): + from google.cloud.bigtable.row_filters import RowFilterChain + + filter1 = RowFilterChain() + query = self._make_one() + self.assertEqual(query.filter, None) + result = query.set_filter(filter1) + self.assertEqual(query.filter, filter1) + self.assertEqual(result, query) + filter2 = RowFilterChain() + result = query.set_filter(filter2) + self.assertEqual(query.filter, filter2) + result = query.set_filter(None) + self.assertEqual(query.filter, None) + self.assertEqual(result, query) + query.filter = RowFilterChain() + self.assertEqual(query.filter, RowFilterChain()) + with self.assertRaises(ValueError): + query.filter = 1 + + def test_set_filter_dict(self): + from google.cloud.bigtable.row_filters import RowSampleFilter + from google.cloud.bigtable_v2.types.bigtable import ReadRowsRequest + + filter1 = RowSampleFilter(0.5) + filter1_dict = filter1.to_dict() + query = self._make_one() + self.assertEqual(query.filter, None) + result = query.set_filter(filter1_dict) + self.assertEqual(query.filter, filter1_dict) + self.assertEqual(result, query) + output = query.to_dict() + self.assertEqual(output["filter"], filter1_dict) + proto_output = ReadRowsRequest(**output) + self.assertEqual(proto_output.filter, filter1.to_pb()) + + query.filter = None + self.assertEqual(query.filter, None) + + def test_set_limit(self): + query = self._make_one() + self.assertEqual(query.limit, None) + result = query.set_limit(10) + self.assertEqual(query.limit, 10) + self.assertEqual(result, query) + query.limit = 9 + self.assertEqual(query.limit, 9) + result = query.set_limit(0) + self.assertEqual(query.limit, 0) + self.assertEqual(result, query) + with self.assertRaises(ValueError): + query.set_limit(-1) + with self.assertRaises(ValueError): + query.limit = -100 + + def test_add_rows_str(self): + query = self._make_one() + self.assertEqual(query.row_keys, set()) + input_str = "test_row" + result = query.add_rows(input_str) + self.assertEqual(len(query.row_keys), 1) + self.assertIn(input_str.encode(), query.row_keys) + self.assertEqual(result, query) + input_str2 = "test_row2" + result = query.add_rows(input_str2) + self.assertEqual(len(query.row_keys), 2) + self.assertIn(input_str.encode(), query.row_keys) + self.assertIn(input_str2.encode(), query.row_keys) + self.assertEqual(result, query) + + def test_add_rows_bytes(self): + query = self._make_one() + self.assertEqual(query.row_keys, set()) + input_bytes = b"test_row" + result = query.add_rows(input_bytes) + self.assertEqual(len(query.row_keys), 1) + self.assertIn(input_bytes, query.row_keys) + self.assertEqual(result, query) + input_bytes2 = b"test_row2" + result = query.add_rows(input_bytes2) + self.assertEqual(len(query.row_keys), 2) + self.assertIn(input_bytes, query.row_keys) + self.assertIn(input_bytes2, query.row_keys) + self.assertEqual(result, query) + + def test_add_rows_batch(self): + query = self._make_one() + self.assertEqual(query.row_keys, set()) + input_batch = ["test_row", b"test_row2", "test_row3"] + result = query.add_rows(input_batch) + self.assertEqual(len(query.row_keys), 3) + self.assertIn(b"test_row", query.row_keys) + self.assertIn(b"test_row2", query.row_keys) + self.assertIn(b"test_row3", query.row_keys) + self.assertEqual(result, query) + # test adding another batch + query.add_rows(["test_row4", b"test_row5"]) + self.assertEqual(len(query.row_keys), 5) + self.assertIn(input_batch[0].encode(), query.row_keys) + self.assertIn(input_batch[1], query.row_keys) + self.assertIn(input_batch[2].encode(), query.row_keys) + self.assertIn(b"test_row4", query.row_keys) + self.assertIn(b"test_row5", query.row_keys) + + def test_add_rows_invalid(self): + query = self._make_one() + with self.assertRaises(ValueError): + query.add_rows(1) + with self.assertRaises(ValueError): + query.add_rows(["s", 0]) + + def test_duplicate_rows(self): + # should only hold one of each input key + key_1 = b"test_row" + key_2 = b"test_row2" + query = self._make_one(row_keys=[key_1, key_1, key_2]) + self.assertEqual(len(query.row_keys), 2) + self.assertIn(key_1, query.row_keys) + self.assertIn(key_2, query.row_keys) + key_3 = "test_row3" + query.add_rows([key_3 for _ in range(10)]) + self.assertEqual(len(query.row_keys), 3) + + def test_add_range(self): + # test with start and end keys + query = self._make_one() + self.assertEqual(query.row_ranges, []) + result = query.add_range("test_row", "test_row2") + self.assertEqual(len(query.row_ranges), 1) + self.assertEqual(query.row_ranges[0][0].key, "test_row".encode()) + self.assertEqual(query.row_ranges[0][1].key, "test_row2".encode()) + self.assertEqual(query.row_ranges[0][0].is_inclusive, True) + self.assertEqual(query.row_ranges[0][1].is_inclusive, False) + self.assertEqual(result, query) + # test with start key only + result = query.add_range("test_row3") + self.assertEqual(len(query.row_ranges), 2) + self.assertEqual(query.row_ranges[1][0].key, "test_row3".encode()) + self.assertEqual(query.row_ranges[1][1], None) + self.assertEqual(result, query) + # test with end key only + result = query.add_range(start_key=None, end_key="test_row5") + self.assertEqual(len(query.row_ranges), 3) + self.assertEqual(query.row_ranges[2][0], None) + self.assertEqual(query.row_ranges[2][1].key, "test_row5".encode()) + self.assertEqual(query.row_ranges[2][1].is_inclusive, False) + # test with start and end keys and inclusive flags + result = query.add_range(b"test_row6", b"test_row7", False, True) + self.assertEqual(len(query.row_ranges), 4) + self.assertEqual(query.row_ranges[3][0].key, b"test_row6") + self.assertEqual(query.row_ranges[3][1].key, b"test_row7") + self.assertEqual(query.row_ranges[3][0].is_inclusive, False) + self.assertEqual(query.row_ranges[3][1].is_inclusive, True) + # test with nothing passed + result = query.add_range() + self.assertEqual(len(query.row_ranges), 5) + self.assertEqual(query.row_ranges[4][0], None) + self.assertEqual(query.row_ranges[4][1], None) + # test with inclusive flags only + with self.assertRaises(ValueError): + query.add_range(start_is_inclusive=True, end_is_inclusive=True) + with self.assertRaises(ValueError): + query.add_range(start_is_inclusive=False, end_is_inclusive=False) + with self.assertRaises(ValueError): + query.add_range(start_is_inclusive=False) + with self.assertRaises(ValueError): + query.add_range(end_is_inclusive=True) + + def test_to_dict_rows_default(self): + # dictionary should be in rowset proto format + from google.cloud.bigtable_v2.types.bigtable import ReadRowsRequest + + query = self._make_one() + output = query.to_dict() + self.assertTrue(isinstance(output, dict)) + self.assertEqual(len(output.keys()), 1) + expected = {"rows": {"row_keys": [], "row_ranges": []}} + self.assertEqual(output, expected) + + request_proto = ReadRowsRequest(**output) + self.assertEqual(request_proto.rows.row_keys, []) + self.assertEqual(request_proto.rows.row_ranges, []) + self.assertFalse(request_proto.filter) + self.assertEqual(request_proto.rows_limit, 0) + + def test_to_dict_rows_populated(self): + # dictionary should be in rowset proto format + from google.cloud.bigtable_v2.types.bigtable import ReadRowsRequest + from google.cloud.bigtable.row_filters import PassAllFilter + + row_filter = PassAllFilter(False) + query = self._make_one(limit=100, row_filter=row_filter) + query.add_range("test_row", "test_row2") + query.add_range("test_row3") + query.add_range(start_key=None, end_key="test_row5") + query.add_range(b"test_row6", b"test_row7", False, True) + query.add_range() + query.add_rows(["test_row", b"test_row2", "test_row3"]) + query.add_rows(["test_row3", b"test_row4"]) + output = query.to_dict() + self.assertTrue(isinstance(output, dict)) + request_proto = ReadRowsRequest(**output) + rowset_proto = request_proto.rows + # check rows + self.assertEqual(len(rowset_proto.row_keys), 4) + self.assertEqual(rowset_proto.row_keys[0], b"test_row") + self.assertEqual(rowset_proto.row_keys[1], b"test_row2") + self.assertEqual(rowset_proto.row_keys[2], b"test_row3") + self.assertEqual(rowset_proto.row_keys[3], b"test_row4") + # check ranges + self.assertEqual(len(rowset_proto.row_ranges), 5) + self.assertEqual(rowset_proto.row_ranges[0].start_key_closed, b"test_row") + self.assertEqual(rowset_proto.row_ranges[0].end_key_open, b"test_row2") + self.assertEqual(rowset_proto.row_ranges[1].start_key_closed, b"test_row3") + self.assertEqual(rowset_proto.row_ranges[1].end_key_open, b"") + self.assertEqual(rowset_proto.row_ranges[2].start_key_closed, b"") + self.assertEqual(rowset_proto.row_ranges[2].end_key_open, b"test_row5") + self.assertEqual(rowset_proto.row_ranges[3].start_key_open, b"test_row6") + self.assertEqual(rowset_proto.row_ranges[3].end_key_closed, b"test_row7") + self.assertEqual(rowset_proto.row_ranges[4].start_key_closed, b"") + self.assertEqual(rowset_proto.row_ranges[4].end_key_open, b"") + # check limit + self.assertEqual(request_proto.rows_limit, 100) + # check filter + filter_proto = request_proto.filter + self.assertEqual(filter_proto, row_filter.to_pb()) + + def test_shard(self): + pass From 08cab3877a8f4b7656e7960007f55ab15656b813 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 22 Mar 2023 14:32:46 -0700 Subject: [PATCH 03/19] added assertions for exception messages --- google/cloud/bigtable/read_rows_query.py | 10 ++--- tests/unit/test_read_rows_query.py | 50 +++++++++++++++++++----- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index 3f4ef1ebb..fb7a4174b 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -91,9 +91,7 @@ def set_filter( or isinstance(row_filter, RowFilter) or row_filter is None ): - raise ValueError( - "row_filter must be a RowFilter or corresponding dict representation" - ) + raise ValueError("row_filter must be a RowFilter or dict") self._filter = row_filter return self @@ -144,13 +142,11 @@ def add_range( if start_is_inclusive is None: start_is_inclusive = True elif start_key is None: - raise ValueError( - "start_is_inclusive must not be included if start_key is None" - ) + raise ValueError("start_is_inclusive must not be set without start_key") if end_is_inclusive is None: end_is_inclusive = False elif end_key is None: - raise ValueError("end_is_inclusive must not be included if end_key is None") + raise ValueError("end_is_inclusive must not be set without end_key") # ensure that start_key and end_key are bytes if isinstance(start_key, str): start_key = start_key.encode() diff --git a/tests/unit/test_read_rows_query.py b/tests/unit/test_read_rows_query.py index eb924edaa..569e97f17 100644 --- a/tests/unit/test_read_rows_query.py +++ b/tests/unit/test_read_rows_query.py @@ -50,8 +50,9 @@ def test_ctor_explicit(self): self.assertEqual(query.limit, 10) def test_ctor_invalid_limit(self): - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as exc: self._make_one(limit=-1) + self.assertEqual(exc.exception.args, ("limit must be >= 0",)) def test_set_filter(self): from google.cloud.bigtable.row_filters import RowFilterChain @@ -70,8 +71,11 @@ def test_set_filter(self): self.assertEqual(result, query) query.filter = RowFilterChain() self.assertEqual(query.filter, RowFilterChain()) - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as exc: query.filter = 1 + self.assertEqual( + exc.exception.args, ("row_filter must be a RowFilter or dict",) + ) def test_set_filter_dict(self): from google.cloud.bigtable.row_filters import RowSampleFilter @@ -103,10 +107,12 @@ def test_set_limit(self): result = query.set_limit(0) self.assertEqual(query.limit, 0) self.assertEqual(result, query) - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as exc: query.set_limit(-1) - with self.assertRaises(ValueError): + self.assertEqual(exc.exception.args, ("limit must be >= 0",)) + with self.assertRaises(ValueError) as exc: query.limit = -100 + self.assertEqual(exc.exception.args, ("limit must be >= 0",)) def test_add_rows_str(self): query = self._make_one() @@ -159,10 +165,12 @@ def test_add_rows_batch(self): def test_add_rows_invalid(self): query = self._make_one() - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as exc: query.add_rows(1) - with self.assertRaises(ValueError): + self.assertEqual(exc.exception.args, ("row_keys must be strings or bytes",)) + with self.assertRaises(ValueError) as exc: query.add_rows(["s", 0]) + self.assertEqual(exc.exception.args, ("row_keys must be strings or bytes",)) def test_duplicate_rows(self): # should only hold one of each input key @@ -212,14 +220,36 @@ def test_add_range(self): self.assertEqual(query.row_ranges[4][0], None) self.assertEqual(query.row_ranges[4][1], None) # test with inclusive flags only - with self.assertRaises(ValueError): + with self.assertRaises(ValueError) as exc: query.add_range(start_is_inclusive=True, end_is_inclusive=True) - with self.assertRaises(ValueError): + self.assertEqual( + exc.exception.args, + ("start_is_inclusive must not be set without start_key",), + ) + with self.assertRaises(ValueError) as exc: query.add_range(start_is_inclusive=False, end_is_inclusive=False) - with self.assertRaises(ValueError): + self.assertEqual( + exc.exception.args, + ("start_is_inclusive must not be set without start_key",), + ) + with self.assertRaises(ValueError) as exc: query.add_range(start_is_inclusive=False) - with self.assertRaises(ValueError): + self.assertEqual( + exc.exception.args, + ("start_is_inclusive must not be set without start_key",), + ) + with self.assertRaises(ValueError) as exc: query.add_range(end_is_inclusive=True) + self.assertEqual( + exc.exception.args, ("end_is_inclusive must not be set without end_key",) + ) + # test with invalid keys + with self.assertRaises(ValueError) as exc: + query.add_range(1, "2") + self.assertEqual(exc.exception.args, ("start_key must be a string or bytes",)) + with self.assertRaises(ValueError) as exc: + query.add_range("1", 2) + self.assertEqual(exc.exception.args, ("end_key must be a string or bytes",)) def test_to_dict_rows_default(self): # dictionary should be in rowset proto format From 7bf6c7be8ced1b0dad2617ba778a4be9312edcb3 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 22 Mar 2023 14:32:58 -0700 Subject: [PATCH 04/19] added to_dict stub --- google/cloud/bigtable/row_filters.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/google/cloud/bigtable/row_filters.py b/google/cloud/bigtable/row_filters.py index 53192acc8..696e76c8c 100644 --- a/google/cloud/bigtable/row_filters.py +++ b/google/cloud/bigtable/row_filters.py @@ -35,6 +35,14 @@ class RowFilter(object): This class is a do-nothing base class for all row filters. """ + def to_dict(self): + """Convert the filter to a dictionary. + + :rtype: dict + :returns: The dictionary representation of this filter. + """ + raise NotImplementedError + class _BoolFilter(RowFilter): """Row filter that uses a boolean flag. From d40128d2a61cb31570d999b3e3c9b78b54cb4052 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 23 Mar 2023 13:55:58 -0700 Subject: [PATCH 05/19] fixed failing tests --- tests/unit/test_read_rows_query.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_read_rows_query.py b/tests/unit/test_read_rows_query.py index 569e97f17..2f0261eba 100644 --- a/tests/unit/test_read_rows_query.py +++ b/tests/unit/test_read_rows_query.py @@ -91,7 +91,7 @@ def test_set_filter_dict(self): output = query.to_dict() self.assertEqual(output["filter"], filter1_dict) proto_output = ReadRowsRequest(**output) - self.assertEqual(proto_output.filter, filter1.to_pb()) + self.assertEqual(proto_output.filter, filter1._to_pb()) query.filter = None self.assertEqual(query.filter, None) @@ -308,7 +308,7 @@ def test_to_dict_rows_populated(self): self.assertEqual(request_proto.rows_limit, 100) # check filter filter_proto = request_proto.filter - self.assertEqual(filter_proto, row_filter.to_pb()) + self.assertEqual(filter_proto, row_filter._to_pb()) def test_shard(self): pass From b8a6218cc5a22826397ea6c2e808862cf3a22502 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 27 Mar 2023 11:43:31 -0700 Subject: [PATCH 06/19] update error text Co-authored-by: Mattie Fu --- google/cloud/bigtable/read_rows_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index fb7a4174b..f92e588bf 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -142,7 +142,7 @@ def add_range( if start_is_inclusive is None: start_is_inclusive = True elif start_key is None: - raise ValueError("start_is_inclusive must not be set without start_key") + raise ValueError("start_is_inclusive must be set with start_key") if end_is_inclusive is None: end_is_inclusive = False elif end_key is None: From fd1038da6343267075dd980f8254f7aa4f31cec4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 27 Mar 2023 11:44:09 -0700 Subject: [PATCH 07/19] update error text Co-authored-by: Mattie Fu --- google/cloud/bigtable/read_rows_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index f92e588bf..f16de7b14 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -146,7 +146,7 @@ def add_range( if end_is_inclusive is None: end_is_inclusive = False elif end_key is None: - raise ValueError("end_is_inclusive must not be set without end_key") + raise ValueError("end_is_inclusive must be set with end_key") # ensure that start_key and end_key are bytes if isinstance(start_key, str): start_key = start_key.encode() From 00be65a8d330d80f4efad87980ba378dd5c484b7 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Wed, 29 Mar 2023 16:40:00 -0700 Subject: [PATCH 08/19] fixed broken test --- tests/unit/test_read_rows_query.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_read_rows_query.py b/tests/unit/test_read_rows_query.py index 2f0261eba..822a75384 100644 --- a/tests/unit/test_read_rows_query.py +++ b/tests/unit/test_read_rows_query.py @@ -224,24 +224,24 @@ def test_add_range(self): query.add_range(start_is_inclusive=True, end_is_inclusive=True) self.assertEqual( exc.exception.args, - ("start_is_inclusive must not be set without start_key",), + ("start_is_inclusive must be set with start_key",), ) with self.assertRaises(ValueError) as exc: query.add_range(start_is_inclusive=False, end_is_inclusive=False) self.assertEqual( exc.exception.args, - ("start_is_inclusive must not be set without start_key",), + ("start_is_inclusive must be set with start_key",), ) with self.assertRaises(ValueError) as exc: query.add_range(start_is_inclusive=False) self.assertEqual( exc.exception.args, - ("start_is_inclusive must not be set without start_key",), + ("start_is_inclusive must be set with start_key",), ) with self.assertRaises(ValueError) as exc: query.add_range(end_is_inclusive=True) self.assertEqual( - exc.exception.args, ("end_is_inclusive must not be set without end_key",) + exc.exception.args, ("end_is_inclusive must be set with end_key",) ) # test with invalid keys with self.assertRaises(ValueError) as exc: From 8f15e9c2b80e9ff28ad622ed81c2609f950a2caf Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Thu, 30 Mar 2023 20:53:50 +0000 Subject: [PATCH 09/19] Update docstring Co-authored-by: Mariatta Wijaya --- google/cloud/bigtable/read_rows_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index f16de7b14..24f85f622 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -24,7 +24,7 @@ @dataclass class _RangePoint: - # model class for a point in a row range + """Model class for a point in a row range""" key: row_key is_inclusive: bool From 68a5a0f9167ea47e6f9957cb8951d0035673c963 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 3 Apr 2023 08:58:42 -0700 Subject: [PATCH 10/19] added RowRange object --- google/cloud/bigtable/read_rows_query.py | 75 +++++++++++++++--------- 1 file changed, 48 insertions(+), 27 deletions(-) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index 24f85f622..0cec29718 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -28,6 +28,46 @@ class _RangePoint: key: row_key is_inclusive: bool +@dataclass +class RowRange + start: _RangePoint | None + end: _RangePoint | None + + def __init__(self, + start_key: str | bytes | None = None, + end_key: str | bytes | None = None, + start_is_inclusive: bool | None = None, + end_is_inclusive: bool | None = None, + ): + # check for invalid combinations of arguments + if start_is_inclusive is None: + start_is_inclusive = True + elif start_key is None: + raise ValueError("start_is_inclusive must be set with start_key") + if end_is_inclusive is None: + end_is_inclusive = False + elif end_key is None: + raise ValueError("end_is_inclusive must be set with end_key") + # ensure that start_key and end_key are bytes + if isinstance(start_key, str): + start_key = start_key.encode() + elif start_key is not None and not isinstance(start_key, bytes): + raise ValueError("start_key must be a string or bytes") + if isinstance(end_key, str): + end_key = end_key.encode() + elif end_key is not None and not isinstance(end_key, bytes): + raise ValueError("end_key must be a string or bytes") + + self.start = ( + _RangePoint(start_key, start_is_inclusive) + if start_key is not None + else None + ) + self.end = ( + _RangePoint(end_key, end_is_inclusive) + if end_key is not None + else None + ) class ReadRowsQuery: """ @@ -37,6 +77,7 @@ class ReadRowsQuery: def __init__( self, row_keys: list[str | bytes] | str | bytes | None = None, + row_ranges: list[RowRange] | RowRange | None = None, limit: int | None = None, row_filter: RowFilter | dict[str, Any] | None = None, ): @@ -50,7 +91,9 @@ def __init__( - row_filter: a RowFilter to apply to the query """ self.row_keys: set[bytes] = set() - self.row_ranges: list[tuple[_RangePoint | None, _RangePoint | None]] = [] + self.row_ranges: list[RowRange] = [] + for range in row_ranges: + self.row_ranges.append(range) if row_keys: self.add_rows(row_keys) self.limit: int | None = limit @@ -138,32 +181,10 @@ def add_range( - end_is_inclusive: if True, the end key is included in the range defaults to False if None. Must not be included if end_key is None """ - # check for invalid combinations of arguments - if start_is_inclusive is None: - start_is_inclusive = True - elif start_key is None: - raise ValueError("start_is_inclusive must be set with start_key") - if end_is_inclusive is None: - end_is_inclusive = False - elif end_key is None: - raise ValueError("end_is_inclusive must be set with end_key") - # ensure that start_key and end_key are bytes - if isinstance(start_key, str): - start_key = start_key.encode() - elif start_key is not None and not isinstance(start_key, bytes): - raise ValueError("start_key must be a string or bytes") - if isinstance(end_key, str): - end_key = end_key.encode() - elif end_key is not None and not isinstance(end_key, bytes): - raise ValueError("end_key must be a string or bytes") - - start_pt = ( - _RangePoint(start_key, start_is_inclusive) - if start_key is not None - else None + new_range = RowRange( + start_key, end_key, start_is_inclusive, end_is_inclusive ) - end_pt = _RangePoint(end_key, end_is_inclusive) if end_key is not None else None - self.row_ranges.append((start_pt, end_pt)) + self.row_ranges.append(new_range) return self def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery]: @@ -226,7 +247,7 @@ def limit(self, new_limit: int | None): @property def filter(self): """ - Getter implemntation for filter property + Getter implementation for filter property """ return self._filter From 1fba6eaa7b3d121f93b5a3aadc9631e09844b19e Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 3 Apr 2023 11:32:44 -0700 Subject: [PATCH 11/19] updated add_keys --- google/cloud/bigtable/read_rows_query.py | 32 ++++++++++++------------ 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index 0cec29718..2d8e5d895 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -76,7 +76,7 @@ class ReadRowsQuery: def __init__( self, - row_keys: list[str | bytes] | str | bytes | None = None, + row_key: list[str | bytes] | str | bytes | None = None, row_ranges: list[RowRange] | RowRange | None = None, limit: int | None = None, row_filter: RowFilter | dict[str, Any] | None = None, @@ -85,7 +85,9 @@ def __init__( Create a new ReadRowsQuery Args: - - row_keys: a list of row keys to include in the query + - row_keys: row keys to include in the query + a query can contain multiple keys, but ranges should be preferred + - row_ranges: ranges of rows to include in the query - limit: the maximum number of rows to return. None or 0 means no limit default: None (no limit) - row_filter: a RowFilter to apply to the query @@ -95,7 +97,8 @@ def __init__( for range in row_ranges: self.row_ranges.append(range) if row_keys: - self.add_rows(row_keys) + for k in row_keys: + self.add_key(k) self.limit: int | None = limit self.filter: RowFilter | dict[str, Any] = row_filter @@ -138,27 +141,24 @@ def set_filter( self._filter = row_filter return self - def add_rows(self, row_keys: list[str | bytes] | str | bytes) -> ReadRowsQuery: + def add_key(self, row_key: str | bytes) -> ReadRowsQuery: """ - Add a list of row keys to this query + Add a row key to this query + + A query can contain multiple keys, but ranges should be preferred Args: - - row_keys: a list of row keys to add to this query + - row_key: a key to add to this query Returns: - a reference to this query for chaining Raises: - ValueError if an input is not a string or bytes """ - if not isinstance(row_keys, list): - row_keys = [row_keys] - update_set = set() - for k in row_keys: - if isinstance(k, str): - k = k.encode() - elif not isinstance(k, bytes): - raise ValueError("row_keys must be strings or bytes") - update_set.add(k) - self.row_keys.update(update_set) + if isinstance(row_key, str): + row_key = row_key.encode() + elif not isinstance(row_key, bytes): + raise ValueError("row_key must be string or bytes") + self.row_keys.add(row_key) return self def add_range( From c4f82b049ff05dc226193ca83bbec2e4cb3534fc Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 3 Apr 2023 11:35:29 -0700 Subject: [PATCH 12/19] removed chaining --- google/cloud/bigtable/read_rows_query.py | 53 +++++------------------- 1 file changed, 11 insertions(+), 42 deletions(-) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index 2d8e5d895..1df66658b 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -29,12 +29,13 @@ class _RangePoint: is_inclusive: bool @dataclass -class RowRange +class RowRange: start: _RangePoint | None end: _RangePoint | None - def __init__(self, - start_key: str | bytes | None = None, + def __init__( + self, + start_key: str | bytes | None = None, end_key: str | bytes | None = None, start_is_inclusive: bool | None = None, end_is_inclusive: bool | None = None, @@ -102,7 +103,8 @@ def __init__( self.limit: int | None = limit self.filter: RowFilter | dict[str, Any] = row_filter - def set_limit(self, new_limit: int | None): + @property + def limit(self, new_limit: int | None): """ Set the maximum number of rows to return by this query. @@ -118,11 +120,11 @@ def set_limit(self, new_limit: int | None): if new_limit is not None and new_limit < 0: raise ValueError("limit must be >= 0") self._limit = new_limit - return self - def set_filter( + @property + def filter( self, row_filter: RowFilter | dict[str, Any] | None - ) -> ReadRowsQuery: + ): """ Set a RowFilter to apply to this query @@ -139,9 +141,8 @@ def set_filter( ): raise ValueError("row_filter must be a RowFilter or dict") self._filter = row_filter - return self - def add_key(self, row_key: str | bytes) -> ReadRowsQuery: + def add_key(self, row_key: str | bytes): """ Add a row key to this query @@ -159,7 +160,6 @@ def add_key(self, row_key: str | bytes) -> ReadRowsQuery: elif not isinstance(row_key, bytes): raise ValueError("row_key must be string or bytes") self.row_keys.add(row_key) - return self def add_range( self, @@ -167,7 +167,7 @@ def add_range( end_key: str | bytes | None = None, start_is_inclusive: bool | None = None, end_is_inclusive: bool | None = None, - ) -> ReadRowsQuery: + ): """ Add a range of row keys to this query. @@ -185,7 +185,6 @@ def add_range( start_key, end_key, start_is_inclusive, end_is_inclusive ) self.row_ranges.append(new_range) - return self def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery]: """ @@ -227,33 +226,3 @@ def to_dict(self) -> dict[str, Any]: if self.limit is not None: final_dict["rows_limit"] = self.limit return final_dict - - # Support limit and filter as properties - - @property - def limit(self) -> int | None: - """ - Getter implementation for limit property - """ - return self._limit - - @limit.setter - def limit(self, new_limit: int | None): - """ - Setter implementation for limit property - """ - self.set_limit(new_limit) - - @property - def filter(self): - """ - Getter implementation for filter property - """ - return self._filter - - @filter.setter - def filter(self, row_filter: RowFilter | dict[str, Any] | None): - """ - Setter implementation for filter property - """ - self.set_filter(row_filter) From caca14ccf10567df2400822abb9d481869b8e3b4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 3 Apr 2023 11:42:54 -0700 Subject: [PATCH 13/19] improved to_dicts --- google/cloud/bigtable/read_rows_query.py | 26 +++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index 1df66658b..8d64371c3 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -70,6 +70,17 @@ def __init__( else None ) + def _to_dict(self) -> dict[str, bytes]: + """Converts this object to a dictionary""" + output = {} + if self.start is not None: + key = "start_key_closed" if self.start.is_inclusive else "start_key_open" + output[key] = self.start.key + if self.end is not None: + key = "end_key_closed" if self.end.is_inclusive else "end_key_open" + output[key] = self.end.key + return output + class ReadRowsQuery: """ Class to encapsulate details of a read row request @@ -77,7 +88,7 @@ class ReadRowsQuery: def __init__( self, - row_key: list[str | bytes] | str | bytes | None = None, + row_keys: list[str | bytes] | str | bytes | None = None, row_ranges: list[RowRange] | RowRange | None = None, limit: int | None = None, row_filter: RowFilter | dict[str, Any] | None = None, @@ -197,21 +208,12 @@ def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery """ raise NotImplementedError - def to_dict(self) -> dict[str, Any]: + def _to_dict(self) -> dict[str, Any]: """ Convert this query into a dictionary that can be used to construct a ReadRowsRequest protobuf """ - ranges = [] - for start, end in self.row_ranges: - new_range = {} - if start is not None: - key = "start_key_closed" if start.is_inclusive else "start_key_open" - new_range[key] = start.key - if end is not None: - key = "end_key_closed" if end.is_inclusive else "end_key_open" - new_range[key] = end.key - ranges.append(new_range) + ranges = [r._to_dict() for r in self.row_ranges] row_keys = list(self.row_keys) row_keys.sort() row_set = {"row_keys": row_keys, "row_ranges": ranges} From 5f9ce85b27ab09d950d053430d6eba5802cc594b Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 3 Apr 2023 11:59:28 -0700 Subject: [PATCH 14/19] improving row_ranges --- google/cloud/bigtable/read_rows_query.py | 28 +++++++++--------------- 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index 8d64371c3..4d8eae42a 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -174,28 +174,17 @@ def add_key(self, row_key: str | bytes): def add_range( self, - start_key: str | bytes | None = None, - end_key: str | bytes | None = None, - start_is_inclusive: bool | None = None, - end_is_inclusive: bool | None = None, + row_range: RowRange | dict[str, bytes], ): """ Add a range of row keys to this query. Args: - - start_key: the start of the range - if None, start_key is interpreted as the empty string, inclusive - - end_key: the end of the range - if None, end_key is interpreted as the infinite row key, exclusive - - start_is_inclusive: if True, the start key is included in the range - defaults to True if None. Must not be included if start_key is None - - end_is_inclusive: if True, the end key is included in the range - defaults to False if None. Must not be included if end_key is None + - row_range: a range of row keys to add to this query + Can be a RowRange object or a dict representation in + RowRange proto format """ - new_range = RowRange( - start_key, end_key, start_is_inclusive, end_is_inclusive - ) - self.row_ranges.append(new_range) + self.row_ranges.append(row_range) def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery]: """ @@ -213,10 +202,13 @@ def _to_dict(self) -> dict[str, Any]: Convert this query into a dictionary that can be used to construct a ReadRowsRequest protobuf """ - ranges = [r._to_dict() for r in self.row_ranges] + row_ranges = [] + for r in self.row_ranges: + dict_range = r._to_dict() if isinstance(r, RowRange) else r + row_ranges.append(dict_range) row_keys = list(self.row_keys) row_keys.sort() - row_set = {"row_keys": row_keys, "row_ranges": ranges} + row_set = {"row_keys": row_keys, "row_ranges": row_ranges} final_dict: dict[str, Any] = { "rows": row_set, } From 8e5f60a5e25dcd2b5eb900e209c74b0fd7d50682 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 3 Apr 2023 12:03:29 -0700 Subject: [PATCH 15/19] fixed properties --- google/cloud/bigtable/read_rows_query.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index 4d8eae42a..d8237f9f2 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -115,6 +115,10 @@ def __init__( self.filter: RowFilter | dict[str, Any] = row_filter @property + def limit(self) -> int | None: + return self._limit + + @property.setter def limit(self, new_limit: int | None): """ Set the maximum number of rows to return by this query. @@ -133,6 +137,10 @@ def limit(self, new_limit: int | None): self._limit = new_limit @property + def filter(self) -> RowFilter | dict[str, Any]: + return self._filter + + @property.setter def filter( self, row_filter: RowFilter | dict[str, Any] | None ): From 57184c18fec7e455e99bbbc2017f64c0b91edff4 Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 3 Apr 2023 12:03:37 -0700 Subject: [PATCH 16/19] added type checking to range --- google/cloud/bigtable/read_rows_query.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index d8237f9f2..96e4bdc4a 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -192,6 +192,12 @@ def add_range( Can be a RowRange object or a dict representation in RowRange proto format """ + if not ( + isinstance(row_range, dict) + or isinstance(row_range, RowRange) + or row_range is None + ): + raise ValueError("row_range must be a RowRange or dict") self.row_ranges.append(row_range) def shard(self, shard_keys: "RowKeySamples" | None = None) -> list[ReadRowsQuery]: From 3eda7f4a6c8f97e549f75aaf2118c24658c24fde Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 3 Apr 2023 13:04:45 -0700 Subject: [PATCH 17/19] got tests passing --- google/cloud/bigtable/__init__.py | 2 + google/cloud/bigtable/read_rows_query.py | 17 +- tests/unit/test_read_rows_query.py | 245 +++++++++++++---------- 3 files changed, 156 insertions(+), 108 deletions(-) diff --git a/google/cloud/bigtable/__init__.py b/google/cloud/bigtable/__init__.py index daa562c0c..251e41e42 100644 --- a/google/cloud/bigtable/__init__.py +++ b/google/cloud/bigtable/__init__.py @@ -22,6 +22,7 @@ from google.cloud.bigtable.client import Table from google.cloud.bigtable.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.read_rows_query import RowRange from google.cloud.bigtable.row_response import RowResponse from google.cloud.bigtable.row_response import CellResponse @@ -43,6 +44,7 @@ "Table", "RowKeySamples", "ReadRowsQuery", + "RowRange", "MutationsBatcher", "Mutation", "BulkMutationsEntry", diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index 96e4bdc4a..01b62507f 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -91,7 +91,7 @@ def __init__( row_keys: list[str | bytes] | str | bytes | None = None, row_ranges: list[RowRange] | RowRange | None = None, limit: int | None = None, - row_filter: RowFilter | dict[str, Any] | None = None, + row_filter: RowFilter | None = None, ): """ Create a new ReadRowsQuery @@ -105,10 +105,15 @@ def __init__( - row_filter: a RowFilter to apply to the query """ self.row_keys: set[bytes] = set() - self.row_ranges: list[RowRange] = [] - for range in row_ranges: - self.row_ranges.append(range) + self.row_ranges: list[RowRange | dict[str, bytes]] = [] + if row_ranges: + if isinstance(row_ranges, RowRange): + row_ranges = [row_ranges] + for r in row_ranges: + self.add_range(r) if row_keys: + if not isinstance(row_keys, list): + row_keys = [row_keys] for k in row_keys: self.add_key(k) self.limit: int | None = limit @@ -118,7 +123,7 @@ def __init__( def limit(self) -> int | None: return self._limit - @property.setter + @limit.setter def limit(self, new_limit: int | None): """ Set the maximum number of rows to return by this query. @@ -140,7 +145,7 @@ def limit(self, new_limit: int | None): def filter(self) -> RowFilter | dict[str, Any]: return self._filter - @property.setter + @filter.setter def filter( self, row_filter: RowFilter | dict[str, Any] | None ): diff --git a/tests/unit/test_read_rows_query.py b/tests/unit/test_read_rows_query.py index 822a75384..b4954b261 100644 --- a/tests/unit/test_read_rows_query.py +++ b/tests/unit/test_read_rows_query.py @@ -19,6 +19,95 @@ b"row_key_2", ] +class TestRowRange(unittest.TestCase): + @staticmethod + def _get_target_class(): + from google.cloud.bigtable.read_rows_query import RowRange + return RowRange + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def test_ctor_start_end(self): + row_range = self._make_one("test_row", "test_row2") + self.assertEqual(row_range.start.key, "test_row".encode()) + self.assertEqual(row_range.end.key, "test_row2".encode()) + self.assertEqual(row_range.start.is_inclusive, True) + self.assertEqual(row_range.end.is_inclusive, False) + + def test_ctor_start_only(self): + row_range = self._make_one("test_row3") + self.assertEqual(row_range.start.key, "test_row3".encode()) + self.assertEqual(row_range.start.is_inclusive, True) + self.assertEqual(row_range.end, None) + + def test_ctor_end_only(self): + row_range = self._make_one(end_key="test_row4") + self.assertEqual(row_range.end.key, "test_row4".encode()) + self.assertEqual(row_range.end.is_inclusive, False) + self.assertEqual(row_range.start, None) + + def test_ctor_inclusive_flags(self): + row_range = self._make_one("test_row5", "test_row6", False, True) + self.assertEqual(row_range.start.key, "test_row5".encode()) + self.assertEqual(row_range.end.key, "test_row6".encode()) + self.assertEqual(row_range.start.is_inclusive, False) + self.assertEqual(row_range.end.is_inclusive, True) + + def test_ctor_defaults(self): + row_range = self._make_one() + self.assertEqual(row_range.start, None) + self.assertEqual(row_range.end, None) + + def test_ctor_flags_only(self): + with self.assertRaises(ValueError) as exc: + self._make_one(start_is_inclusive=True, end_is_inclusive=True) + self.assertEqual( + exc.exception.args, + ("start_is_inclusive must be set with start_key",), + ) + with self.assertRaises(ValueError) as exc: + self._make_one(start_is_inclusive=False, end_is_inclusive=False) + self.assertEqual( + exc.exception.args, + ("start_is_inclusive must be set with start_key",), + ) + with self.assertRaises(ValueError) as exc: + self._make_one(start_is_inclusive=False) + self.assertEqual( + exc.exception.args, + ("start_is_inclusive must be set with start_key",), + ) + with self.assertRaises(ValueError) as exc: + self._make_one(end_is_inclusive=True) + self.assertEqual( + exc.exception.args, ("end_is_inclusive must be set with end_key",) + ) + + def test_ctor_invalid_keys(self): + # test with invalid keys + with self.assertRaises(ValueError) as exc: + self._make_one(1, "2") + self.assertEqual(exc.exception.args, ("start_key must be a string or bytes",)) + with self.assertRaises(ValueError) as exc: + self._make_one("1", 2) + self.assertEqual(exc.exception.args, ("end_key must be a string or bytes",)) + + def test__to_dict_defaults(self): + row_range = self._make_one("test_row", "test_row2") + expected = { + "start_key_closed": b"test_row", + "end_key_open": b"test_row2", + } + self.assertEqual(row_range._to_dict(), expected) + + def test__to_dict_inclusive_flags(self): + row_range = self._make_one("test_row", "test_row2", False, True) + expected = { + "start_key_open": b"test_row", + "end_key_closed": b"test_row2", + } + self.assertEqual(row_range._to_dict(), expected) class TestReadRowsQuery(unittest.TestCase): @staticmethod @@ -60,15 +149,13 @@ def test_set_filter(self): filter1 = RowFilterChain() query = self._make_one() self.assertEqual(query.filter, None) - result = query.set_filter(filter1) + query.filter = filter1 self.assertEqual(query.filter, filter1) - self.assertEqual(result, query) filter2 = RowFilterChain() - result = query.set_filter(filter2) + query.filter = filter2 self.assertEqual(query.filter, filter2) - result = query.set_filter(None) + query.filter = None self.assertEqual(query.filter, None) - self.assertEqual(result, query) query.filter = RowFilterChain() self.assertEqual(query.filter, RowFilterChain()) with self.assertRaises(ValueError) as exc: @@ -85,10 +172,9 @@ def test_set_filter_dict(self): filter1_dict = filter1.to_dict() query = self._make_one() self.assertEqual(query.filter, None) - result = query.set_filter(filter1_dict) + query.filter = filter1_dict self.assertEqual(query.filter, filter1_dict) - self.assertEqual(result, query) - output = query.to_dict() + output = query._to_dict() self.assertEqual(output["filter"], filter1_dict) proto_output = ReadRowsRequest(**output) self.assertEqual(proto_output.filter, filter1._to_pb()) @@ -99,63 +185,58 @@ def test_set_filter_dict(self): def test_set_limit(self): query = self._make_one() self.assertEqual(query.limit, None) - result = query.set_limit(10) + query.limit = 10 self.assertEqual(query.limit, 10) - self.assertEqual(result, query) query.limit = 9 self.assertEqual(query.limit, 9) - result = query.set_limit(0) + query.limit = 0 self.assertEqual(query.limit, 0) - self.assertEqual(result, query) with self.assertRaises(ValueError) as exc: - query.set_limit(-1) + query.limit = -1 self.assertEqual(exc.exception.args, ("limit must be >= 0",)) with self.assertRaises(ValueError) as exc: query.limit = -100 self.assertEqual(exc.exception.args, ("limit must be >= 0",)) - def test_add_rows_str(self): + def test_add_key_str(self): query = self._make_one() self.assertEqual(query.row_keys, set()) input_str = "test_row" - result = query.add_rows(input_str) + query.add_key(input_str) self.assertEqual(len(query.row_keys), 1) self.assertIn(input_str.encode(), query.row_keys) - self.assertEqual(result, query) input_str2 = "test_row2" - result = query.add_rows(input_str2) + query.add_key(input_str2) self.assertEqual(len(query.row_keys), 2) self.assertIn(input_str.encode(), query.row_keys) self.assertIn(input_str2.encode(), query.row_keys) - self.assertEqual(result, query) - def test_add_rows_bytes(self): + def test_add_key_bytes(self): query = self._make_one() self.assertEqual(query.row_keys, set()) input_bytes = b"test_row" - result = query.add_rows(input_bytes) + query.add_key(input_bytes) self.assertEqual(len(query.row_keys), 1) self.assertIn(input_bytes, query.row_keys) - self.assertEqual(result, query) input_bytes2 = b"test_row2" - result = query.add_rows(input_bytes2) + query.add_key(input_bytes2) self.assertEqual(len(query.row_keys), 2) self.assertIn(input_bytes, query.row_keys) self.assertIn(input_bytes2, query.row_keys) - self.assertEqual(result, query) def test_add_rows_batch(self): query = self._make_one() self.assertEqual(query.row_keys, set()) input_batch = ["test_row", b"test_row2", "test_row3"] - result = query.add_rows(input_batch) + for k in input_batch: + query.add_key(k) self.assertEqual(len(query.row_keys), 3) self.assertIn(b"test_row", query.row_keys) self.assertIn(b"test_row2", query.row_keys) self.assertIn(b"test_row3", query.row_keys) - self.assertEqual(result, query) # test adding another batch - query.add_rows(["test_row4", b"test_row5"]) + for k in ['test_row4', b"test_row5"]: + query.add_key(k) self.assertEqual(len(query.row_keys), 5) self.assertIn(input_batch[0].encode(), query.row_keys) self.assertIn(input_batch[1], query.row_keys) @@ -163,14 +244,14 @@ def test_add_rows_batch(self): self.assertIn(b"test_row4", query.row_keys) self.assertIn(b"test_row5", query.row_keys) - def test_add_rows_invalid(self): + def test_add_key_invalid(self): query = self._make_one() with self.assertRaises(ValueError) as exc: - query.add_rows(1) - self.assertEqual(exc.exception.args, ("row_keys must be strings or bytes",)) + query.add_key(1) + self.assertEqual(exc.exception.args, ("row_key must be string or bytes",)) with self.assertRaises(ValueError) as exc: - query.add_rows(["s", 0]) - self.assertEqual(exc.exception.args, ("row_keys must be strings or bytes",)) + query.add_key(["s"]) + self.assertEqual(exc.exception.args, ("row_key must be string or bytes",)) def test_duplicate_rows(self): # should only hold one of each input key @@ -181,82 +262,38 @@ def test_duplicate_rows(self): self.assertIn(key_1, query.row_keys) self.assertIn(key_2, query.row_keys) key_3 = "test_row3" - query.add_rows([key_3 for _ in range(10)]) + for i in range(10): + query.add_key(key_3) self.assertEqual(len(query.row_keys), 3) def test_add_range(self): - # test with start and end keys + from google.cloud.bigtable.read_rows_query import RowRange query = self._make_one() self.assertEqual(query.row_ranges, []) - result = query.add_range("test_row", "test_row2") + input_range = RowRange(start_key=b"test_row") + query.add_range(input_range) self.assertEqual(len(query.row_ranges), 1) - self.assertEqual(query.row_ranges[0][0].key, "test_row".encode()) - self.assertEqual(query.row_ranges[0][1].key, "test_row2".encode()) - self.assertEqual(query.row_ranges[0][0].is_inclusive, True) - self.assertEqual(query.row_ranges[0][1].is_inclusive, False) - self.assertEqual(result, query) - # test with start key only - result = query.add_range("test_row3") + self.assertEqual(query.row_ranges[0], input_range) + input_range2 = RowRange(start_key=b"test_row2") + query.add_range(input_range2) self.assertEqual(len(query.row_ranges), 2) - self.assertEqual(query.row_ranges[1][0].key, "test_row3".encode()) - self.assertEqual(query.row_ranges[1][1], None) - self.assertEqual(result, query) - # test with end key only - result = query.add_range(start_key=None, end_key="test_row5") - self.assertEqual(len(query.row_ranges), 3) - self.assertEqual(query.row_ranges[2][0], None) - self.assertEqual(query.row_ranges[2][1].key, "test_row5".encode()) - self.assertEqual(query.row_ranges[2][1].is_inclusive, False) - # test with start and end keys and inclusive flags - result = query.add_range(b"test_row6", b"test_row7", False, True) - self.assertEqual(len(query.row_ranges), 4) - self.assertEqual(query.row_ranges[3][0].key, b"test_row6") - self.assertEqual(query.row_ranges[3][1].key, b"test_row7") - self.assertEqual(query.row_ranges[3][0].is_inclusive, False) - self.assertEqual(query.row_ranges[3][1].is_inclusive, True) - # test with nothing passed - result = query.add_range() - self.assertEqual(len(query.row_ranges), 5) - self.assertEqual(query.row_ranges[4][0], None) - self.assertEqual(query.row_ranges[4][1], None) - # test with inclusive flags only - with self.assertRaises(ValueError) as exc: - query.add_range(start_is_inclusive=True, end_is_inclusive=True) - self.assertEqual( - exc.exception.args, - ("start_is_inclusive must be set with start_key",), - ) - with self.assertRaises(ValueError) as exc: - query.add_range(start_is_inclusive=False, end_is_inclusive=False) - self.assertEqual( - exc.exception.args, - ("start_is_inclusive must be set with start_key",), - ) - with self.assertRaises(ValueError) as exc: - query.add_range(start_is_inclusive=False) - self.assertEqual( - exc.exception.args, - ("start_is_inclusive must be set with start_key",), - ) - with self.assertRaises(ValueError) as exc: - query.add_range(end_is_inclusive=True) - self.assertEqual( - exc.exception.args, ("end_is_inclusive must be set with end_key",) - ) - # test with invalid keys - with self.assertRaises(ValueError) as exc: - query.add_range(1, "2") - self.assertEqual(exc.exception.args, ("start_key must be a string or bytes",)) - with self.assertRaises(ValueError) as exc: - query.add_range("1", 2) - self.assertEqual(exc.exception.args, ("end_key must be a string or bytes",)) + self.assertEqual(query.row_ranges[0], input_range) + self.assertEqual(query.row_ranges[1], input_range2) + + def test_add_range_dict(self): + query = self._make_one() + self.assertEqual(query.row_ranges, []) + input_range = {"start_key_closed": b"test_row"} + query.add_range(input_range) + self.assertEqual(len(query.row_ranges), 1) + self.assertEqual(query.row_ranges[0], input_range) def test_to_dict_rows_default(self): # dictionary should be in rowset proto format from google.cloud.bigtable_v2.types.bigtable import ReadRowsRequest query = self._make_one() - output = query.to_dict() + output = query._to_dict() self.assertTrue(isinstance(output, dict)) self.assertEqual(len(output.keys()), 1) expected = {"rows": {"row_keys": [], "row_ranges": []}} @@ -272,17 +309,21 @@ def test_to_dict_rows_populated(self): # dictionary should be in rowset proto format from google.cloud.bigtable_v2.types.bigtable import ReadRowsRequest from google.cloud.bigtable.row_filters import PassAllFilter + from google.cloud.bigtable.read_rows_query import RowRange row_filter = PassAllFilter(False) query = self._make_one(limit=100, row_filter=row_filter) - query.add_range("test_row", "test_row2") - query.add_range("test_row3") - query.add_range(start_key=None, end_key="test_row5") - query.add_range(b"test_row6", b"test_row7", False, True) - query.add_range() - query.add_rows(["test_row", b"test_row2", "test_row3"]) - query.add_rows(["test_row3", b"test_row4"]) - output = query.to_dict() + query.add_range(RowRange("test_row", "test_row2")) + query.add_range(RowRange("test_row3")) + query.add_range(RowRange(start_key=None, end_key="test_row5")) + query.add_range(RowRange(b"test_row6", b"test_row7", False, True)) + query.add_range(RowRange()) + query.add_key("test_row") + query.add_key(b"test_row2") + query.add_key("test_row3") + query.add_key(b"test_row3") + query.add_key(b"test_row4") + output = query._to_dict() self.assertTrue(isinstance(output, dict)) request_proto = ReadRowsRequest(**output) rowset_proto = request_proto.rows From 65f5a2ae0eb2e08dcb98e17542bef33f5feea51a Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 3 Apr 2023 13:08:18 -0700 Subject: [PATCH 18/19] blacken, mypy --- google/cloud/bigtable/read_rows_query.py | 16 +++++++--------- tests/unit/test_read_rows_query.py | 6 +++++- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index 01b62507f..9704606d1 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -25,9 +25,11 @@ @dataclass class _RangePoint: """Model class for a point in a row range""" + key: row_key is_inclusive: bool + @dataclass class RowRange: start: _RangePoint | None @@ -65,9 +67,7 @@ def __init__( else None ) self.end = ( - _RangePoint(end_key, end_is_inclusive) - if end_key is not None - else None + _RangePoint(end_key, end_is_inclusive) if end_key is not None else None ) def _to_dict(self) -> dict[str, bytes]: @@ -81,6 +81,7 @@ def _to_dict(self) -> dict[str, bytes]: output[key] = self.end.key return output + class ReadRowsQuery: """ Class to encapsulate details of a read row request @@ -117,7 +118,7 @@ def __init__( for k in row_keys: self.add_key(k) self.limit: int | None = limit - self.filter: RowFilter | dict[str, Any] = row_filter + self.filter: RowFilter | dict[str, Any] | None = row_filter @property def limit(self) -> int | None: @@ -142,13 +143,11 @@ def limit(self, new_limit: int | None): self._limit = new_limit @property - def filter(self) -> RowFilter | dict[str, Any]: + def filter(self) -> RowFilter | dict[str, Any] | None: return self._filter @filter.setter - def filter( - self, row_filter: RowFilter | dict[str, Any] | None - ): + def filter(self, row_filter: RowFilter | dict[str, Any] | None): """ Set a RowFilter to apply to this query @@ -200,7 +199,6 @@ def add_range( if not ( isinstance(row_range, dict) or isinstance(row_range, RowRange) - or row_range is None ): raise ValueError("row_range must be a RowRange or dict") self.row_ranges.append(row_range) diff --git a/tests/unit/test_read_rows_query.py b/tests/unit/test_read_rows_query.py index b4954b261..aa690bc86 100644 --- a/tests/unit/test_read_rows_query.py +++ b/tests/unit/test_read_rows_query.py @@ -19,10 +19,12 @@ b"row_key_2", ] + class TestRowRange(unittest.TestCase): @staticmethod def _get_target_class(): from google.cloud.bigtable.read_rows_query import RowRange + return RowRange def _make_one(self, *args, **kwargs): @@ -109,6 +111,7 @@ def test__to_dict_inclusive_flags(self): } self.assertEqual(row_range._to_dict(), expected) + class TestReadRowsQuery(unittest.TestCase): @staticmethod def _get_target_class(): @@ -235,7 +238,7 @@ def test_add_rows_batch(self): self.assertIn(b"test_row2", query.row_keys) self.assertIn(b"test_row3", query.row_keys) # test adding another batch - for k in ['test_row4', b"test_row5"]: + for k in ["test_row4", b"test_row5"]: query.add_key(k) self.assertEqual(len(query.row_keys), 5) self.assertIn(input_batch[0].encode(), query.row_keys) @@ -268,6 +271,7 @@ def test_duplicate_rows(self): def test_add_range(self): from google.cloud.bigtable.read_rows_query import RowRange + query = self._make_one() self.assertEqual(query.row_ranges, []) input_range = RowRange(start_key=b"test_row") From 3e724dbcb00bc4baaab1504a2d96690dc258588f Mon Sep 17 00:00:00 2001 From: Daniel Sanche Date: Mon, 3 Apr 2023 13:23:08 -0700 Subject: [PATCH 19/19] ran blacken --- google/cloud/bigtable/read_rows_query.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/google/cloud/bigtable/read_rows_query.py b/google/cloud/bigtable/read_rows_query.py index 9704606d1..9fd349d5f 100644 --- a/google/cloud/bigtable/read_rows_query.py +++ b/google/cloud/bigtable/read_rows_query.py @@ -196,10 +196,7 @@ def add_range( Can be a RowRange object or a dict representation in RowRange proto format """ - if not ( - isinstance(row_range, dict) - or isinstance(row_range, RowRange) - ): + if not (isinstance(row_range, dict) or isinstance(row_range, RowRange)): raise ValueError("row_range must be a RowRange or dict") self.row_ranges.append(row_range)