Source code for astrophot.image.psf_image

from typing import List, Optional, Union

import torch
import numpy as np
from torch.nn.functional import avg_pool2d

from .image_object import Image, Image_List
from .image_header import Image_Header
from .model_image import Model_Image
from .jacobian_image import Jacobian_Image
from astropy.io import fits
from .. import AP_config
from ..errors import SpecificationConflict, InvalidData

__all__ = ["PSF_Image"]

[docs] class PSF_Image(Image): """Image object which represents a model of PSF (Point Spread Function). PSF_Image inherits from the base Image class and represents the model of a point spread function. The point spread function characterizes the response of an imaging system to a point source or point object. The shape of the PSF data must be odd. Attributes: data (torch.Tensor): The image data of the PSF. identity (str): The identity of the image. Default is None. Methods: psf_border_int: Calculates and returns the convolution border size of the PSF image in integer format. psf_border: Calculates and returns the convolution border size of the PSF image in the units of pixelscale. _save_image_list: Saves the image list to the PSF HDU header. reduce: Reduces the size of the image using a given scale factor. """ has_mask = False has_variance = False def __init__(self, *args, **kwargs): """ Initializes the PSF_Image class. Args: *args: Variable length argument list. **kwargs: Arbitrary keyword arguments. band (str, optional): The band of the image. Default is None. """ super().__init__(*args, **kwargs) self.window.reference_radec = (0,0) self.window.reference_planexy = (0,0) self.window.reference_imageij = np.flip(np.array(self.data.shape, dtype = float) - 1.) / 2 self.window.reference_imagexy = (0,0)
[docs] def set_data( self, data: Union[torch.Tensor, np.ndarray], require_shape: bool = True ): super().set_data(data = data, require_shape = require_shape) if torch.any( (torch.tensor(self.data.shape) % 2) != 1 ): raise SpecificationConflict(f"psf must have odd shape, not {self.data.shape}") if torch.any(self.data < 0): AP_config.ap_logger.warning("psf data should be non-negative")
[docs] def normalize(self): """Normalizes the PSF image to have a sum of 1.""" self.data /= torch.sum(self.data)
@property def psf_border_int(self): """Calculates and returns the border size of the PSF image in integer format. This is the border used for padding before convolution. Returns: torch.Tensor: The border size of the PSF image in integer format. """ return torch.ceil( ( 1 + torch.flip( torch.tensor( self.data.shape, dtype=AP_config.ap_dtype, device=AP_config.ap_device, ), (0,), ) ) / 2 ).int()
[docs] def _save_image_list(self, image_list): """Saves the image list to the PSF HDU header. Args: image_list (list): The list of images to be saved. psf_header (astropy.io.fits.Header): The header of the PSF HDU. """ img_header = self.header._save_image_list() img_header["IMAGE"] = "PSF" image_list.append( fits.ImageHDU(self.data.detach().cpu().numpy(), header=img_header) )
[docs] def jacobian_image( self, parameters: Optional[List[str]] = None, data: Optional[torch.Tensor] = None, **kwargs, ): """ Construct a blank `Jacobian_Image` object formatted like this current `PSF_Image` object. Mostly used internally. """ if parameters is None: data = None parameters = [] elif data is None: data = torch.zeros( (*self.data.shape, len(parameters)), dtype=AP_config.ap_dtype, device=AP_config.ap_device, ) return Jacobian_Image( parameters=parameters, target_identity=self.identity, data=data, header=self.header, **kwargs, )
[docs] def model_image(self, data: Optional[torch.Tensor] = None, **kwargs): """ Construct a blank `Model_Image` object formatted like this current `Target_Image` object. Mostly used internally. """ return Model_Image( data=torch.zeros_like(self.data) if data is None else data, header=self.header, target_identity=self.identity, **kwargs, )
[docs] def expand(self, padding): raise NotImplementedError("expand not available for PSF_Image")
[docs] def get_fits_state(self): states = [{}] states[0]["DATA"] = self.data.detach().cpu().numpy() states[0]["HEADER"] = self.header.get_fits_state() states[0]["HEADER"]["IMAGE"] = "PSF" return states
[docs] def set_fits_state(self, states): for state in states: if state["HEADER"]["IMAGE"] == "PSF": self.set_data(np.array(state["DATA"], dtype=np.float64), require_shape=False) self.header = Image_Header(fits_state = state["HEADER"]) break