Skip to content

Commit

Permalink
Overwrite SQLite3 Task image (#1165)
Browse files Browse the repository at this point in the history
* Overwrite SQLite3Task image

Signed-off-by: Kevin Su <[email protected]>

* fix lint error

Signed-off-by: Kevin Su <[email protected]>

* remove comment

Signed-off-by: Kevin Su <[email protected]>

* nit

Signed-off-by: Kevin Su <[email protected]>

Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored Sep 16, 2022
1 parent aedcfd4 commit 495894d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
7 changes: 4 additions & 3 deletions flytekit/extras/sqlite3/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pandas as pd

from flytekit import FlyteContext, kwtypes
from flytekit.configuration import SerializationSettings
from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask
from flytekit.core.shim_task import ShimTaskExecutor
Expand Down Expand Up @@ -79,6 +79,7 @@ def __init__(
inputs: typing.Optional[typing.Dict[str, typing.Type]] = None,
task_config: typing.Optional[SQLite3Config] = None,
output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None,
container_image: typing.Optional[str] = None,
**kwargs,
):
if task_config is None or task_config.uri is None:
Expand All @@ -87,8 +88,8 @@ def __init__(
super().__init__(
name=name,
task_config=task_config,
# If you make changes to this task itself, you'll have to bump this image to what the release _will_ be.
container_image="ghcr.io/flyteorg/flytekit:v0.19.0",
# if you use your own image, keep in mind to specify the container image here
container_image=container_image or DefaultImages.default_image(),
executor_type=SQLite3TaskExecutor,
task_type=self._SQLITE_TASK_TYPE,
query_template=query_template,
Expand Down
8 changes: 7 additions & 1 deletion tests/flytekit/unit/extras/sqlite3/test_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas

from flytekit import kwtypes, task, workflow
from flytekit.configuration import DefaultImages
from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task

# https://www.sqlitetutorial.net/sqlite-sample-database/
Expand Down Expand Up @@ -99,4 +100,9 @@ def test_task_serialization():
]

assert tt.custom["query_template"] == "select TrackId, Name from tracks limit {{.inputs.limit}}"
assert tt.container.image != ""
assert tt.container.image == DefaultImages.default_image()

image = "xyz.io/docker2:latest"
sql_task._container_image = image
tt = sql_task.serialize_to_model(sql_task.SERIALIZE_SETTINGS)
assert tt.container.image == image

0 comments on commit 495894d

Please sign in to comment.