Skip to content

Commit

Permalink
Fix functions to be generalizable for more datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
sjcshin committed Aug 29, 2024
1 parent 7b2b283 commit e8858a3
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 46 deletions.
77 changes: 56 additions & 21 deletions R/interoperability.R
Original file line number Diff line number Diff line change
Expand Up @@ -3355,9 +3355,12 @@ spatialdataToGiotto <- function(
)

# Attach hires image
raster <- terra::rast(extract_image(sdata))
giotto_image <- createGiottoLargeImage(raster)
gobject <- addGiottoLargeImage(gobject = gobject, largeImages = c(giotto_image))
extracted_images <- extract_image(sdata)
extract_image_names <- extract_image_names(sdata)

raster_image_list <- lapply(extracted_images, terra::rast)
large_image_list <- createGiottoLargeImageList(raster_image_list, names = extract_image_names)
gobject <- addGiottoLargeImage(gobject = gobject, largeImages = large_image_list)

# Attach metadata
cm <- readCellMetadata(cm)
Expand Down Expand Up @@ -3463,7 +3466,7 @@ spatialdataToGiotto <- function(
vert <- unique(x = c(nn_dt$from_cell_ID, nn_dt$to_cell_ID))
nn_network_igraph <- igraph::graph_from_data_frame(nn_dt[, .(from_cell_ID, to_cell_ID, weight, distance)], directed = TRUE, vertices = vert)

nn_info <- extract_NN_info(adata = adata, key_added = n_key_added_it)
nn_info <- extract_NN_info(sdata = sdata, key_added = n_key_added_it)

net_type <- "kNN" # anndata default
if (("sNN" %in% n_key_added_it) & !is.null(n_key_added_it)) {
Expand Down Expand Up @@ -3587,6 +3590,35 @@ spatialdataToGiotto <- function(
)
}
}

### Layers
lay_names <- extract_layer_names(sdata)
if (!is.null(lay_names)) {
for (l_n in lay_names) {
lay <- extract_layered_data(sdata, layer_name = l_n)
if ("data.frame" %in% class(lay)) {
names(lay) <- fID
row.names(lay) <- cID
} else {
lay@Dimnames[[1]] <- fID
lay@Dimnames[[2]] <- cID
}
layExprObj <- createExprObj(lay, name = l_n)
gobject <- set_expression_values(
gobject = gobject,
spat_unit = spat_unit,
feat_type = feat_type,
name = l_n,
values = layExprObj
)
}
}

gobject <- update_giotto_params(
gobject = gobject,
description = "_AnnData_Conversion"
)

return(gobject)
}

Expand Down Expand Up @@ -3646,29 +3678,32 @@ giottoToSpatialData <- function(
save_directory = temp
)

# Extract GiottoImage
gimg <- getGiottoImage(gobject, image_type = "largeImage")

# Temporarily save the image to disk
writeGiottoLargeImage(
giottoLargeImage = gimg,
gobject = gobject,
largeImage_name = "largeImage",
filename = "temp_image.png",
dataType = NULL,
max_intensity = NULL,
overwrite = TRUE,
verbose = TRUE
)
# Extract GiottoImage only if an image exists
image_exists <- NULL
if (length(slot(gobject, "images")) > 0) {
image_exists <- TRUE
gimg_list <- slot(gobject, "images")
for (i in seq_along(gimg_list)) {
img_name <- slot(gimg_list[[i]], "name")
writeGiottoLargeImage(
giottoLargeImage = gimg_list[[i]],
gobject = gobject,
largeImage_name = img_name,
filename = paste0(temp, img_name, ".png"),
dataType = NULL,
max_intensity = NULL,
overwrite = TRUE,
verbose = TRUE
)
}
}

spat_locs <- getSpatialLocations(gobject, output="data.table")

# Create SpatialData object
createSpatialData(temp, spat_locs, spot_radius, save_directory)
createSpatialData(temp, spat_locs, spot_radius, save_directory, image_exists)

# Delete temporary files and folders
unlink("temp_image.png")
unlink("temp_image.png.aux.xml")
unlink(temp, recursive = TRUE)

# Successful Conversion
Expand Down
33 changes: 21 additions & 12 deletions inst/python/g2sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,23 @@
from xarray import DataArray
import geopandas as gpd
from shapely.geometry import Point
import os
import glob, os

from spatialdata import SpatialData
from spatialdata.models import Image2DModel, ShapesModel, TableModel
from spatialdata.transformations.transformations import Identity

