Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to_tf_dataset rewrite #4170

Merged
merged 56 commits into from
Jun 6, 2022
Merged

to_tf_dataset rewrite #4170

merged 56 commits into from
Jun 6, 2022

Conversation

Rocketknight1
Copy link
Member

@Rocketknight1 Rocketknight1 commented Apr 14, 2022

This PR rewrites almost all of to_tf_dataset(), which makes it kind of hard to list all the changes, but the most critical ones are:

  • Much better stability and no more dropping unexpected column names (Sorry @NielsRogge)
  • Doesn't clobber custom transforms on the data (Sorry @NielsRogge again)
  • Much better handling of the situation when the collate_fn adds columns that aren't in the dataset.
  • Better inference of shapes and data types
  • Lots of hacky special-casing code removed
  • Can return string columns (as tf.String)
  • Most arguments have default values, calling the method should be much simpler
  • Can accept a model argument and only return columns that are valid inputs to that model
  • Drops the dummy_labels argument - this was a workaround for Keras issues that have been resolved by changes in transformers. Also remove it from tests and the Overview notebook.

I still have a couple of TODOs remaining and some testing to do, so don't merge yet, but it should be mostly ready for review at this point!

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 14, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact your changes make this very entangled with Transformers code makes me think you are trying to do too much here. In particular the code for guessing the labels or the numpy data collator from Transformers are very aimed at Transformer models, and the method in Datasets would be expected to work on any model. Copy-pasting code from Transformers is kind of a red flag ;-)

I think we need to rethink the design and make the base to_tf_dataset more basic and less magic, then have another method in Transformers that will do the magic guessing of defaults using the model (and use a different default collate).

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
@Rocketknight1
Copy link
Member Author

Magic is now banned by decree of @sgugger. This is honestly much cleaner, and the functionality will make much more sense in transformers anyway!

Copy link
Contributor

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better for me! Are there any real changes in the notebook? All I see in the diff is Pycharm adding useless metadata.

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Added a few minor comments, and I'd like to make a request as well: since we now have default arguments for everything, can we add a test case with no passed arguments?

(I don't want to approve because I don't have much datasets knowledge :) )

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
@Rocketknight1
Copy link
Member Author

@gante I renamed the default collator to minimal_tf_collate_fn!

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome thank you !

I added a few comments, and also found a bug when you don't specify the columns explicitly (a None is passed to _get_output_signature).

Feel free to also add some tests to make sure that the new default behavior works as expected !

