diff --git a/flash/graph/GraphClassification/data.py b/flash/graph/GraphClassification/data.py index ac985486387..a02e865ea54 100644 --- a/flash/graph/GraphClassification/data.py +++ b/flash/graph/GraphClassification/data.py @@ -33,7 +33,7 @@ class BasicGraphDataset(Dataset): ''' - Probably unnecessary having the following class. + #todo: Probably unnecessary having the following class. ''' def __init__(self, root = None, processed_dir = 'processed', raw_dir = 'raw', transform=None, pre_transform=None, pre_filter=None): @@ -94,6 +94,10 @@ def __init__( if self.has_labels: self.label_to_class_mapping = {v: k for k, v in enumerate(list(sorted(list(set(self.fnames)))))} + @property + def has_dict_labels(self) -> bool: + return isinstance(self.labels, dict) + @property def has_labels(self) -> bool: return self.labels is not None @@ -105,6 +109,10 @@ def __getitem__(self, index: int) -> Tuple[Any, Optional[int]]: filename = self.fnames[index] graph = self.loader(filename) label = None + if self.has_dict_labels: + name = os.path.splitext(filename)[0] + name = os.path.basename(name) + label = self.labels[name] if self.has_labels: label = self.label_to_class_mapping[filename] return graph, label @@ -136,7 +144,7 @@ class FlashDatasetFolder(torch.utils.data.Dataset): with_targets: Whether to include targets graph_paths: List of graph paths to load. Only used when ``with_targets=False`` - Attributes: + Attributes: classes (list): List of the class names sorted alphabetically. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples @@ -147,7 +155,7 @@ def __init__( self, root: str, loader: Callable, - extensions: Tuple[str] = Graph_EXTENSIONS, + extensions: Tuple[str] = Graph_EXTENSIONS, #todo: Graph_EXTENSIONS is not defined. In PyG the extension .pt is used transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, is_valid_file: Optional[Callable] = None, @@ -175,7 +183,7 @@ def __init__( else: if not graph_paths: raise MisconfigurationException( - "`FlashDatasetFolder(with_target=False)` but no `img_paths` were provided" + "`FlashDatasetFolder(with_target=False)` but no `graph_paths` were provided" ) self.samples = graph_paths @@ -281,6 +289,16 @@ def from_filepaths( >>> _data = GraphClassificationData.from_filepaths(["a.pt", "b.pt"], [0, 1]) # doctest: +SKIP """ + + # enable passing in a string which loads all files in that folder as a list + if isinstance(train_filepaths, str): + train_filepaths = [os.path.join(train_filepaths, x) for x in os.listdir(train_filepaths)] + if isinstance(valid_filepaths, str): + valid_filepaths = [os.path.join(valid_filepaths, x) for x in os.listdir(valid_filepaths)] + if isinstance(test_filepaths, str): + test_filepaths = [os.path.join(test_filepaths, x) for x in os.listdir(test_filepaths)] + + train_ds = FilepathDataset( filepaths=train_filepaths, labels=train_labels, diff --git a/flash/graph/GraphClassification/model.py b/flash/graph/GraphClassification/model.py index 691436e0a3d..225d4f4222e 100644 --- a/flash/graph/GraphClassification/model.py +++ b/flash/graph/GraphClassification/model.py @@ -14,6 +14,7 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union, Mapping, Sequence, Union import torch +import pytorch_lightning as pl from pytorch_lightning.metrics import Accuracy from torch import nn from torch.nn import functional as F @@ -21,6 +22,7 @@ from torch_geometric.nn import GCNConv from torch_geometric.nn import global_mean_pool + from flash.core.classification import ClassificationTask from flash.core.data import DataPipeline @@ -47,25 +49,26 @@ def __init__( optimizer: Type[torch.optim.Optimizer] = torch.optim.Adam, metrics: Union[Callable, Mapping, Sequence, None] = [Accuracy()], learning_rate: float = 1e-3, + model: torch.nn.Module = None, ): if isinstance(hidden, int): hidden = [hidden] #sizes = [input_size] + hidden + [num_classes] + if model == None: + self.model = GCN(in_features = num_features, hidden_channels=hidden, out_features = num_classes) super().__init__( - model = GCN(in_features = num_features, hidden_channels=hidden, out_features = num_classes), + model = model, loss_fn=loss_fn, optimizer=optimizer, metrics=metrics, learning_rate=learning_rate, ) - #train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) - def forward(self, data) -> Any: - x = self.model(data.x, data.edge_index, data.batch) #This line is probably something to change + x = self.model(data.x, data.edge_index, data.batch) return self.head(x) @staticmethod @@ -73,10 +76,10 @@ def default_pipeline() -> ClassificationDataPipeline: return GraphClassificationData.default_pipeline() #Taken from https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=CN3sRVuaQ88l -class GCN(torch.nn.Module): +class GCN(pl.LightningModule): def __init__(self, num_features, hidden_channels, num_classes): - super(GCN, self).__init__() #I don't understand why we need to call super here with GCN as an argument - #torch.manual_seed(12345) + super(GCN, self).__init__() + torch.manual_seed(12345) self.conv1 = GCNConv(num_features, hidden_channels) self.conv2 = GCNConv(hidden_channels, hidden_channels) self.conv3 = GCNConv(hidden_channels, hidden_channels) @@ -96,5 +99,14 @@ def forward(self, x, edge_index, batch): # 3. Apply a final classifier x = F.dropout(x, p=0.5, training=self.training) x = self.lin(x) - - return x \ No newline at end of file + + return x + + def training_step(self, batch, batch_idx): #todo: is this needed? + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return loss + + def configure_optimizers(self): #todo: is this needed? + return torch.optim.Adam(self.parameters(), lr=0.02) \ No newline at end of file