Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fp32 to in8 conversion notebook. Removed imread-from-url library dependency. Minor fixes #2

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gitignore.txt → .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ dmypy.json
# Cython debug symbols
cython_debug/

# vscode
.vscode

# PyCharm
# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
Expand Down
38 changes: 20 additions & 18 deletions annotation_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from sam2 import SAM2Image, draw_masks, colors
from imread_from_url import imread_from_url

class ImageAnnotationApp:
def __init__(self, root, sam2: SAM2Image):
self.root = root
Expand Down Expand Up @@ -47,8 +46,9 @@ def __init__(self, root, sam2: SAM2Image):
self.add_label(0) # Add default label 1
img_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/e/e2/Dexter_professionellt_fotograferad.jpg/1280px-Dexter_professionellt_fotograferad.jpg"
self.image = imread_from_url(img_url)
self.mask_image = self.image.copy()
self.sam2.set_image(self.image)
if self.image is not None:
self.mask_image = self.image.copy()
self.sam2.set_image(self.image)
self.display_image()

def browse_image(self):
Expand All @@ -64,19 +64,20 @@ def browse_image(self):
self.display_image()

def display_image(self):
if self.mask_image.shape[0] == 0:
return

# Convert the image to RGB (from BGR)
rgb_image = cv2.cvtColor(self.mask_image, cv2.COLOR_BGR2RGB)
# Convert the image to PIL format
pil_image = Image.fromarray(rgb_image)
self.tk_image = ImageTk.PhotoImage(image=pil_image)
self.canvas.config(width=self.tk_image.width(), height=self.tk_image.height())
self.canvas.create_image(0, 0, anchor=tk.NW, image=self.tk_image)
self.draw_points()

def add_label(self, label_id: int = None):
if self.mask_image is not None:
if self.mask_image.shape[0] == 0:
return

# Convert the image to RGB (from BGR)
rgb_image = cv2.cvtColor(self.mask_image, cv2.COLOR_BGR2RGB)
# Convert the image to PIL format
pil_image = Image.fromarray(rgb_image)
self.tk_image = ImageTk.PhotoImage(image=pil_image)
self.canvas.config(width=self.tk_image.width(), height=self.tk_image.height())
self.canvas.create_image(0, 0, anchor=tk.NW, image=self.tk_image)
self.draw_points()

def add_label(self, label_id: int | None):
if label_id is None:
max_label = max(self.label_ids) if self.label_ids else 0

Expand Down Expand Up @@ -130,6 +131,7 @@ def remove_label(self):
self.canvas.delete(point[0])
self.points.remove(point)


masks = self.sam2.update_mask()
self.mask_image = draw_masks(self.image, masks)
self.display_image()
Expand Down Expand Up @@ -247,8 +249,8 @@ def reset(self):
if __name__ == "__main__":
root = tk.Tk()

encoder_model_path = "models/sam2_hiera_base_plus_encoder.onnx"
decoder_model_path = "models/sam2_hiera_base_plus_decoder.onnx"
encoder_model_path = "models/sam2_hiera_tiny_encoder.onnx"
decoder_model_path = "models/sam2_hiera_tiny_decoder.onnx"
sam2 = SAM2Image(encoder_model_path, decoder_model_path)

app = ImageAnnotationApp(root, sam2)
Expand Down
27 changes: 17 additions & 10 deletions image_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,29 @@
sam2 = SAM2Image(encoder_model_path, decoder_model_path)

# Set image
sam2.set_image(img)
if img:
sam2.set_image(img)

# Add points
point_coords = [np.array([[420, 440]]), np.array([[360, 275], [370, 210]]), np.array([[810, 440]]),
np.array([[920, 314]])]
point_labels = [np.array([1]), np.array([1, 1]), np.array([1]), np.array([1])]
# Add points
point_coords = [np.array([[420, 440]]), np.array([[360, 275], [370, 210]]), np.array([[810, 440]]),
np.array([[920, 314]])]
point_labels = [np.array([1]), np.array([1, 1]), np.array([1]), np.array([1])]

for label_id, (point_coord, point_label) in enumerate(zip(point_coords, point_labels)):
for i in range(point_label.shape[0]):
sam2.add_point((point_coord[i][0], point_coord[i][1]), point_label[i], label_id)

for label_id, (point_coord, point_label) in enumerate(zip(point_coords, point_labels)):
for i in range(point_label.shape[0]):
sam2.add_point((point_coord[i][0], point_coord[i][1]), point_label[i], label_id)

# Decode image
masks = sam2.update_mask()

# Draw masks
masked_img = draw_masks(img, masks)
# Draw masks
masked_img = draw_masks(img, masks)

cv2.imshow("masked_img", masked_img)
if cv2.waitKey(1000) & 0xFF == ord('q'):
break


