Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

finished helper functions #44

Merged
merged 1 commit into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
The diff you're trying to view is too large. We only load the first 3000 changed files.
228 changes: 126 additions & 102 deletions tutorials/W1D1_Generalization/W1D1_Tutorial1.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,130 @@
"outputs": [],
"source": [
"# Install \n",
"!pip install transformers gradio sentencepiece\n",
"!pip install numpy \n",
"\n"
"#!pip install transformers gradio sentencepiece numpy torch torchvision trdg Pillow==9.5.0\n",
"\n",
"# Core Python Data Science and Image Processing Libraries\n",
"import numpy as np\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Deep Learning and Model Specific Libraries\n",
"import torch\n",
"from torchvision import transforms\n",
"from transformers import TrOCRProcessor, VisionEncoderDecoderModel\n",
"\n",
"# Utility and Interface Libraries\n",
"import gradio as gr\n",
"from IPython.display import IFrame\n",
"from trdg.generators import GeneratorFromStrings\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b00326ca-807d-460f-adf4-767e94bc0ccc",
"id": "1bf34b9a-1dd5-458a-b390-0fa12609d532",
"metadata": {},
"outputs": [],
"source": [
"from transformers import TrOCRProcessor, VisionEncoderDecoderModel\n",
"import gradio as gr\n",
"# @title Plotting functions\n",
"\n",
"def display_image(image_path):\n",
" \"\"\"Display an image from a given file path.\n",
"\n",
" Args:\n",
" image_path (str): The path to the image file.\n",
" \"\"\"\n",
" # Open the image\n",
" image = Image.open(image_path)\n",
" if image.mode != 'RGB':\n",
" image = image.convert('RGB')\n",
"\n",
" # Display the image\n",
" plt.imshow(image)\n",
" plt.axis('off') # Turn off the axis\n",
" plt.show()\n",
"\n",
"def display_transformed_images(image, transformations):\n",
" \"\"\"\n",
" Apply a list of transformations to an image and display them.\n",
"\n",
" Args:\n",
" image (Tensor): The input image as a tensor.\n",
" transformations (list): A list of torchvision transformations to apply.\n",
" \"\"\"\n",
" # Convert tensor image to PIL Image for display\n",
" pil_image = transforms.ToPILImage()(image)\n",
"\n",
" fig, axs = plt.subplots(len(transformations) + 1, 1, figsize=(5, 15))\n",
" axs[0].imshow(pil_image)\n",
" axs[0].set_title('Original')\n",
" axs[0].axis('off')\n",
"\n",
" for i, transform in enumerate(transformations):\n",
" # Apply transformation if it's not the placeholder\n",
" if transform != \"Custom ElasticTransform Placeholder\":\n",
" transformed_image = transform(image)\n",
" # Convert transformed tensor image to PIL Image for display\n",
" display_image = transforms.ToPILImage()(transformed_image)\n",
" axs[i+1].imshow(display_image)\n",
" axs[i+1].set_title(transform.__class__.__name__)\n",
" axs[i+1].axis('off')\n",
" else:\n",
" axs[i+1].text(0.5, 0.5, 'ElasticTransform Placeholder', ha='center')\n",
" axs[i+1].axis('off')\n",
"\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"def display_original_and_transformed_images(original_tensor, transformed_tensor):\n",
" \"\"\"\n",
" Display the original and transformed images side by side.\n",
"\n",
" Args:\n",
" original_tensor (Tensor): The original image as a tensor.\n",
" transformed_tensor (Tensor): The transformed image as a tensor.\n",
" \"\"\"\n",
" fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n",
"\n",
" # Display original image\n",
" original_image = original_tensor.permute(1, 2, 0) # Convert from (C, H, W) to (H, W, C)\n",
" axs[0].imshow(original_image)\n",
" axs[0].set_title('Original')\n",
" axs[0].axis('off')\n",
"\n",
" # Display transformed image\n",
" transformed_image = transformed_tensor.permute(1, 2, 0) # Convert from (C, H, W) to (H, W, C)\n",
" axs[1].imshow(transformed_image)\n",
" axs[1].set_title('Transformed')\n",
" axs[1].axis('off')\n",
"\n",
" plt.show()\n",
"\n",
"def display_generated_images(generator):\n",
" \"\"\"\n",
" Display images generated from strings.\n",
"\n",
" Args:\n",
" generator (GeneratorFromStrings): A generator that produces images from strings.\n",
" \"\"\"\n",
" plt.figure(figsize=(15, 3))\n",
" for i, (text_img, lbl) in enumerate(generator, 1):\n",
" ax = plt.subplot(1, len(generator.strings) * generator.count // len(generator.strings), i)\n",
" plt.imshow(text_img)\n",
" plt.title(f\"Example {i}\")\n",
" plt.axis('off')\n",
" \n",
" plt.tight_layout()\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b00326ca-807d-460f-adf4-767e94bc0ccc",
"metadata": {},
"outputs": [],
"source": [
"# Load the pre-trained TrOCR model and processor\n",
"model = VisionEncoderDecoderModel.from_pretrained(\"microsoft/trocr-small-handwritten\")\n",
"processor = TrOCRProcessor.from_pretrained(\"microsoft/trocr-small-handwritten\")\n",
Expand All @@ -43,7 +152,6 @@
" description=\"Demo for Microsoft’s TrOCR, an encoder-decoder model for OCR on single-text line images.\",\n",
")\n",
"\n",
"\n",
"# Launch the interface\n",
"interface.launch()\n"
]
Expand All @@ -55,9 +163,6 @@
"metadata": {},
"outputs": [],
"source": [
"#Imports\n",
"from transformers import TrOCRProcessor, VisionEncoderDecoderModel\n",
"\n",
"# Load the model\n",
"model = VisionEncoderDecoderModel.from_pretrained(\"microsoft/trocr-small-handwritten\")\n"
]
Expand Down Expand Up @@ -101,18 +206,7 @@
"# Count parameters in the decoder\n",
"decoder_params = count_parameters(model.decoder)\n",
"\n",
"encoder_params, decoder_params\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bf37c12f-77c3-4b90-841d-901b0ee4bb8c",
"metadata": {},
"outputs": [],
"source": [
"# Imports\n",
"import numpy as np"
"encoder_params, decoder_params"
]
},
{
Expand Down Expand Up @@ -253,41 +347,16 @@
"time_to_write_lifetimes_llama = calculate_writing_time(total_words_llama2, words_per_day, days_per_week, weeks_per_year, average_human_lifespan)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "76cafcd3-4c19-4a5d-a9b5-c0c65e1c777a",
"metadata": {},
"outputs": [],
"source": [
"#Install and imports\n",
"!pip install torch torchvision trdg\n",
"!pip install Pillow==9.5.0\n",
"import torch\n",
"from torchvision import transforms\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "78127c97-955e-46e0-869d-7f7285449f70",
"metadata": {},
"outputs": [],
"source": [
"# Path to the image\n",
"# Usage\n",
"image_path = '../static/img_1235.jpg'\n",
"\n",
"# Open the image\n",
"image = Image.open(image_path)\n",
"if image.mode != 'RGB':\n",
" image = image.convert('RGB')\n",
"\n",
"# Display the image\n",
"plt.imshow(image)\n",
"plt.axis('off') # Turn off the axis\n",
"plt.show()"
"display_image(image_path)"
]
},
{
Expand All @@ -298,8 +367,8 @@
"outputs": [],
"source": [
"# Convert PIL Image to Tensor\n",
"to_tensor = transforms.ToTensor()\n",
"image = to_tensor(image)\n",
"image = Image.open(image_path)\n",
"image = transforms.ToTensor()(image)\n",
"\n",
"# Define each transformation separately\n",
"# RandomAffine: applies rotations, translations, scaling. Here, rotates by up to ±15 degrees,\n",
Expand All @@ -321,26 +390,7 @@
"# A list of all transformations for iteration\n",
"transformations = [affine, elastic, perspective, erasing, gaussian_blur]\n",
"\n",
"# Vertical stacking of transformations\n",
"fig, axs = plt.subplots(len(transformations) + 1, 1, figsize=(5, 15)) \n",
"\n",
"# Permute the image dimensions from (C, H, W) to (H, W, C) for display\n",
"display_image = image.permute(1, 2, 0)\n",
"\n",
"axs[0].imshow(display_image)\n",
"axs[0].set_title('Original')\n",
"axs[0].axis('off')\n",
"\n",
"for i, transform in enumerate(transformations, 1):\n",
" augmented_image = transforms.Compose([transform])(image)\n",
" # Permute the augmented image dimensions from (C, H, W) to (H, W, C) for display\n",
" display_augmented = augmented_image.permute(1, 2, 0)\n",
"\n",
" axs[i].imshow(display_augmented)\n",
" axs[i].set_title(transform.__class__.__name__)\n",
" axs[i].axis('off')\n",
"\n",
"plt.show()"
"display_transformed_images(image, transformations)"
]
},
{
Expand All @@ -357,8 +407,7 @@
"image = Image.open(image_path)\n",
"\n",
"# Convert PIL Image to Tensor\n",
"to_tensor = transforms.ToTensor()\n",
"image_tensor = to_tensor(image)\n",
"image_tensor = transforms.ToTensor()(image)\n",
"\n",
"# Define your transformations here\n",
"affine = transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1))\n",
Expand All @@ -379,20 +428,8 @@
"# Apply combined transformation\n",
"augmented_image_tensor = all_transforms(image_tensor)\n",
"\n",
"# Set up matplotlib subplots\n",
"fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n",
"\n",
"# Display original image\n",
"axs[0].imshow(image_tensor.permute(1, 2, 0)) # Change to (H, W, C) for display\n",
"axs[0].set_title('Original')\n",
"axs[0].axis('off')\n",
"\n",
"# Display augmented image\n",
"axs[1].imshow(augmented_image_tensor.permute(1, 2, 0)) # Change to (H, W, C) for display\n",
"axs[1].set_title('Transformed')\n",
"axs[1].axis('off')\n",
"\n",
"plt.show()\n"
"# Assuming 'image_tensor' and 'augmented_image_tensor' are defined as in your snippet\n",
"display_original_and_transformed_images(image_tensor, augmented_image_tensor)\n"
]
},
{
Expand All @@ -402,9 +439,6 @@
"metadata": {},
"outputs": [],
"source": [
"from trdg.generators import GeneratorFromStrings\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Define your strings\n",
"strings = ['Hello', 'This is Patrick', 'From NMA'] # Update this list as needed\n",
"\n",
Expand All @@ -417,16 +451,8 @@
" fonts=['Purisa'] # Update or add more fonts as needed\n",
")\n",
"\n",
"# Setup matplotlib figure and display images\n",
"plt.figure(figsize=(15, 3))\n",
"for i, (text_img, lbl) in enumerate(generator, 1):\n",
" ax = plt.subplot(1, 5, i) # Adjust the number of subplots if needed\n",
" plt.imshow(text_img)\n",
" plt.title(f\"Example {i}\")\n",
" plt.axis('off')\n",
"\n",
"plt.tight_layout()\n",
"plt.show()\n"
"# Call the function with your generator\n",
"display_generated_images(generator)"
]
},
{
Expand All @@ -436,8 +462,6 @@
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import IFrame\n",
"\n",
"IFrame(\"https://www.calligrapher.ai/\", width=800, height=600)"
]
}
Expand Down
Loading
Loading