diff --git a/grace/evaluation/dim_reduction.py b/grace/evaluation/dim_reduction.py new file mode 100644 index 0000000..cc96a0d --- /dev/null +++ b/grace/evaluation/dim_reduction.py @@ -0,0 +1,153 @@ +import torch + +from pathlib import Path + +import matplotlib.pyplot as plt +import networkx as nx + +import numpy.typing as npt +import matplotlib + +from sklearn.manifold import TSNE + +from grace.models.datasets import dataset_from_graph + +from grace.logger import LOGGER + + +def drop_linear_layers_from_model( + model: torch.nn.Module, +) -> torch.nn.Sequential: + """Chops off last 2 Linear layers from the classifier to + access node embeddings learnt by the GCN classifier.""" + + modules = list(model.children())[:-2] + node_emb_extractor = torch.nn.Sequential(*modules) + for p in node_emb_extractor.parameters(): + p.requires_grad = False + + return node_emb_extractor + + +class TSNEDimensionalityReduction(object): + def __init__( + self, + graph: nx.Graph, + model: str | Path, + ) -> None: + self.graph = graph + self.model = model + + if self.model is not None: + self.model = Path(self.model) + assert self.model.is_file() + + def read_graph_dataset_IO(self) -> tuple[torch.stack]: + # Prepare GT labels: + dataset_batches = dataset_from_graph( + graph=self.graph, mode="whole", in_train_mode=False + ) + dataset_batches = dataset_batches[0] + + # Prepare the data: + node_labels = dataset_batches.y + node_embeds = dataset_batches.x + edge_indices = dataset_batches.edge_index + + return node_labels, node_embeds, edge_indices + + def extract_GCN_node_embeddings(self) -> tuple[torch.stack]: + node_labels, node_embeds, edge_indices = self.read_graph_dataset_IO() + + if self.model is None: + LOGGER.info( + "Warning, only returning the 'node_embeddings' as" + "no pre-trained GCN model was specified..." + ) + + else: + # Log the classifier time-stamp name: + name = self.model.parent.name + LOGGER.info( + f"Processing the model time-stamp: '{name}/classifier.py'" + ) + + # Load the model & drop the `Linear` layers: + full_gcn_classifier = torch.load(self.model) + gcn_only_classifier = drop_linear_layers_from_model( + full_gcn_classifier + ) + + # If only `Linear` model, log the warning: + if len(gcn_only_classifier) < 1: + LOGGER.info( + "Warning, only returning the 'node_embeddings' as " + "the GCN contains no graph convolutional layers..." + ) + + # Get the GCN node embeddings: + else: + # Prep the model & modify embeddings in-place: + gcn_only_classifier.eval() + for module in gcn_only_classifier[0]: + node_embeds = module(node_embeds, edge_indices) + + # Log the shapes: + LOGGER.info( + "Extracted 'node_embeddings' -> " + f"{node_embeds.shape}, {node_embeds.dtype}" + ) + + return node_labels, node_embeds + + def perform_and_plot_tsne( + self, + node_GT_label: npt.NDArray, + node_features: npt.NDArray, + *, + n_components: int = 2, + title: str = "", + ax: matplotlib.axes = None, + ) -> matplotlib.axes: + # Shapes must agree: + assert len(node_GT_label) == len(node_features) + tsne = TSNE(n_components=n_components) + node_embed = tsne.fit_transform(X=node_features) + + # Plot the TSNE manifold: + title = f"TSNE of Patch Features\n{title}" + umap1, umap2 = node_embed[:, 0], node_embed[:, 1] + scatter = ax.scatter( + x=umap1, y=umap2, c=node_GT_label, cmap="coolwarm" + ) + cbar = plt.colorbar(scatter) + cbar.ax.get_yaxis().labelpad = 15 + cbar.ax.set_ylabel("Ground Truth Node Label", rotation=270) + ax.set_xlabel("UMAP 1") + ax.set_ylabel("UMAP 2") + ax.set_title(title) + return ax + + def plot_TSNE_before_and_after_GCN(self, **kwargs) -> None: + # Plot the subplots: + size = 5 + _, axes = plt.subplots(1, 2, figsize=(size * 2 + 2, size * 1)) + + # Get the embeddings: + for p, (plot_name, method) in enumerate( + zip( + ["Before", "After"], + [self.read_graph_dataset_IO, self.extract_GCN_node_embeddings], + ) + ): + labels, embeds = method()[:2] + shape = embeds.shape[-1] + title = f"{plot_name} GCN | Node Feature Embedding [{shape}]" + self.perform_and_plot_tsne( + labels, embeds, title=title, ax=axes[p], **kwargs + ) + + # Annotate & display: + plt.tight_layout() + plt.show() + plt.close() diff --git a/grace/evaluation/inference.py b/grace/evaluation/inference.py index 13dff5c..8b0a6e7 100644 --- a/grace/evaluation/inference.py +++ b/grace/evaluation/inference.py @@ -1,6 +1,8 @@ import networkx as nx import numpy as np +from pathlib import Path + import torch from torch_geometric.data import Data from tqdm.auto import tqdm @@ -13,14 +15,23 @@ accuracy_metric, areas_under_curves_metrics, ) +from grace.evaluation.visualisation import ( + visualise_prediction_probs_hist, + visualise_node_and_edge_probabilities, +) class GraphLabelPredictor(object): """TODO: Fill in.""" - def __init__(self, model: torch.nn.Module) -> None: + def __init__(self, model: str | torch.nn.Module) -> None: super().__init__() + if isinstance(model, str): + assert Path(model).is_file() + model = torch.load(model) + + model.eval() self.pretrained_gcn = model def set_node_and_edge_probabilities(self, G: nx.Graph): @@ -40,16 +51,22 @@ def set_node_and_edge_probabilities(self, G: nx.Graph): prediction = (int(e_pred[e_idx].item()), e_probabs[e_idx].numpy()) edge[-1][GraphAttrs.EDGE_PREDICTION] = prediction - def visualise_performance(self, G: nx.Graph): + def visualise_model_performance_on_graph( + self, G: nx.Graph, positive_class: int = 1 + ): # Prep the data & plot them: - node_true = [ - node[GraphAttrs.NODE_GROUND_TRUTH] - for _, node in G.nodes(data=True) - ] - node_pred = [ - node[GraphAttrs.NODE_PREDICTION][0] - for _, node in G.nodes(data=True) - ] + node_true = np.array( + [ + node[GraphAttrs.NODE_GROUND_TRUTH] + for _, node in G.nodes(data=True) + ] + ) + node_pred = np.array( + [ + node[GraphAttrs.NODE_PREDICTION][0] + for _, node in G.nodes(data=True) + ] + ) node_probabs = np.array( [ node[GraphAttrs.NODE_PREDICTION][1] @@ -57,14 +74,18 @@ def visualise_performance(self, G: nx.Graph): ] ) - edge_true = [ - edge[GraphAttrs.EDGE_GROUND_TRUTH] - for _, _, edge in G.edges(data=True) - ] - edge_pred = [ - edge[GraphAttrs.EDGE_PREDICTION][0] - for _, _, edge in G.edges(data=True) - ] + edge_true = np.array( + [ + edge[GraphAttrs.EDGE_GROUND_TRUTH] + for _, _, edge in G.edges(data=True) + ] + ) + edge_pred = np.array( + [ + edge[GraphAttrs.EDGE_PREDICTION][0] + for _, _, edge in G.edges(data=True) + ] + ) edge_probabs = np.array( [ edge[GraphAttrs.EDGE_PREDICTION][1] @@ -72,13 +93,41 @@ def visualise_performance(self, G: nx.Graph): ] ) + # Unify the inputs - get the predictions scores for TP class: + filter_mask = np.logical_or( + node_true == 0, node_true == positive_class + ) + node_true = node_true[filter_mask] + node_pred = node_pred[filter_mask] + node_probabs = node_probabs[filter_mask] + + filter_mask = np.logical_or( + edge_true == 0, edge_true == positive_class + ) + edge_true = edge_true[filter_mask] + edge_pred = edge_pred[filter_mask] + edge_probabs = edge_probabs[filter_mask] + + # Compute & return accuracy: node_acc, edge_acc = accuracy_metric( node_pred, edge_pred, node_true, edge_true ) + + # Display metrics figures: + plot_confusion_matrix_tiles(node_pred, edge_pred, node_true, edge_true) + areas_under_curves_metrics( - node_probabs, edge_probabs, node_true, edge_true, figsize=(10, 4) + node_probabs[:, positive_class], + edge_probabs[:, positive_class], + node_true, + edge_true, + figsize=(10, 4), ) - plot_confusion_matrix_tiles(node_pred, edge_pred, node_true, edge_true) + + # Localise where the errors occur: + visualise_prediction_probs_hist(G=G) + visualise_node_and_edge_probabilities(G=G) + return node_acc, edge_acc diff --git a/grace/evaluation/metrics_classifier.py b/grace/evaluation/metrics_classifier.py index ec9bf6a..68ab00d 100644 --- a/grace/evaluation/metrics_classifier.py +++ b/grace/evaluation/metrics_classifier.py @@ -94,18 +94,11 @@ def areas_under_curves_metrics( edge_true: torch.tensor, figsize: tuple[int] = (20, 7), ) -> tuple[plt.figure]: - # Unify the inputs - get the predictions scores for TP class: - if node_pred.shape[-1] == 2: - node_pred = node_pred[:, 1] - if edge_pred.shape[-1] == 2: - edge_pred = edge_pred[:, 1] - # Instantiate the figure _, axes = plt.subplots(nrows=1, ncols=2, figsize=figsize) # Area under ROC: roc_score_nodes = roc_auc_score(y_true=node_true, y_score=node_pred) - # rcd_nodes = RocCurveDisplay.from_predictions( RocCurveDisplay.from_predictions( y_true=node_true, y_pred=node_pred, @@ -116,7 +109,6 @@ def areas_under_curves_metrics( ) roc_score_edges = roc_auc_score(y_true=edge_true, y_score=edge_pred) - # rcd_edges = RocCurveDisplay.from_predictions( RocCurveDisplay.from_predictions( y_true=edge_true, y_pred=edge_pred, @@ -130,7 +122,6 @@ def areas_under_curves_metrics( prc_score_nodes = average_precision_score( y_true=node_true, y_score=node_pred ) - # prc_nodes = PrecisionRecallDisplay.from_predictions( PrecisionRecallDisplay.from_predictions( y_true=node_true, y_pred=node_pred, @@ -143,7 +134,6 @@ def areas_under_curves_metrics( prc_score_edges = average_precision_score( y_true=edge_true, y_score=edge_pred ) - # prc_edges = PrecisionRecallDisplay.from_predictions( PrecisionRecallDisplay.from_predictions( y_true=edge_true, y_pred=edge_pred, @@ -154,7 +144,6 @@ def areas_under_curves_metrics( ) # Annotate the figure: - # axes[0].plot([0, 0], [1, 1], ls="dashed", lw=1, color="lightgrey") axes[0].plot([0, 1], [0, 1], ls="dashed", lw=1, color="lightgrey") axes[1].plot([0, 1], [0.5, 0.5], ls="dashed", lw=1, color="lightgrey") axes[1].plot([0.5, 0.5], [0, 1], ls="dashed", lw=1, color="lightgrey") @@ -162,8 +151,6 @@ def areas_under_curves_metrics( axes[0].set_title("Area under ROC") axes[1].set_title("Average Precision Score") plt.tight_layout() - # plt.show() - return axes diff --git a/grace/evaluation/visualisation.py b/grace/evaluation/visualisation.py index f2f3a59..2735186 100644 --- a/grace/evaluation/visualisation.py +++ b/grace/evaluation/visualisation.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import networkx as nx import numpy as np +import matplotlib import numpy.typing as npt @@ -12,23 +13,19 @@ def plot_simple_graph( - # G: nx.Graph, title: str = "", figsize: tuple[int, int] = (16, 16), G: nx.Graph, title: str = "", - ax=None, + ax: matplotlib.axes = None, ) -> None: """Plots a simple graph with black nodes and edges.""" - # Fancy annotation plot - # _, ax = plt.subplots(figsize=figsize) - - # node positions + # Read node positions: pos = { idx: (node[GraphAttrs.NODE_X], node[GraphAttrs.NODE_Y]) for idx, node in G.nodes(data=True) } - # draw all nodes/vertices in the graph, including noisy nodes + # Draw all nodes/vertices in the graph, including noisy nodes: nx.draw_networkx( G, ax=ax, @@ -40,33 +37,33 @@ def plot_simple_graph( ) ax.set_title(f"{title}") - # plt.show() return ax def plot_connected_components( - # G: nx.Graph, title: str = "", figsize: tuple[int, int] = (16, 16) G: nx.Graph, title: str = "", - ax=None, + ax: matplotlib.axes = None, ) -> None: """Colour-codes the connected components (individual objects) & plots them onto a simple graph with black nodes & edges. Connected component (subgraph) must contain at least one edge. """ - # Fancy annotation plot - # _, ax = plt.subplots(figsize=figsize) - - # node positions + # Read node positions: pos = { idx: (node[GraphAttrs.NODE_X], node[GraphAttrs.NODE_Y]) for idx, node in G.nodes(data=True) } - # draw all nodes/vertices in the graph, including noisy nodes + # Draw all nodes/vertices in the graph, including noisy nodes: nx.draw_networkx( - G, ax=ax, pos=pos, with_labels=False, node_color="k", node_size=32 + G, + ax=ax, + pos=pos, + with_labels=False, + node_color="k", + node_size=32, ) # get each connected subgraph and draw it with a different colour @@ -81,16 +78,20 @@ def plot_connected_components( sg = G.subgraph(sg).copy() nx.draw_networkx( - sg, ax=ax, pos=pos, edge_color=c_idx, node_color=c_idx + sg, + ax=ax, + pos=pos, + edge_color=c_idx, + node_color=c_idx, ) ax.set_title(f"{title}") - # plt.show() return ax def display_image_and_grace_annotation( - image: npt.NDArray, target: dict[str] + image: npt.NDArray, + target: dict[str], ) -> None: """Overlays the annotation image (binary mask) with annotated graph, colour-coding the true positive (TP), true negative (TN), and @@ -158,8 +159,7 @@ def display_image_and_grace_annotation( ax.imshow(annotation, cmap=plt.cm.turbo, interpolation="none") - # draw all nodes/vertices in the graph, including those not determined to be - # part of the objects + # draw all nodes/vertices in the graph: nx.draw_networkx( graph, ax=ax, @@ -171,6 +171,125 @@ def display_image_and_grace_annotation( ) ax.set_title(f"{target['metadata']['image_filename']}\n{node_GT_counter}") + plt.show() + plt.close() + + +def visualise_prediction_probs_hist(G: nx.Graph) -> None: + """Plot the prediction probabilities colour-coded by their GT label.""" + + # Process the true & pred values: + n_true, n_pred = [], [] + for _, node in G.nodes(data=True): + n_pred.append(node[GraphAttrs.NODE_PREDICTION][1][1]) + n_true.append(node[GraphAttrs.NODE_GROUND_TRUTH]) + + e_true, e_pred = [], [] + for _, _, edge in G.edges(data=True): + e_pred.append(edge[GraphAttrs.EDGE_PREDICTION][1][1]) + e_true.append(edge[GraphAttrs.EDGE_GROUND_TRUTH]) + + # Plot the node & edge histogram by label: + _, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 4)) + for i, (pred, true, att) in enumerate( + zip([n_pred, e_pred], [n_true, e_true], ["nodes", "edges"]) + ): + for lab_idx in np.unique(true): + preds = [p for p, t in zip(pred, true) if t == lab_idx] + axes[i].hist( + preds, alpha=0.7, label=f"GT = {lab_idx} | {len(preds)} {att}" + ) + axes[i].set_title(f"Inferred predictions -> TP {att}") + axes[i].set_xlabel("Predicted softmax probability") + axes[i].legend() + + axes[0].set_ylabel("Attribute count") + plt.show() + plt.close() + + +def visualise_node_and_edge_probabilities(G: nx.Graph) -> None: + """Visualise per-node & per-edge predictions on color-coded + graph of TP attribute probabilities independently for + nodes, independently for edges & in overlay of both. + """ + + # Create a figure and axes + nrows, ncols = 1, 3 + _, axes = plt.subplots(nrows, ncols, figsize=(15, 4)) + cmap = plt.cm.ScalarMappable(cmap="coolwarm") + + # JUST THE NODES: + nodes = list(G.nodes(data=True)) + x_coords = [node[GraphAttrs.NODE_X] for _, node in nodes] + y_coords = [node[GraphAttrs.NODE_Y] for _, node in nodes] + node_preds = [node[GraphAttrs.NODE_PREDICTION][1][1] for _, node in nodes] + + # Plot nodes: + axes[0].scatter( + x=x_coords, + y=y_coords, + c=node_preds, + cmap="coolwarm", + vmin=0.0, + vmax=1.0, + ) + axes[2].scatter( + x=x_coords, + y=y_coords, + c=node_preds, + cmap="coolwarm", + vmin=0.0, + vmax=1.0, + ) + + # Add colorbar: + cbar = plt.colorbar(cmap, ax=axes[0]) + cbar.set_label("Node Probability") + + # JUST THE EDGES: + for src, dst, edge in G.edges(data=True): + e_st_x, e_st_y = ( + nodes[src][1][GraphAttrs.NODE_X], + nodes[src][1][GraphAttrs.NODE_Y], + ) + e_en_x, e_en_y = ( + nodes[dst][1][GraphAttrs.NODE_X], + nodes[dst][1][GraphAttrs.NODE_Y], + ) + edge_pred = edge[GraphAttrs.EDGE_PREDICTION][1][1] + + axes[1].plot( + [e_st_x, e_en_x], + [e_st_y, e_en_y], + color=cmap.to_rgba(edge_pred), + marker="", + ) + axes[2].plot( + [e_st_x, e_en_x], + [e_st_y, e_en_y], + color=cmap.to_rgba(edge_pred), + marker="", + ) + + # Add colorbar + cbar = plt.colorbar(cmap, ax=axes[1]) + cbar.set_label("Edge Probability") + + # Annotate & display: + cbar = plt.colorbar(cmap, ax=axes[2]) + cbar.set_label("TP Probability") + + axes[0].set_title("Probability of 'nodeness'") + axes[1].set_title("Probability of 'edgeness'") + axes[2].set_title("Merged graph predictions") + + [axes[i].get_xaxis().set_visible(False) for i in range(ncols)] + [axes[i].get_yaxis().set_visible(False) for i in range(ncols)] + + plt.tight_layout() + plt.show() + plt.close() def read_patch_stack_by_label( @@ -259,6 +378,7 @@ def montage_from_image_patches(crops: list[npt.NDArray]) -> None: plt.title(f"Montage of patches\nwith 'node_label' = {c}") plt.axis("off") plt.show() + plt.close() def overlay_from_image_patches(crops: list[npt.NDArray]) -> None: @@ -281,3 +401,4 @@ def overlay_from_image_patches(crops: list[npt.NDArray]) -> None: plt.title(f"Montage of patches\nwith 'node_label' = {c}") plt.axis("off") plt.show() + plt.close() diff --git a/grace/io/image_dataset.py b/grace/io/image_dataset.py index da623ca..6f777aa 100644 --- a/grace/io/image_dataset.py +++ b/grace/io/image_dataset.py @@ -7,10 +7,10 @@ import cv2 import tifffile import mrcfile -import logging from grace.io import read_graph from grace.base import GraphAttrs, Annotation +from grace.logger import LOGGER import torch from torch.utils.data import Dataset @@ -94,7 +94,7 @@ def __getitem__(self, idx: int) -> tuple[torch.Tensor, dict]: # Print original graph label statistics: if self.verbose is True: - logging.info(img_path.stem) + LOGGER.info(img_path.stem) log_graph_label_statistics(graph) # Relabel Annotation.UNKNOWN in nodes: @@ -111,7 +111,7 @@ def __getitem__(self, idx: int) -> tuple[torch.Tensor, dict]: self.keep_node_unknown_labels is False or self.keep_edge_unknown_labels is False ): - logging.info("Relabelled 'Annotation.UNKNOWN'") + LOGGER.info("Relabelled 'Annotation.UNKNOWN'") log_graph_label_statistics(graph) # Package together: @@ -164,7 +164,7 @@ def log_graph_label_statistics(G: nx.Graph) -> None: perc = [item / np.sum(counter) for item in counter] perc = [float("%.2f" % (elem * 100)) for elem in perc] string = f"{attribute.capitalize()} count | {counter} x | {perc} %" - logging.info(string) + LOGGER.info(string) def mrc_reader(fn: os.PathLike) -> npt.NDArray: diff --git a/grace/logger.py b/grace/logger.py new file mode 100644 index 0000000..894a9e1 --- /dev/null +++ b/grace/logger.py @@ -0,0 +1,8 @@ +import logging + +LOGGER = logging +LOGGER.basicConfig( + level=logging.INFO, + format="%(asctime)s %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", +) diff --git a/grace/training/run.py b/grace/training/run.py index 14b9872..0d9c37b 100644 --- a/grace/training/run.py +++ b/grace/training/run.py @@ -3,11 +3,11 @@ import os import click import torch -import logging from datetime import datetime from tqdm.auto import tqdm +from grace.logger import LOGGER from grace.io.image_dataset import ImageGraphDataset from grace.training.train import train_model from grace.models.datasets import dataset_from_graph @@ -16,12 +16,6 @@ from grace.training.config import write_config_file, load_config_params from grace.utils.transforms import get_transforms -logging.basicConfig( - level=logging.INFO, - format="%(asctime)s %(message)s", - datefmt="%m/%d/%Y %I:%M:%S %p", -) - def run_grace(config_file: Union[str, os.PathLike]) -> None: """Runs the GRACE pipeline; going straight from images and .grace annotations @@ -68,7 +62,7 @@ def transform(img, grph): input_data, desc="Extracting patch features from training data... " ): file_name = target["metadata"]["image_filename"] - logging.info(f"Processing file: {file_name}") + LOGGER.info(f"Processing file: {file_name}") graph_dataset = dataset_from_graph(target["graph"], mode="sub") dataset.extend(graph_dataset) diff --git a/grace/training/train.py b/grace/training/train.py index a414065..a9f6a03 100644 --- a/grace/training/train.py +++ b/grace/training/train.py @@ -3,7 +3,6 @@ import random import torch import torch_geometric -import logging import matplotlib.pyplot as plt @@ -11,6 +10,7 @@ from grace.base import Annotation from grace.evaluation.metrics_classifier import get_metric +from grace.logger import LOGGER from torch.utils.tensorboard import SummaryWriter @@ -213,7 +213,7 @@ def valid(loader): ) # Print out the logging string: - logging.info(logger_string) + LOGGER.info(logger_string) writer.flush() writer.close() diff --git a/notebooks/infer_predictions.ipynb b/notebooks/infer_predictions.ipynb index a662e27..0cec4bb 100644 --- a/notebooks/infer_predictions.ipynb +++ b/notebooks/infer_predictions.ipynb @@ -14,7 +14,6 @@ "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", - "\n", "import torch" ] }, @@ -26,8 +25,15 @@ "source": [ "from grace.io.image_dataset import ImageGraphDataset\n", "from grace.models.feature_extractor import FeatureExtractor\n", + "\n", "from grace.evaluation.visualisation import plot_simple_graph\n", - "from grace.evaluation.inference import GraphLabelPredictor\n" + "from grace.evaluation.inference import GraphLabelPredictor\n", + "from grace.evaluation.dim_reduction import TSNEDimensionalityReduction\n", + "from grace.evaluation.visualisation import (\n", + " read_patch_stack_by_label, \n", + " montage_from_image_patches,\n", + " overlay_from_image_patches,\n", + ")" ] }, { @@ -43,9 +49,10 @@ "metadata": {}, "outputs": [], "source": [ + "bbox_size = (224, 224)\n", "extractor_filename = \"/Users/kulicna/Desktop/classifier/extractor/resnet152.pt\"\n", "pre_trained_resnet = torch.load(extractor_filename)\n", - "feature_extractor = FeatureExtractor(model=pre_trained_resnet)\n" + "feature_extractor = FeatureExtractor(model=pre_trained_resnet, bbox_size=bbox_size)\n" ] }, { @@ -56,13 +63,13 @@ "source": [ "grace_path = \"/Users/kulicna/Desktop/dataset/shape_stars/train\"\n", "# grace_path = \"/Users/kulicna/Desktop/dataset/shape_stars/infer\"\n", + "\n", "dataset = ImageGraphDataset(\n", " image_dir=grace_path, \n", " grace_dir=grace_path, \n", " transform=feature_extractor,\n", - " keep_node_unknown_labels=False, \n", - " keep_edge_unknown_labels=False, \n", - " \n", + " keep_node_unknown_labels=True, \n", + " keep_edge_unknown_labels=True, \n", ")" ] }, @@ -106,6 +113,17 @@ "plt.show()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "crops = read_patch_stack_by_label(G, image=image, crop_shape=bbox_size)\n", + "montage_from_image_patches(crops)\n", + "overlay_from_image_patches(crops)" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -119,20 +137,33 @@ "metadata": {}, "outputs": [], "source": [ - "# classifier_filename = \"/Users/kulicna/Desktop/classifier/runs/2023-09-07_17-07-08/classifier.pt\"\n", - "# classifier_filename = \"/Users/kulicna/Desktop/classifier/runs/2023-09-07_17-15-47/classifier.pt\"\n", - "classifier_filename = \"/Users/kulicna/Desktop/classifier/runs/2023-09-07_17-30-51/classifier.pt\" # best Linear classifier\n", "# classifier_filename = \"/Users/kulicna/Desktop/classifier/runs/2023-09-08_15-11-58/classifier.pt\" # bad GCN + Linear classifier\n", - "\n", - "pre_trained_gcn = torch.load(classifier_filename)\n", - "pre_trained_gcn.eval()\n" + "classifier_filename = \"/Users/kulicna/Desktop/classifier/runs/2023-09-07_17-30-51/classifier.pt\" # best Linear classifier\n", + "classifier_filename\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Perform TSNE before & after GCN:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dim_red = TSNEDimensionalityReduction(graph=G, model=classifier_filename)\n", + "dim_red.plot_TSNE_before_and_after_GCN()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Features are now automatically appended to the image - predict:" + "### Show how well the classifier performs:" ] }, { @@ -141,9 +172,9 @@ "metadata": {}, "outputs": [], "source": [ - "predictor = GraphLabelPredictor(pre_trained_gcn)\n", - "predictor.set_node_and_edge_probabilities(G)\n", - "node_acc, edge_acc = predictor.visualise_performance(G)\n", + "GLP = GraphLabelPredictor(model=classifier_filename)\n", + "GLP.set_node_and_edge_probabilities(G=G)\n", + "node_acc, edge_acc = GLP.visualise_model_performance_on_graph(G=G)\n", "print(f\"Node accuracy = {node_acc:.4f} | Edge accuracy = {edge_acc:.4f}\")" ] }, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 78f7abc..7b34b7d 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,6 +1,3 @@ -# import networkx as nx -# import torch - import pytest from grace.io.core import GraphAttrs, Annotation