def createImageModel():
def createImageModel(temp):
images = {}
hires_image_path = "temp_image.png"
hires_img = imread(hires_image_path).squeeze().transpose(2,0,1)
hires_img = DataArray(hires_img, dims=("c","y","x"))
images["hires_image"] = Image2DModel.parse(hires_img, transformations={"downscaled_hires": Identity()})
image_paths = glob.glob(temp+"*.png")
for path in image_paths:
image = imread(path).squeeze()
if len(image.shape) == 2:
image = np.expand_dims(image, axis=-1)
image = image.transpose(2,0,1)
image = DataArray(image, dims=("c","y","x"))
image_name = os.path.splitext(os.path.basename(path))[0]
images[image_name] = Image2DModel.parse(image)
return images

def createShapeModel(spat_locs, spot_radius):
Expand All @@ -33,15 +38,19 @@ def createShapeModel(spat_locs, spot_radius):
return shapes

def createTableModel(temp):
alist = os.listdir(temp)[0]
adata = ad.read_h5ad(os.path.join(temp, alist))
alist = glob.glob(temp+"*.h5ad")
adata = ad.read_h5ad(alist[0])
table = TableModel.parse(adata)
return table

def createSpatialData(temp, spat_locs, spot_radius, save_directory):
images = createImageModel()
def createSpatialData(temp, spat_locs, spot_radius, save_directory, image_exists):
if image_exists:
images = createImageModel(temp)
table = createTableModel(temp)
shapes = createShapeModel(spat_locs, spot_radius)
sd = SpatialData(table = table, images = images)
if image_exists:
sd = SpatialData(table = table, images = images)
else:
sd = SpatialData(table = table)
sd.shapes["Shapes"] = shapes
sd.write(save_directory)
sd.write(save_directory, overwrite = True)
31 changes: 18 additions & 13 deletions inst/python/sd2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ def read_spatialdata_from_path(sd_path = None):
# Extract gene expression
def extract_expression(sdata = None):
expr = sdata.table.X.transpose().todense()
expr_df = pd.DataFrame(expr, index=sdata.table.var['gene_ids'].index, columns=sdata.table.obs['array_row'].index)
expr_df = pd.DataFrame(expr, index=sdata.table.var.index, columns=sdata.table.obs.index)
return expr_df

# Extract cell IDs
def extract_cell_IDs(sdata = None):
cell_IDs = sdata.table.obs['array_row'].index.tolist()
cell_IDs = sdata.table.obs.index.tolist()
return cell_IDs

# Extract feature IDs
def extract_feat_IDs(sdata = None):
feat_IDs = sdata.table.var['gene_ids'].index.tolist()
feat_IDs = sdata.table.var.index.tolist()
return feat_IDs

# Metadata
Expand Down Expand Up @@ -124,17 +124,23 @@ def parse_obsm_for_spat_locs(sdata = None):
spat_locs["sdimy"] = -1 * spat_locs["sdimy"]
return spat_locs

# Extract hires image
# Extract images
def extract_image(sdata = None):
# Find SpatialData image name for hires image
for key in sdata.images.keys():
if "hires" in key:
hires_image_name = key
# Retrieve the list of images
image_list = list(sdata.images.keys())

# Extract image from SpatialData and convert it to numpy array
hires_image = sdata.images[hires_image_name]
hires_image_array = np.transpose(hires_image.compute().data, (1, 2, 0)) # Transpose to (y, x, c)
return hires_image_array
extracted_images = []
for image_key in image_list:
image = sdata.images[image_key]
image_array = np.transpose(image.compute().data, (1, 2, 0)) # Transpose to (y, x, c)
extracted_images.append(image_array)
return extracted_images

# Extract image names
def extract_image_names(sdata = None):
image_names = list(sdata.images.keys())
return image_names

# Extract PCA
def extract_pca(sdata = None):
Expand Down Expand Up @@ -240,7 +246,6 @@ def extract_NN_connectivities(sdata = None, key_added = None):
for nk in nn_key_list:
if "connectivities" in nk:
connectivities = sdata.table.obsp[nk]

return connectivities

def extract_NN_distances(sdata = None, key_added = None):
Expand Down Expand Up @@ -318,7 +323,7 @@ def find_SN_keys(sdata = None, key_added = None):
with open(key_added) as f:
for line in f.readlines():
line = line.strip()
line_key_added = line + suffix
line_key_added = line + "_" + suffix
line_keys.append(line_key_added)
for key in line_keys:
map_keys = sdata.table.uns[key].keys()
Expand Down

0 comments on commit e8858a3

Please sign in to comment.