From b483bd396695ea448a3e60ef3d4c65d30f0b6a62 Mon Sep 17 00:00:00 2001 From: Alejandro Mossi Date: Sat, 1 Apr 2023 14:38:39 +0200 Subject: [PATCH 01/11] m --- .gitignore | 0 FISHscale/__init__.py | 0 FISHscale/graphNN/__init__.py | 0 FISHscale/graphNN/cluster_utils.py | 0 FISHscale/graphNN/graph_decoder.py | 0 FISHscale/graphNN/graph_pci.py | 0 FISHscale/graphNN/graph_utils.py | 0 FISHscale/graphNN/graphdata.py | 0 FISHscale/graphNN/models.py | 0 FISHscale/graphNN/pciSeq/__init__.py | 0 FISHscale/graphNN/pciSeq/app.py | 0 FISHscale/graphNN/pciSeq/config.py | 0 FISHscale/graphNN/pciSeq/src/__init__.py | 0 FISHscale/graphNN/pciSeq/src/_version.py | 0 FISHscale/graphNN/pciSeq/src/cell_call/__init__.py | 0 FISHscale/graphNN/pciSeq/src/cell_call/datatypes.py | 0 .../graphNN/pciSeq/src/cell_call/log_config.py | 0 FISHscale/graphNN/pciSeq/src/cell_call/main.py | 0 FISHscale/graphNN/pciSeq/src/cell_call/summary.py | 0 FISHscale/graphNN/pciSeq/src/cell_call/utils.py | 0 FISHscale/graphNN/pciSeq/src/preprocess/__init__.py | 0 .../graphNN/pciSeq/src/preprocess/cell_borders.py | 0 .../graphNN/pciSeq/src/preprocess/segmentation.py | 0 .../graphNN/pciSeq/src/preprocess/spot_labels.py | 0 FISHscale/graphNN/pciSeq/src/preprocess/utils.py | 0 FISHscale/graphNN/pciSeq/src/viewer/__init__.py | 0 FISHscale/graphNN/pciSeq/src/viewer/stage_image.py | 0 FISHscale/graphNN/pciSeq/src/viewer/utils.py | 0 FISHscale/graphNN/sample2shoji.py | 0 FISHscale/graphNN/submodules.py | 0 FISHscale/segmentation/__init__.py | 0 FISHscale/segmentation/cellpose.py | 0 FISHscale/spatial/__init__.py | 0 FISHscale/spatial/boundaries.py | 0 FISHscale/spatial/gene_order.py | 0 FISHscale/utils/Louvain_modified.py | 0 FISHscale/utils/__init__.py | 0 FISHscale/utils/bonefight.py | 0 FISHscale/utils/colors.py | 0 FISHscale/utils/coordinate_based_colocalization.py | 0 FISHscale/utils/data_handling.py | 0 FISHscale/utils/dataset.py | 0 FISHscale/utils/decomposition.py | 0 FISHscale/utils/density_1D.py | 0 FISHscale/utils/dpca.py | 0 FISHscale/utils/fast_iteration.py | 0 FISHscale/utils/gene_correlation.py | 0 FISHscale/utils/hex_regionalization.py | 0 FISHscale/utils/inside_polygon.py | 0 FISHscale/utils/normalization.py | 0 FISHscale/utils/regionalization_gradient.py | 0 FISHscale/utils/regionalization_multi.py | 0 FISHscale/utils/segmentation_utils.py | 0 FISHscale/utils/spatial_metrics.py | 0 FISHscale/utils/volume_align.py | 0 FISHscale/visualization/__init__.py | 0 FISHscale/visualization/gene_scatter.py | 0 FISHscale/visualization/vis_linux.py | 0 FISHscale/visualization/vis_macos.py | 0 FISHscale_demo.png | Bin FISHscale_open_3D.gif | Bin FISHscale_open_3D_2.gif | Bin FISHscale_tutorial_multi_dataset.ipynb | 0 Images/144_color_example.png | Bin Images/1D_distribution_example.png | Bin Images/2133.png_300.png | Bin Images/CBC_example_high.png | Bin Images/CBC_example_low.png | Bin Images/Hex_bin_example_data1810um.png | Bin Images/test16x16.png | Bin README.md | 0 .../FISHscale_dask_attrs_and_graphsage.ipynb | 0 .../FISHscale_tutorial_multi_dataset.ipynb | 0 .../FISHscale_tutorial_single_dataset.ipynb | 0 example_notebooks/MultiAnalysis.ipynb | 0 setup.py | 0 76 files changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 .gitignore mode change 100755 => 100644 FISHscale/__init__.py mode change 100755 => 100644 FISHscale/graphNN/__init__.py mode change 100755 => 100644 FISHscale/graphNN/cluster_utils.py mode change 100755 => 100644 FISHscale/graphNN/graph_decoder.py mode change 100755 => 100644 FISHscale/graphNN/graph_pci.py mode change 100755 => 100644 FISHscale/graphNN/graph_utils.py mode change 100755 => 100644 FISHscale/graphNN/graphdata.py mode change 100755 => 100644 FISHscale/graphNN/models.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/__init__.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/app.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/config.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/__init__.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/_version.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/cell_call/__init__.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/cell_call/datatypes.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/cell_call/log_config.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/cell_call/main.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/cell_call/summary.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/cell_call/utils.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/preprocess/__init__.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/preprocess/cell_borders.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/preprocess/segmentation.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/preprocess/spot_labels.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/preprocess/utils.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/viewer/__init__.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/viewer/stage_image.py mode change 100755 => 100644 FISHscale/graphNN/pciSeq/src/viewer/utils.py mode change 100755 => 100644 FISHscale/graphNN/sample2shoji.py mode change 100755 => 100644 FISHscale/graphNN/submodules.py mode change 100755 => 100644 FISHscale/segmentation/__init__.py mode change 100755 => 100644 FISHscale/segmentation/cellpose.py mode change 100755 => 100644 FISHscale/spatial/__init__.py mode change 100755 => 100644 FISHscale/spatial/boundaries.py mode change 100755 => 100644 FISHscale/spatial/gene_order.py mode change 100755 => 100644 FISHscale/utils/Louvain_modified.py mode change 100755 => 100644 FISHscale/utils/__init__.py mode change 100755 => 100644 FISHscale/utils/bonefight.py mode change 100755 => 100644 FISHscale/utils/colors.py mode change 100755 => 100644 FISHscale/utils/coordinate_based_colocalization.py mode change 100755 => 100644 FISHscale/utils/data_handling.py mode change 100755 => 100644 FISHscale/utils/dataset.py mode change 100755 => 100644 FISHscale/utils/decomposition.py mode change 100755 => 100644 FISHscale/utils/density_1D.py mode change 100755 => 100644 FISHscale/utils/dpca.py mode change 100755 => 100644 FISHscale/utils/fast_iteration.py mode change 100755 => 100644 FISHscale/utils/gene_correlation.py mode change 100755 => 100644 FISHscale/utils/hex_regionalization.py mode change 100755 => 100644 FISHscale/utils/inside_polygon.py mode change 100755 => 100644 FISHscale/utils/normalization.py mode change 100755 => 100644 FISHscale/utils/regionalization_gradient.py mode change 100755 => 100644 FISHscale/utils/regionalization_multi.py mode change 100755 => 100644 FISHscale/utils/segmentation_utils.py mode change 100755 => 100644 FISHscale/utils/spatial_metrics.py mode change 100755 => 100644 FISHscale/utils/volume_align.py mode change 100755 => 100644 FISHscale/visualization/__init__.py mode change 100755 => 100644 FISHscale/visualization/gene_scatter.py mode change 100755 => 100644 FISHscale/visualization/vis_linux.py mode change 100755 => 100644 FISHscale/visualization/vis_macos.py mode change 100755 => 100644 FISHscale_demo.png mode change 100755 => 100644 FISHscale_open_3D.gif mode change 100755 => 100644 FISHscale_open_3D_2.gif mode change 100755 => 100644 FISHscale_tutorial_multi_dataset.ipynb mode change 100755 => 100644 Images/144_color_example.png mode change 100755 => 100644 Images/1D_distribution_example.png mode change 100755 => 100644 Images/2133.png_300.png mode change 100755 => 100644 Images/CBC_example_high.png mode change 100755 => 100644 Images/CBC_example_low.png mode change 100755 => 100644 Images/Hex_bin_example_data1810um.png mode change 100755 => 100644 Images/test16x16.png mode change 100755 => 100644 README.md mode change 100755 => 100644 example_notebooks/FISHscale_dask_attrs_and_graphsage.ipynb mode change 100755 => 100644 example_notebooks/FISHscale_tutorial_multi_dataset.ipynb mode change 100755 => 100644 example_notebooks/FISHscale_tutorial_single_dataset.ipynb mode change 100755 => 100644 example_notebooks/MultiAnalysis.ipynb mode change 100755 => 100644 setup.py diff --git a/.gitignore b/.gitignore old mode 100755 new mode 100644 diff --git a/FISHscale/__init__.py b/FISHscale/__init__.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/__init__.py b/FISHscale/graphNN/__init__.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/cluster_utils.py b/FISHscale/graphNN/cluster_utils.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/graph_decoder.py b/FISHscale/graphNN/graph_decoder.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/graph_pci.py b/FISHscale/graphNN/graph_pci.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/graph_utils.py b/FISHscale/graphNN/graph_utils.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/graphdata.py b/FISHscale/graphNN/graphdata.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/models.py b/FISHscale/graphNN/models.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/__init__.py b/FISHscale/graphNN/pciSeq/__init__.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/app.py b/FISHscale/graphNN/pciSeq/app.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/config.py b/FISHscale/graphNN/pciSeq/config.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/__init__.py b/FISHscale/graphNN/pciSeq/src/__init__.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/_version.py b/FISHscale/graphNN/pciSeq/src/_version.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/cell_call/__init__.py b/FISHscale/graphNN/pciSeq/src/cell_call/__init__.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/cell_call/datatypes.py b/FISHscale/graphNN/pciSeq/src/cell_call/datatypes.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/cell_call/log_config.py b/FISHscale/graphNN/pciSeq/src/cell_call/log_config.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/cell_call/main.py b/FISHscale/graphNN/pciSeq/src/cell_call/main.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/cell_call/summary.py b/FISHscale/graphNN/pciSeq/src/cell_call/summary.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/cell_call/utils.py b/FISHscale/graphNN/pciSeq/src/cell_call/utils.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/preprocess/__init__.py b/FISHscale/graphNN/pciSeq/src/preprocess/__init__.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/preprocess/cell_borders.py b/FISHscale/graphNN/pciSeq/src/preprocess/cell_borders.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/preprocess/segmentation.py b/FISHscale/graphNN/pciSeq/src/preprocess/segmentation.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/preprocess/spot_labels.py b/FISHscale/graphNN/pciSeq/src/preprocess/spot_labels.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/preprocess/utils.py b/FISHscale/graphNN/pciSeq/src/preprocess/utils.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/viewer/__init__.py b/FISHscale/graphNN/pciSeq/src/viewer/__init__.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/viewer/stage_image.py b/FISHscale/graphNN/pciSeq/src/viewer/stage_image.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/pciSeq/src/viewer/utils.py b/FISHscale/graphNN/pciSeq/src/viewer/utils.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/sample2shoji.py b/FISHscale/graphNN/sample2shoji.py old mode 100755 new mode 100644 diff --git a/FISHscale/graphNN/submodules.py b/FISHscale/graphNN/submodules.py old mode 100755 new mode 100644 diff --git a/FISHscale/segmentation/__init__.py b/FISHscale/segmentation/__init__.py old mode 100755 new mode 100644 diff --git a/FISHscale/segmentation/cellpose.py b/FISHscale/segmentation/cellpose.py old mode 100755 new mode 100644 diff --git a/FISHscale/spatial/__init__.py b/FISHscale/spatial/__init__.py old mode 100755 new mode 100644 diff --git a/FISHscale/spatial/boundaries.py b/FISHscale/spatial/boundaries.py old mode 100755 new mode 100644 diff --git a/FISHscale/spatial/gene_order.py b/FISHscale/spatial/gene_order.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/Louvain_modified.py b/FISHscale/utils/Louvain_modified.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/__init__.py b/FISHscale/utils/__init__.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/bonefight.py b/FISHscale/utils/bonefight.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/colors.py b/FISHscale/utils/colors.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/coordinate_based_colocalization.py b/FISHscale/utils/coordinate_based_colocalization.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/data_handling.py b/FISHscale/utils/data_handling.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/dataset.py b/FISHscale/utils/dataset.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/decomposition.py b/FISHscale/utils/decomposition.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/density_1D.py b/FISHscale/utils/density_1D.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/dpca.py b/FISHscale/utils/dpca.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/fast_iteration.py b/FISHscale/utils/fast_iteration.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/gene_correlation.py b/FISHscale/utils/gene_correlation.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/hex_regionalization.py b/FISHscale/utils/hex_regionalization.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/inside_polygon.py b/FISHscale/utils/inside_polygon.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/normalization.py b/FISHscale/utils/normalization.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/regionalization_gradient.py b/FISHscale/utils/regionalization_gradient.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/regionalization_multi.py b/FISHscale/utils/regionalization_multi.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/segmentation_utils.py b/FISHscale/utils/segmentation_utils.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/spatial_metrics.py b/FISHscale/utils/spatial_metrics.py old mode 100755 new mode 100644 diff --git a/FISHscale/utils/volume_align.py b/FISHscale/utils/volume_align.py old mode 100755 new mode 100644 diff --git a/FISHscale/visualization/__init__.py b/FISHscale/visualization/__init__.py old mode 100755 new mode 100644 diff --git a/FISHscale/visualization/gene_scatter.py b/FISHscale/visualization/gene_scatter.py old mode 100755 new mode 100644 diff --git a/FISHscale/visualization/vis_linux.py b/FISHscale/visualization/vis_linux.py old mode 100755 new mode 100644 diff --git a/FISHscale/visualization/vis_macos.py b/FISHscale/visualization/vis_macos.py old mode 100755 new mode 100644 diff --git a/FISHscale_demo.png b/FISHscale_demo.png old mode 100755 new mode 100644 diff --git a/FISHscale_open_3D.gif b/FISHscale_open_3D.gif old mode 100755 new mode 100644 diff --git a/FISHscale_open_3D_2.gif b/FISHscale_open_3D_2.gif old mode 100755 new mode 100644 diff --git a/FISHscale_tutorial_multi_dataset.ipynb b/FISHscale_tutorial_multi_dataset.ipynb old mode 100755 new mode 100644 diff --git a/Images/144_color_example.png b/Images/144_color_example.png old mode 100755 new mode 100644 diff --git a/Images/1D_distribution_example.png b/Images/1D_distribution_example.png old mode 100755 new mode 100644 diff --git a/Images/2133.png_300.png b/Images/2133.png_300.png old mode 100755 new mode 100644 diff --git a/Images/CBC_example_high.png b/Images/CBC_example_high.png old mode 100755 new mode 100644 diff --git a/Images/CBC_example_low.png b/Images/CBC_example_low.png old mode 100755 new mode 100644 diff --git a/Images/Hex_bin_example_data1810um.png b/Images/Hex_bin_example_data1810um.png old mode 100755 new mode 100644 diff --git a/Images/test16x16.png b/Images/test16x16.png old mode 100755 new mode 100644 diff --git a/README.md b/README.md old mode 100755 new mode 100644 diff --git a/example_notebooks/FISHscale_dask_attrs_and_graphsage.ipynb b/example_notebooks/FISHscale_dask_attrs_and_graphsage.ipynb old mode 100755 new mode 100644 diff --git a/example_notebooks/FISHscale_tutorial_multi_dataset.ipynb b/example_notebooks/FISHscale_tutorial_multi_dataset.ipynb old mode 100755 new mode 100644 diff --git a/example_notebooks/FISHscale_tutorial_single_dataset.ipynb b/example_notebooks/FISHscale_tutorial_single_dataset.ipynb old mode 100755 new mode 100644 diff --git a/example_notebooks/MultiAnalysis.ipynb b/example_notebooks/MultiAnalysis.ipynb old mode 100755 new mode 100644 diff --git a/setup.py b/setup.py old mode 100755 new mode 100644 From ecf266c9d17aeceedcd164be48b0d676e7d75693 Mon Sep 17 00:00:00 2001 From: Alejandro Mossi Date: Sat, 1 Apr 2023 16:19:18 +0200 Subject: [PATCH 02/11] allow to plot different genes with different sizes --- FISHscale/visualization/gene_scatter.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/FISHscale/visualization/gene_scatter.py b/FISHscale/visualization/gene_scatter.py index dc5fb8e8..1d711462 100644 --- a/FISHscale/visualization/gene_scatter.py +++ b/FISHscale/visualization/gene_scatter.py @@ -131,8 +131,7 @@ def _add_scale_bar(self, ax): class GeneScatter(AxSize): - - def scatter_plot(self, genes: Union[List, np.ndarray], s: float=0.1, + def scatter_plot(self, genes: Union[List, np.ndarray], s: Union[float, list]=0.1, colors: Union[List, np.ndarray] = None, ax_scale_factor: int=10, view: Union[Any, List] = None, scalebar: bool=True, show_axes: bool=False, @@ -205,7 +204,11 @@ def scatter_plot(self, genes: Union[List, np.ndarray], s: float=0.1, #Plot points if type(colors) == type(None): colors = [self.color_dict[g] for g in genes] - for g, c in zip(genes, colors): + + # Allows to pass different sizes for each gene + if type(s) == float: + s = [s] * len(genes) + for g, c, s_ in zip(genes, colors, s): data = self.get_gene(g) x = data.x y = data.y @@ -218,7 +221,7 @@ def scatter_plot(self, genes: Union[List, np.ndarray], s: float=0.1, if reset_xy: x = x - view[0][0] y = y - view[0][1] - ax.scatter(x, y, s=s, color=c, zorder=0, label=g, alpha=alpha, rasterized=True) + ax.scatter(x, y, s=s_, color=c, zorder=0, label=g, alpha=alpha, rasterized=True) del data if invert_yaxis: @@ -485,9 +488,11 @@ def attribute_scatter_plot(self, attributes: Union[List, np.ndarray], section:st self.color_dict[str(a)] = (r()/255,r()/255,r()/255) colors = [self.color_dict[g] for g in attributes] - for g, c in zip(attributes, colors): - data = self.dask_attrs[section] - data= data[data[section].isin([g])].compute() + # Allows to pass different sizes for each gene + if type(s) == float: + s = [s] * len(genes) + for g, c, s_ in zip(genes, colors, s): + data = self.get_gene(g) x = data.x y = data.y if isinstance(view, list): @@ -496,7 +501,10 @@ def attribute_scatter_plot(self, attributes: Union[List, np.ndarray], section:st filt = filt_x & filt_y x = x[filt] y = y[filt] - ax.scatter(x, y, s=s, color=c, zorder=0, label=g, alpha=alpha, rasterized=True) + if reset_xy: + x = x - view[0][0] + y = y - view[0][1] + ax.scatter(x, y, s=s_, color=c, zorder=0, label=g, alpha=alpha, rasterized=True) del data #Rescale From b976ca78848ce79d4d58dbaf8770f9932feef791 Mon Sep 17 00:00:00 2001 From: alejandro Date: Wed, 5 Apr 2023 16:14:51 +0200 Subject: [PATCH 03/11] m --- FISHscale/graphNN/cellularneighborhoods.py | 1 - FISHscale/graphNN/models_deepresidual.py | 394 +++++++++++++++++++++ 2 files changed, 394 insertions(+), 1 deletion(-) create mode 100755 FISHscale/graphNN/models_deepresidual.py diff --git a/FISHscale/graphNN/cellularneighborhoods.py b/FISHscale/graphNN/cellularneighborhoods.py index bad0fde6..9adbddba 100755 --- a/FISHscale/graphNN/cellularneighborhoods.py +++ b/FISHscale/graphNN/cellularneighborhoods.py @@ -113,7 +113,6 @@ def __init__(self, anndata.raw = anndata if normalize: sc.pp.normalize_total(anndata, target_sum=1e4) - sc.pp.log1p(anndata) self.anndata = anndata ### Model hyperparameters diff --git a/FISHscale/graphNN/models_deepresidual.py b/FISHscale/graphNN/models_deepresidual.py new file mode 100755 index 00000000..b81ba4fa --- /dev/null +++ b/FISHscale/graphNN/models_deepresidual.py @@ -0,0 +1,394 @@ +from matplotlib.pyplot import get +import torchmetrics +import dgl +import torch as th +import torch.nn as nn +import torch.nn.functional as F +import dgl.nn.pytorch as dglnn +import dgl.function as fn +import tqdm +from pytorch_lightning import LightningModule +from FISHscale.graphNN.submodules import Classifier +from pyro.distributions import GammaPoisson, Poisson +from torch.distributions import Gamma,Normal, Multinomial, kl_divergence as kl +import pyro +from pyro import distributions as dist +from pyro.nn import PyroModule, PyroSample +from pyro.distributions import constraints +from pyro import poutine +from scvi.nn import FCLayers + +class CrossEntropyLoss(nn.Module): + def forward(self, block_outputs, pos_graph, neg_graph): + with pos_graph.local_scope(): + pos_graph.ndata['h'] = block_outputs + pos_graph.apply_edges(fn.u_mul_v('h', 'h', 'score')) + pos_score = pos_graph.edata['score'] + with neg_graph.local_scope(): + neg_graph.ndata['h'] = block_outputs + neg_graph.apply_edges(fn.u_mul_v('h', 'h', 'score')) + neg_score = neg_graph.edata['score'] + + pos_loss, neg_loss= -F.logsigmoid(pos_score.sum(-1)), - F.logsigmoid(-neg_score.sum(-1)) + neg_loss = neg_loss.reshape([pos_score.shape[0], int(neg_score.shape[0]/pos_score.shape[0])])#.mean(axis=1) + loss = pos_loss + neg_loss.mean(axis=1) + #score = th.cat([pos_score, neg_score]) + #label = th.cat([th.ones_like(pos_score), th.zeros_like(neg_score)]).long() + #loss = F.binary_cross_entropy_with_logits(score, label.float()) + return loss#, pos_loss, neg_loss + +class SAGELightning(LightningModule): + def __init__(self, + in_feats, + n_latent, + n_classes, + n_layers, + n_hidden=64, + dropout=0.1, + lr=0.001, + features_name='gene', + supervised=False, + reference=0, + smooth=False, + device='cpu', + aggregator='attentional', + celltype_distribution=None, + ncells = None, + inference_type='deterministic', + l_loc = None, + l_scale = None, + scale_factor = 1, + warmup_factor = 1, + loss_type='unsupervised',#or supervised + ): + super().__init__() + + self.module = SAGE(in_feats=in_feats, + n_hidden=n_hidden, + n_latent=n_latent, + n_classes=n_classes, + n_layers=n_layers, + dropout=dropout, + supervised=supervised, + aggregator= aggregator, + features_name=features_name, + ) + + self.lr = lr + self.supervised= supervised + self.loss_fcn = CrossEntropyLoss() + self.kappa = 0 + self.reference=th.tensor(reference,dtype=th.float32, device=device) + self.smooth = smooth + self.in_feats = in_feats + self.n_hidden = n_hidden + self.n_latent = n_latent + self.inference_type = inference_type + self.loss_type = loss_type + self.features_name = features_name + + if self.inference_type == 'VI': + self.automatic_optimization = False + self.svi = PyroOptWrap(model=self.model, + guide=self.guide, + optim=pyro.optim.Adam({"lr": self.lr}), + loss=pyro.infer.Trace_ELBO()) + + self.automatic_optimization = False + if self.supervised: + self.num_classes = n_classes + ''' + self.l_loc = l_loc + self.l_scale = l_scale + self.train_acc = torchmetrics.Accuracy() + self.kl = th.nn.KLDivLoss(reduction='sum') + self.dist = celltype_distribution + self.ncells = ncells + self.scale_factor = scale_factor + self.alpha = 1 + self.warmup_counter = 0 + self.warmup_factor=warmup_factor + ''' + + def training_step(self, batch, batch_idx): + + self.reference = self.reference.to(self.device) + losses = [] + for sub_batch in batch: + if self.supervised: + _, _, mfgs = sub_batch + mfgs = [mfg.int() for mfg in mfgs] + batch_inputs = mfgs[0].srcdata[self.features_name] + + else: + _, pos, neg, mfgs = sub_batch + pos_ids = pos.edges()[0] + mfgs = [mfg.int() for mfg in mfgs] + batch_inputs = mfgs[0].srcdata[self.features_name] + + if len(batch_inputs.shape) == 1: + if self.supervised == False: + batch_inputs = F.one_hot(batch_inputs.to(th.int64), num_classes=self.in_feats) + + zn_loc = self.module.encoder(batch_inputs,mfgs) + if self.loss_type == 'unsupervised': + graph_loss = self.loss_fcn(zn_loc, pos, neg).mean() + else: + graph_loss = F.cross_entropy(zn_loc, mfgs[-1].dstdata['label']) + + opt_g = self.optimizers() + opt_g.zero_grad() + self.manual_backward(graph_loss) + opt_g.step() + + loss = graph_loss + losses.append(loss) + loss = th.stack(losses).mean() + self.log('train_loss', + loss, prog_bar=True, on_step=True, on_epoch=True,batch_size=zn_loc.shape[0]) + return loss + + def configure_optimizers(self): + optimizer_graph = th.optim.Adam(self.module.encoder.parameters(), lr=self.lr) + #optimizer_nb = th.optim.Adam(self.module.encoder_molecule.parameters(), lr=0.01) + lr_scheduler = th.optim.lr_scheduler.ReduceLROnPlateau(optimizer_graph,) + scheduler = { + 'scheduler': lr_scheduler, + 'reduce_on_plateau': True, + # val_checkpoint_on is val_loss passed in as checkpoint_on + 'monitor': 'train_loss' + } + return [optimizer_graph],[scheduler] + + def validation_step(self,batch, batch_idx): + pass + + def training_epoch_end(self, outputs): + sch = self.lr_schedulers() + + # If the selected scheduler is a ReduceLROnPlateau scheduler. + if isinstance(sch, th.optim.lr_scheduler.ReduceLROnPlateau): + sch.step(self.trainer.callback_metrics["train_loss"]) + +class PyroOptWrap(pyro.infer.SVI): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def state_dict(self,): + return {} + +class SAGE(nn.Module): + def __init__(self, + in_feats, + n_hidden, + n_latent, + n_classes, + n_layers, + dropout, + supervised, + aggregator, + features_name='gene'): + super().__init__() + self.n_layers = n_layers + self.n_hidden = n_hidden + self.n_latent = n_latent + self.n_classes = n_classes + self.supervised = supervised + self.aggregator = aggregator + self.in_feats = in_feats + self.features_name = features_name + + n_embed = 256 + self.embedding = nn.Embedding(in_feats, n_embed) + self.encoder = Encoder(in_feats=in_feats, + n_hidden=n_hidden, + n_latent=n_latent, + n_layers=n_layers, + supervised=supervised, + aggregator=aggregator, + dropout= dropout, + ) + + def inference(self, g, device, batch_size, num_workers, core_nodes=None): + """ + Inference with the GraphSAGE model on full neighbors (i.e. without neighbor sampling). + g : the entire graph. + The inference code is written in a fashion that it could handle any number of nodes and + layers. + """ + self.eval() + if len(g.ndata[self.features_name].shape) == 1: + g.ndata['h'] = th.log(F.one_hot(g.ndata[self.features_name], num_classes=self.in_feats)+1) + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) + + if core_nodes is None: + dataloader = dgl.dataloading.NodeDataLoader( + g, th.arange(g.num_nodes()).to(g.device), sampler, device=device, + batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, + persistent_workers=(num_workers > 0)) + else: + dataloader = dgl.dataloading.NodeDataLoader( + g, core_nodes.to(g.device), sampler, device=device, + batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, + persistent_workers=(num_workers > 0)) + + else: + g.ndata['h'] = th.log(g.ndata[self.features_name]+1) + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) + dataloader = dgl.dataloading.NodeDataLoader( + g, th.arange(g.num_nodes()).to(g.device), sampler, device=device, + batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, + persistent_workers=(num_workers > 0)) + + for l, layer in enumerate(self.encoder.encoder_dict['GS']): + if l == self.n_layers - 1: + y = th.zeros(g.num_nodes(), self.n_latent) #if not self.supervised else th.zeros(g.num_nodes(), self.n_classes) + else: + if self.aggregator == 'attentional': + y = th.zeros(g.num_nodes(), self.n_hidden*4) + else: + y = th.zeros(g.num_nodes(), self.n_hidden) + + if self.supervised: + p_class = th.zeros(g.num_nodes(), self.n_classes) + else: + p_class = None + + for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): + x = blocks[0].srcdata['h'] + if l != self.n_layers-1: + h = layer(blocks[0], x) + if self.aggregator == 'attentional': + h= h.flatten(1) + else: + h = layer(blocks[0], x) + if self.aggregator == 'attentional': + h = h.mean(1) + #h = self.encoder.gs_mu(h) + y[output_nodes] = h.cpu().detach()#.numpy() + g.ndata['h'] = y + return y, p_class + + def inference_attention(self, g, device, batch_size, num_workers, nodes=None, buffer_device=None): + # The difference between this inference function and the one in the official + # example is that the intermediate results can also benefit from prefetching. + if type(nodes) == type(None): + nodes = th.arange(g.num_nodes()).to(g.device) + + if len(g.ndata[self.features_name].shape) == 1: + g.ndata['h'] = th.log(F.one_hot(g.ndata[self.features_name], num_classes=self.in_feats)+1) + else: + g.ndata['h'] = th.log(g.ndata[self.features_name]+1) + + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) + dataloader = dgl.dataloading.NodeDataLoader( + g, nodes, sampler, device=device, + batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, + persistent_workers=(num_workers > 0)) + + if buffer_device is None: + buffer_device = device + + for l, layer in enumerate(self.encoder.encoder_dict['GS']): + if l == self.n_layers - 1: + y = th.zeros(g.num_nodes(), self.n_latent,device=buffer_device) + att2_list = [] #if not self.supervised else th.zeros(g.num_nodes(), self.n_classes) + else: + y = th.zeros(g.num_nodes(), self.n_hidden*4, device=buffer_device) + att1_list = [] + + for input_nodes, output_nodes, blocks in dataloader: + x = blocks[0].srcdata['h'] + if l != self.n_layers-1: + h,att1 = layer(blocks[0], x,get_attention=True) + att1_list.append(att1.mean(1).cpu().detach()) + h= h.flatten(1) + + else: + h, att2 = layer(blocks[0], x,get_attention=True) + att2_list.append(att2.mean(1).cpu().detach()) + h = h.mean(1) + #h = self.encoder.gs_mu(h) + y[output_nodes] = h.cpu().detach().to(buffer_device) + g.ndata['h'] = y + return th.concat(att1_list), th.concat(att2_list) + +class Encoder(nn.Module): + def __init__( + self, + in_feats, + n_hidden, + n_latent, + n_layers, + supervised, + aggregator, + dropout, + ): + super().__init__() + self.aggregator = aggregator + layers = nn.ModuleList() + if supervised: + self.norm = F.normalize#PairNorm()#DiffGroupNorm(n_hidden,n_classes,None) + else: + self.norm = F.normalize#PairNorm()#DiffGroupNorm(n_hidden,20) + self.num_heads = 4 + self.n_layers = n_layers + for i in range(0,n_layers-1): + if i > 0: + in_feats = n_hidden + x = 0.2 + else: + x = 0 + + if aggregator == 'attentional': + layers.append(dglnn.GATv2Conv(in_feats, + n_hidden, + num_heads=self.num_heads, + feat_drop=x, + #allow_zero_in_degree=False + )) + + else: + layers.append(dglnn.SAGEConv(in_feats, + n_hidden, + aggregator_type=aggregator, + #feat_drop=0.2, + activation=F.relu, + norm=self.norm, + )) + + if aggregator == 'attentional': + layers.append(dglnn.GATv2Conv(n_hidden*self.num_heads, + n_latent, + num_heads=self.num_heads, + feat_drop=dropout, + #allow_zero_in_degree=False + )) + + else: + layers.append(dglnn.SAGEConv(n_hidden, + n_latent, + aggregator_type=aggregator, + feat_drop=dropout, + activation=F.relu, + norm=self.norm + )) + + self.encoder_dict = nn.ModuleDict({'GS': layers}) + #self.gs_mu = nn.Linear(n_hidden, n_latent) + #self.gs_var = nn.Linear(n_hidden, n_latent) + + def forward(self, x, blocks=None): + h = th.log(x+1) + for l, (layer, block) in enumerate(zip(self.encoder_dict['GS'], blocks)): + if self.aggregator != 'attentional': + h = layer(block, h,) + else: + if l != self.n_layers-1: + h = layer(block, h,).flatten(1) + else: + h = layer(block, h,).mean(1) + + #z_loc = self.gs_mu(h) + #z_scale = th.exp(self.gs_var(h)) +1e-6 + return h \ No newline at end of file From bca35576eb1e801a4f48270c6f5d565278227131 Mon Sep 17 00:00:00 2001 From: alejandro Date: Wed, 5 Apr 2023 16:16:24 +0200 Subject: [PATCH 04/11] m --- FISHscale/graphNN/models_deepresidual.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/FISHscale/graphNN/models_deepresidual.py b/FISHscale/graphNN/models_deepresidual.py index b81ba4fa..c04de648 100755 --- a/FISHscale/graphNN/models_deepresidual.py +++ b/FISHscale/graphNN/models_deepresidual.py @@ -197,9 +197,6 @@ def __init__(self, self.aggregator = aggregator self.in_feats = in_feats self.features_name = features_name - - n_embed = 256 - self.embedding = nn.Embedding(in_feats, n_embed) self.encoder = Encoder(in_feats=in_feats, n_hidden=n_hidden, n_latent=n_latent, @@ -326,11 +323,15 @@ def __init__( ): super().__init__() self.aggregator = aggregator + n_embed = 256 + self.embedding = nn.Embedding(in_feats, n_embed) + layers = nn.ModuleList() if supervised: self.norm = F.normalize#PairNorm()#DiffGroupNorm(n_hidden,n_classes,None) else: self.norm = F.normalize#PairNorm()#DiffGroupNorm(n_hidden,20) + self.num_heads = 4 self.n_layers = n_layers for i in range(0,n_layers-1): From 81688e144f35865af09f3c96b459e1ed6f29147a Mon Sep 17 00:00:00 2001 From: alejandro Date: Wed, 5 Apr 2023 16:22:52 +0200 Subject: [PATCH 05/11] deep residuals --- FISHscale/graphNN/models_deepresidual.py | 25 +++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/FISHscale/graphNN/models_deepresidual.py b/FISHscale/graphNN/models_deepresidual.py index c04de648..b3a29517 100755 --- a/FISHscale/graphNN/models_deepresidual.py +++ b/FISHscale/graphNN/models_deepresidual.py @@ -198,8 +198,8 @@ def __init__(self, self.in_feats = in_feats self.features_name = features_name self.encoder = Encoder(in_feats=in_feats, - n_hidden=n_hidden, - n_latent=n_latent, + n_hidden=128, + n_latent=128, n_layers=n_layers, supervised=supervised, aggregator=aggregator, @@ -215,7 +215,8 @@ def inference(self, g, device, batch_size, num_workers, core_nodes=None): """ self.eval() if len(g.ndata[self.features_name].shape) == 1: - g.ndata['h'] = th.log(F.one_hot(g.ndata[self.features_name], num_classes=self.in_feats)+1) + g.ndata['h'] = self.encoder.embedding(F.one_hot(g.ndata[self.features_name], num_classes=self.in_feats)) + #g.ndata['h'] = th.log(F.one_hot(g.ndata[self.features_name], num_classes=self.in_feats)+1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) if core_nodes is None: @@ -230,7 +231,8 @@ def inference(self, g, device, batch_size, num_workers, core_nodes=None): persistent_workers=(num_workers > 0)) else: - g.ndata['h'] = th.log(g.ndata[self.features_name]+1) + #g.ndata['h'] = th.log(g.ndata[self.features_name]+1) + g.ndata['h'] = self.encoder.embedding(g.ndata[self.features_name]) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) dataloader = dgl.dataloading.NodeDataLoader( g, th.arange(g.num_nodes()).to(g.device), sampler, device=device, @@ -273,9 +275,11 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu nodes = th.arange(g.num_nodes()).to(g.device) if len(g.ndata[self.features_name].shape) == 1: - g.ndata['h'] = th.log(F.one_hot(g.ndata[self.features_name], num_classes=self.in_feats)+1) + #g.ndata['h'] = th.log(F.one_hot(g.ndata[self.features_name], num_classes=self.in_feats)+1) + g.ndata['h'] = self.encoder.embedding(F.one_hot(g.ndata[self.features_name], num_classes=self.in_feats)) else: - g.ndata['h'] = th.log(g.ndata[self.features_name]+1) + #g.ndata['h'] = th.log(g.ndata[self.features_name]+1) + g.ndata['h'] = self.encoder.embedding(g.ndata[self.features_name]) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) dataloader = dgl.dataloading.NodeDataLoader( @@ -323,9 +327,9 @@ def __init__( ): super().__init__() self.aggregator = aggregator - n_embed = 256 + n_embed = 128 self.embedding = nn.Embedding(in_feats, n_embed) - + layers = nn.ModuleList() if supervised: self.norm = F.normalize#PairNorm()#DiffGroupNorm(n_hidden,n_classes,None) @@ -380,7 +384,9 @@ def __init__( #self.gs_var = nn.Linear(n_hidden, n_latent) def forward(self, x, blocks=None): - h = th.log(x+1) + #h = th.log(x+1) + e = self.embedding(x) + h = e for l, (layer, block) in enumerate(zip(self.encoder_dict['GS'], blocks)): if self.aggregator != 'attentional': h = layer(block, h,) @@ -389,6 +395,7 @@ def forward(self, x, blocks=None): h = layer(block, h,).flatten(1) else: h = layer(block, h,).mean(1) + h = h + e #z_loc = self.gs_mu(h) #z_scale = th.exp(self.gs_var(h)) +1e-6 From dc617173532eda6895d7497d61cb8279755bb6b9 Mon Sep 17 00:00:00 2001 From: alejandro Date: Fri, 7 Apr 2023 10:58:00 +0200 Subject: [PATCH 06/11] m --- FISHscale/graphNN/cellularneighborhoods.py | 9 +- FISHscale/graphNN/models_deepresidual.py | 180 +++++++++------------ 2 files changed, 83 insertions(+), 106 deletions(-) diff --git a/FISHscale/graphNN/cellularneighborhoods.py b/FISHscale/graphNN/cellularneighborhoods.py index 9adbddba..545f0162 100755 --- a/FISHscale/graphNN/cellularneighborhoods.py +++ b/FISHscale/graphNN/cellularneighborhoods.py @@ -10,7 +10,8 @@ import pytorch_lightning as pl import pandas as pd import dgl -from FISHscale.graphNN.models import SAGELightning +from FISHscale.graphNN.models_deepresidual import SAGELightning +#from FISHscale.graphNN.models import SAGELightning from FISHscale.graphNN.graph_utils import GraphUtils, GraphPlotting from FISHscale.graphNN.graph_decoder import GraphDecoder @@ -397,10 +398,10 @@ def get_latents(self): labelled (bool, optional): [description]. Defaults to True. """ self.model.eval() - - self.latent_unlabelled, prediction_unlabelled = self.model.module.inference( + self.g.to('cuda') + self.latent_unlabelled, _ = self.model.module.inference( self.g, - self.model.device, + self.g.device, 10*512, 0) diff --git a/FISHscale/graphNN/models_deepresidual.py b/FISHscale/graphNN/models_deepresidual.py index b3a29517..940052bd 100755 --- a/FISHscale/graphNN/models_deepresidual.py +++ b/FISHscale/graphNN/models_deepresidual.py @@ -45,21 +45,16 @@ def __init__(self, n_layers, n_hidden=64, dropout=0.1, - lr=0.001, + lr=1e-4, features_name='gene', supervised=False, reference=0, smooth=False, device='cpu', aggregator='attentional', - celltype_distribution=None, - ncells = None, inference_type='deterministic', - l_loc = None, - l_scale = None, - scale_factor = 1, - warmup_factor = 1, loss_type='unsupervised',#or supervised + #decoder_loss=True, ): super().__init__() @@ -97,18 +92,6 @@ def __init__(self, self.automatic_optimization = False if self.supervised: self.num_classes = n_classes - ''' - self.l_loc = l_loc - self.l_scale = l_scale - self.train_acc = torchmetrics.Accuracy() - self.kl = th.nn.KLDivLoss(reduction='sum') - self.dist = celltype_distribution - self.ncells = ncells - self.scale_factor = scale_factor - self.alpha = 1 - self.warmup_counter = 0 - self.warmup_factor=warmup_factor - ''' def training_step(self, batch, batch_idx): @@ -125,14 +108,24 @@ def training_step(self, batch, batch_idx): pos_ids = pos.edges()[0] mfgs = [mfg.int() for mfg in mfgs] batch_inputs = mfgs[0].srcdata[self.features_name] + dr = mfgs[-1].dstdata[self.features_name] if len(batch_inputs.shape) == 1: if self.supervised == False: - batch_inputs = F.one_hot(batch_inputs.to(th.int64), num_classes=self.in_feats) + #batch_inputs = F.one_hot(batch_inputs.to(th.int64), num_classes=self.in_feats) + batch_inputs = batch_inputs.to(th.int64) - zn_loc = self.module.encoder(batch_inputs,mfgs) + zn_loc = self.module.encoder(batch_inputs,mfgs, dr=dr) if self.loss_type == 'unsupervised': graph_loss = self.loss_fcn(zn_loc, pos, neg).mean() + decoder_n1 = self.module.encoder.decoder(zn_loc).softmax(dim=-1) + adjacency_matrix = mfgs[1].adjacency_matrix().to_dense() + feats_n1 = F.one_hot((mfgs[1].srcdata[self.features_name]), num_classes=self.in_feats).T + feats_n1 = (th.tensor(feats_n1,dtype=th.float32)@adjacency_matrix.to(self.device)).T + feats_n1 = feats_n1.softmax(dim=-1) + #print(feats_n1.shape, decoder_n1.shape) + graph_loss += - nn.CosineSimilarity(dim=1, eps=1e-08)(decoder_n1, feats_n1).mean(axis=0) + else: graph_loss = F.cross_entropy(zn_loc, mfgs[-1].dstdata['label']) @@ -198,8 +191,8 @@ def __init__(self, self.in_feats = in_feats self.features_name = features_name self.encoder = Encoder(in_feats=in_feats, - n_hidden=128, - n_latent=128, + n_hidden= 64, + n_latent= 64, n_layers=n_layers, supervised=supervised, aggregator=aggregator, @@ -214,31 +207,15 @@ def inference(self, g, device, batch_size, num_workers, core_nodes=None): layers. """ self.eval() - if len(g.ndata[self.features_name].shape) == 1: - g.ndata['h'] = self.encoder.embedding(F.one_hot(g.ndata[self.features_name], num_classes=self.in_feats)) - #g.ndata['h'] = th.log(F.one_hot(g.ndata[self.features_name], num_classes=self.in_feats)+1) - sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) - - if core_nodes is None: - dataloader = dgl.dataloading.NodeDataLoader( - g, th.arange(g.num_nodes()).to(g.device), sampler, device=device, - batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, - persistent_workers=(num_workers > 0)) - else: - dataloader = dgl.dataloading.NodeDataLoader( - g, core_nodes.to(g.device), sampler, device=device, - batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, - persistent_workers=(num_workers > 0)) + g.ndata['h'] = g.ndata[self.features_name] + sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) + + dataloader = dgl.dataloading.NodeDataLoader( + g, th.arange(g.num_nodes()).to(g.device), sampler, device=device, + batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, + persistent_workers=(num_workers > 0)) + - else: - #g.ndata['h'] = th.log(g.ndata[self.features_name]+1) - g.ndata['h'] = self.encoder.embedding(g.ndata[self.features_name]) - sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) - dataloader = dgl.dataloading.NodeDataLoader( - g, th.arange(g.num_nodes()).to(g.device), sampler, device=device, - batch_size=batch_size, shuffle=False, drop_last=False, num_workers=num_workers, - persistent_workers=(num_workers > 0)) - for l, layer in enumerate(self.encoder.encoder_dict['GS']): if l == self.n_layers - 1: y = th.zeros(g.num_nodes(), self.n_latent) #if not self.supervised else th.zeros(g.num_nodes(), self.n_classes) @@ -254,16 +231,22 @@ def inference(self, g, device, batch_size, num_workers, core_nodes=None): p_class = None for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): - x = blocks[0].srcdata['h'] + if l == 0: + x = self.encoder.embedding(blocks[0].srcdata['h']) + else: + x = blocks[0].srcdata['h'] + dr = blocks[0].dstdata[self.features_name] if l != self.n_layers-1: - h = layer(blocks[0], x) - if self.aggregator == 'attentional': - h= h.flatten(1) + h,att1 = layer(blocks[0], x,get_attention=True) + h= h.flatten(1) + else: - h = layer(blocks[0], x) - if self.aggregator == 'attentional': - h = h.mean(1) - #h = self.encoder.gs_mu(h) + h, att2 = layer(blocks[0], x,get_attention=True) + h = h.mean(1) + + h = self.encoder.ln1(h) + self.encoder.embedding(dr) + h = self.encoder.fw(self.encoder.ln2(h)) + h + y[output_nodes] = h.cpu().detach()#.numpy() g.ndata['h'] = y return y, p_class @@ -274,13 +257,8 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu if type(nodes) == type(None): nodes = th.arange(g.num_nodes()).to(g.device) - if len(g.ndata[self.features_name].shape) == 1: - #g.ndata['h'] = th.log(F.one_hot(g.ndata[self.features_name], num_classes=self.in_feats)+1) - g.ndata['h'] = self.encoder.embedding(F.one_hot(g.ndata[self.features_name], num_classes=self.in_feats)) - else: - #g.ndata['h'] = th.log(g.ndata[self.features_name]+1) - g.ndata['h'] = self.encoder.embedding(g.ndata[self.features_name]) - + + g.ndata['h'] = g.ndata[self.features_name] sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) dataloader = dgl.dataloading.NodeDataLoader( g, nodes, sampler, device=device, @@ -299,7 +277,12 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu att1_list = [] for input_nodes, output_nodes, blocks in dataloader: - x = blocks[0].srcdata['h'] + #x = blocks[0].srcdata['h'] + if l == 0: + x = self.encoder.embedding(blocks[0].srcdata['h']) + else: + x = blocks[0].srcdata['h'] + dr = blocks[0].dstdata[self.features_name] if l != self.n_layers-1: h,att1 = layer(blocks[0], x,get_attention=True) att1_list.append(att1.mean(1).cpu().detach()) @@ -309,7 +292,10 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu h, att2 = layer(blocks[0], x,get_attention=True) att2_list.append(att2.mean(1).cpu().detach()) h = h.mean(1) - #h = self.encoder.gs_mu(h) + + h = self.encoder.ln1(h) + self.encoder.embedding(dr) + h = self.encoder.fw(self.encoder.ln2(h)) + h + y[output_nodes] = h.cpu().detach().to(buffer_device) g.ndata['h'] = y return th.concat(att1_list), th.concat(att2_list) @@ -327,15 +313,12 @@ def __init__( ): super().__init__() self.aggregator = aggregator - n_embed = 128 + n_embed = 64 self.embedding = nn.Embedding(in_feats, n_embed) + self.ln1 = nn.LayerNorm(n_embed) + self.ln2 = nn.LayerNorm(n_embed) layers = nn.ModuleList() - if supervised: - self.norm = F.normalize#PairNorm()#DiffGroupNorm(n_hidden,n_classes,None) - else: - self.norm = F.normalize#PairNorm()#DiffGroupNorm(n_hidden,20) - self.num_heads = 4 self.n_layers = n_layers for i in range(0,n_layers-1): @@ -345,48 +328,42 @@ def __init__( else: x = 0 - if aggregator == 'attentional': - layers.append(dglnn.GATv2Conv(in_feats, + layers.append(dglnn.GATv2Conv(n_embed, n_hidden, num_heads=self.num_heads, feat_drop=x, #allow_zero_in_degree=False )) - else: - layers.append(dglnn.SAGEConv(in_feats, - n_hidden, - aggregator_type=aggregator, - #feat_drop=0.2, - activation=F.relu, - norm=self.norm, - )) - if aggregator == 'attentional': - layers.append(dglnn.GATv2Conv(n_hidden*self.num_heads, - n_latent, - num_heads=self.num_heads, - feat_drop=dropout, - #allow_zero_in_degree=False - )) - - else: - layers.append(dglnn.SAGEConv(n_hidden, - n_latent, - aggregator_type=aggregator, - feat_drop=dropout, - activation=F.relu, - norm=self.norm - )) + layers.append(dglnn.GATv2Conv(n_embed*self.num_heads, + n_latent, + num_heads=self.num_heads, + feat_drop=dropout, + #allow_zero_in_degree=False + )) self.encoder_dict = nn.ModuleDict({'GS': layers}) - #self.gs_mu = nn.Linear(n_hidden, n_latent) - #self.gs_var = nn.Linear(n_hidden, n_latent) + #self.fw = nn.Linear(n_hidden, n_embed) + self.fw = nn.Sequential( + nn.Linear(n_latent, 4 * n_latent), + nn.ReLU(), + nn.Linear(4 * n_latent, n_latent), + nn.Dropout(dropout), + ) + + self.decoder = nn.Sequential( + nn.Linear(n_latent, 4 * n_latent), + nn.ReLU(), + nn.Dropout(dropout), + nn.Linear(4 * n_latent, in_feats), + ) - def forward(self, x, blocks=None): + def forward(self, x, blocks=None, dr=0): #h = th.log(x+1) e = self.embedding(x) h = e + #print(h.shape) for l, (layer, block) in enumerate(zip(self.encoder_dict['GS'], blocks)): if self.aggregator != 'attentional': h = layer(block, h,) @@ -395,8 +372,7 @@ def forward(self, x, blocks=None): h = layer(block, h,).flatten(1) else: h = layer(block, h,).mean(1) - h = h + e - - #z_loc = self.gs_mu(h) + h = self.ln1(h) + self.embedding(dr) + h = self.fw(self.ln2(h)) + h #z_scale = th.exp(self.gs_var(h)) +1e-6 return h \ No newline at end of file From 420daae9893a32f9aa1debe6af0ec4068f43c4ce Mon Sep 17 00:00:00 2001 From: alejandro Date: Fri, 7 Apr 2023 11:06:32 +0200 Subject: [PATCH 07/11] m --- FISHscale/graphNN/models_deepresidual.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/FISHscale/graphNN/models_deepresidual.py b/FISHscale/graphNN/models_deepresidual.py index 940052bd..0ce800e3 100755 --- a/FISHscale/graphNN/models_deepresidual.py +++ b/FISHscale/graphNN/models_deepresidual.py @@ -218,13 +218,11 @@ def inference(self, g, device, batch_size, num_workers, core_nodes=None): for l, layer in enumerate(self.encoder.encoder_dict['GS']): if l == self.n_layers - 1: - y = th.zeros(g.num_nodes(), self.n_latent) #if not self.supervised else th.zeros(g.num_nodes(), self.n_classes) + y = th.zeros(g.num_nodes(), self.encoder.n_embed) #if not self.supervised else th.zeros(g.num_nodes(), self.n_classes) else: if self.aggregator == 'attentional': - y = th.zeros(g.num_nodes(), self.n_hidden*4) - else: - y = th.zeros(g.num_nodes(), self.n_hidden) - + y = th.zeros(g.num_nodes(), self.encoder.n_embed*self.encoder.num_heads) + if self.supervised: p_class = th.zeros(g.num_nodes(), self.n_classes) else: @@ -270,10 +268,10 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu for l, layer in enumerate(self.encoder.encoder_dict['GS']): if l == self.n_layers - 1: - y = th.zeros(g.num_nodes(), self.n_latent,device=buffer_device) + y = th.zeros(g.num_nodes(), self.encoder.n_embed, device=buffer_device) att2_list = [] #if not self.supervised else th.zeros(g.num_nodes(), self.n_classes) else: - y = th.zeros(g.num_nodes(), self.n_hidden*4, device=buffer_device) + y = th.zeros(g.num_nodes(), self.encoder.n_embed*self.encoder.num_heads, device=buffer_device) att1_list = [] for input_nodes, output_nodes, blocks in dataloader: @@ -314,6 +312,7 @@ def __init__( super().__init__() self.aggregator = aggregator n_embed = 64 + self.n_embed = n_embed self.embedding = nn.Embedding(in_feats, n_embed) self.ln1 = nn.LayerNorm(n_embed) self.ln2 = nn.LayerNorm(n_embed) From 75d9ab3260b20c649fb55a7f5f718d86a2a34fb4 Mon Sep 17 00:00:00 2001 From: alejandro Date: Fri, 7 Apr 2023 12:11:25 +0200 Subject: [PATCH 08/11] m --- FISHscale/graphNN/models_deepresidual.py | 28 ++++++++++++++++-------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/FISHscale/graphNN/models_deepresidual.py b/FISHscale/graphNN/models_deepresidual.py index 0ce800e3..2eedce07 100755 --- a/FISHscale/graphNN/models_deepresidual.py +++ b/FISHscale/graphNN/models_deepresidual.py @@ -283,12 +283,14 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu dr = blocks[0].dstdata[self.features_name] if l != self.n_layers-1: h,att1 = layer(blocks[0], x,get_attention=True) - att1_list.append(att1.mean(1).cpu().detach()) + #att1_list.append(att1.mean(1).cpu().detach()) + att1_list.append(att1.cpu().detach()) h= h.flatten(1) else: h, att2 = layer(blocks[0], x,get_attention=True) - att2_list.append(att2.mean(1).cpu().detach()) + #att2_list.append(att2.mean(1).cpu().detach()) + att2_list.append(att2.cpu().detach()) h = h.mean(1) h = self.encoder.ln1(h) + self.encoder.embedding(dr) @@ -320,19 +322,27 @@ def __init__( layers = nn.ModuleList() self.num_heads = 4 self.n_layers = n_layers + [0,1,2] for i in range(0,n_layers-1): - if i > 0: - in_feats = n_hidden - x = 0.2 - else: - x = 0 - + #if i > 0: + # in_feats = n_hidden + # x = 0.2 + #else: + # x = 0 + if i == 0: layers.append(dglnn.GATv2Conv(n_embed, n_hidden, num_heads=self.num_heads, - feat_drop=x, + feat_drop=dropout, #allow_zero_in_degree=False )) + else: + layers.append(dglnn.GATv2Conv(n_embed*self.num_heads, + n_hidden, + num_heads=self.num_heads, + feat_drop=dropout, + #allow_zero_in_degree=False + )) layers.append(dglnn.GATv2Conv(n_embed*self.num_heads, From 0dcad06a276792c826ab5eaf4ddbdc09859e0fb8 Mon Sep 17 00:00:00 2001 From: alejandro Date: Fri, 7 Apr 2023 12:11:43 +0200 Subject: [PATCH 09/11] m --- FISHscale/graphNN/models_deepresidual.py | 1 - 1 file changed, 1 deletion(-) diff --git a/FISHscale/graphNN/models_deepresidual.py b/FISHscale/graphNN/models_deepresidual.py index 2eedce07..5eaba376 100755 --- a/FISHscale/graphNN/models_deepresidual.py +++ b/FISHscale/graphNN/models_deepresidual.py @@ -322,7 +322,6 @@ def __init__( layers = nn.ModuleList() self.num_heads = 4 self.n_layers = n_layers - [0,1,2] for i in range(0,n_layers-1): #if i > 0: # in_feats = n_hidden From 529214c287de000d3f16557000aedf3789077a14 Mon Sep 17 00:00:00 2001 From: alejandro Date: Fri, 7 Apr 2023 15:08:30 +0200 Subject: [PATCH 10/11] m --- FISHscale/graphNN/cellularneighborhoods.py | 13 ++++++------ FISHscale/graphNN/models_deepresidual.py | 24 ++++++++++++---------- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/FISHscale/graphNN/cellularneighborhoods.py b/FISHscale/graphNN/cellularneighborhoods.py index 545f0162..6875230e 100755 --- a/FISHscale/graphNN/cellularneighborhoods.py +++ b/FISHscale/graphNN/cellularneighborhoods.py @@ -111,7 +111,7 @@ def __init__(self, self.unique_labels = np.unique(anndata.obs[self.label_name].values) anndata = anndata[(anndata[:, self.genes].X.sum(axis=1) > 5), :] - anndata.raw = anndata + #anndata.raw = anndata if normalize: sc.pp.normalize_total(anndata, target_sum=1e4) self.anndata = anndata @@ -420,15 +420,16 @@ def get_attention(self): labelled (bool, optional): [description]. Defaults to True. """ self.model.eval() - self.attention_ngh1, self.attention_ngh2 = self.model.module.inference_attention( + self.attention = self.model.module.inference_attention( self.g, self.model.device, 5*512, 0, nodes=self.g.nodes(), buffer_device=self.g.device)#.detach().numpy() - self.g.edata['attention1'] = self.attention_ngh1 - self.g.edata['attention2'] = self.attention_ngh2 + for e, a in enumerate(self.attention): + self.g.edata['attention{}'.format(e+1)] = a + self.save_graph() def get_attention_nodes(self,nodes=None): @@ -443,13 +444,13 @@ def get_attention_nodes(self,nodes=None): labelled (bool, optional): [description]. Defaults to True. """ self.model.eval() - att1,att2 = self.model.module.inference_attention(self.g, + att = self.model.module.inference_attention(self.g, self.model.device, 5*512, 0, nodes=nodes, buffer_device=self.g.device)#.detach().numpy() - return att1, att2 + return att def compute_distance_th(self,coords): """ diff --git a/FISHscale/graphNN/models_deepresidual.py b/FISHscale/graphNN/models_deepresidual.py index 5eaba376..fa323229 100755 --- a/FISHscale/graphNN/models_deepresidual.py +++ b/FISHscale/graphNN/models_deepresidual.py @@ -119,9 +119,12 @@ def training_step(self, batch, batch_idx): if self.loss_type == 'unsupervised': graph_loss = self.loss_fcn(zn_loc, pos, neg).mean() decoder_n1 = self.module.encoder.decoder(zn_loc).softmax(dim=-1) - adjacency_matrix = mfgs[1].adjacency_matrix().to_dense() - feats_n1 = F.one_hot((mfgs[1].srcdata[self.features_name]), num_classes=self.in_feats).T - feats_n1 = (th.tensor(feats_n1,dtype=th.float32)@adjacency_matrix.to(self.device)).T + feats_n1 = F.one_hot((mfgs[-1].srcdata[self.features_name]), num_classes=self.in_feats).T + #feats_n1 = (th.tensor(feats_n1,dtype=th.float32)@adjacency_matrix.to(self.device)).T + feats_n1 = th.sparse.mm( + th.tensor(feats_n1,dtype=th.float32).to_sparse_coo(), + mfgs[-1].adjacency_matrix().to(self.device) + ).to_dense().T feats_n1 = feats_n1.softmax(dim=-1) #print(feats_n1.shape, decoder_n1.shape) graph_loss += - nn.CosineSimilarity(dim=1, eps=1e-08)(decoder_n1, feats_n1).mean(axis=0) @@ -265,16 +268,15 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu if buffer_device is None: buffer_device = device + self.attention_list = [[] for x in range(self.n_layers)] for l, layer in enumerate(self.encoder.encoder_dict['GS']): if l == self.n_layers - 1: y = th.zeros(g.num_nodes(), self.encoder.n_embed, device=buffer_device) - att2_list = [] #if not self.supervised else th.zeros(g.num_nodes(), self.n_classes) else: y = th.zeros(g.num_nodes(), self.encoder.n_embed*self.encoder.num_heads, device=buffer_device) - att1_list = [] - for input_nodes, output_nodes, blocks in dataloader: + for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader): #x = blocks[0].srcdata['h'] if l == 0: x = self.encoder.embedding(blocks[0].srcdata['h']) @@ -282,15 +284,15 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu x = blocks[0].srcdata['h'] dr = blocks[0].dstdata[self.features_name] if l != self.n_layers-1: - h,att1 = layer(blocks[0], x,get_attention=True) + h,att = layer(blocks[0], x,get_attention=True) #att1_list.append(att1.mean(1).cpu().detach()) - att1_list.append(att1.cpu().detach()) + self.attention_list[l].append(att.cpu().detach()) h= h.flatten(1) else: - h, att2 = layer(blocks[0], x,get_attention=True) + h, att = layer(blocks[0], x,get_attention=True) #att2_list.append(att2.mean(1).cpu().detach()) - att2_list.append(att2.cpu().detach()) + self.attention_list[l].append(att.cpu().detach()) h = h.mean(1) h = self.encoder.ln1(h) + self.encoder.embedding(dr) @@ -298,7 +300,7 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu y[output_nodes] = h.cpu().detach().to(buffer_device) g.ndata['h'] = y - return th.concat(att1_list), th.concat(att2_list) + return [th.concat(a) for a in self.attention_list] class Encoder(nn.Module): def __init__( From 7130826f73d4c230c5c08513dfadd58aaf7b2385 Mon Sep 17 00:00:00 2001 From: alejandro Date: Fri, 21 Apr 2023 17:42:05 +0200 Subject: [PATCH 11/11] m --- FISHscale/graphNN/cellularneighborhoods.py | 41 +++++++++++++++++++--- FISHscale/graphNN/models_deepresidual.py | 14 ++++---- 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/FISHscale/graphNN/cellularneighborhoods.py b/FISHscale/graphNN/cellularneighborhoods.py index 6875230e..9a74009b 100755 --- a/FISHscale/graphNN/cellularneighborhoods.py +++ b/FISHscale/graphNN/cellularneighborhoods.py @@ -398,10 +398,10 @@ def get_latents(self): labelled (bool, optional): [description]. Defaults to True. """ self.model.eval() - self.g.to('cuda') + #self.g.to('cuda') self.latent_unlabelled, _ = self.model.module.inference( self.g, - self.g.device, + self.model.device, 10*512, 0) @@ -466,8 +466,8 @@ def compute_distance_th(self,coords): from scipy.spatial import cKDTree as KDTree kdT = KDTree(coords) - d,i = kdT.query(coords,k=3) - d_th = np.percentile(d[:,-1],95)*self.distance_factor + d,i = kdT.query(coords,k=2) + d_th = np.percentile(d[:,-1],97)*self.distance_factor logging.info('Chosen dist to connect molecules into a graph: {}'.format(d_th)) print('Chosen dist to connect molecules into a graph: {}'.format(d_th)) return d_th @@ -566,9 +566,42 @@ def cluster(self, n_clusters=10): [type]: [description] """ from sklearn.cluster import MiniBatchKMeans + import scanpy as sc + from sklearn.linear_model import SGDClassifier + from sklearn.preprocessing import StandardScaler + from sklearn.pipeline import make_pipeline clusters = MiniBatchKMeans(n_clusters=n_clusters).fit_predict(self.latent_unlabelled) self.g.ndata['CellularNgh'] = th.tensor(clusters) + + '''logging.info('Latent embeddings generated for {} molecules'.format(self.latent_unlabelled.shape[0])) + + + random_sample_train = np.random.choice( + len(self.latent_unlabelled ), + np.min([len(self.latent_unlabelled ),250000]), + replace=False) + + training_latents = self.latent_unlabelled[random_sample_train,:] + adata = sc.AnnData(X=training_latents.detach().numpy()) + logging.info('Building neighbor graph for clustering...') + sc.pp.neighbors(adata, n_neighbors=15) + logging.info('Running Leiden clustering...') + sc.tl.leiden(adata, random_state=42, resolution=1) + logging.info('Leiden clustering done.') + clusters= adata.obs['leiden'].values + + logging.info('Total of {} found'.format(len(np.unique(clusters)))) + clf = make_pipeline(StandardScaler(), SGDClassifier(loss='log_loss', max_iter=1000, tol=1e-3)) + clf.fit(training_latents, clusters) + clusters = clf.predict(self.latent_unlabelled).astype('int8') + + clf_total = make_pipeline(StandardScaler(), SGDClassifier(loss='log_loss', max_iter=1000, tol=1e-3)) + clf_total.fit(self.latent_unlabelled.detach().numpy(), clusters) + clusters = clf.predict(self.latent_unlabelled.detach().numpy()).astype('int8') + self.g.ndata['CellularNgh'] = th.tensor(clusters)''' + + self.save_graph() return clusters diff --git a/FISHscale/graphNN/models_deepresidual.py b/FISHscale/graphNN/models_deepresidual.py index fa323229..29390767 100755 --- a/FISHscale/graphNN/models_deepresidual.py +++ b/FISHscale/graphNN/models_deepresidual.py @@ -118,7 +118,7 @@ def training_step(self, batch, batch_idx): zn_loc = self.module.encoder(batch_inputs,mfgs, dr=dr) if self.loss_type == 'unsupervised': graph_loss = self.loss_fcn(zn_loc, pos, neg).mean() - decoder_n1 = self.module.encoder.decoder(zn_loc).softmax(dim=-1) + '''decoder_n1 = self.module.encoder.decoder(zn_loc).softmax(dim=-1) feats_n1 = F.one_hot((mfgs[-1].srcdata[self.features_name]), num_classes=self.in_feats).T #feats_n1 = (th.tensor(feats_n1,dtype=th.float32)@adjacency_matrix.to(self.device)).T feats_n1 = th.sparse.mm( @@ -127,7 +127,7 @@ def training_step(self, batch, batch_idx): ).to_dense().T feats_n1 = feats_n1.softmax(dim=-1) #print(feats_n1.shape, decoder_n1.shape) - graph_loss += - nn.CosineSimilarity(dim=1, eps=1e-08)(decoder_n1, feats_n1).mean(axis=0) + graph_loss += - nn.CosineSimilarity(dim=1, eps=1e-08)(decoder_n1, feats_n1).mean(axis=0)''' else: graph_loss = F.cross_entropy(zn_loc, mfgs[-1].dstdata['label']) @@ -210,7 +210,7 @@ def inference(self, g, device, batch_size, num_workers, core_nodes=None): layers. """ self.eval() - g.ndata['h'] = g.ndata[self.features_name] + g.ndata['h'] = g.ndata[self.features_name].long() sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) dataloader = dgl.dataloading.NodeDataLoader( @@ -236,7 +236,7 @@ def inference(self, g, device, batch_size, num_workers, core_nodes=None): x = self.encoder.embedding(blocks[0].srcdata['h']) else: x = blocks[0].srcdata['h'] - dr = blocks[0].dstdata[self.features_name] + dr = blocks[0].dstdata[self.features_name].long() if l != self.n_layers-1: h,att1 = layer(blocks[0], x,get_attention=True) h= h.flatten(1) @@ -259,7 +259,7 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu nodes = th.arange(g.num_nodes()).to(g.device) - g.ndata['h'] = g.ndata[self.features_name] + g.ndata['h'] = g.ndata[self.features_name].long() sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h']) dataloader = dgl.dataloading.NodeDataLoader( g, nodes, sampler, device=device, @@ -282,7 +282,7 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu x = self.encoder.embedding(blocks[0].srcdata['h']) else: x = blocks[0].srcdata['h'] - dr = blocks[0].dstdata[self.features_name] + dr = blocks[0].dstdata[self.features_name].long() if l != self.n_layers-1: h,att = layer(blocks[0], x,get_attention=True) #att1_list.append(att1.mean(1).cpu().detach()) @@ -382,7 +382,7 @@ def forward(self, x, blocks=None, dr=0): h = layer(block, h,).flatten(1) else: h = layer(block, h,).mean(1) - h = self.ln1(h) + self.embedding(dr) + h = self.ln1(h) + self.embedding(dr.long()) h = self.fw(self.ln2(h)) + h #z_scale = th.exp(self.gs_var(h)) +1e-6 return h \ No newline at end of file