diff --git a/CHANGELOG.md b/CHANGELOG.md index 656267508e281..c7935af5d873a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ ## Bug Fixes: +- Fixes issue with `matplotlib` not rendering correctly if the backend was not set to `Agg` by [@abidlabs](https://github.com/abidlabs) in [PR 4029](https://github.com/gradio-app/gradio/pull/4029) - Fixes bug where rendering the same `gr.State` across different Interfaces/Blocks within a larger Blocks would not work by [@abidlabs](https://github.com/abidlabs) in [PR 4030](https://github.com/gradio-app/gradio/pull/4030) @@ -20,11 +21,11 @@ No changes to highlight. ## Breaking Changes: -No changes to highlight. +- `gr.HuggingFaceDatasetSaver` behavior changed internally. The `flagging/` folder is not a `.git/` folder anymore when using it. `organization` parameter is now ignored in favor of passing a full dataset id as `dataset_name` (e.g. `"username/my-dataset"`). ## Full Changelog: -No changes to highlight. +- Safer version of `gr.HuggingFaceDatasetSaver` using HTTP methods instead of git pull/push by [@Wauplin](https://github.com/Wauplin) in [PR 3973](https://github.com/gradio-app/gradio/pull/3973) ## Contributors Shoutout: @@ -94,7 +95,6 @@ No changes to highlight. ## Breaking Changes: -- `gr.HuggingFaceDatasetSaver` behavior changed internally. The `flagging/` folder is not a `.git/` folder anymore when using it. `organization` parameter is now ignored in favor of passing a full dataset id as `dataset_name` (e.g. `"username/my-dataset"`). - Some re-exports in `gradio.themes` utilities (introduced in 3.24.0) have been eradicated. By [@akx](https://github.com/akx) in [PR 3958](https://github.com/gradio-app/gradio/pull/3958) @@ -102,7 +102,6 @@ No changes to highlight. - Add DESCRIPTION.md to image_segmentation demo by [@aliabd](https://github.com/aliabd) in [PR 3866](https://github.com/gradio-app/gradio/pull/3866) - Fix error in running `gr.themes.builder()` by [@deepkyu](https://github.com/deepkyu) in [PR 3869](https://github.com/gradio-app/gradio/pull/3869) -- Safer version of `gr.HuggingFaceDatasetSaver` using HTTP methods instead of git pull/push by [@Wauplin](https://github.com/Wauplin) in [PR 3973](https://github.com/gradio-app/gradio/pull/3973) - Fixed a JavaScript TypeError when loading custom JS with `_js` and setting `outputs` to `None` in `gradio.Blocks()` by [@DavG25](https://github.com/DavG25) in [PR 3883](https://github.com/gradio-app/gradio/pull/3883) - Fixed bg_background_fill theme property to expand to whole background, block_radius to affect form elements as well, and added block_label_shadow theme property by [@aliabid94](https://github.com/aliabid94) in [PR 3590](https://github.com/gradio-app/gradio/pull/3590) diff --git a/demo/blocks_interpretation/run.ipynb b/demo/blocks_interpretation/run.ipynb index c0bc95066bc99..da7d3388b667e 100644 --- a/demo/blocks_interpretation/run.ipynb +++ b/demo/blocks_interpretation/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: blocks_interpretation"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio shap matplotlib transformers torch"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import shap\n", "from transformers import pipeline\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "matplotlib.use('Agg')\n", "\n", "\n", "sentiment_classifier = pipeline(\"text-classification\", return_all_scores=True)\n", "\n", "\n", "def classifier(text):\n", " pred = sentiment_classifier(text)\n", " return {p[\"label\"]: p[\"score\"] for p in pred[0]}\n", "\n", "\n", "def interpretation_function(text):\n", " explainer = shap.Explainer(sentiment_classifier)\n", " shap_values = explainer([text])\n", " # Dimensions are (batch size, text size, number of classes)\n", " # Since we care about positive sentiment, use index 1\n", " scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1]))\n", "\n", " scores_desc = sorted(scores, key=lambda t: t[1])[::-1]\n", "\n", " # Filter out empty string added by shap\n", " scores_desc = [t for t in scores_desc if t[0] != \"\"]\n", "\n", " fig_m = plt.figure()\n", " plt.bar(x=[s[0] for s in scores_desc[:5]],\n", " height=[s[1] for s in scores_desc[:5]])\n", " plt.title(\"Top words contributing to positive sentiment\")\n", " plt.ylabel(\"Shap Value\")\n", " plt.xlabel(\"Word\")\n", " return {\"original\": text, \"interpretation\": scores}, fig_m\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " input_text = gr.Textbox(label=\"Input Text\")\n", " with gr.Row():\n", " classify = gr.Button(\"Classify Sentiment\")\n", " interpret = gr.Button(\"Interpret\")\n", " with gr.Column():\n", " label = gr.Label(label=\"Predicted Sentiment\")\n", " with gr.Column():\n", " with gr.Tab(\"Display interpretation with built-in component\"):\n", " interpretation = gr.components.Interpretation(input_text)\n", " with gr.Tab(\"Display interpretation with plot\"):\n", " interpretation_plot = gr.Plot()\n", "\n", " classify.click(classifier, input_text, label)\n", " interpret.click(interpretation_function, input_text, [interpretation, interpretation_plot])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: blocks_interpretation"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio shap matplotlib transformers torch"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import shap\n", "from transformers import pipeline\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "sentiment_classifier = pipeline(\"text-classification\", return_all_scores=True)\n", "\n", "\n", "def classifier(text):\n", " pred = sentiment_classifier(text)\n", " return {p[\"label\"]: p[\"score\"] for p in pred[0]}\n", "\n", "\n", "def interpretation_function(text):\n", " explainer = shap.Explainer(sentiment_classifier)\n", " shap_values = explainer([text])\n", " # Dimensions are (batch size, text size, number of classes)\n", " # Since we care about positive sentiment, use index 1\n", " scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1]))\n", "\n", " scores_desc = sorted(scores, key=lambda t: t[1])[::-1]\n", "\n", " # Filter out empty string added by shap\n", " scores_desc = [t for t in scores_desc if t[0] != \"\"]\n", "\n", " fig_m = plt.figure()\n", " plt.bar(x=[s[0] for s in scores_desc[:5]],\n", " height=[s[1] for s in scores_desc[:5]])\n", " plt.title(\"Top words contributing to positive sentiment\")\n", " plt.ylabel(\"Shap Value\")\n", " plt.xlabel(\"Word\")\n", " return {\"original\": text, \"interpretation\": scores}, fig_m\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " input_text = gr.Textbox(label=\"Input Text\")\n", " with gr.Row():\n", " classify = gr.Button(\"Classify Sentiment\")\n", " interpret = gr.Button(\"Interpret\")\n", " with gr.Column():\n", " label = gr.Label(label=\"Predicted Sentiment\")\n", " with gr.Column():\n", " with gr.Tab(\"Display interpretation with built-in component\"):\n", " interpretation = gr.components.Interpretation(input_text)\n", " with gr.Tab(\"Display interpretation with plot\"):\n", " interpretation_plot = gr.Plot()\n", "\n", " classify.click(classifier, input_text, label)\n", " interpret.click(interpretation_function, input_text, [interpretation, interpretation_plot])\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/blocks_interpretation/run.py b/demo/blocks_interpretation/run.py index 4aa254d788bd3..467b4474cb68d 100644 --- a/demo/blocks_interpretation/run.py +++ b/demo/blocks_interpretation/run.py @@ -1,9 +1,7 @@ import gradio as gr import shap from transformers import pipeline -import matplotlib import matplotlib.pyplot as plt -matplotlib.use('Agg') sentiment_classifier = pipeline("text-classification", return_all_scores=True) diff --git a/demo/chicago-bikeshare-dashboard/run.ipynb b/demo/chicago-bikeshare-dashboard/run.ipynb index f5ec617ea9de8..f936d3be1f983 100644 --- a/demo/chicago-bikeshare-dashboard/run.ipynb +++ b/demo/chicago-bikeshare-dashboard/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chicago-bikeshare-dashboard"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio psycopg2 matplotlib SQLAlchemy "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import os\n", "import gradio as gr\n", "import matplotlib\n", "import pandas as pd\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "DB_USER = os.getenv(\"DB_USER\")\n", "DB_PASSWORD = os.getenv(\"DB_PASSWORD\")\n", "DB_HOST = os.getenv(\"DB_HOST\")\n", "PORT = 8080\n", "DB_NAME = \"bikeshare\"\n", "\n", "connection_string = (\n", " f\"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}\"\n", ")\n", "\n", "\n", "def get_count_ride_type():\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, rideable_type\n", " FROM rides\n", " GROUP BY rideable_type\n", " ORDER BY n DESC\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " return df\n", "\n", "\n", "def get_most_popular_stations():\n", "\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, MAX(start_station_name) as station\n", " FROM RIDES\n", " WHERE start_station_name is NOT NULL\n", " GROUP BY start_station_id\n", " ORDER BY n DESC\n", " LIMIT 5\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " return df\n", "\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", " # Chicago Bike Share Dashboard\n", " \n", " This demo pulls Chicago bike share data for March 2022 from a postgresql database hosted on AWS.\n", " This demo uses psycopg2 but any postgresql client library (SQLAlchemy)\n", " is compatible with gradio.\n", " \n", " Connection credentials are handled by environment variables\n", " defined as secrets in the Space.\n", "\n", " If data were added to the database, the plots in this demo would update\n", " whenever the webpage is reloaded.\n", " \n", " This demo serves as a starting point for your database-connected apps!\n", " \"\"\"\n", " )\n", " with gr.Row():\n", " bike_type = gr.BarPlot(\n", " x=\"rideable_type\",\n", " y='n',\n", " title=\"Number of rides per bicycle type\",\n", " y_title=\"Number of Rides\",\n", " x_title=\"Bicycle Type\",\n", " vertical=False,\n", " tooltip=['rideable_type', \"n\"],\n", " height=300,\n", " width=300,\n", " )\n", " station = gr.BarPlot(\n", " x='station',\n", " y='n',\n", " title=\"Most Popular Stations\",\n", " y_title=\"Number of Rides\",\n", " x_title=\"Station Name\",\n", " vertical=False,\n", " tooltip=['station', 'n'],\n", " height=300,\n", " width=300\n", " )\n", "\n", " demo.load(get_count_ride_type, inputs=None, outputs=bike_type)\n", " demo.load(get_most_popular_stations, inputs=None, outputs=station)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chicago-bikeshare-dashboard"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio psycopg2 matplotlib SQLAlchemy "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import os\n", "import gradio as gr\n", "import pandas as pd\n", "\n", "DB_USER = os.getenv(\"DB_USER\")\n", "DB_PASSWORD = os.getenv(\"DB_PASSWORD\")\n", "DB_HOST = os.getenv(\"DB_HOST\")\n", "PORT = 8080\n", "DB_NAME = \"bikeshare\"\n", "\n", "connection_string = (\n", " f\"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}?port={PORT}&dbname={DB_NAME}\"\n", ")\n", "\n", "\n", "def get_count_ride_type():\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, rideable_type\n", " FROM rides\n", " GROUP BY rideable_type\n", " ORDER BY n DESC\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " return df\n", "\n", "\n", "def get_most_popular_stations():\n", "\n", " df = pd.read_sql(\n", " \"\"\"\n", " SELECT COUNT(ride_id) as n, MAX(start_station_name) as station\n", " FROM RIDES\n", " WHERE start_station_name is NOT NULL\n", " GROUP BY start_station_id\n", " ORDER BY n DESC\n", " LIMIT 5\n", " \"\"\",\n", " con=connection_string,\n", " )\n", " return df\n", "\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\n", " \"\"\"\n", " # Chicago Bike Share Dashboard\n", " \n", " This demo pulls Chicago bike share data for March 2022 from a postgresql database hosted on AWS.\n", " This demo uses psycopg2 but any postgresql client library (SQLAlchemy)\n", " is compatible with gradio.\n", " \n", " Connection credentials are handled by environment variables\n", " defined as secrets in the Space.\n", "\n", " If data were added to the database, the plots in this demo would update\n", " whenever the webpage is reloaded.\n", " \n", " This demo serves as a starting point for your database-connected apps!\n", " \"\"\"\n", " )\n", " with gr.Row():\n", " bike_type = gr.BarPlot(\n", " x=\"rideable_type\",\n", " y='n',\n", " title=\"Number of rides per bicycle type\",\n", " y_title=\"Number of Rides\",\n", " x_title=\"Bicycle Type\",\n", " vertical=False,\n", " tooltip=['rideable_type', \"n\"],\n", " height=300,\n", " width=300,\n", " )\n", " station = gr.BarPlot(\n", " x='station',\n", " y='n',\n", " title=\"Most Popular Stations\",\n", " y_title=\"Number of Rides\",\n", " x_title=\"Station Name\",\n", " vertical=False,\n", " tooltip=['station', 'n'],\n", " height=300,\n", " width=300\n", " )\n", "\n", " demo.load(get_count_ride_type, inputs=None, outputs=bike_type)\n", " demo.load(get_most_popular_stations, inputs=None, outputs=station)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/chicago-bikeshare-dashboard/run.py b/demo/chicago-bikeshare-dashboard/run.py index b3862021aa186..79bc1f9464ce0 100644 --- a/demo/chicago-bikeshare-dashboard/run.py +++ b/demo/chicago-bikeshare-dashboard/run.py @@ -1,10 +1,7 @@ import os import gradio as gr -import matplotlib import pandas as pd -matplotlib.use("Agg") - DB_USER = os.getenv("DB_USER") DB_PASSWORD = os.getenv("DB_PASSWORD") DB_HOST = os.getenv("DB_HOST") diff --git a/demo/kitchen_sink_random/constants.py b/demo/kitchen_sink_random/constants.py index ba48a845d6fbe..bda41f4f76c8a 100644 --- a/demo/kitchen_sink_random/constants.py +++ b/demo/kitchen_sink_random/constants.py @@ -1,7 +1,4 @@ import numpy as np -import matplotlib - -matplotlib.use("Agg") import matplotlib.pyplot as plt import random import os diff --git a/demo/outbreak_forecast/run.ipynb b/demo/outbreak_forecast/run.ipynb index 816159e1e27bb..dc96e2440a9cd 100644 --- a/demo/outbreak_forecast/run.ipynb +++ b/demo/outbreak_forecast/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly altair"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import altair\n", "\n", "import gradio as gr\n", "from math import sqrt\n", "import matplotlib\n", "\n", "matplotlib.use(\"Agg\")\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", " months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", " m = months.index(month)\n", " start_day = 30 * m\n", " final_day = 30 * (m + 1)\n", " x = np.arange(start_day, final_day + 1)\n", " pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", " if social_distancing:\n", " r = sqrt(r)\n", " df = pd.DataFrame({\"day\": x})\n", " for country in countries:\n", " df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", " if plot_type == \"Matplotlib\":\n", " fig = plt.figure()\n", " plt.plot(df[\"day\"], df[countries].to_numpy())\n", " plt.title(\"Outbreak in \" + month)\n", " plt.ylabel(\"Cases\")\n", " plt.xlabel(\"Days since Day 0\")\n", " plt.legend(countries)\n", " return fig\n", " elif plot_type == \"Plotly\":\n", " fig = px.line(df, x=\"day\", y=countries)\n", " fig.update_layout(\n", " title=\"Outbreak in \" + month,\n", " xaxis_title=\"Cases\",\n", " yaxis_title=\"Days Since Day 0\",\n", " )\n", " return fig\n", " elif plot_type == \"Altair\":\n", " df = df.melt(id_vars=\"day\").rename(columns={\"variable\": \"country\"})\n", " fig = altair.Chart(df).mark_line().encode(x=\"day\", y='value', color='country')\n", " return fig\n", " else:\n", " raise ValueError(\"A plot type must be selected\")\n", "\n", "\n", "inputs = [\n", " gr.Dropdown([\"Matplotlib\", \"Plotly\", \"Altair\"], label=\"Plot Type\"),\n", " gr.Slider(1, 4, 3.2, label=\"R\"),\n", " gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", " gr.CheckboxGroup(\n", " [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", " ),\n", " gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", " fn=outbreak,\n", " inputs=inputs,\n", " outputs=outputs,\n", " examples=[\n", " [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", " [\"Altair\", 2, \"March\", [\"Mexico\", \"Canada\"], True],\n", " [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", " ],\n", " cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: outbreak_forecast\n", "### Generate a plot based on 5 inputs.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib bokeh plotly altair"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import altair\n", "\n", "import gradio as gr\n", "from math import sqrt\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import plotly.express as px\n", "import pandas as pd\n", "\n", "\n", "def outbreak(plot_type, r, month, countries, social_distancing):\n", " months = [\"January\", \"February\", \"March\", \"April\", \"May\"]\n", " m = months.index(month)\n", " start_day = 30 * m\n", " final_day = 30 * (m + 1)\n", " x = np.arange(start_day, final_day + 1)\n", " pop_count = {\"USA\": 350, \"Canada\": 40, \"Mexico\": 300, \"UK\": 120}\n", " if social_distancing:\n", " r = sqrt(r)\n", " df = pd.DataFrame({\"day\": x})\n", " for country in countries:\n", " df[country] = x ** (r) * (pop_count[country] + 1)\n", "\n", " if plot_type == \"Matplotlib\":\n", " fig = plt.figure()\n", " plt.plot(df[\"day\"], df[countries].to_numpy())\n", " plt.title(\"Outbreak in \" + month)\n", " plt.ylabel(\"Cases\")\n", " plt.xlabel(\"Days since Day 0\")\n", " plt.legend(countries)\n", " return fig\n", " elif plot_type == \"Plotly\":\n", " fig = px.line(df, x=\"day\", y=countries)\n", " fig.update_layout(\n", " title=\"Outbreak in \" + month,\n", " xaxis_title=\"Cases\",\n", " yaxis_title=\"Days Since Day 0\",\n", " )\n", " return fig\n", " elif plot_type == \"Altair\":\n", " df = df.melt(id_vars=\"day\").rename(columns={\"variable\": \"country\"})\n", " fig = altair.Chart(df).mark_line().encode(x=\"day\", y='value', color='country')\n", " return fig\n", " else:\n", " raise ValueError(\"A plot type must be selected\")\n", "\n", "\n", "inputs = [\n", " gr.Dropdown([\"Matplotlib\", \"Plotly\", \"Altair\"], label=\"Plot Type\"),\n", " gr.Slider(1, 4, 3.2, label=\"R\"),\n", " gr.Dropdown([\"January\", \"February\", \"March\", \"April\", \"May\"], label=\"Month\"),\n", " gr.CheckboxGroup(\n", " [\"USA\", \"Canada\", \"Mexico\", \"UK\"], label=\"Countries\", value=[\"USA\", \"Canada\"]\n", " ),\n", " gr.Checkbox(label=\"Social Distancing?\"),\n", "]\n", "outputs = gr.Plot()\n", "\n", "demo = gr.Interface(\n", " fn=outbreak,\n", " inputs=inputs,\n", " outputs=outputs,\n", " examples=[\n", " [\"Matplotlib\", 2, \"March\", [\"Mexico\", \"UK\"], True],\n", " [\"Altair\", 2, \"March\", [\"Mexico\", \"Canada\"], True],\n", " [\"Plotly\", 3.6, \"February\", [\"Canada\", \"Mexico\", \"UK\"], False],\n", " ],\n", " cache_examples=True,\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/outbreak_forecast/run.py b/demo/outbreak_forecast/run.py index 2d0101de675c8..5d95edba40365 100644 --- a/demo/outbreak_forecast/run.py +++ b/demo/outbreak_forecast/run.py @@ -2,10 +2,6 @@ import gradio as gr from math import sqrt -import matplotlib - -matplotlib.use("Agg") - import matplotlib.pyplot as plt import numpy as np import plotly.express as px diff --git a/demo/sales_projections/run.ipynb b/demo/sales_projections/run.ipynb index a5b09c927c957..197610cdc48db 100644 --- a/demo/sales_projections/run.ipynb +++ b/demo/sales_projections/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: sales_projections"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio pandas numpy matplotlib"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import matplotlib\n", "matplotlib.use('Agg')\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import gradio as gr\n", "\n", "\n", "def sales_projections(employee_data):\n", " sales_data = employee_data.iloc[:, 1:4].astype(\"int\").to_numpy()\n", " regression_values = np.apply_along_axis(\n", " lambda row: np.array(np.poly1d(np.polyfit([0, 1, 2], row, 2))), 0, sales_data\n", " )\n", " projected_months = np.repeat(\n", " np.expand_dims(np.arange(3, 12), 0), len(sales_data), axis=0\n", " )\n", " projected_values = np.array(\n", " [\n", " month * month * regression[0] + month * regression[1] + regression[2]\n", " for month, regression in zip(projected_months, regression_values)\n", " ]\n", " )\n", " plt.plot(projected_values.T)\n", " plt.legend(employee_data[\"Name\"])\n", " return employee_data, plt.gcf(), regression_values\n", "\n", "\n", "demo = gr.Interface(\n", " sales_projections,\n", " gr.Dataframe(\n", " headers=[\"Name\", \"Jan Sales\", \"Feb Sales\", \"Mar Sales\"],\n", " value=[[\"Jon\", 12, 14, 18], [\"Alice\", 14, 17, 2], [\"Sana\", 8, 9.5, 12]],\n", " ),\n", " [\"dataframe\", \"plot\", \"numpy\"],\n", " description=\"Enter sales figures for employees to predict sales trajectory over year.\",\n", ")\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: sales_projections"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio pandas numpy matplotlib"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import gradio as gr\n", "\n", "\n", "def sales_projections(employee_data):\n", " sales_data = employee_data.iloc[:, 1:4].astype(\"int\").to_numpy()\n", " regression_values = np.apply_along_axis(\n", " lambda row: np.array(np.poly1d(np.polyfit([0, 1, 2], row, 2))), 0, sales_data\n", " )\n", " projected_months = np.repeat(\n", " np.expand_dims(np.arange(3, 12), 0), len(sales_data), axis=0\n", " )\n", " projected_values = np.array(\n", " [\n", " month * month * regression[0] + month * regression[1] + regression[2]\n", " for month, regression in zip(projected_months, regression_values)\n", " ]\n", " )\n", " plt.plot(projected_values.T)\n", " plt.legend(employee_data[\"Name\"])\n", " return employee_data, plt.gcf(), regression_values\n", "\n", "\n", "demo = gr.Interface(\n", " sales_projections,\n", " gr.Dataframe(\n", " headers=[\"Name\", \"Jan Sales\", \"Feb Sales\", \"Mar Sales\"],\n", " value=[[\"Jon\", 12, 14, 18], [\"Alice\", 14, 17, 2], [\"Sana\", 8, 9.5, 12]],\n", " ),\n", " [\"dataframe\", \"plot\", \"numpy\"],\n", " description=\"Enter sales figures for employees to predict sales trajectory over year.\",\n", ")\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/sales_projections/run.py b/demo/sales_projections/run.py index 4e07c8235a6bf..33cdea5ae0afd 100644 --- a/demo/sales_projections/run.py +++ b/demo/sales_projections/run.py @@ -1,5 +1,3 @@ -import matplotlib -matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np diff --git a/demo/stock_forecast/run.ipynb b/demo/stock_forecast/run.ipynb index 5da2cf3b10549..dc0b0dc6bc3e5 100644 --- a/demo/stock_forecast/run.ipynb +++ b/demo/stock_forecast/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stock_forecast"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import matplotlib\n", "matplotlib.use('Agg')\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import gradio as gr\n", "\n", "\n", "def plot_forecast(final_year, companies, noise, show_legend, point_style):\n", " start_year = 2020\n", " x = np.arange(start_year, final_year + 1)\n", " year_count = x.shape[0]\n", " plt_format = ({\"cross\": \"X\", \"line\": \"-\", \"circle\": \"o--\"})[point_style]\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", " for i, company in enumerate(companies):\n", " series = np.arange(0, year_count, dtype=float)\n", " series = series**2 * (i + 1)\n", " series += np.random.rand(year_count) * noise\n", " ax.plot(x, series, plt_format)\n", " if show_legend:\n", " plt.legend(companies)\n", " return fig\n", "\n", "\n", "demo = gr.Interface(\n", " plot_forecast,\n", " [\n", " gr.Radio([2025, 2030, 2035, 2040], label=\"Project to:\"),\n", " gr.CheckboxGroup([\"Google\", \"Microsoft\", \"Gradio\"], label=\"Company Selection\"),\n", " gr.Slider(1, 100, label=\"Noise Level\"),\n", " gr.Checkbox(label=\"Show Legend\"),\n", " gr.Dropdown([\"cross\", \"line\", \"circle\"], label=\"Style\"),\n", " ],\n", " gr.Plot(label=\"forecast\"),\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: stock_forecast"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio numpy matplotlib"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "import gradio as gr\n", "\n", "\n", "def plot_forecast(final_year, companies, noise, show_legend, point_style):\n", " start_year = 2020\n", " x = np.arange(start_year, final_year + 1)\n", " year_count = x.shape[0]\n", " plt_format = ({\"cross\": \"X\", \"line\": \"-\", \"circle\": \"o--\"})[point_style]\n", " fig = plt.figure()\n", " ax = fig.add_subplot(111)\n", " for i, company in enumerate(companies):\n", " series = np.arange(0, year_count, dtype=float)\n", " series = series**2 * (i + 1)\n", " series += np.random.rand(year_count) * noise\n", " ax.plot(x, series, plt_format)\n", " if show_legend:\n", " plt.legend(companies)\n", " return fig\n", "\n", "\n", "demo = gr.Interface(\n", " plot_forecast,\n", " [\n", " gr.Radio([2025, 2030, 2035, 2040], label=\"Project to:\"),\n", " gr.CheckboxGroup([\"Google\", \"Microsoft\", \"Gradio\"], label=\"Company Selection\"),\n", " gr.Slider(1, 100, label=\"Noise Level\"),\n", " gr.Checkbox(label=\"Show Legend\"),\n", " gr.Dropdown([\"cross\", \"line\", \"circle\"], label=\"Style\"),\n", " ],\n", " gr.Plot(label=\"forecast\"),\n", ")\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/stock_forecast/run.py b/demo/stock_forecast/run.py index e6758abcfd506..2f1318c71db59 100644 --- a/demo/stock_forecast/run.py +++ b/demo/stock_forecast/run.py @@ -1,5 +1,3 @@ -import matplotlib -matplotlib.use('Agg') import matplotlib.pyplot as plt import numpy as np diff --git a/demo/xgboost-income-prediction-with-explainability/run.ipynb b/demo/xgboost-income-prediction-with-explainability/run.ipynb index 43d526982ab01..3879709527432 100644 --- a/demo/xgboost-income-prediction-with-explainability/run.ipynb +++ b/demo/xgboost-income-prediction-with-explainability/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: xgboost-income-prediction-with-explainability\n", "### This demo takes in 12 inputs from the user in dropdowns and sliders and predicts income. It also has a separate button for explaining the prediction.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio matplotlib shap xgboost pandas datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import shap\n", "import xgboost as xgb\n", "from datasets import load_dataset\n", "\n", "\n", "matplotlib.use(\"Agg\")\n", "dataset = load_dataset(\"scikit-learn/adult-census-income\")\n", "X_train = dataset[\"train\"].to_pandas()\n", "_ = X_train.pop(\"fnlwgt\")\n", "_ = X_train.pop(\"race\")\n", "y_train = X_train.pop(\"income\")\n", "y_train = (y_train == \">50K\").astype(int)\n", "categorical_columns = [\n", " \"workclass\",\n", " \"education\",\n", " \"marital.status\",\n", " \"occupation\",\n", " \"relationship\",\n", " \"sex\",\n", " \"native.country\",\n", "]\n", "X_train = X_train.astype({col: \"category\" for col in categorical_columns})\n", "data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)\n", "model = xgb.train(params={\"objective\": \"binary:logistic\"}, dtrain=data)\n", "explainer = shap.TreeExplainer(model)\n", "\n", "def predict(*args):\n", " df = pd.DataFrame([args], columns=X_train.columns)\n", " df = df.astype({col: \"category\" for col in categorical_columns})\n", " pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))\n", " return {\">50K\": float(pos_pred[0]), \"<=50K\": 1 - float(pos_pred[0])}\n", "\n", "\n", "def interpret(*args):\n", " df = pd.DataFrame([args], columns=X_train.columns)\n", " df = df.astype({col: \"category\" for col in categorical_columns})\n", " shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))\n", " scores_desc = list(zip(shap_values[0], X_train.columns))\n", " scores_desc = sorted(scores_desc)\n", " fig_m = plt.figure(tight_layout=True)\n", " plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])\n", " plt.title(\"Feature Shap Values\")\n", " plt.ylabel(\"Shap Value\")\n", " plt.xlabel(\"Feature\")\n", " plt.tight_layout()\n", " return fig_m\n", "\n", "\n", "unique_class = sorted(X_train[\"workclass\"].unique())\n", "unique_education = sorted(X_train[\"education\"].unique())\n", "unique_marital_status = sorted(X_train[\"marital.status\"].unique())\n", "unique_relationship = sorted(X_train[\"relationship\"].unique())\n", "unique_occupation = sorted(X_train[\"occupation\"].unique())\n", "unique_sex = sorted(X_train[\"sex\"].unique())\n", "unique_country = sorted(X_train[\"native.country\"].unique())\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"\"\"\n", " **Income Classification with XGBoost \ud83d\udcb0**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).\n", " \"\"\")\n", " with gr.Row():\n", " with gr.Column():\n", " age = gr.Slider(label=\"Age\", minimum=17, maximum=90, step=1, randomize=True)\n", " work_class = gr.Dropdown(\n", " label=\"Workclass\",\n", " choices=unique_class,\n", " value=lambda: random.choice(unique_class),\n", " )\n", " education = gr.Dropdown(\n", " label=\"Education Level\",\n", " choices=unique_education,\n", " value=lambda: random.choice(unique_education),\n", " )\n", " years = gr.Slider(\n", " label=\"Years of schooling\",\n", " minimum=1,\n", " maximum=16,\n", " step=1,\n", " randomize=True,\n", " )\n", " marital_status = gr.Dropdown(\n", " label=\"Marital Status\",\n", " choices=unique_marital_status,\n", " value=lambda: random.choice(unique_marital_status),\n", " )\n", " occupation = gr.Dropdown(\n", " label=\"Occupation\",\n", " choices=unique_occupation,\n", " value=lambda: random.choice(unique_occupation),\n", " )\n", " relationship = gr.Dropdown(\n", " label=\"Relationship Status\",\n", " choices=unique_relationship,\n", " value=lambda: random.choice(unique_relationship),\n", " )\n", " sex = gr.Dropdown(\n", " label=\"Sex\", choices=unique_sex, value=lambda: random.choice(unique_sex)\n", " )\n", " capital_gain = gr.Slider(\n", " label=\"Capital Gain\",\n", " minimum=0,\n", " maximum=100000,\n", " step=500,\n", " randomize=True,\n", " )\n", " capital_loss = gr.Slider(\n", " label=\"Capital Loss\", minimum=0, maximum=10000, step=500, randomize=True\n", " )\n", " hours_per_week = gr.Slider(\n", " label=\"Hours Per Week Worked\", minimum=1, maximum=99, step=1\n", " )\n", " country = gr.Dropdown(\n", " label=\"Native Country\",\n", " choices=unique_country,\n", " value=lambda: random.choice(unique_country),\n", " )\n", " with gr.Column():\n", " label = gr.Label()\n", " plot = gr.Plot()\n", " with gr.Row():\n", " predict_btn = gr.Button(value=\"Predict\")\n", " interpret_btn = gr.Button(value=\"Explain\")\n", " predict_btn.click(\n", " predict,\n", " inputs=[\n", " age,\n", " work_class,\n", " education,\n", " years,\n", " marital_status,\n", " occupation,\n", " relationship,\n", " sex,\n", " capital_gain,\n", " capital_loss,\n", " hours_per_week,\n", " country,\n", " ],\n", " outputs=[label],\n", " )\n", " interpret_btn.click(\n", " interpret,\n", " inputs=[\n", " age,\n", " work_class,\n", " education,\n", " years,\n", " marital_status,\n", " occupation,\n", " relationship,\n", " sex,\n", " capital_gain,\n", " capital_loss,\n", " hours_per_week,\n", " country,\n", " ],\n", " outputs=[plot],\n", " )\n", "\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: xgboost-income-prediction-with-explainability\n", "### This demo takes in 12 inputs from the user in dropdowns and sliders and predicts income. It also has a separate button for explaining the prediction.\n", " "]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio matplotlib shap xgboost pandas datasets"]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import random\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import shap\n", "import xgboost as xgb\n", "from datasets import load_dataset\n", "\n", "\n", "dataset = load_dataset(\"scikit-learn/adult-census-income\")\n", "X_train = dataset[\"train\"].to_pandas()\n", "_ = X_train.pop(\"fnlwgt\")\n", "_ = X_train.pop(\"race\")\n", "y_train = X_train.pop(\"income\")\n", "y_train = (y_train == \">50K\").astype(int)\n", "categorical_columns = [\n", " \"workclass\",\n", " \"education\",\n", " \"marital.status\",\n", " \"occupation\",\n", " \"relationship\",\n", " \"sex\",\n", " \"native.country\",\n", "]\n", "X_train = X_train.astype({col: \"category\" for col in categorical_columns})\n", "data = xgb.DMatrix(X_train, label=y_train, enable_categorical=True)\n", "model = xgb.train(params={\"objective\": \"binary:logistic\"}, dtrain=data)\n", "explainer = shap.TreeExplainer(model)\n", "\n", "def predict(*args):\n", " df = pd.DataFrame([args], columns=X_train.columns)\n", " df = df.astype({col: \"category\" for col in categorical_columns})\n", " pos_pred = model.predict(xgb.DMatrix(df, enable_categorical=True))\n", " return {\">50K\": float(pos_pred[0]), \"<=50K\": 1 - float(pos_pred[0])}\n", "\n", "\n", "def interpret(*args):\n", " df = pd.DataFrame([args], columns=X_train.columns)\n", " df = df.astype({col: \"category\" for col in categorical_columns})\n", " shap_values = explainer.shap_values(xgb.DMatrix(df, enable_categorical=True))\n", " scores_desc = list(zip(shap_values[0], X_train.columns))\n", " scores_desc = sorted(scores_desc)\n", " fig_m = plt.figure(tight_layout=True)\n", " plt.barh([s[1] for s in scores_desc], [s[0] for s in scores_desc])\n", " plt.title(\"Feature Shap Values\")\n", " plt.ylabel(\"Shap Value\")\n", " plt.xlabel(\"Feature\")\n", " plt.tight_layout()\n", " return fig_m\n", "\n", "\n", "unique_class = sorted(X_train[\"workclass\"].unique())\n", "unique_education = sorted(X_train[\"education\"].unique())\n", "unique_marital_status = sorted(X_train[\"marital.status\"].unique())\n", "unique_relationship = sorted(X_train[\"relationship\"].unique())\n", "unique_occupation = sorted(X_train[\"occupation\"].unique())\n", "unique_sex = sorted(X_train[\"sex\"].unique())\n", "unique_country = sorted(X_train[\"native.country\"].unique())\n", "\n", "with gr.Blocks() as demo:\n", " gr.Markdown(\"\"\"\n", " **Income Classification with XGBoost \ud83d\udcb0**: This demo uses an XGBoost classifier predicts income based on demographic factors, along with Shapley value-based *explanations*. The [source code for this Gradio demo is here](https://huggingface.co/spaces/gradio/xgboost-income-prediction-with-explainability/blob/main/app.py).\n", " \"\"\")\n", " with gr.Row():\n", " with gr.Column():\n", " age = gr.Slider(label=\"Age\", minimum=17, maximum=90, step=1, randomize=True)\n", " work_class = gr.Dropdown(\n", " label=\"Workclass\",\n", " choices=unique_class,\n", " value=lambda: random.choice(unique_class),\n", " )\n", " education = gr.Dropdown(\n", " label=\"Education Level\",\n", " choices=unique_education,\n", " value=lambda: random.choice(unique_education),\n", " )\n", " years = gr.Slider(\n", " label=\"Years of schooling\",\n", " minimum=1,\n", " maximum=16,\n", " step=1,\n", " randomize=True,\n", " )\n", " marital_status = gr.Dropdown(\n", " label=\"Marital Status\",\n", " choices=unique_marital_status,\n", " value=lambda: random.choice(unique_marital_status),\n", " )\n", " occupation = gr.Dropdown(\n", " label=\"Occupation\",\n", " choices=unique_occupation,\n", " value=lambda: random.choice(unique_occupation),\n", " )\n", " relationship = gr.Dropdown(\n", " label=\"Relationship Status\",\n", " choices=unique_relationship,\n", " value=lambda: random.choice(unique_relationship),\n", " )\n", " sex = gr.Dropdown(\n", " label=\"Sex\", choices=unique_sex, value=lambda: random.choice(unique_sex)\n", " )\n", " capital_gain = gr.Slider(\n", " label=\"Capital Gain\",\n", " minimum=0,\n", " maximum=100000,\n", " step=500,\n", " randomize=True,\n", " )\n", " capital_loss = gr.Slider(\n", " label=\"Capital Loss\", minimum=0, maximum=10000, step=500, randomize=True\n", " )\n", " hours_per_week = gr.Slider(\n", " label=\"Hours Per Week Worked\", minimum=1, maximum=99, step=1\n", " )\n", " country = gr.Dropdown(\n", " label=\"Native Country\",\n", " choices=unique_country,\n", " value=lambda: random.choice(unique_country),\n", " )\n", " with gr.Column():\n", " label = gr.Label()\n", " plot = gr.Plot()\n", " with gr.Row():\n", " predict_btn = gr.Button(value=\"Predict\")\n", " interpret_btn = gr.Button(value=\"Explain\")\n", " predict_btn.click(\n", " predict,\n", " inputs=[\n", " age,\n", " work_class,\n", " education,\n", " years,\n", " marital_status,\n", " occupation,\n", " relationship,\n", " sex,\n", " capital_gain,\n", " capital_loss,\n", " hours_per_week,\n", " country,\n", " ],\n", " outputs=[label],\n", " )\n", " interpret_btn.click(\n", " interpret,\n", " inputs=[\n", " age,\n", " work_class,\n", " education,\n", " years,\n", " marital_status,\n", " occupation,\n", " relationship,\n", " sex,\n", " capital_gain,\n", " capital_loss,\n", " hours_per_week,\n", " country,\n", " ],\n", " outputs=[plot],\n", " )\n", "\n", "demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/xgboost-income-prediction-with-explainability/run.py b/demo/xgboost-income-prediction-with-explainability/run.py index cbc1f44fcd39c..27ef4a3f1de40 100644 --- a/demo/xgboost-income-prediction-with-explainability/run.py +++ b/demo/xgboost-income-prediction-with-explainability/run.py @@ -1,6 +1,5 @@ import gradio as gr import random -import matplotlib import matplotlib.pyplot as plt import pandas as pd import shap @@ -8,7 +7,6 @@ from datasets import load_dataset -matplotlib.use("Agg") dataset = load_dataset("scikit-learn/adult-census-income") X_train = dataset["train"].to_pandas() _ = X_train.pop("fnlwgt") diff --git a/gradio/components.py b/gradio/components.py index 38bc98f374427..6ee217027a06c 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -23,7 +23,6 @@ import aiofiles import altair as alt -import matplotlib.figure import numpy as np import pandas as pd import PIL @@ -4892,6 +4891,8 @@ def postprocess(self, y) -> Dict[str, str] | None: Returns: plot type mapped to plot base64 data """ + import matplotlib.figure + if y is None: return None if isinstance(y, (ModuleType, matplotlib.figure.Figure)): # type: ignore diff --git a/gradio/helpers.py b/gradio/helpers.py index e201804385a2b..32b81252847fe 100644 --- a/gradio/helpers.py +++ b/gradio/helpers.py @@ -14,7 +14,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple -import matplotlib import matplotlib.pyplot as plt import numpy as np import PIL @@ -309,9 +308,13 @@ async def cache(self) -> None: processed_input = self.processed_examples[example_id] if self.batch: processed_input = [[value] for value in processed_input] - prediction = await Context.root_block.process_api( - fn_index=fn_index, inputs=processed_input, request=None, state={} - ) + with utils.MatplotlibBackendMananger(): + prediction = await Context.root_block.process_api( + fn_index=fn_index, + inputs=processed_input, + request=None, + state={}, + ) output = prediction["data"] if self.batch: output = [value[0] for value in output] @@ -749,58 +752,60 @@ def get_color_gradient(c1, c2, n): samples = np.abs(samples) samples = np.max(samples, 1) - matplotlib.use("Agg") - plt.clf() - # Plot waveform - color = ( - bars_color - if isinstance(bars_color, str) - else get_color_gradient(bars_color[0], bars_color[1], bar_count) - ) - plt.bar( - np.arange(0, bar_count), - samples * 2, - bottom=(-1 * samples), - width=bar_width, - color=color, - ) - plt.axis("off") - plt.margins(x=0) - tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False) - savefig_kwargs: Dict[str, Any] = {"bbox_inches": "tight"} - if bg_image is not None: - savefig_kwargs["transparent"] = True - else: - savefig_kwargs["facecolor"] = bg_color - plt.savefig(tmp_img.name, **savefig_kwargs) - waveform_img = PIL.Image.open(tmp_img.name) - waveform_img = waveform_img.resize((1000, 200)) - - # Composite waveform with background image - if bg_image is not None: - waveform_array = np.array(waveform_img) - waveform_array[:, :, 3] = waveform_array[:, :, 3] * fg_alpha - waveform_img = PIL.Image.fromarray(waveform_array) - - bg_img = PIL.Image.open(bg_image) - waveform_width, waveform_height = waveform_img.size - bg_width, bg_height = bg_img.size - if waveform_width != bg_width: - bg_img = bg_img.resize( - (waveform_width, 2 * int(bg_height * waveform_width / bg_width / 2)) - ) - bg_width, bg_height = bg_img.size - composite_height = max(bg_height, waveform_height) - composite = PIL.Image.new("RGBA", (waveform_width, composite_height), "#FFFFFF") - composite.paste(bg_img, (0, composite_height - bg_height)) - composite.paste( - waveform_img, (0, composite_height - waveform_height), waveform_img + with utils.MatplotlibBackendMananger(): + plt.clf() + # Plot waveform + color = ( + bars_color + if isinstance(bars_color, str) + else get_color_gradient(bars_color[0], bars_color[1], bar_count) ) - composite.save(tmp_img.name) - img_width, img_height = composite.size - else: - img_width, img_height = waveform_img.size - waveform_img.save(tmp_img.name) + plt.bar( + np.arange(0, bar_count), + samples * 2, + bottom=(-1 * samples), + width=bar_width, + color=color, + ) + plt.axis("off") + plt.margins(x=0) + tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + savefig_kwargs: Dict[str, Any] = {"bbox_inches": "tight"} + if bg_image is not None: + savefig_kwargs["transparent"] = True + else: + savefig_kwargs["facecolor"] = bg_color + plt.savefig(tmp_img.name, **savefig_kwargs) + waveform_img = PIL.Image.open(tmp_img.name) + waveform_img = waveform_img.resize((1000, 200)) + + # Composite waveform with background image + if bg_image is not None: + waveform_array = np.array(waveform_img) + waveform_array[:, :, 3] = waveform_array[:, :, 3] * fg_alpha + waveform_img = PIL.Image.fromarray(waveform_array) + + bg_img = PIL.Image.open(bg_image) + waveform_width, waveform_height = waveform_img.size + bg_width, bg_height = bg_img.size + if waveform_width != bg_width: + bg_img = bg_img.resize( + (waveform_width, 2 * int(bg_height * waveform_width / bg_width / 2)) + ) + bg_width, bg_height = bg_img.size + composite_height = max(bg_height, waveform_height) + composite = PIL.Image.new( + "RGBA", (waveform_width, composite_height), "#FFFFFF" + ) + composite.paste(bg_img, (0, composite_height - bg_height)) + composite.paste( + waveform_img, (0, composite_height - waveform_height), waveform_img + ) + composite.save(tmp_img.name) + img_width, img_height = composite.size + else: + img_width, img_height = waveform_img.size + waveform_img.save(tmp_img.name) # Convert waveform to video with ffmpeg output_mp4 = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) diff --git a/gradio/routes.py b/gradio/routes.py index 640995f9f8335..8fc784960c8be 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -396,15 +396,16 @@ async def run_predict( if not (body.batched) and batch: raw_input = [raw_input] try: - output = await app.get_blocks().process_api( - fn_index=fn_index_inferred, - inputs=raw_input, - request=request, - state=session_state, - iterators=iterators, - event_id=event_id, - event_data=event_data, - ) + with utils.MatplotlibBackendMananger(): + output = await app.get_blocks().process_api( + fn_index=fn_index_inferred, + inputs=raw_input, + request=request, + state=session_state, + iterators=iterators, + event_id=event_id, + event_data=event_data, + ) iterator = output.pop("iterator", None) if hasattr(body, "session_hash"): if fn_index in app.iterators[body.session_hash]["should_reset"]: diff --git a/gradio/utils.py b/gradio/utils.py index 0d82afeedc2c0..ac9400cbe81cc 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -38,7 +38,7 @@ import aiohttp import httpx -import matplotlib.pyplot as plt +import matplotlib import requests from markdown_it import MarkdownIt from mdit_py_plugins.dollarmath.index import dollarmath_plugin @@ -899,34 +899,48 @@ def __str__(self): return "" +class MatplotlibBackendMananger: + def __enter__(self): + self._original_backend = matplotlib.get_backend() + matplotlib.use("agg") + + def __exit__(self, exc_type, exc_val, exc_tb): + matplotlib.use(self._original_backend) + + def tex2svg(formula, *args): - FONTSIZE = 20 - DPI = 300 - plt.rc("mathtext", fontset="cm") - fig = plt.figure(figsize=(0.01, 0.01)) - fig.text(0, 0, rf"${formula}$", fontsize=FONTSIZE) - output = BytesIO() - fig.savefig( - output, - dpi=DPI, - transparent=True, - format="svg", - bbox_inches="tight", - pad_inches=0.0, - ) - plt.close(fig) - output.seek(0) - xml_code = output.read().decode("utf-8") - svg_start = xml_code.index(".*<\/metadata>", "", svg_code, flags=re.DOTALL) - svg_code = re.sub(r' width="[^"]+"', "", svg_code) - height_match = re.search(r'height="([\d.]+)pt"', svg_code) - if height_match: - height = float(height_match.group(1)) - new_height = height / FONTSIZE # conversion from pt to em - svg_code = re.sub(r'height="[\d.]+pt"', f'height="{new_height}em"', svg_code) - copy_code = f"{formula}" + with MatplotlibBackendMananger(): + import matplotlib.pyplot as plt + + FONTSIZE = 20 + DPI = 300 + plt.rc("mathtext", fontset="cm") + fig = plt.figure(figsize=(0.01, 0.01)) + fig.text(0, 0, rf"${formula}$", fontsize=FONTSIZE) + output = BytesIO() + fig.savefig( + output, + dpi=DPI, + transparent=True, + format="svg", + bbox_inches="tight", + pad_inches=0.0, + ) + plt.close(fig) + output.seek(0) + xml_code = output.read().decode("utf-8") + svg_start = xml_code.index(".*<\/metadata>", "", svg_code, flags=re.DOTALL) + svg_code = re.sub(r' width="[^"]+"', "", svg_code) + height_match = re.search(r'height="([\d.]+)pt"', svg_code) + if height_match: + height = float(height_match.group(1)) + new_height = height / FONTSIZE # conversion from pt to em + svg_code = re.sub( + r'height="[\d.]+pt"', f'height="{new_height}em"', svg_code + ) + copy_code = f"{formula}" return f"{copy_code}{svg_code}" diff --git a/test/test_components.py b/test/test_components.py index a7558139ee3f3..2b0dde62d49bd 100644 --- a/test/test_components.py +++ b/test/test_components.py @@ -15,8 +15,6 @@ from pathlib import Path from unittest.mock import MagicMock, patch -import matplotlib -import matplotlib.pyplot as plt import numpy as np import pandas as pd import PIL @@ -27,10 +25,9 @@ from scipy.io import wavfile import gradio as gr -from gradio import processing_utils +from gradio import processing_utils, utils os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" -matplotlib.use("Agg") class TestComponent: @@ -734,12 +731,15 @@ async def test_in_interface_as_output(self): """ def plot(num): + import matplotlib.pyplot as plt + fig = plt.figure() plt.plot(range(num), range(num)) return fig iface = gr.Interface(plot, "slider", "plot") - output = await iface.process_api(fn_index=0, inputs=[10], state={}) + with utils.MatplotlibBackendMananger(): + output = await iface.process_api(fn_index=0, inputs=[10], state={}) assert output["data"][0]["type"] == "matplotlib" assert output["data"][0]["plot"].startswith("data:image/png;base64") @@ -747,8 +747,11 @@ def test_static(self): """ postprocess """ - fig = plt.figure() - plt.plot([1, 2, 3], [1, 2, 3]) + with utils.MatplotlibBackendMananger(): + import matplotlib.pyplot as plt + + fig = plt.figure() + plt.plot([1, 2, 3], [1, 2, 3]) component = gr.Plot(fig) assert component.get_config().get("value") is not None diff --git a/test/test_processing_utils.py b/test/test_processing_utils.py index f9b30b9e76627..fca2fdaa1b8f4 100644 --- a/test/test_processing_utils.py +++ b/test/test_processing_utils.py @@ -6,13 +6,12 @@ from unittest.mock import patch import ffmpy -import matplotlib.pyplot as plt import numpy as np import pytest from gradio_client import media_data from PIL import Image -from gradio import processing_utils +from gradio import processing_utils, utils os.environ["GRADIO_ANALYTICS_ENABLED"] = "False" @@ -32,8 +31,11 @@ def test_decode_base64_to_image(self): assert output_image == output_image_without_header def test_encode_plot_to_base64(self): - plt.plot([1, 2, 3, 4]) - output_base64 = processing_utils.encode_plot_to_base64(plt) + with utils.MatplotlibBackendMananger(): + import matplotlib.pyplot as plt + + plt.plot([1, 2, 3, 4]) + output_base64 = processing_utils.encode_plot_to_base64(plt) assert output_base64.startswith( "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAo" ) diff --git a/test/test_utils.py b/test/test_utils.py index afe3114b63849..984b0f677376c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -40,6 +40,7 @@ sagemaker_check, sanitize_list_for_csv, sanitize_value_for_csv, + tex2svg, validate_url, version_check, ) @@ -653,3 +654,16 @@ def f(s: str, evt: EventData): for x in test_objs: check_function_inputs_match(x, [None], False) + + +def test_tex2svg_preserves_matplotlib_backend(): + import matplotlib + + matplotlib.use("svg") + tex2svg("1+1=2") + assert matplotlib.get_backend() == "svg" + with pytest.raises( + Exception # specifically a pyparsing.ParseException but not important here + ): + tex2svg("$$$1+1=2$$$") + assert matplotlib.get_backend() == "svg"