from typing import List, Optional, Union
import torch
import numpy as np
from torch.nn.functional import avg_pool2d
from .image_object import Image, Image_List
from .image_header import Image_Header
from .model_image import Model_Image
from .jacobian_image import Jacobian_Image
from astropy.io import fits
from .. import AP_config
from ..errors import SpecificationConflict, InvalidData
__all__ = ["PSF_Image"]
[docs]
class PSF_Image(Image):
    """Image object which represents a model of PSF (Point Spread Function).
    PSF_Image inherits from the base Image class and represents the model of a point spread function.
    The point spread function characterizes the response of an imaging system to a point source or point object.
    The shape of the PSF data must be odd.
    Attributes:
        data (torch.Tensor): The image data of the PSF.
        identity (str): The identity of the image. Default is None.
    Methods:
        psf_border_int: Calculates and returns the convolution border size of the PSF image in integer format.
        psf_border: Calculates and returns the convolution border size of the PSF image in the units of pixelscale.
        _save_image_list: Saves the image list to the PSF HDU header.
        reduce: Reduces the size of the image using a given scale factor.
    """
    has_mask = False
    has_variance = False
    def __init__(self, *args, **kwargs):
        """
        Initializes the PSF_Image class.
        Args:
            *args: Variable length argument list.
            **kwargs: Arbitrary keyword arguments.
                band (str, optional): The band of the image. Default is None.
        """
        super().__init__(*args, **kwargs)
        self.window.reference_radec = (0,0)
        self.window.reference_planexy = (0,0)
        self.window.reference_imageij = np.flip(np.array(self.data.shape, dtype = float) - 1.) / 2
        self.window.reference_imagexy = (0,0)
[docs]
    def set_data(
        self, data: Union[torch.Tensor, np.ndarray], require_shape: bool = True
    ):
        super().set_data(data = data, require_shape = require_shape)
        
        if torch.any(
            (torch.tensor(self.data.shape) % 2) != 1
        ):
            raise SpecificationConflict(f"psf must have odd shape, not {self.data.shape}")
        if torch.any(self.data < 0):
            AP_config.ap_logger.warning("psf data should be non-negative") 
[docs]
    def normalize(self):
        """Normalizes the PSF image to have a sum of 1."""
        self.data /= torch.sum(self.data) 
    @property
    def psf_border_int(self):
        """Calculates and returns the border size of the PSF image in integer
        format. This is the border used for padding before convolution.
        Returns:
            torch.Tensor: The border size of the PSF image in integer format.
        """
        return torch.ceil(
            (
                1
                + torch.flip(
                    torch.tensor(
                        self.data.shape,
                        dtype=AP_config.ap_dtype,
                        device=AP_config.ap_device,
                    ),
                    (0,),
                )
            )
            / 2
        ).int()
[docs]
    def _save_image_list(self, image_list):
        """Saves the image list to the PSF HDU header.
        Args:
            image_list (list): The list of images to be saved.
            psf_header (astropy.io.fits.Header): The header of the PSF HDU.
        """
        img_header = self.header._save_image_list()
        img_header["IMAGE"] = "PSF"
        image_list.append(
            fits.ImageHDU(self.data.detach().cpu().numpy(), header=img_header)
        ) 
[docs]
    def jacobian_image(
        self,
        parameters: Optional[List[str]] = None,
        data: Optional[torch.Tensor] = None,
        **kwargs,
    ):
        """
        Construct a blank `Jacobian_Image` object formatted like this current `PSF_Image` object. Mostly used internally.
        """
        if parameters is None:
            data = None
            parameters = []
        elif data is None:
            data = torch.zeros(
                (*self.data.shape, len(parameters)),
                dtype=AP_config.ap_dtype,
                device=AP_config.ap_device,
            )
        return Jacobian_Image(
            parameters=parameters,
            target_identity=self.identity,
            data=data,
            header=self.header,
            **kwargs,
        ) 
    
[docs]
    def model_image(self, data: Optional[torch.Tensor] = None, **kwargs):
        """
        Construct a blank `Model_Image` object formatted like this current `Target_Image` object. Mostly used internally.
        """
        return Model_Image(
            data=torch.zeros_like(self.data) if data is None else data,
            header=self.header,
            target_identity=self.identity,
            **kwargs,
        ) 
[docs]
    def expand(self, padding):
        raise NotImplementedError("expand not available for PSF_Image") 
    
[docs]
    def get_fits_state(self):
        states = [{}]
        states[0]["DATA"] = self.data.detach().cpu().numpy()
        states[0]["HEADER"] = self.header.get_fits_state()
        states[0]["HEADER"]["IMAGE"] = "PSF"
        return states 
[docs]
    def set_fits_state(self, states):
        for state in states:
            if state["HEADER"]["IMAGE"] == "PSF":
                self.set_data(np.array(state["DATA"], dtype=np.float64), require_shape=False)
                self.header = Image_Header(fits_state = state["HEADER"])
                break