From 3c9f7cc572d1d2a3a70385eb71a03a3edfb13b93 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Wed, 6 Mar 2024 13:00:45 +0100 Subject: [PATCH] Fix tests using single threaded client --- src/fondant/component/data_io.py | 7 ------- tests/component/test_data_io.py | 28 +++++++++++++++++++++------- 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/fondant/component/data_io.py b/src/fondant/component/data_io.py index bbec6529..5d7f442e 100644 --- a/src/fondant/component/data_io.py +++ b/src/fondant/component/data_io.py @@ -5,7 +5,6 @@ import dask.dataframe as dd import dask.distributed -import fsspec import pyarrow as pa from dask.diagnostics import ProgressBar from dask.distributed import as_completed @@ -206,12 +205,6 @@ def _write_dataframe(self, dataframe: dd.DataFrame) -> dd.core.Scalar: f"{self.manifest.run_id}/{self.operation_spec.component_name}" ) - # Create directory the dataframe will be written to, since this is not handled by Pandas - # `to_parquet` method. - protocol = fsspec.utils.get_protocol(location) - fs = fsspec.get_filesystem_class(protocol) - fs().makedirs(location) - schema = { field.name: field.type.value for field in self.operation_spec.produces_to_dataset.values() diff --git a/tests/component/test_data_io.py b/tests/component/test_data_io.py index 3db90272..fd823b19 100644 --- a/tests/component/test_data_io.py +++ b/tests/component/test_data_io.py @@ -4,7 +4,6 @@ import dask.dataframe as dd import pyarrow as pa import pytest -from dask.distributed import Client from fondant.component.data_io import DaskDataLoader, DaskDataWriter from fondant.core.component_spec import ComponentSpec, OperationSpec from fondant.core.manifest import Manifest @@ -51,8 +50,22 @@ def dataframe(manifest, component_spec): @pytest.fixture() -def client(): - return Client() +async def client(): + """Start a Dask client running everything in a single thread. This is necessary to work with + temp directories, which are not available to other processes. + """ + from dask.distributed import Client, Scheduler, Worker + from tornado.concurrent import DummyExecutor + from tornado.ioloop import IOLoop + + loop = IOLoop() + e = DummyExecutor() + s = Scheduler(loop=loop) + await s.start() + w = Worker(s.address, loop=loop, executor=e) + loop.add_callback(w._start) + + return Client(s.address) def test_load_dataframe(manifest, component_spec): @@ -114,7 +127,7 @@ def test_load_dataframe_rows(manifest, component_spec): assert dataframe.npartitions == expected_partitions -def test_write_dataset( +async def test_write_dataset( tmp_path_factory, dataframe, manifest, @@ -145,11 +158,12 @@ def test_write_dataset( assert dataframe.index.name == "id" -def test_write_dataset_custom_produces( +async def test_write_dataset_custom_produces( tmp_path_factory, dataframe, manifest, component_spec_produces, + client, ): """Test writing out subsets.""" produces = { @@ -186,7 +200,7 @@ def test_write_dataset_custom_produces( # TODO: check if this is still needed? -def test_write_reset_index( +async def test_write_reset_index( tmp_path_factory, dataframe, manifest, @@ -210,7 +224,7 @@ def test_write_reset_index( @pytest.mark.parametrize("partitions", list(range(1, 5))) -def test_write_divisions( # noqa: PLR0913 +async def test_write_divisions( # noqa: PLR0913 tmp_path_factory, dataframe, manifest,