Source code for astrophot.image.target_image

from typing import List, Optional

import torch
import numpy as np
from torch.nn.functional import avg_pool2d

from .image_object import Image, Image_List
from .jacobian_image import Jacobian_Image, Jacobian_Image_List
from .model_image import Model_Image, Model_Image_List
from .psf_image import PSF_Image
from astropy.io import fits
from .. import AP_config
from ..errors import SpecificationConflict, InvalidImage

__all__ = ["Target_Image", "Target_Image_List"]


[docs] class Target_Image(Image): """Image object which represents the data to be fit by a model. It can include a variance image, mask, and PSF as anciliary data which describes the target image. Target images are a basic unit of data in `AstroPhot`, they store the information collected from telescopes for which models are to be fit. There is minimial functionality in the `Target_Image` object itself, it is mostly defined in terms of how other objects interact with it. Basic usage: .. code-block:: python import astrophot as ap # Create target image image = ap.image.Target_Image( data = <pixel data>, wcs = <astropy WCS object>, variance = <pixel uncertainties>, psf = <point spread function as PSF_Image object>, mask = <pixels to ignore>, ) # Display the data fig, ax = plt.subplots() ap.plots.target_image(fig, ax, image) plt.show() # Save the image image.save("mytarget.fits") # Load the image image2 = ap.image.Target_Image(filename = "mytarget.fits") # Make low resolution version lowrez = image.reduce(2) Some important information to keep in mind. First, providing an `astropy WCS` object is the best way to keep track of coordinates and pixel scale properties, especially when dealing with multi-band data. If images have relative positioning, rotation, pixel sizes, field of view this will all be handled automatically by taking advantage of `WCS` objects. Second, Providing accurate variance (or weight) maps is critical to getting a good fit to the data. This is a very common source of issues so it is worthwhile to review literature on how best to construct such a map. A good starting place is the FAQ for GALFIT: https://users.obs.carnegiescience.edu/peng/work/galfit/CHI2.html which is an excellent resource for all things image modeling. Just note that `AstroPhot` uses variance or weight maps, not sigma images. `AstroPhot` will not crete a variance map for the user, by default it will just assume uniform variance which is rarely accurate. Third, The PSF pixelscale must be a multiple of the image pixelscale. So if the image has a pixelscale of 1 then the PSF must have a pixelscale of 1, 1/2, 1/3, etc for anything to work out. Note that if the PSF pixelscale is finer than the image, then all modelling will be done at the higher resolution. This is recommended for accuracy though it can mean higher memory consumption. """ image_count = 0 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not self.has_weight and "weight" in kwargs: self.set_weight(kwargs.get("weight", None)) elif not self.has_variance and "variance" in kwargs: self.set_variance(kwargs.get("variance", None)) if not self.has_mask: self.set_mask(kwargs.get("mask", None)) if not self.has_psf: self.set_psf(kwargs.get("psf", None), kwargs.get("psf_upscale", 1)) # Set nan pixels to be masked automatically if torch.any(torch.isnan(self.data)).item(): self.set_mask(torch.logical_or(self.mask, torch.isnan(self.data))) @property def standard_deviation(self): """Stores the standard deviation of the image pixels. This represents the uncertainty in each pixel value. It should always have the same shape as the image data. In the case where the standard deviation is not known, a tensor of ones will be created to stand in as the standard deviation values. The standard deviation is not stored directly, instead it is computed as :math:`\\sqrt{1/W}` where :math:`W` is the weights. """ if self.has_variance: return torch.sqrt(self.variance) return torch.ones_like(self.data) @property def variance(self): """Stores the variance of the image pixels. This represents the uncertainty in each pixel value. It should always have the same shape as the image data. In the case where the variance is not known, a tensor of ones will be created to stand in as the variance values. The variance is not stored directly, instead it is computed as :math:`\\frac{1}{W}` where :math:`W` is the weights. """ if self.has_variance: return torch.where(self._weight == 0, torch.inf, 1 / self._weight) return torch.ones_like(self.data) @variance.setter def variance(self, variance): self.set_variance(variance) @property def has_variance(self): """Returns True when the image object has stored variance values. If this is False and the variance property is called then a tensor of ones will be returned. """ try: return self._weight is not None except AttributeError: return False @property def weight(self): """Stores the weight of the image pixels. This represents the uncertainty in each pixel value. It should always have the same shape as the image data. In the case where the weight is not known, a tensor of ones will be created to stand in as the weight values. The weights are used to proprtionately scale residuals in the likelihood. Most commonly this shows up as a :math:`\\chi^2` like: .. math:: \\chi^2 = (\\vec{y} - \\vec{f(\\theta)})^TW(\\vec{y} - \\vec{f(\\theta)}) which can be optimized to find parameter values. Using the Jacobian, which in this case is the derivative of every pixel wrt every parameter, the weight matrix also appears in the gradient: .. math:: \\vec{g} = J^TW(\\vec{y} - \\vec{f(\\theta)}) and the hessian approximation used in Levenberg-Marquardt: .. math:: H \\approx J^TWJ """ if self.has_weight: return self._weight return torch.ones_like(self.data) @weight.setter def weight(self, weight): self.set_weight(weight) @property def has_weight(self): """Returns True when the image object has stored weight values. If this is False and the weight property is called then a tensor of ones will be returned. """ try: return self._weight is not None except AttributeError: self._weight = None return False @property def mask(self): """The mask stores a tensor of boolean values which indicate any pixels to be ignored. These pixels will be skipped in likelihood evaluations and in parameter optimization. It is common practice to mask pixels with pathological values such as due to cosmic rays or satelites passing through the image. In a mask, a True value indicates that the pixel is masked and should be ignored. False indicates a normal pixel which will inter into most calculaitons. If no mask is provided, all pixels are assumed valid. """ if self.has_mask: return self._mask return torch.zeros_like(self.data, dtype=torch.bool) @mask.setter def mask(self, mask): self.set_mask(mask) @property def has_mask(self): """ Single boolean to indicate if a mask has been provided by the user. """ try: return self._mask is not None except AttributeError: return False @property def psf(self): """Stores the point-spread-function for this target. This should be a `PSF_Image` object which represents the scattering of a point source of light. It can also be an `AstroPhot_Model` object which will contribute its own parameters to an optimization problem. The PSF stored for a `Target_Image` object is passed to all models applied to that target which have a `psf_mode` that is not `none`. This means they will all use the same PSF model. If one wishes to define a variable PSF across an image, then they should pass the PSF objects to the `AstroPhot_Model`'s directly instead of to a `Target_Image`. Raises: AttributeError: if this is called without a PSF defined """ if self.has_psf: return self._psf raise AttributeError("This image does not have a PSF") @psf.setter def psf(self, psf): self.set_psf(psf) @property def has_psf(self): try: return self._psf is not None except AttributeError: return False
[docs] def set_variance(self, variance): """ Provide a variance tensor for the image. Variance is equal to $\\sigma^2$. This should have the same shape as the data. """ if variance is None: self._weight = None return self.set_weight(1 / variance)
[docs] def set_weight(self, weight): """Provide a weight tensor for the image. Weight is equal to $\\frac{1}{\\sigma^2}$. This should have the same shape as the data. """ if weight is None: self._weight = None return if weight.shape != self.data.shape: raise SpecificationConflict( f"weight/variance must have same shape as data ({weight.shape} vs {self.data.shape})" ) self._weight = ( weight.to(dtype=AP_config.ap_dtype, device=AP_config.ap_device) if isinstance(weight, torch.Tensor) else torch.as_tensor(weight, dtype=AP_config.ap_dtype, device=AP_config.ap_device) )
[docs] def set_psf(self, psf, psf_upscale=1): """Provide a psf for the `Target_Image`. This is stored and passed to models which need to be convolved. The PSF doesn't need to have the same pixelscale as the image. It should be some multiple of the resolution of the `Target_Image` though. So if the image has a pixelscale of 1, the psf may have a pixelscale of 1, 1/2, 1/3, 1/4 and so on. """ if psf is None: self._psf = None return if isinstance(psf, PSF_Image): self._psf = psf return self._psf = PSF_Image( data=psf, psf_upscale=psf_upscale, pixelscale=self.pixelscale / psf_upscale, identity=self.identity, )
[docs] def set_mask(self, mask): """ Set the boolean mask which will indicate which pixels to ignore. A mask value of True means the pixel will be ignored. """ if mask is None: self._mask = None return if mask.shape != self.data.shape: raise SpecificationConflict( f"mask must have same shape as data ({mask.shape} vs {self.data.shape})" ) self._mask = ( mask.to(dtype=torch.bool, device=AP_config.ap_device) if isinstance(mask, torch.Tensor) else torch.as_tensor(mask, dtype=torch.bool, device=AP_config.ap_device) )
[docs] def to(self, dtype=None, device=None): """Converts the stored `Target_Image` data, variance, psf, etc to a given data type and device. """ super().to(dtype=dtype, device=device) if dtype is not None: dtype = AP_config.ap_dtype if device is not None: device = AP_config.ap_device if self.has_weight: self._weight = self._weight.to(dtype=dtype, device=device) if self.has_psf: self._psf = self._psf.to(dtype=dtype, device=device) if self.has_mask: self._mask = self.mask.to(dtype=torch.bool, device=device) return self
[docs] def or_mask(self, mask): """ Combines the currently stored mask with a provided new mask using the boolean `or` operator. """ self._mask = torch.logical_or(self.mask, mask)
[docs] def and_mask(self, mask): """ Combines the currently stored mask with a provided new mask using the boolean `and` operator. """ self._mask = torch.logical_and(self.mask, mask)
[docs] def copy(self, **kwargs): """Produce a copy of this image with all of the same properties. This can be used when one wishes to make temporary modifications to an image and then will want the original again. """ return super().copy( mask=self._mask, psf=self._psf, weight=self._weight, **kwargs, )
[docs] def blank_copy(self, **kwargs): """Produces a blank copy of the image which has the same properties except that its data is not filled with zeros. """ return super().blank_copy(mask=self._mask, psf=self._psf, **kwargs)
[docs] def get_window(self, window, **kwargs): """Get a sub-region of the image as defined by a window on the sky.""" indices = self.window.get_self_indices(window) return super().get_window( window=window, weight=self._weight[indices] if self.has_weight else None, mask=self._mask[indices] if self.has_mask else None, psf=self._psf, **kwargs, )
[docs] def jacobian_image( self, parameters: Optional[List[str]] = None, data: Optional[torch.Tensor] = None, **kwargs, ): """ Construct a blank `Jacobian_Image` object formatted like this current `Target_Image` object. Mostly used internally. """ if parameters is None: data = None parameters = [] elif data is None: data = torch.zeros( (*self.data.shape, len(parameters)), dtype=AP_config.ap_dtype, device=AP_config.ap_device, ) return Jacobian_Image( parameters=parameters, target_identity=self.identity, data=data, header=self.header, **kwargs, )
[docs] def model_image(self, data: Optional[torch.Tensor] = None, **kwargs): """ Construct a blank `Model_Image` object formatted like this current `Target_Image` object. Mostly used internally. """ return Model_Image( data=torch.zeros_like(self.data) if data is None else data, header=self.header, target_identity=self.identity, **kwargs, )
[docs] def reduce(self, scale, **kwargs): """Returns a new `Target_Image` object with a reduced resolution compared to the current image. `scale` should be an integer indicating how much to reduce the resolution. If the `Target_Image` was originally (48,48) pixels across with a pixelscale of 1 and `reduce(2)` is called then the image will be (24,24) pixels and the pixelscale will be 2. If `reduce(3)` is called then the returned image will be (16,16) pixels across and the pixelscale will be 3. """ MS = self.data.shape[0] // scale NS = self.data.shape[1] // scale return super().reduce( scale=scale, variance=( self.variance[: MS * scale, : NS * scale] .reshape(MS, scale, NS, scale) .sum(axis=(1, 3)) if self.has_variance else None ), mask=( self.mask[: MS * scale, : NS * scale] .reshape(MS, scale, NS, scale) .amax(axis=(1, 3)) if self.has_mask else None ), psf=self.psf.reduce(scale) if self.has_psf else None, **kwargs, )
[docs] def expand(self, padding): """ `Target_Image` doesn't have expand yet. """ raise NotImplementedError("expand not available for Target_Image yet")
[docs] def get_state(self): state = super().get_state() if self.has_weight: state["weight"] = self.weight.detach().cpu().tolist() if self.has_mask: state["mask"] = self.mask.detach().cpu().tolist() if self.has_psf: state["psf"] = self.psf.get_state() return state
[docs] def set_state(self, state): super().set_state(state) self.weight = state.get("weight", None) self.mask = state.get("mask", None) if "psf" in state: self.psf = PSF_Image(state=state["psf"])
[docs] def get_fits_state(self): states = super().get_fits_state() if self.has_weight: states.append( { "DATA": self.weight.detach().cpu().numpy(), "HEADER": {"IMAGE": "WEIGHT"}, } ) if self.has_mask: states.append( { "DATA": self.mask.detach().cpu().numpy().astype(int), "HEADER": {"IMAGE": "MASK"}, } ) if self.has_psf: states += self.psf.get_fits_state() return states
[docs] def set_fits_state(self, states): super().set_fits_state(states) for state in states: if state["HEADER"]["IMAGE"] == "WEIGHT": self.weight = np.array(state["DATA"], dtype=np.float64) if state["HEADER"]["IMAGE"] == "mask": self.mask = np.array(state["DATA"], dtype=bool) if state["HEADER"]["IMAGE"] == "PSF": self.psf = PSF_Image(fits_state=states)
[docs] class Target_Image_List(Image_List, Target_Image): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not all(isinstance(image, Target_Image) for image in self.image_list): raise InvalidImage( f"Target_Image_List can only hold Target_Image objects, not {tuple(type(image) for image in self.image_list)}" ) @property def variance(self): return tuple(image.variance for image in self.image_list) @variance.setter def variance(self, variance): for image, var in zip(self.image_list, variance): image.set_variance(var) @property def has_variance(self): return any(image.has_variance for image in self.image_list) @property def weight(self): return tuple(image.weight for image in self.image_list) @weight.setter def weight(self, weight): for image, wgt in zip(self.image_list, weight): image.set_weight(wgt) @property def has_weight(self): return any(image.has_weight for image in self.image_list)
[docs] def jacobian_image(self, parameters: List[str], data: Optional[List[torch.Tensor]] = None): if data is None: data = [None] * len(self.image_list) return Jacobian_Image_List( list(image.jacobian_image(parameters, dat) for image, dat in zip(self.image_list, data)) )
[docs] def model_image(self, data: Optional[List[torch.Tensor]] = None): if data is None: data = [None] * len(self.image_list) return Model_Image_List( list(image.model_image(data=dat) for image, dat in zip(self.image_list, data)) )
[docs] def match_indices(self, other): indices = [] if isinstance(other, Target_Image_List): for other_image in other.image_list: for isi, self_image in enumerate(self.image_list): if other_image.identity == self_image.identity: indices.append(isi) break else: indices.append(None) elif isinstance(other, Target_Image): for isi, self_image in enumerate(self.image_list): if other.identity == self_image.identity: indices = isi break else: indices = None return indices
def __isub__(self, other): if isinstance(other, Target_Image_List): for other_image in other.image_list: for self_image in self.image_list: if other_image.identity == self_image.identity: self_image -= other_image break else: self.image_list.append(other_image) elif isinstance(other, Target_Image): for self_image in self.image_list: if other.identity == self_image.identity: self_image -= other break elif isinstance(other, Model_Image_List): for other_image in other.image_list: for self_image in self.image_list: if other_image.target_identity == self_image.identity: self_image -= other_image break elif isinstance(other, Model_Image): for self_image in self.image_list: if other.target_identity == self_image.identity: self_image -= other else: for self_image, other_image in zip(self.image_list, other): self_image -= other_image return self def __iadd__(self, other): if isinstance(other, Target_Image_List): for other_image in other.image_list: for self_image in self.image_list: if other_image.identity == self_image.identity: self_image += other_image break else: self.image_list.append(other_image) elif isinstance(other, Target_Image): for self_image in self.image_list: if other.identity == self_image.identity: self_image += other elif isinstance(other, Model_Image_List): for other_image in other.image_list: for self_image in self.image_list: if other_image.target_identity == self_image.identity: self_image += other_image break elif isinstance(other, Model_Image): for self_image in self.image_list: if other.target_identity == self_image.identity: self_image += other else: for self_image, other_image in zip(self.image_list, other): self_image += other_image return self @property def mask(self): return tuple(image.mask for image in self.image_list) @mask.setter def mask(self, mask): for image, M in zip(self.image_list, mask): image.set_mask(M) @property def has_mask(self): return any(image.has_mask for image in self.image_list) @property def psf(self): return tuple(image.psf for image in self.image_list) @psf.setter def psf(self, psf): for image, P in zip(self.image_list, psf): image.set_psf(P) @property def has_psf(self): return any(image.has_psf for image in self.image_list) @property def psf_border(self): return tuple(image.psf_border for image in self.image_list) @property def psf_border_int(self): return tuple(image.psf_border_int for image in self.image_list)
[docs] def set_variance(self, variance, img): self.image_list[img].set_variance(variance)
[docs] def set_psf(self, psf, img): self.image_list[img].set_psf(psf)
[docs] def set_mask(self, mask, img): self.image_list[img].set_mask(mask)
[docs] def or_mask(self, mask): raise NotImplementedError()
[docs] def and_mask(self, mask): raise NotImplementedError()