From 9436f0ff672a5525ca215abf7d9d020f9b6c07c7 Mon Sep 17 00:00:00 2001 From: Roman Isecke <136338424+rbiseck3@users.noreply.github.com> Date: Fri, 10 Jan 2025 11:03:31 -0500 Subject: [PATCH] feat/add error handling to databricks volume connector (#328) * add error handling to databricks volume connector * add error wrapping to uploader * make sure os is cleared when running pat int tests * remove use of old error handling * bump changelog --- CHANGELOG.md | 2 +- .../databricks/test_volumes_native.py | 43 +++++++ unstructured_ingest/__version__.py | 2 +- .../connectors/databricks/volumes.py | 114 +++++++++++------- 4 files changed, 116 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fcd6c1d07..5af7f05e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## 0.3.13-dev2 +## 0.3.13-dev3 ### Fixes diff --git a/test/integration/connectors/databricks/test_volumes_native.py b/test/integration/connectors/databricks/test_volumes_native.py index 82e78c9e4..0be6d0d4d 100644 --- a/test/integration/connectors/databricks/test_volumes_native.py +++ b/test/integration/connectors/databricks/test_volumes_native.py @@ -16,6 +16,7 @@ source_connector_validation, ) from test.integration.utils import requires_env +from unstructured_ingest.v2.errors import UserAuthError, UserError from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers from unstructured_ingest.v2.processes.connectors.databricks.volumes_native import ( CONNECTOR_TYPE, @@ -143,6 +144,48 @@ async def test_volumes_native_source_pat(tmp_path: Path): ) +@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG) +@requires_env("DATABRICKS_HOST", "DATABRICKS_PAT", "DATABRICKS_CATALOG") +def test_volumes_native_source_pat_invalid_catalog(): + env_data = get_pat_env_data() + with mock.patch.dict(os.environ, clear=True): + indexer_config = DatabricksNativeVolumesIndexerConfig( + recursive=True, + volume="test-platform", + volume_path="databricks-volumes-test-input", + catalog="fake_catalog", + ) + indexer = DatabricksNativeVolumesIndexer( + connection_config=env_data.get_connection_config(), index_config=indexer_config + ) + with pytest.raises(UserError): + _ = list(indexer.run()) + + +@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG) +@requires_env("DATABRICKS_HOST") +def test_volumes_native_source_pat_invalid_pat(): + host = os.environ["DATABRICKS_HOST"] + with mock.patch.dict(os.environ, clear=True): + indexer_config = DatabricksNativeVolumesIndexerConfig( + recursive=True, + volume="test-platform", + volume_path="databricks-volumes-test-input", + catalog="fake_catalog", + ) + connection_config = DatabricksNativeVolumesConnectionConfig( + host=host, + access_config=DatabricksNativeVolumesAccessConfig( + token="invalid-token", + ), + ) + indexer = DatabricksNativeVolumesIndexer( + connection_config=connection_config, index_config=indexer_config + ) + with pytest.raises(UserAuthError): + _ = list(indexer.run()) + + def _get_volume_path(catalog: str, volume: str, volume_path: str): return f"/Volumes/{catalog}/default/{volume}/{volume_path}" diff --git a/unstructured_ingest/__version__.py b/unstructured_ingest/__version__.py index dff5c63d1..f53758827 100644 --- a/unstructured_ingest/__version__.py +++ b/unstructured_ingest/__version__.py @@ -1 +1 @@ -__version__ = "0.3.13-dev2" # pragma: no cover +__version__ = "0.3.13-dev3" # pragma: no cover diff --git a/unstructured_ingest/v2/processes/connectors/databricks/volumes.py b/unstructured_ingest/v2/processes/connectors/databricks/volumes.py index a4d5326ec..caa4a272c 100644 --- a/unstructured_ingest/v2/processes/connectors/databricks/volumes.py +++ b/unstructured_ingest/v2/processes/connectors/databricks/volumes.py @@ -7,12 +7,13 @@ from pydantic import BaseModel, Field -from unstructured_ingest.error import ( - DestinationConnectionError, - SourceConnectionError, - SourceConnectionNetworkError, -) from unstructured_ingest.utils.dep_check import requires_dependencies +from unstructured_ingest.v2.errors import ( + ProviderError, + RateLimitError, + UserAuthError, + UserError, +) from unstructured_ingest.v2.interfaces import ( AccessConfig, ConnectionConfig, @@ -65,6 +66,29 @@ class DatabricksVolumesConnectionConfig(ConnectionConfig, ABC): "Databricks accounts endpoint.", ) + def wrap_error(self, e: Exception) -> Exception: + from databricks.sdk.errors.base import DatabricksError + from databricks.sdk.errors.platform import STATUS_CODE_MAPPING + + if isinstance(e, ValueError): + error_message = e.args[0] + message_split = error_message.split(":") + if message_split[0].endswith("auth"): + return UserAuthError(e) + if isinstance(e, DatabricksError): + reverse_mapping = {v: k for k, v in STATUS_CODE_MAPPING.items()} + if status_code := reverse_mapping.get(type(e)): + if status_code in [401, 403]: + return UserAuthError(e) + if status_code == 429: + return RateLimitError(e) + if 400 <= status_code < 500: + return UserError(e) + if 500 <= status_code < 600: + return ProviderError(e) + logger.error(f"unhandled exception from databricks: {e}", exc_info=True) + return e + @requires_dependencies(dependencies=["databricks.sdk"], extras="databricks-volumes") def get_client(self) -> "WorkspaceClient": from databricks.sdk import WorkspaceClient @@ -88,32 +112,37 @@ def precheck(self) -> None: try: self.connection_config.get_client() except Exception as e: - logger.error(f"failed to validate connection: {e}", exc_info=True) - raise SourceConnectionError(f"failed to validate connection: {e}") + raise self.connection_config.wrap_error(e=e) from e def run(self, **kwargs: Any) -> Generator[FileData, None, None]: - for file_info in self.connection_config.get_client().dbfs.list( - path=self.index_config.path, recursive=self.index_config.recursive - ): - if file_info.is_dir: - continue - rel_path = file_info.path.replace(self.index_config.path, "") - if rel_path.startswith("/"): - rel_path = rel_path[1:] - filename = Path(file_info.path).name - yield FileData( - identifier=str(uuid5(NAMESPACE_DNS, file_info.path)), - connector_type=self.connector_type, - source_identifiers=SourceIdentifiers( - filename=filename, - rel_path=rel_path, - fullpath=file_info.path, - ), - additional_metadata={"catalog": self.index_config.catalog, "path": file_info.path}, - metadata=FileDataSourceMetadata( - url=file_info.path, date_modified=str(file_info.modification_time) - ), - ) + try: + for file_info in self.connection_config.get_client().dbfs.list( + path=self.index_config.path, recursive=self.index_config.recursive + ): + if file_info.is_dir: + continue + rel_path = file_info.path.replace(self.index_config.path, "") + if rel_path.startswith("/"): + rel_path = rel_path[1:] + filename = Path(file_info.path).name + yield FileData( + identifier=str(uuid5(NAMESPACE_DNS, file_info.path)), + connector_type=self.connector_type, + source_identifiers=SourceIdentifiers( + filename=filename, + rel_path=rel_path, + fullpath=file_info.path, + ), + additional_metadata={ + "catalog": self.index_config.catalog, + "path": file_info.path, + }, + metadata=FileDataSourceMetadata( + url=file_info.path, date_modified=str(file_info.modification_time) + ), + ) + except Exception as e: + raise self.connection_config.wrap_error(e=e) class DatabricksVolumesDownloaderConfig(DownloaderConfig): @@ -129,8 +158,7 @@ def precheck(self) -> None: try: self.connection_config.get_client() except Exception as e: - logger.error(f"failed to validate connection: {e}", exc_info=True) - raise SourceConnectionError(f"failed to validate connection: {e}") + raise self.connection_config.wrap_error(e=e) def get_download_path(self, file_data: FileData) -> Path: return self.download_config.download_dir / Path(file_data.source_identifiers.relative_path) @@ -143,12 +171,10 @@ def run(self, file_data: FileData, **kwargs: Any) -> DownloadResponse: try: with self.connection_config.get_client().dbfs.download(path=volumes_path) as c: read_content = c._read_handle.read() - with open(download_path, "wb") as f: - f.write(read_content) except Exception as e: - logger.error(f"failed to download file {file_data.identifier}: {e}", exc_info=True) - raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}") - + raise self.connection_config.wrap_error(e=e) + with open(download_path, "wb") as f: + f.write(read_content) return self.generate_download_response(file_data=file_data, download_path=download_path) @@ -165,16 +191,18 @@ def precheck(self) -> None: try: assert self.connection_config.get_client().current_user.me().active except Exception as e: - logger.error(f"failed to validate connection: {e}", exc_info=True) - raise DestinationConnectionError(f"failed to validate connection: {e}") + raise self.connection_config.wrap_error(e=e) def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None: output_path = os.path.join( self.upload_config.path, f"{file_data.source_identifiers.filename}.json" ) with open(path, "rb") as elements_file: - self.connection_config.get_client().files.upload( - file_path=output_path, - contents=elements_file, - overwrite=True, - ) + try: + self.connection_config.get_client().files.upload( + file_path=output_path, + contents=elements_file, + overwrite=True, + ) + except Exception as e: + raise self.connection_config.wrap_error(e=e)