Skip to content

Commit

Permalink
Merge pull request #320 from LineaLabs/export-slice-airflow
Browse files Browse the repository at this point in the history
Implement  `--export-slice-airflow`
  • Loading branch information
saulshanabrook authored Oct 25, 2021
2 parents 7dfe4e2 + 20e0899 commit 0ef50e4
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 25 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ housing.csv
# Result of --export-slice
sliced_housing.py

# Result of --export-slice-to-airflow-dag
sliced_housing_dag.py

tracer
tracer.pdf
devcontainer.json
13 changes: 13 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,19 @@
"tests/housing.py",
],
},
{
"name": "lineapy --airflow",
"type": "python",
"request": "launch",
"module": "lineapy.cli.cli",
"args": [
"--slice",
"p value",
"--airflow",
"sliced_housing_dag",
"tests/housing.py",
],
},
{
"name": "Python: Current File",
"type": "python",
Expand Down
4 changes: 2 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"python.formatting.provider": "black",
"python.linting.mypyEnabled": true,
"python.linting.enabled": true,
"python.linting.enabled": false,
"python.testing.pytestEnabled": true,
"python.linting.flake8Enabled": true,
"editor.formatOnSave": true,
Expand All @@ -10,4 +10,4 @@
"source.organizeImports": true
}
}
}
}
13 changes: 13 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,19 @@ jupyter nbconvert --to notebook --execute tests/test_notebook.ipynb --inplace --
Or you can open it in a notebook UI (JupyterLab, JupyterNotebook, VS Code, etc.)
and re-run it manually

### Airflow

Sliced code can be exported to an Airflow DAG using the following command:

```
lineapy tests/housing.py --slice "p value" --airflow sliced_housing_dag
```
This creates a `sliced_housing_dag.py` file in the current dir. It can be executed with:

```
airflow db init
airflow dags test sliced_housing_dag_dag $(date '+%Y-%m-%d') -S .
```
## Visual Graphs

Sometimes it's helpful to see a visual representation of the graph
Expand Down
22 changes: 21 additions & 1 deletion lineapy/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lineapy.db.relational.db import RelationalLineaDB
from lineapy.instrumentation.tracer import Tracer
from lineapy.logging import configure_logging
from lineapy.plugins.airflow import sliced_aiflow_dag
from lineapy.transformer.node_transformer import transform
from lineapy.utils import prettify

