-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Sets matplotlib backend to agg before rendering math (#4029)
* matplotlib-agg * fix * context manager * Update CHANGELOG.md * update demos * linting * removed warning * fix test * fixes * fix tests
- Loading branch information
1 parent
aca91b5
commit 92c95b6
Showing
21 changed files
with
153 additions
and
132 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: blocks_interpretation"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio shap matplotlib transformers torch"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import shap\n", "from transformers import pipeline\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "matplotlib.use('Agg')\n", "\n", "\n", "sentiment_classifier = pipeline(\"text-classification\", return_all_scores=True)\n", "\n", "\n", "def classifier(text):\n", " pred = sentiment_classifier(text)\n", " return {p[\"label\"]: p[\"score\"] for p in pred[0]}\n", "\n", "\n", "def interpretation_function(text):\n", " explainer = shap.Explainer(sentiment_classifier)\n", " shap_values = explainer([text])\n", " # Dimensions are (batch size, text size, number of classes)\n", " # Since we care about positive sentiment, use index 1\n", " scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1]))\n", "\n", " scores_desc = sorted(scores, key=lambda t: t[1])[::-1]\n", "\n", " # Filter out empty string added by shap\n", " scores_desc = [t for t in scores_desc if t[0] != \"\"]\n", "\n", " fig_m = plt.figure()\n", " plt.bar(x=[s[0] for s in scores_desc[:5]],\n", " height=[s[1] for s in scores_desc[:5]])\n", " plt.title(\"Top words contributing to positive sentiment\")\n", " plt.ylabel(\"Shap Value\")\n", " plt.xlabel(\"Word\")\n", " return {\"original\": text, \"interpretation\": scores}, fig_m\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " input_text = gr.Textbox(label=\"Input Text\")\n", " with gr.Row():\n", " classify = gr.Button(\"Classify Sentiment\")\n", " interpret = gr.Button(\"Interpret\")\n", " with gr.Column():\n", " label = gr.Label(label=\"Predicted Sentiment\")\n", " with gr.Column():\n", " with gr.Tab(\"Display interpretation with built-in component\"):\n", " interpretation = gr.components.Interpretation(input_text)\n", " with gr.Tab(\"Display interpretation with plot\"):\n", " interpretation_plot = gr.Plot()\n", "\n", " classify.click(classifier, input_text, label)\n", " interpret.click(interpretation_function, input_text, [interpretation, interpretation_plot])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} | ||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: blocks_interpretation"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio shap matplotlib transformers torch"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import shap\n", "from transformers import pipeline\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "sentiment_classifier = pipeline(\"text-classification\", return_all_scores=True)\n", "\n", "\n", "def classifier(text):\n", " pred = sentiment_classifier(text)\n", " return {p[\"label\"]: p[\"score\"] for p in pred[0]}\n", "\n", "\n", "def interpretation_function(text):\n", " explainer = shap.Explainer(sentiment_classifier)\n", " shap_values = explainer([text])\n", " # Dimensions are (batch size, text size, number of classes)\n", " # Since we care about positive sentiment, use index 1\n", " scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1]))\n", "\n", " scores_desc = sorted(scores, key=lambda t: t[1])[::-1]\n", "\n", " # Filter out empty string added by shap\n", " scores_desc = [t for t in scores_desc if t[0] != \"\"]\n", "\n", " fig_m = plt.figure()\n", " plt.bar(x=[s[0] for s in scores_desc[:5]],\n", " height=[s[1] for s in scores_desc[:5]])\n", " plt.title(\"Top words contributing to positive sentiment\")\n", " plt.ylabel(\"Shap Value\")\n", " plt.xlabel(\"Word\")\n", " return {\"original\": text, \"interpretation\": scores}, fig_m\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " input_text = gr.Textbox(label=\"Input Text\")\n", " with gr.Row():\n", " classify = gr.Button(\"Classify Sentiment\")\n", " interpret = gr.Button(\"Interpret\")\n", " with gr.Column():\n", " label = gr.Label(label=\"Predicted Sentiment\")\n", " with gr.Column():\n", " with gr.Tab(\"Display interpretation with built-in component\"):\n", " interpretation = gr.components.Interpretation(input_text)\n", " with gr.Tab(\"Display interpretation with plot\"):\n", " interpretation_plot = gr.Plot()\n", "\n", " classify.click(classifier, input_text, label)\n", " interpret.click(interpretation_function, input_text, [interpretation, interpretation_plot])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chicago-bikeshare-dashboard"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio psycopg2 matplotlib SQLAlchemy "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import os\n", "import gradio as gr\n", "import matplotlib\n", "import pandas as pd\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "DB_USER = os.getenv(\"DB_USER\")\n", "DB_PASSWORD = os.getenv(\"DB_PASSWORD\")\n", "DB_HOST = os.getenv(\"DB_HOST\")\n", "PORT = 8080\n", "DB_NAME = \"bikeshare\"\n", "\n", "connection_string = (\n", " f\"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}\"\n", ")\n", "\n", "\n", "def get_count_ride_type():\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, rideable_type\n", " FROM rides\n", " GROUP BY rideable_type\n", " ORDER BY n DESC\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " return df\n", "\n", "\n", "def get_most_popular_stations():\n", "\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, MAX(start_station_name) as station\n", " FROM RIDES\n", " WHERE start_station_name is NOT NULL\n", " GROUP BY start_station_id\n", " ORDER BY n DESC\n", " LIMIT 5\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " return df\n", "\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", " # Chicago Bike Share Dashboard\n", " \n", " This demo pulls Chicago bike share data for March 2022 from a postgresql database hosted on AWS.\n", " This demo uses psycopg2 but any postgresql client library (SQLAlchemy)\n", " is compatible with gradio.\n", " \n", " Connection credentials are handled by environment variables\n", " defined as secrets in the Space.\n", "\n", " If data were added to the database, the plots in this demo would update\n", " whenever the webpage is reloaded.\n", " \n", " This demo serves as a starting point for your database-connected apps!\n", " \"\"\"\n", " )\n", " with gr.Row():\n", " bike_type = gr.BarPlot(\n", " x=\"rideable_type\",\n", " y='n',\n", " title=\"Number of rides per bicycle type\",\n", " y_title=\"Number of Rides\",\n", " x_title=\"Bicycle Type\",\n", " vertical=False,\n", " tooltip=['rideable_type', \"n\"],\n", " height=300,\n", " width=300,\n", " )\n", " station = gr.BarPlot(\n", " x='station',\n", " y='n',\n", " title=\"Most Popular Stations\",\n", " y_title=\"Number of Rides\",\n", " x_title=\"Station Name\",\n", " vertical=False,\n", " tooltip=['station', 'n'],\n", " height=300,\n", " width=300\n", " )\n", "\n", " demo.load(get_count_ride_type, inputs=None, outputs=bike_type)\n", " demo.load(get_most_popular_stations, inputs=None, outputs=station)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} | ||
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chicago-bikeshare-dashboard"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio psycopg2 matplotlib SQLAlchemy "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import os\n", "import gradio as gr\n", "import pandas as pd\n", "\n", "DB_USER = os.getenv(\"DB_USER\")\n", "DB_PASSWORD = os.getenv(\"DB_PASSWORD\")\n", "DB_HOST = os.getenv(\"DB_HOST\")\n", "PORT = 8080\n", "DB_NAME = \"bikeshare\"\n", "\n", "connection_string = (\n", " f\"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}\"\n", ")\n", "\n", "\n", "def get_count_ride_type():\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, rideable_type\n", " FROM rides\n", " GROUP BY rideable_type\n", " ORDER BY n DESC\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " return df\n", "\n", "\n", "def get_most_popular_stations():\n", "\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, MAX(start_station_name) as station\n", " FROM RIDES\n", " WHERE start_station_name is NOT NULL\n", " GROUP BY start_station_id\n", " ORDER BY n DESC\n", " LIMIT 5\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " return df\n", "\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", " # Chicago Bike Share Dashboard\n", " \n", " This demo pulls Chicago bike share data for March 2022 from a postgresql database hosted on AWS.\n", " This demo uses psycopg2 but any postgresql client library (SQLAlchemy)\n", " is compatible with gradio.\n", " \n", " Connection credentials are handled by environment variables\n", " defined as secrets in the Space.\n", "\n", " If data were added to the database, the plots in this demo would update\n", " whenever the webpage is reloaded.\n", " \n", " This demo serves as a starting point for your database-connected apps!\n", " \"\"\"\n", " )\n", " with gr.Row():\n", " bike_type = gr.BarPlot(\n", " x=\"rideable_type\",\n", " y='n',\n", " title=\"Number of rides per bicycle type\",\n", " y_title=\"Number of Rides\",\n", " x_title=\"Bicycle Type\",\n", " vertical=False,\n", " tooltip=['rideable_type', \"n\"],\n", " height=300,\n", " width=300,\n", " )\n", " station = gr.BarPlot(\n", " x='station',\n", " y='n',\n", " title=\"Most Popular Stations\",\n", " y_title=\"Number of Rides\",\n", " x_title=\"Station Name\",\n", " vertical=False,\n", " tooltip=['station', 'n'],\n", " height=300,\n", " width=300\n", " )\n", "\n", " demo.load(get_count_ride_type, inputs=None, outputs=bike_type)\n", " demo.load(get_most_popular_stations, inputs=None, outputs=station)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.