Skip to content

Commit

Permalink
feat: Adding billing_project_id in BigQueryOfflineStoreConfig (#3253)
Browse files Browse the repository at this point in the history
* adding_billing_project_in_config

Signed-off-by: “Varun <[email protected]>

* Fix lint

Signed-off-by: Danny Chiao <[email protected]>

Signed-off-by: “Varun <[email protected]>
Signed-off-by: Danny Chiao <[email protected]>
Co-authored-by: Danny Chiao <[email protected]>
  • Loading branch information
vmallya-123 and adchia authored Oct 1, 2022
1 parent 53dc811 commit f80f05f
Showing 1 changed file with 44 additions and 16 deletions.
60 changes: 44 additions & 16 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pandas as pd
import pyarrow
import pyarrow.parquet
from pydantic import StrictStr
from pydantic import StrictStr, validator
from pydantic.typing import Literal
from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed

Expand Down Expand Up @@ -83,7 +83,8 @@ class BigQueryOfflineStoreConfig(FeastConfigBaseModel):

project_id: Optional[StrictStr] = None
""" (optional) GCP project name used for the BigQuery offline store """

billing_project_id: Optional[StrictStr] = None
""" (optional) GCP project name used to run the bigquery jobs at """
location: Optional[StrictStr] = None
""" (optional) GCP location name used for the BigQuery offline store.
Examples of location names include ``US``, ``EU``, ``us-central1``, ``us-west4``.
Expand All @@ -94,6 +95,14 @@ class BigQueryOfflineStoreConfig(FeastConfigBaseModel):
gcs_staging_location: Optional[str] = None
""" (optional) GCS location used for offloading BigQuery results as parquet files."""

@validator("billing_project_id")
def project_id_exists(cls, v, values, **kwargs):
if v and not values["project_id"]:
raise ValueError(
"please specify project_id if billing_project_id is specified"
)
return v


class BigQueryOfflineStore(OfflineStore):
@staticmethod
Expand Down Expand Up @@ -122,9 +131,11 @@ def pull_latest_from_table_or_query(
timestamps.append(created_timestamp_column)
timestamp_desc_string = " DESC, ".join(timestamps) + " DESC"
field_string = ", ".join(join_key_columns + feature_name_columns + timestamps)

project_id = (
config.offline_store.billing_project_id or config.offline_store.project_id
)
client = _get_bigquery_client(
project=config.offline_store.project_id,
project=project_id,
location=config.offline_store.location,
)
query = f"""
Expand Down Expand Up @@ -162,9 +173,11 @@ def pull_all_from_table_or_query(
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
assert isinstance(data_source, BigQuerySource)
from_expression = data_source.get_table_query_string()

project_id = (
config.offline_store.billing_project_id or config.offline_store.project_id
)
client = _get_bigquery_client(
project=config.offline_store.project_id,
project=project_id,
location=config.offline_store.location,
)
field_string = ", ".join(
Expand Down Expand Up @@ -197,17 +210,22 @@ def get_historical_features(
assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)
for fv in feature_views:
assert isinstance(fv.batch_source, BigQuerySource)

project_id = (
config.offline_store.billing_project_id or config.offline_store.project_id
)
client = _get_bigquery_client(
project=config.offline_store.project_id,
project=project_id,
location=config.offline_store.location,
)

assert isinstance(config.offline_store, BigQueryOfflineStoreConfig)

if config.offline_store.billing_project_id:
dataset_project = str(config.offline_store.project_id)
else:
dataset_project = client.project
table_reference = _get_table_reference_for_new_entity(
client,
client.project,
dataset_project,
config.offline_store.dataset,
config.offline_store.location,
)
Expand Down Expand Up @@ -295,9 +313,11 @@ def write_logged_features(
):
destination = logging_config.destination
assert isinstance(destination, BigQueryLoggingDestination)

project_id = (
config.offline_store.billing_project_id or config.offline_store.project_id
)
client = _get_bigquery_client(
project=config.offline_store.project_id,
project=project_id,
location=config.offline_store.location,
)

Expand Down Expand Up @@ -353,9 +373,11 @@ def offline_write_batch(

if table.schema != pa_schema:
table = table.cast(pa_schema)

project_id = (
config.offline_store.billing_project_id or config.offline_store.project_id
)
client = _get_bigquery_client(
project=config.offline_store.project_id,
project=project_id,
location=config.offline_store.location,
)

Expand Down Expand Up @@ -451,7 +473,10 @@ def to_bigquery(
if not job_config:
today = date.today().strftime("%Y%m%d")
rand_id = str(uuid.uuid4())[:7]
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
if self.config.offline_store.billing_project_id:
path = f"{self.config.offline_store.project_id}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
else:
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
job_config = bigquery.QueryJobConfig(destination=path)

if not job_config.dry_run and self.on_demand_feature_views:
Expand Down Expand Up @@ -525,7 +550,10 @@ def to_remote_storage(self) -> List[str]:

bucket: str
prefix: str
storage_client = StorageClient(project=self.client.project)
if self.config.offline_store.billing_project_id:
storage_client = StorageClient(project=self.config.offline_store.project_id)
else:
storage_client = StorageClient(project=self.client.project)
bucket, prefix = self._gcs_path[len("gs://") :].split("/", 1)
prefix = prefix.rsplit("/", 1)[0]
if prefix.startswith("/"):
Expand Down

0 comments on commit f80f05f

Please sign in to comment.