from typing import Optional, Union, Any, Sequence, Tuple
from copy import deepcopy
import torch
from torch.nn.functional import pad
import numpy as np
from astropy.io import fits
from astropy.wcs import WCS as AstropyWCS
from .window_object import Window, Window_List
from .image_header import Image_Header
from .. import AP_config
from ..errors import SpecificationConflict, ConflicingWCS, InvalidData, InvalidWindow
__all__ = ["Image", "Image_List"]
[docs]
class Image(object):
    """Core class to represent images with pixel values, pixel scale,
       and a window defining the spatial coordinates on the sky.
       It supports arithmetic operations with other image objects while preserving logical image boundaries.
       It also provides methods for determining the coordinate locations of pixels
    Parameters:
        data: the matrix of pixel values for the image
        pixelscale: the length of one side of a pixel in arcsec/pixel
        window: an AstroPhot Window object which defines the spatial cooridnates on the sky
        filename: a filename from which to load the image.
        zeropoint: photometric zero point for converting from pixel flux to magnitude
        metadata: Any information the user wishes to associate with this image, stored in a python dictionary
        origin: The origin of the image in the coordinate system.
    """
    def __init__(
        self,
        *,
        data: Optional[torch.Tensor] = None,
        header: Optional[Image_Header] = None,
        wcs: Optional[AstropyWCS] = None,
        pixelscale: Optional[Union[float, torch.Tensor]] = None,
        window: Optional[Window] = None,
        filename: Optional[str] = None,
        zeropoint: Optional[Union[float, torch.Tensor]] = None,
        metadata: Optional[dict] = None,
        origin: Optional[Sequence] = None,
        center: Optional[Sequence] = None,
        identity: str = None,
        state: Optional[dict] = None,
        fits_state: Optional[dict] = None,
        **kwargs: Any,
    ) -> None:
        """Initialize an instance of the APImage class.
        Parameters:
        -----------
        data : numpy.ndarray or None, optional
            The image data. Default is None.
        wcs : astropy.wcs.wcs.WCS or None, optional
            A WCS object which defines a coordinate system for the image. Note that AstroPhot only handles basic WCS conventions. It will use the WCS object to get `wcs.pixel_to_world(-0.5, -0.5)` to determine the position of the origin in world coordinates. It will also extract the `pixel_scale_matrix` to index pixels going forward.
        pixelscale : float or None, optional
            The physical scale of the pixels in the image, in units of arcseconds. Default is None.
        window : Window or None, optional
            A Window object defining the area of the image to use. Default is None.
        filename : str or None, optional
            The name of a file containing the image data. Default is None.
        zeropoint : float or None, optional
            The image's zeropoint, used for flux calibration. Default is None.
        metadata : dict or None, optional
            Any information the user wishes to associate with this image, stored in a python dictionary. Default is None.
        origin : numpy.ndarray or None, optional
            The origin of the image in the coordinate system, as a 1D array of length 2. Default is None.
        center : numpy.ndarray or None, optional
            The center of the image in the coordinate system, as a 1D array of length 2. Default is None.
        Returns:
        --------
        None
        """
        self._data = None
        if state is not None:
            self.header = Image_Header(state=state["header"])
        elif fits_state is not None:
            self.set_fits_state(fits_state)
            return
        elif header is None:
            if data is None and window is None and filename is None:
                raise InvalidData("Image must have either data or a window to construct itself.")
            self.header = Image_Header(
                data_shape=None if data is None else data.shape,
                pixelscale=pixelscale,
                wcs=wcs,
                window=window,
                filename=filename,
                zeropoint=zeropoint,
                metadata=metadata,
                origin=origin,
                center=center,
                identity=identity,
                **kwargs,
            )
        else:
            self.header = header
        if filename is not None:
            self.load(filename)
        elif state is not None:
            self.set_state(state)
        elif fits_state is not None:
            self.data = fits_state[0]["DATA"]
        else:
            # set the data
            if data is None:
                self.data = torch.zeros(
                    torch.flip(self.window.pixel_shape, (0,)).detach().cpu().tolist(),
                    dtype=AP_config.ap_dtype,
                    device=AP_config.ap_device,
                )
            else:
                self.data = data
            self.to()
        # # Check that image data and header are in agreement (this requires talk back from GPU to CPU so is only used for testing)
        # assert np.all(np.flip(np.array(self.data.shape)[:2]) == self.window.pixel_shape.numpy()), f"data shape {np.flip(np.array(self.data.shape)[:2])}, window shape {self.window.pixel_shape.numpy()}"
    @property
    def north(self):
        return self.header.north
    @property
    def pixel_area(self):
        return self.header.pixel_area
    @property
    def pixel_length(self):
        return self.header.pixel_length
