Skip to content

Commit

Permalink
Merge pull request #229 from alan-turing-institute/dim-red-viz
Browse files Browse the repository at this point in the history
Graph visualisation & embedding dimensionality reduction
  • Loading branch information
KristinaUlicna authored Sep 12, 2023
2 parents f530eba + b85d6e5 commit bd91bdf
Show file tree
Hide file tree
Showing 10 changed files with 427 additions and 87 deletions.
153 changes: 153 additions & 0 deletions grace/evaluation/dim_reduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import torch

from pathlib import Path

import matplotlib.pyplot as plt
import networkx as nx

import numpy.typing as npt
import matplotlib

from sklearn.manifold import TSNE

from grace.models.datasets import dataset_from_graph

from grace.logger import LOGGER


def drop_linear_layers_from_model(
model: torch.nn.Module,
) -> torch.nn.Sequential:
"""Chops off last 2 Linear layers from the classifier to
access node embeddings learnt by the GCN classifier."""

modules = list(model.children())[:-2]
node_emb_extractor = torch.nn.Sequential(*modules)
for p in node_emb_extractor.parameters():
p.requires_grad = False

return node_emb_extractor


class TSNEDimensionalityReduction(object):
def __init__(
self,
graph: nx.Graph,
model: str | Path,
) -> None:
self.graph = graph
self.model = model

if self.model is not None:
self.model = Path(self.model)
assert self.model.is_file()

def read_graph_dataset_IO(self) -> tuple[torch.stack]:
# Prepare GT labels:
dataset_batches = dataset_from_graph(
graph=self.graph, mode="whole", in_train_mode=False
)
dataset_batches = dataset_batches[0]

# Prepare the data:
node_labels = dataset_batches.y
node_embeds = dataset_batches.x
edge_indices = dataset_batches.edge_index

return node_labels, node_embeds, edge_indices

def extract_GCN_node_embeddings(self) -> tuple[torch.stack]:
node_labels, node_embeds, edge_indices = self.read_graph_dataset_IO()

if self.model is None:
LOGGER.info(
"Warning, only returning the 'node_embeddings' as"
"no pre-trained GCN model was specified..."
)

else:
# Log the classifier time-stamp name:
name = self.model.parent.name
LOGGER.info(
f"Processing the model time-stamp: '{name}/classifier.py'"
)

# Load the model & drop the `Linear` layers:
full_gcn_classifier = torch.load(self.model)
gcn_only_classifier = drop_linear_layers_from_model(
full_gcn_classifier
)

# If only `Linear` model, log the warning:
if len(gcn_only_classifier) < 1:
LOGGER.info(
"Warning, only returning the 'node_embeddings' as "
"the GCN contains no graph convolutional layers..."
)

# Get the GCN node embeddings:
else:
# Prep the model & modify embeddings in-place:
gcn_only_classifier.eval()
for module in gcn_only_classifier[0]:
node_embeds = module(node_embeds, edge_indices)

# Log the shapes:
LOGGER.info(
"Extracted 'node_embeddings' -> "
f"{node_embeds.shape}, {node_embeds.dtype}"
)

return node_labels, node_embeds

def perform_and_plot_tsne(
self,
node_GT_label: npt.NDArray,
node_features: npt.NDArray,
*,
n_components: int = 2,
title: str = "",
ax: matplotlib.axes = None,
) -> matplotlib.axes:
# Shapes must agree:
assert len(node_GT_label) == len(node_features)
tsne = TSNE(n_components=n_components)
node_embed = tsne.fit_transform(X=node_features)

# Plot the TSNE manifold:
title = f"TSNE of Patch Features\n{title}"
umap1, umap2 = node_embed[:, 0], node_embed[:, 1]
scatter = ax.scatter(
x=umap1, y=umap2, c=node_GT_label, cmap="coolwarm"
)
cbar = plt.colorbar(scatter)
cbar.ax.get_yaxis().labelpad = 15
cbar.ax.set_ylabel("Ground Truth Node Label", rotation=270)
ax.set_xlabel("UMAP 1")
ax.set_ylabel("UMAP 2")
ax.set_title(title)
return ax

def plot_TSNE_before_and_after_GCN(self, **kwargs) -> None:
# Plot the subplots:
size = 5
_, axes = plt.subplots(1, 2, figsize=(size * 2 + 2, size * 1))

# Get the embeddings:
for p, (plot_name, method) in enumerate(
zip(
["Before", "After"],
[self.read_graph_dataset_IO, self.extract_GCN_node_embeddings],
)
):
labels, embeds = method()[:2]
shape = embeds.shape[-1]
title = f"{plot_name} GCN | Node Feature Embedding [{shape}]"
self.perform_and_plot_tsne(
labels, embeds, title=title, ax=axes[p], **kwargs
)

# Annotate & display:
plt.tight_layout()
plt.show()
plt.close()
89 changes: 69 additions & 20 deletions grace/evaluation/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import networkx as nx
import numpy as np

from pathlib import Path

import torch
from torch_geometric.data import Data
from tqdm.auto import tqdm
Expand All @@ -13,14 +15,23 @@
accuracy_metric,
areas_under_curves_metrics,
)
from grace.evaluation.visualisation import (
visualise_prediction_probs_hist,
visualise_node_and_edge_probabilities,
)


class GraphLabelPredictor(object):
"""TODO: Fill in."""

def __init__(self, model: torch.nn.Module) -> None:
def __init__(self, model: str | torch.nn.Module) -> None:
super().__init__()

if isinstance(model, str):
assert Path(model).is_file()
model = torch.load(model)

model.eval()
self.pretrained_gcn = model

