From 72705a1844457aa9a9b8cfe7c2cc9770aab29a8c Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 1 Jul 2022 00:30:16 +0800 Subject: [PATCH] Add support insert in SQLAlchemyTask (#1070) Signed-off-by: Kevin Su --- .../flytekitplugins/sqlalchemy/task.py | 15 +++++++++++---- plugins/flytekit-sqlalchemy/tests/test_task.py | 10 +++++++++- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py index 1150c8a941..88e4ef41c0 100644 --- a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py +++ b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import pandas as pd +from pandas.io.sql import pandasSQL_builder from sqlalchemy import create_engine # type: ignore from flytekit import current_context, kwtypes @@ -82,12 +83,14 @@ def __init__( query_template: str, task_config: SQLAlchemyConfig, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, - output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None, + output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = FlyteSchema, container_image: str = SQLAlchemyDefaultImages.default_image(), **kwargs, ): - output_schema = output_schema_type if output_schema_type else FlyteSchema - outputs = kwtypes(results=output_schema) + if output_schema_type: + outputs = kwtypes(results=output_schema_type) + else: + outputs = None super().__init__( name=name, @@ -128,5 +131,9 @@ def execute_from_model(self, tt: task_models.TaskTemplate, **kwargs) -> typing.A interpolated_query = SQLAlchemyTask.interpolate_query(tt.custom["query_template"], **kwargs) print(f"Interpolated query {interpolated_query}") with engine.begin() as connection: - df = pd.read_sql_query(interpolated_query, connection) + df = None + if tt.interface.outputs: + df = pd.read_sql_query(interpolated_query, connection) + else: + pandasSQL_builder(connection).execute(interpolated_query) return df diff --git a/plugins/flytekit-sqlalchemy/tests/test_task.py b/plugins/flytekit-sqlalchemy/tests/test_task.py index 167c8e796d..6d20027b2a 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_task.py +++ b/plugins/flytekit-sqlalchemy/tests/test_task.py @@ -75,6 +75,13 @@ def test_workflow(sql_server): def my_task(df: pandas.DataFrame) -> int: return len(df[df.columns[0]]) + insert_task = SQLAlchemyTask( + "test", + query_template="insert into tracks values (5, 'flyte')", + output_schema_type=None, + task_config=SQLAlchemyConfig(uri=sql_server), + ) + sql_task = SQLAlchemyTask( "test", query_template="select * from tracks limit {{.inputs.limit}}", @@ -84,9 +91,10 @@ def my_task(df: pandas.DataFrame) -> int: @workflow def wf(limit: int) -> int: + insert_task() return my_task(df=sql_task(limit=limit)) - assert wf(limit=5) == 5 + assert wf(limit=10) == 6 def test_task_serialization(sql_server):