Skip to content

Commit

Permalink
updated toy data loading
Browse files Browse the repository at this point in the history
  • Loading branch information
bekemax committed Jul 13, 2024
1 parent d4a71a6 commit 248a804
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
12 changes: 12 additions & 0 deletions configs/datasets/toy_point_cloud.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
data_domain: point_cloud
data_type: toy_dataset
data_name: toy_point_cloud
data_dir: datasets/${data_domain}/${data_type}

# Dataset parameters
num_points: 8
num_classes: 2

num_features: 1
task: classification
loss_type: cross_entropy
9 changes: 4 additions & 5 deletions modules/data/load/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,7 @@ def load(self) -> torch_geometric.data.Dataset:
torch_geometric.data.Dataset
torch_geometric.data.Dataset object containing the loaded data.
"""
if "data_name" not in self.parameters:
data = load_point_cloud(num_classes=self.cfg["num_classes"])
return CustomDataset([data], self.cfg["data_dir"])
else: # noqa: RET505
raise NotImplementedError
data = load_point_cloud(
num_classes=self.cfg["num_classes"], num_points=self.cfg["num_points"]
)
return CustomDataset([data], self.cfg["data_dir"])

0 comments on commit 248a804

Please sign in to comment.