Source code for astrophot.plots.image

import numpy as np
import torch

from astropy.visualization import HistEqStretch, ImageNormalize, LogStretch, SqrtStretch
from matplotlib.patches import Rectangle, Polygon
from matplotlib import pyplot as plt
import matplotlib
from scipy.stats import iqr

from ..models import Group_Model, Sky_Model, PSF_Model
from ..image import Image_List, Window_List
from .. import AP_config
from ..utils.conversions.units import flux_to_sb
from .visuals import *


__all__ = ["target_image", "psf_image", "model_image", "residual_image", "model_window"]


[docs] def target_image(fig, ax, target, window=None, **kwargs): """ This function is used to display a target image using the provided figure and axes. Args: fig (matplotlib.figure.Figure): The figure object in which the target image will be displayed. ax (matplotlib.axes.Axes): The axes object on which the target image will be plotted. target (Image or Image_List): The image or list of images to be displayed. window (Window, optional): The window through which the image is viewed. If `None`, the window of the provided `target` is used. Defaults to `None`. **kwargs: Arbitrary keyword arguments. Returns: fig (matplotlib.figure.Figure): The figure object containing the displayed target image. ax (matplotlib.axes.Axes): The axes object containing the displayed target image. Note: If the `target` is an `Image_List`, this function will recursively call itself for each image in the list. The `window` parameter and `kwargs` are passed unchanged to each recursive call. """ # recursive call for target image list if isinstance(target, Image_List): for i in range(len(target.image_list)): target_image(fig, ax[i], target.image_list[i], window=window, **kwargs) return fig, ax if window is None: window = target.window if kwargs.get("flipx", False): ax.invert_xaxis() target_area = target[window] dat = np.copy(target_area.data.detach().cpu().numpy()) if target_area.has_mask: dat[target_area.mask.detach().cpu().numpy()] = np.nan X, Y = target_area.get_coordinate_corner_meshgrid() X = X.detach().cpu().numpy() Y = Y.detach().cpu().numpy() sky = np.nanmedian(dat) noise = iqr(dat[np.isfinite(dat)], rng=(16, 84)) / 2 if noise == 0: noise = np.nanstd(dat) vmin = sky - 5 * noise vmax = sky + 5 * noise if kwargs.get("linear", False): im = ax.pcolormesh( X, Y, dat, cmap=cmap_grad, ) else: im = ax.pcolormesh( X, Y, dat, cmap="Greys", norm=ImageNormalize( stretch=HistEqStretch( dat[np.logical_and(dat <= (sky + 3 * noise), np.isfinite(dat))] ), clip=False, vmax=sky + 3 * noise, vmin=np.nanmin(dat), ), ) im = ax.pcolormesh( X, Y, np.ma.masked_where(dat < (sky + 3 * noise), dat), cmap=cmap_grad, norm=matplotlib.colors.LogNorm(), clim=[sky + 3 * noise, None], ) ax.axis("equal") ax.set_xlabel("Tangent Plane X [arcsec]") ax.set_ylabel("Tangent Plane Y [arcsec]") return fig, ax
[docs] @torch.no_grad() def psf_image( fig, ax, psf, window=None, cmap_levels=None, flipx=False, **kwargs, ): if isinstance(psf, PSF_Model): psf = psf() # recursive call for target image list if isinstance(psf, Image_List): for i in range(len(psf.image_list)): psf_image(fig, ax[i], psf.image_list[i], window=window, **kwargs) return fig, ax if window is None: window = psf.window if flipx: ax.invert_xaxis() # cut out the requested window psf = psf[window] # Evaluate the model image X, Y = psf.get_coordinate_corner_meshgrid() X = X.detach().cpu().numpy() Y = Y.detach().cpu().numpy() psf = psf.data.detach().cpu().numpy() # Default kwargs for image imshow_kwargs = { "cmap": cmap_grad, "norm": matplotlib.colors.LogNorm(), # "norm": ImageNormalize(stretch=LogStretch(), clip=False), } # Update with user provided kwargs imshow_kwargs.update(kwargs) # if requested, convert the continuous colourmap into discrete levels if cmap_levels is not None: imshow_kwargs["cmap"] = matplotlib.colors.ListedColormap( list(imshow_kwargs["cmap"](c) for c in np.linspace(0.0, 1.0, cmap_levels)) ) # Plot the image im = ax.pcolormesh(X, Y, psf, **imshow_kwargs) # Enforce equal spacing on x y ax.axis("equal") ax.set_xlabel("PSF X [arcsec]") ax.set_ylabel("PSF Y [arcsec]") return fig, ax
[docs] @torch.no_grad() def model_image( fig, ax, model, sample_image=None, window=None, target=None, showcbar=True, target_mask=False, cmap_levels=None, flipx=False, magunits=True, sample_full_image=False, **kwargs, ): """ This function is used to generate a model image and display it using the provided figure and axes. Args: fig (matplotlib.figure.Figure): The figure object in which the image will be displayed. ax (matplotlib.axes.Axes): The axes object on which the image will be plotted. model (Model): The model object used to generate a model image if `sample_image` is not provided. sample_image (Image or Image_List, optional): The image or list of images to be displayed. If `None`, a model image is generated using the provided `model`. Defaults to `None`. window (Window, optional): The window through which the image is viewed. If `None`, the window of the provided `model` is used. Defaults to `None`. target (Target, optional): The target or list of targets for the image or image list. If `None`, the target of the `model` is used. Defaults to `None`. showcbar (bool, optional): Whether to show the color bar. Defaults to `True`. target_mask (bool, optional): Whether to apply the mask of the target. If `True` and if the target has a mask, the mask is applied to the image. Defaults to `False`. cmap_levels (int, optional): The number of discrete levels to convert the continuous color map to. If not `None`, the color map is converted to a ListedColormap with the specified number of levels. Defaults to `None`. sample_full_image: If True, every model will be sampled on the full image window. If False (default) each model will only be sampled in its fitting window. **kwargs: Arbitrary keyword arguments. These are used to override the default imshow_kwargs. Returns: fig (matplotlib.figure.Figure): The figure object containing the displayed image. ax (matplotlib.axes.Axes): The axes object containing the displayed image. Note: If the `sample_image` is an `Image_List`, this function will recursively call itself for each image in the list, with the corresponding target and window. The `showcbar` parameter and `kwargs` are passed unchanged to each recursive call. """ if sample_image is None: if sample_full_image: sample_image = model.make_model_image() sample_image = model(sample_image) else: sample_image = model() # Use model target if not given if target is None: target = model.target # Use model window if not given if window is None: window = model.window # Handle image lists if isinstance(sample_image, Image_List): for i, images in enumerate(zip(sample_image, target, window)): model_image( fig, ax[i], model, sample_image=images[0], window=images[2], target=images[1], showcbar=showcbar, target_mask=target_mask, cmap_levels=cmap_levels, flipx=flipx, magunits=magunits, **kwargs, ) return fig, ax if flipx: ax.invert_xaxis() # cut out the requested window sample_image = sample_image[window] # Evaluate the model image X, Y = sample_image.get_coordinate_corner_meshgrid() X = X.detach().cpu().numpy() Y = Y.detach().cpu().numpy() sample_image = sample_image.data.detach().cpu().numpy() # Default kwargs for image imshow_kwargs = { "cmap": cmap_grad, "norm": matplotlib.colors.LogNorm(), # "norm": ImageNormalize(stretch=LogStretch(), clip=False), } # Update with user provided kwargs imshow_kwargs.update(kwargs) # if requested, convert the continuous colourmap into discrete levels if cmap_levels is not None: imshow_kwargs["cmap"] = matplotlib.colors.ListedColormap( list(imshow_kwargs["cmap"](c) for c in np.linspace(0.0, 1.0, cmap_levels)) ) # If zeropoint is available, convert to surface brightness units if target.zeropoint is not None and magunits: sample_image = flux_to_sb(sample_image, target.pixel_area.item(), target.zeropoint.item()) del imshow_kwargs["norm"] imshow_kwargs["cmap"] = imshow_kwargs["cmap"].reversed() # Apply the mask if available if target_mask and target.has_mask: sample_image[target.mask.detach().cpu().numpy()] = np.nan # Plot the image im = ax.pcolormesh(X, Y, sample_image, **imshow_kwargs) # Enforce equal spacing on x y ax.axis("equal") ax.set_xlabel("Tangent Plane X [arcsec]") ax.set_ylabel("Tangent Plane Y [arcsec]") # Add a colourbar if showcbar: if target.zeropoint is not None and magunits: clb = fig.colorbar(im, ax=ax, label="Surface Brightness [mag/arcsec$^2$]") clb.ax.invert_yaxis() else: clb = fig.colorbar(im, ax=ax, label=f"log$_{{10}}$(flux)") return fig, ax
[docs] @torch.no_grad() def residual_image( fig, ax, model, target=None, sample_image=None, showcbar=True, window=None, center_residuals=False, clb_label=None, normalize_residuals=False, flipx=False, sample_full_image=False, **kwargs, ): """ This function is used to calculate and display the residuals of a model image with respect to a target image. The residuals are calculated as the difference between the target image and the sample image. Args: fig (matplotlib.figure.Figure): The figure object in which the residuals will be displayed. ax (matplotlib.axes.Axes): The axes object on which the residuals will be plotted. model (Model): The model object used to generate a model image if `sample_image` is not provided. target (Target or Image_List, optional): The target or list of targets for the image or image list. If `None`, the target of the `model` is used. Defaults to `None`. sample_image (Image or Image_List, optional): The image or list of images from which residuals will be calculated. If `None`, a model image is generated using the provided `model`. Defaults to `None`. showcbar (bool, optional): Whether to show the color bar. Defaults to `True`. window (Window or Window_List, optional): The window through which the image is viewed. If `None`, the window of the provided `model` is used. Defaults to `None`. center_residuals (bool, optional): Whether to subtract the median of the residuals. If `True`, the median is subtracted from the residuals. Defaults to `False`. clb_label (str, optional): The label for the colorbar. If `None`, a default label is used based on the normalization of the residuals. Defaults to `None`. normalize_residuals (bool, optional): Whether to normalize the residuals. If `True`, residuals are divided by the square root of the variance of the target. Defaults to `False`. sample_full_image: If True, every model will be sampled on the full image window. If False (default) each model will only be sampled in its fitting window. **kwargs: Arbitrary keyword arguments. These are used to override the default imshow_kwargs. Returns: fig (matplotlib.figure.Figure): The figure object containing the displayed residuals. ax (matplotlib.axes.Axes): The axes object containing the displayed residuals. Note: If the `window`, `target`, or `sample_image` are lists, this function will recursively call itself for each element in the list, with the corresponding window, target, and sample image. The `showcbar`, `center_residuals`, and `kwargs` are passed unchanged to each recursive call. """ if window is None: window = model.window if target is None: target = model.target if sample_image is None: if sample_full_image: sample_image = model.make_model_image() sample_image = model(sample_image) else: sample_image = model() if isinstance(window, Window_List) or isinstance(target, Image_List): for i_ax, win, tar, sam in zip(ax, window, target, sample_image): residual_image( fig, i_ax, model, target=tar, sample_image=sam, window=win, showcbar=showcbar, center_residuals=center_residuals, clb_label=clb_label, normalize_residuals=normalize_residuals, flipx=flipx, **kwargs, ) return fig, ax if flipx: ax.invert_xaxis() X, Y = sample_image[window].get_coordinate_corner_meshgrid() X = X.detach().cpu().numpy() Y = Y.detach().cpu().numpy() residuals = (target[window] - sample_image[window]).data if normalize_residuals: residuals = residuals / torch.sqrt(target[window].variance) residuals = residuals.detach().cpu().numpy() if target.has_mask: residuals[target[window].mask.detach().cpu().numpy()] = np.nan if center_residuals: residuals -= np.nanmedian(residuals) residuals = np.arctan(residuals / (iqr(residuals[np.isfinite(residuals)], rng=[10, 90]) * 2)) extreme = np.max(np.abs(residuals[np.isfinite(residuals)])) imshow_kwargs = { "cmap": cmap_div, "vmin": -extreme, "vmax": extreme, } imshow_kwargs.update(kwargs) im = ax.pcolormesh(X, Y, residuals, **imshow_kwargs) ax.axis("equal") ax.set_xlabel("Tangent Plane X [arcsec]") ax.set_ylabel("Tangent Plane Y [arcsec]") if showcbar: if normalize_residuals: default_label = f"tan$^{{-1}}$((Target - {model.name}) / $\\sigma$)" else: default_label = f"tan$^{{-1}}$(Target - {model.name})" clb = fig.colorbar(im, ax=ax, label=default_label if clb_label is None else clb_label) clb.ax.set_yticks([]) clb.ax.set_yticklabels([]) return fig, ax
[docs] def model_window(fig, ax, model, target=None, rectangle_linewidth=2, **kwargs): if isinstance(ax, np.ndarray): for i, axitem in enumerate(ax): model_window(fig, axitem, model, target=model.target.image_list[i], **kwargs) return fig, ax if isinstance(model, Group_Model): for m in model.models.values(): if isinstance(m.window, Window_List): use_window = m.window.window_list[m.target.index(target)] else: use_window = m.window lowright = use_window.pixel_shape.clone().to(dtype=AP_config.ap_dtype) lowright[1] = 0.0 lowright = use_window.origin + use_window.pixel_to_plane_delta(lowright) lowright = lowright.detach().cpu().numpy() upleft = use_window.pixel_shape.clone().to(dtype=AP_config.ap_dtype) upleft[0] = 0.0 upleft = use_window.origin + use_window.pixel_to_plane_delta(upleft) upleft = upleft.detach().cpu().numpy() end = use_window.origin + use_window.end end = end.detach().cpu().numpy() x = [ use_window.origin[0].detach().cpu().numpy(), lowright[0], end[0], upleft[0], ] y = [ use_window.origin[1].detach().cpu().numpy(), lowright[1], end[1], upleft[1], ] ax.add_patch( Polygon( xy=list(zip(x, y)), fill=False, linewidth=rectangle_linewidth, edgecolor=main_pallet["secondary1"], ) ) else: if isinstance(model.window, Window_List): use_window = model.window.window_list[model.target.index(target)] else: use_window = model.window lowright = use_window.pixel_shape.clone().to(dtype=AP_config.ap_dtype) lowright[1] = 0.0 lowright = use_window.origin + use_window.pixel_to_plane_delta(lowright) lowright = lowright.detach().cpu().numpy() upleft = use_window.pixel_shape.clone().to(dtype=AP_config.ap_dtype) upleft[0] = 0.0 upleft = use_window.origin + use_window.pixel_to_plane_delta(upleft) upleft = upleft.detach().cpu().numpy() end = use_window.origin + use_window.end end = end.detach().cpu().numpy() x = [ use_window.origin[0].detach().cpu().numpy(), lowright[0], end[0], upleft[0], ] y = [ use_window.origin[1].detach().cpu().numpy(), lowright[1], end[1], upleft[1], ] ax.add_patch( Polygon( xy=list(zip(x, y)), fill=False, linewidth=rectangle_linewidth, edgecolor=main_pallet["secondary1"], ) ) return fig, ax