"""
This module contains the code to cluster a single cell experiment many times.
"""
from itertools import product
import scanpy as sc
from anndata import AnnData
from typing import Collection, Tuple
import numpy as np
import pandas as pd
import leidenalg
from multiprocessing import Pool
from functools import partial
from tqdm import tqdm
# import bbknn
# TODO: Is random_state being passed to the right thing?
# TODO: Refactor
[docs]def cluster(
adata: AnnData,
n_neighbors: Collection[int],
resolutions: Collection[float],
random_state: Collection[int],
n_procs: int = 1,
neighbor_kwargs: dict = {},
leiden_kwargs: dict = {},
progress_bar: bool = True,
) -> Tuple[pd.DataFrame, pd.DataFrame]:
"""
Generate clusterings for each combination of ``n_neighbors``, ``resolutions``, and ``random_state``.
Parameters
----------
adata
Object to be clustered.
n_neighbors
Values for numbers of neighbors.
resolutions
Values for resolution parameter for modularity optimization.
random_state
Random seeds to start with.
n_procs
Number of processes to use.
neighbor_kwargs
Key word arguments to pass to all calls to :func:`scanpy.pp.neighbors`.
For example: `{"use_rep": "X"}`.
leiden_kwargs
Key word argument to pass to all calls to :func:`leidenalg.find_partition`.
For example, ``{"partition_type": leidenalg.CPMVertexPartition}``.
progress_bar
Whether to diplay a progress bar for the clustering process.
Returns
-------
Pair of dataframes, where the first contains the settings for each partitioning,
and the second contains the partitionings.
Example
-------
>>> params, clusterings = cluster(
adata,
n_neighbors=np.linspace(15, 90, 4, dtype=int),
resolutions=np.geomspace(0.05, 20, 50),
random_state=[0,1,2,3],
n_procs=4
)
"""
# Argument handling
leiden_kwargs = leiden_kwargs.copy()
neighbor_kwargs = neighbor_kwargs.copy()
if "partition_type" not in leiden_kwargs:
leiden_kwargs["partition_type"] = leidenalg.RBConfigurationVertexPartition
if "weights" not in leiden_kwargs:
leiden_kwargs["weights"] = "weight"
def _check_params(kwargs, vals, arg_name):
for val in vals:
if val in kwargs:
raise ValueError(
f"You cannot pass value for key `{val}` in `{arg_name}`"
)
_check_params(
neighbor_kwargs, ["adata", "n_neighbors", "random_state"], "neighbor_kwargs"
)
_check_params(
leiden_kwargs, ["graph", "resolution_parameter", "resolution"], "leiden_kwargs"
)
n_neighbors = sorted(n_neighbors)
resolutions = sorted(resolutions)
random_state = sorted(random_state)
# Logic
neighbor_graphs = []
n_iters = len(n_neighbors) * len(random_state)
for n, seed in tqdm(
product(n_neighbors, random_state),
desc="Building neighbor graphs",
total=n_iters,
disable=not progress_bar,
):
# Neighbor finding is already multithreaded (sorta)
sc.pp.neighbors(adata, n_neighbors=n, random_state=seed, **neighbor_kwargs)
g = sc._utils.get_igraph_from_adjacency(
adata.obsp["connectivities"], directed=True
)
neighbor_graphs.append({"n_neighbors": n, "random_state": seed, "graph": g})
cluster_jobs = []
for graph, res in product(neighbor_graphs, resolutions):
job = graph.copy()
job.update({"resolution": res})
cluster_jobs.append(job)
_cluster_single_kwargd = partial(_cluster_single, leiden_kwargs=leiden_kwargs)
with Pool(n_procs) as p:
# solutions = p.map(_cluster_single, cluster_jobs)
# TODO: Make sure this is returning in the right order, also try chunking
solutions = []
for s in tqdm(
p.imap(_cluster_single_kwargd, cluster_jobs, chunksize=5),
desc="Finding communities",
total=len(cluster_jobs),
disable=not progress_bar,
):
solutions.append(s)
# solutions = p.map(_cluster_single_kwargd, cluster_jobs)
clusters = pd.DataFrame(index=adata.obs_names)
for i, clustering in enumerate(solutions):
clusters[i] = clustering
settings_iter = (
(job["n_neighbors"], job["resolution"], job["random_state"])
for job in cluster_jobs
)
settings = pd.DataFrame.from_records(
settings_iter,
columns=["n_neighbors", "resolution", "random_state"],
index=range(len(solutions)),
)
return settings, clusters
# def cluster_batch_bbknn(
# adata: AnnData,
# batch_key: str,
# neighbors_within_batch: Collection[int],
# trim: Collection[int],
# resolutions: Collection[float],
# random_state: Collection[int],
# n_procs: int = 1,
# bbknn_kwargs: dict = {},
# leiden_kwargs: dict = {},
# ) -> Tuple[pd.DataFrame, pd.DataFrame]:
# """
# Generate clusterings for each combination of ``neighbors_within_batch``, ``trim``, ``resolutions``, and ``random_state``.
# Parameters
# ----------
# adata
# Object to be clustered.
# batch_key
# Key for ``adata.obs`` which encodes batch.
# neighbors_within_batch
# Values for numbers of neighbors within batch.
# trim
# Values for numbers of top connections to keep.
# resolutions
# Values for resolution parameter for modularity optimization.
# random_state
# Random seeds to start with.
# n_procs
# Number of processes to use.
# bbknn_kwargs
# Key word arguments to pass to all calls to ``sc.pp.neighbors``. For
# example: `{"use_rep": "X"}`.
# leiden_kwargs
# Key word argument to pass to all calls to ``leidenalg.find_partition``.
# For example, ``{"partition_type": leidenalg.CPMVertexPartition}``.
# Returns
# -------
# Pair of dataframes, where the first contains the settings for each partitioning,
# and the second contains the partitionings.
# """
# # Argument handling
# leiden_kwargs = leiden_kwargs.copy()
# bbknn_kwargs = bbknn_kwargs.copy()
# if "partition_type" not in leiden_kwargs:
# leiden_kwargs["partition_type"] = leidenalg.RBConfigurationVertexPartition
# if "weights" not in leiden_kwargs:
# leiden_kwargs["weights"] = "weight"
# def _check_params(kwargs, vals, arg_name):
# for val in vals:
# if val in kwargs:
# raise ValueError(
# f"You cannot pass value for key `{val}` in `{arg_name}`"
# )
# _check_params(
# bbknn_kwargs, ["adata", "pca", "batch_list" "neighbors_within_batch", "trim"], "bbknn_kwargs"
# )
# _check_params(
# leiden_kwargs, ["graph", "resolution_parameter", "resolution"], "leiden_kwargs"
# )
# neighbors_within_batch = sorted(neighbors_within_batch)
# # trim = sorted(trim) # TODO: value can be None
# resolutions = sorted(resolutions)
# random_state = sorted(random_state)
# # Logic
# neighbor_graphs = []
# for n, t in product(neighbors_within_batch, trim):
# # Neighbor finding is already multithreaded (sorta)
# i, c = bbknn.bbknn_pca_matrix(
# pca=adata.obsm["X_pca"],
# batch_list=adata.obs[batch_key].values,
# neighbors_within_batch=n,
# trim=t,
# **bbknn_kwargs
# )
# # sc.pp.neighbors(adata, n_neighbors=n, random_state=seed, **neighbor_kwargs)
# g = sc._utils.get_igraph_from_adjacency(c, directed=True)
# neighbor_graphs.append({"neighbors_within_batch": n, "trim": t, "graph": g})
# cluster_jobs = []
# for graph, res, seed in product(neighbor_graphs, resolutions, random_state):
# job = graph.copy()
# job.update({"resolution": res, "random_state": seed})
# cluster_jobs.append(job)
# _cluster_single_kwargd = partial(_cluster_single, leiden_kwargs=leiden_kwargs)
# with Pool(n_procs) as p:
# # solutions = p.map(_cluster_single, cluster_jobs)
# solutions = p.map(_cluster_single_kwargd, cluster_jobs)
# clusters = pd.DataFrame(index=adata.obs_names)
# for i, clustering in enumerate(solutions):
# clusters[i] = clustering
# settings_iter = (
# (job["neighbors_within_batch"], job["trim"], job["resolution"], job["random_state"])
# for job in cluster_jobs
# )
# settings = pd.DataFrame.from_records(
# settings_iter,
# columns=["neighbors_within_batch", "trim", "resolution", "random_state"],
# index=range(len(solutions)),
# )
# return settings, clusters
def _cluster_single(argdict, leiden_kwargs):
part = leidenalg.find_partition(
argdict["graph"],
resolution_parameter=argdict["resolution"],
seed=argdict["random_state"],
**leiden_kwargs,
)
return np.array(part.membership)