Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle partitions natively in W&B IO Manager #15170

Merged
merged 11 commits into from
Aug 9, 2023
62 changes: 55 additions & 7 deletions examples/with_wandb/with_wandb/assets/advanced_example.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import wandb
from dagster import AssetExecutionContext, AssetIn, asset
from dagster_wandb import WandbArtifactConfiguration

import wandb
from wandb import Artifact

wandb_artifact_configuration: WandbArtifactConfiguration = {
"description": "My **Markdown** description",
"aliases": ["first_alias", "second_alias"],
Expand Down Expand Up @@ -32,7 +34,7 @@
compute_kind="wandb",
metadata={"wandb_artifact_configuration": wandb_artifact_configuration},
)
def write_advanced_artifact() -> wandb.wandb_sdk.wandb_artifacts.Artifact:
def write_advanced_artifact() -> Artifact:
"""Example that writes an advanced Artifact.

Here we use the full power of the integration with W&B Artifacts.
Expand All @@ -57,7 +59,7 @@ def write_advanced_artifact() -> wandb.wandb_sdk.wandb_artifacts.Artifact:
- https://docs.wandb.ai/ref/python/artifact#add_reference

Returns:
wandb.Artifact: The Artifact we augment with the integration
Artifact: The Artifact we augment with the integration
"""
artifact = wandb.Artifact(MY_ASSET, "files")
table = wandb.Table(columns=["a", "b", "c"], data=[[1, 2, 3]])
Expand Down Expand Up @@ -132,18 +134,64 @@ def get_path(context: AssetExecutionContext, path: str) -> None:
},
output_required=False,
)
def get_artifact(
context: AssetExecutionContext, artifact: wandb.wandb_sdk.wandb_artifacts.Artifact
) -> None:
def get_artifact(context: AssetExecutionContext, artifact: Artifact) -> None:
"""Example that gets the entire Artifact object.

Args:
context (AssetExecutionContext): Dagster execution context
artifact (wandb.wandb_sdk.wandb_artifacts.Artifact): Downloaded Artifact object
artifact (Artifact): Downloaded Artifact object

Here, we use the integration to collect the entire W&B Artifact object created from in first
asset.

The integration downloads the entire Artifact for us.
"""
context.log.info(f"Result: {artifact.name}") # Result: my_advanced_artifact:v0


@asset(
compute_kind="wandb",
ins={
"artifact": AssetIn(
asset_key=MY_ASSET,
metadata={
"wandb_artifact_configuration": {
"version": "v0",
}
},
)
},
output_required=False,
)
def get_version(context: AssetExecutionContext, artifact: Artifact) -> None:
"""Example that gets the entire Artifact object based on its version.

Args:
context (AssetExecutionContext): Dagster execution context
artifact (Artifact): Downloaded Artifact object
"""
context.log.info(f"Result: {artifact.name}") # Result: my_advanced_artifact:v0


@asset(
compute_kind="wandb",
ins={
"artifact": AssetIn(
asset_key=MY_ASSET,
metadata={
"wandb_artifact_configuration": {
"alias": "first_alias",
}
},
)
},
output_required=False,
)
def get_alias(context: AssetExecutionContext, artifact: Artifact) -> None:
"""Example that gets the entire Artifact object based on its alias.

Args:
context (AssetExecutionContext): Dagster execution context
artifact (Artifact): Downloaded Artifact object
"""
context.log.info(f"Result: {artifact.name}") # Result: my_advanced_artifact:first_alias
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from dagster import AssetIn, StaticPartitionsDefinition, asset

import wandb

partitions_def = StaticPartitionsDefinition(["red", "orange", "yellow", "blue", "green"])

ARTIFACT_NAME = "my_advanced_configuration_partitioned_asset"


@asset(
group_name="partitions",
partitions_def=partitions_def,
name=ARTIFACT_NAME,
compute_kind="wandb",
metadata={
"wandb_artifact_configuration": {
"aliases": ["special_alias"],
}
},
)
def write_advanced_artifact(context):
"""Example writing an Artifact with partitions and custom metadata."""
artifact = wandb.Artifact(ARTIFACT_NAME, "dataset")
partition_key = context.asset_partition_key_for_output()

