Skip to content

Commit

Permalink
feat/add error handling to databricks volume connector (#328)
Browse files Browse the repository at this point in the history
* add error handling to databricks volume connector

* add error wrapping to uploader

* make sure os is cleared when running pat int tests

* remove use of old error handling

* bump changelog
  • Loading branch information
rbiseck3 authored Jan 10, 2025
1 parent 1616327 commit 9436f0f
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 45 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## 0.3.13-dev2
## 0.3.13-dev3

### Fixes

Expand Down
43 changes: 43 additions & 0 deletions test/integration/connectors/databricks/test_volumes_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
source_connector_validation,
)
from test.integration.utils import requires_env
from unstructured_ingest.v2.errors import UserAuthError, UserError
from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
from unstructured_ingest.v2.processes.connectors.databricks.volumes_native import (
CONNECTOR_TYPE,
Expand Down Expand Up @@ -143,6 +144,48 @@ async def test_volumes_native_source_pat(tmp_path: Path):
)


@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
@requires_env("DATABRICKS_HOST", "DATABRICKS_PAT", "DATABRICKS_CATALOG")
def test_volumes_native_source_pat_invalid_catalog():
env_data = get_pat_env_data()
with mock.patch.dict(os.environ, clear=True):
indexer_config = DatabricksNativeVolumesIndexerConfig(
recursive=True,
volume="test-platform",
volume_path="databricks-volumes-test-input",
catalog="fake_catalog",
)
indexer = DatabricksNativeVolumesIndexer(
connection_config=env_data.get_connection_config(), index_config=indexer_config
)
with pytest.raises(UserError):
_ = list(indexer.run())


@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG)
@requires_env("DATABRICKS_HOST")
def test_volumes_native_source_pat_invalid_pat():
host = os.environ["DATABRICKS_HOST"]
with mock.patch.dict(os.environ, clear=True):
indexer_config = DatabricksNativeVolumesIndexerConfig(
recursive=True,
volume="test-platform",
volume_path="databricks-volumes-test-input",
catalog="fake_catalog",
)
connection_config = DatabricksNativeVolumesConnectionConfig(
host=host,
access_config=DatabricksNativeVolumesAccessConfig(
token="invalid-token",
),
)
indexer = DatabricksNativeVolumesIndexer(
connection_config=connection_config, index_config=indexer_config
)
with pytest.raises(UserAuthError):
_ = list(indexer.run())


def _get_volume_path(catalog: str, volume: str, volume_path: str):
return f"/Volumes/{catalog}/default/{volume}/{volume_path}"

Expand Down
2 changes: 1 addition & 1 deletion unstructured_ingest/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.13-dev2" # pragma: no cover
__version__ = "0.3.13-dev3" # pragma: no cover
114 changes: 71 additions & 43 deletions unstructured_ingest/v2/processes/connectors/databricks/volumes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@

from pydantic import BaseModel, Field

from unstructured_ingest.error import (
DestinationConnectionError,
SourceConnectionError,
SourceConnectionNetworkError,
)
from unstructured_ingest.utils.dep_check import requires_dependencies
from unstructured_ingest.v2.errors import (
ProviderError,
RateLimitError,
UserAuthError,
UserError,
)
from unstructured_ingest.v2.interfaces import (
AccessConfig,
ConnectionConfig,
Expand Down Expand Up @@ -65,6 +66,29 @@ class DatabricksVolumesConnectionConfig(ConnectionConfig, ABC):
"Databricks accounts endpoint.",
)

def wrap_error(self, e: Exception) -> Exception:
from databricks.sdk.errors.base import DatabricksError
from databricks.sdk.errors.platform import STATUS_CODE_MAPPING

if isinstance(e, ValueError):
error_message = e.args[0]
message_split = error_message.split(":")
if message_split[0].endswith("auth"):
return UserAuthError(e)
if isinstance(e, DatabricksError):
reverse_mapping = {v: k for k, v in STATUS_CODE_MAPPING.items()}
if status_code := reverse_mapping.get(type(e)):
if status_code in [401, 403]:
return UserAuthError(e)
if status_code == 429:
return RateLimitError(e)
if 400 <= status_code < 500:
return UserError(e)
if 500 <= status_code < 600:
return ProviderError(e)
logger.error(f"unhandled exception from databricks: {e}", exc_info=True)
return e

