Skip to content

Commit

Permalink
Forward pass in tutorial notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
sukjulian committed Jul 13, 2024
1 parent 0b221bb commit 5a91883
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
9 changes: 9 additions & 0 deletions configs/datasets/wall_shear_stress.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,12 @@ data_domain: simplicial_complex
data_type: artery
data_name: wall_shear_stress
data_dir: datasets/${data_domain}/${data_type}

num_features:
- 1 # initial node features
- 3 # initial edge features
num_classes: 3
task: regression
loss_type: mae
montitor_metric: mae
task_level: node
33 changes: 30 additions & 3 deletions tutorials/pointcloud2hypergraph/pointnet_lifting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
"from modules.utils.utils import (\n",
" describe_data,\n",
" load_dataset_config,\n",
" load_model_config,\n",
" load_transform_config,\n",
")"
]
Expand Down Expand Up @@ -94,7 +95,8 @@
" pos=dataset.pos[slice_pos],\n",
" face=dataset.face[:, slice_face],\n",
" # x_2=dataset.x_2[slice_face],\n",
" # incidence_2=dataset.incidence_2[slice_pos, slice_face] # not supported by PyTorch\n",
" # incidence_2=dataset.incidence_2[slice_pos, slice_face] # not supported by PyTorch,\n",
" num_features=dataset.num_features,\n",
")\n",
"\n",
"print(\"Data sample:\")\n",
Expand Down Expand Up @@ -122,10 +124,10 @@
"transform_config[\"lifting\"][\"sampling_ratio\"] = 0.2\n",
"transform_config[\"lifting\"][\"cluster_radius\"] = 0.1\n",
"\n",
"lifted_dataset = PreProcessor(\n",
"lifted_data = PreProcessor(\n",
" data, transform_config, os.path.join(rootutils.find_root(), dataset_config.data_dir)\n",
")\n",
"describe_data(lifted_dataset)"
"describe_data(lifted_data)"
]
},
{
Expand All @@ -134,6 +136,31 @@
"source": [
"This hypergraph represents the first set abstraction layer that is used by [PointNet++](https://arxiv.org/abs/1706.02413). To construct a complet PointNet++ out of this, we would have to recursively apply the lifting while regarding the previous hyperedges as new \"hyper-nodes\"."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running inference"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from modules.models.hypergraph.unigcn import UniGCNModel\n",
"\n",
"model_type = \"hypergraph\"\n",
"model_id = \"unigcn\"\n",
"model_config = load_model_config(model_type, model_id)\n",
"\n",
"model = UniGCNModel(model_config, dataset_config)\n",
"\n",
"print(\"\\nModel output:\")\n",
"model(lifted_data)"
]
}
],
"metadata": {
Expand Down

0 comments on commit 5a91883

Please sign in to comment.