Source code for tuesday.core.plotting.sliceplots

"""Module for LC and coeval sliceplots."""

from collections.abc import Callable

import matplotlib.pyplot as plt
import numpy as np
from astropy import units as un
from matplotlib import colormaps, colors, rcParams
from matplotlib.colors import LogNorm
from scipy.ndimage import gaussian_filter

from ..units import validate

try:
    eor_colour = colors.LinearSegmentedColormap.from_list(
        "eor",
        [
            (0, "white"),
            (0.21, "yellow"),
            (0.42, "orange"),
            (0.63, "red"),
            (0.86, "black"),
            (0.9, "blue"),
            (1, "cyan"),
        ],
    )

    colormaps.register(cmap=eor_colour)
except ValueError:
    # If the colormap already exists, we can ignore this error.
    pass


def _plot_slice(
    img_slice: un.Quantity,
    xaxis: un.Quantity,
    yaxis: un.Quantity,
    *,
    vmin: float | None = None,
    vmax: float | None = None,
    fontsize: float | None = 16,
    log: tuple[bool, bool, bool] = (False, False, False),
    title: str | None = None,
    xlabel: str | None = None,
    ylabel: str | None = None,
    clabel: str | None = None,
    ax: plt.Axes | None = None,
    cmap: str = "viridis",
) -> plt.Axes:
    """Plot a 2D slice of the data."""
    validate(yaxis, "length")
    rcParams.update({"font.size": fontsize})
    if xaxis.unit.physical_type != "dimensionless":
        validate(xaxis, "length")
    if ax is None:
        _, ax = plt.subplots()
    cmap_kwargs = {}
    if vmin is None:
        if log[2]:
            cmap_kwargs["vmin"] = np.nanpercentile(np.log10(img_slice.value), 5)
        else:
            cmap_kwargs["vmin"] = np.nanpercentile(img_slice.value, 5)
    else:
        cmap_kwargs["vmin"] = vmin
    if vmax is None:
        if log[2]:
            cmap_kwargs["vmax"] = np.nanpercentile(np.log10(img_slice.value), 95)
        else:
            cmap_kwargs["vmax"] = np.nanpercentile(img_slice.value, 95)
    else:
        cmap_kwargs["vmax"] = vmax
    if log[2]:
        cmap_kwargs = {}
        cmap_kwargs["norm"] = LogNorm(vmin=vmin, vmax=vmax)
    im = ax.pcolormesh(
        xaxis.value,
        yaxis.value,
        img_slice.value.T,
        cmap=cmap,
        shading="auto",
        **cmap_kwargs,
    )

    if log[0]:
        ax.set_xscale("log")
    if log[1]:
        ax.set_yscale("log")
    if title is not None:
        ax.set_title(title)

    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)

    plt.colorbar(im, ax=ax, label=clabel)

    return ax


def lc2slice_x(
    zmin: float | None = None,
    zmax: float | None = None,
    idx: int | None = 0,
) -> un.Quantity:
    """Get the slice index for a given redshift range."""

    def slice_index(
        box: un.Quantity, redshift: np.ndarray | un.Quantity
    ) -> un.Quantity:
        """Get the slice index for a given redshift range."""
        idx_min = 0 if zmin is None else np.argmin(np.abs(redshift - zmin))
        idx_max = (
            box.shape[-1] if zmax is None else np.argmin(np.abs(redshift - zmax)) + 1
        )

        return box[idx, :, idx_min:idx_max]

    return slice_index


def lc2slice_y(
    zmin: float | None = None,
    zmax: float | None = None,
    idx: int | None = 0,
) -> un.Quantity:
    """Get the slice index for a given redshift range."""

    def slice_index(
        box: un.Quantity, redshift: np.ndarray | un.Quantity
    ) -> un.Quantity:
        """Get the slice index for a given redshift range."""
        idx_min = 0 if zmin is None else np.argmin(np.abs(redshift - zmin))
        idx_max = (
            box.shape[-1] if zmax is None else np.argmin(np.abs(redshift - zmax)) + 1
        )

        return box[:, idx, idx_min:idx_max]

    return slice_index


