Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds the extensions patch functionality #101

Merged
merged 12 commits into from
Sep 11, 2024
Merged

Adds the extensions patch functionality #101

merged 12 commits into from
Sep 11, 2024

Conversation

dalonsoa
Copy link
Collaborator

@dalonsoa dalonsoa commented Sep 5, 2024

Adds the ability to patch any method, attribute or item within an attribute in a node using decorators. This includes the patching of handlers and of sub-attributes.

Supersedes #97

Comment on lines 74 to 84
# Check if all patches are registered
assert len(extensions_registry) == 4

# Apply the patches
apply_patches(model)

# Verify that the patches are applied correctly
assert model.nodes[node.name].apply_overrides == dummy_patch
assert model.nodes[node.name].t == another_dummy_patch(node)
assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch
assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can include an example that behaves like a conventional decorator (since this will be a common use case)?

Suggested change
# Check if all patches are registered
assert len(extensions_registry) == 4
# Apply the patches
apply_patches(model)
# Verify that the patches are applied correctly
assert model.nodes[node.name].apply_overrides == dummy_patch
assert model.nodes[node.name].t == another_dummy_patch(node)
assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch
assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch
# 5. Patch a decorator
@register_node_patch("dummy_node.pull_distributed")
def a_dummy_decorator(node, vqip):
#Only pull from Reservoir
return node.pull_distributed(vqip, of_type=['Reservoir'])
# Check if all patches are registered
assert len(extensions_registry) == 5
# Apply the patches
apply_patches(model)
# Verify that the patches are applied correctly
assert model.nodes[node.name].apply_overrides == dummy_patch
assert model.nodes[node.name].t == another_dummy_patch(node)
assert model.nodes[node.name].pull_set_handler["default"] == yet_another_dummy_patch
assert model.nodes[node.name].dummy_arc.arc_mass_balance == arc_dummy_patch
assert model.nodes[node.name].dummy_arc.pull_distributed == a_dummy_decorator

Copy link
Collaborator

@barneydobson barneydobson Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think I've misunderstood the decorator example.

If I add the line:

_ = model.nodes[node.name].pull_distributed(node.empty_vqip())

To the test (which would be the normal use case) - it fails. doesn't seem to help if I set is_attr=True... so definitely an example to cover that would be helpful ;)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand what you mean by "a conventional decorator". And where are you putting that line?

This comment was marked as outdated.

This comment was marked as outdated.

Copy link
Collaborator

@barneydobson barneydobson Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK sorry for all the messages there - I was just trying to figure out how to use this properly.

The below test passes, but is it the correct way to extend an existing function (while still calling it)? If so it should be in tests and in the documentation as it will be one of the more common uses of extensions.

def assertDictAlmostEqual(d1, d2, accuracy=19):
    """

    Args:
        d1:
        d2:
        accuracy:
    """
    for d in [d1, d2]:
        for key, item in d.items():
            d[key] = round(item, accuracy)
    assert d1 == d2


def test_apply_dec(temp_extension_registry):
    from wsimod.arcs.arcs import Arc
    from wsimod.extensions import (
        apply_patches,
        extensions_registry,
        register_node_patch,
    )
    from wsimod.nodes.storage import Reservoir
    from wsimod.orchestration.model import Model

    # Create a dummy model
    node = Reservoir(name="dummy_node", initial_storage=10,capacity = 10)
    node.dummy_arc = Arc("dummy_arc", in_port=node, out_port=node)

    vq = node.pull_distributed({'volume' : 5})
    assertDictAlmostEqual(vq, node.v_change_vqip(node.empty_vqip(),5))

    model = Model()
    model.nodes[node.name] = node

    # 5. Patch a decorator
    @register_node_patch("dummy_node.pull_distributed", is_attr=True)
    def extend_function(node):
        def wrapper(f_old):
            def f(vqip, *args, **kw):
                return f_old(vqip, of_type = ['Node'], *args, **kw)
            return f
        #Only pull from Reservoir
        return wrapper(node.pull_distributed)
        
    
    # Apply the patches
    apply_patches(model)

    # Check appropriate result
    assert node.tank.storage['volume'] == 5
    vq = model.nodes[node.name].pull_distributed({'volume' : 5})
    assertDictAlmostEqual(vq, node.empty_vqip())
    assert node.tank.storage['volume'] == 5

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, now I get what you want. So a common use case is to use the old function you are overriding in the override itself, something like calling super().some_function in a child class, right?

