"""
Plotting utilities for RamanLib.
This module provides convenience functions to visualize spectra stored in a
:class:`ramanlib.core.GroupedSpectralContainer` (GSC) and derived statistics
computed by :mod:`ramanlib.calc`. Functions are designed so that common analysis
outputs feed directly into plotting helpers—for example:
- :func:`ramanlib.calc.outliers_per_group` → :func:`ramanlib.plot.outliers_per_group`
- :func:`ramanlib.calc.mean_difference` → :func:`ramanlib.plot.mean_difference`
- :func:`ramanlib.calc.mean_correlation_per_group` → :func:`ramanlib.plot.mean_correlation_per_group`
Most functions delegate the actual drawing to :mod:`ramanspy`’s plotting backend
and Matplotlib/Seaborn, adding light logic for grouping, sampling, and overlays.
Notes
-----
Return types mirror whatever the underlying plotting call returns (usually a
Matplotlib :class:`matplotlib.axes.Axes`, a list of Axes, or a
:class:`matplotlib.figure.Figure`), so these functions compose naturally with
your own Matplotlib pipelines (titles, labels, styling).
"""
from __future__ import annotations
import numpy as np
import ramanspy as rp
from .utils import _normalize_axes_obj
import random
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
[docs]
def mean_per_group(gsc, by=None, interval=None, plot_type="separate", ci_z=1.96, **kwargs):
"""
Plot mean spectrum per group with optional uncertainty bands.
The group means and (optionally) per-wavenumber standard deviation and
confidence intervals are computed via :meth:`gsc.mean` with
``include_stats=True``. Bands can represent standard deviation (``±SD``) or
normal-approximation confidence intervals (``± z * SD / sqrt(n)``).
Parameters
----------
gsc : GroupedSpectralContainer
Input container.
by : str or list of str or None, optional
Column name(s) used to form groups. If ``None``, all rows form a single
group named ``"all"``. Passed to :meth:`pandas.DataFrame.groupby`.
interval : {None, "ci", "sd"}, optional
Type of band to draw around the mean. ``"sd"`` plots ``± SD``.
``"ci"`` plots ``± z * SD / sqrt(n)`` with ``z=ci_z``.
``None`` (default) disables bands.
plot_type : {"separate", "single", "single stacked"}, optional
Plot style passed to :func:`ramanspy.plot.spectra`. Default is
``"separate"`` (one subplot per group). For ``"single stacked"``,
interval bands are disabled due to vertical offsets.
ci_z : float, optional
Z-score for CI bands (e.g., ``1.96`` ≈ 95% CI). Ignored if
``interval != "ci"``. Default is ``1.96``.
**kwargs
Forwarded to :func:`ramanspy.plot.spectra` and Matplotlib.
Returns
-------
matplotlib.axes.Axes or list[matplotlib.axes.Axes] or matplotlib.figure.Figure
Whatever :func:`ramanspy.plot.spectra` returns.
See Also
--------
ramanlib.core.GroupedSpectralContainer.mean
Computes group means and optional per-wavenumber statistics.
ramanlib.plot.random_per_group
Sample and plot random spectra per group.
Notes
-----
Bands require the presence of ``'std_vector'`` and, for ``interval="ci"``,
also ``'n'`` in the dataframe returned by ``gsc.mean(include_stats=True)``.
"""
# 1) Compute group means + stats once
means_gsc = gsc.mean(by=by, include_stats=True, ddof=1)
df = means_gsc.df
# 2) Prepare means and labels
group_means = df["spectrum"].tolist()
# Use user labels if provided; otherwise build them. Avoid double-passing via kwargs.
group_labels = kwargs.pop("label", None)
if group_labels is None:
if by is None:
group_labels = ["all"]
else:
grouped = df.groupby(by, dropna=False)
group_labels = [
", ".join(map(str, key)) if isinstance(key, tuple) else str(key)
for key, _ in grouped
]
spectral_axis = group_means[0].spectral_axis if group_means else None
# 3) Precompute bands if requested
bands = [None] * len(df)
if interval in ("ci", "sd"):
if "std_vector" not in df or (interval == "ci" and "n" not in df):
warnings.warn("Required statistics not present; skipping interval bands.")
else:
for i, row in df.iterrows():
std = row["std_vector"]
if std is None or (isinstance(std, np.ndarray) and std.size == 0):
bands[i] = None
continue
if interval == "sd":
band = std
else: # 'ci'
n = int(row["n"])
band = (ci_z * std / np.sqrt(n)) if n > 0 else None
bands[i] = band
# 4) Plot means
axes_obj = rp.plot.spectra(group_means, label=group_labels, plot_type=plot_type, **kwargs)
axes = _normalize_axes_obj(axes_obj)
# 5) Handle stacked-with-offset limitation
if (plot_type or "").lower() == "single stacked" and interval is not None:
warnings.warn("Interval bands disabled for 'single stacked' due to vertical offsets.")
return axes_obj
# 6) Overlay bands on the correct axes
if spectral_axis is not None and any(b is not None for b in bands):
for ax, mean_spec, band in zip(axes, group_means, bands):
if band is None:
continue
ax.fill_between(
spectral_axis,
mean_spec.spectral_data - band,
mean_spec.spectral_data + band,
alpha=0.2
)
return axes_obj
[docs]
def random_per_group(gsc, by=None, n_samples=3, plot_type="single", seed=None, **kwargs): # Note: fix 'label' attribute if user gives this as input
"""
Plot a random sample of spectra from each group.
Parameters
----------
gsc : GroupedSpectralContainer
Input container.
by : str or list of str or None, optional
Column name(s) to group by. If ``None``, samples from all rows as one group.
n_samples : int, optional
Number of spectra to sample per group. If a group has fewer than
``n_samples`` rows, sampling with replacement is used to reach ``n_samples``.
Default is ``3``.
plot_type : {"single", "separate", "single stacked"}, optional
Plot style passed to :func:`ramanspy.plot.spectra`. Default ``"single"``.
seed : int or None, optional
Random seed for reproducibility. Default ``None``.
**kwargs
Forwarded to :func:`ramanspy.plot.spectra` and Matplotlib.
Returns
-------
matplotlib.axes.Axes or list[matplotlib.axes.Axes] or matplotlib.figure.Figure
Whatever :func:`ramanspy.plot.spectra` returns.
See Also
--------
ramanlib.plot.mean_per_group
Plot group means with optional uncertainty bands.
"""
rng = random.Random(seed) # local RNG
def _sample_k(spectra, k):
if len(spectra) == 0:
return []
if k <= len(spectra):
return rng.sample(spectra, k)
return spectra[:] + rng.choices(spectra, k=k - len(spectra))
if by is None:
spectra = gsc.df["spectrum"].tolist()
spectra_groups = [_sample_k(spectra, n_samples)]
group_labels = ["all"]
else:
grouped = gsc.df.groupby(by)
spectra_groups, group_labels = [], []
for key, group_df in grouped:
sample = _sample_k(group_df["spectrum"].tolist(), n_samples)
spectra_groups.append(sample)
group_labels.append(", ".join(map(str, key)) if isinstance(key, tuple) else str(key))
return rp.plot.spectra(spectra_groups, label=group_labels, plot_type=plot_type, **kwargs)
[docs]
def outliers_per_group(gsc, results, **kwargs):
"""
Plot detected outlier spectra for each group and overlay the group mean.
This is the plotting counterpart of :func:`ramanlib.calc.outliers_per_group`.
It expects the exact mapping produced by that function and draws, for each
group, the selected "outlier" spectra in a separate subplot, with the group's
mean spectrum overlaid.
Parameters
----------
gsc : GroupedSpectralContainer
The container from which spectra are retrieved by global row index.
results : dict[str, tuple[list[int], rp.Spectrum]]
Output mapping from :func:`ramanlib.calc.outliers_per_group`, i.e.,
``{ group_label: ([row_indices_into_gsc_df], mean_spectrum) }``.
**kwargs
Forwarded to :func:`ramanspy.plot.spectra` (e.g., ``color``, ``linewidth``,
``title``, ``ax``). Any provided ``plot_type`` is ignored (layout is fixed
to ``"separate"`` for robustness and clarity).
Returns
-------
matplotlib.axes.Axes or list[matplotlib.axes.Axes] or matplotlib.figure.Figure or None
Whatever :func:`ramanspy.plot.spectra` returns. Returns ``None`` if
``results`` is empty.
See Also
--------
ramanlib.calc.outliers_per_group
Compute per-group outlier indices and each group's mean spectrum.
Notes
-----
Overlays the supplied per-group mean spectrum in red. No legend/tight layout
adjustments are performed here, so you can customize them upstream or
downstream as needed.
"""
if not results:
return None
# Don’t allow plot_type injection here to avoid edge cases.
if "plot_type" in kwargs:
warnings.warn("plot_type is fixed to 'separate' for this plot; ignoring provided plot_type.")
kwargs = {k: v for k, v in kwargs.items() if k != "plot_type"}
group_labels = list(results.keys())
spectra_groups = []
means_for_overlay = []
for label in group_labels:
idxs, mean_spec = results[label]
spectra_groups.append(gsc.df.loc[idxs, "spectrum"].tolist())
means_for_overlay.append(mean_spec)
axes_obj = rp.plot.spectra(
spectra_groups,
label=group_labels,
plot_type="separate",
**kwargs
)
# Normalize to a list of Axes to overlay the mean
axes_list = _normalize_axes_obj(axes_obj)
# Overlay the mean line (no labels/legend/tight_layout/show here)
for ax, mean_spec in zip(axes_list, means_for_overlay):
ax.plot(mean_spec.spectral_axis, mean_spec.spectral_data, color="red", linewidth=1.5)
return axes_obj
[docs]
def baseline(spectrum, baseline_process, **kwargs):
"""
Plot a spectrum, its estimated baseline, and the baseline-corrected spectrum.
Parameters
----------
spectrum : rp.Spectrum
Input spectrum to correct.
baseline_process : ramanspy.preprocessing.PreprocessingStep or ramanspy.preprocessing.Pipeline
A single preprocessing step or a multi-step preprocessing pipeline from RamanSPy.
The object must expose ``.apply(Spectrum) -> Spectrum``, returning the corrected
spectrum on the same spectral axis (Raman wavenumber cm⁻¹).
**kwargs
Forwarded to :func:`ramanspy.plot.spectra` (e.g., ``ax``, ``alpha``,
line styling, etc.).
Returns
-------
matplotlib.axes.Axes or list[matplotlib.axes.Axes] or matplotlib.figure.Figure
Whatever :func:`ramanspy.plot.spectra` returns.
Notes
-----
The baseline is computed as ``baseline = original - corrected`` and plotted alongside
the original and corrected spectra with fixed labels:
``["Original spectrum", "removed baseline", "corrected spectrum"]``.
See Also
--------
ramanspy.preprocessing.PreprocessingStep
ramanspy.preprocessing.Pipeline
"""
corrected_spectrum = baseline_process.apply(spectrum)
baseline = rp.Spectrum(spectrum.spectral_data - corrected_spectrum.spectral_data, spectrum.spectral_axis)
spectra = [spectrum, baseline, corrected_spectrum]
labels = ["Original spectrum", "removed baseline", "corrected spectrum"]
return rp.plot.spectra(spectra, label=labels, plot_type="single", alpha=0.9, **kwargs)
[docs]
def n_baselines(raw_gsc, baseline_process, process_name, n_samples=3, figsize=(8,7), seed=None):
"""
Plot several randomly selected spectra with their baselines in a vertical figure.
Parameters
----------
raw_gsc : GroupedSpectralContainer
Container from which spectra are sampled.
baseline_process : ramanspy.preprocessing.PreprocessingStep or ramanspy.preprocessing.Pipeline
Baseline-correction operator applied to each sampled spectrum. Must implement
``.apply(Spectrum) -> Spectrum``.
process_name : str
Title displayed above the figure.
n_samples : int, optional
Number of spectra (rows) to sample. Default ``3``.
figsize : tuple[float, float], optional
Matplotlib figure size passed to :func:`matplotlib.pyplot.subplots`.
Default ``(8, 7)``.
seed : int or None, optional
Random seed used when sampling. Default ``None``.
Returns
-------
list[matplotlib.axes.Axes]
The list of axes for each subplot row.
See Also
--------
ramanspy.preprocessing.PreprocessingStep
ramanspy.preprocessing.Pipeline
"""
spec_samples = raw_gsc.df.sample(n=n_samples)["spectrum"]
fig, axs = plt.subplots(n_samples, 1, figsize=figsize)
for i, spec in enumerate(spec_samples):
baseline(spec, baseline_process, ax=axs[i], title="", xlabel="")
fig.suptitle(f"{process_name}")
plt.xlabel("Wavenumber (cm⁻¹)")
plt.tight_layout()
return axs
[docs]
def compare_baselines(spectrum, baseline_processes, process_names, figsize=(8,7)):
"""
Compare multiple baseline algorithms (steps or pipelines) on the same spectrum.
Parameters
----------
spectrum : rp.Spectrum
Input spectrum to be corrected by each baseline process.
baseline_processes : list[ramanspy.preprocessing.PreprocessingStep or ramanspy.preprocessing.Pipeline]
Sequence of RamanSPy preprocessing operators. Each must implement
``.apply(Spectrum) -> Spectrum``.
process_names : list[str]
Display names for each process. Must have the same length and order as
``baseline_processes``.
figsize : tuple[float, float], optional
Matplotlib figure size. Default ``(8, 7)``.
Returns
-------
list[matplotlib.axes.Axes]
The list of axes for each subplot row.
See Also
--------
ramanspy.preprocessing.PreprocessingStep
ramanspy.preprocessing.Pipeline
"""
fig, axs = plt.subplots(len(baseline_processes), 1, figsize=figsize)
for i, process in enumerate(baseline_processes):
baseline(spectrum, process, ax=axs[i], title=f"{process_names[i]}", xlabel="")
plt.xlabel("Wavenumber(cm⁻¹)")
plt.tight_layout()
return axs
[docs]
def mean_difference(diff_spectrum, ci_band, label="Difference in Means", **kwargs):
"""
Plot a difference-of-means spectrum with a two-sided CI band centered at zero.
This is the plotting counterpart of :func:`ramanlib.calc.mean_difference`.
Pass in the tuple returned by that function’s computation step.
Parameters
----------
diff_spectrum : rp.Spectrum
The difference spectrum (e.g., group A mean minus group B mean).
Typically the first element returned by
:func:`ramanlib.calc.mean_difference`.
ci_band : numpy.ndarray
One-dimensional nonnegative array giving the half-width of the symmetric
confidence band at each wavenumber (i.e., plot ``± ci_band``). Typically
the second element returned by :func:`ramanlib.calc.mean_difference`.
label : str, optional
Legend label for the difference trace. Default ``"Difference in Means"``.
**kwargs
Forwarded to :func:`ramanspy.plot.spectra`. If ``plot_type`` is supplied,
it is ignored and a warning is issued (this plot always uses ``"single"``).
Returns
-------
matplotlib.axes.Axes or list[matplotlib.axes.Axes] or matplotlib.figure.Figure
Whatever :func:`ramanspy.plot.spectra` returns.
Notes
-----
The shaded band is drawn as ``[-ci_band, +ci_band]`` about zero on the same
x-axis as ``diff_spectrum``. A horizontal reference line at ``y=0`` is added.
"""
if "plot_type" in kwargs.keys():
warnings.warn('Only plot_type="single" is supported for mean_difference')
kwargs.pop("plot_type")
ax_obj = rp.plot.spectra(diff_spectrum, label=label, plot_type="single", **kwargs)
axs = _normalize_axes_obj(ax_obj)
axs[0].fill_between(diff_spectrum.spectral_axis, -ci_band, ci_band, color='gray', alpha=0.3, label='95% Confidence Band')
plt.axhline(0, color='gray', linestyle='--', linewidth=1)
return ax_obj
[docs]
def mean_correlation_per_group(
correlation_matrix,
title="Correlation Matrix of Raman Spectra",
vmin=0,
vmax=1,
annot=True,
cmap="coolwarm",
figsize=(8, 6),
**kwargs,
):
"""
Plot a heatmap of correlations between per-group mean spectra.
This is the plotting counterpart of
:func:`ramanlib.calc.mean_correlation_per_group`. Pass in the
square correlation matrix that function returns.
Parameters
----------
correlation_matrix : pandas.DataFrame
Square matrix of correlation coefficients; index/columns are group labels
in the same order used to compute the means.
title : str, optional
Figure title. Default ``"Correlation Matrix of Raman Spectra"``.
vmin : float, optional
Lower bound for the color scale. Default ``0``.
vmax : float, optional
Upper bound for the color scale. Default ``1``.
annot : bool, optional
Whether to annotate each cell with its numeric value. Default ``True``.
cmap : str, optional
Colormap passed to :func:`seaborn.heatmap`. Default ``"coolwarm"``.
figsize : tuple[float, float], optional
Matplotlib figure size. Default ``(8, 6)``.
**kwargs
Additional keyword arguments forwarded to :func:`seaborn.heatmap`.
Returns
-------
matplotlib.axes.Axes
The Axes containing the heatmap.
"""
plt.figure(figsize=figsize)
sns.heatmap(correlation_matrix, annot=annot, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
plt.title(title)