From e83e328b7ce7c7435660ab9e6a9dd281be74b5c4 Mon Sep 17 00:00:00 2001 From: Nathaniel Rindlaub Date: Thu, 10 Nov 2022 23:42:19 -0800 Subject: [PATCH] add notebook for evaluating megadetector against real data --- .gitignore | 3 + api/megadetectorv5/README.md | 3 +- notebooks/evaluate-megadetector.ipynb | 542 ++++++++++++++++++++++++++ requirements.txt | 28 +- 4 files changed, 570 insertions(+), 6 deletions(-) create mode 100644 notebooks/evaluate-megadetector.ipynb diff --git a/.gitignore b/.gitignore index ba0f33e..9176a85 100644 --- a/.gitignore +++ b/.gitignore @@ -135,6 +135,9 @@ dmypy.json models/*/* !models/*/code +# ignore tmp +tmp/ + # container log files log.txt .DS_Store diff --git a/api/megadetectorv5/README.md b/api/megadetectorv5/README.md index 28e52cc..6a8f513 100644 --- a/api/megadetectorv5/README.md +++ b/api/megadetectorv5/README.md @@ -1,9 +1,8 @@ # Setup Instructions ## Download weights and torchscript model - +From root directory, run: ``` -cd animl-ml/api/megadetectorv5 aws s3 sync s3://animl-model-zoo/megadetectorv5/ models/megadetectorv5/ ``` diff --git a/notebooks/evaluate-megadetector.ipynb b/notebooks/evaluate-megadetector.ipynb new file mode 100644 index 0000000..ec6ff41 --- /dev/null +++ b/notebooks/evaluate-megadetector.ipynb @@ -0,0 +1,542 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluate Megadetector v5a hosted on Sagemaker Serverless\n", + "\n", + "Test running images through Megadetector v5a to obtain object bounding boxes and filter results at different confidence thresholds.\n", + "\n", + "*NOTE: This notebook is intended to be run locally, and assumes the following:*\n", + "- you are currently running a virtual env with Python 3.9\n", + "- you have configured the awscli with an account called \"animl\" with the requisite permissions to read from S3 and invoke Sagemaker endpoints\n", + "- you have a MongoDB Atlas URL/connection string with read permissions stored in a .env file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### MongoDB Atlas Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext dotenv\n", + "%dotenv\n", + "\n", + "import os\n", + "from pymongo import MongoClient\n", + "\n", + "MONGODB_URL = os.getenv(\"MONGODB_URL\")\n", + "\n", + "db_client = MongoClient(MONGODB_URL)\n", + "db = db_client[\"animl-prod\"]\n", + "images = db[\"images\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### AWS Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import boto3, time, json\n", + "import sagemaker\n", + "import os\n", + "\n", + "os.environ['AWS_PROFILE'] = \"animl\"\n", + "os.environ['AWS_DEFAULT_REGION'] = \"us-west-2\"\n", + "\n", + "sess = boto3.Session()\n", + "sm = sess.client(\"sagemaker\")\n", + "region = sess.region_name\n", + "account = boto3.client(\"sts\").get_caller_identity().get(\"Account\")\n", + "\n", + "img_bucket = \"animl-images-serving-prod\"\n", + "class_map = { 1: \"animal\", 2: \"person\", 3: \"vehicle\" }" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Check status of SageMaker endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%time\n", + "endpoint_name = \"megadetectorv5-torchserve-serverless-prod\"\n", + "resp = sm.describe_endpoint(EndpointName=endpoint_name)\n", + "status = resp[\"EndpointStatus\"]\n", + "print(\"Status: \" + status)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Query variables" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from datetime import datetime\n", + "\n", + "project = 'sci_biosecurity'\n", + "start = datetime(2022, 7, 16)\n", + "end = datetime(2022, 11, 1)\n", + "category = 'rodent'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from io import BytesIO\n", + "from PIL import Image, ImageDraw\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "def get_image_records(q):\n", + " img_count = images.count_documents(q)\n", + " print(f'found {img_count} image records')\n", + " img_rcrds = list(images.find(q))\n", + " return img_rcrds\n", + "\n", + "def download_image_files(img_rcrds):\n", + " print('Downloading image files to memory...')\n", + " ret = []\n", + " for rec in img_rcrds:\n", + " key = f'original/{rec[\"_id\"]}-original.jpg'\n", + " img = boto3.client(\"s3\").get_object(Bucket=img_bucket, Key=key)['Body'].read()\n", + " ret.append({ \"name\": rec[\"_id\"], \"data\": img })\n", + " print(f'Downloaded {len(ret)} images to memory')\n", + " return ret\n", + "\n", + "def detect_objects(imgs):\n", + " print('Submitting images to endpoint for object detection...')\n", + " client = boto3.client(\"runtime.sagemaker\")\n", + " ret = []\n", + " for i in range(len(imgs)):\n", + " response = client.invoke_endpoint(\n", + " EndpointName = endpoint_name,\n", + " ContentType = \"application/x-image\",\n", + " Body = imgs[i]['data']\n", + " )\n", + " response = json.loads(response[\"Body\"].read())\n", + " ret.append({ \"name\": imgs[i][\"name\"], \"objects\": response })\n", + " if i % 5 == 0:\n", + " print(f'successfully detected objects in image {i + 1}/{len(imgs)}')\n", + " return ret\n", + "\n", + "def filter_dets(imgs, conf, classes):\n", + " print(f'filtering detections below confidence threshold {conf}')\n", + " def func(obj): \n", + " if obj[\"confidence\"] < conf or obj[\"class\"] not in classes:\n", + " return False\n", + " else:\n", + " return True\n", + " for img in imgs:\n", + " img[\"filtered_objects\"] = list(filter(func, img[\"objects\"]))\n", + " return imgs\n", + "\n", + "def draw_bounding_box_on_image(image,ymin,xmin,ymax,xmax,classification):\n", + " color_map = { 1: 'red', 2: 'blue', 3: 'yellow' }\n", + " color = color_map.get(classification)\n", + " draw = ImageDraw.Draw(image)\n", + " im_width, im_height = image.size\n", + " (left, right, top, bottom) = (xmin * im_width, xmax * im_width,\n", + " ymin * im_height, ymax * im_height)\n", + " draw.line([(left, top), (left, bottom), (right, bottom),\n", + " (right, top), (left, top)], width=4, fill=color)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# functions for sequence grouping\n", + "\n", + "import uuid\n", + "\n", + "def stage_for_grouping(delta_index, index_array):\n", + " for i in [delta_index, delta_index + 1]:\n", + " if i not in index_array: \n", + " index_array.append(i)\n", + "\n", + "def group_as_sequence(dep_img_indexes, dep_df, images_df):\n", + " # use indices to get image ids from deployments DataFrame\n", + " img_ids = dep_df.iloc[dep_img_indexes]\n", + " img_ids = img_ids[\"_id\"].tolist()\n", + " # find the corresponding images records in the images DataFrame\n", + " # and assign them the same burstId\n", + " burstId = uuid.uuid4()\n", + " images_df.loc[images_df['_id'].isin(img_ids), 'burstId'] = burstId" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Associate image records with burst Ids\n", + " - pull all image records (for a specific project & within date range) into a DataFrame\n", + " - split out by deployment\n", + " - sort each deployment's image records chronologically\n", + " - create array of time deltas between each image\n", + " - iterate deltas, if the delta is <= some fixed delta limit (say, 2 seconds), treat them as being in the same burst\n", + " - as a sanity check, print out a list of all the images in chronological order along side an \"image is in burst\" or \"image is not in burst\" evaluation... the images IN bursts should be clustered together chronologically (assuming that setting could get turned on/off)\n", + " - other interesting stats would be: avg number of images in bursts, count of outliers (e.g. bursts w/ 4+ images or 2 images)\n", + "\n", + "End goal is be able to map an image to a burst, and get the rest of the images in that burst" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query = { \n", + " 'projectId': project,\n", + " 'dateAdded': { '$gt': start, '$lt': end }\n", + "}\n", + "\n", + "# read image records into DataFrame\n", + "raw_img_rcrds = get_image_records(query)\n", + "images_df = pd.DataFrame(raw_img_rcrds)\n", + "\n", + "# add burstId column, parse dateTimeOriginal values as datetime64, sort chronologically\n", + "images_df['burstId'] = None\n", + "images_df['dateTimeOriginal'] = images_df['dateTimeOriginal'].apply(pd.to_datetime)\n", + "images_df.sort_values('dateTimeOriginal', inplace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Pull out all possible dep_ids\n", + "deploymentIds = np.unique(images_df['deploymentId'].values)\n", + "print(f'identified {len(deploymentIds)} deployment(s)')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Iterate over deployments and group images into sequences\n", + "max_delta = 2 # seconds\n", + "\n", + "for deploymentId in deploymentIds:\n", + " # create deployment DataFrame\n", + " dep_df = images_df.loc[images_df['deploymentId'] == deploymentId]\n", + "\n", + " # get time deltas (as timedelta64's)\n", + " deltas = np.diff(dep_df['dateTimeOriginal']).astype('float64')\n", + " \n", + " # iterate over the deltas and group images by sequence\n", + " img_indexes_to_sequence = []\n", + " for i, delta in enumerate(deltas):\n", + " if delta/1e9 <= max_delta:\n", + " # the two images are part of same sequence\n", + " stage_for_grouping(i, img_indexes_to_sequence)\n", + " else:\n", + " # this is a gap between sequences\n", + " if len(img_indexes_to_sequence) > 0:\n", + " group_as_sequence(img_indexes_to_sequence, dep_df, images_df)\n", + " img_indexes_to_sequence = []\n", + "\n", + " if i == len(deltas) - 1:\n", + " # we've reached the last delta in the array, \n", + " # so group the last staged sequence if there is one\n", + " if len(img_indexes_to_sequence) > 0:\n", + " group_as_sequence(img_indexes_to_sequence, dep_df, images_df)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for deploymentId in deploymentIds:\n", + " dep_df = images_df.loc[images_df['deploymentId'] == deploymentId]\n", + " dep_df.to_csv(f'imgs_with_burst_ids-{deploymentId}.csv', index = True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Find false negatives" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### MongoDB query\n", + "This query is an attempt to Id Megadetector v5a false negatives. For more info: https://docs.google.com/spreadsheets/d/1xaMsICF-e97Ndgm8A9hkrxNRQkJofPQSGOgO9ML8wHU/edit#gid=0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "query = {\n", + " 'projectId': project,\n", + " 'dateAdded': { '$gt': start, '$lt': end },\n", + " 'objects': {\n", + " '$elemMatch': {\n", + " '$and': [\n", + " {'locked': True},\n", + " {'labels': {\n", + " '$elemMatch': {\n", + " '$and': [\n", + " {'type': 'ml'},\n", + " {'mlModel': 'megadetector'},\n", + " {'validation.validated': False},\n", + " {'category':'empty'}\n", + " ]\n", + " }\n", + " }}\n", + " ]\n", + " }\n", + " },\n", + " 'objects.labels': {\n", + " '$elemMatch': {\n", + " '$and': [\n", + " {'type': 'manual'},\n", + " {'validation.validated': True},\n", + " {'category': category}\n", + " ]\n", + " }\n", + " }\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Read image records & image files into memory, submit to MDv5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img_rcrds = get_image_records(query)\n", + "imgs = download_image_files(img_rcrds)\n", + "img_detections = detect_objects(imgs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Filter detections below confidence threshold" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + " # class schema we use is 1 for animal, 2 for person, 3 for vehicle\n", + "conf = 0.1\n", + "classes_to_include = [1,2] # supress vehicles\n", + "\n", + "imgs_with_filtered_detections = filter_dets(img_detections, conf, classes_to_include)\n", + "\n", + "count = 0 \n", + "imgs_that_would_have_had_detections_if_conf_was_lower = []\n", + "for i, img in enumerate(imgs_with_filtered_detections):\n", + " if len(img[\"filtered_objects\"]) > 0:\n", + " imgs_that_would_have_had_detections_if_conf_was_lower.append(img[\"name\"])\n", + " for obj in img[\"filtered_objects\"]:\n", + " print(f'{i} --- {img[\"name\"]} --- {obj[\"class\"]} --- {obj[\"confidence\"]}')\n", + " count = count + 1\n", + "\n", + "print(f'found {count} objects with detections above {conf}')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Check false negatives\n", + "for true positivies in their respective bursts" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "imgs_to_check_bursts = []\n", + "for rec in img_rcrds:\n", + " if rec[\"_id\"] not in imgs_that_would_have_had_detections_if_conf_was_lower:\n", + " imgs_to_check_bursts.append(rec)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def check_img_for_true_positive(img):\n", + " # return true if image has an object w/ a megadetector label AND\n", + " # a validated label of our desired class\n", + " ret = False\n", + " for obj in img.objects:\n", + " has_md_label = False\n", + " has_manual_label = False\n", + " for lbl in obj[\"labels\"]:\n", + " if (lbl[\"type\"] == \"ml\" and \n", + " lbl[\"mlModel\"] == \"megadetector\"):\n", + " has_md_label = True\n", + " if (lbl[\"category\"] == category and \n", + " \"validation\" in lbl and \n", + " lbl[\"validation\"][\"validated\"] == True):\n", + " has_manual_label = True\n", + " if has_md_label and has_manual_label:\n", + " ret = True\n", + " return ret\n", + "\n", + "def check_burst_for_true_positives(img_rcrd):\n", + " # print(f'checking img {img_rcrd[\"_id\"]}')\n", + " # find img's burstId,\n", + " burstId = images_df.loc[images_df['_id'] == img_rcrd['_id'], 'burstId'].tolist()\n", + " # print(f'burstId: {burstId[0]}')\n", + "\n", + " # find rest of images in burst, filter out this img\n", + " imgs_in_burst = images_df.loc[images_df['burstId'] == burstId[0]]\n", + " # print(f'images in burst: \\n{imgs_in_burst[\"_id\"]}')\n", + "\n", + " # for each remaining image, check for true positive\n", + " burst_has_true_positive = False\n", + " for row in imgs_in_burst.itertuples():\n", + " has_true_postitive = check_img_for_true_positive(row)\n", + " if has_true_postitive:\n", + " burst_has_true_positive = True\n", + " \n", + " return burst_has_true_positive\n", + "\n", + "# check the bursts of all remaining false negatives\n", + "# (i.e., those that would have still been missed even with a lower conf. threshold)\n", + "# for true positives\n", + "detection_found_in_burst_count = 0\n", + "for img in imgs_to_check_bursts:\n", + " burst_has_true_positive = check_burst_for_true_positives(img)\n", + " if burst_has_true_positive:\n", + " detection_found_in_burst_count = detection_found_in_burst_count + 1\n", + "\n", + "print(f'found {detection_found_in_burst_count} true positives in bursts of {len(imgs_to_check_bursts)} images with false negatives')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Spot-check individual images & objects" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "img_index = 153\n", + "img_to_draw = imgs_with_filtered_detections[img_index]\n", + "image = Image.open(BytesIO(imgs[img_index]['data']))\n", + "\n", + "print(f'{img_index} --- {img_to_draw[\"name\"]}')\n", + "for obj in img_to_draw[\"filtered_objects\"]:\n", + " print(f'object --- class: {obj[\"class\"]} ({class_map[obj[\"class\"]]}), confidence: {obj[\"confidence\"]}')\n", + " draw_bounding_box_on_image(image, obj[\"y1\"], obj[\"x1\"], obj[\"y2\"], obj[\"x2\"], obj[\"class\"])\n", + "image" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.9.0 ('env')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.0" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "9f4c0846dd0921c49a84287b2417a4baea3589afd0a49b45186be44f93694c14" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements.txt b/requirements.txt index 8b37e46..3c52cfc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,23 @@ appnope==0.1.2 +attrs==22.1.0 azure-core==1.9.0 azure-storage-blob==12.6.0 backcall==0.2.0 -boto3==1.12.26 -botocore==1.15.49 +boto3==1.26.4 +botocore==1.29.4 certifi==2019.11.28 cffi==1.14.4 chardet==3.0.4 +contextlib2==21.6.0 cryptography==3.3.2 cycler==0.10.0 decorator==4.4.2 +dill==0.3.6 +dnspython==2.2.1 docutils==0.15.2 +google-pasta==0.2.0 idna==2.9 +importlib-metadata==4.13.0 ipykernel==5.3.4 ipython==7.19.0 ipython-genutils==0.2.0 @@ -23,26 +29,40 @@ jupyter-core==4.7.0 kiwisolver==1.1.0 matplotlib==3.2.1 msrest==0.6.19 -numpy==1.18.2 +multiprocess==0.70.14 +numpy==1.23.4 oauthlib==3.1.0 +packaging==21.3 +pandas==1.5.1 parso==0.7.1 +pathos==0.3.0 pexpect==4.8.0 pickleshare==0.7.5 Pillow==8.3.2 +pox==0.3.2 +ppft==1.7.6.6 prompt-toolkit==3.0.8 +protobuf==3.20.3 +protobuf3-to-dict==0.1.5 ptyprocess==0.6.0 pycparser==2.20 Pygments==2.7.4 +pymongo==4.3.2 pyparsing==2.4.6 python-dateutil==2.8.1 python-dotenv==0.15.0 +pytz==2022.6 pyzmq==20.0.0 requests==2.25.0 requests-oauthlib==1.3.0 requests-toolbelt==0.9.1 -s3transfer==0.3.3 +s3transfer==0.6.0 +sagemaker==2.116.0 +schema==0.7.5 six==1.14.0 +smdebug-rulesconfig==1.0.1 tornado==6.1 traitlets==5.0.5 urllib3==1.25.11 wcwidth==0.2.5 +zipp==3.10.0