Skip to content

Commit

Permalink
feat: add the latest tag to the docker image (#79)
Browse files Browse the repository at this point in the history
This PR also fixes the client-demo when the server is run in
 a Docker or Kubernetes (k8s) environment.
  • Loading branch information
prmoore77 authored Aug 22, 2024
1 parent c5e8943 commit 7e482a5
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 12 deletions.
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
helm-chart
tls
.github
Dockerfile
.gitignore
3 changes: 3 additions & 0 deletions .github/workflows/build_docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ jobs:
uses: docker/metadata-action@v5
with:
images: ${{ env.IMAGE_NAME }}
tags: |
# set latest tag for default branch
type=raw,value=latest,enable={{is_default_branch}}
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
Expand Down
4 changes: 3 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,6 @@ RUN pip install .
# Expose the gRPC port
EXPOSE 50051

ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "spark-substrait-gateway-env", "python", "src/gateway/server.py"]
ENV GENERATE_CLIENT_DEMO_DATA="true"

ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "spark-substrait-gateway-env", "scripts/start_demo_server.sh"]
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,4 @@ spark-substrait-gateway-server = "gateway.server:click_serve"
spark-substrait-client-demo = "gateway.demo.client_demo:click_run_demo"
spark-substrait-create-tls-keypair = "gateway.setup.tls_utilities:click_create_tls_keypair"
spark-substrait-create-jwt = "gateway.utilities.create_jwt:main"
spark-substrait-create-client-demo-data = "gateway.demo.generate_tpch_parquet_data:click_generate_tpch_parquet_data"
14 changes: 14 additions & 0 deletions scripts/start_demo_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/bin/bash
# SPDX-License-Identifier: Apache-2.0

# This script starts the Spark Substrait Gateway demo server.
# It will create demo TPC-H (Scale Factor 1GB) data, and start the server.

set -e

if [ $(echo "${GENERATE_CLIENT_DEMO_DATA}" | tr '[:upper:]' '[:lower:]') == "true" ]; then
echo "Generating client demo TPC-H data..."
spark-substrait-create-client-demo-data
fi

spark-substrait-gateway-server
33 changes: 22 additions & 11 deletions src/gateway/demo/client_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import os
import sys
from pathlib import Path

import click
Expand All @@ -11,30 +12,40 @@

from gateway.config import SERVER_PORT

_LOGGER = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S %Z",
level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")),
stream=sys.stdout,
)
_LOGGER = logging.getLogger()

# Constants
CLIENT_DEMO_DATA_LOCATION = Path("data") / "tpch" / "parquet"


def find_tpch() -> Path:
def find_tpch(raise_error_if_not_exists: bool) -> Path:
"""Find the location of the TPCH dataset."""
location = Path("third_party") / "tpch" / "parquet"
if location.exists():
return location
raise ValueError("TPCH dataset not found")
location = CLIENT_DEMO_DATA_LOCATION
if raise_error_if_not_exists and not location.exists():
raise ValueError("TPCH dataset not found")
return location


# pylint: disable=fixme
def get_customer_database(spark_session: SparkSession) -> DataFrame:
def get_customer_database(spark_session: SparkSession, use_gateway: bool) -> DataFrame:
"""Register the TPC-H customer database."""
location_customer = str(find_tpch() / "customer")
location_customer = str(find_tpch(raise_error_if_not_exists=(not use_gateway)) / "customer")

return spark_session.read.parquet(location_customer, mergeSchema=False)


# pylint: disable=fixme
# ruff: noqa: T201
def execute_query(spark_session: SparkSession) -> None:
def execute_query(spark_session: SparkSession, use_gateway: bool) -> None:
"""Run a single sample query against the gateway."""
df_customer = get_customer_database(spark_session)
df_customer = get_customer_database(spark_session=spark_session, use_gateway=use_gateway)

df_customer.createOrReplaceTempView("customer")