src/datasets/utils/tf_utils.py Outdated Show resolved Hide resolved
"""
# TODO Try an Image dataset and see if we can do the conversion
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Image datasets return PIL.Image objects by default, which are not supported by TF.

However we can improve the .with_format("numpy") so that it returns a numpy array instead of the PIL Image, so that you don't have to deal with custom types by yourself in to_tf_dataset and always assume that you'll get numpy arrays that work with TF :)

cc @mariosasko

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we're not clobbering with_transform then users can also just do that!

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's still not working, see my comment. Can you please add some tests ?

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
"encoded_tf_dataset = encoded_dataset['train'].to_tf_dataset(\n",
" columns=columns,\n",
" collate_fn=collate_fn,\n",
" batch_size=8,\n",
" shuffle=True,\n",
" dummy_labels=True\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was pretty hard to find these little changes in this big diff xD

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, I don't know how to stop Jupyter saving all that junk when I just want to change one line!

src/datasets/utils/tf_utils.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
@Rocketknight1
Copy link
Member Author

@lhoestq @sgugger @gante

I think this should now be ready, it looks good in testing! I'll try a few more notebooks today and tomorrow to be sure before I merge. Key changes are:

  • No column autodetection magic (will make a separate PR to add this as a transformers function)
  • Drops non-numerical features automatically (this is more of a 'DataLoader' method, we'll have a separate method to expose 'raw' datasets to tf.data)
  • Better autodetection of numerical features.
  • Shouldn't randomly crash mid-function 💀

We definitely have some questions still to resolve about how to handle making a 'DataLoader' dataset versus a 'raw' dataset - see the Notion doc if you're interested. Still, since this PR is just fixes/improvements to an existing method which never supported non-numerical features anyway, we can merge it before we've resolved those issues, and then think about how to name and split things afterwards.

@Rocketknight1
Copy link
Member Author

P.S. I'll take out the region comments at the end before I merge, I promise! They're just helpful while I'm editing it

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! 👍

Can we also add a test case for .to_tf_dataset(), i.e. using defaults for all arguments?

src/datasets/utils/tf_utils.py Outdated Show resolved Hide resolved
@lhoestq
Copy link
Member

lhoestq commented May 2, 2022

+1 for the tests

Drops non-numerical features automatically

Can you give more details on how this work and the rationale as well ? This is not explained in the docs

Also why are you adding error_on_missing and auto_fix_label_names ? The rationale is not clear to me. In particular I think it is sensible enough to expect users to not ask columns that don't exist, and to rename a label column when required.

@Rocketknight1
Copy link
Member Author

@lhoestq I rewrote those parts - they were causing some other issues too! error_on_missing and auto_fix_label_names have been removed. The new logic is to simply drop (before batch collation) all columns the user doesn't ask for, but not to raise errors if the user asked for columns not in the dataset, as they may be added by the collator. Hopefully this cleans it up and matches the documentation better!

Comment on lines 400 to 412
# Following the logic in `transformers.Trainer`, we do not drop `label_ids` or `label` even if they
# are not in the list of requested columns, because the collator may rename them
# This might work better if moved to a method attached to our transformers Model objects, but doing so
# could break backward compatibility
unwanted_columns = [
col
for col in self.features.keys()
if col not in columns and col not in label_cols and col not in ("label_ids", "label")
]
dataset = dataset.remove_columns(unwanted_columns)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing I'm still not sure about is that we special-case "label_ids" and "label", because those columns are frequently renamed to "labels" in transformers and we don't want to drop them. In PyTorch this special-casing occurs too, but it happens inside Trainer, so it doesn't mix things up between the transformers and datasets libraries. We might have to leave it here for now until I can make the 'magic method' on our transformers models!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like this part to be removed since it's transformers.Trainer-specific
What needs to be backward compatible ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The standard workflow for transformers is that the column in the dataset is called label, but the argument to the model is called labels, and the renaming is done by the collate_fn. Therefore, we must be careful not to drop columns called label or label_ids, even if the user doesn't request them, because the collate_fn might need them.

In PyTorch columns are dropped by the Trainer, and it makes sure not to drop these. However, in Tensorflow, columns are dropped by the to_tf_dataset() method, and therefore this code needs to be in there.

I think a good solution would be to make a method in transformers that calls to_tf_dataset() and converts datasets for training, and then we could move the label and label_ids special casing in there. However, until we do that, we'll need to keep the special casing in to_tf_dataset() or all our examples will break!

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, let me know if you need help for the tests :)

I think we need several additional tests, especially since the function is quite long and not trivial to maintain. In particular it would be nice to check the default behavior, and check the behavior of the main parameters.

src/datasets/arrow_dataset.py Show resolved Hide resolved
@Rocketknight1
Copy link
Member Author

@lhoestq New tests are now in!

@Rocketknight1
Copy link
Member Author

Seeing some other random tests failing that don't look to be associated with this PR.

@Rocketknight1
Copy link
Member Author

@lhoestq I can't figure out these test failures! They don't seem related to this PR at all, but I rebased to the latest version and they keep happening, even though they're not visible on master.

@lhoestq
Copy link
Member

lhoestq commented May 19, 2022

Thanks for the ping, will take a look tomorrow :)

Maybe the rebase didn't go well for the code recently merged about label alignment from #4277 ?

@Rocketknight1
Copy link
Member Author

Rocketknight1 commented May 20, 2022

@lhoestq Got it! It was caused by a name collision - I was importing typing.Sequence, but the code also needed features.Sequence. The tests from that PR were expecting the latter but got the former, and then crashed.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh good catch ! And thanks for the tests :)

I think we just need to update the docstring, and I also added a question about the default batch_size (sorry for not raising the question earlier, I missed it)

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Show resolved Hide resolved
Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you ! I added more comments about the default for shuffle and drop_remainder, and some suggestions:

src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool ! This is a HUGE step !! Thanks a lot @Rocketknight1 :)

Here are my final comments, then I think we can merge 🚀

src/datasets/arrow_dataset.py Show resolved Hide resolved
src/datasets/arrow_dataset.py Show resolved Hide resolved
tests/test_arrow_dataset.py Outdated Show resolved Hide resolved
tests/test_arrow_dataset.py Show resolved Hide resolved
src/datasets/arrow_dataset.py Outdated Show resolved Hide resolved
@Rocketknight1
Copy link
Member Author

@lhoestq Thanks! Also, when you're ready, don't merge it immediately! I'd like to do a quick round of manual testing with the very final build once you're happy to make sure it still works in our notebooks and examples.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks ! Feel free to run all your tests and merge yourself when you're good and if the CI is green :)

@Rocketknight1
Copy link
Member Author

@lhoestq Tests look good to me, merging now!

@Rocketknight1 Rocketknight1 merged commit e3f2bbb into master Jun 6, 2022
@Rocketknight1 Rocketknight1 deleted the to_tf_dataset_tpu_warning branch June 6, 2022 14:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants