Skip to content

Commit

Permalink
Fix tests using single threaded client
Browse files Browse the repository at this point in the history
  • Loading branch information
RobbeSneyders committed Mar 6, 2024
1 parent 3f2a1a7 commit 3c9f7cc
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
7 changes: 0 additions & 7 deletions src/fondant/component/data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import dask.dataframe as dd
import dask.distributed
import fsspec
import pyarrow as pa
from dask.diagnostics import ProgressBar
from dask.distributed import as_completed
Expand Down Expand Up @@ -206,12 +205,6 @@ def _write_dataframe(self, dataframe: dd.DataFrame) -> dd.core.Scalar:
f"{self.manifest.run_id}/{self.operation_spec.component_name}"
)

# Create directory the dataframe will be written to, since this is not handled by Pandas
# `to_parquet` method.
protocol = fsspec.utils.get_protocol(location)
fs = fsspec.get_filesystem_class(protocol)
fs().makedirs(location)

schema = {
field.name: field.type.value
for field in self.operation_spec.produces_to_dataset.values()
Expand Down
28 changes: 21 additions & 7 deletions tests/component/test_data_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import dask.dataframe as dd
import pyarrow as pa
import pytest
from dask.distributed import Client
from fondant.component.data_io import DaskDataLoader, DaskDataWriter
from fondant.core.component_spec import ComponentSpec, OperationSpec
from fondant.core.manifest import Manifest
Expand Down Expand Up @@ -51,8 +50,22 @@ def dataframe(manifest, component_spec):


@pytest.fixture()
def client():
return Client()
async def client():
"""Start a Dask client running everything in a single thread. This is necessary to work with
temp directories, which are not available to other processes.
"""
from dask.distributed import Client, Scheduler, Worker
from tornado.concurrent import DummyExecutor
from tornado.ioloop import IOLoop

loop = IOLoop()
e = DummyExecutor()
s = Scheduler(loop=loop)
await s.start()
w = Worker(s.address, loop=loop, executor=e)
loop.add_callback(w._start)

return Client(s.address)


def test_load_dataframe(manifest, component_spec):
Expand Down Expand Up @@ -114,7 +127,7 @@ def test_load_dataframe_rows(manifest, component_spec):
assert dataframe.npartitions == expected_partitions


def test_write_dataset(
async def test_write_dataset(
tmp_path_factory,
dataframe,
manifest,
Expand Down Expand Up @@ -145,11 +158,12 @@ def test_write_dataset(
assert dataframe.index.name == "id"


def test_write_dataset_custom_produces(
async def test_write_dataset_custom_produces(
tmp_path_factory,
dataframe,
manifest,
component_spec_produces,
client,
):
"""Test writing out subsets."""
produces = {
Expand Down Expand Up @@ -186,7 +200,7 @@ def test_write_dataset_custom_produces(


# TODO: check if this is still needed?
def test_write_reset_index(
async def test_write_reset_index(
tmp_path_factory,
dataframe,
manifest,
Expand All @@ -210,7 +224,7 @@ def test_write_reset_index(


@pytest.mark.parametrize("partitions", list(range(1, 5)))
def test_write_divisions( # noqa: PLR0913
async def test_write_divisions( # noqa: PLR0913
tmp_path_factory,
dataframe,
manifest,
Expand Down

0 comments on commit 3c9f7cc

Please sign in to comment.