Expand Down Expand Up @@ -39,6 +40,12 @@
default=None,
help="Requires --slice. Export the sliced code that {slice} depends on to {export_slice}.py",
)
@click.option(
"--export-slice-to-airflow-dag",
"--airflow",
default=None,
help="Requires --slice. Export the sliced code that {slice} depends on to an Airflow DAG {export_slice}.py",
)
@click.option(
"--print-source", help="Whether to print the source code", is_flag=True
)
Expand Down Expand Up @@ -66,6 +73,7 @@ def linea_cli(
mode,
slice,
export_slice,
export_slice_to_airflow_dag,
print_source,
print_graph,
verbose,
Expand Down Expand Up @@ -94,7 +102,8 @@ def linea_cli(

if visualize:
tracer.visualize()
if slice and not export_slice:

if slice and not export_slice and not export_slice_to_airflow_dag:
tree.add(
rich.console.Group(
f"Slice of {repr(slice)}",
Expand All @@ -109,6 +118,17 @@ def linea_cli(
full_code = tracer.sliced_func(slice, export_slice)
pathlib.Path(f"{export_slice}.py").write_text(full_code)

if export_slice_to_airflow_dag:
if not slice:
print(
"Please specify --slice. It is required for --export-slice-to-airflow-dag"
)
exit(1)
full_code = sliced_aiflow_dag(
tracer, slice, export_slice_to_airflow_dag
)
pathlib.Path(f"{export_slice_to_airflow_dag}.py").write_text(full_code)

tracer.db.close()
if print_graph:
graph_code = prettify(
Expand Down
18 changes: 9 additions & 9 deletions lineapy/db/relational/schema/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def process_result_value(self, value, dialect):
return value


class SessionContextORM(Base):
class SessionContextORM(Base): # type: ignore
__tablename__ = "session_context"
id = Column(String, primary_key=True)
environment_type = Column(Enum(SessionType))
Expand All @@ -108,7 +108,7 @@ class SessionContextORM(Base):
execution_id = Column(String, ForeignKey("execution.id"))


class LibraryORM(Base):
class LibraryORM(Base): # type: ignore
__tablename__ = "library"
__table_args__ = (
UniqueConstraint(
Expand All @@ -125,7 +125,7 @@ class LibraryORM(Base):
path = Column(String)


class ArtifactORM(Base):
class ArtifactORM(Base): # type: ignore
"""
An artifact is a named pointer to a node.
"""
Expand All @@ -140,7 +140,7 @@ class ArtifactORM(Base):
)


class ExecutionORM(Base):
class ExecutionORM(Base): # type: ignore
"""
An execution represents one Python interpreter invocation of some number of nodes
"""
Expand All @@ -150,7 +150,7 @@ class ExecutionORM(Base):
timestamp = Column(DateTime, nullable=True, default=datetime.utcnow)


class NodeValueORM(Base):
class NodeValueORM(Base): # type: ignore
"""
A node value represents the value of a node during some execution.
Expand All @@ -170,7 +170,7 @@ class NodeValueORM(Base):
end_time = Column(DateTime, nullable=True)


class BaseNodeORM(Base):
class BaseNodeORM(Base): # type: ignore
"""
node.source_code has a path value if node.session.environment_type == "script"
otherwise the environment type is "jupyter" and it has a jupyter execution
Expand Down Expand Up @@ -214,7 +214,7 @@ class BaseNodeORM(Base):
}


class SourceCodeORM(Base):
class SourceCodeORM(Base): # type: ignore
__tablename__ = "source_code"

id = Column(String, primary_key=True)
Expand Down Expand Up @@ -261,7 +261,7 @@ class ImportNodeORM(BaseNodeORM):
# https://docs.sqlalchemy.org/en/14/orm/basic_relationships.html#association-object


class PositionalArgORM(Base):
class PositionalArgORM(Base): # type: ignore
__tablename__ = "positional_arg"
call_node_id: str = Column(
ForeignKey("call_node.id"), primary_key=True, nullable=False
Expand All @@ -273,7 +273,7 @@ class PositionalArgORM(Base):
argument = relationship(BaseNodeORM, uselist=False)


class KeywordArgORM(Base):
class KeywordArgORM(Base): # type: ignore
__tablename__ = "keyword_arg"
call_node_id: str = Column(
ForeignKey("call_node.id"), primary_key=True, nullable=False
Expand Down
33 changes: 21 additions & 12 deletions lineapy/instrumentation/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,9 @@ def artifacts(self) -> dict[str, str]:

def sliced_func(self, slice_name: str, func_name: str) -> str:
artifact = self.db.get_artifact_by_name(slice_name)
if not artifact.node:
artifact_var = self.slice_var_name(artifact)
if not artifact_var:
return "Unable to extract the slice"
_line_no = artifact.node.lineno if artifact.node.lineno else 0
artifact_line = str(artifact.node.source_code.code).split("\n")[
_line_no - 1
]
_col_offset = (
artifact.node.col_offset if artifact.node.col_offset else 0
)
if _col_offset < 3:
return "Unable to extract the slice"
artifact_name = artifact_line[: _col_offset - 3]
slice_code = get_program_slice(self.graph, [artifact.id])
# We split the code in import and code blocks and join them to full code test
import_block, code_block, main_block = split_code_blocks(
Expand All @@ -146,7 +137,7 @@ def sliced_func(self, slice_name: str, func_name: str) -> str:
import_block
+ "\n\n"
+ code_block
+ f"\n\treturn {artifact_name}"
+ f"\n\treturn {artifact_var}"
+ "\n\n"
+ main_block
)
Expand All @@ -163,6 +154,24 @@ def slice(self, name: str) -> str:
artifact = self.db.get_artifact_by_name(name)
return get_program_slice(self.graph, [artifact.id])

def slice_var_name(self, artifact: ArtifactORM) -> str:
"""
Returns the variable name for the given artifact.
i.e. in lineapy.linea_publish(p, "p value") "p" is returned
"""
if not artifact.node:
return ""
_line_no = artifact.node.lineno if artifact.node.lineno else 0
artifact_line = str(artifact.node.source_code.code).split("\n")[
_line_no - 1
]
_col_offset = (
artifact.node.col_offset if artifact.node.col_offset else 0
)
if _col_offset < 3:
return ""
return artifact_line[: _col_offset - 3]

def visualize(
self,
filename="tracer",
Expand Down
63 changes: 63 additions & 0 deletions lineapy/plugins/airflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from black import FileMode, format_str

from lineapy.graph_reader.program_slice import (
get_program_slice,
split_code_blocks,
)
from lineapy.instrumentation.tracer import Tracer

AIRFLOW_IMPORTS_TEMPLATE = """
from airflow import DAG
from airflow.utils.dates import days_ago
from airflow.operators.python_operator import PythonOperator
"""

AIRFLOW_MAIN_TEMPLATE = """
default_dag_args = {"owner": "airflow", "retries": 2, "start_date": days_ago(1)}
dag = DAG(
dag_id="DAG_NAME_dag",
schedule_interval="*/15 * * * *", # Every 15 minutes
max_active_runs=1,
catchup=False,
default_args=default_dag_args,
)
DAG_NAME = PythonOperator(
dag=dag, task_id=f"DAG_NAME_task", python_callable=DAG_NAME,
)
"""


def sliced_aiflow_dag(tracer: Tracer, slice_name: str, func_name: str) -> str:
"""
Returns a an Airflow DAG of the sliced code.
:param tracer: the tracer object.
:param slice_name: name of the artifacts to get the code slice for.
:return: string containing the code of the Airflow DAG running this slice
"""
artifact = tracer.db.get_artifact_by_name(slice_name)
artifact_var = tracer.slice_var_name(artifact)
if not artifact_var:
return "Unable to extract the slice"
slice_code = get_program_slice(tracer.graph, [artifact.id])
# We split the code in import and code blocks and join them to full code test
import_block, code_block, main_block = split_code_blocks(
slice_code, func_name
)
full_code = (
import_block
+ "\n"
+ AIRFLOW_IMPORTS_TEMPLATE
+ "\n\n"
+ code_block
+ f"\n\tprint({artifact_var})" # TODO What to do with artifact_var in a DAG?
+ "\n\n"
+ AIRFLOW_MAIN_TEMPLATE.replace("DAG_NAME", func_name)
)
# Black lint
black_mode = FileMode()
black_mode.line_length = 79
full_code = format_str(full_code, mode=black_mode)
return full_code
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ relative_files = true

# https://docs.sqlalchemy.org/en/14/orm/extensions/mypy.html
# https://pydantic-docs.helpmanual.io/mypy_plugin/#enabling-the-plugin
plugins = ["sqlalchemy.ext.mypy.plugin", "pydantic.mypy"]
# plugins = ["sqlalchemy.ext.mypy.plugin", "pydantic.mypy"]


# Enable function body type checking, even if function types are not annotated
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def version(path):
"coveralls",
"seaborn",
"graphviz",
"apache-airflow==2.2.0",
]
},
include_package_data=True,
Expand Down
35 changes: 35 additions & 0 deletions tests/test_airflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import subprocess


def test_export_slice_housing_dag():
"""
Verifies that the "--airflow" CLI command produces a working Airflow DAG
"""
subprocess.check_call(
[
"lineapy",
"tests/housing.py",
"--slice",
"p value",
"--airflow",
"sliced_housing_dag",
]
)
subprocess.check_call(
[
"airflow",
"db",
"init",
]
)
subprocess.check_call(
[
"airflow",
"dags",
"test",
"sliced_housing_dag_dag",
"2020-10-19",
"-S",
".",
]
)

0 comments on commit 0ef50e4

Please sign in to comment.