if partition_key == "red":
return "red"
elif partition_key == "orange":
return wandb.Table(columns=["color"], data=[["orange"]])
elif partition_key == "yellow":
table = wandb.Table(columns=["color"], data=[["yellow"]])
artifact.add(table, "custom_table_name")
else:
table = wandb.Table(columns=["color", "value"], data=[[partition_key, 1]])
artifact.add(table, "default_table_name")
return artifact


@asset(
group_name="partitions",
compute_kind="wandb",
ins={
"partitions": AssetIn(
asset_key=ARTIFACT_NAME,
metadata={
"wandb_artifact_configuration": {
"partitions": {
# The wildcard "*" means "all non-configured partitions"
"*": {
"get": "default_table_name",
},
# You can override the wildcard for specific partition using their key
"yellow": {
"get": "custom_table_name",
},
# You can collect a specific Artifact version
"orange": {
"version": "v0",
},
# You can collect a specific alias, note you must specify the 'get' value.
# This is because the wildcard is only applied to partitions that haven't
# been configured.
"blue": {
"alias": "special_alias",
"get": "default_table_name",
},
},
},
},
)
},
output_required=False,
)
def read_objects_directly(context, partitions):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trying to wrap my head around here. where does the partitions come from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the AssetIn

"""Example reading all Artifact partitions from the previous asset."""
for partition, content in partitions.items():
context.log.info(f"partition={partition}, type={type(content)}")
if partition == "red":
context.log.info(content)
elif partition == "orange":
# The orange partition was a raw W&B Table, the IO Manager wrapped that Table in an
# Artifact. The default name for the table is 'Table'. We could have also set
# the partition 'get' config to receive the table directly instead of the Artifact.
context.log.info(content.get("Table").get_column("color"))
else:
context.log.info(content.get_column("color"))
3 changes: 2 additions & 1 deletion examples/with_wandb/with_wandb/assets/example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import wandb
from fashion_data import fashion
from torch.autograd import Variable

import wandb

