Skip to content

Commit

Permalink
Update get_cbc
Browse files Browse the repository at this point in the history
* Add return type.
* Rename argument `graph` to `graph_key`.
* Update docstrings.
* Rename and refactor `get_pearson_corr`.
  • Loading branch information
WeilerP committed Feb 29, 2024
1 parent c049dac commit 2c5ee8f
Showing 1 changed file with 30 additions and 19 deletions.
49 changes: 30 additions & 19 deletions src/cellrank/kernels/_base_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,33 +550,44 @@ def _get_vector_field_estimate(self, rep: str) -> np.ndarray:
extrapolated_gex = self.transition_matrix @ self.adata.obsm[rep]
return extrapolated_gex - self.adata.obsm[rep]

def get_cbc(self, source: str, target: str, cluster_key: str, rep: str, graph: str = "distances"):
"""Compute cross-boundary correctness score between source and target cluster."""
def get_cbc(self, source: str, target: str, cluster_key: str, rep: str, graph_key: str = "distances") -> np.ndarray:
"""Compute cross-boundary correctness score between source and target cluster.
def get_pearson_corr(x, y):
Parameters
----------
source
Name of source cluster.
target
Name of target cluster.
cluster_key
Key in :attr:`~anndata.AnnData.obs` to obtain cluster annotations.
rep
Key in :attr:`~anndata.AnnData.obsm` to use as data representation.
graph_key
Name of graph representation to use from :attr:`~anndata.AnnData.obsp`.
Returns
-------
Cross-boundary correctness score for each observation.
"""

def _pearsonr(x, y):
x_centered = x - np.mean(x, axis=1).reshape(-1, 1)
if y.ndim == 1:
y_centered = y - np.mean(y)
pearson_corr = (
np.dot(x_centered, y_centered) / np.linalg.norm(x_centered, axis=1) / np.linalg.norm(y_centered)
)
else:
y_centered = y - np.mean(y, axis=1).reshape(-1, 1)
pearson_corr = (
np.sum(x_centered * y_centered, axis=1)
/ np.linalg.norm(x_centered, axis=1)
/ np.linalg.norm(y_centered, axis=1)
)
return pearson_corr
y_centered = y - np.mean(y, axis=1).reshape(-1, 1)
return (
np.sum(x_centered * y_centered, axis=1)
/ np.linalg.norm(x_centered, axis=1)
/ np.linalg.norm(y_centered, axis=1)
)

target_obs_mask = self.adata.obs[cluster_key].isin([target] if isinstance(target, str) else target)
boundary_ids = self._get_boundary(source=source, target=target, cluster_key=cluster_key, graph_key=graph)
boundary_ids = self._get_boundary(source=source, target=target, cluster_key=cluster_key, graph_key=graph_key)
empirical_velo = self._get_empirical_velocity_field(
boundary_ids=boundary_ids, target_obs_mask=target_obs_mask, rep=rep, graph_key=graph
boundary_ids=boundary_ids, target_obs_mask=target_obs_mask, rep=rep, graph_key=graph_key
)
estimated_velo = self._get_vector_field_estimate(rep=rep)[boundary_ids, :]

cbc = get_pearson_corr(x=estimated_velo, y=empirical_velo)
cbc = _pearsonr(x=estimated_velo, y=empirical_velo)
if hasattr(self, "cbc"):
self.cbc[(source, target)] = cbc
return cbc
Expand Down

0 comments on commit 2c5ee8f

Please sign in to comment.