Expand Down Expand Up @@ -76,7 +87,7 @@ def run_demo(
spark = SparkSession.builder.remote(f"sc://{host}:{port}/{uri_parameters}").getOrCreate()
else:
spark = SparkSession.builder.master("local").getOrCreate()
execute_query(spark)
execute_query(spark_session=spark, use_gateway=use_gateway)


@click.command()
Expand Down
128 changes: 128 additions & 0 deletions src/gateway/demo/generate_tpch_parquet_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# SPDX-License-Identifier: Apache-2.0
"""A utility module for generating TPC-H parquet data for the client demo."""

import logging
import os
import shutil
import sys
from pathlib import Path

import click
import duckdb

from .client_demo import CLIENT_DEMO_DATA_LOCATION

# Setup logging
logging.basicConfig(
format="%(asctime)s - %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S %Z",
level=getattr(logging, os.getenv("LOG_LEVEL", "INFO")),
stream=sys.stdout,
)
_LOGGER = logging.getLogger()


def execute_query(conn: duckdb.DuckDBPyConnection, query: str):
"""Execute and log a query with a DuckDB connection."""
_LOGGER.info(msg=f"Executing SQL: '{query}'")
conn.execute(query=query)


def get_printable_number(num: float):
"""Return a number in a printable format."""
return f"{num:.9g}"


def generate_tpch_parquet_data(
tpch_scale_factor: int, data_directory: str, overwrite: bool
) -> Path:
"""Generate a TPC-H parquet dataset."""
_LOGGER.info(
msg=(
"Creating a TPC-H parquet dataset - with parameters: "
f"--tpch-scale-factor={tpch_scale_factor} "
f"--data-directory='{data_directory}' "
f"--overwrite={overwrite}"
)
)

# Output the database version
_LOGGER.info(msg=f"Using DuckDB Version: {duckdb.__version__}")

# Get an in-memory DuckDB database connection
conn = duckdb.connect()

# Load the TPCH extension needed to generate the data...
conn.load_extension(extension="tpch")

# Generate the data
execute_query(conn=conn, query=f"CALL dbgen(sf={tpch_scale_factor})")

# Export the data
target_directory = Path(data_directory)

if target_directory.exists():
if overwrite:
_LOGGER.warning(msg=f"Directory: {target_directory.as_posix()} exists, removing...")
shutil.rmtree(path=target_directory.as_posix())
else:
raise RuntimeError(f"Directory: {target_directory.as_posix()} exists, aborting.")

target_directory.mkdir(parents=True, exist_ok=True)
execute_query(
conn=conn, query=f"EXPORT DATABASE '{target_directory.as_posix()}' (FORMAT PARQUET)"
)

_LOGGER.info(msg=f"Wrote out parquet data to path: '{target_directory.as_posix()}'")

# Restructure the contents of the directory so that each file is in its own directory
for filename in target_directory.glob(pattern="*.parquet"):
file = Path(filename)
table_name = file.stem
table_directory = target_directory / table_name
table_directory.mkdir(parents=True, exist_ok=True)

if file.name not in ("nation.parquet", "region.parquet"):
new_file_name = f"{table_name}.1.parquet"
else:
new_file_name = file.name

file.rename(target=table_directory / new_file_name)

_LOGGER.info(msg="All done.")

return target_directory


@click.command()
@click.option(
"--tpch-scale-factor",
type=float,
default=1,
show_default=True,
required=True,
help="The TPC-H scale factor to generate.",
)
@click.option(
"--data-directory",
type=str,
default=CLIENT_DEMO_DATA_LOCATION.as_posix(),
show_default=True,
required=True,
help="The target output data directory to put the files into",
)
@click.option(
"--overwrite/--no-overwrite",
type=bool,
default=False,
show_default=True,
required=True,
help="Can we overwrite the target directory if it already exists...",
)
def click_generate_tpch_parquet_data(tpch_scale_factor: int, data_directory: str, overwrite: bool):
"""Provide a click interface for generating TPC-H parquet data."""
generate_tpch_parquet_data(**locals())


if __name__ == "__main__":
click_generate_tpch_parquet_data()

0 comments on commit 7e482a5

Please sign in to comment.