diff --git a/clients/client-python/build.gradle.kts b/clients/client-python/build.gradle.kts index 3d77039d60e..625e28c1468 100644 --- a/clients/client-python/build.gradle.kts +++ b/clients/client-python/build.gradle.kts @@ -223,6 +223,7 @@ tasks { "START_EXTERNAL_GRAVITINO" to "true", "DOCKER_TEST" to dockerTest.toString(), "GRAVITINO_CI_HIVE_DOCKER_IMAGE" to "apache/gravitino-ci:hive-0.1.13", + "GRAVITINO_OAUTH2_SAMPLE_SERVER" to "datastrato/sample-authorization-server:0.3.0", // Set the PYTHONPATH to the client-python directory, make sure the tests can import the // modules from the client-python directory. "PYTHONPATH" to "${project.rootDir.path}/clients/client-python" diff --git a/clients/client-python/gravitino/dto/responses/oauth2_token_response.py b/clients/client-python/gravitino/dto/responses/oauth2_token_response.py index 37071b72372..2b81cd54f35 100644 --- a/clients/client-python/gravitino/dto/responses/oauth2_token_response.py +++ b/clients/client-python/gravitino/dto/responses/oauth2_token_response.py @@ -44,8 +44,6 @@ def validate(self): Raise: IllegalArgumentException If the response is invalid, this exception is thrown. """ - super().validate() - if self._access_token is None: raise IllegalArgumentException("Invalid access token: None") diff --git a/clients/client-python/gravitino/utils/http_client.py b/clients/client-python/gravitino/utils/http_client.py index 678942bb4e7..696fe415cce 100644 --- a/clients/client-python/gravitino/utils/http_client.py +++ b/clients/client-python/gravitino/utils/http_client.py @@ -140,7 +140,7 @@ def _make_request(self, opener, request, timeout=None) -> Tuple[bool, Response]: except HTTPError as err: err_body = err.read() - if err_body is None: + if err_body is None or len(err_body) == 0: return ( False, ErrorResponse.generate_error_response(RESTException, err.reason), diff --git a/clients/client-python/requirements-dev.txt b/clients/client-python/requirements-dev.txt index e91d966a4cd..002e08964ca 100644 --- a/clients/client-python/requirements-dev.txt +++ b/clients/client-python/requirements-dev.txt @@ -29,3 +29,4 @@ cachetools==5.3.3 readerwriterlock==1.0.9 docker==7.1.0 pyjwt[crypto]==2.8.0 +jwcrypto==1.5.6 diff --git a/clients/client-python/tests/integration/auth/__init__.py b/clients/client-python/tests/integration/auth/__init__.py new file mode 100644 index 00000000000..c206137f175 --- /dev/null +++ b/clients/client-python/tests/integration/auth/__init__.py @@ -0,0 +1,18 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +""" diff --git a/clients/client-python/tests/integration/test_simple_auth_client.py b/clients/client-python/tests/integration/auth/test_auth_common.py similarity index 86% rename from clients/client-python/tests/integration/test_simple_auth_client.py rename to clients/client-python/tests/integration/auth/test_auth_common.py index 062e03e9dd4..57c377f788d 100644 --- a/clients/client-python/tests/integration/test_simple_auth_client.py +++ b/clients/client-python/tests/integration/auth/test_auth_common.py @@ -29,15 +29,17 @@ Catalog, Fileset, ) -from gravitino.auth.simple_auth_provider import SimpleAuthProvider from gravitino.exceptions.base import GravitinoRuntimeException -from tests.integration.integration_test_env import IntegrationTestEnv logger = logging.getLogger(__name__) -class TestSimpleAuthClient(IntegrationTestEnv): - creator: str = "test_client" +class TestCommonAuth: + """ + A common test set for AuthProvider Integration Tests + """ + + creator: str = "test" metalake_name: str = "TestClient_metalake" + str(randint(1, 10000)) catalog_name: str = "fileset_catalog" catalog_location_prop: str = "location" # Fileset Catalog must set `location` @@ -63,16 +65,8 @@ class TestSimpleAuthClient(IntegrationTestEnv): metalake_name, catalog_name, schema_name ) fileset_ident: NameIdentifier = NameIdentifier.of(schema_name, fileset_name) - - def setUp(self): - os.environ["GRAVITINO_USER"] = self.creator - self.gravitino_admin_client = GravitinoAdminClient( - uri="http://localhost:8090", auth_data_provider=SimpleAuthProvider() - ) - self.init_test_env() - - def tearDown(self): - self.clean_test_data() + gravitino_admin_client: GravitinoAdminClient + gravitino_client: GravitinoClient def clean_test_data(self): catalog = self.gravitino_client.load_catalog(name=self.catalog_name) @@ -117,14 +111,7 @@ def clean_test_data(self): os.environ["GRAVITINO_USER"] = "" def init_test_env(self): - self.gravitino_admin_client.create_metalake( - self.metalake_name, comment="", properties={} - ) - self.gravitino_client = GravitinoClient( - uri="http://localhost:8090", - metalake_name=self.metalake_name, - auth_data_provider=SimpleAuthProvider(), - ) + catalog = self.gravitino_client.create_catalog( name=self.catalog_name, catalog_type=Catalog.Type.FILESET, diff --git a/clients/client-python/tests/integration/auth/test_oauth2_client.py b/clients/client-python/tests/integration/auth/test_oauth2_client.py new file mode 100644 index 00000000000..3db7232b3a8 --- /dev/null +++ b/clients/client-python/tests/integration/auth/test_oauth2_client.py @@ -0,0 +1,178 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +""" + +import os +import subprocess +import logging +import unittest +import sys +import requests +from jwcrypto import jwk + +from gravitino.auth.auth_constants import AuthConstants +from gravitino.auth.default_oauth2_token_provider import DefaultOAuth2TokenProvider +from gravitino import GravitinoAdminClient, GravitinoClient +from gravitino.exceptions.base import GravitinoRuntimeException + +from tests.integration.auth.test_auth_common import TestCommonAuth +from tests.integration.integration_test_env import ( + IntegrationTestEnv, + check_gravitino_server_status, +) +from tests.integration.containers.oauth2_container import OAuth2Container + +logger = logging.getLogger(__name__) + +DOCKER_TEST = os.environ.get("DOCKER_TEST") + + +@unittest.skipIf( + DOCKER_TEST == "false", + "Skipping tests when DOCKER_TEST=false", +) +class TestOAuth2(IntegrationTestEnv, TestCommonAuth): + + oauth2_container: OAuth2Container = None + + @classmethod + def setUpClass(cls): + + cls._get_gravitino_home() + + cls.oauth2_container = OAuth2Container() + cls.oauth2_container_ip = cls.oauth2_container.get_ip() + + cls.oauth2_server_uri = f"http://{cls.oauth2_container_ip}:8177" + + # Get PEM from OAuth Server + default_sign_key = cls._get_default_sign_key() + + cls.config = { + "gravitino.authenticators": "oauth", + "gravitino.authenticator.oauth.serviceAudience": "test", + "gravitino.authenticator.oauth.defaultSignKey": default_sign_key, + "gravitino.authenticator.oauth.serverUri": cls.oauth2_server_uri, + "gravitino.authenticator.oauth.tokenPath": "/oauth2/token", + } + + cls.oauth2_conf_path = f"{cls.gravitino_home}/conf/gravitino.conf" + + # append the hadoop conf to server + cls._append_conf(cls.config, cls.oauth2_conf_path) + # restart the server + cls._restart_server_with_oauth() + + @classmethod + def tearDownClass(cls): + try: + # reset server conf + cls._reset_conf(cls.config, cls.oauth2_conf_path) + # restart server + cls.restart_server() + finally: + # close oauth2 container + cls.oauth2_container.close() + + @classmethod + def _get_default_sign_key(cls) -> str: + + jwk_uri = f"{cls.oauth2_server_uri}/oauth2/jwks" + + # Get JWK from OAuth2 Server + res = requests.get(jwk_uri).json() + key = res["keys"][0] + + # Convert JWK to PEM + pem = jwk.JWK(**key).export_to_pem().decode("utf-8") + + default_sign_key = "".join(pem.split("\n")[1:-2]) + + return default_sign_key + + @classmethod + def _restart_server_with_oauth(cls): + logger.info("Restarting Gravitino server...") + gravitino_home = os.environ.get("GRAVITINO_HOME") + gravitino_startup_script = os.path.join(gravitino_home, "bin/gravitino.sh") + if not os.path.exists(gravitino_startup_script): + raise GravitinoRuntimeException( + f"Can't find Gravitino startup script: {gravitino_startup_script}, " + "Please execute `./gradlew compileDistribution -x test` in the Gravitino " + "project root directory." + ) + + # Restart Gravitino Server + env_vars = os.environ.copy() + result = subprocess.run( + [gravitino_startup_script, "restart"], + env=env_vars, + capture_output=True, + text=True, + check=False, + ) + if result.stdout: + logger.info("stdout: %s", result.stdout) + if result.stderr: + logger.info("stderr: %s", result.stderr) + + oauth2_token_provider = DefaultOAuth2TokenProvider( + f"{cls.oauth2_server_uri}", "test:test", "test", "oauth2/token" + ) + + auth_header = { + AuthConstants.HTTP_HEADER_AUTHORIZATION: oauth2_token_provider.get_token_data().decode( + "utf-8" + ) + } + + if not check_gravitino_server_status(headers=auth_header): + logger.error("ERROR: Can't start Gravitino server!") + sys.exit(0) + + def setUp(self): + oauth2_token_provider = DefaultOAuth2TokenProvider( + f"{self.oauth2_server_uri}", "test:test", "test", "oauth2/token" + ) + + self.gravitino_admin_client = GravitinoAdminClient( + uri="http://localhost:8090", auth_data_provider=oauth2_token_provider + ) + + self.init_test_env() + + def init_test_env(self): + + self.gravitino_admin_client.create_metalake( + self.metalake_name, comment="", properties={} + ) + + oauth2_token_provider = DefaultOAuth2TokenProvider( + f"{self.oauth2_server_uri}", "test:test", "test", "oauth2/token" + ) + + self.gravitino_client = GravitinoClient( + uri="http://localhost:8090", + metalake_name=self.metalake_name, + auth_data_provider=oauth2_token_provider, + ) + + super().init_test_env() + + def tearDown(self): + self.clean_test_data() diff --git a/clients/client-python/tests/integration/auth/test_simple_auth_client.py b/clients/client-python/tests/integration/auth/test_simple_auth_client.py new file mode 100644 index 00000000000..6533e49a7a4 --- /dev/null +++ b/clients/client-python/tests/integration/auth/test_simple_auth_client.py @@ -0,0 +1,58 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +""" + +import logging +import os + +from gravitino import ( + GravitinoClient, + GravitinoAdminClient, +) +from gravitino.auth.simple_auth_provider import SimpleAuthProvider + +from tests.integration.auth.test_auth_common import TestCommonAuth +from tests.integration.integration_test_env import IntegrationTestEnv + +logger = logging.getLogger(__name__) + + +class TestSimpleAuthClient(IntegrationTestEnv, TestCommonAuth): + + def setUp(self): + os.environ["GRAVITINO_USER"] = self.creator + self.gravitino_admin_client = GravitinoAdminClient( + uri="http://localhost:8090", auth_data_provider=SimpleAuthProvider() + ) + + self.init_test_env() + + def init_test_env(self): + self.gravitino_admin_client.create_metalake( + self.metalake_name, comment="", properties={} + ) + self.gravitino_client = GravitinoClient( + uri="http://localhost:8090", + metalake_name=self.metalake_name, + auth_data_provider=SimpleAuthProvider(), + ) + + super().init_test_env() + + def tearDown(self): + self.clean_test_data() diff --git a/clients/client-python/tests/integration/hdfs_container.py b/clients/client-python/tests/integration/containers/base_container.py similarity index 59% rename from clients/client-python/tests/integration/hdfs_container.py rename to clients/client-python/tests/integration/containers/base_container.py index 16cb2a80cc6..83d03b277cd 100644 --- a/clients/client-python/tests/integration/hdfs_container.py +++ b/clients/client-python/tests/integration/containers/base_container.py @@ -17,72 +17,33 @@ under the License. """ -import asyncio import logging -import os -import time +from typing import Dict import docker from docker import types as tp -from docker.errors import NotFound, DockerException +from docker.errors import NotFound from gravitino.exceptions.base import GravitinoRuntimeException -from gravitino.exceptions.base import InternalError logger = logging.getLogger(__name__) -async def check_hdfs_status(hdfs_container): - retry_limit = 15 - for _ in range(retry_limit): - try: - command_and_args = ["bash", "/tmp/check-status.sh"] - exec_result = hdfs_container.exec_run(command_and_args) - if exec_result.exit_code != 0: - message = ( - f"Command {command_and_args} exited with {exec_result.exit_code}" - ) - logger.warning(message) - logger.warning("output: %s", exec_result.output) - output_status_command = ["hdfs", "dfsadmin", "-report"] - exec_result = hdfs_container.exec_run(output_status_command) - logger.info("HDFS report, output: %s", exec_result.output) - else: - logger.info("HDFS startup successfully!") - return True - except DockerException as e: - logger.error( - "Exception occurred while checking HDFS container status: %s", e - ) - time.sleep(10) - return False - - -async def check_hdfs_container_status(hdfs_container): - timeout_sec = 150 - try: - result = await asyncio.wait_for( - check_hdfs_status(hdfs_container), timeout=timeout_sec - ) - if not result: - raise InternalError("HDFS container startup failed!") - except asyncio.TimeoutError as e: - raise GravitinoRuntimeException( - "Timeout occurred while waiting for checking HDFS container status." - ) from e - - -class HDFSContainer: +class BaseContainer: _docker_client = None _container = None _network = None _ip = "" _network_name = "python-net" - _container_name = "python-hdfs" + _container_name: str - def __init__(self): + def __init__( + self, container_name: str, image_name: str, enviroment: Dict = None, **kwarg + ): + self._container_name = container_name self._docker_client = docker.from_env() self._create_networks() + try: container = self._docker_client.containers.get(self._container_name) if container is not None: @@ -90,23 +51,19 @@ def __init__(self): container.restart() self._container = container except NotFound: - logger.warning("Cannot find hdfs container in docker env, skip remove.") + logger.warning( + "Cannot find the container %s in docker env, skip remove.", + self._container_name, + ) if self._container is None: - image_name = os.environ.get("GRAVITINO_CI_HIVE_DOCKER_IMAGE") - if image_name is None: - raise GravitinoRuntimeException( - "GRAVITINO_CI_HIVE_DOCKER_IMAGE env variable is not set." - ) self._container = self._docker_client.containers.run( image=image_name, name=self._container_name, detach=True, - environment={"HADOOP_USER_NAME": "anonymous"}, + environment=enviroment, network=self._network_name, + **kwarg, ) - asyncio.run(check_hdfs_container_status(self._container)) - - self._fetch_ip() def _create_networks(self): pool_config = tp.IPAMPool(subnet="10.20.31.16/28") @@ -123,7 +80,9 @@ def _create_networks(self): def _fetch_ip(self): if self._container is None: - raise GravitinoRuntimeException("The HDFS container has not init.") + raise GravitinoRuntimeException( + f"The container {self._container_name} has not init." + ) container_info = self._docker_client.api.inspect_container(self._container.id) self._ip = container_info["NetworkSettings"]["Networks"][self._network_name][ diff --git a/clients/client-python/tests/integration/containers/hdfs_container.py b/clients/client-python/tests/integration/containers/hdfs_container.py new file mode 100644 index 00000000000..8676f673627 --- /dev/null +++ b/clients/client-python/tests/integration/containers/hdfs_container.py @@ -0,0 +1,88 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +""" + +import asyncio +import logging +import os +import time + +from docker.errors import DockerException +from gravitino.exceptions.base import GravitinoRuntimeException +from gravitino.exceptions.base import InternalError + +from tests.integration.containers.base_container import BaseContainer + +logger = logging.getLogger(__name__) + + +async def check_hdfs_status(hdfs_container): + retry_limit = 15 + for _ in range(retry_limit): + try: + command_and_args = ["bash", "/tmp/check-status.sh"] + exec_result = hdfs_container.exec_run(command_and_args) + if exec_result.exit_code != 0: + message = ( + f"Command {command_and_args} exited with {exec_result.exit_code}" + ) + logger.warning(message) + logger.warning("output: %s", exec_result.output) + output_status_command = ["hdfs", "dfsadmin", "-report"] + exec_result = hdfs_container.exec_run(output_status_command) + logger.info("HDFS report, output: %s", exec_result.output) + else: + logger.info("HDFS startup successfully!") + return True + except DockerException as e: + logger.error( + "Exception occurred while checking HDFS container status: %s", e + ) + time.sleep(10) + return False + + +async def check_hdfs_container_status(hdfs_container): + timeout_sec = 150 + try: + result = await asyncio.wait_for( + check_hdfs_status(hdfs_container), timeout=timeout_sec + ) + if not result: + raise InternalError("HDFS container startup failed!") + except asyncio.TimeoutError as e: + raise GravitinoRuntimeException( + "Timeout occurred while waiting for checking HDFS container status." + ) from e + + +class HDFSContainer(BaseContainer): + + def __init__(self): + container_name = "python-hdfs" + image_name = os.environ.get("GRAVITINO_CI_HIVE_DOCKER_IMAGE") + if image_name is None: + raise GravitinoRuntimeException( + "GRAVITINO_CI_HIVE_DOCKER_IMAGE env variable is not set." + ) + environment = {"HADOOP_USER_NAME": "anonymous"} + + super().__init__(container_name, image_name, environment) + + asyncio.run(check_hdfs_container_status(self._container)) + self._fetch_ip() diff --git a/clients/client-python/tests/integration/containers/oauth2_container.py b/clients/client-python/tests/integration/containers/oauth2_container.py new file mode 100644 index 00000000000..763e23619fe --- /dev/null +++ b/clients/client-python/tests/integration/containers/oauth2_container.py @@ -0,0 +1,69 @@ +""" +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +""" + +import asyncio +import os +import time + +from gravitino.exceptions.base import GravitinoRuntimeException + +from tests.integration.containers.base_container import BaseContainer + +TIEMOUT_SEC = 5 +RETRY_LIMIT = 30 + + +async def check_oauth2_container_status(oauth2_container: "OAuth2Container"): + for _ in range(RETRY_LIMIT): + if oauth2_container.health_check(): + return True + time.sleep(TIEMOUT_SEC) + return False + + +class OAuth2Container(BaseContainer): + + def __init__(self): + container_name = "sample-auth-server" + image_name = os.environ.get("GRAVITINO_OAUTH2_SAMPLE_SERVER") + if image_name is None: + raise GravitinoRuntimeException( + "GRAVITINO_OAUTH2_SAMPLE_SERVER env variable is not set." + ) + + healthcheck = { + "test": [ + "CMD-SHELL", + "wget -qO - http://localhost:8177/oauth2/jwks || exit 1", + ], + "interval": TIEMOUT_SEC * 1000000000, + "retries": RETRY_LIMIT, + } + + super().__init__(container_name, image_name, healthcheck=healthcheck) + asyncio.run(check_oauth2_container_status(self)) + self._fetch_ip() + + def health_check(self) -> bool: + return ( + self._docker_client.api.inspect_container(self._container_name)["State"][ + "Health" + ]["Status"] + == "healthy" + ) diff --git a/clients/client-python/tests/integration/integration_test_env.py b/clients/client-python/tests/integration/integration_test_env.py index 7b8c05f538b..4263d5af436 100644 --- a/clients/client-python/tests/integration/integration_test_env.py +++ b/clients/client-python/tests/integration/integration_test_env.py @@ -31,9 +31,9 @@ logger = logging.getLogger(__name__) -def get_gravitino_server_version(): +def get_gravitino_server_version(**kwargs): try: - response = requests.get("http://localhost:8090/api/version") + response = requests.get("http://localhost:8090/api/version", **kwargs) response.raise_for_status() # raise an exception for bad status codes response.close() return True @@ -42,11 +42,11 @@ def get_gravitino_server_version(): return False -def check_gravitino_server_status() -> bool: +def check_gravitino_server_status(**kwargs) -> bool: gravitino_server_running = False for i in range(5): logger.info("Monitoring Gravitino server status. Attempt %s", i + 1) - if get_gravitino_server_version(): + if get_gravitino_server_version(**kwargs): logger.debug("Gravitino Server is running") gravitino_server_running = True break @@ -69,14 +69,10 @@ def setUpClass(cls): logger.error("ERROR: Can't find online Gravitino server!") return - gravitino_home = os.environ.get("GRAVITINO_HOME") - if gravitino_home is None: - logger.error( - "Gravitino Python client integration test must configure `GRAVITINO_HOME`" - ) - sys.exit(0) - - cls.gravitino_startup_script = os.path.join(gravitino_home, "bin/gravitino.sh") + cls._get_gravitino_home() + cls.gravitino_startup_script = os.path.join( + cls.gravitino_home, "bin/gravitino.sh" + ) if not os.path.exists(cls.gravitino_startup_script): logger.error( "Can't find Gravitino startup script: %s, " @@ -166,37 +162,35 @@ def restart_server(cls): raise GravitinoRuntimeException("ERROR: Can't start Gravitino server!") @classmethod - def _append_catalog_hadoop_conf(cls, config): - logger.info("Append catalog hadoop conf.") + def _get_gravitino_home(cls): gravitino_home = os.environ.get("GRAVITINO_HOME") if gravitino_home is None: - raise GravitinoRuntimeException("Cannot find GRAVITINO_HOME env.") - hadoop_conf_path = f"{gravitino_home}/catalogs/hadoop/conf/hadoop.conf" - if not os.path.exists(hadoop_conf_path): - raise GravitinoRuntimeException( - f"Hadoop conf file is not found at `{hadoop_conf_path}`." + logger.error( + "Gravitino Python client integration test must configure `GRAVITINO_HOME`" ) + sys.exit(0) + + cls.gravitino_home = gravitino_home - with open(hadoop_conf_path, mode="a", encoding="utf-8") as f: + @classmethod + def _append_conf(cls, config, conf_path): + logger.info("Append %s.", conf_path) + if not os.path.exists(conf_path): + raise GravitinoRuntimeException(f"Conf file is not found at `{conf_path}`.") + + with open(conf_path, mode="a", encoding="utf-8") as f: for key, value in config.items(): f.write(f"\n{key} = {value}") @classmethod - def _reset_catalog_hadoop_conf(cls, config): - logger.info("Reset catalog hadoop conf.") - gravitino_home = os.environ.get("GRAVITINO_HOME") - if gravitino_home is None: - raise GravitinoRuntimeException("Cannot find GRAVITINO_HOME env.") - hadoop_conf_path = f"{gravitino_home}/catalogs/hadoop/conf/hadoop.conf" - if not os.path.exists(hadoop_conf_path): - raise GravitinoRuntimeException( - f"Hadoop conf file is not found at `{hadoop_conf_path}`." - ) + def _reset_conf(cls, config, conf_path): + logger.info("Reset %s.", conf_path) + if not os.path.exists(conf_path): + raise GravitinoRuntimeException(f"Conf file is not found at `{conf_path}`.") filtered_lines = [] - with open(hadoop_conf_path, mode="r", encoding="utf-8") as file: + with open(conf_path, mode="r", encoding="utf-8") as file: origin_lines = file.readlines() - existed_config = {} for line in origin_lines: line = line.strip() if line.startswith("#"): @@ -205,16 +199,16 @@ def _reset_catalog_hadoop_conf(cls, config): else: try: key, value = line.split("=") - existed_config[key.strip()] = value.strip() + key = key.strip() + value = value.strip() + if key not in config: + append_line = f"{key} = {value}\n" + filtered_lines.append(append_line) + except ValueError: # cannot split to key, value, so just append filtered_lines.append(line + "\n") - for key, value in existed_config.items(): - if config[key] is None: - append_line = f"{key} = {value}\n" - filtered_lines.append(append_line) - - with open(hadoop_conf_path, mode="w", encoding="utf-8") as file: + with open(conf_path, mode="w", encoding="utf-8") as file: for line in filtered_lines: file.write(line) diff --git a/clients/client-python/tests/integration/test_gvfs_with_hdfs.py b/clients/client-python/tests/integration/test_gvfs_with_hdfs.py index 93682fa575e..af73b354d7b 100644 --- a/clients/client-python/tests/integration/test_gvfs_with_hdfs.py +++ b/clients/client-python/tests/integration/test_gvfs_with_hdfs.py @@ -46,7 +46,7 @@ from gravitino.auth.auth_constants import AuthConstants from gravitino.exceptions.base import GravitinoRuntimeException from tests.integration.integration_test_env import IntegrationTestEnv -from tests.integration.hdfs_container import HDFSContainer +from tests.integration.containers.hdfs_container import HDFSContainer from tests.integration.base_hadoop_env import BaseHadoopEnvironment logger = logging.getLogger(__name__) @@ -94,6 +94,9 @@ class TestGvfsWithHDFS(IntegrationTestEnv): @classmethod def setUpClass(cls): + + cls._get_gravitino_home() + cls.hdfs_container = HDFSContainer() hdfs_container_ip = cls.hdfs_container.get_ip() # init hadoop env @@ -101,8 +104,11 @@ def setUpClass(cls): cls.config = { "gravitino.bypass.fs.defaultFS": f"hdfs://{hdfs_container_ip}:9000" } + + cls.hadoop_conf_path = f"{cls.gravitino_home}/catalogs/hadoop/conf/hadoop.conf" + # append the hadoop conf to server - cls._append_catalog_hadoop_conf(cls.config) + cls._append_conf(cls.config, cls.hadoop_conf_path) # restart the server cls.restart_server() # create entity @@ -113,7 +119,7 @@ def tearDownClass(cls): try: cls._clean_test_data() # reset server conf - cls._reset_catalog_hadoop_conf(cls.config) + cls._reset_conf(cls.config, cls.hadoop_conf_path) # restart server cls.restart_server() # clear hadoop env