From 7e482a5f113a8ee90071c08c0109e542fdf0533a Mon Sep 17 00:00:00 2001 From: prmoore77 Date: Thu, 22 Aug 2024 19:11:49 -0400 Subject: [PATCH] feat: add the latest tag to the docker image (#79) This PR also fixes the client-demo when the server is run in a Docker or Kubernetes (k8s) environment. --- .dockerignore | 2 + .github/workflows/build_docker.yml | 3 + Dockerfile | 4 +- pyproject.toml | 1 + scripts/start_demo_server.sh | 14 ++ src/gateway/demo/client_demo.py | 33 +++-- .../demo/generate_tpch_parquet_data.py | 128 ++++++++++++++++++ 7 files changed, 173 insertions(+), 12 deletions(-) create mode 100755 scripts/start_demo_server.sh create mode 100644 src/gateway/demo/generate_tpch_parquet_data.py diff --git a/.dockerignore b/.dockerignore index c4ccd29..060db1b 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,3 +2,5 @@ helm-chart tls .github +Dockerfile +.gitignore diff --git a/.github/workflows/build_docker.yml b/.github/workflows/build_docker.yml index 2e8d0bf..b7ca00d 100644 --- a/.github/workflows/build_docker.yml +++ b/.github/workflows/build_docker.yml @@ -32,6 +32,9 @@ jobs: uses: docker/metadata-action@v5 with: images: ${{ env.IMAGE_NAME }} + tags: | + # set latest tag for default branch + type=raw,value=latest,enable={{is_default_branch}} - name: Set up QEMU uses: docker/setup-qemu-action@v3 diff --git a/Dockerfile b/Dockerfile index 232acd0..509ebd1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -47,4 +47,6 @@ RUN pip install . # Expose the gRPC port EXPOSE 50051 -ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "spark-substrait-gateway-env", "python", "src/gateway/server.py"] +ENV GENERATE_CLIENT_DEMO_DATA="true" + +ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "spark-substrait-gateway-env", "scripts/start_demo_server.sh"] diff --git a/pyproject.toml b/pyproject.toml index 71fcdde..24d3d00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,3 +83,4 @@ spark-substrait-gateway-server = "gateway.server:click_serve" spark-substrait-client-demo = "gateway.demo.client_demo:click_run_demo" spark-substrait-create-tls-keypair = "gateway.setup.tls_utilities:click_create_tls_keypair" spark-substrait-create-jwt = "gateway.utilities.create_jwt:main" +spark-substrait-create-client-demo-data = "gateway.demo.generate_tpch_parquet_data:click_generate_tpch_parquet_data" diff --git a/scripts/start_demo_server.sh b/scripts/start_demo_server.sh new file mode 100755 index 0000000..f305ef2 --- /dev/null +++ b/scripts/start_demo_server.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# SPDX-License-Identifier: Apache-2.0 + +# This script starts the Spark Substrait Gateway demo server. +# It will create demo TPC-H (Scale Factor 1GB) data, and start the server. + +set -e + +if [ $(echo "${GENERATE_CLIENT_DEMO_DATA}" | tr '[:upper:]' '[:lower:]') == "true" ]; then + echo "Generating client demo TPC-H data..." + spark-substrait-create-client-demo-data +fi + +spark-substrait-gateway-server diff --git a/src/gateway/demo/client_demo.py b/src/gateway/demo/client_demo.py index 1e29b49..0cc7b2e 100644 --- a/src/gateway/demo/client_demo.py +++ b/src/gateway/demo/client_demo.py @@ -3,6 +3,7 @@ import logging import os +import sys from pathlib import Path import click @@ -11,30 +12,40 @@ from gateway.config import SERVER_PORT -_LOGGER = logging.getLogger(__name__) +# Setup logging +logging.basicConfig( + format="%(asctime)s - %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S %Z", + level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")), + stream=sys.stdout, +) +_LOGGER = logging.getLogger() + +# Constants +CLIENT_DEMO_DATA_LOCATION = Path("data") / "tpch" / "parquet" -def find_tpch() -> Path: +def find_tpch(raise_error_if_not_exists: bool) -> Path: """Find the location of the TPCH dataset.""" - location = Path("third_party") / "tpch" / "parquet" - if location.exists(): - return location - raise ValueError("TPCH dataset not found") + location = CLIENT_DEMO_DATA_LOCATION + if raise_error_if_not_exists and not location.exists(): + raise ValueError("TPCH dataset not found") + return location # pylint: disable=fixme -def get_customer_database(spark_session: SparkSession) -> DataFrame: +def get_customer_database(spark_session: SparkSession, use_gateway: bool) -> DataFrame: """Register the TPC-H customer database.""" - location_customer = str(find_tpch() / "customer") + location_customer = str(find_tpch(raise_error_if_not_exists=(not use_gateway)) / "customer") return spark_session.read.parquet(location_customer, mergeSchema=False) # pylint: disable=fixme # ruff: noqa: T201 -def execute_query(spark_session: SparkSession) -> None: +def execute_query(spark_session: SparkSession, use_gateway: bool) -> None: """Run a single sample query against the gateway.""" - df_customer = get_customer_database(spark_session) + df_customer = get_customer_database(spark_session=spark_session, use_gateway=use_gateway) df_customer.createOrReplaceTempView("customer") @@ -76,7 +87,7 @@ def run_demo( spark = SparkSession.builder.remote(f"sc://{host}:{port}/{uri_parameters}").getOrCreate() else: spark = SparkSession.builder.master("local").getOrCreate() - execute_query(spark) + execute_query(spark_session=spark, use_gateway=use_gateway) @click.command() diff --git a/src/gateway/demo/generate_tpch_parquet_data.py b/src/gateway/demo/generate_tpch_parquet_data.py new file mode 100644 index 0000000..c9eae49 --- /dev/null +++ b/src/gateway/demo/generate_tpch_parquet_data.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +"""A utility module for generating TPC-H parquet data for the client demo.""" + +import logging +import os +import shutil +import sys +from pathlib import Path + +import click +import duckdb + +from .client_demo import CLIENT_DEMO_DATA_LOCATION + +# Setup logging +logging.basicConfig( + format="%(asctime)s - %(levelname)-8s %(message)s", + datefmt="%Y-%m-%d %H:%M:%S %Z", + level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")), + stream=sys.stdout, +) +_LOGGER = logging.getLogger() + + +def execute_query(conn: duckdb.DuckDBPyConnection, query: str): + """Execute and log a query with a DuckDB connection.""" + _LOGGER.info(msg=f"Executing SQL: '{query}'") + conn.execute(query=query) + + +def get_printable_number(num: float): + """Return a number in a printable format.""" + return f"{num:.9g}" + + +def generate_tpch_parquet_data( + tpch_scale_factor: int, data_directory: str, overwrite: bool +) -> Path: + """Generate a TPC-H parquet dataset.""" + _LOGGER.info( + msg=( + "Creating a TPC-H parquet dataset - with parameters: " + f"--tpch-scale-factor={tpch_scale_factor} " + f"--data-directory='{data_directory}' " + f"--overwrite={overwrite}" + ) + ) + + # Output the database version + _LOGGER.info(msg=f"Using DuckDB Version: {duckdb.__version__}") + + # Get an in-memory DuckDB database connection + conn = duckdb.connect() + + # Load the TPCH extension needed to generate the data... + conn.load_extension(extension="tpch") + + # Generate the data + execute_query(conn=conn, query=f"CALL dbgen(sf={tpch_scale_factor})") + + # Export the data + target_directory = Path(data_directory) + + if target_directory.exists(): + if overwrite: + _LOGGER.warning(msg=f"Directory: {target_directory.as_posix()} exists, removing...") + shutil.rmtree(path=target_directory.as_posix()) + else: + raise RuntimeError(f"Directory: {target_directory.as_posix()} exists, aborting.") + + target_directory.mkdir(parents=True, exist_ok=True) + execute_query( + conn=conn, query=f"EXPORT DATABASE '{target_directory.as_posix()}' (FORMAT PARQUET)" + ) + + _LOGGER.info(msg=f"Wrote out parquet data to path: '{target_directory.as_posix()}'") + + # Restructure the contents of the directory so that each file is in its own directory + for filename in target_directory.glob(pattern="*.parquet"): + file = Path(filename) + table_name = file.stem + table_directory = target_directory / table_name + table_directory.mkdir(parents=True, exist_ok=True) + + if file.name not in ("nation.parquet", "region.parquet"): + new_file_name = f"{table_name}.1.parquet" + else: + new_file_name = file.name + + file.rename(target=table_directory / new_file_name) + + _LOGGER.info(msg="All done.") + + return target_directory + + +@click.command() +@click.option( + "--tpch-scale-factor", + type=float, + default=1, + show_default=True, + required=True, + help="The TPC-H scale factor to generate.", +) +@click.option( + "--data-directory", + type=str, + default=CLIENT_DEMO_DATA_LOCATION.as_posix(), + show_default=True, + required=True, + help="The target output data directory to put the files into", +) +@click.option( + "--overwrite/--no-overwrite", + type=bool, + default=False, + show_default=True, + required=True, + help="Can we overwrite the target directory if it already exists...", +) +def click_generate_tpch_parquet_data(tpch_scale_factor: int, data_directory: str, overwrite: bool): + """Provide a click interface for generating TPC-H parquet data.""" + generate_tpch_parquet_data(**locals()) + + +if __name__ == "__main__": + click_generate_tpch_parquet_data()