hyperparameter_defaults = dict(
dropout=0.5,
channels_one=16,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import wandb
from dagster import AssetIn, Config, asset

import wandb
from wandb import Artifact

MODEL_NAME = "my_model"


@asset(
name=MODEL_NAME,
compute_kind="wandb",
)
def write_model() -> wandb.wandb_sdk.wandb_artifacts.Artifact:
def write_model() -> Artifact:
"""Write your model.

Here, we have we're creating a very simple Artifact with the integration.
Expand All @@ -32,7 +34,7 @@ class PromoteBestModelToProductionConfig(Config):
output_required=False,
)
def promote_best_model_to_production(
artifact: wandb.wandb_sdk.wandb_artifacts.Artifact,
artifact: Artifact,
config: PromoteBestModelToProductionConfig,
):
"""Example that links a model stored in a W&B Artifact to the Model Registry.
Expand Down
3 changes: 2 additions & 1 deletion examples/with_wandb/with_wandb/assets/multi_asset_example.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Tuple

import wandb
from dagster import AssetOut, multi_asset

import wandb


@multi_asset(
name="write_multiple_artifacts",
Expand Down
62 changes: 62 additions & 0 deletions examples/with_wandb/with_wandb/assets/multi_partitions_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from dagster import (
AssetIn,
DailyPartitionsDefinition,
MultiPartitionsDefinition,
StaticPartitionsDefinition,
asset,
)

import wandb

partitions_def = MultiPartitionsDefinition(
{
"date": DailyPartitionsDefinition(start_date="2023-01-01", end_date="2023-01-05"),
"color": StaticPartitionsDefinition(["red", "yellow", "blue"]),
}
)


@asset(
group_name="partitions",
partitions_def=partitions_def,
name="my_multi_partitioned_asset",
compute_kind="wandb",
metadata={
"wandb_artifact_configuration": {
"type": "dataset",
}
},
)
def create_my_multi_partitioned_asset(context):
"""Example writing an Artifact with mutli partitions and custom metadata."""
partition_key = context.asset_partition_key_for_output()
context.log.info(f"Creating partitioned asset for {partition_key}")
if partition_key == "red|2023-01-02":
artifact = wandb.Artifact("my_multi_partitioned_asset", "dataset")
table = wandb.Table(columns=["color"], data=[[partition_key]])
return artifact.add(table, "default_table_name")
return partition_key # e.g. "blue|2023-01-04"


@asset(
group_name="partitions",
compute_kind="wandb",
ins={
"my_multi_partitioned_asset": AssetIn(
metadata={
"wandb_artifact_configuration": {
"partitions": {
"red|2023-01-02": {
"get": "custom_table_name",
},
},
},
},
)
},
output_required=False,
)
def read_all_multi_partitions(context, my_multi_partitioned_asset):
"""Example reading all Artifact partitions from the previous asset."""
for partition, content in my_multi_partitioned_asset.items():
context.log.info(f"partition={partition}, content={content}")
66 changes: 66 additions & 0 deletions examples/with_wandb/with_wandb/assets/simple_partitions_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import random

from dagster import (
AssetIn,
DailyPartitionsDefinition,
TimeWindowPartitionMapping,
asset,
)

partitions_def = DailyPartitionsDefinition(start_date="2023-01-01", end_date="2023-02-01")


@asset(
group_name="partitions",
partitions_def=partitions_def,
name="my_daily_partitioned_asset",
compute_kind="wandb",
metadata={
"wandb_artifact_configuration": {
"type": "dataset",
}
},
)
def create_my_daily_partitioned_asset(context):
"""Example writing an Artifact with daily partitions and custom metadata."""
# Happens when the asset is materialized in multiple runs (one per partition)
if context.has_partition_key:
partition_key = context.asset_partition_key_for_output()
context.log.info(f"Creating partitioned asset for {partition_key}")
return random.randint(0, 100)

# Happens when the asset is materialized in a single run
# Important: this will throw an error because we don't support materializing a partitioned
# asset in a single run
partition_key_range = context.asset_partition_key_range
context.log.info(f"Creating partitioned assets for window {partition_key_range}")
return random.randint(0, 100)


@asset(
group_name="partitions",
compute_kind="wandb",
ins={"my_daily_partitioned_asset": AssetIn()},
output_required=False,
)
def read_all_partitions(context, my_daily_partitioned_asset):
"""Example reading all Artifact partitions from the first asset."""
for partition, content in my_daily_partitioned_asset.items():
context.log.info(f"partition={partition}, content={content}")


@asset(
group_name="partitions",
partitions_def=partitions_def,
compute_kind="wandb",
ins={
"my_daily_partitioned_asset": AssetIn(
partition_mapping=TimeWindowPartitionMapping(start_offset=-1)
)
},
output_required=False,
)
def read_specific_partitions(context, my_daily_partitioned_asset):
"""Example reading specific Artifact partitions from the first asset."""
for partition, content in my_daily_partitioned_asset.items():
context.log.info(f"partition={partition}, content={content}")
4 changes: 1 addition & 3 deletions examples/with_wandb/with_wandb/ops/simple_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def create_my_final_list(downloaded_artifact: List[int]) -> List[int]:
project=str,
),
"wandb_resource": wandb_resource.configured({"api_key": {"env": "WANDB_API_KEY"}}),
"io_manager": wandb_artifacts_io_manager.configured(
{"wandb_run_id": "my_resumable_run_id"}
),
"io_manager": wandb_artifacts_io_manager.configured({"run_id": "my_resumable_run_id"}),
}
)
def simple_job_example():
Expand Down
2 changes: 2 additions & 0 deletions examples/with_wandb/workspace.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
load_from:
- python_module: with_wandb
Loading