[docs]
    def world_to_plane(self, *args, **kwargs):
        return self.window.world_to_plane(*args, **kwargs) 
[docs]
    def plane_to_world(self, *args, **kwargs):
        return self.window.plane_to_world(*args, **kwargs) 
[docs]
    def plane_to_pixel(self, *args, **kwargs):
        return self.window.plane_to_pixel(*args, **kwargs) 
[docs]
    def pixel_to_plane(self, *args, **kwargs):
        return self.window.pixel_to_plane(*args, **kwargs) 
[docs]
    def plane_to_pixel_delta(self, *args, **kwargs):
        return self.window.plane_to_pixel_delta(*args, **kwargs) 
[docs]
    def pixel_to_plane_delta(self, *args, **kwargs):
        return self.window.pixel_to_plane_delta(*args, **kwargs) 
[docs]
    def world_to_pixel(self, *args, **kwargs):
        return self.window.world_to_pixel(*args, **kwargs) 
[docs]
    def pixel_to_world(self, *args, **kwargs):
        return self.window.pixel_to_world(*args, **kwargs) 
    def get_coordinate_meshgrid(self):
        return self.window.get_coordinate_meshgrid()
    def get_coordinate_corner_meshgrid(self):
        return self.window.get_coordinate_corner_meshgrid()
    def get_coordinate_simps_meshgrid(self):
        return self.window.get_coordinate_simps_meshgrid()
    @property
    def origin(self) -> torch.Tensor:
        """
        Returns the origin (bottom-left corner) of the image window.
        Returns:
            torch.Tensor: A 1D tensor of shape (2,) containing the (x, y) coordinates of the origin.
        """
        return self.header.window.origin
    @property
    def shape(self) -> torch.Tensor:
        """
        Returns the shape (size) of the image window.
        Returns:
                torch.Tensor: A 1D tensor of shape (2,) containing the (width, height) of the window in pixels.
        """
        return self.header.window.shape
    @property
    def center(self) -> torch.Tensor:
        """
        Returns the center of the image window.
        Returns:
            torch.Tensor: A 1D tensor of shape (2,) containing the (x, y) coordinates of the center.
        """
        return self.header.window.center
    @property
    def size(self) -> torch.Tensor:
        """
        Returns the size of the image window, the number of pixels in the image.
        Returns:
            torch.Tensor: A 0D tensor containing the number of pixels.
        """
        return self.header.window.size
    @property
    def window(self):
        return self.header.window
    @property
    def pixelscale(self):
        return self.header.pixelscale
    @property
    def zeropoint(self):
        return self.header.zeropoint
    @property
    def metadata(self):
        return self.header.metadata
    @property
    def identity(self):
        return self.header.identity
    @property
    def data(self) -> torch.Tensor:
        """
        Returns the image data.
        """
        return self._data
    @data.setter
    def data(self, data) -> None:
        """Set the image data."""
        self.set_data(data)
[docs]
    def set_data(self, data: Union[torch.Tensor, np.ndarray], require_shape: bool = True):
        """
        Set the image data.
        Args:
            data (torch.Tensor or numpy.ndarray): The image data.
            require_shape (bool): Whether to check that the shape of the data is the same as the current data.
        Raises:
            SpecificationConflict: If `require_shape` is `True` and the shape of the data is different from the current data.
        """
        if self._data is not None and require_shape and data.shape != self._data.shape:
            raise SpecificationConflict(
                f"Attempting to change image data with tensor that has a different shape! ({data.shape} vs {self._data.shape}) Use 'require_shape = False' if this is desired behaviour."
            )
        if data is None:
            self.data = torch.tensor((), dtype=AP_config.ap_dtype, device=AP_config.ap_device)
        elif isinstance(data, torch.Tensor):
            self._data = data.to(dtype=AP_config.ap_dtype, device=AP_config.ap_device)
        else:
            self._data = torch.as_tensor(data, dtype=AP_config.ap_dtype, device=AP_config.ap_device) 
[docs]
    def copy(self, **kwargs):
        """Produce a copy of this image with all of the same properties. This
        can be used when one wishes to make temporary modifications to
        an image and then will want the original again.
        """
        return self.__class__(
            data=torch.clone(self.data),
            header=self.header.copy(**kwargs),
            **kwargs,
        ) 
