Skip to content

Commit

Permalink
fixed merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
mmasden committed Jul 10, 2024
2 parents f714757 + f4a9ce2 commit 6c38d3a
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 125 deletions.
106 changes: 55 additions & 51 deletions modules/transforms/liftings/graph2hypergraph/mapper_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
SVDFeatureReduction,
ToUndirected,
)
from torch_geometric.utils import subgraph

from modules.transforms.liftings.graph2hypergraph.base import Graph2HypergraphLifting

Expand All @@ -21,16 +20,14 @@ class MapperCover:
resolution : int, optional
The number of intervals in the MapperCover. Default is 10.
gain : float, optional
The percentage of overlap between consectutive intervals
in MapperCover and should be value between 0 and 0.5.
The proportion of overlap between consectutive intervals
in the MapperCover and should be value between 0 and 0.5.
Default is 0.3.
Attributes
----------
left_endpoints : (resolution, 1) Tensor
The left endpoints for each interval in the MapperCover.
right_endpoints : (resolution, 1) Tensor
The right endpoints for each interval in the MapperCover.
cover_intervals : (resolution, 2) Tensor
A tensor containing each interval in the MapperCover.
"""

def __init__(self, resolution=10, gain=0.3):
Expand Down Expand Up @@ -59,7 +56,8 @@ def fit_transform(self, filtered_data):
data_min = torch.min(filtered_data) - 1e-3
data_max = torch.max(filtered_data) + 1e-3
data_range = data_max - data_min
# width of each interval in the cover

# width of each interval in the cover and last left endpoint
cover_width = data_range / (self.resolution - (self.resolution - 1) * self.gain)
last_lower_endpoint = data_min + cover_width * (self.resolution - 1) * (
1 - self.gain
Expand All @@ -75,16 +73,19 @@ def fit_transform(self, filtered_data):
# want a n x resolution Boolean tensor
lower_values = torch.gt(filtered_data, lower_endpoints)
upper_values = torch.lt(filtered_data, upper_endpoints)

# need to check close values to deal with some endpoint issues
lower_is_close_values = torch.isclose(filtered_data, lower_endpoints)
upper_is_close_values = torch.isclose(filtered_data, upper_endpoints)

# construct the boolean mask
mask = torch.logical_and(
torch.logical_or(lower_values, lower_is_close_values),
torch.logical_or(upper_values, upper_is_close_values),
)
# assert every data point is covered
assert torch.all(torch.any(mask, 1)), f"{torch.any(mask,1)}"

# remove empty intervals from cover
non_empty_covers = torch.any(mask, 0)
return mask[:, non_empty_covers]
Expand Down Expand Up @@ -120,7 +121,7 @@ def _verify_cover_parameters(self):

class MapperLifting(Graph2HypergraphLifting):
r"""Lifts graphs to hypergraph domain using a Mapper construction for CC-pooling.
(See Figure 30 in [1])
(See Figure 30 in \[1\])
Parameters
----------
Expand All @@ -135,7 +136,7 @@ class MapperLifting(Graph2HypergraphLifting):
Default is 10.
gain : float, optional
The percentage of overlap between consectutive intervals
in MapperCover and should be value between 0 and 0.5.
in MapperCover and should be a value between 0 and 0.5.
Default is 0.3.
filter_func : object, optional
Filter function used for Mapper construction.
Expand All @@ -147,24 +148,41 @@ class MapperLifting(Graph2HypergraphLifting):
**kwargs : optional
Additional arguments for the class.
Attributes
----------
filtered_data : dict
Filtered data used to compute the Mapper lifting.
Dictionary is of the form
{filter_attr: filter_func(data)}.
cover : (k, resolution) boolean Tensor
Mask computed from the MapperCover class
to compute the Mapper lifting with k < n_sample.
clusters : dict
Distinct connected components in each cover set
computed after fitting the Mapper cover.
Dictionary has integer keys and tuple values
of the form (cover_set_i, nodes_in_cluster).
Each cluster is a rank 2 hyperedge in the
hypergraph.
Notes
-----
The following are common filter functions which can be called with
filter_attr.
1. "laplacian" : Converts data to an undirected graph and then applies the
torch_geometric.transforms.AddLaplacianEigenvectorPE(k=1) transform and
projects onto the 1st eigenvector.
projects onto the smallest nonzero eigenvector.
2. "svd" : Applies the torch_geometric.transforms.SVDFeatureReduction(out_channels=1)
transform to the node feature matrix (ie. torch_geometric.Data.data.x)
to project data to a 1-dimensional subspace.
3. "feature_pca" : Applies torch.pca_lowrank(q=1) transform to node feature matrix
(ie. torch_geometric.Data.data.x) and then projects to the 1st principle component.
(ie. torch_geometric.Data.data.x) and then projects to the 1st principal component.
4. "position_pca" : Applies torch.pca_lowrank(q=1) transform to node feature matrix
(ie. torch_geometric.Data.data.pos) and then projects to the 1st principle component.
4. "position_pca" : Applies torch.pca_lowrank(q=1) transform to node position matrix
(ie. torch_geometric.Data.data.pos) and then projects to the 1st principal component.
5. "feature_sum" : Applies torch.sum(dim=1) to the node feature matrix in the graph
(ie. torch_geometric.Data.data.x).
Expand All @@ -174,9 +192,9 @@ class MapperLifting(Graph2HypergraphLifting):
You may also construct your own filter_attr and filter_func:
7. "my_filter_attr" : my_filter_func = lambda data : my_filter_func(data)
7. "my_filter_attr" : Name of a self defined function
my_filter_func = lambda data : my_filter_func(data)
where my_filter_func(data) outputs a (n_sample, 1) Tensor.
Additionally, assign filter_func = my_filter_func.
References
----------
Expand All @@ -201,6 +219,19 @@ def __init__(
self.filter_func = filter_func
self._verify_filter_parameters(filter_attr, filter_func)

def _verify_filter_parameters(self, filter_attr, filter_func):
if filter_func is None:
assert (
self.filter_attr in filter_dict
), f"Please add function to filter_func or choose filter_attr from {list(filter_dict)}. \
Currently filter_func is {filter_func} and filter_attr is {filter_attr}."
if filter_func is not None:
assert (
self.filter_attr not in filter_dict
), f"Assign new filter_attr not in {list(filter_dict)} or leave filter_func as None. \
Currently filter_func is {filter_func} and filter_attr is {filter_attr}"
assert type(filter_attr) is str, f"{filter_attr} must be a string."

def _filter(self, data):
"""Applies 1-dimensional filter function to
torch_geometric.Data.data.
Expand Down Expand Up @@ -233,10 +264,11 @@ def _cluster(self, data, cover_mask):
"""Finds clusters in each cover set within cover_mask.
For each cover set, a cluster is a
distinct connected component.
Clusters are stored in dictionary, self.clusters.
Clusters are stored in the dictionary, self.clusters.
"""
mapper_clusters = {}
num_clusters = 0

# convert data to undirected graph for clustering
to_undirected = ToUndirected()
data = to_undirected(data)
Expand All @@ -256,17 +288,20 @@ def _cluster(self, data, cover_mask):
]

nodes = [i.item() for i in torch.where(cover_set)[0]]

# build graph to find clusters
cover_graph = nx.Graph()
cover_graph.add_nodes_from(nodes)
cover_graph.add_edges_from(edges)

# find clusters
clusters = nx.connected_components(cover_graph)

for cluster in clusters:
# index is the subset of nodes in data
# contained in cluster
index = torch.Tensor(list(cluster))

# kth cluster is item in dictionary
# of the form:
# k : (cover_set_index, nodes_in_cluster)
Expand All @@ -278,30 +313,13 @@ def _cluster(self, data, cover_mask):
return mapper_clusters

def lift_topology(self, data: torch_geometric.data.Data) -> dict:
r"""Lifts the topology of a graph to hypergraph domain by considering k-nearest neighbors.
r"""Lifts the topology of a graph to hypergraph domain by Mapper on Graphs.
Parameters
----------
data : torch_geometric.data.Data
The input data to be lifted.
Attributes
----------
filtered_data : dict
Filtered data used to compute the Mapper lifting.
Dictionary is of the form
{filter_attr: filter_func(data)}.
cover : (n_sample, resolution) boolean Tensor
Mask computed from the MapperCover class
to compute the Mapper lifting.
clusters : dict
Distinct connected components in each cover set
computed after fitting the Mapper cover.
Dictionary has integer keys and tuple values
of the form (cover_set_i, nodes_in_cluster).
Each cluster is a rank 2 hyperedge in the
hypergraph.
Returns
-------
dict
Expand All @@ -314,8 +332,10 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:
cover = MapperCover(self.resolution, self.gain)
cover_mask = cover.fit_transform(filtered_data)
self.cover = cover_mask

# Find the clusters in the fitted cover
mapper_clusters = self._cluster(data, cover_mask)
self.clusters = mapper_clusters

# Construct the hypergraph dictionary
num_nodes = data["x"].shape[0]
Expand All @@ -337,27 +357,11 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict:
incidence_hyperedges[j.int(), i] = 1

# Incidence matrix is (num_nodes, num_edges + num_clusters) size matrix

incidence = torch.hstack([incidence_edges, incidence_hyperedges])

incidence = torch.Tensor(incidence).to_sparse_coo()
print("hyperedges", incidence_hyperedges)

return {
"incidence_hyperedges": incidence,
"num_hyperedges": num_hyperedges,
"x_0": data.x,
}

def _verify_filter_parameters(self, filter_attr, filter_func):
if filter_func is None:
assert (
self.filter_attr in filter_dict
), f"Please add function to filter_func or choose filter_attr from {list(filter_dict)}. \
Currently filter_func is {filter_func} and filter_attr is {filter_attr}."
if filter_func is not None:
assert (
self.filter_attr not in filter_dict
), f"Assign new filter_attr not in {list(filter_dict)} or leave filter_func as None. \
Currently filter_func is {filter_func} and filter_attr is {filter_attr}"
assert type(filter_attr) is str, f"filter_attr must be a string."
31 changes: 17 additions & 14 deletions test/transforms/liftings/graph2hypergraph/test_mapper_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@
import torch_geometric
from torch_geometric.transforms import (
AddLaplacianEigenvectorPE,
Compose,
SVDFeatureReduction,
ToUndirected,
)

from modules.data.utils.utils import load_manual_graph
from modules.transforms.liftings.graph2hypergraph.mapper_lifting import (
MapperCover,
MapperLifting,
)
from modules.transforms.liftings.graph2hypergraph.mapper_lifting import MapperLifting

expected_edge_incidence = torch.tensor(
[
Expand Down Expand Up @@ -243,6 +239,10 @@
]
)

""" Enrich the `load_manual_graph` graph with the necessary information to test
additional filter functions.
"""


def enriched_manual_graph():
data = load_manual_graph()
Expand All @@ -264,7 +264,9 @@ def enriched_manual_graph():
return data


"""Construct a naive implementation to create the filtered data set given data and filter function."""
""" Construct a naive implementation to create the filtered data set given data and filter function.
Used for testing filter function.
"""


def naive_filter(data, filter):
Expand Down Expand Up @@ -296,7 +298,9 @@ def naive_filter(data, filter):
return filtered_data


"""Construct a cover_mask from filtered data and default lift parameters."""
""" Construct a naive cover_mask from filtered data and default lift parameters.
This tests the cover method.
"""


def naive_cover(filtered_data):
Expand All @@ -306,7 +310,6 @@ def naive_cover(filtered_data):
data_range = data_max - data_min
# width of each interval in the cover
cover_width = data_range / (10 - (10 - 1) * 0.3)
last = data_min + (10 - 1) * (1 - 0.3) * cover_width
lows = torch.zeros(10)
for i in range(10):
lows[i] = (data_min) + (i) * (1 - 0.3) * cover_width
Expand Down Expand Up @@ -359,7 +362,7 @@ def test_filter(self, filter):
torch.isclose(lift_filter_data, naive_filter_data)
), f"Something is wrong with filtered values using {self.filter_name}. The lifted filter data is {lift_filter_data} and the naive filter data is {naive_filter_data}."
if filter == "laplacian":
# laplacian produce eigenvector up to a unit multiple.
# laplacian filter produces an eigenvector up to a unit multiple.
# instead we check their absolute values.
assert torch.all(
torch.isclose(torch.abs(lift_filter_data), torch.abs(naive_filter_data))
Expand All @@ -378,7 +381,7 @@ def test_filter(self, filter):
)
def test_cover(self, filter):
self.setup(filter)
transformed_data = self.mapper_lift.forward(self.data.clone())
self.mapper_lift.forward(self.data.clone())
lift_cover_mask = self.mapper_lift.cover
naive_cover_mask = naive_cover(self.mapper_lift.filtered_data[filter])
assert torch.all(
Expand Down Expand Up @@ -447,13 +450,13 @@ def test_cluster(self, filter):
},
}
self.setup(filter)
transformed_data = self.mapper_lift.forward(self.data.clone())
self.mapper_lift.forward(self.data.clone())
lift_clusters = self.mapper_lift.clusters
if filter != "laplacian":
assert (
expected_clusters[self.filter_name].keys() == lift_clusters.keys()
), f"Different number of clusters using {filter}. Expected {list(expected_clusters[filter])} but got {list(lift_clusters)}."
for cluster in lift_clusters.keys():
for cluster in lift_clusters:
assert (
expected_clusters[self.filter_name][cluster][0]
== lift_clusters[cluster][0]
Expand All @@ -462,7 +465,7 @@ def test_cluster(self, filter):
expected_clusters[self.filter_name][cluster][1],
lift_clusters[cluster][1],
), f"Something is wrong with the clustering using {self.filter_name}. Expected node subset {expected_clusters[self.filter_name][cluster][1]} but got {lift_clusters[cluster][1]} for cluster {cluster}."
# Laplacian function projects up to a unit. This causes clusters to not be identical
# Laplacian function projects up to a unit. This causes clusters to not be identical by index
# instead we check if the node subsets of the lifted set are somewhere in the expected set.
if filter == "laplacian":
assert len(lift_clusters) == len(
Expand All @@ -479,7 +482,7 @@ def test_cluster(self, filter):
expected_cluster_nodes.remove(node_subset)
assert (
expected_cluster_nodes == []
), f"Expected clusters contain more clusters than in the lifted cluster."
), "Expected clusters contain more clusters than in the lifted cluster."

@pytest.mark.parametrize(
"filter",
Expand Down
Loading

0 comments on commit 6c38d3a

Please sign in to comment.