Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev(hansbug): add register custom dicts #88

Merged
merged 7 commits into from
Oct 22, 2023
Merged

dev(hansbug): add register custom dicts #88

merged 7 commits into from
Oct 22, 2023

Conversation

HansBug
Copy link
Member

@HansBug HansBug commented Oct 13, 2023

Description

import torch.nn

from treevalue import FastTreeValue


class MyModule(torch.nn.Module):
    def __init__(self, p):
        torch.nn.Module.__init__(self)
        self.relu = torch.nn.ReLU()
        self.p = torch.tensor(p)

    def forward(self, x):
        return self.relu(x + self.p)


class FullModule(torch.nn.Module):
    def __init__(self, **kwargs):
        torch.nn.Module.__init__(self)
        self._module_dict = torch.nn.ModuleDict({
            key: MyModule(value)
            for key, value in kwargs.items()
        })
        self._module_tv = FastTreeValue(self._module_dict)

    def forward(self, x):
        return self._module_tv(x)


model = FullModule(a=1, b=2)
print(model)

input_ = FastTreeValue({
    'a': torch.randn(3, 4),
    'b': torch.randn(2, 3),
})
print(model(input_))

TODO

  • Try to reduce the lines of ModuleDict&TreeValue usage

Check List

  • merge the latest version source branch/repo, and resolve all the conflicts
  • pass style check
  • pass all the tests

@HansBug HansBug added the enhancement New feature or request label Oct 13, 2023
@HansBug HansBug self-assigned this Oct 13, 2023
@codecov
Copy link

codecov bot commented Oct 13, 2023

Codecov Report

Merging #88 (ab350d6) into main (bd171e4) will decrease coverage by 0.09%.
The diff coverage is 91.89%.

@@            Coverage Diff             @@
##             main      #88      +/-   ##
==========================================
- Coverage   98.88%   98.80%   -0.09%     
==========================================
  Files          43       43              
  Lines        2792     2837      +45     
==========================================
+ Hits         2761     2803      +42     
- Misses         31       34       +3     
Flag Coverage Δ
unittests 98.80% <91.89%> (-0.09%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
treevalue/tree/tree/__init__.py 100.00% <100.00%> (ø)
treevalue/tree/tree/tree.pyx 97.97% <96.66%> (-0.15%) ⬇️
treevalue/tree/integration/torch.py 90.47% <66.66%> (+23.80%) ⬆️

... and 6 files with indirect coverage changes

test/tree/general/base.py Outdated Show resolved Hide resolved
treevalue/tree/integration/jax.py Outdated Show resolved Hide resolved
@HansBug HansBug merged commit 83bca17 into main Oct 22, 2023
54 of 55 checks passed
@HansBug HansBug deleted the dev/dict branch October 22, 2023 14:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants