Skip to content
This repository has been archived by the owner on Jul 3, 2023. It is now read-only.

Commit

Permalink
Removes "augment"
Browse files Browse the repository at this point in the history
This is not going to make it into the final release. Instead see draft PR

#60
  • Loading branch information
elijahbenizzy committed Feb 6, 2022
1 parent e3c72e8 commit 60f6415
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 93 deletions.
70 changes: 2 additions & 68 deletions hamilton/function_modifiers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import ast
import functools
import functools
import inspect
from typing import Dict, Callable, Collection, Tuple, Union, Any, Type

import pandas as pd

from hamilton import node
from hamilton.function_modifiers_base import NodeCreator, NodeResolver, NodeExpander, NodeTransformer, sanitize_function_name
from hamilton.function_modifiers_base import NodeCreator, NodeResolver, NodeExpander, sanitize_function_name
from hamilton.models import BaseModel
from hamilton.node import DependencyType

"""
Annotations for modifying the way functions get added to the DAG.
Expand Down Expand Up @@ -403,68 +402,3 @@ def resolves(configuration: Dict[str, Any]) -> bool:
return all(configuration.get(key) not in value for key, value in key_value_group_pairs.items())

return config(resolves)


class augment(NodeTransformer):
def __init__(self, expr: str, output_type: Type[Type] = None):
"""Takes in an expression to transform a node. Expression must be of the form: foo=...
Note that this is potentially unsafe!!! We use `eval` in this, so if you execute hamilton code with this
it could easily carry out extra code. Keep this in mind when you use @augment -- utilize at your own peril.
:param expr: expression to transform the node. Must have a variable in it with the function's name.
:param output_type: Type of the output, if it differs from the node. Make sure to specify this
if it differs or the types will not compile.
"""
self.expr = expr
self.expression = expr
self.output_type = output_type

def transform_node(self, node_: node.Node, config: Dict[str, Any], fn: Callable) -> Collection[node.Node]:
"""Transforms a node using the expression.
:param node_: Node to transform
:param output_type: The type of the node output. If it changes, then we specify it as Any
:return: The collection of nodes that are outputted by this transform. Will be an extra node
with the expression applied to the result of the original node.
"""

expression_parsed = ast.parse(self.expression, mode='eval')
dependent_variables = {item.id for item in ast.walk(expression_parsed) if isinstance(item, ast.Name)}

var_name = sanitize_function_name(fn.__name__)

# We should be passing the function through (or the name?)

def new_function(**kwargs):
kwargs_for_original_call = {key: value for key, value in kwargs.items() if key in node_.input_types}
unmodified_result = node_.callable(**kwargs_for_original_call)
kwargs_for_new_call = {key: value for key, value in kwargs.items() if key in dependent_variables}
kwargs_for_new_call[var_name] = unmodified_result
return eval(self.expression, kwargs_for_new_call)

replacement_node = node.Node(
name=node_.name,
typ=node_.type if self.output_type is None else self.output_type,
doc_string=node_.documentation,
callabl=new_function,
node_source=node_.node_source,
input_types={
**{dependent_var: (Any, DependencyType.REQUIRED) for dependent_var in dependent_variables if dependent_var != var_name},
**{param: (input_type, dependency_type) for param, (input_type, dependency_type) in node_.input_types.items()}
}
)
return [replacement_node]

# TODO -- find a cleaner way to to copy a node

def validate(self, fn: Callable):
"""Validates the expression is of the right form"""
parsed_expression = ast.parse(self.expression, mode='eval')
if not isinstance(parsed_expression, ast.Expression):
raise ValueError(f'Expression {self.expr} must be an expression.')
dependent_variables = {item.id for item in ast.walk(parsed_expression) if isinstance(item, ast.Name)}
var_name = sanitize_function_name(fn.__name__)
if var_name not in dependent_variables:
raise ValueError(f'Expression must depend on the function its transforming. Did not find {var_name} in your expression\'s AST '
f'If you have a function called "foo", your expression must be a function of foo (as well as other variables). '
f'If you want to replace the value of this function, write your function as you normally would (E.G. as a function of its parameters).')
2 changes: 2 additions & 0 deletions tests/experimental/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from hamilton import node
from hamilton.experimental.decorators import augment
5 changes: 1 addition & 4 deletions tests/resources/layered_decorators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from hamilton.function_modifiers import config, does, parametrized, augment

from hamilton.function_modifiers import config, does, parametrized

def _sum(**kwargs: int) -> int:
return sum(kwargs.values())


@does(_sum)
@parametrized(parameter='a', assigned_output={('e', 'First value'): 10, ('f', 'First value'): 20})
@augment('c**2+d')
@config.when(foo='bar')
def c__foobar(a: int, b: int) -> int:
"""Demonstrates utilizing a bunch of decorators.
Expand All @@ -22,7 +20,6 @@ def c__foobar(a: int, b: int) -> int:

@does(_sum)
@parametrized(parameter='a', assigned_output={('e', 'First value'): 11, ('f', 'First value'): 22})
@augment('c**2+d')
@config.when(foo='baz')
def c__foobaz(a: int, b: int) -> int:
"""Demonstrates utilizing a bunch of decorators.
Expand Down
19 changes: 0 additions & 19 deletions tests/test_function_modifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,22 +368,3 @@ def config_when_fn() -> int:

annotation = function_modifiers.config.when(key='value', name='new_function_name')
assert annotation.resolve(config_when_fn, {'key': 'value'}).__name__ == 'new_function_name'


def test_augment_decorator():

def foo(a: int) -> int:
return a*2

annotation = function_modifiers.augment('foo*MULTIPLIER_foo+OFFSET_foo')
annotation.validate(foo)
nodes = annotation.transform_dag([node.Node.from_fn(foo)], {}, foo)
assert 1 == len(nodes)
nodes_by_name = {node_.name: node_ for node_ in nodes}
assert set(nodes_by_name) == {'foo'}
a = 5
MULTIPLIER_foo = 3
OFFSET_foo = 7
foo = a*2
foo = MULTIPLIER_foo*foo + OFFSET_foo
assert nodes_by_name['foo'].callable(a=a, MULTIPLIER_foo=MULTIPLIER_foo, OFFSET_foo=OFFSET_foo) == foo # note its foo_raw as that's the node on which it depends
4 changes: 2 additions & 2 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ def test_end_to_end_with_layered_decorators_resolves_true():
fg = graph.FunctionGraph(tests.resources.layered_decorators, config={'foo': 'bar', 'd': 10, 'b': 20})
out = fg.execute([n for n in fg.get_nodes()], overrides={'b': 10})
assert len(out) > 0 # test config.when resolves correctly
assert out['e'] == (20+10) ** 2 + 10
assert out['f'] == (30+10) ** 2 + 10
assert out['e'] == (20+10)
assert out['f'] == (30+10)


def test_end_to_end_with_layered_decorators_resolves_false():
Expand Down

0 comments on commit 60f6415

Please sign in to comment.