Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
kaidic committed May 11, 2022
1 parent 6280176 commit 9e2ef84
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 465 deletions.
42 changes: 36 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,41 @@ PyTorch implementation of STELLAR, a geometric deep learning tool for cell-type



### Dependencies
### Installation

STELLAR requires the following packages. We test our software on Ubuntu 16.04 with NVIDIA Geforce 2080 Ti GPU. Please check the [requirements.txt](https://github.com/snap-stanford/stellar/blob/main/requirements.txt) file for more details on required Python packages.
**Requirements**

- [PyTorch==1.9](https://pytorch.org/)
- [PyG==1.7](https://pytorch-geometric.readthedocs.io/en/latest/)
- [sklearn==1.0.1](https://scikit-learn.org/)
- NVIDIA GPU, Linux, Python3. We test our software on Ubuntu 16.04 with NVIDIA Geforce 2080 Ti GPU and 1T CPU memory.


**1. Python environment (Optional):**
We recommend using Conda package manager

```bash
conda create -n stellar python=3.8
source activate stellar
```

**2. Pytorch:**
Install [PyTorch](https://pytorch.org/).
We have verified under PyTorch 1.9.1. For example:
```bash
conda install pytorch cudatoolkit=11.3 -c pytorch
```

**3. Pytorch Geometric:**
Install [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html),
follow their instructions. We have verified under Pyg 2.0. For example:
```bash
conda install pyg -c pyg
```

**4. Other dependencies:**

Please run the following command to install additional packages.
```bash
pip install -r requirements.txt
```

### Getting started

Expand Down Expand Up @@ -46,9 +74,11 @@ python STELLAR_run.py --dataset Hubmap --input-dim 48 --num-heads 22
python STELLAR_run.py --dataset TonsilBE --input-dim 44 --num-heads 13 --num-seed-class 3
```

We also provided a jupyter notebook that walks through a downsampled dataset. Please consider downsample more if there's still a memory issue. Note that the performance of the model would degrade as the training data gets less.

### Use your own dataset

Please refer to `load_hubmap_data()` and implement your own loader and construct the dataset.
Our stellar function requires node features, corresponding labels and corresponding edges as inputs. Here Node feature matrix should have shape [num_nodes, num_node_features] and edge indexes should have shape [2, num_edges].

### Citing

Expand Down
4 changes: 2 additions & 2 deletions STELLAR_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def main():
parser = argparse.ArgumentParser(description='STELLAR')
parser.add_argument('--dataset', default='TonsilBE', help='dataset setting')
parser.add_argument('--dataset', default='Hubmap', help='dataset setting')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--name', type=str, default='STELLAR')
parser.add_argument('--epochs', type=int, default=20)
Expand All @@ -16,7 +16,7 @@ def main():
parser.add_argument('--input-dim', type=int, default=48)
parser.add_argument('--num-heads', type=int, default=22)
parser.add_argument('--num-seed-class', type=int, default=0)
parser.add_argument('--sample-rate', type=float, default=0.1)
parser.add_argument('--sample-rate', type=float, default=0.5)
parser.add_argument('-b', '--batch-size', default=1, type=int,
metavar='N',
help='mini-batch size')
Expand Down
6 changes: 3 additions & 3 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np
from builtins import range
from torch_geometric.data import InMemoryDataset, Data
import sklearn
from sklearn.metrics import pairwise_distances
import pandas as pd

def get_hubmap_edge_index(pos, regions, distance_thres):
Expand All @@ -11,7 +11,7 @@ def get_hubmap_edge_index(pos, regions, distance_thres):
for reg in regions_unique:
locs = np.where(regions == reg)[0]
pos_region = pos[locs, :]
dists = sklearn.metrics.pairwise_distances(pos_region)
dists = pairwise_distances(pos_region)
dists_mask = dists < distance_thres
np.fill_diagonal(dists_mask, 0)
region_edge_list = np.transpose(np.nonzero(dists_mask)).tolist()
Expand All @@ -21,7 +21,7 @@ def get_hubmap_edge_index(pos, regions, distance_thres):

def get_tonsilbe_edge_index(pos, distance_thres):
edge_list = []
dists = sklearn.metrics.pairwise_distances(pos)
dists = pairwise_distances(pos)
dists_mask = dists < distance_thres
np.fill_diagonal(dists_mask, 0)
edge_list = np.transpose(np.nonzero(dists_mask)).tolist()
Expand Down
Loading

0 comments on commit 9e2ef84

Please sign in to comment.