Skip to content

Commit 372b89b

Browse files
authored
Improve signature calculation efficiency (#27)
* more efficient/less flexible signature calculation * don't sort when grouping * improve typing in signatures * don't normalize signatures
1 parent 6552400 commit 372b89b

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

sainsc/utils/_signatures.py

+19-21
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
1-
from collections.abc import Callable
1+
from collections.abc import Hashable
22

33
import anndata as ad
4+
import numpy as np
45
import pandas as pd
6+
from numpy.typing import DTypeLike
57

68

79
def celltype_signatures(
810
adata: ad.AnnData,
911
*,
1012
celltype_col: str = "leiden",
1113
layer: str | None = None,
12-
agg_method: str | Callable = "mean",
14+
dtype: DTypeLike = np.float32,
1315
) -> pd.DataFrame:
1416
"""
15-
Calculate gene expression signatures per 'celltype'.
16-
17-
Note, that this will make a dense copy of `adata.X` or the selected `layer`,
18-
therefore potentially leading to large memory usage.
17+
Calculate gene expression signatures per 'cell type'.
1918
2019
Parameters
2120
----------
@@ -24,25 +23,24 @@ def celltype_signatures(
2423
Name of column in :py:attr:`anndata.AnnData.obs` containing cell-type
2524
information.
2625
layer : str, optional
27-
Which layer to use for aggregation. If `None`, `adata.X` is used.
28-
agg_method : str or collections.abc.Callable, optional
29-
Function to aggregate gene expression per cluster used by
30-
:py:meth:`pandas.DataFrame.agg`.
26+
Which :py:attr:`anndata.AnnData.layers` to use for aggregation. If `None`,
27+
:py:attr:`anndata.AnnData.X` is used.
28+
dytpe : numpy.typing.DTypeLike
29+
Data type to use for the signatures.
3130
3231
Returns
3332
-------
3433
pandas.DataFrame
35-
:py:class:`pandas.DataFrame` of gene expression aggregated per 'celltype'.
34+
:py:class:`pandas.DataFrame` of gene expression aggregated per 'cell type'.
3635
"""
37-
signatures = (
38-
adata.to_df(layer=layer)
39-
.merge(adata.obs[celltype_col], left_index=True, right_index=True)
40-
.groupby(celltype_col, observed=True, sort=False)
41-
.agg(agg_method)
42-
.transpose()
43-
.rename_axis(adata.var_names.name)
44-
)
36+
X = adata.X if layer is None else adata.layers[layer]
37+
grouping = adata.obs.groupby(celltype_col, observed=True, sort=False).indices
4538

46-
signatures /= signatures.sum(axis=0)
39+
signatures: dict[Hashable, np.ndarray] = {}
40+
for name, indices in grouping.items():
41+
mean_X_group = X[indices].mean(axis=0, dtype=dtype)
42+
signatures[name] = (
43+
mean_X_group.A1 if isinstance(mean_X_group, np.matrix) else mean_X_group
44+
)
4745

48-
return signatures
46+
return pd.DataFrame(signatures, index=adata.var_names)

0 commit comments

Comments
 (0)