Source code for astrophot.plots.profile

from functools import partial

import numpy as np
import torch
import matplotlib.pyplot as plt
from scipy.stats import binned_statistic, iqr

from .. import AP_config
from ..models import Warp_Galaxy
from ..utils.conversions.units import flux_to_sb
from .visuals import *
from ..errors import InvalidModel

__all__ = [
    "radial_light_profile",
    "radial_median_profile",
    "ray_light_profile",
    "wedge_light_profile",
    "warp_phase_profile",
]


[docs] def radial_light_profile( fig, ax, model, rad_unit="arcsec", extend_profile=1.0, R0=0.0, resolution=1000, doassert=True, plot_kwargs={}, ): xx = torch.linspace( R0, torch.max(model.window.shape / 2) * extend_profile, int(resolution), dtype=AP_config.ap_dtype, device=AP_config.ap_device, ) flux = model.radial_model(xx).detach().cpu().numpy() if model.target.zeropoint is not None: yy = flux_to_sb( flux, model.target.pixel_area.item(), model.target.zeropoint.item() ) else: yy = np.log10(flux) kwargs = { "linewidth": 2, "color": main_pallet["primary1"], "label": f"{model.name} profile", } kwargs.update(plot_kwargs) with torch.no_grad(): ax.plot( xx.detach().cpu().numpy(), yy, **kwargs, ) if model.target.zeropoint is not None: ax.set_ylabel("Surface Brightness [mag/arcsec$^2$]") if not ax.yaxis_inverted(): ax.invert_yaxis() else: ax.set_ylabel("log$_{10}$(flux/arcsec$^2$)") ax.set_xlabel(f"Radius [{rad_unit}]") ax.set_xlim([R0, None]) return fig, ax
[docs] def radial_median_profile( fig, ax, model: "AstroPhot_Model", count_limit: int = 10, return_profile: bool = False, rad_unit: str = "arcsec", doassert: bool = True, plot_kwargs: dict = {}, ): """Plot an SB profile by taking flux median at each radius. Using the coordinate transforms defined by the model object, assigns a radius to each pixel then bins the pixel-radii and computes the median in each bin. This gives a simple representation of the image data if one were to simply average the pixels along isophotes. Args: fig: matplotlib figure object ax: matplotlib axis object model (AstroPhot_Model): Model object from which to determine the radial binning. Also provides the target image to extract the data count_limit (int): The limit of pixels in a bin, below which uncertainties are not computed. Default: 10 return_profile (bool): Instead of just returning the fig and ax object, will return the extracted profile formatted as: Rbins (the radial bin edges), medians (the median in each bin), scatter (the 16-84 quartile range / 2), count (the number of pixels in each bin). Default: False rad_unit (str): The name of the physical radius units. Default: "arcsec" doassert (bool): If any requirements are imposed on which kind of profile can be plotted, this activates them. Default: True """ Rlast_phys = torch.max(model.window.shape / 2).item() Rlast_pix = Rlast_phys / model.target.pixel_length.item() Rbins = [0.0] while Rbins[-1] < Rlast_pix: Rbins.append(Rbins[-1] + max(2, Rbins[-1] * 0.1)) Rbins = np.array(Rbins) with torch.no_grad(): image = model.target[model.window] X, Y = image.get_coordinate_meshgrid() - model["center"].value[..., None, None] X, Y = model.transform_coordinates(X, Y) R = model.radius_metric(X, Y) R = R.detach().cpu().numpy() count, bins, binnum = binned_statistic( R.ravel(), image.data.detach().cpu().numpy().ravel(), statistic="count", bins=Rbins, ) stat, bins, binnum = binned_statistic( R.ravel(), image.data.detach().cpu().numpy().ravel(), statistic="median", bins=Rbins, ) stat[count == 0] = np.nan scat, bins, binnum = binned_statistic( R.ravel(), image.data.detach().cpu().numpy().ravel(), statistic=partial(iqr, rng=(16, 84)), bins=Rbins, ) scat[count > count_limit] /= 2 * np.sqrt(count[count > count_limit]) scat[count <= count_limit] = 0 if model.target.zeropoint is not None: stat = flux_to_sb( stat, model.target.pixel_area.item(), model.target.zeropoint.item() ) ax.set_ylabel("Surface Brightness [mag/arcsec$^2$]") if not ax.yaxis_inverted(): ax.invert_yaxis() else: stat = np.log10(stat) ax.set_ylabel("log$_{10}$(flux/arcsec^2)") kwargs = { "linewidth": 0, "elinewidth": 1, "color": main_pallet["primary2"], "label": f"data profile", } kwargs.update(plot_kwargs) ax.errorbar( (Rbins[:-1] + Rbins[1:]) / 2, stat, yerr=scat, fmt=".", **kwargs, ) ax.set_xlabel(f"Radius [{rad_unit}]") if return_profile: return Rbins, stat, scat, count return fig, ax
[docs] def ray_light_profile( fig, ax, model, rad_unit="arcsec", extend_profile=1.0, resolution=1000, doassert=True, ): xx = torch.linspace( 0, torch.max(model.window.shape / 2) * extend_profile, int(resolution), dtype=AP_config.ap_dtype, device=AP_config.ap_device, ) for r in range(model.rays): if model.rays <= 5: col = main_pallet[f"primary{r+1}"] else: col = cmap_grad(r / model.rays) with torch.no_grad(): ax.plot( xx.detach().cpu().numpy(), np.log10(model.iradial_model(r, xx).detach().cpu().numpy()), linewidth=2, color=col, label=f"{model.name} profile {r}", ) ax.set_ylabel("log$_{10}$(flux)") ax.set_xlabel(f"Radius [{rad_unit}]") return fig, ax
[docs] def wedge_light_profile( fig, ax, model, rad_unit="arcsec", extend_profile=1.0, resolution=1000, doassert=True, ): xx = torch.linspace( 0, torch.max(model.window.shape / 2) * extend_profile, int(resolution), dtype=AP_config.ap_dtype, device=AP_config.ap_device, ) for r in range(model.wedges): if model.wedges <= 5: col = main_pallet[f"primary{r+1}"] else: col = cmap_grad(r / model.wedges) with torch.no_grad(): ax.plot( xx.detach().cpu().numpy(), np.log10(model.iradial_model(r, xx).detach().cpu().numpy()), linewidth=2, color=col, label=f"{model.name} profile {r}", ) ax.set_ylabel("log$_{10}$(flux)") ax.set_xlabel(f"Radius [{rad_unit}]") return fig, ax
[docs] def warp_phase_profile(fig, ax, model, rad_unit="arcsec", doassert=True): if doassert: if not isinstance(model, Warp_Galaxy): raise InvalidModel(f"warp_phase_profile must be given a 'Warp_Galaxy' object. Not {type(model)}") ax.plot( model.profR, model["q(R)"].value.detach().cpu().numpy(), linewidth=2, color=main_pallet["primary1"], label=f"{model.name} axis ratio", ) ax.plot( model.profR, model["PA(R)"].detach().cpu().numpy() / np.pi, linewidth=2, color=main_pallet["secondary1"], label=f"{model.name} position angle", ) ax.set_ylim([0, 1]) ax.set_ylabel("q [b/a], PA [rad/$\\pi$]") ax.set_xlabel(f"Radius [{rad_unit}]") return fig, ax