From 57cce54bb73e4d3ce4ccf9a6e8f6676be31e00f7 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 24 Dec 2021 00:23:38 +0800 Subject: [PATCH] update get_custom Signed-off-by: Kevin Su --- plugins/flytekit-bigquery/README.md | 4 ---- .../flytekitplugins/bigquery/task.py | 20 +++++++++---------- .../flytekit-bigquery/tests/test_bigquery.py | 5 ++--- 3 files changed, 12 insertions(+), 17 deletions(-) diff --git a/plugins/flytekit-bigquery/README.md b/plugins/flytekit-bigquery/README.md index 3b66cb9de1..7b8468ffc2 100644 --- a/plugins/flytekit-bigquery/README.md +++ b/plugins/flytekit-bigquery/README.md @@ -9,7 +9,3 @@ pip install flytekitplugins-bigquery ``` To configure BigQuery in the Flyte deployment's backend, follow the [configuration guide](https://docs.flyte.org/en/latest/deployment/plugin_setup/gcp/bigquery.html#deployment-plugin-setup-gcp-bigquery). - -TODO: Add example - -An [example]() showcasing BigQuery service can be found in the documentation. diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py index 18e047ca07..59c769ce2e 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Optional, Type from google.cloud import bigquery +from google.protobuf import json_format from google.protobuf.struct_pb2 import Struct from flytekit.extend import SerializationSettings, SQLTask @@ -15,8 +16,8 @@ class BigQueryConfig(object): BigQueryConfig should be used to configure a BigQuery Task. """ - Location: Optional[str] - ProjectID: Optional[str] + ProjectID: str + Location: Optional[str] = None QueryJobConfig: Optional[bigquery.QueryJobConfig] = None @@ -65,16 +66,15 @@ def __init__( self._output_schema_type = output_schema_type def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - config: dict = self.task_config.QueryJobConfig.to_api_repr()["query"] - config.update( - { - "Location": self.task_config.Location, - "ProjectID": self.task_config.ProjectID, - } - ) + config = { + "Location": self.task_config.Location, + "ProjectID": self.task_config.ProjectID, + } + if self.task_config.QueryJobConfig is not None: + config.update(self.task_config.QueryJobConfig.to_api_repr()["query"]) s = Struct() s.update(config) - return s + return json_format.MessageToDict(s) def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]: sql = _task_model.Sql(statement=self.query_template, dialect=_task_model.Sql.Dialect.ANSI) diff --git a/plugins/flytekit-bigquery/tests/test_bigquery.py b/plugins/flytekit-bigquery/tests/test_bigquery.py index dbde23cf7c..95a482091d 100644 --- a/plugins/flytekit-bigquery/tests/test_bigquery.py +++ b/plugins/flytekit-bigquery/tests/test_bigquery.py @@ -3,6 +3,7 @@ import pytest from flytekitplugins.bigquery import BigQueryConfig, BigQueryTask from google.cloud.bigquery import QueryJobConfig +from google.protobuf import json_format from google.protobuf.struct_pb2 import Struct from flytekit import kwtypes, workflow @@ -25,7 +26,6 @@ def test_serialization(): ProjectID="Flyte", Location="Asia", QueryJobConfig=QueryJobConfig(allow_large_results=True) ), query_template=query_template, - # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, ) @@ -49,7 +49,7 @@ def my_wf(ds: str) -> FlyteSchema: assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI s = Struct() s.update({"ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True}) - assert task_spec.template.custom == s + assert task_spec.template.custom == json_format.MessageToDict(s) assert len(task_spec.template.interface.inputs) == 1 assert len(task_spec.template.interface.outputs) == 1 @@ -66,7 +66,6 @@ def test_local_exec(): inputs=kwtypes(ds=str), query_template=query_template, task_config=BigQueryConfig(ProjectID="Flyte", Location="Asia"), - # the schema literal's backend uri will be equal to the value of .raw_output_data output_schema_type=FlyteSchema, )