@requires_dependencies(dependencies=["databricks.sdk"], extras="databricks-volumes")
def get_client(self) -> "WorkspaceClient":
from databricks.sdk import WorkspaceClient
Expand All @@ -88,32 +112,37 @@ def precheck(self) -> None:
try:
self.connection_config.get_client()
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
raise self.connection_config.wrap_error(e=e) from e

def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
for file_info in self.connection_config.get_client().dbfs.list(
path=self.index_config.path, recursive=self.index_config.recursive
):
if file_info.is_dir:
continue
rel_path = file_info.path.replace(self.index_config.path, "")
if rel_path.startswith("/"):
rel_path = rel_path[1:]
filename = Path(file_info.path).name
yield FileData(
identifier=str(uuid5(NAMESPACE_DNS, file_info.path)),
connector_type=self.connector_type,
source_identifiers=SourceIdentifiers(
filename=filename,
rel_path=rel_path,
fullpath=file_info.path,
),
additional_metadata={"catalog": self.index_config.catalog, "path": file_info.path},
metadata=FileDataSourceMetadata(
url=file_info.path, date_modified=str(file_info.modification_time)
),
)
try:
for file_info in self.connection_config.get_client().dbfs.list(
path=self.index_config.path, recursive=self.index_config.recursive
):
if file_info.is_dir:
continue
rel_path = file_info.path.replace(self.index_config.path, "")
if rel_path.startswith("/"):
rel_path = rel_path[1:]
filename = Path(file_info.path).name
yield FileData(
identifier=str(uuid5(NAMESPACE_DNS, file_info.path)),
connector_type=self.connector_type,
source_identifiers=SourceIdentifiers(
filename=filename,
rel_path=rel_path,
fullpath=file_info.path,
),
additional_metadata={
"catalog": self.index_config.catalog,
"path": file_info.path,
},
metadata=FileDataSourceMetadata(
url=file_info.path, date_modified=str(file_info.modification_time)
),
)
except Exception as e:
raise self.connection_config.wrap_error(e=e)


class DatabricksVolumesDownloaderConfig(DownloaderConfig):
Expand All @@ -129,8 +158,7 @@ def precheck(self) -> None:
try:
self.connection_config.get_client()
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise SourceConnectionError(f"failed to validate connection: {e}")
raise self.connection_config.wrap_error(e=e)

def get_download_path(self, file_data: FileData) -> Path:
return self.download_config.download_dir / Path(file_data.source_identifiers.relative_path)
Expand All @@ -143,12 +171,10 @@ def run(self, file_data: FileData, **kwargs: Any) -> DownloadResponse:
try:
with self.connection_config.get_client().dbfs.download(path=volumes_path) as c:
read_content = c._read_handle.read()
with open(download_path, "wb") as f:
f.write(read_content)
except Exception as e:
logger.error(f"failed to download file {file_data.identifier}: {e}", exc_info=True)
raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}")

raise self.connection_config.wrap_error(e=e)
with open(download_path, "wb") as f:
f.write(read_content)
return self.generate_download_response(file_data=file_data, download_path=download_path)


Expand All @@ -165,16 +191,18 @@ def precheck(self) -> None:
try:
assert self.connection_config.get_client().current_user.me().active
except Exception as e:
logger.error(f"failed to validate connection: {e}", exc_info=True)
raise DestinationConnectionError(f"failed to validate connection: {e}")
raise self.connection_config.wrap_error(e=e)

def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
output_path = os.path.join(
self.upload_config.path, f"{file_data.source_identifiers.filename}.json"
)
with open(path, "rb") as elements_file:
self.connection_config.get_client().files.upload(
file_path=output_path,
contents=elements_file,
overwrite=True,
)
try:
self.connection_config.get_client().files.upload(
file_path=output_path,
contents=elements_file,
overwrite=True,
)
except Exception as e:
raise self.connection_config.wrap_error(e=e)

0 comments on commit 9436f0f

Please sign in to comment.