diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 70f7d3dcb7..4b8200d96f 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -287,7 +287,11 @@ def _list_feature_views( for fv in self._registry.list_feature_views( self.project, allow_cache=allow_cache ): - if hide_dummy_entity and fv.entities[0] == DUMMY_ENTITY_NAME: + if ( + hide_dummy_entity + and fv.entities + and fv.entities[0] == DUMMY_ENTITY_NAME + ): fv.entities = [] fv.entity_columns = [] feature_views.append(fv) diff --git a/sdk/python/feast/infra/online_stores/contrib/hbase_online_store/hbase.py b/sdk/python/feast/infra/online_stores/contrib/hbase_online_store/hbase.py index 2636cf95e2..1da9de89a8 100644 --- a/sdk/python/feast/infra/online_stores/contrib/hbase_online_store/hbase.py +++ b/sdk/python/feast/infra/online_stores/contrib/hbase_online_store/hbase.py @@ -3,14 +3,16 @@ from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple -from happybase import Connection +from happybase import ConnectionPool +from happybase.connection import DEFAULT_PROTOCOL, DEFAULT_TRANSPORT +from pydantic import StrictStr from pydantic.typing import Literal from feast import Entity from feast.feature_view import FeatureView from feast.infra.online_stores.helpers import compute_entity_id from feast.infra.online_stores.online_store import OnlineStore -from feast.infra.utils.hbase_utils import HbaseConstants, HbaseUtils +from feast.infra.utils.hbase_utils import HBaseConnector, HbaseConstants from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel, RepoConfig @@ -23,35 +25,20 @@ class HbaseOnlineStoreConfig(FeastConfigBaseModel): type: Literal["hbase"] = "hbase" """Online store type selector""" - host: str + host: StrictStr """Hostname of Hbase Thrift server""" - port: str + port: StrictStr """Port in which Hbase Thrift server is running""" + connection_pool_size: int = 4 + """Number of connections to Hbase Thrift server to keep in the connection pool""" -class HbaseConnection: - """ - Hbase connecttion to connect to hbase. - - Attributes: - store_config: Online store config for Hbase store. - """ + protocol: StrictStr = DEFAULT_PROTOCOL + """Protocol used to communicate with Hbase Thrift server""" - def __init__(self, store_config: HbaseOnlineStoreConfig): - self._store_config = store_config - self._real_conn = Connection( - host=store_config.host, port=int(store_config.port) - ) - - @property - def real_conn(self) -> Connection: - """Stores the real happybase Connection to connect to hbase.""" - return self._real_conn - - def close(self) -> None: - """Close the happybase connection.""" - self.real_conn.close() + transport: StrictStr = DEFAULT_TRANSPORT + """Transport used to communicate with Hbase Thrift server""" class HbaseOnlineStore(OnlineStore): @@ -62,7 +49,7 @@ class HbaseOnlineStore(OnlineStore): _conn: Happybase Connection to connect to hbase thrift server. """ - _conn: Connection = None + _conn: ConnectionPool = None def _get_conn(self, config: RepoConfig): """ @@ -76,7 +63,13 @@ def _get_conn(self, config: RepoConfig): assert isinstance(store_config, HbaseOnlineStoreConfig) if not self._conn: - self._conn = Connection(host=store_config.host, port=int(store_config.port)) + self._conn = ConnectionPool( + host=store_config.host, + port=int(store_config.port), + size=int(store_config.connection_pool_size), + protocol=store_config.protocol, + transport=store_config.transport, + ) return self._conn @log_exceptions_and_usage(online_store="hbase") @@ -102,7 +95,7 @@ def online_write_batch( the online store. Can be used to display progress. """ - hbase = HbaseUtils(self._get_conn(config)) + hbase = HBaseConnector(self._get_conn(config)) project = config.project table_name = self._table_id(project, table) @@ -154,7 +147,7 @@ def online_read( entity_keys: a list of entity keys that should be read from the FeatureStore. requested_features: a list of requested feature names. """ - hbase = HbaseUtils(self._get_conn(config)) + hbase = HBaseConnector(self._get_conn(config)) project = config.project table_name = self._table_id(project, table) @@ -206,7 +199,7 @@ def update( tables_to_delete: Tables to delete from the Hbase Online Store. tables_to_keep: Tables to keep in the Hbase Online Store. """ - hbase = HbaseUtils(self._get_conn(config)) + hbase = HBaseConnector(self._get_conn(config)) project = config.project # We don't create any special state for the entites in this implementation. @@ -232,7 +225,7 @@ def teardown( config: The RepoConfig for the current FeatureStore. tables: Tables to delete from the feature repo. """ - hbase = HbaseUtils(self._get_conn(config)) + hbase = HBaseConnector(self._get_conn(config)) project = config.project for table in tables: diff --git a/sdk/python/feast/infra/utils/hbase_utils.py b/sdk/python/feast/infra/utils/hbase_utils.py index 4816a60087..d44f93f161 100644 --- a/sdk/python/feast/infra/utils/hbase_utils.py +++ b/sdk/python/feast/infra/utils/hbase_utils.py @@ -1,9 +1,6 @@ from typing import List -from happybase import Connection - -from feast.infra.key_encoding_utils import serialize_entity_key -from feast.protos.feast.types.EntityKey_pb2 import EntityKey +from happybase import ConnectionPool class HbaseConstants: @@ -28,7 +25,7 @@ def get_col_from_feature(feature): return HbaseConstants.DEFAULT_COLUMN_FAMILY + ":" + feature -class HbaseUtils: +class HBaseConnector: """ Utils class to manage different Hbase operations. @@ -40,14 +37,22 @@ class HbaseUtils: """ def __init__( - self, conn: Connection = None, host: str = None, port: int = None, timeout=None + self, + pool: ConnectionPool = None, + host: str = None, + port: int = None, + connection_pool_size: int = 4, ): - if conn is None: + if pool is None: self.host = host self.port = port - self.conn = Connection(host=host, port=port, timeout=timeout) + self.pool = ConnectionPool( + host=host, + port=port, + size=connection_pool_size, + ) else: - self.conn = conn + self.pool = pool def create_table(self, table_name: str, colm_family: List[str]): """ @@ -60,7 +65,9 @@ def create_table(self, table_name: str, colm_family: List[str]): cf_dict: dict = {} for cf in colm_family: cf_dict[cf] = dict() - return self.conn.create_table(table_name, cf_dict) + + with self.pool.connection() as conn: + return conn.create_table(table_name, cf_dict) def create_table_with_default_cf(self, table_name: str): """ @@ -69,7 +76,8 @@ def create_table_with_default_cf(self, table_name: str): Arguments: table_name: Name of the Hbase table. """ - return self.conn.create_table(table_name, {"default": dict()}) + with self.pool.connection() as conn: + return conn.create_table(table_name, {"default": dict()}) def check_if_table_exist(self, table_name: str): """ @@ -78,16 +86,18 @@ def check_if_table_exist(self, table_name: str): Arguments: table_name: Name of the Hbase table. """ - return bytes(table_name, "utf-8") in self.conn.tables() + with self.pool.connection() as conn: + return bytes(table_name, "utf-8") in conn.tables() def batch(self, table_name: str): """ - Returns a 'Batch' instance that can be used for mass data manipulation in the hbase table. + Returns a "Batch" instance that can be used for mass data manipulation in the hbase table. Arguments: table_name: Name of the Hbase table. """ - return self.conn.table(table_name).batch() + with self.pool.connection() as conn: + return conn.table(table_name).batch() def put(self, table_name: str, row_key: str, data: dict): """ @@ -98,8 +108,9 @@ def put(self, table_name: str, row_key: str, data: dict): row_key: Row key of the row to be inserted to hbase table. data: Mapping of column family name:column name to column values """ - table = self.conn.table(table_name) - table.put(row_key, data) + with self.pool.connection() as conn: + table = conn.table(table_name) + table.put(row_key, data) def row( self, @@ -119,8 +130,9 @@ def row( timestamp: timestamp specifies the maximum version the cells can have. include_timestamp: specifies if (column, timestamp) to be return instead of only column. """ - table = self.conn.table(table_name) - return table.row(row_key, columns, timestamp, include_timestamp) + with self.pool.connection() as conn: + table = conn.table(table_name) + return table.row(row_key, columns, timestamp, include_timestamp) def rows( self, @@ -140,52 +152,69 @@ def rows( timestamp: timestamp specifies the maximum version the cells can have. include_timestamp: specifies if (column, timestamp) to be return instead of only column. """ - table = self.conn.table(table_name) - return table.rows(row_keys, columns, timestamp, include_timestamp) + with self.pool.connection() as conn: + table = conn.table(table_name) + return table.rows(row_keys, columns, timestamp, include_timestamp) def print_table(self, table_name): """Prints the table scanning all the rows of the hbase table.""" - table = self.conn.table(table_name) - scan_data = table.scan() - for row_key, cols in scan_data: - print(row_key.decode("utf-8"), cols) + with self.pool.connection() as conn: + table = conn.table(table_name) + scan_data = table.scan() + for row_key, cols in scan_data: + print(row_key.decode("utf-8"), cols) def delete_table(self, table: str): """Deletes the hbase table given the table name.""" if self.check_if_table_exist(table): - self.conn.delete_table(table, disable=True) + with self.pool.connection() as conn: + conn.delete_table(table, disable=True) def close_conn(self): """Closes the happybase connection.""" - self.conn.close() + with self.pool.connection() as conn: + conn.close() def main(): + from feast.infra.key_encoding_utils import serialize_entity_key + from feast.protos.feast.types.EntityKey_pb2 import EntityKey from feast.protos.feast.types.Value_pb2 import Value - connection = Connection(host="localhost", port=9090) - table = connection.table("test_hbase_driver_hourly_stats") - row_keys = [ - serialize_entity_key( - EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1004)]), - entity_key_serialization_version=2, - ).hex(), - serialize_entity_key( - EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1005)]), - entity_key_serialization_version=2, - ).hex(), - serialize_entity_key( - EntityKey(join_keys=["driver_id"], entity_values=[Value(int64_val=1024)]), - entity_key_serialization_version=2, - ).hex(), - ] - rows = table.rows(row_keys) - - for row_key, row in rows: - for key, value in row.items(): - col_name = bytes.decode(key, "utf-8").split(":")[1] - print(col_name, value) - print() + pool = ConnectionPool( + host="localhost", + port=9090, + size=2, + ) + with pool.connection() as connection: + table = connection.table("test_hbase_driver_hourly_stats") + row_keys = [ + serialize_entity_key( + EntityKey( + join_keys=["driver_id"], entity_values=[Value(int64_val=1004)] + ), + entity_key_serialization_version=2, + ).hex(), + serialize_entity_key( + EntityKey( + join_keys=["driver_id"], entity_values=[Value(int64_val=1005)] + ), + entity_key_serialization_version=2, + ).hex(), + serialize_entity_key( + EntityKey( + join_keys=["driver_id"], entity_values=[Value(int64_val=1024)] + ), + entity_key_serialization_version=2, + ).hex(), + ] + rows = table.rows(row_keys) + + for _, row in rows: + for key, value in row.items(): + col_name = bytes.decode(key, "utf-8").split(":")[1] + print(col_name, value) + print() if __name__ == "__main__":