From dff6e5ad635db50c7be47e308d0584862c6cca42 Mon Sep 17 00:00:00 2001 From: Oleksii Moskalenko Date: Mon, 16 May 2022 15:25:54 -0700 Subject: [PATCH] random port allocation Signed-off-by: Oleksii Moskalenko --- sdk/python/tests/conftest.py | 27 ++++++++++++------- .../feature_repos/repo_configuration.py | 20 -------------- .../online_store/test_universal_online.py | 15 +++++++++-- 3 files changed, 31 insertions(+), 31 deletions(-) diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index 092f5a0787..627fda524d 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -273,12 +273,19 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): @pytest.fixture(scope="session") -def python_server(environment): - assert not _check_port_open("localhost", environment.get_local_server_port()) +def feature_server_endpoint(environment): + if ( + not environment.python_feature_server + or environment.test_repo_config.provider != "local" + ): + yield environment.feature_store.get_feature_server_endpoint() + return + + port = _free_port() proc = Process( target=start_test_local_server, - args=(environment.feature_store.repo_path, environment.get_local_server_port()), + args=(environment.feature_store.repo_path, port), ) if ( environment.python_feature_server @@ -287,14 +294,10 @@ def python_server(environment): proc.start() # Wait for server to start wait_retry_backoff( - lambda: ( - None, - _check_port_open("localhost", environment.get_local_server_port()), - ), - timeout_secs=10, + lambda: (None, _check_port_open("localhost", port)), timeout_secs=10, ) - yield + yield f"http://localhost:{port}" if proc.is_alive(): proc.kill() @@ -314,6 +317,12 @@ def _check_port_open(host, port) -> bool: return sock.connect_ex((host, port)) == 0 +def _free_port(): + sock = socket.socket() + sock.bind(("", 0)) + return sock.getsockname()[1] + + @pytest.fixture(scope="session") def universal_data_sources(environment) -> TestData: return construct_universal_test_data(environment) diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 27cf1a52e9..f4c9bed92a 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -2,7 +2,6 @@ import importlib import json import os -import re import tempfile import uuid from dataclasses import dataclass @@ -328,29 +327,10 @@ class Environment: worker_id: str online_store_creator: Optional[OnlineStoreCreator] = None - next_id = 0 - def __post_init__(self): self.end_date = datetime.utcnow().replace(microsecond=0, second=0, minute=0) self.start_date: datetime = self.end_date - timedelta(days=3) - Environment.next_id += 1 - self.id = Environment.next_id - - def get_feature_server_endpoint(self) -> str: - if self.python_feature_server and self.test_repo_config.provider == "local": - return f"http://localhost:{self.get_local_server_port()}" - return self.feature_store.get_feature_server_endpoint() - - def get_local_server_port(self) -> int: - # Heuristic when running with xdist to extract unique ports for each worker - parsed_worker_id = re.findall("gw(\\d+)", self.worker_id) - if len(parsed_worker_id) != 0: - worker_id_num = int(parsed_worker_id[0]) - else: - worker_id_num = 0 - return 6000 + 100 * worker_id_num + self.id - def table_name_from_data_source(ds: DataSource) -> Optional[str]: if hasattr(ds, "table_ref"): diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 259a094426..b3115dcb3d 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -288,6 +288,7 @@ def _get_online_features_dict_remotely( def get_online_features_dict( environment: Environment, + endpoint: str, features: Union[List[str], FeatureService], entity_rows: List[Dict[str, Any]], full_feature_names: bool = False, @@ -305,7 +306,6 @@ def get_online_features_dict( assertpy.assert_that(online_features).is_not_none() dict1 = online_features.to_dict() - endpoint = environment.get_feature_server_endpoint() # If endpoint is None, it means that a local / remote feature server aren't configured if endpoint is not None: dict2 = _get_online_features_dict_remotely( @@ -447,7 +447,7 @@ def test_online_retrieval_with_event_timestamps( @pytest.mark.goserver @pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) def test_online_retrieval( - environment, universal_data_sources, python_server, full_feature_names + environment, universal_data_sources, feature_server_endpoint, full_feature_names ): fs = environment.feature_store entities, datasets, data_sources = universal_data_sources @@ -547,6 +547,7 @@ def test_online_retrieval( online_features_dict = get_online_features_dict( environment=environment, + endpoint=feature_server_endpoint, features=feature_refs, entity_rows=entity_rows, full_feature_names=full_feature_names, @@ -556,6 +557,7 @@ def test_online_retrieval( # feature isn't requested. online_features_no_conv_rate = get_online_features_dict( environment=environment, + endpoint=feature_server_endpoint, features=[ref for ref in feature_refs if ref != "driver_stats:conv_rate"], entity_rows=entity_rows, full_feature_names=full_feature_names, @@ -616,6 +618,7 @@ def test_online_retrieval( # Check what happens for missing values missing_responses_dict = get_online_features_dict( environment=environment, + endpoint=feature_server_endpoint, features=feature_refs, entity_rows=[{"driver_id": 0, "customer_id": 0, "val_to_add": 100}], full_feature_names=full_feature_names, @@ -635,6 +638,7 @@ def test_online_retrieval( with pytest.raises(RequestDataNotFoundInEntityRowsException): get_online_features_dict( environment=environment, + endpoint=feature_server_endpoint, features=feature_refs, entity_rows=[{"driver_id": 0, "customer_id": 0}], full_feature_names=full_feature_names, @@ -642,6 +646,7 @@ def test_online_retrieval( assert_feature_service_correctness( environment, + feature_server_endpoint, feature_service, entity_rows, full_feature_names, @@ -659,6 +664,7 @@ def test_online_retrieval( ] assert_feature_service_entity_mapping_correctness( environment, + feature_server_endpoint, feature_service_entity_mapping, entity_rows, full_feature_names, @@ -856,6 +862,7 @@ def get_latest_feature_values_for_location_df(entity_row, origin_df, destination def assert_feature_service_correctness( environment, + endpoint, feature_service, entity_rows, full_feature_names, @@ -866,6 +873,7 @@ def assert_feature_service_correctness( ): feature_service_online_features_dict = get_online_features_dict( environment=environment, + endpoint=endpoint, features=feature_service, entity_rows=entity_rows, full_feature_names=full_feature_names, @@ -905,6 +913,7 @@ def assert_feature_service_correctness( def assert_feature_service_entity_mapping_correctness( environment, + endpoint, feature_service, entity_rows, full_feature_names, @@ -914,6 +923,7 @@ def assert_feature_service_entity_mapping_correctness( if full_feature_names: feature_service_online_features_dict = get_online_features_dict( environment=environment, + endpoint=endpoint, features=feature_service, entity_rows=entity_rows, full_feature_names=full_feature_names, @@ -948,6 +958,7 @@ def assert_feature_service_entity_mapping_correctness( with pytest.raises(FeatureNameCollisionError): get_online_features_dict( environment=environment, + endpoint=endpoint, features=feature_service, entity_rows=entity_rows, full_feature_names=full_feature_names,