diff --git a/flytekit/extras/sqlalchemy/task.py b/flytekit/extras/sqlalchemy/task.py index 59a9d2ae6d..7def055c64 100644 --- a/flytekit/extras/sqlalchemy/task.py +++ b/flytekit/extras/sqlalchemy/task.py @@ -4,7 +4,7 @@ import pandas as pd from sqlalchemy import create_engine -from flytekit import current_context, kwtypes +from flytekit import current_context, kwtypes, Secret from flytekit.core.base_sql_task import SQLTask from flytekit.core.python_function_task import PythonInstanceTask from flytekit.types.schema import FlyteSchema @@ -29,7 +29,7 @@ class SQLAlchemyConfig(object): uri: str connect_args: typing.Optional[typing.Dict[str, typing.Any]] = None - secret_connect_args: typing.Optional[typing.Dict[str, typing.Dict[str, typing.Any]]] = None + secret_connect_args: typing.Optional[typing.Dict[str, Secret]] = None class SQLAlchemyTask(PythonInstanceTask[SQLAlchemyConfig], SQLTask[SQLAlchemyConfig]): @@ -74,9 +74,8 @@ def output_columns(self) -> typing.Optional[typing.List[str]]: def execute(self, **kwargs) -> typing.Any: if self._secret_connect_args is not None: for key, secret in self._secret_connect_args.items(): - if "name" in secret and "group" in secret: - value = current_context().secrets.get(secret["group"], secret["name"]) - self._connect_args[key] = value + value = current_context().secrets.get(secret.group, secret.key) + self._connect_args[key] = value engine = create_engine(self._uri, connect_args=self._connect_args, echo=False) print(f"Connecting to db {self._uri}") with engine.begin() as connection: diff --git a/tests/flytekit/unit/extras/sqlalchemy/test_task.py b/tests/flytekit/unit/extras/sqlalchemy/test_task.py index 3e5ecc7533..50275f162c 100644 --- a/tests/flytekit/unit/extras/sqlalchemy/test_task.py +++ b/tests/flytekit/unit/extras/sqlalchemy/test_task.py @@ -74,6 +74,7 @@ def my_task(df: pandas.DataFrame) -> int: return len(df[df.columns[0]]) os.environ[current_context().secrets.get_secrets_env_var("group", "key")] = "root" + user_secret = Secret(group="group", key="key") sql_task = SQLAlchemyTask( "test", query_template="select * from tracks limit {{.inputs.limit}}", @@ -82,13 +83,10 @@ def my_task(df: pandas.DataFrame) -> int: uri=BAD_EXAMPLE_DB, connect_args=dict(port=3307), secret_connect_args=dict( - user=dict( - group="group", - name="key", - ), + user=user_secret, ), ), - secret_requests=[Secret("group", key="key")], + secret_requests=[user_secret], ) @workflow