Skip to content

Commit

Permalink
Secrets to dataclass
Browse files Browse the repository at this point in the history
  • Loading branch information
max-hoffman committed Apr 16, 2021
1 parent 18a3c2d commit f010509
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
9 changes: 4 additions & 5 deletions flytekit/extras/sqlalchemy/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pandas as pd
from sqlalchemy import create_engine

from flytekit import current_context, kwtypes
from flytekit import current_context, kwtypes, Secret
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.python_function_task import PythonInstanceTask
from flytekit.types.schema import FlyteSchema
Expand All @@ -29,7 +29,7 @@ class SQLAlchemyConfig(object):

uri: str
connect_args: typing.Optional[typing.Dict[str, typing.Any]] = None
secret_connect_args: typing.Optional[typing.Dict[str, typing.Dict[str, typing.Any]]] = None
secret_connect_args: typing.Optional[typing.Dict[str, Secret]] = None


class SQLAlchemyTask(PythonInstanceTask[SQLAlchemyConfig], SQLTask[SQLAlchemyConfig]):
Expand Down Expand Up @@ -74,9 +74,8 @@ def output_columns(self) -> typing.Optional[typing.List[str]]:
def execute(self, **kwargs) -> typing.Any:
if self._secret_connect_args is not None:
for key, secret in self._secret_connect_args.items():
if "name" in secret and "group" in secret:
value = current_context().secrets.get(secret["group"], secret["name"])
self._connect_args[key] = value
value = current_context().secrets.get(secret.group, secret.key)
self._connect_args[key] = value
engine = create_engine(self._uri, connect_args=self._connect_args, echo=False)
print(f"Connecting to db {self._uri}")
with engine.begin() as connection:
Expand Down
8 changes: 3 additions & 5 deletions tests/flytekit/unit/extras/sqlalchemy/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def my_task(df: pandas.DataFrame) -> int:
return len(df[df.columns[0]])

os.environ[current_context().secrets.get_secrets_env_var("group", "key")] = "root"
user_secret = Secret(group="group", key="key")
sql_task = SQLAlchemyTask(
"test",
query_template="select * from tracks limit {{.inputs.limit}}",
Expand All @@ -82,13 +83,10 @@ def my_task(df: pandas.DataFrame) -> int:
uri=BAD_EXAMPLE_DB,
connect_args=dict(port=3307),
secret_connect_args=dict(
user=dict(
group="group",
name="key",
),
user=user_secret,
),
),
secret_requests=[Secret("group", key="key")],
secret_requests=[user_secret],
)

@workflow
Expand Down

0 comments on commit f010509

Please sign in to comment.