def coeval2slice_x(
    idx: int | None = 0,
) -> un.Quantity:
    """Slice the box along the x-axis."""

    def slice_index(box: un.Quantity) -> un.Quantity:
        """Slice the box along the x-axis."""
        return box[idx, :, :]

    return slice_index


def coeval2slice_y(
    idx: int | None = 0,
) -> un.Quantity:
    """Slice the box along the y-axis."""

    def slice_index(box: un.Quantity) -> un.Quantity:
        """Slice the box along the y-axis."""
        return box[:, idx, :]

    return slice_index


def coeval2slice_z(
    idx: int | None = 0,
) -> un.Quantity:
    """Slice the box along the z-axis."""

    def slice_index(box: un.Quantity) -> un.Quantity:
        """Slice the box along the z-axis."""
        return box[:, :, idx]

    return slice_index


[docs] def plot_redshift_slice( lightcone: un.Quantity, box_length: un.Quantity, redshift: np.ndarray | un.Quantity, *, fontsize: float | None = 16, title: str | None = None, xlabel: str | None = None, ylabel: str | None = None, clabel: str | None = None, cmap: str = "eor", logx: bool = False, logy: bool = False, logc: bool = False, zmin: float | None = None, zmax: float | None = None, vmin: float | None = None, vmax: float | None = None, ax: plt.Axes | None = None, smooth: bool | float = False, transform2slice: Callable | None = None, ) -> plt.Axes: """Plot a slice from a lightcone of shape (N_x, N_y, N_redshifts). Parameters ---------- lightcone : un.Quantity The lightcone data to plot with shape (N_x, N_y, N_redshifts). box_length : un.Quantity The length of the box. redshift : np.ndarray | un.Quantity The redshift values corresponding to the lightcone. title : str, optional The title of the plot. xlabel : str, optional The label for the x-axis. ylabel : str, optional The label for the y-axis. clabel : str, optional The label for the colorbar. cmap : str, optional The colormap to use for the plot. logx : bool, optional Whether to use a logarithmic scale for the x-axis. logy : bool, optional Whether to use a logarithmic scale for the y-axis. logc : bool, optional Whether to use a logarithmic scale for the colorbar. zmin : float, optional The minimum redshift of the lightcone. zmax : float, optional The maximum redshift of the lightcone. vmin : float, optional The minimum value for the color scale. vmax : float, optional The maximum value for the color scale. ax : plt.Axes, optional The axes to plot on. If None, a new figure and axes will be created. smooth : bool | float, optional Whether to apply Gaussian smoothing to the lightcone data. If True, a default sigma of 1.0 will be used. If a float, it will be used as the sigma for the Gaussian filter. """ validate(box_length, "length") rcParams.update({"font.size": fontsize}) if ax is None: _, ax = plt.subplots(figsize=(20, 4)) if transform2slice is not None: lightcone = transform2slice(lightcone, redshift) else: lightcone = lc2slice_x(zmin=zmin, zmax=zmax, idx=0)(lightcone, redshift) if smooth: if isinstance(smooth, bool): smooth = 1.0 lightcone = gaussian_filter(lightcone.value, sigma=smooth) * lightcone.unit yaxis = np.linspace(0, box_length, lightcone.shape[0]) if not isinstance(redshift, un.Quantity): redshift = redshift * un.dimensionless_unscaled if clabel is None: if lightcone.unit.physical_type == un.get_physical_type("temperature"): clabel = "Brightness Temperature " + f" [{lightcone.unit:latex_inline}]" elif lightcone.unit.is_equivalent(un.dimensionless_unscaled): clabel = "Density Contrast" else: clabel = ( f"{lightcone.unit.physical_type} " + f" [{lightcone.unit:latex_inline}]" ) if vmin is None and vmax is None: if logc: vmin = np.nanpercentile(np.log10(lightcone.value), 5) else: vmin = np.nanpercentile(lightcone.value, 5) if cmap.lower() == "eor": vmax = -1.0 * vmin / 0.86 + vmin return _plot_slice( lightcone.T, redshift, yaxis, vmin=vmin, vmax=vmax, log=[logx, logy, logc], title=title, xlabel="Redshift" if xlabel is None else xlabel, ylabel=f"Distance [{box_length.unit:latex_inline}]" if ylabel is None else ylabel, clabel=clabel, cmap=cmap, ax=ax, )
[docs] def plot_coeval_slice( coeval: un.Quantity, box_length: un.Quantity, *, fontsize: float | None = 16, title: str | None = None, xlabel: str | None = None, ylabel: str | None = None, clabel: str | None = None, cmap: str = "viridis", logx: bool = False, logy: bool = False, logc: bool = False, idx: int = 0, vmin: float | None = None, vmax: float | None = None, ax: plt.Axes | None = None, smooth: bool | float = False, transform2slice: Callable | None = None, v_x: un.Quantity | None = None, v_y: un.Quantity | None = None, quiver_label: str | bool = False, quiver_kwargs: dict | None = None, quiver_label_kwargs: dict | None = None, quiver_decimate_factor: int = 1, ) -> plt.Axes: """Plot a slice from a coeval of shape (Nx, Ny, N redshifts). Parameters ---------- coeval : un.Quantity The coeval data cube with shape (Nx, Ny, N redshifts). box_length : un.Quantity The length of the box. title : str, optional The title of the plot. xlabel : str, optional The label for the x-axis. ylabel : str, optional The label for the y-axis. clabel : str, optional The label for the colorbar. cmap : str, optional The colormap to use for the plot. logx : bool, optional Whether to use a logarithmic scale for the x-axis. logy : bool, optional Whether to use a logarithmic scale for the y-axis. logc : bool, optional Whether to use a logarithmic scale for the colorbar. idx : int, optional The index of the slice to plot along the z-axis. Default is 0. vmin : float, optional The minimum value for the color scale. vmax : float, optional The maximum value for the color scale. ax : plt.Axes, optional The axes to plot on. If None, a new figure and axes will be created. smooth : bool | float, optional Whether to apply Gaussian smoothing to the coeval data. If True, a default sigma of 1.0 will be used. If a float, it will be used as the sigma for the Gaussian filter. transform2slice : Callable, optional A function to transform the coeval data into a slice. If None, the default slicing function will be used. v_x : un.Quantity, optional The x-component of the velocity field to plot as a vector field on top of the slice plot. This is a 2D array with shape (Nx, Ny). v_y : un.Quantity, optional The y-component of the velocity field to plot as a vector field on top of the slice plot. This is a 2D array with shape (Nx, Ny). quiver_label : str | bool, optional The label for the quiver plot that appears on the top right corner right outside of the plot area. If True, a default label will be put, assuming the velocity is being plotted. If False, no label will be added. quiver_kwargs : dict, optional Additional keyword arguments for the quiver plot, such as arrow color, width, etc. See `matplotlib.pyplot.quiver` for more details. quiver_label_kwargs : dict, optional Additional keyword arguments for the quiver label, such as color, angle, etc. See `matplotlib.pyplot.quiverkey` for more details. quiver_decimate_factor : int, optional The factor by which to decimate the vector field for plotting. This is useful for reducing the number of arrows in the quiver plot to avoid cluttering the plot. Default is 1 (no decimation). Returns ------- plt.Axes The axes with the coeval slice plot. """ validate(box_length, "length") rcParams.update({"font.size": fontsize}) if ax is None: _, ax = plt.subplots(figsize=(7, 6)) if transform2slice is not None: coeval = transform2slice(coeval) else: coeval = coeval2slice_z(idx=idx)(coeval) if smooth: if isinstance(smooth, bool): smooth = 1.0 coeval = gaussian_filter(coeval.value, sigma=smooth) * coeval.unit xaxis = np.linspace(0, box_length, coeval.shape[0]) yaxis = np.linspace(0, box_length, coeval.shape[1]) if clabel is None: if coeval.unit.physical_type == un.get_physical_type("temperature"): clabel = "Brightness Temperature " + f" [{coeval.unit:latex_inline}]" elif coeval.unit.is_equivalent(un.dimensionless_unscaled): clabel = "Density Contrast" else: clabel = f"{coeval.unit.physical_type} " + f" [{coeval.unit:latex_inline}]" ax = _plot_slice( coeval, xaxis, yaxis, vmin=vmin, vmax=vmax, log=[logx, logy, logc], title=title, xlabel=f"Distance [{box_length.unit:latex_inline}]" if xlabel is None else xlabel, ylabel=f"Distance [{box_length.unit:latex_inline}]" if ylabel is None else ylabel, clabel=clabel, cmap=cmap, ax=ax, ) if v_x is not None and v_y is not None: if quiver_kwargs is None: quiver_kwargs = { "color": "k", "width": 0.006, "headwidth": 4, } if quiver_label: quiver_label = "Velocity " + f"[{v_x.unit:latex_inline}]" if quiver_label_kwargs is None: quiver_label_kwargs = { "labelpos": "E", "coordinates": "figure", } axq = ax.quiver( quiver_kwargs.pop("X", xaxis.value[::quiver_decimate_factor]), quiver_kwargs.pop("Y", yaxis.value[::quiver_decimate_factor]), quiver_kwargs.pop( "U", v_x.value[::quiver_decimate_factor, ::quiver_decimate_factor] ), quiver_kwargs.pop( "V", v_y.value[::quiver_decimate_factor, ::quiver_decimate_factor] ), **quiver_kwargs, ) if isinstance(quiver_label, str): ax.quiverkey( axq, quiver_label_kwargs.pop("X", 0.9), quiver_label_kwargs.pop("Y", 0.9), quiver_label_kwargs.pop("U", 1.0), quiver_label, **quiver_label_kwargs, ) return ax
[docs] def plot_pdf( box: un.Quantity, *, fontsize: float | None = 16, title: str | None = None, xlabel: str | None = None, ylabel: str | None = None, logx: bool = False, ax: plt.Axes | None = None, smooth: bool | float = False, hist_kwargs, ) -> plt.Axes: """Plot a pxiel distribution function (PDF) of the box. Parameters ---------- box : un.Quantity The box data to plot. fontsize : float, optional The font size for the plot. title : str, optional The title of the plot. xlabel : str, optional The label for the x-axis. ylabel : str, optional The label for the y-axis. logx : bool, optional Whether to use a logarithmic scale for the x-axis. ax : plt.Axes, optional The axes to plot on. If None, a new figure and axes will be created. smooth : bool | float, optional Whether to apply Gaussian smoothing to the box data. If True, a default sigma of 1.0 will be used. If a float, it will be used as the sigma for the Gaussian filter. Returns ------- plt.Axes The axes with the PDF plot. """ rcParams.update({"font.size": fontsize}) if smooth: if isinstance(smooth, bool): smooth = 1.0 box = gaussian_filter(box.value, sigma=smooth) * box.unit if ax is None: _, ax = plt.subplots(figsize=(7, 6)) ax.hist( box.value.flatten(), **hist_kwargs, ) if xlabel is None: if box.unit.physical_type == un.get_physical_type("temperature"): xlabel = "Brightness Temperature " + f" [{box.unit:latex_inline}]" elif box.unit.is_equivalent(un.dimensionless_unscaled): xlabel = "Density Contrast" else: xlabel = f"{box.unit.physical_type} " + f" [{box.unit:latex_inline}]" ax.set_xlabel(xlabel) ax.set_ylabel("Counts" if ylabel is None else ylabel) if title is not None: ax.set_title(title) if logx: ax.set_xscale("log") return ax