Skip to content

Commit

Permalink
Overwrite SQLite3Task image
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Sep 15, 2022
1 parent 2ccaed7 commit 82582a9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
5 changes: 3 additions & 2 deletions flytekit/extras/sqlite3/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd

from flytekit import FlyteContext, kwtypes
from flytekit.configuration import SerializationSettings
from flytekit.configuration import SerializationSettings, DefaultImages
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask
from flytekit.core.shim_task import ShimTaskExecutor
Expand Down Expand Up @@ -79,6 +79,7 @@ def __init__(
inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
task_config: typing.Optional[SQLite3Config] = None,
output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None,
container_image: typing.Optional[str] = None,
**kwargs,
):
if task_config is None or task_config.uri is None:
Expand All @@ -88,7 +89,7 @@ def __init__(
name=name,
task_config=task_config,
# If you make changes to this task itself, you'll have to bump this image to what the release _will_ be.
container_image="ghcr.io/flyteorg/flytekit:v0.19.0",
container_image=container_image or DefaultImages.default_image(),
executor_type=SQLite3TaskExecutor,
task_type=self._SQLITE_TASK_TYPE,
query_template=query_template,
Expand Down
8 changes: 7 additions & 1 deletion tests/flytekit/unit/extras/sqlite3/test_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas

from flytekit import kwtypes, task, workflow
from flytekit.configuration import DefaultImages
from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task

# https://www.sqlitetutorial.net/sqlite-sample-database/
Expand Down Expand Up @@ -99,4 +100,9 @@ def test_task_serialization():
]

assert tt.custom["query_template"] == "select TrackId, Name from tracks limit {{.inputs.limit}}"
assert tt.container.image != ""
assert tt.container.image == DefaultImages.default_image()

image = "xyz.io/docker2:latest"
sql_task._container_image = image
tt = sql_task.serialize_to_model(sql_task.SERIALIZE_SETTINGS)
assert tt.container.image == image

0 comments on commit 82582a9

Please sign in to comment.