[docs]
    def blank_copy(self, **kwargs):
        """Produces a blank copy of the image which has the same properties
        except that its data is now filled with zeros.
        """
        return self.__class__(
            data=torch.zeros_like(self.data),
            header=self.header.copy(**kwargs),
            **kwargs,
        ) 
[docs]
    def get_window(self, window, **kwargs):
        """Get a sub-region of the image as defined by a window on the sky."""
        return self.__class__(
            data=self.data[self.window.get_self_indices(window)],
            header=self.header.get_window(window, **kwargs),
            **kwargs,
        ) 
[docs]
    def to(self, dtype=None, device=None):
        if dtype is None:
            dtype = AP_config.ap_dtype
        if device is None:
            device = AP_config.ap_device
        if self._data is not None:
            self._data = self._data.to(dtype=dtype, device=device)
        self.header.to(dtype=dtype, device=device)
        return self 
[docs]
    def crop(self, pixels):
        # does this show up?
        if len(pixels) == 1:  # same crop in all dimension
            self.set_data(
                self.data[
                    pixels[0].int() : (self.data.shape[0] - pixels[0]).int(),
                    pixels[0].int() : (self.data.shape[1] - pixels[0]).int(),
                ],
                require_shape=False,
            )
        elif len(pixels) == 2:  # different crop in each dimension
            self.set_data(
                self.data[
                    pixels[1].int() : (self.data.shape[0] - pixels[1]).int(),
                    pixels[0].int() : (self.data.shape[1] - pixels[0]).int(),
                ],
                require_shape=False,
            )
        elif len(pixels) == 4:  # different crop on all sides
            self.set_data(
                self.data[
                    pixels[2].int() : (self.data.shape[0] - pixels[3]).int(),
                    pixels[0].int() : (self.data.shape[1] - pixels[1]).int(),
                ],
                require_shape=False,
            )
        self.header = self.header.crop(pixels)
        return self 
[docs]
    def flatten(self, attribute: str = "data") -> np.ndarray:
        return getattr(self, attribute).reshape(-1) 
[docs]
    def get_coordinate_meshgrid(self):
        return self.header.get_coordinate_meshgrid() 
[docs]
    def get_coordinate_corner_meshgrid(self):
        return self.header.get_coordinate_corner_meshgrid() 
[docs]
    def get_coordinate_simps_meshgrid(self):
        return self.header.get_coordinate_simps_meshgrid() 
[docs]
    def reduce(self, scale: int, **kwargs):
        """This operation will downsample an image by the factor given. If
        scale = 2 then 2x2 blocks of pixels will be summed together to
        form individual larger pixels. A new image object will be
        returned with the appropriate pixelscale and data tensor. Note
        that the window does not change in this operation since the
        pixels are condensed, but the pixel size is increased
        correspondingly.
        Parameters:
            scale: factor by which to condense the image pixels. Each scale X scale region will be summed [int]
        """
        if not isinstance(scale, int) and not (
            isinstance(scale, torch.Tensor) and scale.dtype is torch.int32
        ):
            raise SpecificationConflict(f"Reduce scale must be an integer! not {type(scale)}")
        if scale == 1:
            return self
        MS = self.data.shape[0] // scale
        NS = self.data.shape[1] // scale
        return self.__class__(
            data=self.data[: MS * scale, : NS * scale]
            .reshape(MS, scale, NS, scale)
            .sum(axis=(1, 3)),
            header=self.header.rescale_pixel(scale, **kwargs),
            **kwargs,
        ) 
[docs]
    def expand(self, padding: Tuple[float]) -> None:
        """
        Args:
          padding tuple[float]: length 4 tuple with amounts to pad each dimension in physical units
        """
        padding = np.array(padding)
        if np.any(padding < 0):
            raise SpecificationConflict("negative padding not allowed in expand method")
        pad_boundaries = tuple(np.int64(np.round(np.array(padding) / self.pixelscale)))
        self.data = pad(self.data, pad=pad_boundaries, mode="constant", value=0)
        self.header.expand(padding) 
[docs]
    def get_state(self):
        state = {}
        state["type"] = self.__class__.__name__
        state["data"] = self.data.detach().cpu().tolist()
        state["header"] = self.header.get_state()
        return state 
[docs]
    def set_state(self, state):
        self.set_data(state["data"], require_shape=False)
        self.header.set_state(state["header"]) 
