import numpy as np
import pandas as pd
from typing import Union, Optional
from . import logger
from .symbols import NOVEL, REMAIN
import sys
try:
import plotly.graph_objects as go
except ImportError:
pass
import matplotlib
from matplotlib import pyplot as plt
import os
SEP1 = '_____'
SEP2 = 'CXCXCXCX'
DEFAULT_SANKEY_COLORS = ['#2E91E5', '#E15F99', '#1CA71C', '#FB0D0D', '#DA16FF', '#B68100', '#750D86', '#EB663B', '#511CFB', '#00A08B', '#FB00D1', '#FC0080', '#B2828D', '#6C7C32', '#778AAE', '#862A16',
'#A777F1', '#620042', '#1616A7', '#DA60CA', '#6C4516', '#0D2A63', '#AF0038', '#FD3216', '#00FE35', '#6A76FC', '#FED4C4', '#FE00CE', '#0DF9FF', '#F6F926', '#FF9616', '#479B55',
'#EEA6FB', '#DC587D', '#D626FF', '#6E899C', '#00B5F7', '#B68E00', '#C9FBE5', '#FF0092', '#22FFA7', '#E3EE9E', '#86CE00', '#BC7196', '#7E7DCD', '#FC6955', '#E48F72']
def _qq_order(sub_relation, ordered_cols):
"""
Disentangle the cross connections in tree plot.
"""
sub_relation2 = sub_relation[ordered_cols].copy()
nms = sub_relation2.iloc[:, 0].value_counts().index.values
for i in range(len(ordered_cols) - 1):
i += 1
val1 = sub_relation2.iloc[:, :i].agg('_'.join, axis = 1)
val2 = sub_relation2.iloc[:, i]
tb = pd.crosstab(val1, val2).loc[nms]
nms = []
flag = 0
for j in range(tb.shape[0]):
ind1 = (tb.iloc[j] > 0).values
if j < (tb.shape[0] - 1):
ind2 = (tb.iloc[j + 1] > 0).values
else:
ind2 = np.repeat([False], tb.shape[1])
ind = ind1 & ind2
all_nms = tb.columns[ind1][np.argsort(tb.loc[:, ind1].iloc[j])[::-1]]
nm1 = tb.index[j]
if ind.sum() > 0:
last_nm = tb.columns[ind][0]
mid_nms = np.setdiff1d(all_nms, last_nm)
if flag == 1:
first_nm = fnm
mid_nms = np.setdiff1d(mid_nms, first_nm)
nms.append(nm1 + '_' + first_nm)
for md in mid_nms:
nms.append(nm1 + '_' + md)
nms.append(nm1 + '_' + last_nm)
flag = 1
fnm = last_nm
else:
if flag == 1:
all_nms = np.setdiff1d(all_nms, fnm)
nms.append(nm1 + '_' + fnm)
for md in all_nms:
nms.append(nm1 + '_' + md)
flag = 0
sub_relation.index = sub_relation2.agg('_'.join, axis = 1)
return sub_relation.loc[nms]
#def _qq_order2(sub_relation, ordered_cols):
# sub_relation2 = sub_relation[ordered_cols].copy()
# nms = sub_relation2.iloc[:, 0].value_counts().index.values
# for i in range(len(ordered_cols)-1):
# i += 1
# val1 = sub_relation2.iloc[:, :i].agg('@'.join, axis=1)
# val2 = sub_relation2.iloc[:, i]
# tb = pd.crosstab(val1, val2)
# tb = tb.loc[nms]
# csum = (tb > 0).sum(0)
# if csum.max() > 1:
# idxs = np.argsort(csum)
# for idx in idxs[csum[idxs] > 1]:
# idxs2 = np.where(tb.iloc[:, idx] > 0)[0]
# if np.any(np.diff(idxs2) > 1):
# ti = idxs2[0]
# tnm = np.array(nms[ti].split('@'))
# for idx2 in idxs2[1:]:
# nm = np.array(nms[idx2].split('@'))
# if idx2 == 0:
# nm2 = nms[idx2+1].split('_')
# elif idx2 + 1 < len(nms):
# nm2 = nms[idx2-1].split('_') + nms[idx2+1].split('_')
# else:
# nm2 = nms[idx2-1].split('_')
# if len(set(nm) & set(nm2)) == 0:
# mv_idxs.append(idx2)
# nms = []
# flag = 0
# for j in range(tb.shape[0]):
# ind1 = (tb.iloc[j] > 0).values
# if j < (tb.shape[0] - 1):
# ind2 = (tb.iloc[j+1] > 0).values
# else:
# ind2 = np.repeat([False], tb.shape[1])
# ind = (ind1 & ind2)
# all_nms = tb.columns[ind1][np.argsort(tb.loc[:, ind1].iloc[j])[::-1]]
# nm1 = tb.index[j]
# if ind.sum() > 0:
# last_nm = tb.columns[ind][0]
# mid_nms = np.setdiff1d(all_nms, last_nm)
# if flag == 1:
# first_nm = fnm
# mid_nms = np.setdiff1d(mid_nms, first_nm)
# nms.append(nm1 + '_' + first_nm)
# for md in mid_nms:
# nms.append(nm1 + '_' + md)
# nms.append(nm1 + '_' + last_nm)
# flag = 1
# fnm = last_nm
# else:
# if flag == 1:
# all_nms = np.setdiff1d(all_nms, fnm)
# nms.append(nm1 + '_' + fnm)
# for md in all_nms:
# nms.append(nm1 + '_' + md)
# flag = 0
# sub_relation.index = sub_relation2.agg('_'.join, axis=1)
# sub_relation = sub_relation.loc[nms]
# return sub_relation
def _relation_to_data(relation, return_sankey: bool = True) -> Union[pd.DataFrame, tuple]:
"""
For internal use. Turn the harmonization result into Sankey input.
"""
datasets = relation.columns[0::2]
relation = relation[datasets].copy()
relation.index = np.arange(relation.shape[0])
#add prefix for cell types
for i in relation.index:
for j in datasets:
content = relation.loc[i, j]
if content not in [NOVEL, REMAIN]:
relation.loc[i, j] = j + SEP1 + content
#rename columns
for _j in range(len(datasets)):
j = datasets[_j]
mapping_NOVEL = dict()
mapping_REMAIN = dict()
suffix_NOVEL = 0
suffix_REMAIN = 0
for i in relation.index:
content = relation.loc[i, j]
if content in [NOVEL, REMAIN]:
celltypes = relation.loc[i].values
relay = celltypes[~np.isin(celltypes, [NOVEL, REMAIN])][0] if _j == 0 else celltypes[_j-1]
if content == NOVEL:
if relay not in mapping_NOVEL:
suffix_NOVEL += 1
mapping_NOVEL[relay] = f"{NOVEL}{SEP1}{suffix_NOVEL}"
relation.loc[i, j] = j + SEP1 + mapping_NOVEL[relay]
else:
if relay not in mapping_REMAIN:
suffix_REMAIN += 1
mapping_REMAIN[relay] = f"{REMAIN}{SEP1}{suffix_REMAIN}"
relation.loc[i, j] = j + SEP1 + mapping_REMAIN[relay]
if not return_sankey:
return relation
#label
label = np.concatenate([np.unique(relation[j]) for j in datasets])
refer = pd.Series(label).reset_index().set_index(0)
#link = source + target + value
link = pd.DataFrame(columns = ['source', 'target', 'value'])
for k in range(len(datasets) - 1):
vc = (relation[datasets[k]] + SEP2 + relation[datasets[k+1]]).value_counts().reset_index()
count_column = vc.columns[1]
vc[['source', 'target']] = vc['index'].str.split(SEP2, expand=True).values
vc.rename(columns = {count_column: 'value'}, inplace = True)
link = pd.concat([link, vc[['source', 'target', 'value']]])
source = refer.loc[link.source.values, 'index'].values
target = refer.loc[link.target.values, 'index'].values
value = link.value.astype(int).values
return relation, go.Sankey(node = dict(label = label), link = dict(source = source, target = target, value = value))
def _identify_relation_groups(relation, group_prefix: str = 'Group', order_row: bool = True, order_column: bool = False) -> tuple:
"""
For internal use. Identify cell type groups based on the cell type harmonization result.
"""
new_relation = _relation_to_data(relation, False)
datasets = new_relation.columns
dup_celltypes = np.unique(np.concatenate([new_relation[dataset].values[new_relation[dataset].duplicated()] for dataset in datasets]))
if len(dup_celltypes) > 0:
groups = np.full(new_relation.shape[0], f"{group_prefix}0", dtype = object)
rownames = new_relation.index
i = 0
while len(rownames) > 0:
i += 1
receive = []
provide = [rownames[0]]
while len(provide) > 0:
pp = provide.pop()
receive.append(pp)
for dataset in datasets:
celltype = new_relation.loc[pp, dataset]
if celltype in dup_celltypes:
extends = new_relation[new_relation[dataset] == celltype].index
provide.extend(extends.tolist())
provide = list(set(provide).difference(receive))
groups[new_relation.index.isin(receive)] = f"{group_prefix}{i}"
rownames = rownames[~np.isin(rownames, receive)]
else:
groups = np.array([f"{group_prefix}{i+1}" for i in range(new_relation.shape[0])], dtype = object)
##remove in the future-->
assert np.all(groups != f"{group_prefix}0")
##<<-remove in the future
if order_row:
df = []
gs = []
for j in range(1, len(np.unique(groups))+1):
each_group = f"{group_prefix}{j}"
sub_relation = new_relation[groups == each_group]
if sub_relation.shape[0] > 1:
col_uniques = sub_relation.apply(pd.Series.nunique).values
ordered_cols = datasets[np.argsort(col_uniques)]
#sub_relation = sub_relation.sort_values(by = ordered_cols.tolist())
sub_relation = _qq_order(sub_relation, ordered_cols)
if order_column:
sub_relation = sub_relation[ordered_cols]
sub_relation.columns = [f"D{x+1}" for x in range(sub_relation.shape[1])]
flag_unique = ~sub_relation.isin(dup_celltypes)
for row_index in range(flag_unique.shape[0]):
row_series = flag_unique.iloc[row_index]
col_indices = []
last_col = flag_unique.shape[1] - 1
while row_series.values[last_col]:
col_indices.append(last_col)
last_col -= 1
col_indices.reverse()
if len(col_indices) >= 2:
is_blank = np.isin([x.split(SEP1)[1] for x in sub_relation.iloc[row_index, col_indices]], [NOVEL, REMAIN])
sub_relation.iloc[row_index, col_indices] = sub_relation.iloc[row_index, col_indices].values[np.argsort(~is_blank)]
else:
if order_column:
is_blank = np.isin(relation.loc[groups == each_group, datasets].values[0], [NOVEL, REMAIN])
sub_relation = sub_relation[datasets[np.argsort(~is_blank)]]
sub_relation.columns = [f"D{x+1}" for x in range(sub_relation.shape[1])]
df.append(sub_relation)
gs.extend([each_group] * sub_relation.shape[0])
return np.array(gs, dtype = object), pd.concat(df, axis = 0, ignore_index = True)
return groups, new_relation
def _mix_colors(cols) -> str:
"""
For internal use. Get the blended color.
"""
return matplotlib.colors.to_hex(np.array([matplotlib.colors.to_rgb(col) for col in cols]).mean(axis = 0))
def _new_relation_to_color(new_relation, node_color, novel_node_color, remain_node_color, cmap = 'Reds') -> dict:
"""
For internal use. Get the cell-type-to-color mapping.
"""
map_color = {NOVEL: novel_node_color, REMAIN: remain_node_color}
if node_color is None:
for i in new_relation.index:
row_color = DEFAULT_SANKEY_COLORS[i % len(DEFAULT_SANKEY_COLORS)]
for j in new_relation.columns:
content = new_relation.loc[i, j]
if content.split(SEP1)[1] in [NOVEL, REMAIN]:
continue
if content not in map_color:
map_color[content] = row_color
else:
map_color[content] = _mix_colors([map_color[content], row_color])
elif isinstance(node_color, pd.DataFrame):
node_color = node_color.copy()
map_values = node_color.iloc[:, 2].values
if not isinstance(map_values[0], str):
q10 = np.quantile(map_values, 0.10)
q85 = np.quantile(map_values, 0.85)
map_values[map_values <= q10] = q10
map_values[map_values >= q85] = q85
map_values = (map_values - map_values.min()) / map_values.ptp()
map_values = plt.get_cmap(cmap, 256)(map_values)
map_values = [matplotlib.colors.to_hex(map_value) for map_value in map_values]
node_color['_combination'] = node_color.iloc[:, 0].astype(str) + SEP1 + node_color.iloc[:, 1].astype(str)
map_color.update(dict(zip(node_color['_combination'].values, map_values)))
else:
raise TypeError(
f"🛑 Please provide `node_color` as a data frame")
return map_color
[docs]
def sankeyplot(alignment,
#node colors
node_color: Optional[pd.DataFrame] = None, novel_node_color: str = '#FFFFFF', remain_node_color: str = '#F0F0F0',
#link color
link_color: Optional[str] = None,
#figure elements
title: str = 'CellHint label harmonization',
#figure size
show: bool = True, save: Union[str, bool] = False, width: Optional[int] = None, height: Optional[int] = None,
#to fig.update_layout
layout_dict: dict = {},
#to fig.update_traces
trace_dict: dict = {},
#for developer use
expand_label: bool = False,
) -> None:
"""
Generate a Sankey diagram showing the CellHint label harmonization in a qualitative manner.
Parameters
----------
alignment
A :class:`~cellhint.align.DistanceAlignment` or :class:`~pandas.DataFrame` object representing the harmonization result.
node_color
A :class:`~pandas.DataFrame` with three consecutive columns representing dataset, cell type, and color, respectively.
Default to a color scheme that allows matched cell types to have the same colors.
novel_node_color
Color of dataset-specific (i.e., novel) cell types.
(Default: `'#FFFFFF'`)
remain_node_color
Color of remaining unresolved cell types.
(Default: `'#F0F0F0'`)
link_color
Color of links. Default to translucent grey as used in plotly.
title
Figure title.
(Default: `'CellHint label harmonization'`)
show
Whether to show the plot.
(Default: `True`)
save
Whether to save the plot. This can also be a figure filename.
Supported figure suffixes are: .html, .png, .jpg, .jpeg, .webp, .svg, .pdf, .eps.
(Default: `False`)
width
Figure width in pixels.
Default to 700 in a canonical Plotly setting.
height
Figure height in pixels.
Default to 450 in a canonical Plotly setting.
layout_dict
A dict passed to the method `.update_layout` of :class:`plotly.graph_objects.Figure` for setting the figure layout.
Example keys include `plot_bgcolor` which sets the plot area color, `font_color` which sets the text color, `font_size` which sets the text size, etc.
trace_dict
A dict passed to the method `.update_traces` of :class:`plotly.graph_objects.Figure` for setting the Sankey plot.
Example keys include `note_pad` which sets the padding between nodes, `link_line_color` which sets the link border color, `orientation` which sets the plot orientation, etc.
expand_label
Ignored. Whether to show the unique expanded labels. Only for developer use.
(Default: `False`)
Returns
----------
None
"""
if 'plotly' not in sys.modules:
logger.warn(f"⚠️ Warning: to draw a Sankey diagram, package `plotly` is required. Please install `plotly` first")
return
if isinstance(alignment, pd.DataFrame):
relation = alignment
elif hasattr(alignment, 'relation'):
relation = alignment.relation
else:
raise TypeError(
f"🛑 Please provide correct input - either a DistanceAlignment or a data frame")
trace = _relation_to_data(relation, return_sankey = True)[1]
#relation2new_relation is run twice actually, but time cost is negligible; this new relation is row ordered
new_relation = _identify_relation_groups(relation, group_prefix = 'Group', order_row = True, order_column = False)[1]
expanded_label = trace.node.label
original_label = pd.Series(expanded_label).str.split(SEP1, expand = True)[1].values
blank_flag = np.isin(original_label, [NOVEL, REMAIN])
#node color
if isinstance(node_color, pd.DataFrame) and not np.array_equal(np.sort(node_color.iloc[:, 0].astype(str) + SEP1 + node_color.iloc[:, 1].astype(str)), np.sort(expanded_label[~blank_flag])):
raise ValueError(
f"🛑 Please provide a comprehensive combination of datasets and cell types in `node_color`")
color_mapping = _new_relation_to_color(new_relation, node_color, novel_node_color, remain_node_color)
node_color = np.array([color_mapping[x] for x in np.where(blank_flag, original_label, expanded_label)], dtype = object)
#annotations
datasets = new_relation.columns
if (len(trace_dict) >= 1) and ('orientation' in trace_dict) and (trace_dict['orientation'] == 'v'):
annotations = [dict(text = datasets[i], y = 1 - i/(len(datasets)-1), x = 0, xanchor = 'right', xref = 'paper', yanchor = 'top' if i == 0 else ('bottom' if i == len(datasets)-1 else 'middle'), yref = 'paper', showarrow = False) for i in range(len(datasets))]
else:
annotations = [dict(text = datasets[i], x = i/(len(datasets)-1), y = 0, yanchor = 'top', yref = 'paper', xanchor = 'left' if i == 0 else ('right' if i == len(datasets)-1 else 'center'), xref = 'paper', showarrow = False) for i in range(len(datasets))]
#update trace and layout
fig = go.Figure(trace)
fig.update_traces(node_label = np.where(blank_flag, '', original_label) if not expand_label else expanded_label, node_color = node_color, link_color = link_color, **trace_dict)
fig.update_layout(title_text = title, width = width, height = height, annotations = annotations, **layout_dict)
if show:
fig.show()
if save:
ext = os.path.splitext(save)[1] if isinstance(save, str) else '.html'
if ext not in ['.html', '.png', '.jpg', '.jpeg', '.webp', '.svg', '.pdf', '.eps']:
raise ValueError(
f"🛑 Please provide valid figure suffix: .html, .png, .jpg, .jpeg, .webp, .svg, .pdf, .eps")
if ext == '.html':
fig.write_html(save) if isinstance(save, str) else fig.write_html('CellHint_sankeyplot.html')
else:
fig.write_image(save)
[docs]
def treeplot(alignment, group_celltype: bool = True, order_dataset: bool = False,
#link
link_color: str = '#0000007B', link_width: Optional[float] = None,
#root and node
node_shape: Union[list, str] = 'o', node_color: Optional[pd.DataFrame] = None, cmap: Union[matplotlib.colors.Colormap, str] = 'Reds', node_size: Optional[float] = None,
#label
show_label: bool = True, label_color: str = '#000000', label_size: Optional[Union[float, str]] = None, label_ha: str = 'center', label_va: str = 'top',
#figure elements
title: str = 'CellHint label harmonization tree',
#show and/or save figure
ax: Optional[matplotlib.axes.Axes] = None, figsize: Optional[Union[list, tuple]] = None, show: bool = True, save: Union[str, bool] = False,
#link setting
link_dict: dict = {},
#node and root setting
node_dict: dict = {},
#label setting
label_dict: dict = {},
#for developer use
expand_label: bool = False,
) -> None:
"""
Generate a tree showing the CellHint label harmonization in a qualitative manner.
Parameters
----------
alignment
A :class:`~cellhint.align.DistanceAlignment` or :class:`~pandas.DataFrame` object representing the harmonization result.
group_celltype
Whether to group cell types (rows) in the harmonization table for plotting.
(N.B. Do not change the default value of this argument unless you know what you are doing.)
(Default: `True`)
order_dataset
Whether to change the dataset order in each cell type group to manifest as hierarchy (tree).
(Default: `False`)
link_color
Color of links/branches.
(Default: `'#0000007B'`)
link_width
Width of links/branches in points.
Default to 1.5 in a canonical Matplotlib setting.
node_shape
Shape of the node. This can also be a list of symbols for datasets that are aligned.
(Default: `'o'`)
node_color
A :class:`~pandas.DataFrame` with three consecutive columns representing dataset, cell type, and color, respectively.
Default to a color scheme that allows matched cell types to have the same colors.
This can also be a data frame with columns of dataset, cell type, and numeric value (for mapping color gradient).
cmap
Color map to use. This parameter is only relevant if `node_color` is a value-mapping data frame.
(Default: `'Reds'`)
node_size
Size of nodes (cell types) in points.
Default to 6.0 in a canonical Matplotlib setting.
show_label
Whether to label each node with its cell type name.
(Default: `True`)
label_color
Color of cell type labels.
(Default: `'#000000'`)
label_size
Size of cell type labels.
Default to 10.0 in a canonical Matplotlib setting.
label_ha
Horizontal alignment of cell type labels relative to the nodes.
(Default: `'center'`)
label_va
Vertical alignment of cell type labels relative to the nodes.
(Default: `'top'`)
title
Figure title.
(Default: `'CellHint label harmonization tree'`)
ax
An :class:`~matplotlib.axes.Axes` where the tree will be drawn. Default to draw the tree on a new axes.
figsize
Tuple of figure width and height in inches.
Default to auto-adjusting the figure size based on the numbers of datasets and cell types.
show
Whether to show the plot.
(Default: `True`)
save
Whether to save the plot. This can also be a figure filename.
(Default: `False`)
link_dict
A dict passed to :class:`~matplotlib.lines.Line2D` for setting the links/branches.
node_dict
A dict passed to :class:`~matplotlib.lines.Line2D` for setting the nodes.
label_dict
A dict passed to :class:`~matplotlib.text.Text` for setting cell type labels.
expand_label
Ignored. Whether to show the unique expanded labels. Only for developer use.
(Default: `False`)
Returns
----------
None
"""
if isinstance(alignment, pd.DataFrame):
relation = alignment
elif hasattr(alignment, 'relation'):
relation = alignment.relation
else:
raise TypeError(
f"🛑 Please provide correct input - either a DistanceAlignment or a data frame")
new_relation = _identify_relation_groups(relation, group_prefix = 'Group', order_row = group_celltype, order_column = order_dataset)[1]
n_col = new_relation.shape[1]
n_row = new_relation.shape[0]
#node coordinates
node_coord = {}
for col_rank in range(1, n_col + 1):
col_content = new_relation.iloc[:, col_rank - 1]
for celltype in np.unique(col_content):
rows = n_row - np.where(col_content == celltype)[0]
node_coord[celltype] = [col_rank, rows.mean()]
#link pairs
link_pairs = np.row_stack([new_relation.iloc[:, [i, i+1]].drop_duplicates().values for i in range(n_col - 1)])
#ax
if ax is None:
figsize = [3.5*n_col, n_row/3.5] if figsize is None else figsize
_, ax = plt.subplots(figsize = figsize)
#links
for start, end in link_pairs:
xs = [node_coord[start][0], node_coord[end][0]]
ys = [node_coord[start][1], node_coord[end][1]]
ax.plot(xs, ys, ls = '-', marker = 'None', color = link_color, lw = link_width, **link_dict)
if order_dataset:
for first_cell_type in np.unique(new_relation.iloc[:, 0]):
ax.plot([0.5, 1], [(n_row + 1)/2, node_coord[first_cell_type][1]], ls = '-', marker = 'None', color = link_color, lw = link_width, **link_dict)
#nodes and labels
expanded_labels = np.array(list(node_coord.keys()))
original_labels = pd.Series(expanded_labels).str.split(SEP1, expand = True)[1].values
blank_flag = np.isin(original_labels, [NOVEL, REMAIN])
if isinstance(node_color, pd.DataFrame) and not np.array_equal(np.sort(node_color.iloc[:, 0].astype(str) + SEP1 + node_color.iloc[:, 1].astype(str)), np.sort(expanded_labels[~blank_flag])):
raise ValueError(
f"🛑 Please provide a comprehensive combination of datasets and cell types in `node_color`")
color_mapping = _new_relation_to_color(new_relation, node_color, None, None, cmap)
if not isinstance(node_shape, str) and len(node_shape) < n_col:
raise ValueError(
f"🛑 Please provide `node_shape` of length {n_col}")
node_shapes = dict(zip(relation.columns[::2], [node_shape]*n_col)) if isinstance(node_shape, str) else dict(zip(relation.columns[::2], node_shape))
for node, coord in node_coord.items():
dst = node.split(SEP1)[0]
original_label = node.split(SEP1)[1]
if original_label in [NOVEL, REMAIN]:
continue
ax.plot(coord[0], coord[1], marker = node_shapes[dst], ms = node_size, color = color_mapping[node], ls = 'None', **node_dict)
if show_label:
ax.text(coord[0], coord[1], node if expand_label else original_label, color = label_color, size = label_size, ha = label_ha, va = label_va, **label_dict)
if order_dataset:
ax.plot(0.5, (n_row + 1)/2, marker = 'o', ms = node_size, color = '#000000', ls = 'None', **node_dict)
#others
ax.set(xlim = [0, n_col+1], ylim = [0, n_row+1], title = title)
ax.set_axis_off()
if not order_dataset:
for col_rank in range(1, n_col + 1):
ax.text(col_rank, n_row+0.5, new_relation.columns[col_rank-1], color = label_color, size = label_size, ha = 'center', va = 'center', weight = 'bold')
#show and save
if save:
plt.savefig(save) if isinstance(save, str) else plt.savefig('CellHint_treeplot.pdf')
if show:
plt.show()
if save:
plt.close()
def heatmap(alignment, plot_type: str = 'similarity',
#colors
dataset_color: Optional[pd.DataFrame] = None, celltype_color: Optional[pd.DataFrame] = None,
#cell
vmin: Optional[float] = None, vmax: Optional[float] = None, cmap: str = 'RdBu_r',
#labels
dataset_celltype_sep: str = ': ',
#dendrogram
cluster: bool = True, show_row_dendrogram: bool = True, show_col_dendrogram: bool = True,
#figure
figsize: Union[list, tuple] = (10, 10), ax: Optional[matplotlib.axes.Axes] = None, show: bool = True, save: Union[str, bool] = False,
#others
**kwargs) -> None:
"""
Generate a heatmap showing the cell type relationships within and across datasets (i.e., meta-analysis).
Parameters
----------
alignment
A :class:`~cellhint.align.DistanceAlignment` object containing the meta-analysis result.
plot_type
The type of heatmap to show, being either cross-dataset cell type transcriptome similarities (`'similarity'`) or membership (`'membership'`).
(Default: `'similarity'`)
dataset_color
A :class:`~pandas.DataFrame` with two consecutive columns representing dataset and color, respectively.
Default to a CellHint color cycle.
celltype_color
A :class:`~pandas.DataFrame` with three consecutive columns representing dataset, cell type, and color, respectively.
Default to a color scheme that allows matched cell types to have the same colors.
vmin
Minimal value to anchor the colormap.
Default to 0 unless `plot_type = 'similarity'` and normalization is not performed during cell type harmonization.
vmax
Maximal value to anchor the colormap.
Default to 1 unless `plot_type = 'similarity'` and normalization is not performed during cell type harmonization.
cmap
Mapping from data values to color space.
(Default: `'RdBu_r'`)
dataset_celltype_sep
Separator to connect names of data sets and cell types for displaying.
cluster
Whether to cluster the rows and columns of the heatmap.
(Default: `True`)
show_{row,col}_dendrogram
Whether to show the row/column dendrogram.
(Default: `True`)
figsize
Tuple of figure width and height in inches.
Default to 10 inches in both dimensions.
ax
An :class:`~matplotlib.axes.Axes` where the heatmap will be drawn. Default to draw the tree on a new axes.
show
Whether to show the plot.
(Default: `True`)
save
Whether to save the plot. This can also be a figure filename.
(Default: `False`)
others
All other parameters are the same as :func:`seaborn.clustermap` with selected tags and customized defaults.
Returns
----------
None
"""
pass