Skip to content

Commit

Permalink
add async batch sink mode
Browse files Browse the repository at this point in the history
  • Loading branch information
hsheth2 committed Jun 18, 2024
1 parent e180313 commit d44be98
Show file tree
Hide file tree
Showing 3 changed files with 304 additions and 21 deletions.
12 changes: 5 additions & 7 deletions metadata-ingestion/src/datahub/ingestion/graph/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,17 @@ def _post_generic(self, url: str, payload_dict: Dict) -> Dict:
def _make_rest_sink_config(self) -> "DatahubRestSinkConfig":
from datahub.ingestion.sink.datahub_rest import (
DatahubRestSinkConfig,
SyncOrAsync,
RestSinkMode,
)

# This is a bit convoluted - this DataHubGraph class is a subclass of DatahubRestEmitter,
# but initializing the rest sink creates another rest emitter.
# TODO: We should refactor out the multithreading functionality of the sink
# into a separate class that can be used by both the sink and the graph client
# e.g. a DatahubBulkRestEmitter that both the sink and the graph client use.
return DatahubRestSinkConfig(**self.config.dict(), mode=SyncOrAsync.ASYNC)
return DatahubRestSinkConfig(
**self.config.dict(), mode=RestSinkMode.ASYNC_BATCH
)

@contextlib.contextmanager
def make_rest_sink(
Expand Down Expand Up @@ -252,14 +254,10 @@ def emit_all(
) -> None:
"""Emit all items in the iterable using multiple threads."""

# The context manager also ensures that we raise an error if a failure occurs.
with self.make_rest_sink(run_id=run_id) as sink:
for item in items:
sink.emit_async(item)
if sink.report.failures:
raise OperationalError(
f"Failed to emit {len(sink.report.failures)} records",
info=sink.report.as_obj(),
)

def get_aspect(
self,
Expand Down
78 changes: 67 additions & 11 deletions metadata-ingestion/src/datahub/ingestion/sink/datahub_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import threading
import uuid
from enum import auto
from typing import Optional, Union
from typing import List, Optional, Tuple, Union

from datahub.cli.cli_utils import set_env_variables_override_config
from datahub.configuration.common import (
Expand All @@ -16,6 +16,7 @@
OperationalError,
)
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.emitter.mcp_builder import mcps_from_mce
from datahub.emitter.rest_emitter import DataHubRestEmitter
from datahub.ingestion.api.common import RecordEnvelope, WorkUnit
from datahub.ingestion.api.sink import (
Expand All @@ -30,7 +31,10 @@
MetadataChangeEvent,
MetadataChangeProposal,
)
from datahub.utilities.advanced_thread_executor import PartitionExecutor
from datahub.utilities.advanced_thread_executor import (
BatchPartitionExecutor,
PartitionExecutor,
)
from datahub.utilities.perf_timer import PerfTimer
from datahub.utilities.server_config_util import set_gms_config

Expand All @@ -41,15 +45,20 @@
)


class SyncOrAsync(ConfigEnum):
class RestSinkMode(ConfigEnum):
SYNC = auto()
ASYNC = auto()

# Uses the new ingestProposalBatch endpoint. Significantly more efficient than the other modes,
# but requires a server version that supports it.
# https://github.com/datahub-project/datahub/pull/10706
ASYNC_BATCH = auto()


class DatahubRestSinkConfig(DatahubClientConfig):
mode: SyncOrAsync = SyncOrAsync.ASYNC
mode: RestSinkMode = RestSinkMode.ASYNC

# These only apply in async mode.
# These only apply in async modes.
max_threads: int = DEFAULT_REST_SINK_MAX_THREADS
max_pending_requests: int = 2000

Expand Down Expand Up @@ -111,10 +120,19 @@ def __post_init__(self) -> None:
set_env_variables_override_config(self.config.server, self.config.token)
logger.debug("Setting gms config")
set_gms_config(gms_config)
self.executor = PartitionExecutor(
max_workers=self.config.max_threads,
max_pending=self.config.max_pending_requests,
)

if self.config.mode == RestSinkMode.ASYNC_BATCH:
self.executor = BatchPartitionExecutor(
max_workers=self.config.max_threads,
max_pending=self.config.max_pending_requests,
process_batch=self._emit_batch_wrapper,
# TODO: make other things configurable
)
else:
self.executor = PartitionExecutor(
max_workers=self.config.max_threads,
max_pending=self.config.max_pending_requests,
)

@classmethod
def _make_emitter(cls, config: DatahubRestSinkConfig) -> DataHubRestEmitter:
Expand Down Expand Up @@ -189,6 +207,7 @@ def _write_done_callback(
self.report.report_warning({"warning": e.message, "info": e.info})
write_callback.on_failure(record_envelope, e, e.info)
else:
logger.exception(f"Failure: {e}", exc_info=e)
self.report.report_failure({"e": e})
write_callback.on_failure(record_envelope, Exception(e), {})

Expand All @@ -203,6 +222,30 @@ def _emit_wrapper(
# TODO: Add timing metrics
self.emitter.emit(record)

def _emit_batch_wrapper(
self,
records: List[
Tuple[
Union[
MetadataChangeEvent,
MetadataChangeProposal,
MetadataChangeProposalWrapper,
],
]
],
) -> None:
events = []
for record in records:
event = record[0]
if isinstance(event, MetadataChangeEvent):
# Unpack MCEs into MCPs.
mcps = mcps_from_mce(event)
events.extend(mcps)
else:
events.append(event)

self.emitter.emit_mcps(events)

def write_record_async(
self,
record_envelope: RecordEnvelope[
Expand All @@ -218,7 +261,8 @@ def write_record_async(
# should only have a high value if the sink is actually a bottleneck.
with self.report.main_thread_blocking_timer:
record = record_envelope.record
if self.config.mode == SyncOrAsync.ASYNC:
if self.config.mode == RestSinkMode.ASYNC:
assert isinstance(self.executor, PartitionExecutor)
partition_key = _get_partition_key(record_envelope)
self.executor.submit(
partition_key,
Expand All @@ -229,6 +273,17 @@ def write_record_async(
),
)
self.report.pending_requests += 1
elif self.config.mode == RestSinkMode.ASYNC_BATCH:
assert isinstance(self.executor, BatchPartitionExecutor)
partition_key = _get_partition_key(record_envelope)
self.executor.submit(
partition_key,
record,
done_callback=functools.partial(
self._write_done_callback, record_envelope, write_callback
),
)
self.report.pending_requests += 1
else:
# execute synchronously
try:
Expand All @@ -249,7 +304,8 @@ def emit_async(
)

def close(self):
self.executor.shutdown()
with self.report.main_thread_blocking_timer:
self.executor.shutdown()

def __repr__(self) -> str:
return self.emitter.__repr__()
Expand Down
Loading

0 comments on commit d44be98

Please sign in to comment.