From 8787f63160b7f50773f5a0777f217be3da6dd041 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Fri, 6 May 2022 16:48:45 +0530 Subject: [PATCH] fixes --- flash/tabular/classification/data.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/flash/tabular/classification/data.py b/flash/tabular/classification/data.py index 03e311a8f5..75a21b8ef6 100644 --- a/flash/tabular/classification/data.py +++ b/flash/tabular/classification/data.py @@ -362,6 +362,7 @@ def from_dict( Examples ________ .. testsetup:: + >>> train_data = { ... "animal": ["cat", "dog", "cat"], ... "friendly": ["yes", "yes", "no"], @@ -371,16 +372,19 @@ def from_dict( ... "friendly": ["yes", "no", "yes"], ... "weight": [7, 12, 5], ... } + We have dictionaries ``train_data`` and ``predict_data``. + .. doctest:: + >>> from flash import Trainer >>> from flash.tabular import TabularClassifier, TabularClassificationData >>> datamodule = TabularClassificationData.from_dict( ... "friendly", ... "weight", ... "animal", - ... train_data_frame=train_data, - ... predict_data_frame=predict_data, + ... train_data=train_data, + ... predict_data=predict_data, ... batch_size=4, ... ) >>> datamodule.num_classes @@ -393,7 +397,9 @@ def from_dict( Training... >>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE Predicting... + .. testcleanup:: + >>> del train_data >>> del predict_data """