cv2.imshow("masked_img", masked_img)
if cv2.waitKey(1000) & 0xFF == ord('q'):
Expand Down
3 changes: 3 additions & 0 deletions notebooks/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
onnx==1.16.1
onnxruntime
onnxsim
259 changes: 259 additions & 0 deletions notebooks/sam2_onnx_quant.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Quantizing the model to Int 8"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
"import onnx\n",
"import onnxruntime as ort\n",
"from onnxruntime import quantization\n",
"from onnx.onnx_ml_pb2 import ModelProto\n",
"from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession\n",
"from onnxruntime.quantization.quant_utils import QuantType\n",
"\n",
"from pathlib import Path\n",
"import os\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"MODEL_DIR: Path = Path(Path.cwd()).parent.joinpath(\n",
" \"models\"\n",
") # save the original fp32 models in the models directory"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"decoder_onnx_path: Path = MODEL_DIR.joinpath(\"sam2_hiera_tiny_decoder.onnx\")\n",
"encoder_onnx_path: Path = MODEL_DIR.joinpath(\"sam2_hiera_tiny_encoder.onnx\")\n",
"\n",
"decoder_onnx: ModelProto = onnx.load(decoder_onnx_path)\n",
"encoder_onnx: ModelProto = onnx.load(encoder_onnx_path)\n",
"onnx.checker.check_model(encoder_onnx)\n",
"onnx.checker.check_model(decoder_onnx)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"ort_provider: list[str] = [\"CPUExecutionProvider\"]\n",
"\n",
"ort_sess_encoder: InferenceSession = ort.InferenceSession(\n",
" decoder_onnx_path, providers=ort_provider\n",
")\n",
"ort_sess_decoder: InferenceSession = ort.InferenceSession(\n",
" encoder_onnx_path, providers=ort_provider\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"encoder_prep_path = MODEL_DIR.joinpath(\"sam2_hiera_tiny_encoder.onnx\")\n",
"decoder_prep_path = MODEL_DIR.joinpath(\"sam2_hiera_tiny_decoder.onnx\")\n",
"\n",
"quantization.shape_inference.quant_pre_process(\n",
" encoder_onnx_path, encoder_prep_path, skip_symbolic_shape=False\n",
")\n",
"quantization.shape_inference.quant_pre_process(\n",
" decoder_onnx_path, decoder_prep_path, skip_symbolic_shape=True\n",
") # skippinng symbolic shape as it is mostly useful for transformers based models"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"ops_dec = set()\n",
"for node in decoder_onnx.graph.node:\n",
" ops_dec.add(node.op_type)\n",
"\n",
"ops_enc = set()\n",
"for node in encoder_onnx.graph.node:\n",
" ops_enc.add(node.op_type)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"op_to_quantize_dec = [x for x in list(ops_enc) if not x.lower().startswith(\"conv\")]\n",
"# removing conv layer as it is giving an error while converting it int8 because of an ONNX issue.\n",
"\n",
"\n",
"op_to_quantize_enc = [x for x in list(ops_dec) if not x.lower().startswith(\"conv\")]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"quantized_encoder_path: Path = MODEL_DIR.joinpath(\"sam2_hiera_tiny_encoder_quant.onnx\")\n",
"quantized_decoder_path: Path = MODEL_DIR.joinpath(\"sam2_hiera_tiny_decoder_quant.onnx\")\n",
"\n",
"quantization.quantize_dynamic(\n",
" encoder_prep_path,\n",
" quantized_encoder_path,\n",
" weight_type=QuantType.QInt8,\n",
" op_types_to_quantize=op_to_quantize_enc,\n",
") # Make weight_type = QuantType.QUInt8 if you do not wish to leave the conv layers in \"unquantized\"\n",
"quantization.quantize_dynamic(\n",
" decoder_prep_path,\n",
" quantized_decoder_path,\n",
" weight_type=QuantType.QInt8,\n",
" op_types_to_quantize=op_to_quantize_dec,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Simplify quantized models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# TODO: Simplify quantized models using onnxsim"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Loading quantized model"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"decoder_quant_onnx: ModelProto = onnx.load(quantized_decoder_path)\n",
"encoder_quant_onnx: ModelProto = onnx.load(quantized_encoder_path)\n",
"onnx.checker.check_model(decoder_quant_onnx)\n",
"onnx.checker.check_model(encoder_quant_onnx)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [],
"source": [
"ort_provider: list[str] = [\"CPUExecutionProvider\"]\n",
"\n",
"ort_sess_encoder: InferenceSession = ort.InferenceSession(\n",
" quantized_encoder_path, providers=ort_provider\n",
")\n",
"ort_sess_decoder: InferenceSession = ort.InferenceSession(\n",
" quantized_decoder_path, providers=ort_provider\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Comparing size reduction of fp32 to int8 models"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Size of int8 encoder: 52776 KB\n",
"Size of fp32 encoder: 131115 KB\n",
"Size of int8 encoder: 8597 KB\n",
"Size of fp32 encoder: 20152 KB\n",
"Reduction in size in percentage for encoder: 59.748 %\n",
"Reduction in size in percentage for decoder: 57.339 %\n"
]
}
],
"source": [
"size_encoder_quant: int = os.path.getsize(quantized_encoder_path) // 1024\n",
"size_encoder: int = os.path.getsize(encoder_onnx_path) // 1024\n",
"\n",
"size_decoder_quant: int = os.path.getsize(quantized_decoder_path) // 1024\n",
"size_decoder: int = os.path.getsize(decoder_onnx_path) // 1024\n",
"\n",
"\n",
"print(f\"Size of int8 encoder: {size_encoder_quant} KB\")\n",
"print(f\"Size of fp32 encoder: {size_encoder} KB\")\n",
"\n",
"print(f\"Size of int8 encoder: {size_decoder_quant} KB\")\n",
"print(f\"Size of fp32 encoder: {size_decoder} KB\")\n",
"\n",
"reduction_enc: float = round(\n",
" ((size_encoder - size_encoder_quant) / size_encoder) * 100, 3\n",
")\n",
"reduction_dec: float = round(\n",
" ((size_decoder - size_decoder_quant) / size_decoder) * 100, 3\n",
")\n",
"print(f\"Reduction in size in percentage for encoder: {reduction_enc} %\")\n",
"print(f\"Reduction in size in percentage for decoder: {reduction_dec} %\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
opencv-python
imread-from-url
onnxruntime-gpu
cap-from-youtube
Pillow
onnxruntime
imread_from_url
Loading