From e8858a377ec43d93c845dd33c18cac2810b124d5 Mon Sep 17 00:00:00 2001 From: Crystal Shin <54536735+sjcshin@users.noreply.github.com> Date: Thu, 29 Aug 2024 16:28:31 -0400 Subject: [PATCH] Fix functions to be generalizable for more datasets --- R/interoperability.R | 77 ++++++++++++++++++++++++++++++++------------ inst/python/g2sd.py | 33 ++++++++++++------- inst/python/sd2g.py | 31 ++++++++++-------- 3 files changed, 95 insertions(+), 46 deletions(-) diff --git a/R/interoperability.R b/R/interoperability.R index 74466136..34da2872 100644 --- a/R/interoperability.R +++ b/R/interoperability.R @@ -3355,9 +3355,12 @@ spatialdataToGiotto <- function( ) # Attach hires image - raster <- terra::rast(extract_image(sdata)) - giotto_image <- createGiottoLargeImage(raster) - gobject <- addGiottoLargeImage(gobject = gobject, largeImages = c(giotto_image)) + extracted_images <- extract_image(sdata) + extract_image_names <- extract_image_names(sdata) + + raster_image_list <- lapply(extracted_images, terra::rast) + large_image_list <- createGiottoLargeImageList(raster_image_list, names = extract_image_names) + gobject <- addGiottoLargeImage(gobject = gobject, largeImages = large_image_list) # Attach metadata cm <- readCellMetadata(cm) @@ -3463,7 +3466,7 @@ spatialdataToGiotto <- function( vert <- unique(x = c(nn_dt$from_cell_ID, nn_dt$to_cell_ID)) nn_network_igraph <- igraph::graph_from_data_frame(nn_dt[, .(from_cell_ID, to_cell_ID, weight, distance)], directed = TRUE, vertices = vert) - nn_info <- extract_NN_info(adata = adata, key_added = n_key_added_it) + nn_info <- extract_NN_info(sdata = sdata, key_added = n_key_added_it) net_type <- "kNN" # anndata default if (("sNN" %in% n_key_added_it) & !is.null(n_key_added_it)) { @@ -3587,6 +3590,35 @@ spatialdataToGiotto <- function( ) } } + + ### Layers + lay_names <- extract_layer_names(sdata) + if (!is.null(lay_names)) { + for (l_n in lay_names) { + lay <- extract_layered_data(sdata, layer_name = l_n) + if ("data.frame" %in% class(lay)) { + names(lay) <- fID + row.names(lay) <- cID + } else { + lay@Dimnames[[1]] <- fID + lay@Dimnames[[2]] <- cID + } + layExprObj <- createExprObj(lay, name = l_n) + gobject <- set_expression_values( + gobject = gobject, + spat_unit = spat_unit, + feat_type = feat_type, + name = l_n, + values = layExprObj + ) + } + } + + gobject <- update_giotto_params( + gobject = gobject, + description = "_AnnData_Conversion" + ) + return(gobject) } @@ -3646,29 +3678,32 @@ giottoToSpatialData <- function( save_directory = temp ) - # Extract GiottoImage - gimg <- getGiottoImage(gobject, image_type = "largeImage") - - # Temporarily save the image to disk - writeGiottoLargeImage( - giottoLargeImage = gimg, - gobject = gobject, - largeImage_name = "largeImage", - filename = "temp_image.png", - dataType = NULL, - max_intensity = NULL, - overwrite = TRUE, - verbose = TRUE - ) + # Extract GiottoImage only if an image exists + image_exists <- NULL + if (length(slot(gobject, "images")) > 0) { + image_exists <- TRUE + gimg_list <- slot(gobject, "images") + for (i in seq_along(gimg_list)) { + img_name <- slot(gimg_list[[i]], "name") + writeGiottoLargeImage( + giottoLargeImage = gimg_list[[i]], + gobject = gobject, + largeImage_name = img_name, + filename = paste0(temp, img_name, ".png"), + dataType = NULL, + max_intensity = NULL, + overwrite = TRUE, + verbose = TRUE + ) + } + } spat_locs <- getSpatialLocations(gobject, output="data.table") # Create SpatialData object - createSpatialData(temp, spat_locs, spot_radius, save_directory) + createSpatialData(temp, spat_locs, spot_radius, save_directory, image_exists) # Delete temporary files and folders - unlink("temp_image.png") - unlink("temp_image.png.aux.xml") unlink(temp, recursive = TRUE) # Successful Conversion diff --git a/inst/python/g2sd.py b/inst/python/g2sd.py index 5f7d2233..7e15559c 100644 --- a/inst/python/g2sd.py +++ b/inst/python/g2sd.py @@ -5,18 +5,23 @@ from xarray import DataArray import geopandas as gpd from shapely.geometry import Point -import os +import glob, os from spatialdata import SpatialData from spatialdata.models import Image2DModel, ShapesModel, TableModel from spatialdata.transformations.transformations import Identity -def createImageModel(): +def createImageModel(temp): images = {} - hires_image_path = "temp_image.png" - hires_img = imread(hires_image_path).squeeze().transpose(2,0,1) - hires_img = DataArray(hires_img, dims=("c","y","x")) - images["hires_image"] = Image2DModel.parse(hires_img, transformations={"downscaled_hires": Identity()}) + image_paths = glob.glob(temp+"*.png") + for path in image_paths: + image = imread(path).squeeze() + if len(image.shape) == 2: + image = np.expand_dims(image, axis=-1) + image = image.transpose(2,0,1) + image = DataArray(image, dims=("c","y","x")) + image_name = os.path.splitext(os.path.basename(path))[0] + images[image_name] = Image2DModel.parse(image) return images def createShapeModel(spat_locs, spot_radius): @@ -33,15 +38,19 @@ def createShapeModel(spat_locs, spot_radius): return shapes def createTableModel(temp): - alist = os.listdir(temp)[0] - adata = ad.read_h5ad(os.path.join(temp, alist)) + alist = glob.glob(temp+"*.h5ad") + adata = ad.read_h5ad(alist[0]) table = TableModel.parse(adata) return table -def createSpatialData(temp, spat_locs, spot_radius, save_directory): - images = createImageModel() +def createSpatialData(temp, spat_locs, spot_radius, save_directory, image_exists): + if image_exists: + images = createImageModel(temp) table = createTableModel(temp) shapes = createShapeModel(spat_locs, spot_radius) - sd = SpatialData(table = table, images = images) + if image_exists: + sd = SpatialData(table = table, images = images) + else: + sd = SpatialData(table = table) sd.shapes["Shapes"] = shapes - sd.write(save_directory) + sd.write(save_directory, overwrite = True) diff --git a/inst/python/sd2g.py b/inst/python/sd2g.py index c63aee1e..f6b8af88 100644 --- a/inst/python/sd2g.py +++ b/inst/python/sd2g.py @@ -28,17 +28,17 @@ def read_spatialdata_from_path(sd_path = None): # Extract gene expression def extract_expression(sdata = None): expr = sdata.table.X.transpose().todense() - expr_df = pd.DataFrame(expr, index=sdata.table.var['gene_ids'].index, columns=sdata.table.obs['array_row'].index) + expr_df = pd.DataFrame(expr, index=sdata.table.var.index, columns=sdata.table.obs.index) return expr_df # Extract cell IDs def extract_cell_IDs(sdata = None): - cell_IDs = sdata.table.obs['array_row'].index.tolist() + cell_IDs = sdata.table.obs.index.tolist() return cell_IDs # Extract feature IDs def extract_feat_IDs(sdata = None): - feat_IDs = sdata.table.var['gene_ids'].index.tolist() + feat_IDs = sdata.table.var.index.tolist() return feat_IDs # Metadata @@ -124,17 +124,23 @@ def parse_obsm_for_spat_locs(sdata = None): spat_locs["sdimy"] = -1 * spat_locs["sdimy"] return spat_locs -# Extract hires image +# Extract images def extract_image(sdata = None): - # Find SpatialData image name for hires image - for key in sdata.images.keys(): - if "hires" in key: - hires_image_name = key + # Retrieve the list of images + image_list = list(sdata.images.keys()) # Extract image from SpatialData and convert it to numpy array - hires_image = sdata.images[hires_image_name] - hires_image_array = np.transpose(hires_image.compute().data, (1, 2, 0)) # Transpose to (y, x, c) - return hires_image_array + extracted_images = [] + for image_key in image_list: + image = sdata.images[image_key] + image_array = np.transpose(image.compute().data, (1, 2, 0)) # Transpose to (y, x, c) + extracted_images.append(image_array) + return extracted_images + +# Extract image names +def extract_image_names(sdata = None): + image_names = list(sdata.images.keys()) + return image_names # Extract PCA def extract_pca(sdata = None): @@ -240,7 +246,6 @@ def extract_NN_connectivities(sdata = None, key_added = None): for nk in nn_key_list: if "connectivities" in nk: connectivities = sdata.table.obsp[nk] - return connectivities def extract_NN_distances(sdata = None, key_added = None): @@ -318,7 +323,7 @@ def find_SN_keys(sdata = None, key_added = None): with open(key_added) as f: for line in f.readlines(): line = line.strip() - line_key_added = line + suffix + line_key_added = line + "_" + suffix line_keys.append(line_key_added) for key in line_keys: map_keys = sdata.table.uns[key].keys()