Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Update QA to new inputs and datamodule #1045

Merged
merged 8 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ __MACOSX
*-v2.0.json
cifar-10*
mini-imagenet*
squad_tiny*

docs/source/_static/images/course_UvA-DL
docs/source/_static/images/lightning_examples
Expand Down
16 changes: 7 additions & 9 deletions docs/source/api/text.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,13 @@ __________________
~question_answering.model.QuestionAnsweringTask
~question_answering.data.QuestionAnsweringData

question_answering.data.QuestionAnsweringBackboneState
question_answering.data.QuestionAnsweringCSVInput
question_answering.data.QuestionAnsweringInput
question_answering.data.QuestionAnsweringDictionaryInput
question_answering.data.QuestionAnsweringFileInput
question_answering.data.QuestionAnsweringJSONInput
question_answering.data.QuestionAnsweringOutputTransform
question_answering.data.QuestionAnsweringInputTransform
question_answering.data.SQuADInput
question_answering.input.QuestionAnsweringInputBase
question_answering.input.QuestionAnsweringCSVInput
question_answering.input.QuestionAnsweringJSONInput
question_answering.input.QuestionAnsweringSQuADInput
question_answering.input.QuestionAnsweringDictionaryInput
question_answering.input_transform.QuestionAnsweringInputTransform
question_answering.output_transform.QuestionAnsweringOutputTransform

Summarization
_____________
Expand Down
17 changes: 10 additions & 7 deletions flash/core/integrations/transformers/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@
import torch

from flash.core.data.input_transform import InputTransform


def to_tensor(sample: Dict[str, Any]) -> Dict[str, Any]:
for key in sample:
sample[key] = torch.as_tensor(sample[key])
return sample
from flash.core.data.io.input import DataKeys


@dataclass
class TransformersInputTransform(InputTransform):
@staticmethod
def to_tensor(sample: Dict[str, Any]) -> Dict[str, Any]:
for key in sample:
if key is DataKeys.METADATA:
continue
sample[key] = torch.as_tensor(sample[key])
return sample

def per_sample_transform(self) -> Callable:
return to_tensor
return self.to_tensor
10 changes: 2 additions & 8 deletions flash/text/question_answering/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from flash.core.data.utils import download_data
from flash.core.utilities.flash_cli import FlashCLI
from flash.text import QuestionAnsweringData, QuestionAnsweringTask
Expand All @@ -20,21 +19,17 @@


def from_squad(
backbone: str = "distilbert-base-uncased",
batch_size: int = 4,
num_workers: int = 0,
**input_transform_kwargs,
**data_module_kwargs,
) -> QuestionAnsweringData:
"""Downloads and loads a tiny subset of the squad V2 data set."""
download_data("https://pl-flash-data.s3.amazonaws.com/squad_tiny.zip", "./data/")

return QuestionAnsweringData.from_squad_v2(
train_file="./data/squad_tiny/train.json",
val_file="./data/squad_tiny/val.json",
backbone=backbone,
batch_size=batch_size,
num_workers=num_workers,
**input_transform_kwargs,
**data_module_kwargs,
)


Expand All @@ -48,7 +43,6 @@ def question_answering():
"trainer.max_epochs": 3,
"model.backbone": "distilbert-base-uncased",
},
legacy=True,
)

cli.trainer.save_checkpoint("question_answering_model.pt")
Expand Down
Loading