def set_node_and_edge_probabilities(self, G: nx.Graph):
Expand All @@ -40,45 +51,83 @@ def set_node_and_edge_probabilities(self, G: nx.Graph):
prediction = (int(e_pred[e_idx].item()), e_probabs[e_idx].numpy())
edge[-1][GraphAttrs.EDGE_PREDICTION] = prediction

def visualise_performance(self, G: nx.Graph):
def visualise_model_performance_on_graph(
self, G: nx.Graph, positive_class: int = 1
):
# Prep the data & plot them:
node_true = [
node[GraphAttrs.NODE_GROUND_TRUTH]
for _, node in G.nodes(data=True)
]
node_pred = [
node[GraphAttrs.NODE_PREDICTION][0]
for _, node in G.nodes(data=True)
]
node_true = np.array(
[
node[GraphAttrs.NODE_GROUND_TRUTH]
for _, node in G.nodes(data=True)
]
)
node_pred = np.array(
[
node[GraphAttrs.NODE_PREDICTION][0]
for _, node in G.nodes(data=True)
]
)
node_probabs = np.array(
[
node[GraphAttrs.NODE_PREDICTION][1]
for _, node in G.nodes(data=True)
]
)

edge_true = [
edge[GraphAttrs.EDGE_GROUND_TRUTH]
for _, _, edge in G.edges(data=True)
]
edge_pred = [
edge[GraphAttrs.EDGE_PREDICTION][0]
for _, _, edge in G.edges(data=True)
]
edge_true = np.array(
[
edge[GraphAttrs.EDGE_GROUND_TRUTH]
for _, _, edge in G.edges(data=True)
]
)
edge_pred = np.array(
[
edge[GraphAttrs.EDGE_PREDICTION][0]
for _, _, edge in G.edges(data=True)
]
)
edge_probabs = np.array(
[
edge[GraphAttrs.EDGE_PREDICTION][1]
for _, _, edge in G.edges(data=True)
]
)

# Unify the inputs - get the predictions scores for TP class:
filter_mask = np.logical_or(
node_true == 0, node_true == positive_class
)
node_true = node_true[filter_mask]
node_pred = node_pred[filter_mask]
node_probabs = node_probabs[filter_mask]

filter_mask = np.logical_or(
edge_true == 0, edge_true == positive_class
)
edge_true = edge_true[filter_mask]
edge_pred = edge_pred[filter_mask]
edge_probabs = edge_probabs[filter_mask]

# Compute & return accuracy:
node_acc, edge_acc = accuracy_metric(
node_pred, edge_pred, node_true, edge_true
)

# Display metrics figures:
plot_confusion_matrix_tiles(node_pred, edge_pred, node_true, edge_true)

areas_under_curves_metrics(
node_probabs, edge_probabs, node_true, edge_true, figsize=(10, 4)
node_probabs[:, positive_class],
edge_probabs[:, positive_class],
node_true,
edge_true,
figsize=(10, 4),
)
plot_confusion_matrix_tiles(node_pred, edge_pred, node_true, edge_true)

# Localise where the errors occur:
visualise_prediction_probs_hist(G=G)
visualise_node_and_edge_probabilities(G=G)

return node_acc, edge_acc


Expand Down
13 changes: 0 additions & 13 deletions grace/evaluation/metrics_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,18 +94,11 @@ def areas_under_curves_metrics(
edge_true: torch.tensor,
figsize: tuple[int] = (20, 7),
) -> tuple[plt.figure]:
# Unify the inputs - get the predictions scores for TP class:
if node_pred.shape[-1] == 2:
node_pred = node_pred[:, 1]
if edge_pred.shape[-1] == 2:
edge_pred = edge_pred[:, 1]

# Instantiate the figure
_, axes = plt.subplots(nrows=1, ncols=2, figsize=figsize)

# Area under ROC:
roc_score_nodes = roc_auc_score(y_true=node_true, y_score=node_pred)
# rcd_nodes = RocCurveDisplay.from_predictions(
RocCurveDisplay.from_predictions(
y_true=node_true,
y_pred=node_pred,
Expand All @@ -116,7 +109,6 @@ def areas_under_curves_metrics(
)

roc_score_edges = roc_auc_score(y_true=edge_true, y_score=edge_pred)
# rcd_edges = RocCurveDisplay.from_predictions(
RocCurveDisplay.from_predictions(
y_true=edge_true,
y_pred=edge_pred,
Expand All @@ -130,7 +122,6 @@ def areas_under_curves_metrics(
prc_score_nodes = average_precision_score(
y_true=node_true, y_score=node_pred
)
# prc_nodes = PrecisionRecallDisplay.from_predictions(
PrecisionRecallDisplay.from_predictions(
y_true=node_true,
y_pred=node_pred,
Expand All @@ -143,7 +134,6 @@ def areas_under_curves_metrics(
prc_score_edges = average_precision_score(
y_true=edge_true, y_score=edge_pred
)
# prc_edges = PrecisionRecallDisplay.from_predictions(
PrecisionRecallDisplay.from_predictions(
y_true=edge_true,
y_pred=edge_pred,
Expand All @@ -154,16 +144,13 @@ def areas_under_curves_metrics(
)

# Annotate the figure:
# axes[0].plot([0, 0], [1, 1], ls="dashed", lw=1, color="lightgrey")
axes[0].plot([0, 1], [0, 1], ls="dashed", lw=1, color="lightgrey")
axes[1].plot([0, 1], [0.5, 0.5], ls="dashed", lw=1, color="lightgrey")
axes[1].plot([0.5, 0.5], [0, 1], ls="dashed", lw=1, color="lightgrey")

axes[0].set_title("Area under ROC")
axes[1].set_title("Average Precision Score")
plt.tight_layout()
# plt.show()

return axes


Expand Down
Loading

0 comments on commit bd91bdf

Please sign in to comment.