Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added stratify option to train_test_split function. #4322

Merged
merged 16 commits into from
May 25, 2022

Conversation

nandwalritik
Copy link
Contributor

This PR adds stratify option to train_test_split method. I took reference from scikit-learn's StratifiedShuffleSplit class for implementing stratified split and integrated the changes as were suggested by @lhoestq.

It fixes #3452.

@lhoestq Please review and let me know, if any changes are required.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice thank you ! This will be super useful :)

Could you also add some tests in test_arrow_dataset.py and add an example of usage in the Example: section of the train_test_split docstring ?

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/splits.py Outdated Show resolved Hide resolved
src/datasets/splits.py Outdated Show resolved Hide resolved
src/datasets/splits.py Outdated Show resolved Hide resolved
@nandwalritik
Copy link
Contributor Author

Nice thank you ! This will be super useful :)

Could you also add some tests in test_arrow_dataset.py and add an example of usage in the Example: section of the train_test_split docstring ?

I will try to do it, is there any documentation for adding test cases? I have never done it before.

@lhoestq
Copy link
Member

lhoestq commented May 12, 2022

Thanks for the changes !

I will try to do it, is there any documentation for adding test cases? I have never done it before.

You can just add a function test_train_test_split_startify in test_arrow_dataset.py.

In this function you can define a dataset and make sure that train_test_split with the stratify argument works as expected.

You can do pytest tests/test_arrow_dataset.py::test_train_test_split_startify to run your test.

Feel free to get some inspiration from other tests like test_interleave_datasets for example

@nandwalritik
Copy link
Contributor Author

I have added tests for stratified train_test_split in test_arrow_dataset.py file inside test_train_test_split_startify function. I have also added example usage with stratify arg in Example: section of the train_test_split docstring.
Results of tests:

(data) nandwalritik@hp:~/datasets$ pytest tests/test_arrow_dataset.py::test_train_test_split_startify -W ignore
============================================================================ test session starts ============================================================================
platform linux -- Python 3.9.5, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/nandwalritik/datasets
plugins: datadir-1.3.1, forked-1.4.0, xdist-2.5.0
collected 1 item                                                                                                                                                            

tests/test_arrow_dataset.py .                                                                                                                                         [100%]

============================================================================= 1 passed in 0.12s =============================================================================

Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice!

PS: Maybe we should put the _approximate_mode and _stratified_shuffle_split_generate_indices in a separate module @lhoestq (utils/stratify.py for instance)? And if we keep them here as methods, we should at least mark them with the @staticmethod decorator.

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
@lhoestq
Copy link
Member

lhoestq commented May 16, 2022

Thanks a lot !

utils/stratify.py sounds good yes :)

Also feel free to merge master into your branch to fix the CI ;)

@nandwalritik nandwalritik force-pushed the add_stratify_train_test_split branch from 57bd246 to 1b60d4b Compare May 16, 2022 18:11
@nandwalritik nandwalritik force-pushed the add_stratify_train_test_split branch from 1b60d4b to 049b51b Compare May 17, 2022 02:42
@nandwalritik
Copy link
Contributor Author

Added all the changes as were suggested and rebased with main.

Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few more nits. Other than that looks good!

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/utils/stratify.py Outdated Show resolved Hide resolved
tests/test_arrow_dataset.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented May 18, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@mariosasko mariosasko left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! LGTM!

Pinging @lhoestq for the final review.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot !

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
@lhoestq lhoestq merged commit 961e596 into huggingface:master May 25, 2022
@Damon03
Copy link

Damon03 commented Nov 22, 2022

Hi, I encounter an error when I try to specify the stratify_by_column. However, I have a columns which specific the label of the row as a string. But an error showed when I try to do it. "ValueError: Stratifying by column is only supported for ClassLabel column, and column code is Value.".

@lhoestq
Copy link
Member

lhoestq commented Nov 22, 2022

Hi @Damon03 , you can change the type of your column to ClassLabel using

ds = ds.class_encode_column(column_name)

then you'll be free to use stratify :)

@Damon03
Copy link

Damon03 commented Nov 22, 2022

Thank you so much. It worked.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

why the stratify option is omitted from test_train_split function?
5 participants