Source code for constclust.plotting

from __future__ import annotations

from . import aggregate
import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
import scanpy as sc
from anndata import AnnData
import pandas as pd
from pandas.api.types import is_float_dtype, is_numeric_dtype, is_categorical_dtype
import numpy as np
from itertools import product
from typing import Optional

# TODO: Make it possible to have plot objects returned
# TODO: Think up visualization for the set of components


# TODO: Add title arg
# TODO: crosstab instead of pivot_table?
[docs]def component_param_range( component: "aggregate.Component", x: str = "n_neighbors", y: str = "resolution", ax: Optional[mpl.axis.Axis] = None, ) -> mpl.axis.Axis: """ Given a component, show which parameters it's found at as a heatmap. Params ------ component The component to plot. x The parameter for the x axis. y The parameter to place on the y axis. ax Optional axis to plot on. Example ------- >>> comps = reconciler.get_comps(0.9) >>> plotting.component_param_range(comps[0]) """ # Calculate colorbar param_states = pd.Series(component._parent.settings[[x, y]].itertuples(index=False)) ncolors = param_states.value_counts().max() + 1 cmap = mpl.colors.LinearSegmentedColormap.from_list( "dummy_name", sns.cm.rocket( np.linspace(20, 240, ncolors, dtype=int) ), # Shifted a little so it's prettier ncolors, ) # Initialize blank array data = pd.pivot_table( component._parent.settings[[x, y]], index=y, columns=x, aggfunc=lambda x: 0 ) params = pd.pivot_table( component.settings[[x, y]], index=y, columns=x, aggfunc=len, fill_value=0 ) data = (data + params).fillna(0).astype(int) # Fix axis labels if is_float_dtype(data.index.dtype): data.index = data.index.map(lambda x: "{:g}".format(x)) if is_float_dtype(data.columns.dtype): data.columns = data.columns.map(lambda x: "{:g}".format(x)) ax = sns.heatmap( data, linewidths=0.2, ax=ax, vmin=0, vmax=ncolors, cmap=cmap, ) # Discretize colorbar cb = ax.collections[0].colorbar new_pos = np.stack([cb._boundaries[:-1], cb._boundaries[1:]], 1).mean(axis=1) assert len(new_pos) == ncolors cb.set_ticks(new_pos) cb.set_ticklabels(list(range(ncolors))) return ax
def umap_cells(cells, adata, ax=None, umap_kwargs={}): if hasattr(cells, "keys") and hasattr(cells, "values"): cells = pd.Series(cells) else: cells = pd.Series(1, index=cells) cell_values = pd.Series(0, index=adata.obs_names, dtype=float) cell_values[cells.index] += cells adata.obs["_tmp"] = cell_values sc.pl.umap(adata, color="_tmp", ax=ax, title="UMAP", **umap_kwargs) adata.obs.drop(columns="_tmp", inplace=True) def component_embedding(component, adata, ax=None, embedding_kwargs={}, basis="X_umap"): # TODO: Views should have parents, which is where I should get my obs names embedding_kwargs = embedding_kwargs.copy() if "title" not in embedding_kwargs: embedding_kwargs["title"] = basis cell_names = component._parent._obs_names # if not all(cell_names.isin(adata.obs_names)): # raise ValueError("Counldn't find all cells in component's parent in adata.") cell_value = pd.Series(0, index=adata.obs_names, dtype=float) # present_freqs = component.cell_frequency # cell_value[present_freqs.index] = present_freqs # TODO: Apparently, this is super slow. Sum on a ndarray instead for cluster in component.cluster_ids: cell_value[component._parent._mapping.iloc[cluster]] += 1 cell_value = cell_value / cell_value.max() adata.obs["_tmp"] = cell_value if len(cell_names) < len(adata.obs_names): # Take view sc.pl.embedding( adata[cell_names, :], basis=basis, color="_tmp", ax=ax, **embedding_kwargs ) else: sc.pl.embedding( adata[cell_names, :], basis=basis, color="_tmp", ax=ax, **embedding_kwargs ) adata.obs.drop(columns="_tmp", inplace=True) def global_stability( settings, clusters, x="n_neighbors", y="resolution", cmap=sns.cm.rocket, ax=None ): # This should probably aggregate, currently do hacky thing of just subsetting if len(set(settings[[x, y]].itertuples(index=False))) != len(settings): # raise NotImplementedError("Aggregation of multiple global solutions not yet implemented.") settings = settings[settings["random_state"] == 0].copy() xlabs = sorted(settings[x].unique()) ylabs = sorted(settings[y].unique()) clusters = clusters[settings.index].copy() mapping = dict( zip(settings.index, settings[[x, y]].itertuples(index=False, name=None)) ) # + .5 might make it align better with local plot xpos = np.arange(len(xlabs)) # + .5 ypos = np.arange(len(ylabs)) # + .5 pos_map = {} for (xlab, x), (ylab, y) in product(zip(xlabs, xpos), zip(ylabs, ypos)): pos_map[(xlab, ylab)] = (x, y) edges = aggregate.build_global_graph(settings, clusters) lines = [] colors = [] for edge in edges: color = edge[2] line = [pos_map[mapping[edge[0]]], pos_map[mapping[edge[1]]]] colors.append(color) lines.append(line) lc = mpl.collections.LineCollection(lines, cmap=cmap) lc.set_array(np.array(colors)) if ax is not None: fig = ax.get_figure() else: fig, ax = plt.subplots() ax.add_collection(lc) xstep = _tick_step(ax, xlabs, 0) ystep = _tick_step(ax, ylabs, 1) ax.set(xticks=xpos[::xstep], yticks=ypos[::ystep]) fmat = np.vectorize("{:g}".format) ax.set_xticklabels(labels=fmat(xlabs[::xstep])) ax.set_yticklabels(labels=fmat(ylabs[::ystep])) ax.invert_yaxis() ax.autoscale() ax.set_frame_on(False) cb = fig.colorbar(lc, ax=ax) cb.outline.set_visible(False)
[docs]def component( component: "Component", adata: AnnData, x: str = "n_neighbors", y: str = "resolution", embedding_basis: str = "X_umap", plot_global: bool = False, aspect: float = None, embedding_kwargs: dict = {}, ): """ Plot stability and embedding for component. Params ------ component Component object to plot. adata AnnData to use for plotting UMAP. Should have same cell names as `Component`s parent `Reconciler`. x Parameter to plot on the X-axis of the heatmap. y Parameter to plot on the Y-axis of the heatmap. embedding_basis Which basis from the AnnData object to use for embedding. aspect Aspect ratio of entire plot. Defaults to 1/2. embedding_kwargs Arguments passed to `sc.pl.embedding`. """ if aspect is None: if plot_global: aspect = 1 / 3 else: aspect = 1 / 2 fig = plt.figure(figsize=mpl.figure.figaspect(aspect)) if plot_global: reconciler = component._parent gs = fig.add_gridspec(1, 3) global_ax = fig.add_subplot(gs[0, 1]) global_stability( reconciler.settings, reconciler.clusterings, x, y, ax=global_ax ) else: gs = fig.add_gridspec(1, 2) heatmap_ax = fig.add_subplot(gs[0, 0]) embedding_ax = fig.add_subplot(gs[0, -1]) component_param_range(component, x, y, ax=heatmap_ax) component_embedding( component, adata, ax=embedding_ax, basis=embedding_basis, embedding_kwargs=embedding_kwargs, ) return fig
def edge_weight_distribution(recon, **kwargs): return sns.histplot(recon.graph.es["weight"], **kwargs) # Modified from seaborn def _tick_step(ax, labels, axis): transform = ax.figure.dpi_scale_trans.inverted() bbox = ax.get_window_extent().transformed(transform) size = [bbox.width, bbox.height][axis] axis = [ax.xaxis, ax.yaxis][axis] (tick,) = axis.set_ticks([0]) fontsize = tick.label.get_size() max_ticks = int(size // (fontsize / 72)) tick_every = len(labels) // max_ticks + 1 tick_every = 1 if tick_every == 0 else tick_every return tick_every