Skip to content

Commit

Permalink
fix: Allow copy dataset with empty workspace to the default user work…
Browse files Browse the repository at this point in the history
…space (#2618)

When users have an old dataset without workspaces (`-`) and try to
migrate to a proper workspace (the user default), an error is raised
with a 409 error:

```bash
AlreadyExistsApiError: Argilla server returned an error with http status: 409
Error details: [{'code': 'argilla.api.errors::EntityAlreadyExistsError', 'params': {'name': 'my_dataset', 'type': 'ServiceBaseDataset', 'workspace': 'argilla'}}]
```

This PR fixes this problem by checking the existence target workspace by
searching with provided name + workspace, instead of using a more wide
way.

**Type of change**

(Please delete options that are not relevant. Remember to title the PR
according to the type of change)

- [x] Bug fix (non-breaking change which fixes an issue)

**How Has This Been Tested**

Tests checking this behavior are added.

**Checklist**

- [x] I have merged the original branch into my forked branch
- [x] follows the style guidelines of this project
- [x] I made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my
feature works
  • Loading branch information
frascuchon committed Mar 30, 2023
1 parent c89153e commit 905d4de
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 15 deletions.
22 changes: 7 additions & 15 deletions src/argilla/server/services/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,15 +143,13 @@ def copy_dataset(
dataset_workspace = copy_workspace or dataset.workspace
dataset_workspace = user.check_workspace(dataset_workspace)

self._validate_create_dataset(
name=copy_name,
workspace=dataset_workspace,
user=user,
)
self._validate_copy_dataset(name=copy_name, workspace=dataset_workspace)

copy_dataset = dataset.copy()

copy_dataset.name = copy_name
copy_dataset.workspace = dataset_workspace
copy_dataset.created_by = user.username

date_now = datetime.utcnow()

Expand All @@ -172,16 +170,10 @@ def copy_dataset(

return copy_dataset

def _validate_create_dataset(self, name: str, workspace: str, user: User):
try:
found = self.find_by_name(user=user, name=name, workspace=workspace)
raise EntityAlreadyExistsError(
name=found.name,
type=found.__class__,
workspace=workspace,
)
except (EntityNotFoundError, ForbiddenOperationError):
pass
def _validate_copy_dataset(self, name: str, workspace: str):
found = self.__dao__.find_by_name(name=name, owner=workspace)
if found:
raise EntityAlreadyExistsError(name=found.name, type=found.__class__, workspace=workspace)

async def get_settings(
self,
Expand Down
39 changes: 39 additions & 0 deletions tests/server/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pytest
from argilla.server.daos.backend import GenericElasticEngineBackend
from argilla.server.daos.datasets import DatasetsDAO
from argilla.server.daos.records import DatasetRecordsDAO
from argilla.server.services.datasets import DatasetsService


@pytest.fixture(scope="session")
def es():
return GenericElasticEngineBackend.get_instance()


@pytest.fixture(scope="session")
def records_dao(es: GenericElasticEngineBackend):
return DatasetRecordsDAO.get_instance(es)


@pytest.fixture(scope="session")
def datasets_dao(records_dao: DatasetRecordsDAO, es: GenericElasticEngineBackend):
return DatasetsDAO.get_instance(es=es, records_dao=records_dao)


@pytest.fixture(scope="session")
def datasets_service(datasets_dao: DatasetsDAO):
return DatasetsService.get_instance(datasets_dao)
13 changes: 13 additions & 0 deletions tests/server/services/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
44 changes: 44 additions & 0 deletions tests/server/services/test_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla.server.commons.models import TaskType
from argilla.server.daos.datasets import DatasetsDAO
from argilla.server.daos.models.datasets import BaseDatasetDB
from argilla.server.security.model import User
from argilla.server.services.datasets import DatasetsService


def test_copy_dataset_with_no_owner_info(datasets_service: DatasetsService, datasets_dao: DatasetsDAO):
dataset_name = "test_copy_dataset_with_no_owned_dataset"
dataset_copy_name = f"{dataset_name}_copy"

dataset = BaseDatasetDB(name=dataset_name, task=TaskType.text_classification)
user = User(username="test-user")

datasets_dao.delete_dataset(dataset)
datasets_dao.delete_dataset(
BaseDatasetDB(
name=dataset_copy_name,
task=TaskType.text_classification,
owner=user.username,
)
)

datasets_dao.create_dataset(dataset)

dataset_copy = datasets_service.copy_dataset(user, dataset=dataset, copy_name=dataset_copy_name)

assert dataset_copy.created_by == user.username
assert dataset_copy.name == dataset_copy_name
assert dataset_copy.owner == user.username

0 comments on commit 905d4de

Please sign in to comment.