diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index 03e311a8f5..75a21b8ef6 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -362,6 +362,7 @@ def from_dict( Examples ________ .. testsetup:: + >>> train_data = { ... "animal": ["cat", "dog", "cat"], ... "friendly": ["yes", "yes", "no"], @@ -371,16 +372,19 @@ def from_dict( ... "friendly": ["yes", "no", "yes"], ... "weight": [7, 12, 5], ... } + We have dictionaries ``train_data`` and ``predict_data``. + .. doctest:: + >>> from flash import Trainer >>> from flash.tabular import TabularClassifier, TabularClassificationData >>> datamodule = TabularClassificationData.from_dict( ... "friendly", ... "weight", ... "animal", - ... train_data_frame=train_data, - ... predict_data_frame=predict_data, + ... train_data=train_data, + ... predict_data=predict_data, ... batch_size=4, ... ) >>> datamodule.num_classes @@ -393,7 +397,9 @@ def from_dict( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + .. testcleanup:: + >>> del train_data >>> del predict_data """