[docs]
    def get_fits_state(self):
        states = [{}]
        states[0]["DATA"] = self.data.detach().cpu().numpy()
        states[0]["HEADER"] = self.header.get_fits_state()
        states[0]["HEADER"]["IMAGE"] = "PRIMARY"
        return states 
[docs]
    def set_fits_state(self, states):
        for state in states:
            if state["HEADER"]["IMAGE"] == "PRIMARY":
                self.set_data(np.array(state["DATA"], dtype=np.float64), require_shape=False)
                self.header.set_fits_state(state["HEADER"])
                break 
[docs]
    def save(self, filename=None, overwrite=True):
        states = self.get_fits_state()
        img_list = [fits.PrimaryHDU(states[0]["DATA"], header=fits.Header(states[0]["HEADER"]))]
        for state in states[1:]:
            img_list.append(fits.ImageHDU(state["DATA"], header=fits.Header(state["HEADER"])))
        hdul = fits.HDUList(img_list)
        if filename is not None:
            hdul.writeto(filename, overwrite=overwrite)
        return hdul 
[docs]
    def load(self, filename):
        hdul = fits.open(filename)
        states = list({"DATA": hdu.data, "HEADER": hdu.header} for hdu in hdul)
        self.set_fits_state(states) 
    def __sub__(self, other):
        if isinstance(other, Image):
            new_img = self[other.window].copy()
            new_img.data -= other.data[self.window.get_other_indices(other)]
            return new_img
        else:
            new_img = self.copy()
            new_img.data -= other
            return new_img
    def __add__(self, other):
        if isinstance(other, Image):
            new_img = self[other.window].copy()
            new_img.data += other.data[self.window.get_other_indices(other)]
            return new_img
        else:
            new_img = self.copy()
            new_img.data += other
            return new_img
    def __iadd__(self, other):
        if isinstance(other, Image):
            self.data[other.window.get_other_indices(self)] += other.data[
                self.window.get_other_indices(other)
            ]
        else:
            self.data += other
        return self
    def __isub__(self, other):
        if isinstance(other, Image):
            self.data[other.window.get_other_indices(self)] -= other.data[
                self.window.get_other_indices(other)
            ]
        else:
            self.data -= other
        return self
    def __getitem__(self, *args):
        if len(args) == 1 and isinstance(args[0], Window):
            return self.get_window(args[0])
        if len(args) == 1 and isinstance(args[0], Image):
            return self.get_window(args[0].window)
        raise ValueError("Unrecognized Image getitem request!")
    def __str__(self):
        return f"image pixelscale: {self.pixelscale.detach().cpu().numpy()} origin: {self.origin.detach().cpu().numpy()} shape: {self.shape.detach().cpu().numpy()}"
    def __repr__(self):
        return f"image pixelscale: {self.pixelscale.detach().cpu().numpy()} origin: {self.origin.detach().cpu().numpy()} shape: {self.shape.detach().cpu().numpy()} center: {self.center.detach().cpu().numpy()}\ndata: {self.data.detach().cpu().numpy()}" 
[docs]
class Image_List(Image):
    def __init__(self, image_list, window=None):
        self.image_list = list(image_list)
        self.check_wcs()
        self.window = window
[docs]
    def check_wcs(self):
        """Ensure the WCS systems being used by all the windows in this list
        are consistent with each other. They should all project world
        coordinates onto the same tangent plane.
        """
        ref = torch.stack(tuple(I.window.reference_radec for I in self.image_list))
        if not torch.allclose(ref, ref[0]):
            raise ConflicingWCS(
                "Reference (world) coordinate mismatch! All images in Image_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means."
            )
        ref = torch.stack(tuple(I.window.reference_planexy for I in self.image_list))
        if not torch.allclose(ref, ref[0]):
            raise ConflicingWCS(
                "Reference (tangent plane) coordinate mismatch! All images in Image_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means."
            )
        if len(set(I.window.projection for I in self.image_list)) > 1:
            raise ConflicingWCS(
                "Projection mismatch! All images in Image_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means."
            ) 
    @property
    def window(self):
        return Window_List(list(image.window for image in self.image_list))
    @window.setter
    def window(self, window):
        if window is None:
            return
        if not isinstance(window, Window_List):
            raise InvalidWindow("Target_List must take a Window_List object as its window")
        for i in range(len(self.image_list)):
            self.image_list[i] = self.image_list[i][window.window_list[i]]
    @property
    def pixelscale(self):
        return tuple(image.pixelscale for image in self.image_list)
    @property
    def zeropoint(self):
        return tuple(image.zeropoint for image in self.image_list)
    @property
    def data(self):
        return tuple(image.data for image in self.image_list)
    @data.setter
    def data(self, data):
        for image, dat in zip(self.image_list, data):
            image.data = dat