Ok, let's see if I can figure out the most elegant way of doing it, so the user doesn't need to deal with functions, within functions, within functions...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes - while operational my approach is not the most elegant ;)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have a look now. It might be useful to re-read the docstrings to make sure the explanations are clear.

"""
for (target, item, is_attr), func in extensions_registry.items():
# Process the target string
starget = target.split(".")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there isn't anything to stop nodes have a . in the node name.. It's not in my default model setup - but I wouldn't be surprised if others have introduced this unthinkingly.

Perhaps could we have target_node as a separate argument - since the name could be anything, and then any sub-attributes can by . delimited since they will follow python syntax?

Not sure - what do you think? If it's too awful then at least we validate to ensure no . in name during model.load

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a problem. That's an easy fix. I just put it all together in a single line because I felt it was easier to understand and to cover more cases - in particular the sub-attributes one - in one, consistent approach.

About users using . for the node names... it might be a bit opinionated, but you are the developer, so you tell the users how they should use the tool. If you tell them not to use . but _ or something else, they won't use .. It is not the other way around.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you're right. There's no point on artificially restricting what a node name can be. What about changing the decorator signature to?

def register_node_patch(
    node_name: str, target: str, item: Hashable = None, is_attr: bool = False
) -> Callable:

So node_name is provided independently and can therefore be anything?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works for me!

Copy link
Collaborator

@barneydobson barneydobson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tyop but otherwise lgtm


The overridden method or attribute can be accessed within the patched function using the
`_patched_{method_name}` attribute of the object, eg. `self._patched_pull_distributed`.
The exception to this is when patching an item, in which case the original item is no
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The exception to this is when patching an item, in which case the original item is no
The exception to this is when patching an item, in which case the original item is not

Copy link
Collaborator

@barneydobson barneydobson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually hold on.. (now with examples in new review)

This example (test_extensions.py):

    # 3. Patch a method with an item
    @register_node_patch("dummy_node", "pull_set_handler", item="default")
    def yet_another_dummy_patch():
        pass

Doesn't seem consistent with this in the docstring (extensions.py):

Example of patching an attribute item:

`patch_default_pull_set_handler` will be assigned to
`pull_set_handler["default"]`:

    >>> @register_node_patch("my_node", "pull_set_handler", item="default")
    >>> def patch_default_pull_set_handler(self, vqip):
    >>>     return {}

Does the patch have access to self?

Add handler behaviour tests
assert_dict_almost_equal(vq, node.empty_vqip())
assert node.tank.storage["volume"] == 5

def test_handler_extensions(temp_extension_registry):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dalonsoa I've added a test to update handlers in different ways - if you are happy that these are suitable then could you update the docstring accordingly please

@dalonsoa
Copy link
Collaborator Author

It does have access, and indeed the test will fail if I call the node.pull_set_handler["default"] since it does not have the right input arguments. All versions of the patch require the first argument to be a node (typically self). The only difference is that for the case of having an item the previous version of the object to override is not preserved in a self._patched_whatever. This is - I had hoped - explained in the docstring.

I'll check your test, now

@dalonsoa
Copy link
Collaborator Author

There was a but in the code making the use of the function more convoluted. I've fixed the bug and amend your test, so it works as it should. See the following for the changes 566976e

Many thanks for your tests - they have picked a few things that were just wrong.

@dalonsoa dalonsoa merged commit e74c6e7 into main Sep 11, 2024
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants