Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Multiple perf improvements and viztracer performance profiling #298

Merged
merged 18 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .viztracerrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
; # Default config settomgs for viztracer
aaronsteers marked this conversation as resolved.
Show resolved Hide resolved
; # https://viztracer.readthedocs.io/en/latest/basic_usage.html#configuration-file

[default]
max_stack_depth = 20
unique_output_file = True
output_file = viztracer_report.json
tracer_entries = 5_000_000
14 changes: 13 additions & 1 deletion airbyte/_future_cdk/sql_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ def __init__(
)
self.type_converter = self.type_converter_class()
self._cached_table_definitions: dict[str, sqlalchemy.Table] = {}

self._known_schemas_list: list[str] = []
self._ensure_schema_exists()

# Public interface:
Expand Down Expand Up @@ -305,6 +307,10 @@ def _ensure_schema_exists(
) -> None:
"""Return a new (unique) temporary table name."""
schema_name = self.sql_config.schema_name

if self._known_schemas_list and self.sql_config.schema_name in self._known_schemas_list:
return # Already exists

if schema_name in self._get_schemas_list():
return

Expand Down Expand Up @@ -372,17 +378,23 @@ def _get_tables_list(
def _get_schemas_list(
self,
database_name: str | None = None,
*,
force_refresh: bool = False,
) -> list[str]:
"""Return a list of all tables in the database."""
if not force_refresh and self._known_schemas_list:
return self._known_schemas_list

inspector: Inspector = sqlalchemy.inspect(self.get_sql_engine())
database_name = database_name or self.database_name
found_schemas = inspector.get_schema_names()
return [
self._known_schemas_list = [
found_schema.split(".")[-1].strip('"')
for found_schema in found_schemas
if "." not in found_schema
or (found_schema.split(".")[0].lower().strip('"') == database_name.lower())
]
return self._known_schemas_list

def _ensure_final_table_exists(
self,
Expand Down
2 changes: 1 addition & 1 deletion airbyte/_processors/file/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)


DEFAULT_BATCH_SIZE = 10000
DEFAULT_BATCH_SIZE = 100_000


class FileWriterBase(abc.ABC):
Expand Down
10 changes: 9 additions & 1 deletion airbyte/_processors/sql/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor
from textwrap import dedent, indent
from typing import TYPE_CHECKING

Expand All @@ -26,6 +27,9 @@
from sqlalchemy.engine import Connection


MAX_UPLOAD_THREADS = 8


class SnowflakeConfig(SqlConfig):
"""Configuration for the Snowflake cache."""

Expand Down Expand Up @@ -120,10 +124,14 @@ def _write_files_to_new_table(
def path_str(path: Path) -> str:
return str(path.absolute()).replace("\\", "\\\\")

for file_path in files:
def upload_file(file_path: str) -> None:
query = f"PUT 'file://{path_str(file_path)}' {internal_sf_stage_name};"
self._execute_sql(query)

# Upload files in parallel
with ThreadPoolExecutor(max_workers=MAX_UPLOAD_THREADS) as executor:
executor.map(upload_file, files)
aaronsteers marked this conversation as resolved.
Show resolved Hide resolved

columns_list = [
self._quote_identifier(c)
for c in list(self._get_sql_column_definitions(stream_name).keys())
Expand Down
2 changes: 2 additions & 0 deletions airbyte/_util/name_normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import annotations

import abc
import functools
import re
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -48,6 +49,7 @@ class LowerCaseNormalizer(NameNormalizerBase):
"""A name normalizer that converts names to lower case."""

@staticmethod
@functools.cache
def normalize(name: str) -> str:
"""Return the normalized name.

Expand Down
19 changes: 11 additions & 8 deletions airbyte/progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,20 @@ def _to_time_str(timestamp: float) -> str:
return datetime_obj.strftime("%H:%M:%S")


def _get_elapsed_time_str(seconds: int) -> str:
def _get_elapsed_time_str(seconds: float) -> str:
"""Return duration as a string.

Seconds are included until 10 minutes is exceeded.
Minutes are always included after 1 minute elapsed.
Hours are always included after 1 hour elapsed.
"""
if seconds <= 60: # noqa: PLR2004 # Magic numbers OK here.
return f"{seconds} seconds"
return f"{seconds:.0f} seconds"

if seconds < 60 * 10:
minutes = seconds // 60
seconds %= 60
return f"{minutes}min {seconds}s"
return f"{minutes}min {seconds:.0f}s"

if seconds < 60 * 60:
minutes = seconds // 60
Expand Down Expand Up @@ -280,12 +280,12 @@ def elapsed_seconds_since_last_update(self) -> float | None:
return time.time() - self.last_update_time

@property
def elapsed_read_seconds(self) -> int:
def elapsed_read_seconds(self) -> float:
"""Return the number of seconds elapsed since the read operation started."""
if self.read_end_time is None:
return int(time.time() - self.read_start_time)
return time.time() - self.read_start_time

return int(self.read_end_time - self.read_start_time)
return self.read_end_time - self.read_start_time

@property
def elapsed_read_time_string(self) -> str:
Expand Down Expand Up @@ -366,7 +366,7 @@ def update_display(self, *, force_refresh: bool = False) -> None:
if (
not force_refresh
and self.last_update_time # if not set, then we definitely need to update
and cast(float, self.elapsed_seconds_since_last_update) < 0.5 # noqa: PLR2004
and cast(float, self.elapsed_seconds_since_last_update) < 0.8 # noqa: PLR2004
):
return

Expand Down Expand Up @@ -396,7 +396,10 @@ def _get_status_message(self) -> str:
start_time_str = _to_time_str(self.read_start_time)
records_per_second: float = 0.0
if self.elapsed_read_seconds > 0:
records_per_second = round(self.total_records_read / self.elapsed_read_seconds, 1)
records_per_second = round(
float(self.total_records_read) / self.elapsed_read_seconds,
ndigits=1,
)
status_message = (
f"## Read Progress\n\n"
f"Started reading at {start_time_str}.\n\n"
Expand Down
2 changes: 1 addition & 1 deletion examples/run_bigquery_faker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main() -> None:

cache = BigQueryCache(
project_name=bigquery_destination_secret["project_id"],
dataset_name=bigquery_destination_secret["dataset_id"],
dataset_name=bigquery_destination_secret.get("dataset_id", "pyairbyte_integtest"),
credentials_path=temp.name,
)

Expand Down
51 changes: 51 additions & 0 deletions examples/run_perf_test_cache_snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
"""
Simple script to get performance profile of a Snowflake cache.

For performance profiling:
```
poetry run viztracer --open -- ./examples/run_perf_test_cache_snowflake.py
```
"""

from __future__ import annotations

import airbyte as ab
from airbyte.caches import SnowflakeCache
from airbyte.secrets.google_gsm import GoogleGSMSecretManager


AIRBYTE_INTERNAL_GCP_PROJECT = "dataline-integration-testing"
secret_mgr = GoogleGSMSecretManager(
project=AIRBYTE_INTERNAL_GCP_PROJECT,
credentials_json=ab.get_secret("GCP_GSM_CREDENTIALS"),
)

secret = secret_mgr.get_secret(
secret_name="AIRBYTE_LIB_SNOWFLAKE_CREDS",
)
assert secret is not None, "Secret not found."
secret_config = secret.parse_json()


cache = SnowflakeCache(
account=secret_config["account"],
username=secret_config["username"],
password=secret_config["password"],
database=secret_config["database"],
warehouse=secret_config["warehouse"],
role=secret_config["role"],
)

source = ab.get_source(
"source-pokeapi",
config={"pokemon_name": "bulbasaur"},
source_manifest=True,
)
source.check()

source.select_streams(["products"])
result = source.read(cache)

for name in ["products"]:
print(f"Stream {name}: {len(list(result[name]))} records")
119 changes: 119 additions & 0 deletions examples/run_perf_test_reads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
"""
Simple script to get performance profile of read throughput.

This script accepts a single argument `-e=SCALE` as a power of 10.

-e=4 is equivalent to 50_000 records.
-e=5 is equivalent to 500_000 records.
-e=6 is equivalent to 5_000_000 records.

Use smaller values of `e` (2-3) to understand read and overhead costs.
Use larger values of `e` (4-5) to understand write throughput at scale.

For performance profiling:
```
poetry run viztracer --open -- ./examples/run_perf_test_reads.py -e=3
poetry run viztracer --open -- ./examples/run_perf_test_reads.py -e=5
```
"""

from __future__ import annotations

import argparse
import tempfile

import airbyte as ab
from airbyte.caches import BigQueryCache, CacheBase, SnowflakeCache
from airbyte.secrets.google_gsm import GoogleGSMSecretManager


AIRBYTE_INTERNAL_GCP_PROJECT = "dataline-integration-testing"


def get_gsm_secret_json(secret_name: str) -> dict:
secret_mgr = GoogleGSMSecretManager(
project=AIRBYTE_INTERNAL_GCP_PROJECT,
credentials_json=ab.get_secret("GCP_GSM_CREDENTIALS"),
)
secret = secret_mgr.get_secret(
secret_name=secret_name,
)
assert secret is not None, "Secret not found."
return secret.parse_json()
aaronsteers marked this conversation as resolved.
Show resolved Hide resolved


def main(
e: int = 4,
cache_type: str = "duckdb",
) -> None:
e = e or 4
cache_type = cache_type or "duckdb"

cache: CacheBase
if cache_type == "duckdb":
cache = ab.new_local_cache()

elif cache_type == "snowflake":
secret_config = get_gsm_secret_json(
secret_name="AIRBYTE_LIB_SNOWFLAKE_CREDS",
)
cache = SnowflakeCache(
account=secret_config["account"],
username=secret_config["username"],
password=secret_config["password"],
database=secret_config["database"],
warehouse=secret_config["warehouse"],
role=secret_config["role"],
)

elif cache_type == "bigquery":
temp = tempfile.NamedTemporaryFile(mode="w+", delete=False, encoding="utf-8")
secret_config = get_gsm_secret_json(
secret_name="SECRET_DESTINATION-BIGQUERY_CREDENTIALS__CREDS",
)
try:
# Write credentials to the temp file
temp.write(secret_config["credentials_json"])
temp.flush()
finally:
temp.close()

cache = BigQueryCache(
project_name=secret_config["project_id"],
dataset_name=secret_config.get("dataset_id", "pyairbyte_integtest"),
credentials_path=temp.name,
)

source = ab.get_source(
"source-faker",
config={"count": 5 * (10**e)},
install_if_missing=False,
streams=["purchases"],
)
source.check()

source.read(cache)

aaronsteers marked this conversation as resolved.
Show resolved Hide resolved

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run performance test reads.")
parser.add_argument(
"-e",
type=int,
help=(
"The scale, as a power of 10. "
"Recommended values: 2-3 (500 or 5_000) for read and overhead costs, "
" 4-6 (50K or 5MM) for write throughput."
),
)
parser.add_argument(
"--cache",
type=str,
help="The cache type to use.",
choices=["duckdb", "snowflake", "bigquery"],
default="duckdb",
)
args = parser.parse_args()

main(e=args.e, cache_type=args.cache)
Loading
Loading