"""Plotting functions for 1D and 2D power spectra."""
import warnings
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rcParams
from matplotlib.colors import LogNorm
from scipy.ndimage import gaussian_filter
from ..summaries import CylindricalPS, SphericalPS
[docs]
def plot_1d_power_spectrum_k(
power_spectrum: SphericalPS,
*,
ax: plt.Axes | None = None,
title: str | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
color: list | None = None,
log: list[bool] | None = False,
fontsize: float | None = 16,
legend: str | None = None,
smooth: float | bool = False,
legend_kwargs: dict | None = None,
) -> tuple[plt.Figure, plt.Axes]:
"""
Plot 1D power spectrum vs wave mode.
Parameters
----------
power_spectrum : SphericalPS
Instance of the SphericalPS class.
ax : plt.Axes, optional
Axes object to plot on. If None, a new axes is created.
title : str, optional
Title of the plot.
xlabel : str, optional
Label for the x-axis.
ylabel : str, optional
Label for the y-axis.
color : str, optional
Color of the PS line in the plot.
log : list[bool], optional
List of booleans to set the x and y axes to log scale.
fontsize : float, optional
Font size for the plot labels.
legend : str, optional
Legend label for the PS.
smooth : float, optional
Standard deviation for Gaussian smoothing.
If True, uses a standard deviation of 1.
legend_kwargs : dict, optional
Keyword arguments for the legend.
"""
if not isinstance(power_spectrum, SphericalPS):
raise ValueError(
"power_spectrum must be a SphericalPS object,"
f" got {type(power_spectrum)} instead."
)
rcParams.update({"font.size": fontsize})
wavemodes = power_spectrum.kcenters
is_deltasq = power_spectrum.is_deltasq
power_spectrum = power_spectrum.ps
if color is None:
color = "C0"
if xlabel is None:
xlabel = f"k [{wavemodes.unit:latex_inline}]"
if ylabel is None:
ylabel = f"[{power_spectrum.unit:latex_inline}]"
ylabel = r"$\Delta^2_{21} \,$" + ylabel if is_deltasq else r"$P(k) \,$" + ylabel
if smooth:
power_spectrum = gaussian_filter(power_spectrum, sigma=smooth)
ax.plot(wavemodes, power_spectrum, color=color, label=legend)
if title is not None:
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
if log[0]:
ax.set_xscale("log")
if log[1]:
ax.set_yscale("log")
if legend is not None:
ax.legend(**legend_kwargs)
return ax
[docs]
def plot_1d_power_spectrum_z(
power_spectra: list[SphericalPS],
at_k: float,
*,
ax: plt.Axes | None = None,
title: str | None = None,
xlabel: str | None = "Redshift",
ylabel: str | None = None,
color: list | None = "C0",
log: list[bool] | None = False,
fontsize: float | None = 16,
legend: str | None = None,
smooth: float | bool = False,
legend_kwargs: dict | None = None,
) -> tuple[plt.Figure, plt.Axes]:
"""
Plot 1D power spectra as a function of redshift at a given scale.
Parameters
----------
power_spectrum : list[SphericalPS]
List of instances of the SphericalPS class.
at_k : float
If provided, plots the 1D power spectrum at a specific k value.
The k value is assumed to be in the same unit
as the k in the SphericalPS instance wavemodes.
ax : plt.Axes, optional
Axes object to plot on. If None, a new axes is created.
title : str, optional
Title of the plot.
xlabel : str, optional
Label for the x-axis.
ylabel : str, optional
Label for the y-axis.
color : str, optional
Color of the PS line in the plot.
log : list[bool], optional
List of booleans to set the x and y axes to log scale.
fontsize : float, optional
Font size for the plot labels.
legend : str, optional
Legend label for the PS.
smooth : float, optional
Standard deviation for Gaussian smoothing.
If True, uses a standard deviation of 1.
legend_kwargs : dict, optional
Keyword arguments for the legend.
"""
for i in range(len(power_spectra)):
if not isinstance(power_spectra[i], SphericalPS):
raise ValueError(
"power_spectrum must be a SphericalPS object or a list of "
"SphericalPS objects,"
f" got {type(power_spectra[i])} instead."
)
rcParams.update({"font.size": fontsize})
is_deltasq = power_spectra[0].is_deltasq
xaxis = [ps.redshift for ps in power_spectra]
kbins = np.abs(power_spectra[0].kcenters.value - at_k)
kbins[np.isnan(kbins)] = np.inf # Avoid NaNs
at_k = np.argmin(kbins)
if ylabel is None:
ylabel = f"[{power_spectra[0].ps.unit:latex_inline}]"
ylabel = r"$\Delta^2_{21} \,$" + ylabel if is_deltasq else r"$P(k) \,$" + ylabel
psvals = []
for power_spectrum in power_spectra:
if smooth:
ps = gaussian_filter(power_spectrum.ps, sigma=smooth)
else:
ps = power_spectrum.ps.value
psvals.append(ps[at_k])
ax.plot(xaxis, psvals, color=color, label=legend)
if title is not None:
ax.set_title(title, fontsize=fontsize)
ax.set_xlabel(xlabel, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
if log[0]:
ax.set_xscale("log")
if log[1]:
ax.set_yscale("log")
if legend is not None:
ax.legend(**legend_kwargs)
return ax
[docs]
def plot_2d_power_spectrum(
power_spectrum: CylindricalPS,
*,
ax: plt.Axes | None = None,
title: str | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
clabel: str | None = None,
cmap: str | None = "viridis",
fontsize: float | None = 16,
vmin: float | None = None,
vmax: float | None = None,
log: list[bool] | None = False,
smooth: float | bool = False,
cbar: bool | None = True,
) -> tuple[plt.Figure, plt.Axes]:
"""
Plot a 2D power spectrum.
Parameters
----------
power_spectrum : CylindricalPS
Instance of the CylindricalPS class.
axs : plt.Axes | list[plt.Axes], optional
Axes object(s) to plot on. If None, new axes are created.
title : str, optional
Title(s) of the plot.
xlabel : str, optional
Label for the x-axis.
ylabel : str, optional
Label for the y-axis.
clabel : str, optional
Label for the colorbar.
cmap : str, optional
Colormap for the plot.
fontsize : float, optional
Font size for the plot labels.
vmin : float, optional
Minimum value for the color scale.
vmax : float, optional
Maximum value for the color scale.
log : list[bool], optional
List of booleans to set the kperp, kpar, and PS axes to log scale.
smooth : float, optional
Standard deviation for Gaussian smoothing.
Default is False, if True, uses a standard deviation of 1.
"""
if not isinstance(power_spectrum, CylindricalPS):
raise ValueError(
"power_spectrum must be a CylindricalPS object,"
f" got {type(power_spectrum)} instead."
)
rcParams.update({"font.size": fontsize})
kperp = power_spectrum.kperp
kpar = power_spectrum.kpar
is_deltasq = power_spectrum.is_deltasq
power_spectrum = power_spectrum.ps
if xlabel is None:
xlabel = r"k$_\perp \,$" + f"[{kperp.unit:latex_inline}]"
if ylabel is None:
ylabel = r"k$_\parallel \,$" + f"[{kpar.unit:latex_inline}]"
if clabel is None:
clabel = f"[{power_spectrum.unit:latex_inline}]"
clabel = r"$\Delta^2_{21} \,$" + clabel if is_deltasq else r"$P(k) \,$" + clabel
cmap_kwargs = {}
if vmin is None:
if log[2]:
cmap_kwargs["vmin"] = np.nanpercentile(np.log10(power_spectrum.value), 5)
else:
cmap_kwargs["vmin"] = np.nanpercentile(power_spectrum.value, 5)
if vmax is None:
if log[2]:
cmap_kwargs["vmax"] = np.nanpercentile(np.log10(power_spectrum.value), 95)
else:
cmap_kwargs["vmax"] = np.nanpercentile(power_spectrum.value, 95)
if log[2]:
cmap_kwargs = {}
cmap_kwargs["norm"] = LogNorm(vmin=vmin, vmax=vmax)
if title is not None:
ax.set_title(title, fontsize=fontsize)
ax.set_ylabel(ylabel, fontsize=fontsize)
if smooth:
unit = power_spectrum.unit
power_spectrum = gaussian_filter(power_spectrum, sigma=smooth) * unit
mask = np.isnan(np.nanmean(power_spectrum, axis=-1))
power_spectrum = power_spectrum[~mask]
kperp = kperp[~mask]
im = ax.pcolormesh(
kperp.value,
kpar.value,
power_spectrum.value.T,
cmap=cmap,
**cmap_kwargs,
)
ax.set_xlabel(xlabel, fontsize=fontsize)
if cbar:
plt.colorbar(im, label=clabel)
if log[0]:
ax.set_xscale("log")
if log[1]:
ax.set_yscale("log")
return ax
[docs]
def plot_power_spectrum(
power_spectrum: SphericalPS | CylindricalPS | list[SphericalPS],
*,
ax: plt.Axes | list[plt.Axes] | None = None,
title: str | None = None,
xlabel: str | None = None,
ylabel: str | None = None,
clabel: str | None = None,
at_k: float | int | None = None,
cmap: str | None = "viridis",
color: list | None = None,
fontsize: float | None = 16,
vmin: float | None = None,
vmax: float | None = None,
logx: bool | None = False,
logy: bool | None = False,
logc: bool | None = False,
cbar: bool | None = True,
legend: str | None = None,
smooth: float | bool = False,
legend_kwargs: dict | None = None,
) -> tuple[plt.Figure, plt.Axes]:
"""
Plot a power spectrum.
Parameters
----------
power_spectrum : CylindricalPS | SphericalPS | list[SphericalPS]
Instance of the CylindricalPS class, or instance or
list of instances of the or SphericalPS class.
ax : plt.Axes | list[plt.Axes], optional
Axes object(s) to plot on. If None, new axes are created.
title : str, optional
Title of the plot.
xlabel : str, optional
Label for the x-axis.
ylabel : str, optional
Label for the y-axis.
clabel : str, optional
Label for the colorbar.
at_k : float | int, optional
If provided, plots the 1D power spectrum at a specific k value.
If int, it is interpreted as the index of the k value.
If float, it is interpreted as the k value itself
in the same unit as the k in the SphericalPS instance wavemodes.
cmap : str, optional
Colormap for the plot.
colors : list, optional
List of colors for each line in the plot.
fontsize : float, optional
Font size for the plot labels.
vmin : float, optional
Minimum value for the color scale.
vmax : float, optional
Maximum value for the color scale.
logx : bool, optional
Whether to set the x-axis to log scale.
logy : bool, optional
Whether to set the y-axis to log scale.
logc : bool, optional
Whether to set the color-axis to log scale.
legend : str, optional
Legend label for the 1D PS.
smooth : bool or float, optional
Standard deviation for Gaussian smoothing.
If True, uses a standard deviation of 1.
legend_kwargs : dict, optional
Keyword arguments for the legend on the 1D PS plot.
"""
if isinstance(smooth, bool) and smooth:
smooth = 1.0
if isinstance(power_spectrum, SphericalPS):
if legend_kwargs is None:
legend_kwargs = {}
if ax is None:
fig, ax = plt.subplots(
nrows=1, ncols=1, figsize=(7, 6), sharey=True, sharex=True
)
ax = plot_1d_power_spectrum_k(
power_spectrum,
ax=ax,
title=title,
xlabel=xlabel,
ylabel=ylabel,
color=color,
fontsize=fontsize,
log=[logx, logy],
legend=legend,
smooth=smooth,
legend_kwargs=legend_kwargs,
)
elif hasattr(power_spectrum, "__len__") and np.all(
[isinstance(ps, SphericalPS) for ps in power_spectrum]
):
if legend_kwargs is None:
legend_kwargs = {}
ax = plot_1d_power_spectrum_z(
power_spectrum,
at_k,
ax=ax,
title=title,
xlabel=xlabel,
ylabel=ylabel,
color=color,
fontsize=fontsize,
log=[logx, logy],
legend=legend,
smooth=smooth,
legend_kwargs=legend_kwargs,
)
elif isinstance(power_spectrum, CylindricalPS):
if legend is not None or legend_kwargs is not None:
warnings.warn(
"Cylindrical PS plots do not support labels and legends.", stacklevel=2
)
if ax is None:
fig, ax = plt.subplots(
nrows=1, ncols=1, figsize=(7, 6), sharey=True, sharex=True
)
cbar = True
else:
fig = ax.get_figure()
if len(fig.get_axes()) > 1:
cbar = False
ax = plot_2d_power_spectrum(
power_spectrum,
ax=ax,
title=title,
xlabel=xlabel,
ylabel=ylabel,
clabel=clabel,
cmap=cmap,
fontsize=fontsize,
vmin=vmin,
vmax=vmax,
log=[logx, logy, logc],
smooth=smooth,
cbar=cbar,
)
else:
raise ValueError(
"Input must be SphericalPS or CylindricalPS objects,"
f"got {type(power_spectrum)} instead."
)
return ax