From 82582a9f6caa423072a4cb835e480121129595e2 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Thu, 15 Sep 2022 13:46:29 +0800 Subject: [PATCH] Overwrite SQLite3Task image Signed-off-by: Kevin Su --- flytekit/extras/sqlite3/task.py | 5 +++-- tests/flytekit/unit/extras/sqlite3/test_task.py | 8 +++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index 1018b5254b..a1c6992a40 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -9,7 +9,7 @@ import pandas as pd from flytekit import FlyteContext, kwtypes -from flytekit.configuration import SerializationSettings +from flytekit.configuration import SerializationSettings, DefaultImages from flytekit.core.base_sql_task import SQLTask from flytekit.core.python_customized_container_task import PythonCustomizedContainerTask from flytekit.core.shim_task import ShimTaskExecutor @@ -79,6 +79,7 @@ def __init__( inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, task_config: typing.Optional[SQLite3Config] = None, output_schema_type: typing.Optional[typing.Type[FlyteSchema]] = None, + container_image: typing.Optional[str] = None, **kwargs, ): if task_config is None or task_config.uri is None: @@ -88,7 +89,7 @@ def __init__( name=name, task_config=task_config, # If you make changes to this task itself, you'll have to bump this image to what the release _will_ be. - container_image="ghcr.io/flyteorg/flytekit:v0.19.0", + container_image=container_image or DefaultImages.default_image(), executor_type=SQLite3TaskExecutor, task_type=self._SQLITE_TASK_TYPE, query_template=query_template, diff --git a/tests/flytekit/unit/extras/sqlite3/test_task.py b/tests/flytekit/unit/extras/sqlite3/test_task.py index f586d94a16..5e729d9e7a 100644 --- a/tests/flytekit/unit/extras/sqlite3/test_task.py +++ b/tests/flytekit/unit/extras/sqlite3/test_task.py @@ -1,6 +1,7 @@ import pandas from flytekit import kwtypes, task, workflow +from flytekit.configuration import DefaultImages from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task # https://www.sqlitetutorial.net/sqlite-sample-database/ @@ -99,4 +100,9 @@ def test_task_serialization(): ] assert tt.custom["query_template"] == "select TrackId, Name from tracks limit {{.inputs.limit}}" - assert tt.container.image != "" + assert tt.container.image == DefaultImages.default_image() + + image = "xyz.io/docker2:latest" + sql_task._container_image = image + tt = sql_task.serialize_to_model(sql_task.SERIALIZE_SETTINGS) + assert tt.container.image == image