Source code for astrophot.image.window_object

import numpy as np
import torch
from astropy.wcs import WCS as AstropyWCS

from .. import AP_config
from ..utils.conversions.coordinates import Rotate_Cartesian
from .wcs import WCS
from ..errors import ConflicingWCS, SpecificationConflict

__all__ = ["Window", "Window_List"]


[docs] class Window(WCS): """class to define a window on the sky in coordinate space. These windows can undergo arithmetic and preserve logical behavior. Image objects can also be indexed using windows and will return an appropriate subsection of their data. There are several ways to tell a Window object where to place itself. The simplest method is to pass an Astropy WCS object such as:: H = ap.image.Window(wcs = wcs) this will automatically place your image at the correct RA, DEC and assign the correct pixel scale, etc. WARNING, it will default to setting the reference RA DEC at the reference RA DEC of the wcs object; if you have multiple images you should force them all to have the same reference world coordinate by passing ``reference_radec = (ra, dec)``. See the :doc:`coordinates` documentation for more details. There are several other ways to initialize a window. If you provide ``origin_radec`` then it will place the image origin at the requested RA DEC coordinates. If you provide ``center_radec`` then it will place the image center at the requested RA DEC coordiantes. Note that in these cases the fixed point between the pixel grid and image plane is different (pixel origin and center respectively); so if you have rotated pixels in your pixel scale matrix then everything will be rotated about different points (pixel origin and center respectively). If you provide ``origin`` or ``center`` then those are coordiantes in the tangent plane (arcsec) and they will correspondingly become fixed points. For arbitrary control over the pixel positioning, use ``reference_imageij`` and ``reference_imagexy`` to fix the pixel and tangent plane coordinates respectively to each other, any rotation or shear will happen about that fixed point. Args: origin : Sequence or None, optional The origin of the image in the tangent plane coordinate system (arcsec), as a 1D array of length 2. Default is None. origin_radec : Sequence or None, optional The origin of the image in the world coordinate system (RA, DEC in degrees), as a 1D array of length 2. Default is None. center : Sequence or None, optional The center of the image in the tangent plane coordinate system (arcsec), as a 1D array of length 2. Default is None. center_radec : Sequence or None, optional The center of the image in the world coordinate system (RA, DEC in degrees), as a 1D array of length 2. Default is None. wcs: An astropy.wcs.WCS object which gives information about the origin and orientation of the window. reference_radec: world coordinates on the celestial sphere (RA, DEC in degrees) where the tangent plane makes contact. This should be the same for every image in multi-image analysis. reference_planexy: tangent plane coordinates (arcsec) where it makes contact with the celesial sphere. This should typically be (0,0) though that is not stricktly enforced (it is assumed if not given). This reference coordinate should be the same for all images in multi-image analysis. reference_imageij: pixel coordinates about which the image is defined. For example in an Astropy WCS object the wcs.wcs.crpix array gives the pixel coordinate reference point for which the world coordinate mapping (wcs.wcs.crval) is defined. One may think of the referenced pixel location as being "pinned" to the tangent plane. This may be different for each image in multi-image analysis.. reference_imagexy: tangent plane coordinates (arcsec) about which the image is defined. This is the pivot point about which the pixelscale matrix operates, therefore if the pixelscale matrix defines a rotation then this is the coordinate about which the rotation will be performed. This may be different for each image in multi-image analysis. """ def __init__( self, *, pixel_shape=None, origin=None, origin_radec=None, center=None, center_radec=None, state=None, fits_state=None, wcs=None, **kwargs, ): # If loading from a previous state, simply update values and end init if state is not None: self.set_state(state) return if fits_state is not None: self.set_fits_state(fits_state) return # Collect the shape of the window if pixel_shape is not None: self.pixel_shape = pixel_shape else: self.pixel_shape = wcs.pixel_shape # Determine relative positioning of tangent plane and pixel grid. Also world coordinates and tangent plane if not sum(C is not None for C in [wcs, origin_radec, center_radec, origin, center]) <= 1: raise SpecificationConflict( "Please provide only one reference position for the window, otherwise the placement is ambiguous" ) # Image coordinates provided by WCS if wcs is not None: super().__init__(wcs=wcs, **kwargs) # Image reference position from RA and DEC of image origin elif origin_radec is not None: # Origin given, it is reference point origin_radec = torch.as_tensor( origin_radec, dtype=AP_config.ap_dtype, device=AP_config.ap_device ) kwargs["reference_radec"] = kwargs.get("reference_radec", origin_radec) super().__init__(**kwargs) self.reference_imageij = (-0.5, -0.5) self.reference_imagexy = self.world_to_plane(origin_radec) # Image reference position from RA and DEC of image center elif center_radec is not None: pix_center = self.pixel_shape.to(dtype=AP_config.ap_dtype) / 2 - 0.5 center_radec = torch.as_tensor( center_radec, dtype=AP_config.ap_dtype, device=AP_config.ap_device ) kwargs["reference_radec"] = kwargs.get("reference_radec", center_radec) super().__init__(**kwargs) center = self.world_to_plane(center_radec) self.reference_imageij = pix_center self.reference_imagexy = center # Image reference position from tangent plane position of image origin elif origin is not None: kwargs.update( { "reference_imageij": (-0.5, -0.5), "reference_imagexy": origin, } ) super().__init__(**kwargs) # Image reference position from tangent plane position of image center elif center is not None: pix_center = self.pixel_shape.to(dtype=AP_config.ap_dtype) / 2 - 0.5 kwargs.update( { "reference_imageij": pix_center, "reference_imagexy": center, } ) super().__init__(**kwargs) # Image origin assumed to be at tangent plane origin else: super().__init__(**kwargs) @property def shape(self): dtype, device = self.pixelscale.dtype, self.pixelscale.device S1 = self.pixel_shape.to(dtype=dtype, device=device) S1[1] = 0.0 S2 = self.pixel_shape.to(dtype=dtype, device=device) S2[0] = 0.0 return torch.stack( ( torch.linalg.norm(self.pixelscale @ S1), torch.linalg.norm(self.pixelscale @ S2), ) ) @shape.setter def shape(self, shape): if shape is None: self._pixel_shape = None return shape = torch.as_tensor(shape, dtype=self.pixelscale.dtype, device=self.pixelscale.device) self.pixel_shape = shape / torch.sqrt(torch.sum(self.pixelscale**2, dim=0)) @property def pixel_shape(self): return self._pixel_shape @pixel_shape.setter def pixel_shape(self, shape): if shape is None: self._pixel_shape = None return self._pixel_shape = torch.as_tensor(shape, device=AP_config.ap_device) self._pixel_shape = torch.round(self.pixel_shape).to( dtype=torch.int32, device=AP_config.ap_device ) @property def size(self): """The number of pixels in the window""" return torch.prod(self.pixel_shape) @property def end(self): return self.pixel_to_plane_delta( self.pixel_shape.to(dtype=self.pixelscale.dtype, device=self.pixelscale.device) ) @property def origin(self): return self.pixel_to_plane(-0.5 * torch.ones_like(self.reference_imageij)) @property def center(self): return self.origin + self.end / 2
[docs] def copy(self, **kwargs): copy_kwargs = {"pixel_shape": torch.clone(self.pixel_shape)} copy_kwargs.update(kwargs) return super().copy(**copy_kwargs)
[docs] def to(self, dtype=None, device=None): if dtype is None: dtype = AP_config.ap_dtype if device is None: device = AP_config.ap_device super().to(dtype=dtype, device=device) self.pixel_shape = self.pixel_shape.to(dtype=dtype, device=device)
[docs] def rescale_pixel(self, scale, **kwargs): return self.copy( pixelscale=self.pixelscale * scale, pixel_shape=self.pixel_shape // scale, reference_imageij=(self.reference_imageij + 0.5) / scale - 0.5, **kwargs, )
@staticmethod @torch.no_grad() def _get_indices(ref_window, obj_window): other_origin_pix = torch.round(ref_window.plane_to_pixel(obj_window.origin) + 0.5).int() new_origin_pix = torch.maximum(torch.zeros_like(other_origin_pix), other_origin_pix) other_pixel_end = torch.round( ref_window.plane_to_pixel(obj_window.origin + obj_window.end) + 0.5 ).int() new_pixel_end = torch.minimum(ref_window.pixel_shape, other_pixel_end) return slice(new_origin_pix[1], new_pixel_end[1]), slice( new_origin_pix[0], new_pixel_end[0] )
[docs] def get_self_indices(self, obj): """ Return an index slicing tuple for obj corresponding to this window """ if isinstance(obj, Window): return self._get_indices(self, obj) return self._get_indices(self, obj.window)
[docs] def get_other_indices(self, obj): """ Return an index slicing tuple for obj corresponding to this window """ if isinstance(obj, Window): return self._get_indices(obj, self) return self._get_indices(obj.window, self)
[docs] def overlap_frac(self, other): overlap = self & other overlap_area = torch.prod(overlap.shape) full_area = torch.prod(self.shape) + torch.prod(other.shape) - overlap_area return overlap_area / full_area
[docs] def shift(self, shift): """ Shift the location of the window by a specified amount in tangent plane coordinates """ self.reference_imagexy = self.reference_imagexy + shift return self
[docs] def pixel_shift(self, shift): """ Shift the location of the window by a specified amount in pixel grid coordinates """ self.reference_imageij = self.reference_imageij - shift return self
[docs] def get_astropywcs(self, **kwargs): wargs = { "NAXIS": 2, "NAXIS1": self.pixel_shape[0].item(), "NAXIS2": self.pixel_shape[1].item(), "CTYPE1": "RA---TAN", "CTYPE2": "DEC--TAN", "CRVAL1": self.pixel_to_world(self.reference_imageij)[0].item(), "CRVAL2": self.pixel_to_world(self.reference_imageij)[1].item(), "CRPIX1": self.reference_imageij[0].item(), "CRPIX2": self.reference_imageij[1].item(), "CD1_1": self.pixelscale[0][0].item(), "CD1_2": self.pixelscale[0][1].item(), "CD2_1": self.pixelscale[1][0].item(), "CD2_2": self.pixelscale[1][1].item(), } wargs.update(kwargs) return AstropyWCS(wargs)
[docs] def get_state(self): state = super().get_state() state["pixel_shape"] = self.pixel_shape.detach().cpu().tolist() return state
[docs] def set_state(self, state): super().set_state(state) self.pixel_shape = torch.tensor( state["pixel_shape"], dtype=AP_config.ap_dtype, device=AP_config.ap_device )
[docs] def get_fits_state(self): state = super().get_fits_state() state["PXL_SHPE"] = str(self.pixel_shape.detach().cpu().tolist()) return state
[docs] def set_fits_state(self, state): super().set_fits_state(state) self.pixel_shape = torch.tensor( eval(state["PXL_SHPE"]), dtype=AP_config.ap_dtype, device=AP_config.ap_device )
[docs] def crop_pixel(self, pixels): """ [crop all sides] or [crop x, crop y] or [crop x low, crop y low, crop x high, crop y high] """ if len(pixels) == 1: self.pixel_shape = self.pixel_shape - 2 * pixels[0] self.reference_imageij = self.reference_imageij - pixels[0] elif len(pixels) == 2: pix_shift = torch.as_tensor( pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device ) self.pixel_shape = self.pixel_shape - 2 * pix_shift self.reference_imageij = self.reference_imageij - pix_shift elif len(pixels) == 4: # different crop on all sides pixels = torch.as_tensor(pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device) self.pixel_shape = self.pixel_shape - pixels[:2] - pixels[2:] self.reference_imageij = self.reference_imageij - pixels[:2] else: raise ValueError(f"Unrecognized pixel crop format: {pixels}") return self
[docs] def crop_to_pixel(self, pixels): """ format: [[xmin, xmax],[ymin,ymax]] """ pixels = torch.tensor(pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device) self.reference_imageij = self.reference_imageij - pixels[:, 0] self.pixel_shape = pixels[:, 1] - pixels[:, 0] return self
[docs] def pad_pixel(self, pixels): """ [pad all sides] or [pad x, pad y] or [pad x low, pad y low, pad x high, pad y high] """ if len(pixels) == 1: self.pixel_shape = self.pixel_shape + 2 * pixels[0] self.reference_imageij = self.reference_imageij + pixels[0] elif len(pixels) == 2: pix_shift = torch.as_tensor( pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device ) self.pixel_shape = self.pixel_shape + 2 * pix_shift self.reference_imageij = self.reference_imageij + pix_shift elif len(pixels) == 4: # different crop on all sides pixels = torch.as_tensor(pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device) self.pixel_shape = self.pixel_shape + pixels[:2] + pixels[2:] self.reference_imageij = self.reference_imageij + pixels[:2] else: raise ValueError(f"Unrecognized pixel crop format: {pixels}") return self
[docs] @torch.no_grad() def get_coordinate_meshgrid(self): """Returns a meshgrid with tangent plane coordinates for the center of every pixel. """ pix = self.pixel_shape.to(dtype=AP_config.ap_dtype) xsteps = torch.arange(pix[0], dtype=AP_config.ap_dtype, device=AP_config.ap_device) ysteps = torch.arange(pix[1], dtype=AP_config.ap_dtype, device=AP_config.ap_device) meshx, meshy = torch.meshgrid( xsteps, ysteps, indexing="xy", ) Coords = self.pixel_to_plane(meshx, meshy) return torch.stack(Coords)
[docs] @torch.no_grad() def get_coordinate_corner_meshgrid(self): """Returns a meshgrid with tangent plane coordinates for the corners of every pixel. """ pix = self.pixel_shape.to(dtype=AP_config.ap_dtype) xsteps = ( torch.arange(pix[0] + 1, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - 0.5 ) ysteps = ( torch.arange(pix[1] + 1, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - 0.5 ) meshx, meshy = torch.meshgrid( xsteps, ysteps, indexing="xy", ) Coords = self.pixel_to_plane(meshx, meshy) return torch.stack(Coords)
[docs] @torch.no_grad() def get_coordinate_simps_meshgrid(self): """Returns a meshgrid with tangent plane coordinates for performing simpsons method pixel integration (all corners, centers, and middle of each edge). This is approximately 4 times more points than the standard :meth:`get_coordinate_meshgrid`. """ pix = self.pixel_shape.to(dtype=AP_config.ap_dtype) xsteps = ( 0.5 * torch.arange( 2 * (pix[0]) + 1, dtype=AP_config.ap_dtype, device=AP_config.ap_device, ) - 0.5 ) ysteps = ( 0.5 * torch.arange( 2 * (pix[1]) + 1, dtype=AP_config.ap_dtype, device=AP_config.ap_device, ) - 0.5 ) meshx, meshy = torch.meshgrid( xsteps, ysteps, indexing="xy", ) Coords = self.pixel_to_plane(meshx, meshy) return torch.stack(Coords)
# Window Comparison operators @torch.no_grad() def __eq__(self, other): return ( torch.all(self.pixel_shape == other.pixel_shape) and torch.all(self.pixelscale == other.pixelscale) and (self.projection == other.projection) and ( torch.all( self.pixel_to_plane(torch.zeros_like(self.reference_imageij)) == other.pixel_to_plane(torch.zeros_like(other.reference_imageij)) ) ) ) # fixme more checks? @torch.no_grad() def __ne__(self, other): return not self == other # Window interaction operators @torch.no_grad() def __or__(self, other): other_origin_pix = self.plane_to_pixel(other.origin) new_origin_pix = torch.minimum(-0.5 * torch.ones_like(other_origin_pix), other_origin_pix) other_pixel_end = self.plane_to_pixel(other.origin + other.end) new_pixel_end = torch.maximum( self.pixel_shape.to(dtype=AP_config.ap_dtype), other_pixel_end ) return self.copy( origin=self.pixel_to_plane(new_origin_pix), pixel_shape=new_pixel_end - new_origin_pix, ) @torch.no_grad() def __ior__(self, other): other_origin_pix = self.plane_to_pixel(other.origin) new_origin_pix = torch.minimum(-0.5 * torch.ones_like(other_origin_pix), other_origin_pix) other_pixel_end = self.plane_to_pixel(other.origin + other.end) new_pixel_end = torch.maximum( self.pixel_shape.to(dtype=AP_config.ap_dtype), other_pixel_end ) self.reference_imageij = self.reference_imageij - (new_origin_pix + 0.5) self.pixel_shape = new_pixel_end - new_origin_pix return self @torch.no_grad() def __and__(self, other): other_origin_pix = self.plane_to_pixel(other.origin) new_origin_pix = torch.maximum(-0.5 * torch.ones_like(other_origin_pix), other_origin_pix) other_pixel_end = self.plane_to_pixel(other.origin + other.end) new_pixel_end = torch.minimum( self.pixel_shape.to(dtype=AP_config.ap_dtype) - 0.5, other_pixel_end ) return self.copy( origin=self.pixel_to_plane(new_origin_pix), pixel_shape=new_pixel_end - new_origin_pix, ) @torch.no_grad() def __iand__(self, other): other_origin_pix = self.plane_to_pixel(other.origin) new_origin_pix = torch.maximum(-0.5 * torch.ones_like(other_origin_pix), other_origin_pix) other_pixel_end = self.plane_to_pixel(other.origin + other.end) new_pixel_end = torch.minimum( self.pixel_shape.to(dtype=AP_config.ap_dtype), other_pixel_end ) self.reference_imageij = self.reference_imageij - (new_origin_pix + 0.5) self.pixel_shape = new_pixel_end - new_origin_pix return self def __str__(self): return f"window origin: {self.origin.detach().cpu().tolist()}, shape: {self.shape.detach().cpu().tolist()}, center: {self.center.detach().cpu().tolist()}, pixelscale: {self.pixelscale.detach().cpu().tolist()}" def __repr__(self): return ( f"window pixel_shape: {self.pixel_shape.detach().cpu().tolist()}, shape: {self.shape.detach().cpu().tolist()}\n" + super().__repr__() )
[docs] class Window_List(Window): def __init__(self, window_list=None, state=None): if state is not None: self.set_state(state) else: if window_list is None: window_list = [] self.window_list = list(window_list) self.check_wcs()
[docs] def check_wcs(self): """Ensure the WCS systems being used by all the windows in this list are consistent with each other. They should all project world coordinates onto the same tangent plane. """ windows = tuple( W.reference_radec for W in filter(lambda w: w is not None, self.window_list) ) if len(windows) == 0: return ref = torch.stack(windows) if not torch.allclose(ref, ref[0]): raise ConflicingWCS( "Reference (world) coordinate mismatch! All windows in Window_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means." ) ref = torch.stack( tuple(W.reference_planexy for W in filter(lambda w: w is not None, self.window_list)) ) if not torch.allclose(ref, ref[0]): raise ConflicingWCS( "Reference (tangent plane) coordinate mismatch! All windows in Window_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means." ) if len(set(W.projection for W in filter(lambda w: w is not None, self.window_list))) > 1: raise ConflicingWCS( "Projection mismatch! All windows in Window_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means." )
@property @torch.no_grad() def origin(self): return tuple(w.origin for w in self) @property @torch.no_grad() def shape(self): return tuple(w.shape for w in self) @property @torch.no_grad() def center(self): return tuple(w.center for w in self)
[docs] def shift_origin(self, shift): raise NotImplementedError("shift origin not implemented for window list")
[docs] def copy(self): return self.__class__(list(w.copy() for w in self))
[docs] def to(self, dtype=None, device=None): if dtype is None: dtype = AP_config.ap_dtype if device is None: device = AP_config.ap_device for window in self: window.to(dtype, device)
[docs] def get_state(self): return list(window.get_state() for window in self)
[docs] def set_state(self, state): self.window_list = list(Window(state=st) for st in state)
# Window interaction operators @torch.no_grad() def __or__(self, other): new_windows = list((sw | ow) for sw, ow in zip(self, other)) return self.__class__(window_list=new_windows) @torch.no_grad() def __ior__(self, other): for sw, ow in zip(self, other): sw |= ow return self @torch.no_grad() def __and__(self, other): new_windows = list((sw & ow) for sw, ow in zip(self, other)) return self.__class__(window_list=new_windows) @torch.no_grad() def __iand__(self, other): for sw, ow in zip(self, other): sw &= ow return self # Window Comparison operators @torch.no_grad() def __eq__(self, other): results = list((sw == ow).view(-1) for sw, ow in zip(self, other)) return torch.all(torch.cat(results)) @torch.no_grad() def __ne__(self, other): return not self == other def __len__(self): return len(self.window_list) def __iter__(self): return (win for win in self.window_list) def __str__(self): return "Window List: \n" + ("\n".join(list(str(window) for window in self)) + "\n") def __repr__(self): return "Window List: \n" + ("\n".join(list(repr(window) for window in self)) + "\n")