-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added stratify option to train_test_split function. #4322
Added stratify option to train_test_split function. #4322
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice thank you ! This will be super useful :)
Could you also add some tests in test_arrow_dataset.py and add an example of usage in the Example:
section of the train_test_split
docstring ?
I will try to do it, is there any documentation for adding test cases? I have never done it before. |
Thanks for the changes !
You can just add a function In this function you can define a dataset and make sure that You can do Feel free to get some inspiration from other tests like |
I have added tests for stratified train_test_split in
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice!
PS: Maybe we should put the _approximate_mode
and _stratified_shuffle_split_generate_indices
in a separate module @lhoestq (utils/stratify.py
for instance)? And if we keep them here as methods, we should at least mark them with the @staticmethod
decorator.
Thanks a lot !
Also feel free to merge |
57bd246
to
1b60d4b
Compare
1b60d4b
to
049b51b
Compare
Added all the changes as were suggested and rebased with |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few more nits. Other than that looks good!
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! LGTM!
Pinging @lhoestq for the final review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot !
Hi, I encounter an error when I try to specify the stratify_by_column. However, I have a columns which specific the label of the row as a string. But an error showed when I try to do it. "ValueError: Stratifying by column is only supported for ClassLabel column, and column code is Value.". |
Hi @Damon03 , you can change the type of your column to ClassLabel using ds = ds.class_encode_column(column_name) then you'll be free to use |
Thank you so much. It worked. |
This PR adds
stratify
option totrain_test_split
method. I took reference from scikit-learn'sStratifiedShuffleSplit
class for implementing stratified split and integrated the changes as were suggested by @lhoestq.It fixes #3452.
@lhoestq Please review and let me know, if any changes are required.