Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to calculate embeddings for variable by distance aggregation #807

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5a52976
Add method to calculate embeddings for variable by distance aggregation
LLehner Mar 4, 2024
eb84518
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
488da20
Fix pre-commit
LLehner Mar 4, 2024
8fce577
Fix pre-commit
LLehner Mar 4, 2024
0b72494
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
edcca87
Update param name
LLehner Mar 4, 2024
4be2529
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2024
f91c1af
Merge branch 'var_by_distance_clustering' of https://github.com/scver…
LLehner Mar 4, 2024
cfe496c
Remove duplicate code
LLehner Apr 22, 2024
c4fca29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
64e38df
Improve performance, Update output
LLehner Apr 22, 2024
3ab8467
Improve performance, Update output
LLehner Apr 22, 2024
9eabd0d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 22, 2024
a40a8cf
Remove import
LLehner Apr 22, 2024
90108ad
Merge branch 'var_by_distance_clustering' of https://github.com/scver…
LLehner Apr 22, 2024
09c72b0
Remove import
LLehner Apr 22, 2024
3396146
Update return
LLehner May 26, 2024
a44f661
Merge branch 'var_by_distance_clustering' of https://github.com/scver…
LLehner May 26, 2024
41a2ae4
Merge branch 'main' into var_by_distance_clustering
LLehner May 26, 2024
67bdd5c
Fix pre-commit
LLehner May 26, 2024
99b41b0
Merge branch 'var_by_distance_clustering' of https://github.com/scver…
LLehner May 26, 2024
876c4ed
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 26, 2024
8ee07ba
Fix pre-commit
LLehner May 26, 2024
d3cefff
Fix pre-commit
LLehner May 27, 2024
80e23fc
Merge branch 'main' into var_by_distance_clustering
timtreis Jun 20, 2024
f2b0e12
Merge branch 'main' into var_by_distance_clustering
timtreis Jul 9, 2024
2a863a4
Merge branch 'main' into var_by_distance_clustering
timtreis Aug 7, 2024
5729676
Fix indices; Update return type
LLehner Aug 8, 2024
7dfa933
Add spatialdata as input
LLehner Aug 26, 2024
bf1dcff
Merge branch 'main' into var_by_distance_clustering
LLehner Aug 27, 2024
d6e5ecd
Update docstring
LLehner Aug 27, 2024
6e724f0
Merge branch 'main' into var_by_distance_clustering
timtreis Oct 1, 2024
6e28662
Merge branch 'main' into var_by_distance_clustering
LLehner Oct 8, 2024
1b1c05a
Merge branch 'main' into var_by_distance_clustering
timtreis Nov 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/squidpy/tl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@

from squidpy.tl._sliding_window import _calculate_window_corners, sliding_window
from squidpy.tl._var_by_distance import var_by_distance
from squidpy.tl._var_embeddings import var_embeddings
90 changes: 90 additions & 0 deletions src/squidpy/tl/_var_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

from typing import Any, Optional

import numpy as np
import pandas as pd
import scanpy as sc
from anndata import AnnData
from scanpy import logging as logg
from spatialdata import SpatialData

from squidpy._docs import d

__all__ = ["var_embeddings"]


@d.dedent
def var_embeddings(
sdata: SpatialData,
table: str,
group: str,
design_matrix_key: str = "design_matrix",
n_bins: int = 100,
include_anchor: bool = False,
) -> AnnData | pd.DataFrame:
"""
Bin variables by previously calculated distance to an anchor point.

Parameters
----------
%(adata)s
table
Name of the table in `SpatialData` object.
group
Annotation column in design matrix, given by `design_matrix_key`, that is used as anchor.
design_matrix_key
Name of the design matrix saved to `.obsm`.
n_bins
Number of bins to use for aggregation.
include_anchor
Whether to include the variable counts belonging to the anchor point in the aggregation.
Returns
-------
Stores binned count matrices in `sdata.tables["var_by_dist_bins"]`.
"""

adata = sdata.tables[table]

if design_matrix_key not in adata.obsm.keys():
raise ValueError(f"`.obsm['{design_matrix_key}']` does not exist. Aborting.")

logg.info("Calculating embeddings for distance aggregations by gene.")

df = adata.obsm[design_matrix_key].copy()
# bin the data by distance
df["bins"] = pd.cut(df[group], bins=n_bins)
# get median value of each interval
df["median_value"] = df["bins"].apply(calculate_median)
# turn categorical NaNs into float 0s
df["median_value"] = pd.to_numeric(df["median_value"], errors="coerce").fillna(0).astype(float)
# get count matrix and add binned distance to each .obs
X_df = adata.to_df()
X_df["distance"] = df["median_value"]
# aggregate the count matrix by the bins
aggregated_df = X_df.groupby(["distance"]).sum()
# transpose the count matrix
result = aggregated_df.T
# optionally include or remove variable values for distance 0 (anchor point)
start_bin = 0
if not include_anchor:
result = result.drop(result.columns[0], axis=1)
start_bin = 1

# rename column names for plotting
result.columns = range(start_bin, 101)
# create genes x genes identity matrix (required for highlighting genes in plot)
obs = pd.DataFrame(np.eye(len(result)), columns=result.index)
obs.replace(1.0, pd.Series(obs.columns, obs.columns), inplace=True)
obs.replace(0.0, "other", inplace=True)
obs = obs.astype("category")
obs.index = result.index
adata_new = AnnData(X=result, obs=obs, var=pd.DataFrame(index=result.columns))

sdata.tables["var_by_dist_bins"] = adata_new


def calculate_median(interval: pd.Interval) -> Any:
median = interval.mid

return median
Loading