Skip to content

Commit

Permalink
fix(analysis): update mae_per_gene function and add docstring
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
  • Loading branch information
cameronraysmith committed Sep 17, 2024
1 parent 8174ff9 commit 0ac2cdf
Showing 1 changed file with 72 additions and 5 deletions.
77 changes: 72 additions & 5 deletions src/pyrovelocity/analysis/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,11 +327,78 @@ def vector_field_uncertainty(
return v_map_all, embeds_radian, fdri


def mae_per_gene(pred_counts: ndarray, true_counts: ndarray) -> ndarray:
"""Computes mean average error between counts and predicted probabilities."""
error = np.abs(true_counts - pred_counts).sum(-2)
total = np.clip(true_counts.sum(-2), 1, np.inf)
return -np.array(error / total)
@beartype
def mae_per_gene(
pred_counts: NDArray[np.number],
true_counts: NDArray[np.number],
) -> NDArray[np.number]:
"""
Computes mean absolute error (MAE) between predictive samples and true counts.
The function returns the negative of the normalized MAE for
consistency with the convention that higher values should indicate
better performance in visualizations.
TODO: convert to jax
```python
import jax.numpy as jnp
from jaxtyping import Array, Num, jaxtyped
@jaxtyped(typechecker=beartype)
def mae_per_gene(
pred_counts: Num[Array, "obs vars"],
true_counts: Num[Array, "obs vars"],
) -> Num[Array, "vars"]:
pass
```
Args:
pred_counts (NDArray[np.number]): Predicted counts for all observations.
true_counts (NDArray[np.number]): Observed counts for all observations.
Returns:
NDArray[np.number]: Negative mean absolute error for each gene.
Example:
>>> import numpy as np
>>> from pyrovelocity.analysis.analyze import mae_per_gene
>>> true_counts = np.array(
... [
... [1, 2, 3],
... [1, 2, 3],
... [1, 2, 3],
... [1, 2, 3],
... ]
... )
>>> pred_counts = np.array(
... [
... [1.1, 2.2, 3.3],
... [1.1, 2.2, 3.3],
... [1.1, 2.2, 3.3],
... [1.1, 2.2, 3.3],
... ]
... )
>>> mae_per_gene(pred_counts, true_counts)
array([-0.1, -0.1, -0.1])
>>> true_counts = np.array(
... [
... [10, 15],
... [20, 25],
... ]
... )
>>> pred_counts = np.array(
... [
... [12, 14],
... [18, 26],
... ]
... )
>>> mae_per_gene(pred_counts, true_counts)
array([-0.133..., -0.05... ])
"""
total_true_counts = np.maximum(true_counts.sum(axis=0), 1)
mae = np.abs(true_counts - pred_counts).sum(axis=0) / total_true_counts
return -mae


@beartype
Expand Down

0 comments on commit 0ac2cdf

Please sign in to comment.