Skip to content

Commit

Permalink
m
Browse files Browse the repository at this point in the history
  • Loading branch information
alejandro committed Apr 21, 2023
1 parent 529214c commit 7130826
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
41 changes: 37 additions & 4 deletions FISHscale/graphNN/cellularneighborhoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,10 @@ def get_latents(self):
labelled (bool, optional): [description]. Defaults to True.
"""
self.model.eval()
self.g.to('cuda')
#self.g.to('cuda')
self.latent_unlabelled, _ = self.model.module.inference(
self.g,
self.g.device,
self.model.device,
10*512,
0)

Expand Down Expand Up @@ -466,8 +466,8 @@ def compute_distance_th(self,coords):

from scipy.spatial import cKDTree as KDTree
kdT = KDTree(coords)
d,i = kdT.query(coords,k=3)
d_th = np.percentile(d[:,-1],95)*self.distance_factor
d,i = kdT.query(coords,k=2)
d_th = np.percentile(d[:,-1],97)*self.distance_factor
logging.info('Chosen dist to connect molecules into a graph: {}'.format(d_th))
print('Chosen dist to connect molecules into a graph: {}'.format(d_th))
return d_th
Expand Down Expand Up @@ -566,9 +566,42 @@ def cluster(self, n_clusters=10):
[type]: [description]
"""
from sklearn.cluster import MiniBatchKMeans
import scanpy as sc
from sklearn.linear_model import SGDClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline

clusters = MiniBatchKMeans(n_clusters=n_clusters).fit_predict(self.latent_unlabelled)
self.g.ndata['CellularNgh'] = th.tensor(clusters)

'''logging.info('Latent embeddings generated for {} molecules'.format(self.latent_unlabelled.shape[0]))
random_sample_train = np.random.choice(
len(self.latent_unlabelled ),
np.min([len(self.latent_unlabelled ),250000]),
replace=False)
training_latents = self.latent_unlabelled[random_sample_train,:]
adata = sc.AnnData(X=training_latents.detach().numpy())
logging.info('Building neighbor graph for clustering...')
sc.pp.neighbors(adata, n_neighbors=15)
logging.info('Running Leiden clustering...')
sc.tl.leiden(adata, random_state=42, resolution=1)
logging.info('Leiden clustering done.')
clusters= adata.obs['leiden'].values
logging.info('Total of {} found'.format(len(np.unique(clusters))))
clf = make_pipeline(StandardScaler(), SGDClassifier(loss='log_loss', max_iter=1000, tol=1e-3))
clf.fit(training_latents, clusters)
clusters = clf.predict(self.latent_unlabelled).astype('int8')
clf_total = make_pipeline(StandardScaler(), SGDClassifier(loss='log_loss', max_iter=1000, tol=1e-3))
clf_total.fit(self.latent_unlabelled.detach().numpy(), clusters)
clusters = clf.predict(self.latent_unlabelled.detach().numpy()).astype('int8')
self.g.ndata['CellularNgh'] = th.tensor(clusters)'''


self.save_graph()
return clusters

Expand Down
14 changes: 7 additions & 7 deletions FISHscale/graphNN/models_deepresidual.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def training_step(self, batch, batch_idx):
zn_loc = self.module.encoder(batch_inputs,mfgs, dr=dr)
if self.loss_type == 'unsupervised':
graph_loss = self.loss_fcn(zn_loc, pos, neg).mean()
decoder_n1 = self.module.encoder.decoder(zn_loc).softmax(dim=-1)
'''decoder_n1 = self.module.encoder.decoder(zn_loc).softmax(dim=-1)
feats_n1 = F.one_hot((mfgs[-1].srcdata[self.features_name]), num_classes=self.in_feats).T
#feats_n1 = (th.tensor(feats_n1,dtype=th.float32)@adjacency_matrix.to(self.device)).T
feats_n1 = th.sparse.mm(
Expand All @@ -127,7 +127,7 @@ def training_step(self, batch, batch_idx):
).to_dense().T
feats_n1 = feats_n1.softmax(dim=-1)
#print(feats_n1.shape, decoder_n1.shape)
graph_loss += - nn.CosineSimilarity(dim=1, eps=1e-08)(decoder_n1, feats_n1).mean(axis=0)
graph_loss += - nn.CosineSimilarity(dim=1, eps=1e-08)(decoder_n1, feats_n1).mean(axis=0)'''

else:
graph_loss = F.cross_entropy(zn_loc, mfgs[-1].dstdata['label'])
Expand Down Expand Up @@ -210,7 +210,7 @@ def inference(self, g, device, batch_size, num_workers, core_nodes=None):
layers.
"""
self.eval()
g.ndata['h'] = g.ndata[self.features_name]
g.ndata['h'] = g.ndata[self.features_name].long()
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])

dataloader = dgl.dataloading.NodeDataLoader(
Expand All @@ -236,7 +236,7 @@ def inference(self, g, device, batch_size, num_workers, core_nodes=None):
x = self.encoder.embedding(blocks[0].srcdata['h'])
else:
x = blocks[0].srcdata['h']
dr = blocks[0].dstdata[self.features_name]
dr = blocks[0].dstdata[self.features_name].long()
if l != self.n_layers-1:
h,att1 = layer(blocks[0], x,get_attention=True)
h= h.flatten(1)
Expand All @@ -259,7 +259,7 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu
nodes = th.arange(g.num_nodes()).to(g.device)


g.ndata['h'] = g.ndata[self.features_name]
g.ndata['h'] = g.ndata[self.features_name].long()
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1, prefetch_node_feats=['h'])
dataloader = dgl.dataloading.NodeDataLoader(
g, nodes, sampler, device=device,
Expand All @@ -282,7 +282,7 @@ def inference_attention(self, g, device, batch_size, num_workers, nodes=None, bu
x = self.encoder.embedding(blocks[0].srcdata['h'])
else:
x = blocks[0].srcdata['h']
dr = blocks[0].dstdata[self.features_name]
dr = blocks[0].dstdata[self.features_name].long()
if l != self.n_layers-1:
h,att = layer(blocks[0], x,get_attention=True)
#att1_list.append(att1.mean(1).cpu().detach())
Expand Down Expand Up @@ -382,7 +382,7 @@ def forward(self, x, blocks=None, dr=0):
h = layer(block, h,).flatten(1)
else:
h = layer(block, h,).mean(1)
h = self.ln1(h) + self.embedding(dr)
h = self.ln1(h) + self.embedding(dr.long())
h = self.fw(self.ln2(h)) + h
#z_scale = th.exp(self.gs_var(h)) +1e-6
return h

0 comments on commit 7130826

Please sign in to comment.