Disdat is a Python (3.6 +) package for data versioning and pipeline authoring that allows data scientists to create , share, and track data products. Disdat-kfp is a plugin built upon Disdat and enables data versioning for Kubeflow Pipeline (KFP).
More specifically, this plugin does the following:
- Caching: Captures all intermediate outputs and reuses cached data based on task parameters.
- Data Versioning: All state artifacts are versioned as bundles on S3;
- Minimum Intrusion: Making it easy to refactor existing projects at pipeline level; users don't need to modify any state definitions.
- Share Datasets: Intermediary artifacts can be easily shared between teams with standardized APIs.
Install the package, the pip command will also download the core disdat package if you haven't done so already.
pip install disdat-step-function
Get Started with the tutorial notebook! Check out how easy it is to version a AWS StepFunction state machine in simple_cached_sm.ipynb
Used to create and configure data versioning parameters that Disdat should use. You can eithe create this object and share it between
different components (e.g, if all data go to the same location), or you can create one object for each component
Args
caching_lambda_name
: str
, the lambda function you created with the lambda stub generated by disdat-kfp. See LambdaGenerator
context_name
: str
, the Disdat context in which the artifacts reside
s3_bucket_url
: str
, url of the S3 bucket. For instance s3://my-bucket
force_rerun
: bool
, force rerun all components if set to True
use_verbose
: bool
, see more verbose logs if set to True
state_machine_name
: str
, name of the state machine, reserved for future use
Given a user state, wrap it up with dynamically generated states that implements data versioning and
smart caching.
Args
user_step
: The step object(sepfunctions.steps.states
) to cache
bundle_name
: str
, optional, the name of the bundle to create. Default to state name
force_rerun
: bool
, override the pipeline-level force_rerun
. Set to True
to enable caching
Return
stepfunctions.steps.Chain
: a state machine with user's state embedded in
from stepfunctions.steps import states
from disdat_step_function.caching_wrapper import Caching
caching = Caching(caching_lambda_name="lambda_func_name",
s3_bucket_url="s3://hello",
context_name="simple_demo",
verbose=True)
user_task = states.Pass(state_id='user')
cached_task = caching.cache_step(user_task, bundle_name="simple_task", force_rerun=False,)
Used to refactor an existing pipeline given its definition. PipelineCaching
finds all Task
state in the definition
and call replace it with Caching().cache_step(task)
(a steps.Chain
object)
Args
defintion
: Union[sepfunctions.steps.states.Chain, sepfunctions.steps.states.State]
, the state machine to refactor
caching
: caching_wrapper.Caching
, used to cache individual states.
Modify state machine definition
in-place. The state machine now supports data versioning
Args
None
Return
None
from stepfunctions.steps import states
from disdat_step_function.caching_wrapper import Caching, PipelineCaching
caching = Caching(caching_lambda_name="lambda_func_name",
s3_bucket_url="s3://hello",
context_name="simple_demo",
verbose=True)
# define you graph normally
user_task_1 = states.Pass(state_id='task1')
user_task_2 = states.Pass(state_id='task2')
graph = states.Chain([user_task_1, user_task_2])
# one-button refactoring
PipelineCaching(graph, caching).cache()
Generate code and lambda layer that you can use to create a AWS Python Lambda function, which is called by
the augmented state machine to record artifacts.
Args
root
: Union[str, pathlib.Path]
, where to dump the generated code
force_rerun
: bool
, re-generate everything and overwrite existing code and zip files
Return
None
Artifacts This function will generate a folder of artifacts that you can use to create appropriate caching lambda function.
Root
|- /cache_lambda
| -lambda_for_user.py: copy this file to a lambda function
|
|- /dependency: dependencies installed in a amazonlinux container to ensure compatibility
|
|- disdat_caching_layer.zip: the lambda layer with all necessary dependencies
Since AWS StepFunction is essentially an orchestrator of tasks with heterogeneous runtimes, we must make sure all tasks, not just Python code, get to enjoy the benefits of data versioning.
Hence, disdat-step-function injects states before/after user state to pull/push data to S3 (from now on they are called caching steps).
To enable data versioning and caching for a state, simply use the cache_step()
wrapper and pass in the state obj. If you
have a complex state machine with many states, you can also use PipelineCaching().cache
to one-button refactor
the whole workflow.
As you can see in the figure above, disdat-step-function injects some states around user's designated task. The component name and input parameters are used to uniquely identify an execution. Note that you should not use disdat-step-function for tasks that are not idempotent.
With some high-level understanding of how Disdat-step-function works, let's dive into the actual diagram
It takes the input json file and extract the parameters using the input_path
field selector specified by the user's task.
Once it has the parameters, a proce_name
will be calculated using bundle name and the hash of the parameters. This key value
will be used to pull data from the remote disdat bundle on S3
The output of cache_pull
has the following format:
Output
{
# the full param will be passed to user state
"_full_params": input_data
# the cache param will be passed to cache push
"_cache_params": input_data[field_selector]
# data to pass on to the next state if cache hit
"_data": cached_data/null,
# execute user state or not
"_use_cache": false/true
}
In AWS State Machine, any state can only take parameters from another state. In order for the input parmaters to reach cache_push
,
we must create a bypass. The output of param_pass
has the following format:
Output
{
"_cache_params": input_data[field_selector]
}
Since the user state expects the input data to be the parameters cache_pull
receives, we must unravel the output of cache_pull
using
field selection $._full_params
. Hence, param_resolve
has the following output format:
Output
input_data
If it receives data from cache_pull
(no execution), it will simply pass on the _data
field to the following state
If it receives data from the execution branchm it will use _cache_param
to calculate proc_name and push data to S3
Output
data output by user task
Data provenance is an important feature of data versioning. This is especially true for orchestration workflows such as AWS StepFunction. However, due to the design nature of SF, it is very hard for a state to figure out what its parent is. We leverage the execution logs and build a graph in which each node is an execution event. Unfortunately the execution flow is hard to reconstruct from raw logs (especially true for Parallel and Map states). Right now we support partial lineage, that is for every cached state, we guarantee to find one parent (some states may have multile parents, such as the state following a parallel state with two branches)
Importance Tips
If you need to enable lineage capture, please make sure your Lambda
can performlist_state_machines
,list_executions
andget_execution_history
on StateMachinesLineage capture only supports default bundle names. This is because we can only recover bundle name from execution event name if bundle name = task name
Warning
Lineage capture is still in alpha mode, use it at your own risk (captured dependencies are correct, but some will be missed). By default this feature is disabled. To enable it, modify the
lambda_for_user.py
code:
parent = cache.get_lineage()
# parent = None
return cache.cache_push(event, parent)