1
- from collections .abc import Callable
1
+ from collections .abc import Hashable
2
2
3
3
import anndata as ad
4
+ import numpy as np
4
5
import pandas as pd
6
+ from numpy .typing import DTypeLike
5
7
6
8
7
9
def celltype_signatures (
8
10
adata : ad .AnnData ,
9
11
* ,
10
12
celltype_col : str = "leiden" ,
11
13
layer : str | None = None ,
12
- agg_method : str | Callable = "mean" ,
14
+ dtype : DTypeLike = np . float32 ,
13
15
) -> pd .DataFrame :
14
16
"""
15
- Calculate gene expression signatures per 'celltype'.
16
-
17
- Note, that this will make a dense copy of `adata.X` or the selected `layer`,
18
- therefore potentially leading to large memory usage.
17
+ Calculate gene expression signatures per 'cell type'.
19
18
20
19
Parameters
21
20
----------
@@ -24,25 +23,24 @@ def celltype_signatures(
24
23
Name of column in :py:attr:`anndata.AnnData.obs` containing cell-type
25
24
information.
26
25
layer : str, optional
27
- Which layer to use for aggregation. If `None`, `adata.X` is used.
28
- agg_method : str or collections.abc.Callable, optional
29
- Function to aggregate gene expression per cluster used by
30
- :py:meth:`pandas.DataFrame.agg` .
26
+ Which :py:attr:`anndata.AnnData.layers` to use for aggregation. If `None`,
27
+ :py:attr:`anndata.AnnData.X` is used.
28
+ dytpe : numpy.typing.DTypeLike
29
+ Data type to use for the signatures .
31
30
32
31
Returns
33
32
-------
34
33
pandas.DataFrame
35
- :py:class:`pandas.DataFrame` of gene expression aggregated per 'celltype '.
34
+ :py:class:`pandas.DataFrame` of gene expression aggregated per 'cell type '.
36
35
"""
37
- signatures = (
38
- adata .to_df (layer = layer )
39
- .merge (adata .obs [celltype_col ], left_index = True , right_index = True )
40
- .groupby (celltype_col , observed = True , sort = False )
41
- .agg (agg_method )
42
- .transpose ()
43
- .rename_axis (adata .var_names .name )
44
- )
36
+ X = adata .X if layer is None else adata .layers [layer ]
37
+ grouping = adata .obs .groupby (celltype_col , observed = True , sort = False ).indices
45
38
46
- signatures /= signatures .sum (axis = 0 )
39
+ signatures : dict [Hashable , np .ndarray ] = {}
40
+ for name , indices in grouping .items ():
41
+ mean_X_group = X [indices ].mean (axis = 0 , dtype = dtype )
42
+ signatures [name ] = (
43
+ mean_X_group .A1 if isinstance (mean_X_group , np .matrix ) else mean_X_group
44
+ )
47
45
48
- return signatures
46
+ return pd . DataFrame ( signatures , index = adata . var_names )
0 commit comments