Skip to content

Commit

Permalink
[embedded-elt][sling] passing translator and replication config from …
Browse files Browse the repository at this point in the history
…decorator using metadata (#20564)

## Summary & Motivation

The current implementation of the Sling integration requires passing the
_DagsterSlingTranslator_ and replication config to both the
`@sling_assets` decorator, and the `replicate` method in the function
body.

Following the approach of `dagster-dbt` and the soon-to-be _dlt_
integration. this pull request demonstrates how we can pass the
translator to the `replicate` method through the use of metadata. This
workaround leads to a more intuitive end-user experience, albeit with
more complex implementation.

The negative to this approach is that the _translator_ and
_replication-config_ remain on the metadata, see the modified unit test.

## How I Tested These Changes

Ran unit tests.
  • Loading branch information
cmpadden authored and PedramNavid committed Mar 28, 2024
1 parent 2fb2a6c commit f18b798
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 70 deletions.
18 changes: 3 additions & 15 deletions docs/content/integrations/embedded-elt.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,7 @@ Now you can define a Sling asset using the <PyObject module="dagster_embedded_el
Each stream will render two assets, one for the source stream and one for the target destination. You may override how assets are named by passing in a custom <PyObject module="dagster_embedded_elt.sling" object="DagsterSlingTranslator" /> object.

```python file=/integrations/embedded_elt/sling_dagster_translator.py
from dagster_embedded_elt import sling
from dagster_embedded_elt.sling import (
DagsterSlingTranslator,
SlingResource,
sling_assets,
)
Expand All @@ -137,10 +135,7 @@ sling_resource = SlingResource(connections=[...]) # Add connections here

@sling_assets(replication_config=replication_config)
def my_assets(context, sling: SlingResource):
yield from sling.replicate(
replication_config=replication_config,
dagster_sling_translator=DagsterSlingTranslator(),
)
yield from sling.replicate(context=context)
for row in sling.stream_raw_logs():
context.log.info(row)

Expand All @@ -167,7 +162,6 @@ This is an example of how to setup a Sling sync between two databases such as Po

```python file=/integrations/embedded_elt/postgres_snowflake.py
from dagster_embedded_elt.sling import (
DagsterSlingTranslator,
SlingConnectionResource,
SlingResource,
sling_assets,
Expand Down Expand Up @@ -215,10 +209,7 @@ replication_config = {

@sling_assets(replication_config=replication_config)
def my_assets(context, sling: SlingResource):
yield from sling.replicate(
replication_config=replication_config,
dagster_sling_translator=DagsterSlingTranslator(),
)
yield from sling.replicate(context=context)
```

## Example 2: File to Database
Expand Down Expand Up @@ -251,10 +242,7 @@ replication_config = {

@sling_assets(replication_config=replication_config)
def my_assets(context, sling: SlingResource):
yield from sling.replicate(
replication_config=replication_config,
dagster_sling_translator=DagsterSlingTranslator(),
)
yield from sling.replicate(context=context)
```

---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# pyright: reportOptionalMemberAccess=none

from dagster_embedded_elt.sling import (
DagsterSlingTranslator,
SlingConnectionResource,
SlingResource,
sling_assets,
Expand Down Expand Up @@ -50,7 +49,4 @@

@sling_assets(replication_config=replication_config)
def my_assets(context, sling: SlingResource):
yield from sling.replicate(
replication_config=replication_config,
dagster_sling_translator=DagsterSlingTranslator(),
)
yield from sling.replicate(context=context)
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# pyright: reportOptionalMemberAccess=none

from dagster_embedded_elt.sling import (
DagsterSlingTranslator,
SlingConnectionResource,
SlingResource,
sling_assets,
Expand Down Expand Up @@ -47,10 +46,7 @@

@sling_assets(replication_config=replication_config)
def my_assets(context, sling: SlingResource):
yield from sling.replicate(
replication_config=replication_config,
dagster_sling_translator=DagsterSlingTranslator(),
)
yield from sling.replicate(context=context)


# end_storage_config
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from dagster_embedded_elt import sling
from dagster_embedded_elt.sling import (
DagsterSlingTranslator,
SlingResource,
sling_assets,
)
Expand All @@ -13,10 +11,7 @@

@sling_assets(replication_config=replication_config)
def my_assets(context, sling: SlingResource):
yield from sling.replicate(
replication_config=replication_config,
dagster_sling_translator=DagsterSlingTranslator(),
)
yield from sling.replicate(context=context)
for row in sling.stream_raw_logs():
context.log.info(row)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from dagster import Definitions, file_relative_path
from dagster_embedded_elt.sling import (
DagsterSlingTranslator,
sling_assets,
)
from dagster_embedded_elt.sling.resources import (
Expand Down Expand Up @@ -28,10 +27,7 @@

@sling_assets(replication_config=replication_config)
def my_assets(context, sling: SlingResource):
yield from sling.replicate(
replication_config=replication_config,
dagster_sling_translator=DagsterSlingTranslator(),
)
yield from sling.replicate(context=context)
for row in sling.stream_raw_logs():
context.log.info(row)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ export const HIDDEN_METADATA_ENTRY_LABELS = new Set([
'dagster-dbt/exclude',
'dagster_dbt/manifest',
'dagster_dbt/dagster_dbt_translator',
'dagster_embedded_elt/dagster_sling_translator',
'dagster_embedded_elt/sling_replication_config',
]);

export const LogRowStructuredContentTable = ({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from dagster_embedded_elt.sling.dagster_sling_translator import DagsterSlingTranslator
from dagster_embedded_elt.sling.sling_replication import SlingReplicationParam, validate_replication

METADATA_KEY_TRANSLATOR = "dagster_embedded_elt/dagster_sling_translator"
METADATA_KEY_REPLICATION_CONFIG = "dagster_embedded_elt/sling_replication_config"


def get_streams_from_replication(
replication_config: Mapping[str, Any],
Expand Down Expand Up @@ -71,11 +74,7 @@ def sling_assets(
config_path = "/path/to/replication.yaml"
@sling_assets(replication_config=config_path)
def my_assets(context, sling: SlingResource):
for lines in sling.replicate(
replication_config=config_path,
dagster_sling_translator=DagsterSlingTranslator(),
):
context.log.info(lines)
yield from sling.replicate(context=context)
"""
replication_config = validate_replication(replication_config)
streams = get_streams_from_replication(replication_config)
Expand All @@ -88,7 +87,11 @@ def my_assets(context, sling: SlingResource):
key=dagster_sling_translator.get_asset_key(stream),
deps=dagster_sling_translator.get_deps_asset_key(stream),
description=dagster_sling_translator.get_description(stream),
metadata=dagster_sling_translator.get_metadata(stream),
metadata={ # type: ignore
**dagster_sling_translator.get_metadata(stream),
METADATA_KEY_TRANSLATOR: dagster_sling_translator,
METADATA_KEY_REPLICATION_CONFIG: replication_config,
},
group_name=dagster_sling_translator.get_group_name(stream),
freshness_policy=dagster_sling_translator.get_freshness_policy(stream),
auto_materialize_policy=dagster_sling_translator.get_auto_materialize_policy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
import uuid
from enum import Enum
from subprocess import PIPE, STDOUT, Popen
from typing import IO, Any, AnyStr, Dict, Generator, Iterator, List, Optional
from typing import IO, Any, AnyStr, Dict, Generator, Iterator, List, Optional, Union

import sling
from dagster import (
AssetExecutionContext,
AssetMaterialization,
ConfigurableResource,
EnvVar,
MaterializeResult,
OpExecutionContext,
PermissiveConfig,
get_dagster_logger,
)
Expand All @@ -23,7 +26,11 @@
from dagster._utils.warnings import deprecation_warning
from pydantic import Field

from dagster_embedded_elt.sling.asset_decorator import get_streams_from_replication
from dagster_embedded_elt.sling.asset_decorator import (
METADATA_KEY_REPLICATION_CONFIG,
METADATA_KEY_TRANSLATOR,
get_streams_from_replication,
)
from dagster_embedded_elt.sling.dagster_sling_translator import DagsterSlingTranslator
from dagster_embedded_elt.sling.sling_replication import SlingReplicationParam, validate_replication

Expand Down Expand Up @@ -355,24 +362,35 @@ def sync(

yield from self._exec_sling_cmd(cmd, encoding=encoding)

@public
def replicate(
self,
*,
replication_config: SlingReplicationParam,
dagster_sling_translator: DagsterSlingTranslator,
context: Union[OpExecutionContext, AssetExecutionContext],
replication_config: Optional[SlingReplicationParam] = None,
dagster_sling_translator: Optional[DagsterSlingTranslator] = None,
debug: bool = False,
) -> Generator[MaterializeResult, None, None]:
) -> Generator[Union[MaterializeResult, AssetMaterialization], None, None]:
"""Runs a Sling replication from the given replication config.
Args:
context: Asset or Op execution context.
replication_config: The Sling replication config to use for the replication.
dagster_sling_translator: The translator to use for the replication.
debug: Whether to run the replication in debug mode.
Returns:
Optional[Generator[MaterializeResult, None, None]]: A generator of MaterializeResult
Generator[Union[MaterializeResult, AssetMaterialization], None, None]: A generator of MaterializeResult or AssetMaterialization
"""
# attempt to retrieve params from asset context if not passed as a parameter
if not (replication_config or dagster_sling_translator):
metadata_by_key = context.assets_def.metadata_by_key
first_asset_metadata = next(iter(metadata_by_key.values()))
dagster_sling_translator = first_asset_metadata.get(METADATA_KEY_TRANSLATOR)
replication_config = first_asset_metadata.get(METADATA_KEY_REPLICATION_CONFIG)

# if translator has not been defined on metadata _or_ through param, then use the default constructor
dagster_sling_translator = dagster_sling_translator or DagsterSlingTranslator()

replication_config = validate_replication(replication_config)
stream_definition = get_streams_from_replication(replication_config)

Expand Down Expand Up @@ -408,11 +426,18 @@ def replicate(

end_time = time.time()

has_asset_def: bool = bool(context and context.has_assets_def)

for stream in stream_definition:
output_name = dagster_sling_translator.get_asset_key(stream)
yield MaterializeResult(
asset_key=output_name, metadata={"elapsed_time": end_time - start_time}
)
if has_asset_def:
yield MaterializeResult(
asset_key=output_name, metadata={"elapsed_time": end_time - start_time}
)
else:
yield AssetMaterialization(
asset_key=output_name, metadata={"elapsed_time": end_time - start_time}
)

def stream_raw_logs(self) -> Generator[str, None, None]:
"""Returns a generator of raw logs from the Sling CLI."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import lru_cache
from pathlib import Path
from typing import Any, Mapping, Union, cast
from typing import Any, Mapping, Optional, Union, cast

import dagster._check as check
import yaml
Expand All @@ -18,7 +18,8 @@ def read_replication_path(replication_path: Path) -> Mapping[str, Any]:
return cast(Mapping[str, Any], yaml.safe_load(replication_path.read_bytes()))


def validate_replication(replication: SlingReplicationParam) -> Mapping[str, Any]:
def validate_replication(replication: Optional[SlingReplicationParam]) -> Mapping[str, Any]:
replication = replication or {}
check.inst_param(replication, "manifest", (Path, str, dict))

if isinstance(replication, str):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging
import sqlite3
from pathlib import Path

import pytest
import yaml
from dagster import (
AssetExecutionContext,
AssetKey,
FreshnessPolicy,
JsonMetadataValue,
Expand Down Expand Up @@ -59,17 +61,12 @@ def my_sling_assets(): ...

def test_runs_base_sling_config(
csv_to_sqlite_replication_config: SlingReplicationParam,
path_to_test_csv: str,
path_to_temp_sqlite_db: str,
sqlite_connection: sqlite3.Connection,
):
@sling_assets(replication_config=csv_to_sqlite_replication_config)
def my_sling_assets(sling: SlingResource):
for row in sling.replicate(
replication_config=csv_to_sqlite_replication_config,
dagster_sling_translator=DagsterSlingTranslator(),
):
logging.info(row)
def my_sling_assets(context: AssetExecutionContext, sling: SlingResource):
yield from sling.replicate(context=context)

sling_resource = SlingResource(
connections=[
Expand All @@ -82,7 +79,10 @@ def my_sling_assets(sling: SlingResource):
]
)
res = materialize([my_sling_assets], resources={"sling": sling_resource})

assert res.success
assert len(res.get_asset_materialization_events()) == 1

counts = sqlite_connection.execute("SELECT count(1) FROM main.tbl").fetchone()[0]
assert counts == 3

Expand All @@ -105,11 +105,12 @@ def my_third_sling_assets(): ...


def test_base_with_meta_config_translator():
@sling_assets(
replication_config=file_relative_path(
__file__, "replication_configs/base_with_meta_config/replication.yaml"
)
replication_config_path = file_relative_path(
__file__, "replication_configs/base_with_meta_config/replication.yaml"
)
replication_config = yaml.safe_load(Path(replication_config_path).read_bytes())

@sling_assets(replication_config=replication_config_path)
def my_sling_assets(): ...

assert my_sling_assets.keys == {
Expand All @@ -134,7 +135,13 @@ def my_sling_assets(): ...
}

assert my_sling_assets.metadata_by_key == {
AssetKey(["target", "public", "accounts"]): {"stream_config": JsonMetadataValue(data=None)},
AssetKey(["target", "public", "accounts"]): {
"stream_config": JsonMetadataValue(data=None),
"dagster_embedded_elt/dagster_sling_translator": DagsterSlingTranslator(
target_prefix="target"
),
"dagster_embedded_elt/sling_replication_config": replication_config,
},
AssetKey(["target", "departments"]): {
"stream_config": JsonMetadataValue(
data={
Expand All @@ -152,7 +159,11 @@ def my_sling_assets(): ...
}
},
}
)
),
"dagster_embedded_elt/dagster_sling_translator": DagsterSlingTranslator(
target_prefix="target"
),
"dagster_embedded_elt/sling_replication_config": replication_config,
},
AssetKey(["target", "public", "transactions"]): {
"stream_config": JsonMetadataValue(
Expand All @@ -167,15 +178,23 @@ def my_sling_assets(): ...
}
},
}
)
),
"dagster_embedded_elt/dagster_sling_translator": DagsterSlingTranslator(
target_prefix="target"
),
"dagster_embedded_elt/sling_replication_config": replication_config,
},
AssetKey(["target", "public", "all_users"]): {
"stream_config": JsonMetadataValue(
data={
"sql": 'select all_user_id, name \nfrom public."all_Users"\n',
"object": "public.all_users",
}
)
),
"dagster_embedded_elt/dagster_sling_translator": DagsterSlingTranslator(
target_prefix="target"
),
"dagster_embedded_elt/sling_replication_config": replication_config,
},
}

Expand Down
Loading

0 comments on commit f18b798

Please sign in to comment.