From 495894d02be2119dda009bd7dd82b80389542c74 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Sat, 17 Sep 2022 01:13:25 +0800 Subject: [PATCH] Overwrite SQLite3 Task image (#1165) * Overwrite SQLite3Task image Signed-off-by: Kevin Su * fix lint error Signed-off-by: Kevin Su * remove comment Signed-off-by: Kevin Su * nit Signed-off-by: Kevin Su Signed-off-by: Kevin Su --- flytekit/extras/sqlite3/task.py | 7 ++++--- tests/flytekit/unit/extras/sqlite3/test_task.py | 8 +++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index 1018b5254b..45c23da0ae 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -9,7 +9,7 @@ import pandas as pd from flytekit import FlyteContext, kwtypes -from flytekit.configuration import SerializationSettings +from flytekit.configuration import DefaultImages, SerializationSettings from flytekit.core.base_sql_task import SQLTask from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask from flytekit.core.shim_task import ShimTaskExecutor @@ -79,6 +79,7 @@ def __init__( inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, task_config: typing.Optional[SQLite3Config] = None, output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None, + container_image: typing.Optional[str] = None, **kwargs, ): if task_config is None or task_config.uri is None: @@ -87,8 +88,8 @@ def __init__( super().__init__( name=name, task_config=task_config, - # If you make changes to this task itself, you'll have to bump this image to what the release _will_ be. - container_image="ghcr.io/flyteorg/flytekit:v0.19.0", + # if you use your own image, keep in mind to specify the container image here + container_image=container_image or DefaultImages.default_image(), executor_type=SQLite3TaskExecutor, task_type=self._SQLITE_TASK_TYPE, query_template=query_template, diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index f586d94a16..5e729d9e7a 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -1,6 +1,7 @@ import pandas from flytekit import kwtypes, task, workflow +from flytekit.configuration import DefaultImages from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task # https://www.sqlitetutorial.net/sqlite-sample-database/ @@ -99,4 +100,9 @@ def test_task_serialization(): ] assert tt.custom["query_template"] == "select TrackId, Name from tracks limit {{.inputs.limit}}" - assert tt.container.image != "" + assert tt.container.image == DefaultImages.default_image() + + image = "xyz.io/docker2:latest" + sql_task._container_image = image + tt = sql_task.serialize_to_model(sql_task.SERIALIZE_SETTINGS) + assert tt.container.image == image