Skip to content

Commit

Permalink
Merge pull request #63 from dpeerlab/fix_hvg
Browse files Browse the repository at this point in the history
Fix args to highly_variable_genes function in ENVI.py
  • Loading branch information
DoronHav authored May 20, 2024
2 parents 0f145a1 + 3923485 commit 908b12b
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "scenvi"
version = "0.3.3"
version = "0.3.4"
description = "Integration of scRNA-seq and spatial transcriptomics data"
authors = ["Doron Haviv"]
license = "MIT"
Expand Down
20 changes: 12 additions & 8 deletions scenvi/ENVI.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,22 @@ def __init__(
stable_eps=1e-6,
):

self.spatial_data = spatial_data
self.spatial_data = spatial_data[:, np.intersect1d(spatial_data.var_names, sc_data.var_names)]
self.sc_data = sc_data

if "highly_variable" not in sc_data.var.columns:

sc_data.layers["log"] = np.log(sc_data.X + 1)
sc.pp.highly_variable_genes(
sc_data, layer="log", n_top_genes=min(num_HVG, sc_data.shape[-1])
)
if "highly_variable" not in self.sc_data.var.columns:
if 'log' in self.sc_data.layers.keys():
sc.pp.highly_variable_genes(self.sc_data, n_top_genes=num_HVG, layer="log")
elif('log1p' in self.sc_data.layers.keys()):
sc.pp.highly_variable_genes(self.sc_data, n_top_genes=num_HVG, layer="log1p")
elif(self.sc_data.X.min() < 0):
sc.pp.highly_variable_genes(self.sc_data, n_top_genes=num_HVG)
else:
sc_data.layers["log"] = np.log(self.sc_data.X + 1)
sc.pp.highly_variable_genes(self.sc_data, n_top_genes=num_HVG, layer="log")

sc_genes_keep = np.union1d(
sc_data.var_names[sc_data.var.highly_variable], self.spatial_data.var_names
self.sc_data.var_names[self.sc_data.var.highly_variable], self.spatial_data.var_names
)
if len(sc_genes) > 0:
sc_genes_keep = np.union1d(sc_genes_keep, sc_genes)
Expand Down
14 changes: 10 additions & 4 deletions scenvi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from flax import struct
from flax.training import train_state
from jax import random

import scipy.sparse

class FeedForward(nn.Module):
"""
Expand Down Expand Up @@ -254,7 +254,6 @@ def niche_cell_type(
)
return cell_type_niche


def compute_covet(
spatial_data, k=8, g=64, genes=[], spatial_key="spatial", batch_key=-1
):
Expand All @@ -277,8 +276,15 @@ def compute_covet(
CovGenes = spatial_data.var_names
else:
if "highly_variable" not in spatial_data.var.columns:
spatial_data.layers["log"] = np.log(spatial_data.X + 1)
sc.pp.highly_variable_genes(spatial_data, n_top_genes=g, layer="log")
if 'log' in spatial_data.layers.keys():
sc.pp.highly_variable_genes(spatial_data, n_top_genes=g, layer="log")
elif('log1p' in spatial_data.layers.keys()):
sc.pp.highly_variable_genes(spatial_data, n_top_genes=g, layer="log1p")
elif(spatial_data.X.min() < 0):
sc.pp.highly_variable_genes(spatial_data, n_top_genes=g)
else:
spatial_data.layers["log"] = np.log(spatial_data.X + 1)
sc.pp.highly_variable_genes(spatial_data, n_top_genes=g, layer="log")

CovGenes = np.asarray(spatial_data.var_names[spatial_data.var.highly_variable])
if len(genes) > 0:
Expand Down

0 comments on commit 908b12b

Please sign in to comment.