Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed May 13, 2022
1 parent 8f49289 commit f09f687
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions torch_geometric/graphgym/model_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
import typing
import warnings
from typing import Any, Dict, Tuple

Expand All @@ -11,9 +10,6 @@
from torch_geometric.graphgym.optim import create_optimizer, create_scheduler
from torch_geometric.graphgym.register import network_dict, register_network

if typing.TYPE_CHECKING:
from yacs.config import CfgNode

try:
from pytorch_lightning import LightningModule
except ImportError:
Expand All @@ -25,7 +21,7 @@


class GraphGymModule(LightningModule):
def __init__(self, dim_in, dim_out, cfg: "CfgNode"):
def __init__(self, dim_in, dim_out, cfg):
super().__init__()
self.cfg = cfg
self.model = network_dict[cfg.model.type](dim_in=dim_in,
Expand All @@ -48,13 +44,13 @@ def _shared_step(self, batch, split: str) -> Dict:
step_end_time=step_end_time)

def training_step(self, batch, *args, **kwargs):
return self._shared_step(batch, "train")
return self._shared_step(batch, split="train")

def validation_step(self, batch, *args, **kwargs):
return self._shared_step(batch, "val")
return self._shared_step(batch, split="val")

def test_step(self, batch, *args, **kwargs):
return self._shared_step(batch, "test")
return self._shared_step(batch, split="test")

@property
def encoder(self) -> torch.nn.Module:
Expand All @@ -74,8 +70,7 @@ def pre_mp(self) -> torch.nn.Module:


def create_model(to_device=True, dim_in=None, dim_out=None):
r"""
Create model for graph machine learning
r"""Create model for graph machine learning.
Args:
to_device (string): The devide that the model will be transferred to
Expand Down

0 comments on commit f09f687

Please sign in to comment.