from typing import List, Optional
import torch
import numpy as np
from torch.nn.functional import avg_pool2d
from .image_object import Image, Image_List
from .jacobian_image import Jacobian_Image, Jacobian_Image_List
from .model_image import Model_Image, Model_Image_List
from .psf_image import PSF_Image
from astropy.io import fits
from .. import AP_config
from ..errors import SpecificationConflict, InvalidImage
__all__ = ["Target_Image", "Target_Image_List"]
[docs]
class Target_Image(Image):
"""Image object which represents the data to be fit by a model. It can
include a variance image, mask, and PSF as anciliary data which
describes the target image.
Target images are a basic unit of data in `AstroPhot`, they store
the information collected from telescopes for which models are to
be fit. There is minimial functionality in the `Target_Image`
object itself, it is mostly defined in terms of how other objects
interact with it.
Basic usage:
.. code-block:: python
import astrophot as ap
# Create target image
image = ap.image.Target_Image(
data = <pixel data>,
wcs = <astropy WCS object>,
variance = <pixel uncertainties>,
psf = <point spread function as PSF_Image object>,
mask = <pixels to ignore>,
)
# Display the data
fig, ax = plt.subplots()
ap.plots.target_image(fig, ax, image)
plt.show()
# Save the image
image.save("mytarget.fits")
# Load the image
image2 = ap.image.Target_Image(filename = "mytarget.fits")
# Make low resolution version
lowrez = image.reduce(2)
Some important information to keep in mind. First, providing an
`astropy WCS` object is the best way to keep track of coordinates
and pixel scale properties, especially when dealing with
multi-band data. If images have relative positioning, rotation,
pixel sizes, field of view this will all be handled automatically
by taking advantage of `WCS` objects. Second, Providing accurate
variance (or weight) maps is critical to getting a good fit to the
data. This is a very common source of issues so it is worthwhile
to review literature on how best to construct such a map. A good
starting place is the FAQ for GALFIT:
https://users.obs.carnegiescience.edu/peng/work/galfit/CHI2.html
which is an excellent resource for all things image modeling. Just
note that `AstroPhot` uses variance or weight maps, not sigma
images. `AstroPhot` will not crete a variance map for the user, by
default it will just assume uniform variance which is rarely
accurate. Third, The PSF pixelscale must be a multiple of the
image pixelscale. So if the image has a pixelscale of 1 then the
PSF must have a pixelscale of 1, 1/2, 1/3, etc for anything to
work out. Note that if the PSF pixelscale is finer than the image,
then all modelling will be done at the higher resolution. This is
recommended for accuracy though it can mean higher memory
consumption.
"""
image_count = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not self.has_weight and "weight" in kwargs:
self.set_weight(kwargs.get("weight", None))
elif not self.has_variance and "variance" in kwargs:
self.set_variance(kwargs.get("variance", None))
if not self.has_mask:
self.set_mask(kwargs.get("mask", None))
if not self.has_psf:
self.set_psf(kwargs.get("psf", None), kwargs.get("psf_upscale", 1))
# Set nan pixels to be masked automatically
if torch.any(torch.isnan(self.data)).item():
self.set_mask(torch.logical_or(self.mask, torch.isnan(self.data)))
@property
def standard_deviation(self):
"""Stores the standard deviation of the image pixels. This represents
the uncertainty in each pixel value. It should always have the
same shape as the image data. In the case where the standard
deviation is not known, a tensor of ones will be created to
stand in as the standard deviation values.
The standard deviation is not stored directly, instead it is
computed as :math:`\\sqrt{1/W}` where :math:`W` is the
weights.
"""
if self.has_variance:
return torch.sqrt(self.variance)
return torch.ones_like(self.data)
@property
def variance(self):
"""Stores the variance of the image pixels. This represents the
uncertainty in each pixel value. It should always have the
same shape as the image data. In the case where the variance
is not known, a tensor of ones will be created to stand in as
the variance values.
The variance is not stored directly, instead it is
computed as :math:`\\frac{1}{W}` where :math:`W` is the
weights.
"""
if self.has_variance:
return torch.where(self._weight == 0, torch.inf, 1 / self._weight)
return torch.ones_like(self.data)
@variance.setter
def variance(self, variance):
self.set_variance(variance)
@property
def has_variance(self):
"""Returns True when the image object has stored variance values. If
this is False and the variance property is called then a
tensor of ones will be returned.
"""
try:
return self._weight is not None
except AttributeError:
return False
@property
def weight(self):
"""Stores the weight of the image pixels. This represents the
uncertainty in each pixel value. It should always have the
same shape as the image data. In the case where the weight
is not known, a tensor of ones will be created to stand in as
the weight values.
The weights are used to proprtionately scale residuals in the
likelihood. Most commonly this shows up as a :math:`\\chi^2`
like:
.. math::
\\chi^2 = (\\vec{y} - \\vec{f(\\theta)})^TW(\\vec{y} - \\vec{f(\\theta)})
which can be optimized to find parameter values. Using the
Jacobian, which in this case is the derivative of every pixel
wrt every parameter, the weight matrix also appears in the
gradient:
.. math::
\\vec{g} = J^TW(\\vec{y} - \\vec{f(\\theta)})
and the hessian approximation used in Levenberg-Marquardt:
.. math::
H \\approx J^TWJ
"""
if self.has_weight:
return self._weight
return torch.ones_like(self.data)
@weight.setter
def weight(self, weight):
self.set_weight(weight)
@property
def has_weight(self):
"""Returns True when the image object has stored weight values. If
this is False and the weight property is called then a
tensor of ones will be returned.
"""
try:
return self._weight is not None
except AttributeError:
self._weight = None
return False
@property
def mask(self):
"""The mask stores a tensor of boolean values which indicate any
pixels to be ignored. These pixels will be skipped in
likelihood evaluations and in parameter optimization. It is
common practice to mask pixels with pathological values such
as due to cosmic rays or satelites passing through the image.
In a mask, a True value indicates that the pixel is masked and
should be ignored. False indicates a normal pixel which will
inter into most calculaitons.
If no mask is provided, all pixels are assumed valid.
"""
if self.has_mask:
return self._mask
return torch.zeros_like(self.data, dtype=torch.bool)
@mask.setter
def mask(self, mask):
self.set_mask(mask)
@property
def has_mask(self):
"""
Single boolean to indicate if a mask has been provided by the user.
"""
try:
return self._mask is not None
except AttributeError:
return False
@property
def psf(self):
"""Stores the point-spread-function for this target. This should be a
`PSF_Image` object which represents the scattering of a point
source of light. It can also be an `AstroPhot_Model` object
which will contribute its own parameters to an optimization
problem.
The PSF stored for a `Target_Image` object is passed to all
models applied to that target which have a `psf_mode` that is
not `none`. This means they will all use the same PSF
model. If one wishes to define a variable PSF across an image,
then they should pass the PSF objects to the `AstroPhot_Model`'s
directly instead of to a `Target_Image`.
Raises:
AttributeError: if this is called without a PSF defined
"""
if self.has_psf:
return self._psf
raise AttributeError("This image does not have a PSF")
@psf.setter
def psf(self, psf):
self.set_psf(psf)
@property
def has_psf(self):
try:
return self._psf is not None
except AttributeError:
return False
[docs]
def set_variance(self, variance):
"""
Provide a variance tensor for the image. Variance is equal to $\\sigma^2$. This should have the same shape as the data.
"""
if variance is None:
self._weight = None
return
self.set_weight(1 / variance)
[docs]
def set_weight(self, weight):
"""Provide a weight tensor for the image. Weight is equal to $\\frac{1}{\\sigma^2}$. This should have the same
shape as the data.
"""
if weight is None:
self._weight = None
return
if weight.shape != self.data.shape:
raise SpecificationConflict(
f"weight/variance must have same shape as data ({weight.shape} vs {self.data.shape})"
)
self._weight = (
weight.to(dtype=AP_config.ap_dtype, device=AP_config.ap_device)
if isinstance(weight, torch.Tensor)
else torch.as_tensor(weight, dtype=AP_config.ap_dtype, device=AP_config.ap_device)
)
[docs]
def set_psf(self, psf, psf_upscale=1):
"""Provide a psf for the `Target_Image`. This is stored and passed to
models which need to be convolved.
The PSF doesn't need to have the same pixelscale as the
image. It should be some multiple of the resolution of the
`Target_Image` though. So if the image has a pixelscale of 1,
the psf may have a pixelscale of 1, 1/2, 1/3, 1/4 and so on.
"""
if psf is None:
self._psf = None
return
if isinstance(psf, PSF_Image):
self._psf = psf
return
self._psf = PSF_Image(
data=psf,
psf_upscale=psf_upscale,
pixelscale=self.pixelscale / psf_upscale,
identity=self.identity,
)
[docs]
def set_mask(self, mask):
"""
Set the boolean mask which will indicate which pixels to ignore. A mask value of True means the pixel will be ignored.
"""
if mask is None:
self._mask = None
return
if mask.shape != self.data.shape:
raise SpecificationConflict(
f"mask must have same shape as data ({mask.shape} vs {self.data.shape})"
)
self._mask = (
mask.to(dtype=torch.bool, device=AP_config.ap_device)
if isinstance(mask, torch.Tensor)
else torch.as_tensor(mask, dtype=torch.bool, device=AP_config.ap_device)
)
[docs]
def to(self, dtype=None, device=None):
"""Converts the stored `Target_Image` data, variance, psf, etc to a
given data type and device.
"""
super().to(dtype=dtype, device=device)
if dtype is not None:
dtype = AP_config.ap_dtype
if device is not None:
device = AP_config.ap_device
if self.has_weight:
self._weight = self._weight.to(dtype=dtype, device=device)
if self.has_psf:
self._psf = self._psf.to(dtype=dtype, device=device)
if self.has_mask:
self._mask = self.mask.to(dtype=torch.bool, device=device)
return self
[docs]
def or_mask(self, mask):
"""
Combines the currently stored mask with a provided new mask using the boolean `or` operator.
"""
self._mask = torch.logical_or(self.mask, mask)
[docs]
def and_mask(self, mask):
"""
Combines the currently stored mask with a provided new mask using the boolean `and` operator.
"""
self._mask = torch.logical_and(self.mask, mask)
[docs]
def copy(self, **kwargs):
"""Produce a copy of this image with all of the same properties. This
can be used when one wishes to make temporary modifications to
an image and then will want the original again.
"""
return super().copy(
mask=self._mask,
psf=self._psf,
weight=self._weight,
**kwargs,
)
[docs]
def blank_copy(self, **kwargs):
"""Produces a blank copy of the image which has the same properties
except that its data is not filled with zeros.
"""
return super().blank_copy(mask=self._mask, psf=self._psf, **kwargs)
[docs]
def get_window(self, window, **kwargs):
"""Get a sub-region of the image as defined by a window on the sky."""
indices = self.window.get_self_indices(window)
return super().get_window(
window=window,
weight=self._weight[indices] if self.has_weight else None,
mask=self._mask[indices] if self.has_mask else None,
psf=self._psf,
**kwargs,
)
[docs]
def jacobian_image(
self,
parameters: Optional[List[str]] = None,
data: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Construct a blank `Jacobian_Image` object formatted like this current `Target_Image` object. Mostly used internally.
"""
if parameters is None:
data = None
parameters = []
elif data is None:
data = torch.zeros(
(*self.data.shape, len(parameters)),
dtype=AP_config.ap_dtype,
device=AP_config.ap_device,
)
return Jacobian_Image(
parameters=parameters,
target_identity=self.identity,
data=data,
header=self.header,
**kwargs,
)
[docs]
def model_image(self, data: Optional[torch.Tensor] = None, **kwargs):
"""
Construct a blank `Model_Image` object formatted like this current `Target_Image` object. Mostly used internally.
"""
return Model_Image(
data=torch.zeros_like(self.data) if data is None else data,
header=self.header,
target_identity=self.identity,
**kwargs,
)
[docs]
def reduce(self, scale, **kwargs):
"""Returns a new `Target_Image` object with a reduced resolution
compared to the current image. `scale` should be an integer
indicating how much to reduce the resolution. If the
`Target_Image` was originally (48,48) pixels across with a
pixelscale of 1 and `reduce(2)` is called then the image will
be (24,24) pixels and the pixelscale will be 2. If `reduce(3)`
is called then the returned image will be (16,16) pixels
across and the pixelscale will be 3.
"""
MS = self.data.shape[0] // scale
NS = self.data.shape[1] // scale
return super().reduce(
scale=scale,
variance=(
self.variance[: MS * scale, : NS * scale]
.reshape(MS, scale, NS, scale)
.sum(axis=(1, 3))
if self.has_variance
else None
),
mask=(
self.mask[: MS * scale, : NS * scale]
.reshape(MS, scale, NS, scale)
.amax(axis=(1, 3))
if self.has_mask
else None
),
psf=self.psf.reduce(scale) if self.has_psf else None,
**kwargs,
)
[docs]
def expand(self, padding):
"""
`Target_Image` doesn't have expand yet.
"""
raise NotImplementedError("expand not available for Target_Image yet")
[docs]
def get_state(self):
state = super().get_state()
if self.has_weight:
state["weight"] = self.weight.detach().cpu().tolist()
if self.has_mask:
state["mask"] = self.mask.detach().cpu().tolist()
if self.has_psf:
state["psf"] = self.psf.get_state()
return state
[docs]
def set_state(self, state):
super().set_state(state)
self.weight = state.get("weight", None)
self.mask = state.get("mask", None)
if "psf" in state:
self.psf = PSF_Image(state=state["psf"])
[docs]
def get_fits_state(self):
states = super().get_fits_state()
if self.has_weight:
states.append(
{
"DATA": self.weight.detach().cpu().numpy(),
"HEADER": {"IMAGE": "WEIGHT"},
}
)
if self.has_mask:
states.append(
{
"DATA": self.mask.detach().cpu().numpy().astype(int),
"HEADER": {"IMAGE": "MASK"},
}
)
if self.has_psf:
states += self.psf.get_fits_state()
return states
[docs]
def set_fits_state(self, states):
super().set_fits_state(states)
for state in states:
if state["HEADER"]["IMAGE"] == "WEIGHT":
self.weight = np.array(state["DATA"], dtype=np.float64)
if state["HEADER"]["IMAGE"] == "mask":
self.mask = np.array(state["DATA"], dtype=bool)
if state["HEADER"]["IMAGE"] == "PSF":
self.psf = PSF_Image(fits_state=states)
[docs]
class Target_Image_List(Image_List, Target_Image):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not all(isinstance(image, Target_Image) for image in self.image_list):
raise InvalidImage(
f"Target_Image_List can only hold Target_Image objects, not {tuple(type(image) for image in self.image_list)}"
)
@property
def variance(self):
return tuple(image.variance for image in self.image_list)
@variance.setter
def variance(self, variance):
for image, var in zip(self.image_list, variance):
image.set_variance(var)
@property
def has_variance(self):
return any(image.has_variance for image in self.image_list)
@property
def weight(self):
return tuple(image.weight for image in self.image_list)
@weight.setter
def weight(self, weight):
for image, wgt in zip(self.image_list, weight):
image.set_weight(wgt)
@property
def has_weight(self):
return any(image.has_weight for image in self.image_list)
[docs]
def jacobian_image(self, parameters: List[str], data: Optional[List[torch.Tensor]] = None):
if data is None:
data = [None] * len(self.image_list)
return Jacobian_Image_List(
list(image.jacobian_image(parameters, dat) for image, dat in zip(self.image_list, data))
)
[docs]
def model_image(self, data: Optional[List[torch.Tensor]] = None):
if data is None:
data = [None] * len(self.image_list)
return Model_Image_List(
list(image.model_image(data=dat) for image, dat in zip(self.image_list, data))
)
[docs]
def match_indices(self, other):
indices = []
if isinstance(other, Target_Image_List):
for other_image in other.image_list:
for isi, self_image in enumerate(self.image_list):
if other_image.identity == self_image.identity:
indices.append(isi)
break
else:
indices.append(None)
elif isinstance(other, Target_Image):
for isi, self_image in enumerate(self.image_list):
if other.identity == self_image.identity:
indices = isi
break
else:
indices = None
return indices
def __isub__(self, other):
if isinstance(other, Target_Image_List):
for other_image in other.image_list:
for self_image in self.image_list:
if other_image.identity == self_image.identity:
self_image -= other_image
break
else:
self.image_list.append(other_image)
elif isinstance(other, Target_Image):
for self_image in self.image_list:
if other.identity == self_image.identity:
self_image -= other
break
elif isinstance(other, Model_Image_List):
for other_image in other.image_list:
for self_image in self.image_list:
if other_image.target_identity == self_image.identity:
self_image -= other_image
break
elif isinstance(other, Model_Image):
for self_image in self.image_list:
if other.target_identity == self_image.identity:
self_image -= other
else:
for self_image, other_image in zip(self.image_list, other):
self_image -= other_image
return self
def __iadd__(self, other):
if isinstance(other, Target_Image_List):
for other_image in other.image_list:
for self_image in self.image_list:
if other_image.identity == self_image.identity:
self_image += other_image
break
else:
self.image_list.append(other_image)
elif isinstance(other, Target_Image):
for self_image in self.image_list:
if other.identity == self_image.identity:
self_image += other
elif isinstance(other, Model_Image_List):
for other_image in other.image_list:
for self_image in self.image_list:
if other_image.target_identity == self_image.identity:
self_image += other_image
break
elif isinstance(other, Model_Image):
for self_image in self.image_list:
if other.target_identity == self_image.identity:
self_image += other
else:
for self_image, other_image in zip(self.image_list, other):
self_image += other_image
return self
@property
def mask(self):
return tuple(image.mask for image in self.image_list)
@mask.setter
def mask(self, mask):
for image, M in zip(self.image_list, mask):
image.set_mask(M)
@property
def has_mask(self):
return any(image.has_mask for image in self.image_list)
@property
def psf(self):
return tuple(image.psf for image in self.image_list)
@psf.setter
def psf(self, psf):
for image, P in zip(self.image_list, psf):
image.set_psf(P)
@property
def has_psf(self):
return any(image.has_psf for image in self.image_list)
@property
def psf_border(self):
return tuple(image.psf_border for image in self.image_list)
@property
def psf_border_int(self):
return tuple(image.psf_border_int for image in self.image_list)
[docs]
def set_variance(self, variance, img):
self.image_list[img].set_variance(variance)
[docs]
def set_psf(self, psf, img):
self.image_list[img].set_psf(psf)
[docs]
def set_mask(self, mask, img):
self.image_list[img].set_mask(mask)
[docs]
def or_mask(self, mask):
raise NotImplementedError()
[docs]
def and_mask(self, mask):
raise NotImplementedError()