Skip to content

Commit

Permalink
update get_custom
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw committed Dec 28, 2021
1 parent 09a1de3 commit 57cce54
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 17 deletions.
4 changes: 0 additions & 4 deletions plugins/flytekit-bigquery/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,3 @@ pip install flytekitplugins-bigquery
```

To configure BigQuery in the Flyte deployment's backend, follow the [configuration guide](https://docs.flyte.org/en/latest/deployment/plugin_setup/gcp/bigquery.html#deployment-plugin-setup-gcp-bigquery).

TODO: Add example

An [example]() showcasing BigQuery service can be found in the documentation.
20 changes: 10 additions & 10 deletions plugins/flytekit-bigquery/flytekitplugins/bigquery/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, Optional, Type

from google.cloud import bigquery
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit.extend import SerializationSettings, SQLTask
Expand All @@ -15,8 +16,8 @@ class BigQueryConfig(object):
BigQueryConfig should be used to configure a BigQuery Task.
"""

Location: Optional[str]
ProjectID: Optional[str]
ProjectID: str
Location: Optional[str] = None
QueryJobConfig: Optional[bigquery.QueryJobConfig] = None


Expand Down Expand Up @@ -65,16 +66,15 @@ def __init__(
self._output_schema_type = output_schema_type

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
config: dict = self.task_config.QueryJobConfig.to_api_repr()["query"]
config.update(
{
"Location": self.task_config.Location,
"ProjectID": self.task_config.ProjectID,
}
)
config = {
"Location": self.task_config.Location,
"ProjectID": self.task_config.ProjectID,
}
if self.task_config.QueryJobConfig is not None:
config.update(self.task_config.QueryJobConfig.to_api_repr()["query"])
s = Struct()
s.update(config)
return s
return json_format.MessageToDict(s)

def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]:
sql = _task_model.Sql(statement=self.query_template, dialect=_task_model.Sql.Dialect.ANSI)
Expand Down
5 changes: 2 additions & 3 deletions plugins/flytekit-bigquery/tests/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from flytekitplugins.bigquery import BigQueryConfig, BigQueryTask
from google.cloud.bigquery import QueryJobConfig
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Struct

from flytekit import kwtypes, workflow
Expand All @@ -25,7 +26,6 @@ def test_serialization():
ProjectID="Flyte", Location="Asia", QueryJobConfig=QueryJobConfig(allow_large_results=True)
),
query_template=query_template,
# the schema literal's backend uri will be equal to the value of .raw_output_data
output_schema_type=FlyteSchema,
)

Expand All @@ -49,7 +49,7 @@ def my_wf(ds: str) -> FlyteSchema:
assert task_spec.template.sql.dialect == task_spec.template.sql.Dialect.ANSI
s = Struct()
s.update({"ProjectID": "Flyte", "Location": "Asia", "allowLargeResults": True})
assert task_spec.template.custom == s
assert task_spec.template.custom == json_format.MessageToDict(s)
assert len(task_spec.template.interface.inputs) == 1
assert len(task_spec.template.interface.outputs) == 1

Expand All @@ -66,7 +66,6 @@ def test_local_exec():
inputs=kwtypes(ds=str),
query_template=query_template,
task_config=BigQueryConfig(ProjectID="Flyte", Location="Asia"),
# the schema literal's backend uri will be equal to the value of .raw_output_data
output_schema_type=FlyteSchema,
)

Expand Down

0 comments on commit 57cce54

Please sign in to comment.