from typing import Optional, Union, Any, Sequence, Tuple
from copy import deepcopy
import torch
from torch.nn.functional import pad
import numpy as np
from astropy.io import fits
from astropy.wcs import WCS as AstropyWCS
from .window_object import Window, Window_List
from .image_header import Image_Header
from .. import AP_config
from ..errors import SpecificationConflict, ConflicingWCS, InvalidData, InvalidWindow
__all__ = ["Image", "Image_List"]
[docs]
class Image(object):
"""Core class to represent images with pixel values, pixel scale,
and a window defining the spatial coordinates on the sky.
It supports arithmetic operations with other image objects while preserving logical image boundaries.
It also provides methods for determining the coordinate locations of pixels
Parameters:
data: the matrix of pixel values for the image
pixelscale: the length of one side of a pixel in arcsec/pixel
window: an AstroPhot Window object which defines the spatial cooridnates on the sky
filename: a filename from which to load the image.
zeropoint: photometric zero point for converting from pixel flux to magnitude
metadata: Any information the user wishes to associate with this image, stored in a python dictionary
origin: The origin of the image in the coordinate system.
"""
def __init__(
self,
*,
data: Optional[torch.Tensor] = None,
header: Optional[Image_Header] = None,
wcs: Optional[AstropyWCS] = None,
pixelscale: Optional[Union[float, torch.Tensor]] = None,
window: Optional[Window] = None,
filename: Optional[str] = None,
zeropoint: Optional[Union[float, torch.Tensor]] = None,
metadata: Optional[dict] = None,
origin: Optional[Sequence] = None,
center: Optional[Sequence] = None,
identity: str = None,
state: Optional[dict] = None,
fits_state: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""Initialize an instance of the APImage class.
Parameters:
-----------
data : numpy.ndarray or None, optional
The image data. Default is None.
wcs : astropy.wcs.wcs.WCS or None, optional
A WCS object which defines a coordinate system for the image. Note that AstroPhot only handles basic WCS conventions. It will use the WCS object to get `wcs.pixel_to_world(-0.5, -0.5)` to determine the position of the origin in world coordinates. It will also extract the `pixel_scale_matrix` to index pixels going forward.
pixelscale : float or None, optional
The physical scale of the pixels in the image, in units of arcseconds. Default is None.
window : Window or None, optional
A Window object defining the area of the image to use. Default is None.
filename : str or None, optional
The name of a file containing the image data. Default is None.
zeropoint : float or None, optional
The image's zeropoint, used for flux calibration. Default is None.
metadata : dict or None, optional
Any information the user wishes to associate with this image, stored in a python dictionary. Default is None.
origin : numpy.ndarray or None, optional
The origin of the image in the coordinate system, as a 1D array of length 2. Default is None.
center : numpy.ndarray or None, optional
The center of the image in the coordinate system, as a 1D array of length 2. Default is None.
Returns:
--------
None
"""
self._data = None
if state is not None:
self.header = Image_Header(state=state["header"])
elif fits_state is not None:
self.set_fits_state(fits_state)
return
elif header is None:
if data is None and window is None and filename is None:
raise InvalidData("Image must have either data or a window to construct itself.")
self.header = Image_Header(
data_shape=None if data is None else data.shape,
pixelscale=pixelscale,
wcs=wcs,
window=window,
filename=filename,
zeropoint=zeropoint,
metadata=metadata,
origin=origin,
center=center,
identity=identity,
**kwargs,
)
else:
self.header = header
if filename is not None:
self.load(filename)
elif state is not None:
self.set_state(state)
elif fits_state is not None:
self.data = fits_state[0]["DATA"]
else:
# set the data
if data is None:
self.data = torch.zeros(
torch.flip(self.window.pixel_shape, (0,)).detach().cpu().tolist(),
dtype=AP_config.ap_dtype,
device=AP_config.ap_device,
)
else:
self.data = data
self.to()
# # Check that image data and header are in agreement (this requires talk back from GPU to CPU so is only used for testing)
# assert np.all(np.flip(np.array(self.data.shape)[:2]) == self.window.pixel_shape.numpy()), f"data shape {np.flip(np.array(self.data.shape)[:2])}, window shape {self.window.pixel_shape.numpy()}"
@property
def north(self):
return self.header.north
@property
def pixel_area(self):
return self.header.pixel_area
@property
def pixel_length(self):
return self.header.pixel_length
[docs]
def world_to_plane(self, *args, **kwargs):
return self.window.world_to_plane(*args, **kwargs)
[docs]
def plane_to_world(self, *args, **kwargs):
return self.window.plane_to_world(*args, **kwargs)
[docs]
def plane_to_pixel(self, *args, **kwargs):
return self.window.plane_to_pixel(*args, **kwargs)
[docs]
def pixel_to_plane(self, *args, **kwargs):
return self.window.pixel_to_plane(*args, **kwargs)
[docs]
def plane_to_pixel_delta(self, *args, **kwargs):
return self.window.plane_to_pixel_delta(*args, **kwargs)
[docs]
def pixel_to_plane_delta(self, *args, **kwargs):
return self.window.pixel_to_plane_delta(*args, **kwargs)
[docs]
def world_to_pixel(self, *args, **kwargs):
return self.window.world_to_pixel(*args, **kwargs)
[docs]
def pixel_to_world(self, *args, **kwargs):
return self.window.pixel_to_world(*args, **kwargs)
def get_coordinate_meshgrid(self):
return self.window.get_coordinate_meshgrid()
def get_coordinate_corner_meshgrid(self):
return self.window.get_coordinate_corner_meshgrid()
def get_coordinate_simps_meshgrid(self):
return self.window.get_coordinate_simps_meshgrid()
@property
def origin(self) -> torch.Tensor:
"""
Returns the origin (bottom-left corner) of the image window.
Returns:
torch.Tensor: A 1D tensor of shape (2,) containing the (x, y) coordinates of the origin.
"""
return self.header.window.origin
@property
def shape(self) -> torch.Tensor:
"""
Returns the shape (size) of the image window.
Returns:
torch.Tensor: A 1D tensor of shape (2,) containing the (width, height) of the window in pixels.
"""
return self.header.window.shape
@property
def center(self) -> torch.Tensor:
"""
Returns the center of the image window.
Returns:
torch.Tensor: A 1D tensor of shape (2,) containing the (x, y) coordinates of the center.
"""
return self.header.window.center
@property
def size(self) -> torch.Tensor:
"""
Returns the size of the image window, the number of pixels in the image.
Returns:
torch.Tensor: A 0D tensor containing the number of pixels.
"""
return self.header.window.size
@property
def window(self):
return self.header.window
@property
def pixelscale(self):
return self.header.pixelscale
@property
def zeropoint(self):
return self.header.zeropoint
@property
def metadata(self):
return self.header.metadata
@property
def identity(self):
return self.header.identity
@property
def data(self) -> torch.Tensor:
"""
Returns the image data.
"""
return self._data
@data.setter
def data(self, data) -> None:
"""Set the image data."""
self.set_data(data)
[docs]
def set_data(self, data: Union[torch.Tensor, np.ndarray], require_shape: bool = True):
"""
Set the image data.
Args:
data (torch.Tensor or numpy.ndarray): The image data.
require_shape (bool): Whether to check that the shape of the data is the same as the current data.
Raises:
SpecificationConflict: If `require_shape` is `True` and the shape of the data is different from the current data.
"""
if self._data is not None and require_shape and data.shape != self._data.shape:
raise SpecificationConflict(
f"Attempting to change image data with tensor that has a different shape! ({data.shape} vs {self._data.shape}) Use 'require_shape = False' if this is desired behaviour."
)
if data is None:
self.data = torch.tensor((), dtype=AP_config.ap_dtype, device=AP_config.ap_device)
elif isinstance(data, torch.Tensor):
self._data = data.to(dtype=AP_config.ap_dtype, device=AP_config.ap_device)
else:
self._data = torch.as_tensor(data, dtype=AP_config.ap_dtype, device=AP_config.ap_device)
[docs]
def copy(self, **kwargs):
"""Produce a copy of this image with all of the same properties. This
can be used when one wishes to make temporary modifications to
an image and then will want the original again.
"""
return self.__class__(
data=torch.clone(self.data),
header=self.header.copy(**kwargs),
**kwargs,
)
[docs]
def blank_copy(self, **kwargs):
"""Produces a blank copy of the image which has the same properties
except that its data is now filled with zeros.
"""
return self.__class__(
data=torch.zeros_like(self.data),
header=self.header.copy(**kwargs),
**kwargs,
)
[docs]
def get_window(self, window, **kwargs):
"""Get a sub-region of the image as defined by a window on the sky."""
return self.__class__(
data=self.data[self.window.get_self_indices(window)],
header=self.header.get_window(window, **kwargs),
**kwargs,
)
[docs]
def to(self, dtype=None, device=None):
if dtype is None:
dtype = AP_config.ap_dtype
if device is None:
device = AP_config.ap_device
if self._data is not None:
self._data = self._data.to(dtype=dtype, device=device)
self.header.to(dtype=dtype, device=device)
return self
[docs]
def crop(self, pixels):
# does this show up?
if len(pixels) == 1: # same crop in all dimension
self.set_data(
self.data[
pixels[0].int() : (self.data.shape[0] - pixels[0]).int(),
pixels[0].int() : (self.data.shape[1] - pixels[0]).int(),
],
require_shape=False,
)
elif len(pixels) == 2: # different crop in each dimension
self.set_data(
self.data[
pixels[1].int() : (self.data.shape[0] - pixels[1]).int(),
pixels[0].int() : (self.data.shape[1] - pixels[0]).int(),
],
require_shape=False,
)
elif len(pixels) == 4: # different crop on all sides
self.set_data(
self.data[
pixels[2].int() : (self.data.shape[0] - pixels[3]).int(),
pixels[0].int() : (self.data.shape[1] - pixels[1]).int(),
],
require_shape=False,
)
self.header = self.header.crop(pixels)
return self
[docs]
def flatten(self, attribute: str = "data") -> np.ndarray:
return getattr(self, attribute).reshape(-1)
[docs]
def get_coordinate_meshgrid(self):
return self.header.get_coordinate_meshgrid()
[docs]
def get_coordinate_corner_meshgrid(self):
return self.header.get_coordinate_corner_meshgrid()
[docs]
def get_coordinate_simps_meshgrid(self):
return self.header.get_coordinate_simps_meshgrid()
[docs]
def reduce(self, scale: int, **kwargs):
"""This operation will downsample an image by the factor given. If
scale = 2 then 2x2 blocks of pixels will be summed together to
form individual larger pixels. A new image object will be
returned with the appropriate pixelscale and data tensor. Note
that the window does not change in this operation since the
pixels are condensed, but the pixel size is increased
correspondingly.
Parameters:
scale: factor by which to condense the image pixels. Each scale X scale region will be summed [int]
"""
if not isinstance(scale, int) and not (
isinstance(scale, torch.Tensor) and scale.dtype is torch.int32
):
raise SpecificationConflict(f"Reduce scale must be an integer! not {type(scale)}")
if scale == 1:
return self
MS = self.data.shape[0] // scale
NS = self.data.shape[1] // scale
return self.__class__(
data=self.data[: MS * scale, : NS * scale]
.reshape(MS, scale, NS, scale)
.sum(axis=(1, 3)),
header=self.header.rescale_pixel(scale, **kwargs),
**kwargs,
)
[docs]
def expand(self, padding: Tuple[float]) -> None:
"""
Args:
padding tuple[float]: length 4 tuple with amounts to pad each dimension in physical units
"""
padding = np.array(padding)
if np.any(padding < 0):
raise SpecificationConflict("negative padding not allowed in expand method")
pad_boundaries = tuple(np.int64(np.round(np.array(padding) / self.pixelscale)))
self.data = pad(self.data, pad=pad_boundaries, mode="constant", value=0)
self.header.expand(padding)
[docs]
def get_state(self):
state = {}
state["type"] = self.__class__.__name__
state["data"] = self.data.detach().cpu().tolist()
state["header"] = self.header.get_state()
return state
[docs]
def set_state(self, state):
self.set_data(state["data"], require_shape=False)
self.header.set_state(state["header"])
[docs]
def get_fits_state(self):
states = [{}]
states[0]["DATA"] = self.data.detach().cpu().numpy()
states[0]["HEADER"] = self.header.get_fits_state()
states[0]["HEADER"]["IMAGE"] = "PRIMARY"
return states
[docs]
def set_fits_state(self, states):
for state in states:
if state["HEADER"]["IMAGE"] == "PRIMARY":
self.set_data(np.array(state["DATA"], dtype=np.float64), require_shape=False)
self.header.set_fits_state(state["HEADER"])
break
[docs]
def save(self, filename=None, overwrite=True):
states = self.get_fits_state()
img_list = [fits.PrimaryHDU(states[0]["DATA"], header=fits.Header(states[0]["HEADER"]))]
for state in states[1:]:
img_list.append(fits.ImageHDU(state["DATA"], header=fits.Header(state["HEADER"])))
hdul = fits.HDUList(img_list)
if filename is not None:
hdul.writeto(filename, overwrite=overwrite)
return hdul
[docs]
def load(self, filename):
hdul = fits.open(filename)
states = list({"DATA": hdu.data, "HEADER": hdu.header} for hdu in hdul)
self.set_fits_state(states)
def __sub__(self, other):
if isinstance(other, Image):
new_img = self[other.window].copy()
new_img.data -= other.data[self.window.get_other_indices(other)]
return new_img
else:
new_img = self.copy()
new_img.data -= other
return new_img
def __add__(self, other):
if isinstance(other, Image):
new_img = self[other.window].copy()
new_img.data += other.data[self.window.get_other_indices(other)]
return new_img
else:
new_img = self.copy()
new_img.data += other
return new_img
def __iadd__(self, other):
if isinstance(other, Image):
self.data[other.window.get_other_indices(self)] += other.data[
self.window.get_other_indices(other)
]
else:
self.data += other
return self
def __isub__(self, other):
if isinstance(other, Image):
self.data[other.window.get_other_indices(self)] -= other.data[
self.window.get_other_indices(other)
]
else:
self.data -= other
return self
def __getitem__(self, *args):
if len(args) == 1 and isinstance(args[0], Window):
return self.get_window(args[0])
if len(args) == 1 and isinstance(args[0], Image):
return self.get_window(args[0].window)
raise ValueError("Unrecognized Image getitem request!")
def __str__(self):
return f"image pixelscale: {self.pixelscale.detach().cpu().numpy()} origin: {self.origin.detach().cpu().numpy()} shape: {self.shape.detach().cpu().numpy()}"
def __repr__(self):
return f"image pixelscale: {self.pixelscale.detach().cpu().numpy()} origin: {self.origin.detach().cpu().numpy()} shape: {self.shape.detach().cpu().numpy()} center: {self.center.detach().cpu().numpy()}\ndata: {self.data.detach().cpu().numpy()}"
[docs]
class Image_List(Image):
def __init__(self, image_list, window=None):
self.image_list = list(image_list)
self.check_wcs()
self.window = window
[docs]
def check_wcs(self):
"""Ensure the WCS systems being used by all the windows in this list
are consistent with each other. They should all project world
coordinates onto the same tangent plane.
"""
ref = torch.stack(tuple(I.window.reference_radec for I in self.image_list))
if not torch.allclose(ref, ref[0]):
raise ConflicingWCS(
"Reference (world) coordinate mismatch! All images in Image_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means."
)
ref = torch.stack(tuple(I.window.reference_planexy for I in self.image_list))
if not torch.allclose(ref, ref[0]):
raise ConflicingWCS(
"Reference (tangent plane) coordinate mismatch! All images in Image_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means."
)
if len(set(I.window.projection for I in self.image_list)) > 1:
raise ConflicingWCS(
"Projection mismatch! All images in Image_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means."
)
@property
def window(self):
return Window_List(list(image.window for image in self.image_list))
@window.setter
def window(self, window):
if window is None:
return
if not isinstance(window, Window_List):
raise InvalidWindow("Target_List must take a Window_List object as its window")
for i in range(len(self.image_list)):
self.image_list[i] = self.image_list[i][window.window_list[i]]
@property
def pixelscale(self):
return tuple(image.pixelscale for image in self.image_list)
@property
def zeropoint(self):
return tuple(image.zeropoint for image in self.image_list)
@property
def data(self):
return tuple(image.data for image in self.image_list)
@data.setter
def data(self, data):
for image, dat in zip(self.image_list, data):
image.data = dat
[docs]
def copy(self):
return self.__class__(
tuple(image.copy() for image in self.image_list),
)
[docs]
def blank_copy(self):
return self.__class__(
tuple(image.blank_copy() for image in self.image_list),
)
[docs]
def get_window(self, window):
return self.__class__(
tuple(image[win] for image, win in zip(self.image_list, window)),
)
[docs]
def index(self, other):
if isinstance(other, Image) and hasattr(other, "identity"):
for i, self_image in enumerate(self.image_list):
if other.identity == self_image.identity:
return i
else:
raise ValueError("Could not find identity match between image list and input image")
raise NotImplementedError(f"Image_List cannot get index for {type(other)}")
[docs]
def to(self, dtype=None, device=None):
if dtype is not None:
dtype = AP_config.ap_dtype
if device is not None:
device = AP_config.ap_device
for image in self.image_list:
image.to(dtype=dtype, device=device)
return self
[docs]
def crop(self, *pixels):
raise NotImplementedError("Crop function not available for Image_List object")
[docs]
def get_coordinate_meshgrid(self):
return tuple(image.get_coordinate_meshgrid() for image in self.image_list)
[docs]
def get_coordinate_corner_meshgrid(self):
return tuple(image.get_coordinate_corner_meshgrid() for image in self.image_list)
[docs]
def get_coordinate_simps_meshgrid(self):
return tuple(image.get_coordinate_simps_meshgrid() for image in self.image_list)
[docs]
def flatten(self, attribute="data"):
return torch.cat(tuple(image.flatten(attribute) for image in self.image_list))
[docs]
def reduce(self, scale):
if scale == 1:
return self
return self.__class__(
tuple(image.reduce(scale) for image in self.image_list),
)
def __sub__(self, other):
if isinstance(other, Image_List):
new_list = []
for self_image, other_image in zip(self.image_list, other.image_list):
new_list.append(self_image - other_image)
return self.__class__(new_list)
else:
new_list = []
for self_image, other_image in zip(self.image_list, other):
new_list.append(self_image - other_image)
return self.__class__(new_list)
def __add__(self, other):
if isinstance(other, Image_List):
new_list = []
for self_image, other_image in zip(self.image_list, other.image_list):
new_list.append(self_image + other_image)
return self.__class__(new_list)
else:
new_list = []
for self_image, other_image in zip(self.image_list, other):
new_list.append(self_image + other_image)
return self.__class__(new_list)
def __isub__(self, other):
if isinstance(other, Image_List):
for self_image, other_image in zip(self.image_list, other.image_list):
self_image -= other_image
else:
for self_image, other_image in zip(self.image_list, other):
self_image -= other_image
return self
def __iadd__(self, other):
if isinstance(other, Image_List):
for self_image, other_image in zip(self.image_list, other.image_list):
self_image += other_image
else:
for self_image, other_image in zip(self.image_list, other):
self_image += other_image
return self
[docs]
def save(self, filename=None, overwrite=True):
raise NotImplementedError("Save/load not yet available for image lists")
[docs]
def load(self, filename):
raise NotImplementedError("Save/load not yet available for image lists")
def __getitem__(self, *args):
if len(args) == 1 and isinstance(args[0], Window):
return self.get_window(args[0])
if len(args) == 1 and isinstance(args[0], Image):
return self.get_window(args[0].window)
if all(isinstance(arg, (int, slice)) for arg in args):
return self.image_list.__getitem__(*args)
raise ValueError("Unrecognized Image_List getitem request!")
def __str__(self):
return f"image list of:\n" + "\n".join(image.__str__() for image in self.image_list)
def __repr__(self):
return f"image list of:\n" + "\n".join(image.__repr__() for image in self.image_list)
def __iter__(self):
return (img for img in self.image_list)
# self._index = 0
# return self
# def __next__(self):
# if self._index >= len(self.image_list):
# raise StopIteration
# img = self.image_list[self._index]
# self._index += 1
# return img