Skip to content

Commit

Permalink
added possibility to load alpha weights from params
Browse files Browse the repository at this point in the history
Signed-off-by: Martin <[email protected]>
  • Loading branch information
bmmtstb committed Nov 23, 2024
1 parent f8c3d5a commit 60c27c2
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 3 deletions.
18 changes: 17 additions & 1 deletion dgs/models/alpha/alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@

from dgs.models.modules.named import NamedModule
from dgs.utils.state import State
from dgs.utils.torchtools import init_model_params, load_pretrained_weights
from dgs.utils.types import Config, NodePath, Validations

alpha_validations: Validations = {}
alpha_validations: Validations = {
# optional
"weight": ["optional", ("file exists", "./weights/")],
}


class BaseAlphaModule(NamedModule, t.nn.Module):
Expand All @@ -22,6 +26,10 @@ class BaseAlphaModule(NamedModule, t.nn.Module):
Optional Params
---------------
weight (FilePath):
Local or absolute path to the pretrained weights of the model.
Can be left empty.
"""

model: t.nn.Module
Expand Down Expand Up @@ -56,3 +64,11 @@ def sub_forward(self, data: t.Tensor) -> t.Tensor:
def get_data(self, s: State) -> any:
"""Given a state, return the data which is input into the model."""
raise NotImplementedError

def load_weights(self) -> None:
"""Load the weights of the model from the given file path. If no weights are given, initialize the model."""
if "weight" in self.params:
fp = self.params.get("weight")
load_pretrained_weights(model=self.model, weight_path=fp)
else:
init_model_params(self.model)
1 change: 1 addition & 0 deletions dgs/models/alpha/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(self, config: Config, path: NodePath):
raise NotImplementedError(f"Expected list or str, got: {sub_path}")

self.register_module(name="model", module=self.configure_torch_module(t.nn.Sequential(*modules)))
self.load_weights()

def forward(self, s: State) -> t.Tensor:
"""Forward call for sequential model calls the next layer with the output of the previous layer.
Expand Down
1 change: 1 addition & 0 deletions dgs/models/alpha/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __init__(self, config: Config, path: NodePath):
act_func=self.params.get("act_func", DEF_VAL["alpha"]["act_func"]),
)
self.register_module(name="model", module=self.configure_torch_module(model))
self.load_weights()

def forward(self, s: State) -> t.Tensor:
return self.model(self.get_data(s))
Expand Down
8 changes: 6 additions & 2 deletions dgs/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,12 @@
"endswith": (lambda x, d: isinstance(x, str) and (isinstance(d, str) or bool(str(d))) and x.endswith(d)),
# file and folder
"file exists": (
lambda x, _: isinstance(x, str)
and (VALIDATIONS["file exists absolute"](x, _) or VALIDATIONS["file exists in project"](x, _))
lambda x, f: isinstance(x, str)
and (
VALIDATIONS["file exists absolute"](x, None)
or VALIDATIONS["file exists in project"](x, None)
or VALIDATIONS["file exists in folder"](x, f)
)
),
"file exists absolute": (lambda x, _: isinstance(x, str) and os.path.isfile(x)),
"file exists in project": (lambda x, _: isinstance(x, str) and os.path.isfile(os.path.join(PROJECT_ROOT, x))),
Expand Down
24 changes: 24 additions & 0 deletions tests/models/alpha/test__alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch as t

from dgs.models.alpha import FullyConnectedAlpha
from dgs.models.alpha.alpha import BaseAlphaModule
from dgs.utils.config import insert_into_config
from helper import get_test_config
Expand Down Expand Up @@ -33,6 +34,29 @@ def test_sub_forward(self):
m.model = t.nn.Identity()
self.assertTrue(t.allclose(m.sub_forward(ones), ones))

@patch.multiple(BaseAlphaModule, __abstractmethods__=set())
def test_init_weights_empty(self):
m = FullyConnectedAlpha(config=self.default_cfg, path=self.default_path)
self.assertTrue(isinstance(m.model, t.nn.Module))
self.assertTrue(len(list(m.model.parameters())) > 0)

@patch.multiple(BaseAlphaModule, __abstractmethods__=set())
def test_init_weights(self):
cfg = insert_into_config(
path=self.default_path,
value={"weight": "./tests/test_data/weights/fully_connected_alpha.pth"},
original=self.default_cfg.copy(),
)
m = FullyConnectedAlpha(config=cfg, path=self.default_path)

w = t.load("./tests/test_data/weights/fully_connected_alpha.pth")

self.assertEqual(len(m.model.state_dict()), len(w))
self.assertEqual(m.model.state_dict().keys(), w.keys())

for msd, wsd in zip(m.model.state_dict().values(), w.values()):
self.assertTrue(t.allclose(msd, wsd))


if __name__ == "__main__":
unittest.main()

0 comments on commit 60c27c2

Please sign in to comment.