diff --git a/examples/dna_language_models/dna_lm.ipynb b/examples/dna_language_models/dna_lm.ipynb
new file mode 100644
index 0000000000..b58ffd3d21
--- /dev/null
+++ b/examples/dna_language_models/dna_lm.ipynb
@@ -0,0 +1,2858 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "db4dc272-88fe-47ad-98fd-b94d4f840dca",
+ "metadata": {
+ "id": "db4dc272-88fe-47ad-98fd-b94d4f840dca"
+ },
+ "source": [
+ "# PEFT with DNA Language Models"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d381f473-0d37-4b5b-ae9e-d2b32bab7c04",
+ "metadata": {
+ "id": "d381f473-0d37-4b5b-ae9e-d2b32bab7c04"
+ },
+ "source": [
+ "This notebook demonstrates how to utilize parameter-efficient fine-tuning techniques (PEFT) from the PEFT library to fine-tune a DNA Language Model (DNA-LM). The fine-tuned DNA-LM will be applied to solve a task from the nucleotide benchmark dataset. Parameter-efficient fine-tuning (PEFT) techniques are crucial for adapting large pre-trained models to specific tasks with limited computational resources."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "23f460c3-d7e5-437f-a5e9-d029cd225bf8",
+ "metadata": {
+ "id": "23f460c3-d7e5-437f-a5e9-d029cd225bf8"
+ },
+ "source": [
+ "### 1. Import relevant libraries"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "29a35f95-738a-4f5e-88ce-dc5f8f9be5dc",
+ "metadata": {
+ "id": "29a35f95-738a-4f5e-88ce-dc5f8f9be5dc"
+ },
+ "source": [
+ "We'll start by importing the required libraries, including the PEFT library and other dependencies."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "0a40abdf-ca1c-436f-a2af-603cd67a45a4",
+ "metadata": {
+ "id": "0a40abdf-ca1c-436f-a2af-603cd67a45a4"
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/opt/homebrew/anaconda3/envs/peft/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "import transformers\n",
+ "import peft\n",
+ "import tqdm\n",
+ "import numpy as np"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a445f8be-545d-4085-a5f9-c64983655224",
+ "metadata": {
+ "id": "a445f8be-545d-4085-a5f9-c64983655224"
+ },
+ "source": [
+ "### 2. Load models\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "63782b55-1c38-4e44-b003-e57daa813bed",
+ "metadata": {
+ "id": "63782b55-1c38-4e44-b003-e57daa813bed"
+ },
+ "source": [
+ "We'll load a pre-trained DNA Language Model, \"SpeciesLM\", that serves as the base for fine-tuning. This is done using the transformers library from HuggingFace.\n",
+ "\n",
+ "The tokenizer and the model comes from the paper, \"Species-aware DNA language models capture regulatory elements and their evolution\". [Paper Link](https://www.biorxiv.org/content/10.1101/2023.01.26.525670v2), [Code Link](https://github.com/gagneurlab/SpeciesLM). They introduce a species-aware DNA language model, which is trained on more than 800 species spanning over 500 million years of evolution."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "dac961f4-c450-4124-923e-f4ba9bbd5e07",
+ "metadata": {
+ "id": "dac961f4-c450-4124-923e-f4ba9bbd5e07"
+ },
+ "outputs": [],
+ "source": [
+ "from transformers import AutoTokenizer, AutoModelForMaskedLM"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "e73fae58-03e9-4acc-b0fc-9bc810c7d366",
+ "metadata": {
+ "id": "e73fae58-03e9-4acc-b0fc-9bc810c7d366"
+ },
+ "outputs": [],
+ "source": [
+ "tokenizer = AutoTokenizer.from_pretrained(\"gagneurlab/SpeciesLM\", revision = \"downstream_species_lm\")\n",
+ "lm = AutoModelForMaskedLM.from_pretrained(\"gagneurlab/SpeciesLM\", revision = \"downstream_species_lm\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "ca43b893-2d66-4e93-a08f-b17a92040709",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ca43b893-2d66-4e93-a08f-b17a92040709",
+ "outputId": "ccbac964-a329-414d-f537-3cae7da66cf2"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "BertForMaskedLM(\n",
+ " (bert): BertModel(\n",
+ " (embeddings): BertEmbeddings(\n",
+ " (word_embeddings): Embedding(5504, 768, padding_idx=0)\n",
+ " (position_embeddings): Embedding(512, 768)\n",
+ " (token_type_embeddings): Embedding(2, 768)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): BertEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0-11): 12 x BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSdpaSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (cls): BertOnlyMLMHead(\n",
+ " (predictions): BertLMPredictionHead(\n",
+ " (transform): BertPredictionHeadTransform(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (transform_act_fn): GELUActivation()\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " )\n",
+ " (decoder): Linear(in_features=768, out_features=5504, bias=True)\n",
+ " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "lm.eval()\n",
+ "lm.to(\"cuda\");"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c1bda6f2-34bb-4ce2-aa3f-3013548b0a28",
+ "metadata": {
+ "id": "c1bda6f2-34bb-4ce2-aa3f-3013548b0a28"
+ },
+ "source": [
+ "### 2. Prepare datasets"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "f4c61e59-457c-47d9-8929-5e8cd32d3125",
+ "metadata": {
+ "id": "f4c61e59-457c-47d9-8929-5e8cd32d3125"
+ },
+ "source": [
+ "We'll load the `nucleotide_transformer_downstream_tasks` dataset, which contains 18 downstream tasks from the Nucleotide Transformer paper. This dataset provides a consistent genomics benchmark with binary classification tasks."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "f5c0b3df-911a-4645-9140-99ee489515e8",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 145,
+ "referenced_widgets": [
+ "03bba232d3974119acf8031bc086a072",
+ "9107f7bfc8d3483390f802b0458e9380",
+ "f5c80fa70ead4c86aa3b2a046061b901",
+ "57966a469ca1458daab74e81672ae855",
+ "1464502dc3dd46308be8b4fcc9d5ddb9",
+ "92f64c7e088342b9b3c070ba7a295ed0",
+ "ab0aa8af3816422e9d97934f12af842c",
+ "ff89a891bd9c42a8be164587a94ccac1",
+ "e113a50f8ed2410ca12ce7cb38a1681d",
+ "1afa6e9b69c74136863b7747e62a0608",
+ "0838d19b226d486285a26ce0b04d7e15",
+ "7bdab33f4b244fc89408b91755bf17c5",
+ "4d4ce0d35c124690b3427e84a9a128b1",
+ "33be6b0ca8fd44188f834a48a9574a72",
+ "74e9bc1ead434ae78077df6b85f1df58",
+ "e1acc6e70b9246a5b063b3e262f01c81",
+ "078c6877377a491d97d6fadd27064a76",
+ "d46ee1c39bac44c2b541a88c883de1cb",
+ "12f1de7122a7471e90f01d9e7be81178",
+ "dad286d42a514c9ca6bb01bfe9e9c4be",
+ "c028ed977b5e479fbd93b8add588a6dc",
+ "6d80dec073e449efba272fa9f3527922",
+ "c311b777514f41ef986756a386c0bb34",
+ "e2e4bf053ce442f6aee6ffab5f76525f",
+ "c88cf701e20b4354a63ac7d8645d1df9",
+ "f71c252ada474be882b0335ed9a0a1c3",
+ "e059c665229e46ea905dcbd6fc179c88",
+ "bd5273325a4b453e8053d98a09fe9493",
+ "8f20ed2b74d84e80a8d403793354adea",
+ "57c9af47364d48ffbb4ffbdd2c951ede",
+ "fa9d75fcb1d5400c8ca1d1d13d28d0c7",
+ "682644a713b145f0b2dcff99790c6d4d",
+ "9b9b9d573d44464f9a6f5030a40245fe",
+ "ec165fdbe87a4b00a6c288ef1e85c0a9",
+ "17859b793a304e389d1ea0b9ccc3646f",
+ "34921fd116cc42b7b530174d9f61e71e",
+ "2d5466a5e98849c5a09f16faa98f91da",
+ "952397f9c91c480184fa57e175ab1b4c",
+ "86bcccb842244f4f9add58f62facaace",
+ "78b5bbf4c8ac4fe5961776fded4d5798",
+ "c80062a855cb41a28ac625ab03635da2",
+ "aecd740c17c84d45b0615d4fc4196035",
+ "39640709e7174f84a50da05764abbf99",
+ "7114a029e75c4ed5b966eddd3a3c919d"
+ ]
+ },
+ "id": "f5c0b3df-911a-4645-9140-99ee489515e8",
+ "outputId": "15315be1-9d07-4c46-acda-c65cb5a05250"
+ },
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "03bba232d3974119acf8031bc086a072",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading data: 0%| | 0.00/3.50M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7bdab33f4b244fc89408b91755bf17c5",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Downloading data: 0%| | 0.00/391k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c311b777514f41ef986756a386c0bb34",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating train split: 0%| | 0/13468 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "ec165fdbe87a4b00a6c288ef1e85c0a9",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Generating test split: 0%| | 0/1497 [00:00, ? examples/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from datasets import load_dataset\n",
+ "\n",
+ "raw_data = load_dataset(\"InstaDeepAI/nucleotide_transformer_downstream_tasks\", \"H3\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bbb527c5-8077-4ce4-b093-ae627a5f253c",
+ "metadata": {
+ "id": "bbb527c5-8077-4ce4-b093-ae627a5f253c"
+ },
+ "source": [
+ "We'll use the \"H3\" subset of this dataset, which contains a total of 13,468 rows in the training data, and 1497 rows in the test data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "efef4bb2-60d8-40d1-8777-2b665a87059c",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "efef4bb2-60d8-40d1-8777-2b665a87059c",
+ "outputId": "1c8526ce-0fcb-4fbc-d592-f9a6eae6ebdb"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "DatasetDict({\n",
+ " train: Dataset({\n",
+ " features: ['sequence', 'name', 'label'],\n",
+ " num_rows: 13468\n",
+ " })\n",
+ " test: Dataset({\n",
+ " features: ['sequence', 'name', 'label'],\n",
+ " num_rows: 1497\n",
+ " })\n",
+ "})"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "raw_data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "aafd37c8-6830-4070-a73b-cf62e72e901c",
+ "metadata": {
+ "id": "aafd37c8-6830-4070-a73b-cf62e72e901c"
+ },
+ "source": [
+ "The dataset consists of three columns, ```sequence```, ```name``` and ```label```. An row in this dataset looks like:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "eecd39d8-c073-4d3e-940e-fd83d46f83ab",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "eecd39d8-c073-4d3e-940e-fd83d46f83ab",
+ "outputId": "0b5f8800-eb5d-4c41-a2bc-73e4f837a4d8",
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'sequence': 'TCACTTCGATTATTGAGGCAGTCTTCATTAAAGTTTATTACAATGGATATGGTATCACCAGTCTTGAACCTACAATCATCTATTTTAGGTGAGCTCGTAGGCATTATTGGAAAAGTGTTCTTTCTCTTAATAGAAGAGATTAAATACCCGATAATCACACCCAAAATTATTGTGGATGCCCAGATATCTTCTTGGTCATTGTTTTTTTTCGCTTCAATCTGTAATCTCTCTGCAAAATTTCGGGAGCCAATAGTGACAACATCGTCAATAATAAGTTTGATGGAATCGGAAAAAGATCTTAAAAATGTAAATGAGTATTTCCAAATAATGGCCAAAATGCTCTTTATATTGGAAAATAAAATAGTTGTTTCGCTCTTCGTAGTATTTAACATTTCCGTTCTTATCATTGTAAAGTCTGAGCCATATTCATATGGAAAAGTGCTTTTTAAACCTAGTTCCTCCATATTTTAGTTTTTTATCGATATTGGAAAAAAAAGAGC',\n",
+ " 'name': 'YBR063C_YBR063C_367930|0',\n",
+ " 'label': 0}"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "raw_data['train'][0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "92eccf3e-e846-4c59-af56-0e336ac5a1cd",
+ "metadata": {
+ "id": "92eccf3e-e846-4c59-af56-0e336ac5a1cd"
+ },
+ "source": [
+ "We split out dataset into training, test, and validation sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "f0649bbd-e74e-4dd6-a564-c4d65e46dbbf",
+ "metadata": {
+ "id": "f0649bbd-e74e-4dd6-a564-c4d65e46dbbf"
+ },
+ "outputs": [],
+ "source": [
+ "from datasets import Dataset, DatasetDict\n",
+ "\n",
+ "train_valid_split = raw_data['train'].train_test_split(test_size=0.15, seed=42)\n",
+ "\n",
+ "train_valid_split = DatasetDict({\n",
+ " 'train': train_valid_split['train'],\n",
+ " 'validation': train_valid_split['test']\n",
+ "})\n",
+ "\n",
+ "ds = DatasetDict({\n",
+ " 'train': train_valid_split['train'],\n",
+ " 'validation': train_valid_split['validation'],\n",
+ " 'test': raw_data['test']\n",
+ "})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5424726f-a7ba-45d5-b449-36be9a98b8e6",
+ "metadata": {
+ "id": "5424726f-a7ba-45d5-b449-36be9a98b8e6"
+ },
+ "source": [
+ "Then, we use the tokenizer and a utility function we created, ```get_kmers``` to generate the final data and labels. The ```get_kmers``` function is essential for generating overlapping 6-mers needed by the language model (LM). By using k=6 and stride=1, we ensure that the model receives continuous and overlapping subsequences, capturing the local context within the biological sequence for more effective analysis and prediction.\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "75f267a9-82d1-4343-982e-9b1ea542a330",
+ "metadata": {
+ "id": "75f267a9-82d1-4343-982e-9b1ea542a330"
+ },
+ "outputs": [],
+ "source": [
+ "def get_kmers(seq, k=6, stride=1):\n",
+ " return [seq[i:i + k] for i in range(0, len(seq), stride) if i + k <= len(seq)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "efa9441d-f44c-4ca3-b24c-fa5c853896cd",
+ "metadata": {
+ "id": "efa9441d-f44c-4ca3-b24c-fa5c853896cd"
+ },
+ "outputs": [],
+ "source": [
+ "test_sequences = []\n",
+ "train_sequences = []\n",
+ "val_sequences = []\n",
+ "\n",
+ "dataset_limit = 200 # NOTE: This dataset limit is set to 200, so that the training runs faster. It can be set to None to use the\n",
+ " # entire dataset\n",
+ "\n",
+ "for i in range(0, len(ds['train'])):\n",
+ "\n",
+ " if dataset_limit and i == dataset_limit:\n",
+ " break\n",
+ "\n",
+ " sequence = ds['train'][i]['sequence']\n",
+ " sequence = \"candida_glabrata \" + \" \".join(get_kmers(sequence))\n",
+ " sequence = tokenizer(sequence)[\"input_ids\"]\n",
+ " train_sequences.append(sequence)\n",
+ "\n",
+ "\n",
+ "for i in range(0, len(ds['validation'])):\n",
+ " if dataset_limit and i == dataset_limit:\n",
+ " break\n",
+ " sequence = ds['validation'][i]['sequence']\n",
+ " sequence = \"candida_glabrata \" + \" \".join(get_kmers(sequence))\n",
+ " sequence = tokenizer(sequence)[\"input_ids\"]\n",
+ " val_sequences.append(sequence)\n",
+ "\n",
+ "\n",
+ "for i in range(0, len(ds['test'])):\n",
+ " if dataset_limit and i == dataset_limit:\n",
+ " break\n",
+ " sequence = ds['test'][i]['sequence']\n",
+ " sequence = \"candida_glabrata \" + \" \".join(get_kmers(sequence))\n",
+ " sequence = tokenizer(sequence)[\"input_ids\"]\n",
+ " test_sequences.append(sequence)\n",
+ "\n",
+ "\n",
+ "train_labels = ds['train']['label']\n",
+ "test_labels = ds['test']['label']\n",
+ "val_labels = ds['validation']['label']\n",
+ "\n",
+ "if dataset_limit:\n",
+ " train_labels = train_labels[0:dataset_limit]\n",
+ " test_labels = test_labels[0:dataset_limit]\n",
+ " val_labels = val_labels[0:dataset_limit]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0686955c-201a-427b-8bef-5c663edb85b8",
+ "metadata": {
+ "id": "0686955c-201a-427b-8bef-5c663edb85b8"
+ },
+ "source": [
+ "Finally, we create a Dataset object for each our sets."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "445b4279-2446-46d6-af2a-ceb2638955c7",
+ "metadata": {
+ "id": "445b4279-2446-46d6-af2a-ceb2638955c7"
+ },
+ "outputs": [],
+ "source": [
+ "from datasets import Dataset\n",
+ "\n",
+ "train_dataset = Dataset.from_dict({\"input_ids\": train_sequences, \"labels\": train_labels})\n",
+ "val_dataset = Dataset.from_dict({\"input_ids\": val_sequences, \"labels\": val_labels})\n",
+ "test_dataset = Dataset.from_dict({\"input_ids\": test_sequences, \"labels\": test_labels})"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d05d51a7-b933-4793-95df-af7d4d510b13",
+ "metadata": {
+ "id": "d05d51a7-b933-4793-95df-af7d4d510b13"
+ },
+ "source": [
+ "### 4. Train model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b5ce1985-c24e-4feb-a6d4-aacb909536f0",
+ "metadata": {
+ "id": "b5ce1985-c24e-4feb-a6d4-aacb909536f0"
+ },
+ "source": [
+ "Now, we'll train our DNA Language Model with the training dataset. We'll add a linear layer in the final layer of our language model, and then, train all the parameteres of our model with the training dataset."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "3a34b2c0-6205-4d48-b1a6-371b50ca42de",
+ "metadata": {
+ "id": "3a34b2c0-6205-4d48-b1a6-371b50ca42de"
+ },
+ "outputs": [],
+ "source": [
+ "from transformers import DataCollatorWithPadding\n",
+ "\n",
+ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "700540f4-0ab8-4f8a-a75c-416a6908af47",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "700540f4-0ab8-4f8a-a75c-416a6908af47",
+ "outputId": "9e16c1e9-4676-4cdf-b2a9-d785773b1c8d"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "DNA_LM(\n",
+ " (model): BertModel(\n",
+ " (embeddings): BertEmbeddings(\n",
+ " (word_embeddings): Embedding(5504, 768, padding_idx=0)\n",
+ " (position_embeddings): Embedding(512, 768)\n",
+ " (token_type_embeddings): Embedding(2, 768)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): BertEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0-11): 12 x BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSdpaSelfAttention(\n",
+ " (query): lora.Linear(\n",
+ " (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.01, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=768, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=768, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (key): lora.Linear(\n",
+ " (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.01, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=768, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=768, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (value): lora.Linear(\n",
+ " (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.01, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=768, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=768, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "from torch import nn\n",
+ "\n",
+ "class DNA_LM(nn.Module):\n",
+ " def __init__(self, model, num_labels):\n",
+ " super(DNA_LM, self).__init__()\n",
+ " self.model = model.bert\n",
+ " self.in_features = model.config.hidden_size\n",
+ " self.out_features = num_labels\n",
+ " self.classifier = nn.Linear(self.in_features, self.out_features)\n",
+ "\n",
+ " def forward(self, input_ids, attention_mask=None, labels=None):\n",
+ " outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)\n",
+ " sequence_output = outputs.hidden_states[-1]\n",
+ " # Use the [CLS] token for classification\n",
+ " cls_output = sequence_output[:, 0, :]\n",
+ " logits = self.classifier(cls_output)\n",
+ "\n",
+ " loss = None\n",
+ " if labels is not None:\n",
+ " loss_fct = nn.CrossEntropyLoss()\n",
+ " loss = loss_fct(logits.view(-1, self.out_features), labels.view(-1))\n",
+ "\n",
+ " return (loss, logits) if loss is not None else logits\n",
+ "\n",
+ "# Number of classes for your classification task\n",
+ "num_labels = 2\n",
+ "classification_model = DNA_LM(lm, num_labels)\n",
+ "classification_model.to('cuda');"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "0af97341-8f95-41d9-9d91-1eb64da4b516",
+ "metadata": {
+ "id": "0af97341-8f95-41d9-9d91-1eb64da4b516"
+ },
+ "outputs": [],
+ "source": [
+ "from transformers import DataCollatorWithPadding\n",
+ "\n",
+ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "d9ce6bc3-4f63-4b7b-b28d-d2553002e6db",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 268
+ },
+ "id": "d9ce6bc3-4f63-4b7b-b28d-d2553002e6db",
+ "outputId": "0c8fdbad-f34d-492b-e146-db6c2064e7c5"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ " [65/65 01:43, Epoch 5/5]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.887400 | \n",
+ " 0.685295 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.644700 | \n",
+ " 0.682495 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.599600 | \n",
+ " 0.680431 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.892800 | \n",
+ " 0.679170 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.663800 | \n",
+ " 0.678761 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "TrainOutput(global_step=65, training_loss=0.7263066686116733, metrics={'train_runtime': 104.8696, 'train_samples_per_second': 9.536, 'train_steps_per_second': 0.62, 'total_flos': 0.0, 'train_loss': 0.7263066686116733, 'epoch': 5.0})"
+ ]
+ },
+ "execution_count": 36,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from transformers import Trainer, TrainingArguments\n",
+ "\n",
+ "\n",
+ "# Define training arguments\n",
+ "training_args = TrainingArguments(\n",
+ " output_dir='./results',\n",
+ " eval_strategy=\"epoch\",\n",
+ " learning_rate=2e-5,\n",
+ " per_device_train_batch_size=16,\n",
+ " per_device_eval_batch_size=16,\n",
+ " num_train_epochs=5,\n",
+ " weight_decay=0.01,\n",
+ " eval_steps=1,\n",
+ " logging_steps=1,\n",
+ ")\n",
+ "\n",
+ "# Initialize Trainer\n",
+ "trainer = Trainer(\n",
+ " model=classification_model,\n",
+ " args=training_args,\n",
+ " train_dataset=train_dataset,\n",
+ " eval_dataset=val_dataset,\n",
+ " tokenizer=tokenizer,\n",
+ " data_collator=data_collator,\n",
+ ")\n",
+ "\n",
+ "# Train the model\n",
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ebc7e33a-caad-4412-84e3-3e1ce7d02ccd",
+ "metadata": {
+ "id": "ebc7e33a-caad-4412-84e3-3e1ce7d02ccd"
+ },
+ "source": [
+ "### 5. Evaluation"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "38eb0273-ce7e-4770-8457-2f9609f6843b",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 124
+ },
+ "id": "38eb0273-ce7e-4770-8457-2f9609f6843b",
+ "outputId": "2b0b93c9-0199-4e71-9825-9f6a2bd199d0"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[0 1 1 1 1 1 1 1 1 0 1 0 1 1 1 1 0 1 1 0 1 1 0 1 1 1 1 0 1 0 0 0 1 1 0 1 1\n",
+ " 1 1 1 0 1 1 1 0 1 1 0 1 1 1 1 1 1 1 1 0 1 0 0 1 1 1 1 1 0 0 0 1 0 1 1 0 1\n",
+ " 0 1 1 0 1 1 1 0 0 1 0 1 0 1 0 1 1 1 0 1 1 1 1 0 1 0 0 0 0 1 0 1 0 0 1 1 1\n",
+ " 1 0 1 1 0 0 1 1 1 0 1 1 1 1 0 0 1 1 1 1 0 0 1 1 1 0 0 1 1 0 1 1 0 1 1 0 1\n",
+ " 1 1 1 1 1 1 0 0 1 1 1 1 1 1 1 1 0 0 1 0 1 1 1 1 1 1 1 0 1 1 1 0 0 1 1 1 1\n",
+ " 0 1 1 1 1 0 1 1 0 0 1 0 1 1 0]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Generate predictions\n",
+ "\n",
+ "predictions = trainer.predict(test_dataset)\n",
+ "logits = predictions.predictions\n",
+ "predicted_labels = logits.argmax(axis=-1)\n",
+ "print(predicted_labels)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ae4c7bca",
+ "metadata": {
+ "id": "ae4c7bca"
+ },
+ "source": [
+ "Then, we create a function to calculate the accuracy from the test and predicted labels."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "327a1c3b-88d6-4430-8978-73a7cbdbb697",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "327a1c3b-88d6-4430-8978-73a7cbdbb697",
+ "outputId": "f03ad54d-d35f-4fcc-e709-c24d14906e25"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Accuracy: 0.53\n"
+ ]
+ }
+ ],
+ "source": [
+ "def calculate_accuracy(true_labels, predicted_labels):\n",
+ "\n",
+ " assert len(true_labels) == len(predicted_labels), \"Arrays must have the same length\"\n",
+ " correct_predictions = np.sum(true_labels == predicted_labels)\n",
+ " accuracy = correct_predictions / len(true_labels)\n",
+ "\n",
+ " return accuracy\n",
+ "\n",
+ "accuracy = calculate_accuracy(test_labels, predicted_labels)\n",
+ "print(f\"Accuracy: {accuracy:.2f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9p0fFXKTZz9Q",
+ "metadata": {
+ "id": "9p0fFXKTZz9Q"
+ },
+ "source": [
+ "The results aren't that good, which we can attribute to the small dataset size."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e681864c-f15a-40a6-ac34-0e631d68d5c8",
+ "metadata": {
+ "id": "e681864c-f15a-40a6-ac34-0e631d68d5c8"
+ },
+ "source": [
+ "### 7. Parameter Efficient Fine-Tuning Techniques"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9141fabe-417b-4fbb-bd3e-244ad84e3010",
+ "metadata": {
+ "id": "9141fabe-417b-4fbb-bd3e-244ad84e3010"
+ },
+ "source": [
+ "In this section, we demonstrate how to employ parameter-efficient fine-tuning (PEFT) techniques to adapt a pre-trained model for specific genomics tasks using the PEFT library."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "71b8a749-461e-4533-b1d0-cebc924d3dc0",
+ "metadata": {
+ "id": "71b8a749-461e-4533-b1d0-cebc924d3dc0"
+ },
+ "source": [
+ "The LoraConfig object is instantiated to configure the PEFT parameters:\n",
+ "\n",
+ "- task_type: Specifies the type of task, in this case, sequence classification (SEQ_CLS).\n",
+ "- r: The rank of the LoRA matrices.\n",
+ "- lora_alpha: Scaling factor for adaptive re-parameterization.\n",
+ "- target_modules: Modules within the model to apply PEFT re-parameterization (query, key, value in this example).\n",
+ "- lora_dropout: Dropout rate used during PEFT fine-tuning."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "id": "021641ae-f604-4d69-8724-743b7d7c613c",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "021641ae-f604-4d69-8724-743b7d7c613c",
+ "outputId": "d7c41fca-1c6b-46fd-9116-01f42d1d6ddf"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "DNA_LM(\n",
+ " (model): BertModel(\n",
+ " (embeddings): BertEmbeddings(\n",
+ " (word_embeddings): Embedding(5504, 768, padding_idx=0)\n",
+ " (position_embeddings): Embedding(512, 768)\n",
+ " (token_type_embeddings): Embedding(2, 768)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): BertEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0-11): 12 x BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSdpaSelfAttention(\n",
+ " (query): lora.Linear(\n",
+ " (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.01, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=768, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=768, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (key): lora.Linear(\n",
+ " (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.01, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=768, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=768, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (value): lora.Linear(\n",
+ " (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.01, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=768, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=768, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 40,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Number of classes for your classification task\n",
+ "num_labels = 2\n",
+ "classification_model = DNA_LM(lm, num_labels)\n",
+ "classification_model.to('cuda');"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 41,
+ "id": "6c223937-86ea-42ef-991a-050f23b21ef9",
+ "metadata": {
+ "id": "6c223937-86ea-42ef-991a-050f23b21ef9"
+ },
+ "outputs": [],
+ "source": [
+ "from peft import LoraConfig, TaskType\n",
+ "\n",
+ "peft_config = LoraConfig(\n",
+ " r=8,\n",
+ " lora_alpha=32,\n",
+ " target_modules=[\"query\", \"key\", \"value\"],\n",
+ " lora_dropout=0.01,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 42,
+ "id": "e7a9fe7d-e3ac-4ffa-9a9b-2067fb09b885",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "e7a9fe7d-e3ac-4ffa-9a9b-2067fb09b885",
+ "outputId": "02a6c65f-7474-4bc1-bfab-c05532e350a5"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "trainable params: 442,368 || all params: 90,121,730 || trainable%: 0.4909\n"
+ ]
+ }
+ ],
+ "source": [
+ "from peft import get_peft_model\n",
+ "\n",
+ "peft_model = get_peft_model(classification_model, peft_config)\n",
+ "peft_model.print_trainable_parameters()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 43,
+ "id": "22064519-eaab-4142-8618-d1210d05c6bd",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "22064519-eaab-4142-8618-d1210d05c6bd",
+ "outputId": "ca3f764d-cdb4-4525-c541-8eabfb4cde57"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "PeftModel(\n",
+ " (base_model): LoraModel(\n",
+ " (model): DNA_LM(\n",
+ " (model): BertModel(\n",
+ " (embeddings): BertEmbeddings(\n",
+ " (word_embeddings): Embedding(5504, 768, padding_idx=0)\n",
+ " (position_embeddings): Embedding(512, 768)\n",
+ " (token_type_embeddings): Embedding(2, 768)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): BertEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0-11): 12 x BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSdpaSelfAttention(\n",
+ " (query): lora.Linear(\n",
+ " (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.01, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=768, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=768, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (key): lora.Linear(\n",
+ " (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.01, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=768, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=768, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (value): lora.Linear(\n",
+ " (base_layer): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (lora_dropout): ModuleDict(\n",
+ " (default): Dropout(p=0.01, inplace=False)\n",
+ " )\n",
+ " (lora_A): ModuleDict(\n",
+ " (default): Linear(in_features=768, out_features=8, bias=False)\n",
+ " )\n",
+ " (lora_B): ModuleDict(\n",
+ " (default): Linear(in_features=8, out_features=768, bias=False)\n",
+ " )\n",
+ " (lora_embedding_A): ParameterDict()\n",
+ " (lora_embedding_B): ParameterDict()\n",
+ " )\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " (intermediate_act_fn): GELUActivation()\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
+ " )\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 43,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "peft_model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 45,
+ "id": "d3812e96-6b49-4911-8b21-d8871b7c06a5",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 268
+ },
+ "id": "d3812e96-6b49-4911-8b21-d8871b7c06a5",
+ "outputId": "8d497e30-1d3f-457a-f62a-244731698cb2"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [65/65 01:39, Epoch 5/5]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.625700 | \n",
+ " 0.777132 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.717200 | \n",
+ " 0.773871 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.768200 | \n",
+ " 0.771541 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.687400 | \n",
+ " 0.769679 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.552000 | \n",
+ " 0.768947 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": [
+ "TrainOutput(global_step=65, training_loss=0.74742647592838, metrics={'train_runtime': 100.8429, 'train_samples_per_second': 9.916, 'train_steps_per_second': 0.645, 'total_flos': 0.0, 'train_loss': 0.74742647592838, 'epoch': 5.0})"
+ ]
+ },
+ "execution_count": 45,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# Define training arguments\n",
+ "training_args = TrainingArguments(\n",
+ " output_dir='./results',\n",
+ " eval_strategy=\"epoch\",\n",
+ " learning_rate=2e-5,\n",
+ " per_device_train_batch_size=16,\n",
+ " per_device_eval_batch_size=16,\n",
+ " num_train_epochs=5,\n",
+ " weight_decay=0.01,\n",
+ " eval_steps=1,\n",
+ " logging_steps=1,\n",
+ ")\n",
+ "\n",
+ "# Initialize Trainer\n",
+ "trainer = Trainer(\n",
+ " model=peft_model.model,\n",
+ " args=training_args,\n",
+ " train_dataset=train_dataset,\n",
+ " eval_dataset=val_dataset,\n",
+ " tokenizer=tokenizer,\n",
+ " data_collator=data_collator,\n",
+ ")\n",
+ "\n",
+ "# Train the model\n",
+ "trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "76dbd948-d919-4ade-a405-cec297979577",
+ "metadata": {
+ "id": "76dbd948-d919-4ade-a405-cec297979577"
+ },
+ "source": [
+ "### 8. Evaluate PEFT Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 46,
+ "id": "58cf70ba-47d5-4111-bb12-830ae04c6285",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 124
+ },
+ "id": "58cf70ba-47d5-4111-bb12-830ae04c6285",
+ "outputId": "0abc56a9-bd68-4e4e-9f13-756e8c9ffa3e"
+ },
+ "outputs": [
+ {
+ "data": {
+ "text/html": [],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[1 0 1 0 0 1 1 0 1 1 1 1 0 1 1 1 0 1 0 0 1 0 0 0 0 1 0 0 0 0 0 1 1 0 0 1 1\n",
+ " 1 1 1 0 1 1 0 1 0 0 1 0 0 1 1 0 1 1 0 0 1 1 0 0 1 1 0 0 0 0 0 0 0 1 1 0 1\n",
+ " 1 0 1 0 0 1 1 0 1 0 1 0 1 0 0 1 1 0 0 0 1 1 1 0 1 1 0 1 0 0 1 1 0 1 1 1 0\n",
+ " 1 1 0 0 1 0 1 1 1 0 1 1 0 1 1 0 0 0 0 1 1 0 1 1 1 1 1 0 1 0 1 0 1 1 0 1 1\n",
+ " 0 1 1 1 1 1 1 1 0 1 1 0 1 0 0 0 0 0 0 1 1 0 0 0 1 1 1 1 1 0 0 1 0 1 0 1 0\n",
+ " 0 1 1 0 0 0 1 0 1 1 1 0 1 1 0]\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Generate predictions\n",
+ "\n",
+ "predictions = trainer.predict(test_dataset)\n",
+ "logits = predictions.predictions\n",
+ "predicted_labels = logits.argmax(axis=-1)\n",
+ "print(predicted_labels)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 47,
+ "id": "4bd38fe5-6513-4c88-afee-0cc4e1781fdd",
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "4bd38fe5-6513-4c88-afee-0cc4e1781fdd",
+ "outputId": "a50a91d0-d04d-4620-9006-868716bb992d"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Accuracy: 0.52\n"
+ ]
+ }
+ ],
+ "source": [
+ "def calculate_accuracy(true_labels, predicted_labels):\n",
+ "\n",
+ " assert len(true_labels) == len(predicted_labels), \"Arrays must have the same length\"\n",
+ " correct_predictions = np.sum(true_labels == predicted_labels)\n",
+ " accuracy = correct_predictions / len(true_labels)\n",
+ "\n",
+ " return accuracy\n",
+ "\n",
+ "accuracy = calculate_accuracy(test_labels, predicted_labels)\n",
+ "print(f\"Accuracy: {accuracy:.2f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "4ba5af69",
+ "metadata": {},
+ "source": [
+ "As we can see, the PEFT model achieves similar performance to the baseline model, demonstrating the effectiveness of PEFT in adapting pre-trained models to specific tasks with limited computational resources.\n",
+ "\n",
+ "With PEFT, we only train 442,368 parameters, which is 0.49% of the total parameters in the model. This is a significant reduction in computational resources compared to training the entire model from scratch.\n",
+ "\n",
+ "We can improve the results by using a larger dataset, fine-tuning the model for more epochs or changing the hyperparameters (rank, learning rate, etc.).\n"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "gpuType": "T4",
+ "provenance": []
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ },
+ "widgets": {
+ "application/vnd.jupyter.widget-state+json": {
+ "03bba232d3974119acf8031bc086a072": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_9107f7bfc8d3483390f802b0458e9380",
+ "IPY_MODEL_f5c80fa70ead4c86aa3b2a046061b901",
+ "IPY_MODEL_57966a469ca1458daab74e81672ae855"
+ ],
+ "layout": "IPY_MODEL_1464502dc3dd46308be8b4fcc9d5ddb9"
+ }
+ },
+ "078c6877377a491d97d6fadd27064a76": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "0838d19b226d486285a26ce0b04d7e15": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "12f1de7122a7471e90f01d9e7be81178": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "1464502dc3dd46308be8b4fcc9d5ddb9": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "17859b793a304e389d1ea0b9ccc3646f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_86bcccb842244f4f9add58f62facaace",
+ "placeholder": "",
+ "style": "IPY_MODEL_78b5bbf4c8ac4fe5961776fded4d5798",
+ "value": "Generating test split: 100%"
+ }
+ },
+ "1afa6e9b69c74136863b7747e62a0608": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "2d5466a5e98849c5a09f16faa98f91da": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_39640709e7174f84a50da05764abbf99",
+ "placeholder": "",
+ "style": "IPY_MODEL_7114a029e75c4ed5b966eddd3a3c919d",
+ "value": " 1497/1497 [00:00<00:00, 41394.98 examples/s]"
+ }
+ },
+ "33be6b0ca8fd44188f834a48a9574a72": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_12f1de7122a7471e90f01d9e7be81178",
+ "max": 390606,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_dad286d42a514c9ca6bb01bfe9e9c4be",
+ "value": 390606
+ }
+ },
+ "34921fd116cc42b7b530174d9f61e71e": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_c80062a855cb41a28ac625ab03635da2",
+ "max": 1497,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_aecd740c17c84d45b0615d4fc4196035",
+ "value": 1497
+ }
+ },
+ "39640709e7174f84a50da05764abbf99": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "4d4ce0d35c124690b3427e84a9a128b1": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_078c6877377a491d97d6fadd27064a76",
+ "placeholder": "",
+ "style": "IPY_MODEL_d46ee1c39bac44c2b541a88c883de1cb",
+ "value": "Downloading data: 100%"
+ }
+ },
+ "57966a469ca1458daab74e81672ae855": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_1afa6e9b69c74136863b7747e62a0608",
+ "placeholder": "",
+ "style": "IPY_MODEL_0838d19b226d486285a26ce0b04d7e15",
+ "value": " 3.50M/3.50M [00:00<00:00, 26.3MB/s]"
+ }
+ },
+ "57c9af47364d48ffbb4ffbdd2c951ede": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "682644a713b145f0b2dcff99790c6d4d": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "6d80dec073e449efba272fa9f3527922": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "7114a029e75c4ed5b966eddd3a3c919d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "74e9bc1ead434ae78077df6b85f1df58": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_c028ed977b5e479fbd93b8add588a6dc",
+ "placeholder": "",
+ "style": "IPY_MODEL_6d80dec073e449efba272fa9f3527922",
+ "value": " 391k/391k [00:00<00:00, 3.34MB/s]"
+ }
+ },
+ "78b5bbf4c8ac4fe5961776fded4d5798": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "7bdab33f4b244fc89408b91755bf17c5": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_4d4ce0d35c124690b3427e84a9a128b1",
+ "IPY_MODEL_33be6b0ca8fd44188f834a48a9574a72",
+ "IPY_MODEL_74e9bc1ead434ae78077df6b85f1df58"
+ ],
+ "layout": "IPY_MODEL_e1acc6e70b9246a5b063b3e262f01c81"
+ }
+ },
+ "86bcccb842244f4f9add58f62facaace": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "8f20ed2b74d84e80a8d403793354adea": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "9107f7bfc8d3483390f802b0458e9380": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_92f64c7e088342b9b3c070ba7a295ed0",
+ "placeholder": "",
+ "style": "IPY_MODEL_ab0aa8af3816422e9d97934f12af842c",
+ "value": "Downloading data: 100%"
+ }
+ },
+ "92f64c7e088342b9b3c070ba7a295ed0": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "952397f9c91c480184fa57e175ab1b4c": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "9b9b9d573d44464f9a6f5030a40245fe": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "ab0aa8af3816422e9d97934f12af842c": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "aecd740c17c84d45b0615d4fc4196035": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "bd5273325a4b453e8053d98a09fe9493": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c028ed977b5e479fbd93b8add588a6dc": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c311b777514f41ef986756a386c0bb34": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_e2e4bf053ce442f6aee6ffab5f76525f",
+ "IPY_MODEL_c88cf701e20b4354a63ac7d8645d1df9",
+ "IPY_MODEL_f71c252ada474be882b0335ed9a0a1c3"
+ ],
+ "layout": "IPY_MODEL_e059c665229e46ea905dcbd6fc179c88"
+ }
+ },
+ "c80062a855cb41a28ac625ab03635da2": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "c88cf701e20b4354a63ac7d8645d1df9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_57c9af47364d48ffbb4ffbdd2c951ede",
+ "max": 13468,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_fa9d75fcb1d5400c8ca1d1d13d28d0c7",
+ "value": 13468
+ }
+ },
+ "d46ee1c39bac44c2b541a88c883de1cb": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "DescriptionStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "DescriptionStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "description_width": ""
+ }
+ },
+ "dad286d42a514c9ca6bb01bfe9e9c4be": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "e059c665229e46ea905dcbd6fc179c88": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e113a50f8ed2410ca12ce7cb38a1681d": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "e1acc6e70b9246a5b063b3e262f01c81": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ },
+ "e2e4bf053ce442f6aee6ffab5f76525f": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_bd5273325a4b453e8053d98a09fe9493",
+ "placeholder": "",
+ "style": "IPY_MODEL_8f20ed2b74d84e80a8d403793354adea",
+ "value": "Generating train split: 100%"
+ }
+ },
+ "ec165fdbe87a4b00a6c288ef1e85c0a9": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HBoxModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HBoxModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HBoxView",
+ "box_style": "",
+ "children": [
+ "IPY_MODEL_17859b793a304e389d1ea0b9ccc3646f",
+ "IPY_MODEL_34921fd116cc42b7b530174d9f61e71e",
+ "IPY_MODEL_2d5466a5e98849c5a09f16faa98f91da"
+ ],
+ "layout": "IPY_MODEL_952397f9c91c480184fa57e175ab1b4c"
+ }
+ },
+ "f5c80fa70ead4c86aa3b2a046061b901": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "FloatProgressModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "FloatProgressModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "ProgressView",
+ "bar_style": "success",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_ff89a891bd9c42a8be164587a94ccac1",
+ "max": 3495021,
+ "min": 0,
+ "orientation": "horizontal",
+ "style": "IPY_MODEL_e113a50f8ed2410ca12ce7cb38a1681d",
+ "value": 3495021
+ }
+ },
+ "f71c252ada474be882b0335ed9a0a1c3": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "HTMLModel",
+ "state": {
+ "_dom_classes": [],
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "HTMLModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/controls",
+ "_view_module_version": "1.5.0",
+ "_view_name": "HTMLView",
+ "description": "",
+ "description_tooltip": null,
+ "layout": "IPY_MODEL_682644a713b145f0b2dcff99790c6d4d",
+ "placeholder": "",
+ "style": "IPY_MODEL_9b9b9d573d44464f9a6f5030a40245fe",
+ "value": " 13468/13468 [00:00<00:00, 193879.37 examples/s]"
+ }
+ },
+ "fa9d75fcb1d5400c8ca1d1d13d28d0c7": {
+ "model_module": "@jupyter-widgets/controls",
+ "model_module_version": "1.5.0",
+ "model_name": "ProgressStyleModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/controls",
+ "_model_module_version": "1.5.0",
+ "_model_name": "ProgressStyleModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "StyleView",
+ "bar_color": null,
+ "description_width": ""
+ }
+ },
+ "ff89a891bd9c42a8be164587a94ccac1": {
+ "model_module": "@jupyter-widgets/base",
+ "model_module_version": "1.2.0",
+ "model_name": "LayoutModel",
+ "state": {
+ "_model_module": "@jupyter-widgets/base",
+ "_model_module_version": "1.2.0",
+ "_model_name": "LayoutModel",
+ "_view_count": null,
+ "_view_module": "@jupyter-widgets/base",
+ "_view_module_version": "1.2.0",
+ "_view_name": "LayoutView",
+ "align_content": null,
+ "align_items": null,
+ "align_self": null,
+ "border": null,
+ "bottom": null,
+ "display": null,
+ "flex": null,
+ "flex_flow": null,
+ "grid_area": null,
+ "grid_auto_columns": null,
+ "grid_auto_flow": null,
+ "grid_auto_rows": null,
+ "grid_column": null,
+ "grid_gap": null,
+ "grid_row": null,
+ "grid_template_areas": null,
+ "grid_template_columns": null,
+ "grid_template_rows": null,
+ "height": null,
+ "justify_content": null,
+ "justify_items": null,
+ "left": null,
+ "margin": null,
+ "max_height": null,
+ "max_width": null,
+ "min_height": null,
+ "min_width": null,
+ "object_fit": null,
+ "object_position": null,
+ "order": null,
+ "overflow": null,
+ "overflow_x": null,
+ "overflow_y": null,
+ "padding": null,
+ "right": null,
+ "top": null,
+ "visibility": null,
+ "width": null
+ }
+ }
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}