Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
pzajec committed Jul 13, 2024
1 parent 497947b commit f831a1d
Show file tree
Hide file tree
Showing 11 changed files with 676 additions and 11 deletions.
13 changes: 13 additions & 0 deletions configs/datasets/annulus.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: annulus
data_dir: datasets/${data_domain}/${data_type}/${data_name}

# Dataset parameters
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
num_features: 10
num_classes: 3
num_points: 1000
dim: 2
13 changes: 13 additions & 0 deletions configs/datasets/random_pointcloud.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: random_pointcloud
data_dir: datasets/${data_domain}/${data_type}/${data_name}

# Dataset parameters
task: classification
loss_type: cross_entropy
monitor_metric: accuracy
num_features: 10
num_classes: 3
num_points: 50
dim: 2
11 changes: 11 additions & 0 deletions configs/datasets/stanford_bunny.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
data_domain: pointcloud
data_type: toy_dataset
data_name: stanford_bunny
data_dir: datasets/${data_domain}/${data_type}/${data_name}

# Dataset parameters
num_classes: 1
task: regression
loss_type: mse
monitor_metric: mae
task_level: graph
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
transform_type: 'lifting'
transform_name: "CoverLifting"
feature_lifting: ProjectionSum
34 changes: 34 additions & 0 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
load_cell_complex_dataset,
load_hypergraph_pickle_dataset,
load_manual_graph,
load_pointcloud_dataset,
load_simplicial_dataset,
)

Expand Down Expand Up @@ -204,3 +205,36 @@ def load(
torch_geometric.data.Dataset object containing the loaded data.
"""
return load_hypergraph_pickle_dataset(self.parameters)


class PointCloudLoader(AbstractLoader):
r"""Loader for point cloud datasets.
Parameters
----------
parameters : DictConfig
Configuration parameters.
"""

def __init__(self, parameters: DictConfig):
super().__init__(parameters)
self.parameters = parameters

def load(
self,
) -> torch_geometric.data.Dataset:
r"""Load point cloud dataset.
Parameters
----------
None
Returns
-------
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
# Define the path to the data directory
root_folder = rootutils.find_root()
self.data_dir = os.path.join(root_folder, self.parameters["data_dir"])

data = load_pointcloud_dataset(self.parameters)
print(data, data[0])
return load_pointcloud_dataset(self.parameters)
103 changes: 95 additions & 8 deletions modules/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,17 @@
import networkx as nx
import numpy as np
import omegaconf
import rootutils
import toponetx.datasets.graph as graph
import torch
import torch_geometric
from gudhi.datasets.remote import fetch_bunny
from topomodelx.utils.sparse import from_sparse
from torch_geometric.data import Data
from torch_sparse import coalesce

from modules.data.utils.custom_dataset import CustomDataset


def get_complex_connectivity(complex, max_rank, signed=False):
r"""Gets the connectivity matrices for the complex.
Expand Down Expand Up @@ -50,16 +54,16 @@ def get_complex_connectivity(complex, max_rank, signed=False):
)
except ValueError: # noqa: PERF203
if connectivity_info == "incidence":
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx]
)
else:
connectivity[f"{connectivity_info}_{rank_idx}"] = (
generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity[
f"{connectivity_info}_{rank_idx}"
] = generate_zero_sparse_connectivity(
m=practical_shape[rank_idx], n=practical_shape[rank_idx]
)
connectivity["shape"] = practical_shape
return connectivity
Expand Down Expand Up @@ -372,6 +376,89 @@ def get_TUDataset_pyg(cfg):
return [data for data in dataset]


def load_pointcloud_dataset(cfg):
r"""Loads point cloud datasets.
Parameters
----------
cfg : DictConfig
Configuration parameters.
Returns
-------
torch_geometric.data.Data
Point cloud dataset.
"""
root_folder = rootutils.find_root()
data_dir = osp.join(root_folder, cfg["data_dir"])

if cfg["data_name"] == "random_pointcloud":
num_points, dim = cfg["num_points"], cfg["dim"]
pos = torch.rand((num_points, dim))
elif cfg["data_name"] == "annulus":
num_points, dim = cfg["num_points"], cfg["dim"]
pos = annulus_2d(dim, num_points)
elif cfg["data_name"] == "stanford_bunny":
pos = fetch_bunny(
file_path=osp.join(data_dir, "stanford_bunny.npy"),
accept_license=False,
)

num_points = cfg["num_points"] if "num_points" in cfg else len(pos)
pos = torch.tensor(pos)

pos = pos[np.random.choice(pos.shape[0], num_points, replace=False)]

return CustomDataset(
[
torch_geometric.data.Data(
pos=pos,
)
],
data_dir,
)


def annulus_2d(D, N, R1=0.8, R2=1, A=0):
n = 0
P = np.array([[0.0] * D] * N)
while n < N:
p = np.random.uniform(-R2, R2, D)
if np.linalg.norm(p) > R2 or np.linalg.norm(p) < R1:
continue
if (p[0] > 0) and (np.abs(p[1]) < A / 2):
continue
P[n] = p
n = n + 1
return torch.tensor(P)


def load_annulus():
pos = annulus_2d(2, 1000)
return torch_geometric.data.Data(pos=pos)


def load_manual_pointcloud():
"""Create a manual pointcloud for testing purposes."""
# Define the positions
pos = torch.tensor(
[
[0, 0, 0],
[0, 0, 1],
[0, 1, 0],
[10, 0, 0],
[10, 0, 1],
[10, 1, 0],
[10, 1, 1],
[20, 0, 0],
[20, 0, 1],
[20, 1, 0],
[20, 1, 1],
[30, 0, 0],
]
).float()

return torch_geometric.data.Data(pos=pos, num_nodes=pos.size(0), num_features=0)


def ensure_serializable(obj):
r"""Ensures that the object is serializable.
Expand Down
3 changes: 3 additions & 0 deletions modules/transforms/data_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from modules.transforms.liftings.graph2simplicial.clique_lifting import (
SimplicialCliqueLifting,
)
from modules.transforms.liftings.pointcloud2graph.cover_lifting import CoverLifting

TRANSFORMS = {
# Graph -> Hypergraph
Expand All @@ -23,6 +24,8 @@
"SimplicialCliqueLifting": SimplicialCliqueLifting,
# Graph -> Cell Complex
"CellCycleLifting": CellCycleLifting,
# Point Cloud -> Graph
"CoverLifting": CoverLifting,
# Feature Liftings
"ProjectionSum": ProjectionSum,
# Data Manipulations
Expand Down
9 changes: 6 additions & 3 deletions modules/transforms/liftings/lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data:
initial_data = data.to_dict()
lifted_topology = self.lift_topology(data)
lifted_topology = self.feature_lifting(lifted_topology)

return torch_geometric.data.Data(**initial_data, **lifted_topology)


Expand Down Expand Up @@ -118,9 +119,11 @@ def _generate_graph_from_data(self, data: torch_geometric.data.Data) -> nx.Graph
# In case edge features are given, assign features to every edge
edge_index, edge_attr = (
data.edge_index,
data.edge_attr
if is_undirected(data.edge_index, data.edge_attr)
else to_undirected(data.edge_index, data.edge_attr),
(
data.edge_attr
if is_undirected(data.edge_index, data.edge_attr)
else to_undirected(data.edge_index, data.edge_attr)
),
)
edges = [
(i.item(), j.item(), dict(features=edge_attr[edge_idx], dim=1))
Expand Down
Loading

0 comments on commit f831a1d

Please sign in to comment.