diff --git a/evadb/functions/dalle.py b/evadb/functions/dalle.py new file mode 100644 index 0000000000..37c75b77e2 --- /dev/null +++ b/evadb/functions/dalle.py @@ -0,0 +1,79 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pandas as pd + +from evadb.catalog.catalog_type import NdArrayType +from evadb.configuration.configuration_manager import ConfigurationManager +from evadb.functions.abstract.abstract_function import AbstractFunction +from evadb.functions.decorators.decorators import forward +from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe +from evadb.utils.generic_utils import try_to_import_openai + + +class DallEFunction(AbstractFunction): + @property + def name(self) -> str: + return "DallE" + + def setup(self) -> None: + pass + + @forward( + input_signatures=[ + PandasDataframe( + columns=["prompt"], + column_types=[ + NdArrayType.STR, + ], + column_shapes=[(None,)], + ) + ], + output_signatures=[ + PandasDataframe( + columns=["response"], + column_types=[ + NdArrayType.STR, + ], + column_shapes=[(1,)], + ) + ], + ) + def forward(self, text_df): + try_to_import_openai() + import openai + + # Register API key, try configuration manager first + openai.api_key = ConfigurationManager().get_value("third_party", "OPENAI_KEY") + # If not found, try OS Environment Variable + if len(openai.api_key) == 0: + openai.api_key = os.environ.get("OPENAI_KEY", "") + assert ( + len(openai.api_key) != 0 + ), "Please set your OpenAI API key in evadb.yml file (third_party, open_api_key) or environment variable (OPENAI_KEY)" + + def generate_image(text_df: PandasDataframe): + results = [] + queries = text_df[text_df.columns[0]] + for query in queries: + response = openai.Image.create(prompt=query, n=1, size="1024x1024") + results.append(response["data"][0]["url"]) + return results + + df = pd.DataFrame({"response": generate_image(text_df=text_df)}) + + return df diff --git a/evadb/functions/function_bootstrap_queries.py b/evadb/functions/function_bootstrap_queries.py index 99b2aeed4c..e22c941a1f 100644 --- a/evadb/functions/function_bootstrap_queries.py +++ b/evadb/functions/function_bootstrap_queries.py @@ -202,6 +202,18 @@ MODEL 'yolov8n.pt'; """ +stablediffusion_function_query = """CREATE FUNCTION IF NOT EXISTS StableDiffusion + IMPL '{}/functions/stable_diffusion.py'; + """.format( + EvaDB_INSTALLATION_DIR +) + +dalle_function_query = """CREATE FUNCTION IF NOT EXISTS StableDiffusion + IMPL '{}/functions/dalle.py'; + """.format( + EvaDB_INSTALLATION_DIR +) + def init_builtin_functions(db: EvaDBDatabase, mode: str = "debug") -> None: """Load the built-in functions into the system during system bootstrapping. diff --git a/evadb/functions/stable_diffusion.py b/evadb/functions/stable_diffusion.py new file mode 100644 index 0000000000..85262d81dd --- /dev/null +++ b/evadb/functions/stable_diffusion.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pandas as pd + +from evadb.catalog.catalog_type import NdArrayType +from evadb.functions.abstract.abstract_function import AbstractFunction +from evadb.functions.decorators.decorators import forward +from evadb.functions.decorators.io_descriptors.data_types import PandasDataframe +from evadb.utils.generic_utils import try_to_import_replicate + +_VALID_STABLE_DIFFUSION_MODEL = [ + "sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33", +] + + +class StableDiffusion(AbstractFunction): + @property + def name(self) -> str: + return "StableDiffusion" + + def setup( + self, + model="sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33", + ) -> None: + assert ( + model in _VALID_STABLE_DIFFUSION_MODEL + ), f"Unsupported Stable Diffusion {model}" + self.model = model + + @forward( + input_signatures=[ + PandasDataframe( + columns=["prompt"], + column_types=[ + NdArrayType.STR, + ], + column_shapes=[(None,)], + ) + ], + output_signatures=[ + PandasDataframe( + columns=["response"], + column_types=[ + NdArrayType.STR, + ], + column_shapes=[(1,)], + ) + ], + ) + def forward(self, text_df): + try_to_import_replicate() + import replicate + + if os.environ.get("REPLICATE_API_TOKEN") is None: + replicate_api_key = ( + "r8_Q75IAgbaHFvYVfLSMGmjQPcW5uZZoXz0jGalu" # token for testing + ) + os.environ["REPLICATE_API_TOKEN"] = replicate_api_key + + # @retry(tries=5, delay=20) + def generate_image(text_df: PandasDataframe): + results = [] + queries = text_df[text_df.columns[0]] + for query in queries: + output = replicate.run( + "stability-ai/" + self.model, input={"prompt": query} + ) + results.append(output[0]) + return results + + df = pd.DataFrame({"response": generate_image(text_df=text_df)}) + + return df diff --git a/evadb/utils/generic_utils.py b/evadb/utils/generic_utils.py index 3abc78b288..be18ceafd3 100644 --- a/evadb/utils/generic_utils.py +++ b/evadb/utils/generic_utils.py @@ -596,3 +596,21 @@ def string_comparison_case_insensitive(string_1, string_2) -> bool: return False return string_1.lower() == string_2.lower() + + +def try_to_import_replicate(): + try: + import replicate # noqa: F401 + except ImportError: + raise ValueError( + """Could not import replicate python package. + Please install it with `pip install replicate`.""" + ) + + +def is_replicate_available(): + try: + try_to_import_replicate() + return True + except ValueError: + return False diff --git a/test/integration_tests/long/functions/test_dalle.py b/test/integration_tests/long/functions/test_dalle.py new file mode 100644 index 0000000000..270e3cac15 --- /dev/null +++ b/test/integration_tests/long/functions/test_dalle.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest +from test.util import get_evadb_for_testing +from unittest.mock import patch + +from evadb.server.command_handler import execute_query_fetch_all + + +class DallEFunctionTest(unittest.TestCase): + def setUp(self) -> None: + self.evadb = get_evadb_for_testing() + self.evadb.catalog().reset() + create_table_query = """CREATE TABLE IF NOT EXISTS ImageGen ( + prompt TEXT(100)); + """ + execute_query_fetch_all(self.evadb, create_table_query) + + test_prompts = ["a surreal painting of a cat"] + + for prompt in test_prompts: + insert_query = f"""INSERT INTO ImageGen (prompt) VALUES ('{prompt}')""" + execute_query_fetch_all(self.evadb, insert_query) + + def tearDown(self) -> None: + execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS ImageGen;") + + @patch.dict("os.environ", {"OPENAI_KEY": "mocked_openai_key"}) + @patch("openai.Image.create", return_value={"data": [{"url": "mocked_url"}]}) + def test_dalle_image_generation(self, mock_openai_create): + function_name = "DallE" + + execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};") + + create_function_query = f"""CREATE FUNCTION IF NOT EXISTS{function_name} + IMPL 'evadb/functions/dalle.py'; + """ + execute_query_fetch_all(self.evadb, create_function_query) + + gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;" + output_batch = execute_query_fetch_all(self.evadb, gpt_query) + + self.assertEqual(output_batch.columns, ["dalle.response"]) + mock_openai_create.assert_called_once_with( + prompt="a surreal painting of a cat", n=1, size="1024x1024" + ) diff --git a/test/integration_tests/long/functions/test_selfdiffusion.py b/test/integration_tests/long/functions/test_selfdiffusion.py new file mode 100644 index 0000000000..cb0df2d78a --- /dev/null +++ b/test/integration_tests/long/functions/test_selfdiffusion.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from test.markers import stable_diffusion_skip_marker +from test.util import get_evadb_for_testing +from unittest.mock import patch + +from evadb.server.command_handler import execute_query_fetch_all + + +class StableDiffusionTest(unittest.TestCase): + def setUp(self) -> None: + self.evadb = get_evadb_for_testing() + self.evadb.catalog().reset() + create_table_query = """CREATE TABLE IF NOT EXISTS ImageGen ( + prompt TEXT(100)); + """ + execute_query_fetch_all(self.evadb, create_table_query) + + test_prompts = ["pink cat riding a rocket to the moon"] + + for prompt in test_prompts: + insert_query = f"""INSERT INTO ImageGen (prompt) VALUES ('{prompt}')""" + execute_query_fetch_all(self.evadb, insert_query) + + def tearDown(self) -> None: + execute_query_fetch_all(self.evadb, "DROP TABLE IF EXISTS ImageGen;") + + @stable_diffusion_skip_marker + @patch("replicate.run", return_value=[{"response": "mocked response"}]) + def test_stable_diffusion_image_generation(self, mock_replicate_run): + function_name = "StableDiffusion" + + execute_query_fetch_all(self.evadb, f"DROP FUNCTION IF EXISTS {function_name};") + + create_function_query = f"""CREATE FUNCTION IF NOT EXISTS{function_name} + IMPL 'evadb/functions/stable_diffusion.py'; + """ + execute_query_fetch_all(self.evadb, create_function_query) + + gpt_query = f"SELECT {function_name}(prompt) FROM ImageGen;" + output_batch = execute_query_fetch_all(self.evadb, gpt_query) + + self.assertEqual(output_batch.columns, ["stablediffusion.response"]) + mock_replicate_run.assert_called_once_with( + "stability-ai/sdxl:af1a68a271597604546c09c64aabcd7782c114a63539a4a8d14d1eeda5630c33", + input={"prompt": "pink cat riding a rocket to the moon"}, + ) diff --git a/test/markers.py b/test/markers.py index 8a310e9eff..7d98e55348 100644 --- a/test/markers.py +++ b/test/markers.py @@ -25,6 +25,7 @@ is_ludwig_available, is_pinecone_available, is_qdrant_available, + is_replicate_available, is_sklearn_available, ) @@ -96,3 +97,7 @@ is_forecast_available() is False, reason="Run only if forecasting packages available", ) + +stable_diffusion_skip_marker = pytest.mark.skipif( + is_replicate_available() is False, reason="requires replicate" +) diff --git a/tutorials/16-stable-diffusion.ipynb b/tutorials/16-stable-diffusion.ipynb new file mode 100644 index 0000000000..91e0fea925 --- /dev/null +++ b/tutorials/16-stable-diffusion.ipynb @@ -0,0 +1,319 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \n", + " \n", + " \n", + "
\n", + " Run on Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + "


" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Import dependencies\n", + "import os\n", + "from IPython.display import Image" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Connect to EvaDB" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "source": [ + "%pip install --quiet \"evadb[document,notebook]\"\n", + "import evadb\n", + "cursor = evadb.connect().cursor()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get Input Prompt from User" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# to collect all user prompts\n", + "def get_user_input():\n", + " print('Welcome to EvaDB!')\n", + " print('Enter your image prompts one by one; type \\'exit\\' to stop entering prompts.')\n", + " print('========================================')\n", + " prompts = []\n", + " prompt=None\n", + "\n", + " # receive all prompts from user\n", + " while True:\n", + " prompt = input(\n", + " 'Enter prompt: '\n", + " ).strip()\n", + " if prompt in ['Exit', 'exit', 'EXIT']:\n", + " break\n", + " prompts.append(prompt)\n", + " print(prompt)\n", + "\n", + " return prompts" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Welcome to EvaDB!\n", + "Enter your image prompts one by one; type 'exit' to stop entering prompts.\n", + "========================================\n", + "Brown elephant riding a rocket to the moon\n" + ] + } + ], + "source": [ + "# getting user input\n", + "prompts = get_user_input()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Set API Token Environment Variable" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# to set the replicate API token environment variable\n", + "def set_replicate_token() -> None:\n", + " key = input('Enter your Replicate API Token: ').strip()\n", + "\n", + " try:\n", + " os.environ['REPLICATE_API_TOKEN'] = key\n", + " print('Environment variable set successfully.')\n", + " except Exception as e:\n", + " print(\"❗️ Session ended with an error.\")\n", + " print(e)\n", + " print(\"===========================================\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Environment variable set successfully.\n" + ] + } + ], + "source": [ + "# setting api token as env variable\n", + "set_replicate_token()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load the Stable Diffusion UDF" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# set up the stable diffusion UDF available at functions/stable_diffusion.py\n", + "cursor.query(\"\"\"CREATE FUNCTION IF NOT EXISTS StableDiffusion\n", + " IMPL '../evadb/functions/stable_diffusion.py';\n", + " \"\"\").execute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create Table" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# delete the table if it already exists\n", + "cursor.query(\"\"\"DROP TABLE IF EXISTS ImageGen\n", + " \"\"\").execute()\n", + "\n", + "# create the table specifying the type of the prompt column\n", + "cursor.query(\"\"\"CREATE TABLE IF NOT EXISTS ImageGen (\n", + " prompt TEXT(100))\n", + " \"\"\").execute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load Prompts into Table" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "# insert the prompts into the table\n", + "for prompt in prompts:\n", + " cursor.query(f\"\"\"INSERT INTO ImageGen (prompt) VALUES ('{prompt}')\"\"\").execute()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run Stable Diffusion on the Prompts" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# run stable diffusion on the prompts\n", + "table = cursor.table(\"ImageGen\").select(\"StableDiffusion(prompt)\").df()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# list of generated images\n", + "generated_images = list(table[table.columns[0]])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualize the Generated Image(s)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# visualize the generated image\n", + "Image(url=generated_images[0])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "eva", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}