[docs]
    def copy(self):
        return self.__class__(
            tuple(image.copy() for image in self.image_list),
        ) 
[docs]
    def blank_copy(self):
        return self.__class__(
            tuple(image.blank_copy() for image in self.image_list),
        ) 
[docs]
    def get_window(self, window):
        return self.__class__(
            tuple(image[win] for image, win in zip(self.image_list, window)),
        ) 
[docs]
    def index(self, other):
        if isinstance(other, Image) and hasattr(other, "identity"):
            for i, self_image in enumerate(self.image_list):
                if other.identity == self_image.identity:
                    return i
            else:
                raise ValueError("Could not find identity match between image list and input image")
        raise NotImplementedError(f"Image_List cannot get index for {type(other)}") 
[docs]
    def to(self, dtype=None, device=None):
        if dtype is not None:
            dtype = AP_config.ap_dtype
        if device is not None:
            device = AP_config.ap_device
        for image in self.image_list:
            image.to(dtype=dtype, device=device)
        return self 
[docs]
    def crop(self, *pixels):
        raise NotImplementedError("Crop function not available for Image_List object") 
[docs]
    def get_coordinate_meshgrid(self):
        return tuple(image.get_coordinate_meshgrid() for image in self.image_list) 
[docs]
    def get_coordinate_corner_meshgrid(self):
        return tuple(image.get_coordinate_corner_meshgrid() for image in self.image_list) 
[docs]
    def get_coordinate_simps_meshgrid(self):
        return tuple(image.get_coordinate_simps_meshgrid() for image in self.image_list) 
[docs]
    def flatten(self, attribute="data"):
        return torch.cat(tuple(image.flatten(attribute) for image in self.image_list)) 
[docs]
    def reduce(self, scale):
        if scale == 1:
            return self
        return self.__class__(
            tuple(image.reduce(scale) for image in self.image_list),
        ) 
    def __sub__(self, other):
        if isinstance(other, Image_List):
            new_list = []
            for self_image, other_image in zip(self.image_list, other.image_list):
                new_list.append(self_image - other_image)
            return self.__class__(new_list)
        else:
            new_list = []
            for self_image, other_image in zip(self.image_list, other):
                new_list.append(self_image - other_image)
            return self.__class__(new_list)
    def __add__(self, other):
        if isinstance(other, Image_List):
            new_list = []
            for self_image, other_image in zip(self.image_list, other.image_list):
                new_list.append(self_image + other_image)
            return self.__class__(new_list)
        else:
            new_list = []
            for self_image, other_image in zip(self.image_list, other):
                new_list.append(self_image + other_image)
            return self.__class__(new_list)
    def __isub__(self, other):
        if isinstance(other, Image_List):
            for self_image, other_image in zip(self.image_list, other.image_list):
                self_image -= other_image
        else:
            for self_image, other_image in zip(self.image_list, other):
                self_image -= other_image
        return self
    def __iadd__(self, other):
        if isinstance(other, Image_List):
            for self_image, other_image in zip(self.image_list, other.image_list):
                self_image += other_image
        else:
            for self_image, other_image in zip(self.image_list, other):
                self_image += other_image
        return self
[docs]
    def save(self, filename=None, overwrite=True):
        raise NotImplementedError("Save/load not yet available for image lists") 
[docs]
    def load(self, filename):
        raise NotImplementedError("Save/load not yet available for image lists") 
    def __getitem__(self, *args):
        if len(args) == 1 and isinstance(args[0], Window):
            return self.get_window(args[0])
        if len(args) == 1 and isinstance(args[0], Image):
            return self.get_window(args[0].window)
        if all(isinstance(arg, (int, slice)) for arg in args):
            return self.image_list.__getitem__(*args)
        raise ValueError("Unrecognized Image_List getitem request!")
    def __str__(self):
        return f"image list of:\n" + "\n".join(image.__str__() for image in self.image_list)
    def __repr__(self):
        return f"image list of:\n" + "\n".join(image.__repr__() for image in self.image_list)
    def __iter__(self):
        return (img for img in self.image_list) 
    #     self._index = 0
    #     return self
    # def __next__(self):
    #     if self._index >= len(self.image_list):
    #         raise StopIteration
    #     img = self.image_list[self._index]
    #     self._index += 1
    #     return img