Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
krshrimali committed May 6, 2022
1 parent 39f4c95 commit 8787f63
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions flash/tabular/classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def from_dict(
Examples
________
.. testsetup::
>>> train_data = {
... "animal": ["cat", "dog", "cat"],
... "friendly": ["yes", "yes", "no"],
Expand All @@ -371,16 +372,19 @@ def from_dict(
... "friendly": ["yes", "no", "yes"],
... "weight": [7, 12, 5],
... }
We have dictionaries ``train_data`` and ``predict_data``.
.. doctest::
>>> from flash import Trainer
>>> from flash.tabular import TabularClassifier, TabularClassificationData
>>> datamodule = TabularClassificationData.from_dict(
... "friendly",
... "weight",
... "animal",
... train_data_frame=train_data,
... predict_data_frame=predict_data,
... train_data=train_data,
... predict_data=predict_data,
... batch_size=4,
... )
>>> datamodule.num_classes
Expand All @@ -393,7 +397,9 @@ def from_dict(
Training...
>>> trainer.predict(model, datamodule=datamodule) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
Predicting...
.. testcleanup::
>>> del train_data
>>> del predict_data
"""
Expand Down

0 comments on commit 8787f63

Please sign in to comment.