import numpy as np
import pandas as pd
from anndata import AnnData
from typing import Optional, Union
import types
import sys
try:
from sklearn.neighbors import DistanceMetric
except ImportError:
from sklearn.metrics import DistanceMetric
from sklearn.neighbors import KDTree
from sklearn.metrics import pairwise_distances
from . import logger
from .symbols import UNASSIGN
from umap.umap_ import fuzzy_simplicial_set
import pynndescent
from annoy import AnnoyIndex
from scipy.spatial import cKDTree
from scipy.sparse import coo_matrix
try:
import faiss
except ImportError:
pass
def _locate_meta_neighbors(pca, batch_list, celltype_list, n_meta_neighbors):
"""
Find the cell type neighbors for a given cell type (union across batches).
"""
celltypes = np.unique(celltype_list)
if n_meta_neighbors == 1 or len(celltypes) == 1:
cns = {celltype: np.array([celltype], dtype = object) for celltype in celltypes}
else:
c2b = {celltype: np.unique(batch_list[celltype_list == celltype]) for celltype in celltypes}
b2d = {}
for batch in np.unique(batch_list):
batch_flag = batch_list == batch
sub_celltypes = np.unique(celltype_list[batch_flag])
dists = pairwise_distances(np.array([pca[batch_flag & (celltype_list == sub_celltype)].mean(axis=0) for sub_celltype in sub_celltypes]))
b2d[batch] = pd.DataFrame(dists, index = sub_celltypes, columns = sub_celltypes)
cns = {celltype: np.unique(np.concatenate([b2d[b].loc[celltype].sort_values().index[:n_meta_neighbors] for b in c2b[celltype]])) for celltype in celltypes}
if UNASSIGN in celltypes:
cns_keys = list(cns.keys())
for cns_key in cns_keys:
if cns_key == UNASSIGN:
cns[cns_key] = celltypes
else:
cns[cns_key] = np.unique(np.append(cns[cns_key], UNASSIGN))
return cns
def _create_tree(pca, computation, metric, annoy_n_trees, pynndescent_n_neighbors, pynndescent_random_state):
"""
Copied from BBKNN. Create a faiss/cKDTree/KDTree/annoy/pynndescent index for nearest neighbor lookup.
"""
if computation == 'annoy':
ckd = AnnoyIndex(pca.shape[1], metric = metric)
for i in np.arange(pca.shape[0]):
ckd.add_item(i, pca[i, :])
ckd.build(annoy_n_trees)
elif computation == 'pynndescent':
ckd = pynndescent.NNDescent(pca, metric = metric, n_jobs = -1, n_neighbors = pynndescent_n_neighbors, random_state = pynndescent_random_state)
ckd.prepare()
elif computation == 'faiss':
ckd = faiss.IndexFlatL2(pca.shape[1])
ckd.add(pca)
elif computation == 'cKDTree':
ckd = cKDTree(pca)
elif computation == 'KDTree':
ckd = KDTree(pca, metric = metric)
return ckd
def _query_tree(pca, computation, ckd, n_neighbors):
"""
Copied from BBKNN. Query the faiss/cKDTree/KDTree/annoy/pynndescent index.
"""
if computation == 'annoy':
ckdo_ind = []
ckdo_dist = []
for i in np.arange(pca.shape[0]):
holder = ckd.get_nns_by_vector(pca[i, :], n_neighbors, include_distances = True)
ckdo_ind.append(holder[0])
ckdo_dist.append(holder[1])
ckdout = (np.asarray(ckdo_dist), np.asarray(ckdo_ind))
elif computation == 'pynndescent':
ckdout = ckd.query(pca, k = n_neighbors)
ckdout = (ckdout[1], ckdout[0])
elif computation == 'faiss':
D, I = ckd.search(pca, n_neighbors)
D[D < 0] = 0
ckdout = (np.sqrt(D), I)
elif computation == 'cKDTree':
ckdout = ckd.query(x = pca, k = n_neighbors, workers = -1)
elif computation == 'KDTree':
ckdout = ckd.query(pca, k = n_neighbors)
return ckdout
def _get_graph(pca, batch_list, celltype_list, computation, n_neighbors, n_meta_neighbors, metric, annoy_n_trees, pynndescent_n_neighbors, pynndescent_random_state, random_state):
"""
Identify the cell-type-controlled KNN structure to be used in graph construction.
"""
celltype_groups = _locate_meta_neighbors(pca, batch_list, celltype_list, n_meta_neighbors)
if computation == 'faiss':
pca = pca.astype('float32')
knn_dists = np.zeros((pca.shape[0], n_neighbors))
knn_indices = np.copy(knn_dists).astype(int)
#main
celltypes = np.unique(celltype_list)
for celltype in celltypes:
flag_celltype = celltype_list == celltype
ind_celltype = np.arange(len(batch_list))[flag_celltype]
flag_group = np.isin(celltype_list, celltype_groups[celltype])
batches = np.unique(batch_list[flag_group])
q, mod = divmod(n_neighbors, len(batches))
np.random.seed(random_state)
n_neighbors_across = np.random.permutation([q]*(len(batches)-mod) + [q+1]*mod)
cum_sum = np.cumsum(n_neighbors_across)
for rank, batch, n_neighbors_each in zip(range(len(batches)), batches, n_neighbors_across):
flag_batch = batch_list == batch
flag = flag_group & flag_batch
if n_neighbors_each > flag.sum():
flag = flag_batch
ind = np.arange(len(batch_list))[flag]
ckd = _create_tree(pca[flag], computation, metric, annoy_n_trees, pynndescent_n_neighbors, pynndescent_random_state)
ckdout = _query_tree(pca[flag_celltype], computation, ckd, n_neighbors_each)
col_range = np.arange(0 if rank == 0 else cum_sum[rank-1], cum_sum[rank])
knn_indices[ind_celltype[:, np.newaxis], col_range] = ind[ckdout[1]]
knn_dists[ind_celltype[:, np.newaxis], col_range] = ckdout[0]
return knn_dists, knn_indices
[docs]
def integrate(
#input adata
adata: AnnData, batch: str, cell_type: Optional[str] = None, use_rep: Optional[str] = None, n_latent: int = 50,
#neighbors global setting
n_neighbors: Optional[int] = None, n_meta_neighbors: int = 3, approx: bool = True, metric: Union[str, types.FunctionType, DistanceMetric] = 'euclidean',
#if approx = True, annoy or pyNNDescent
use_annoy: bool = True, annoy_n_trees: int = 10, pynndescent_n_neighbors: int = 30, pynndescent_random_state: int = 0,
#if approx = False
use_faiss: bool = True,
#connectivities
set_op_mix_ratio: float = 1.0, local_connectivity: int = 1, trim: Optional[int] = None,
#random and copy
neighbor_random_state: int = 0, copy: bool = False) -> Union[AnnData, None]:
"""
Cell type controlled k nearest neighbors. This is a variant of BBKNN by searching neighbors across matched cell groups in different batches.
For a given cell belonging to cell type 'c', first determine the batches that contain 'c' and its neighboring cell types, and then in each batch, search nearest neighbors out of them.
Parameters
----------
adata
An :class:`~anndata.AnnData` object containing batch and cell type information in `.obs`, as well as latent space (e.g., `'X_pca'`) in `.obsm`.
batch
Column name (key) of cell metadata specifying batch information.
cell_type
Column name (key) of cell metadata specifying cell type information.
Default to no cell type information provided (i.e., searching nearest neighbors in the entire batch space).
use_rep
Representation used to calculate distances. This can be any representations stored in `.obsm`.
Default to the PCA coordinates (`'X_pca'`) if present.
n_latent
Number of latent representations used.
Default to min(50, number of available latent representations).
n_neighbors
Total number of nearest neighbors for each cell. This number will be contributed equally from batches that qualify.
Default to max(15, n) where n is the number of batches times three, meaning that each qualified batch will provide at least 3 neighbors.
For example, if one cell type exists exclusively in one batch, then this batch needs to provide 15 neighbors.
n_meta_neighbors
Total number of nearest meta neighbors for each cell type in each batch (calculated from cell centroids).
The final nearest meta neighbors are the union across batches that contain this given cell type.
The smaller this value, the stronger bonding of the same cell type.
Setting to 1 will make each cell search nearest neighbors only in the cell type it belongs to (i.e., forcibly clustering the same cell types).
(Default: `3`)
approx
Whether to use fast approximate neighbor finding (annoy or pyNNDescent).
(Default: `True`)
metric
Distance metric to use.
(Default: `'euclidean'`)
use_annoy
Whether to use annoy for neighbor finding when `approx = True`. Setting `use_annoy = False` will use pyNNDescent instead.
(Default: `True`)
annoy_n_trees
Number of trees to construct in the annoy forest when `approx = True` and `use_annoy = True`.
(Default: `10`)
pynndescent_n_neighbors
Number of neighbors to include in the approximate neighbor graph when `approx = True` and `use_annoy = False`.
(Default: `30`)
pynndescent_random_state
Random seed to use in pyNNDescent when `approx = True` and `use_annoy = False`.
(Default: `0`)
use_faiss
Whether to use the faiss package to compute nearest neighbors if installed when `approx = False` and `metric = 'euclidean'`.
(Default: `True`)
set_op_mix_ratio
Float between 0 and 1 controlling the blend between a connectivity matrix formed exclusively from mutual nearest neighbor pairs (0)
and a union of all observed neighbor relationships with the mutual pairs emphasized (1).
(Default: `1.0`)
local_connectivity
UMAP connectivity computation parameter controlling how many nearest neighbors of each cell are assumed to be fully connected (with a connectivity value of 1).
(Default: `1`)
trim
Trim each cell to top `trim` connectivities. May help with population independence and improve the tidiness of clustering.
Default to n_neighbors*10. Set to 0 to skip trimming.
neighbor_random_state
Random seed to use in assigning the remainder neighbors to batches.
For example, assigning 10 nearest neighbors to 3 batches will make one remainder neighbor randomly assigned to one of the three batches.
(Default: `0`)
copy
Whether to copy the adata or modify in-place.
(Default: `False`)
Returns
----------
Union[AnnData, None]
Depending on `copy`, return an updated or copied :class:`~anndata.AnnData` object with neighborhood graph included.
"""
#check adata
adata = adata.copy() if copy else adata
if batch not in adata.obs:
raise KeyError(
f"🛑 '{batch}' is not found in the provided AnnData")
batch_list = adata.obs[batch].astype(str).values
if isinstance(cell_type, str) and cell_type not in adata.obs:
raise KeyError(
f"🛑 '{cell_type}' is not found in the provided AnnData")
celltype_list = adata.obs[cell_type].astype(str).values if isinstance(cell_type, str) else np.full(adata.n_obs, 'cell', dtype = object)
batch_counts = adata.obs[batch].astype(str).value_counts()
few_batches = set(batch_counts.index[batch_counts <= 10])
if len(few_batches) > 0:
logger.warn(f"⚠️ The following batch(es) have too few cells (<= 10), please remove them before running `cellhint.integrate`: {few_batches}")
return
if use_rep is None:
logger.info(f"👀 `use_rep` is not specified, will use `'X_pca'` as the search space")
use_rep = 'X_pca'
if use_rep not in adata.obsm.keys():
raise KeyError(
f"🛑 '{use_rep}' is not found in `.obsm`")
n_latent = min([n_latent, adata.obsm[use_rep].shape[1]])
pca = adata.obsm[use_rep][:, :n_latent]
#check knn search params
n_obs = adata.n_obs
n_neighbors = max([15, 3 * len(np.unique(batch_list))]) if n_neighbors is None else n_neighbors
swapped = False
if approx:
if use_annoy:
computation = 'annoy'
if metric not in ['angular', 'euclidean', 'manhattan', 'hamming']:
swapped = True
metric = 'euclidean'
else:
computation = 'pynndescent'
if not (metric in pynndescent.distances.named_distances or isinstance(metric, types.FunctionType)):
swapped = True
metric = 'euclidean'
else:
if not ((metric == 'euclidean') or isinstance(metric, DistanceMetric) or metric in KDTree.valid_metrics):
swapped = True
metric = 'euclidean'
if metric == 'euclidean':
if 'faiss' in sys.modules and use_faiss:
computation = 'faiss'
else:
computation = 'cKDTree'
else:
computation = 'KDTree'
if swapped:
logger.warn(f"👀 Unrecognized `metric` for type of neighbor calculation, will switch to 'euclidean'")
#knn construction
knn_dists, knn_indices = _get_graph(pca, batch_list, celltype_list, computation, n_neighbors, n_meta_neighbors, metric, annoy_n_trees, pynndescent_n_neighbors, pynndescent_random_state, neighbor_random_state)
newidx = np.argsort(knn_dists, axis = 1)
knn_indices = knn_indices[np.arange(n_obs)[:, np.newaxis], newidx]
knn_dists = knn_dists[np.arange(n_obs)[:, np.newaxis], newidx]
#connectivities + distances
X = coo_matrix(([], ([], [])), shape=(n_obs, 1))
connectivities = fuzzy_simplicial_set(X, n_neighbors, None, None, knn_indices = knn_indices, knn_dists = knn_dists, set_op_mix_ratio = set_op_mix_ratio, local_connectivity = local_connectivity)
if isinstance(connectivities, tuple):
connectivities = connectivities[0]
connectivities = connectivities.tocsr()
rows = np.zeros((n_obs * n_neighbors), dtype=np.int64)
cols = np.zeros((n_obs * n_neighbors), dtype=np.int64)
vals = np.zeros((n_obs * n_neighbors), dtype=np.float64)
for i in range(n_obs):
for j in range(n_neighbors):
if knn_indices[i, j] == -1:
continue
if knn_indices[i, j] == i:
val = 0.0
else:
val = knn_dists[i, j]
rows[i * n_neighbors + j] = i
cols[i * n_neighbors + j] = knn_indices[i, j]
vals[i * n_neighbors + j] = val
distances = coo_matrix((vals, (rows, cols)), shape=(n_obs, n_obs))
distances.eliminate_zeros()
#trim
if trim is None:
trim = 10 * n_neighbors
if trim > 0:
cutoffs = np.zeros(n_obs)
for i in range(n_obs):
row_array = connectivities.data[connectivities.indptr[i]: connectivities.indptr[i+1]]
if row_array.shape[0] <= trim:
continue
cutoffs[i] = row_array[np.argsort(row_array)[-1*trim]]
for iter in range(2):
for i in range(n_obs):
row_array = connectivities.data[connectivities.indptr[i]: connectivities.indptr[i+1]]
row_array[row_array < cutoffs[i]] = 0
connectivities.eliminate_zeros()
connectivities = connectivities.T.tocsr()
#assign
adata.uns['neighbors'] = {}
adata.uns['neighbors']['params'] = {'n_neighbors': n_neighbors, 'method': 'umap', 'metric': metric, 'n_pcs': n_latent, 'ccknn': {'n_meta_neighbors': n_meta_neighbors, 'trim': trim, 'computation': computation, 'batch': batch}}
adata.uns['neighbors']['params']['use_rep'] = use_rep
adata.obsp['distances'] = distances.tocsr()
adata.obsp['connectivities'] = connectivities
adata.uns['neighbors']['distances_key'] = 'distances'
adata.uns['neighbors']['connectivities_key'] = 'connectivities'
return adata if copy else None