import numpy as np
import torch
from astropy.wcs import WCS as AstropyWCS
from .. import AP_config
from ..utils.conversions.coordinates import Rotate_Cartesian
from .wcs import WCS
from ..errors import ConflicingWCS, SpecificationConflict
__all__ = ["Window", "Window_List"]
[docs]
class Window(WCS):
    """class to define a window on the sky in coordinate space. These
    windows can undergo arithmetic and preserve logical behavior. Image
    objects can also be indexed using windows and will return an
    appropriate subsection of their data.
    There are several ways to tell a Window object where to
    place itself. The simplest method is to pass an
    Astropy WCS object such as::
      H = ap.image.Window(wcs = wcs)
    this will automatically place your image at the correct RA, DEC
    and assign the correct pixel scale, etc. WARNING, it will default to
    setting the reference RA DEC at the reference RA DEC of the wcs
    object; if you have multiple images you should force them all to
    have the same reference world coordinate by passing
    ``reference_radec = (ra, dec)``. See the :doc:`coordinates`
    documentation for more details. There are several other ways to
    initialize a window. If you provide ``origin_radec`` then
    it will place the image origin at the requested RA DEC
    coordinates. If you provide ``center_radec`` then it will place
    the image center at the requested RA DEC coordiantes. Note that in
    these cases the fixed point between the pixel grid and image plane
    is different (pixel origin and center respectively); so if you
    have rotated pixels in your pixel scale matrix then everything
    will be rotated about different points (pixel origin and center
    respectively). If you provide ``origin`` or ``center`` then those
    are coordiantes in the tangent plane (arcsec) and they will
    correspondingly become fixed points. For arbitrary control over
    the pixel positioning, use ``reference_imageij`` and
    ``reference_imagexy`` to fix the pixel and tangent plane
    coordinates respectively to each other, any rotation or shear will
    happen about that fixed point.
    Args:
      origin : Sequence or None, optional
          The origin of the image in the tangent plane coordinate system
          (arcsec), as a 1D array of length 2. Default is None.
      origin_radec : Sequence or None, optional
          The origin of the image in the world coordinate system (RA,
          DEC in degrees), as a 1D array of length 2. Default is None.
      center : Sequence or None, optional
          The center of the image in the tangent plane coordinate system
          (arcsec), as a 1D array of length 2. Default is None.
      center_radec : Sequence or None, optional
          The center of the image in the world coordinate system (RA,
          DEC in degrees), as a 1D array of length 2. Default is None.
      wcs: An astropy.wcs.WCS object which gives information about the
          origin and orientation of the window.
      reference_radec: world coordinates on the celestial sphere (RA,
          DEC in degrees) where the tangent plane makes contact. This should
          be the same for every image in multi-image analysis.
      reference_planexy: tangent plane coordinates (arcsec) where it
          makes contact with the celesial sphere. This should typically be
          (0,0) though that is not stricktly enforced (it is assumed if not
          given). This reference coordinate should be the same for all
          images in multi-image analysis.
      reference_imageij: pixel coordinates about which the image is
          defined. For example in an Astropy WCS object the wcs.wcs.crpix
          array gives the pixel coordinate reference point for which the
          world coordinate mapping (wcs.wcs.crval) is defined. One may think
          of the referenced pixel location as being "pinned" to the tangent
          plane. This may be different for each image in multi-image
          analysis..
      reference_imagexy: tangent plane coordinates (arcsec) about
          which the image is defined. This is the pivot point about which the
          pixelscale matrix operates, therefore if the pixelscale matrix
          defines a rotation then this is the coordinate about which the
          rotation will be performed. This may be different for each image in
          multi-image analysis.
    """
    def __init__(
        self,
        *,
        pixel_shape=None,
        origin=None,
        origin_radec=None,
        center=None,
        center_radec=None,
        state=None,
        fits_state=None,
        wcs=None,
        **kwargs,
    ):
        # If loading from a previous state, simply update values and end init
        if state is not None:
            self.set_state(state)
            return
        if fits_state is not None:
            self.set_fits_state(fits_state)
            return
        # Collect the shape of the window
        if pixel_shape is not None:
            self.pixel_shape = pixel_shape
        else:
            self.pixel_shape = wcs.pixel_shape
        # Determine relative positioning of tangent plane and pixel grid. Also world coordinates and tangent plane
        if not sum(C is not None for C in [wcs, origin_radec, center_radec, origin, center]) <= 1:
            raise SpecificationConflict(
                "Please provide only one reference position for the window, otherwise the placement is ambiguous"
            )
        # Image coordinates provided by WCS
        if wcs is not None:
            super().__init__(wcs=wcs, **kwargs)
        # Image reference position from RA and DEC of image origin
        elif origin_radec is not None:
            # Origin given, it is reference point
            origin_radec = torch.as_tensor(
                origin_radec, dtype=AP_config.ap_dtype, device=AP_config.ap_device
            )
            kwargs["reference_radec"] = kwargs.get("reference_radec", origin_radec)
            super().__init__(**kwargs)
            self.reference_imageij = (-0.5, -0.5)
            self.reference_imagexy = self.world_to_plane(origin_radec)
        # Image reference position from RA and DEC of image center
        elif center_radec is not None:
            pix_center = self.pixel_shape.to(dtype=AP_config.ap_dtype) / 2 - 0.5
            center_radec = torch.as_tensor(
                center_radec, dtype=AP_config.ap_dtype, device=AP_config.ap_device
            )
            kwargs["reference_radec"] = kwargs.get("reference_radec", center_radec)
            super().__init__(**kwargs)
            center = self.world_to_plane(center_radec)
            self.reference_imageij = pix_center
            self.reference_imagexy = center
        # Image reference position from tangent plane position of image origin
        elif origin is not None:
            kwargs.update(
                {
                    "reference_imageij": (-0.5, -0.5),
                    "reference_imagexy": origin,
                }
            )
            super().__init__(**kwargs)
        # Image reference position from tangent plane position of image center
        elif center is not None:
            pix_center = self.pixel_shape.to(dtype=AP_config.ap_dtype) / 2 - 0.5
            kwargs.update(
                {
                    "reference_imageij": pix_center,
                    "reference_imagexy": center,
                }
            )
            super().__init__(**kwargs)
        # Image origin assumed to be at tangent plane origin
        else:
            super().__init__(**kwargs)
    @property
    def shape(self):
        dtype, device = self.pixelscale.dtype, self.pixelscale.device
        S1 = self.pixel_shape.to(dtype=dtype, device=device)
        S1[1] = 0.0
        S2 = self.pixel_shape.to(dtype=dtype, device=device)
        S2[0] = 0.0
        return torch.stack(
            (
                torch.linalg.norm(self.pixelscale @ S1),
                torch.linalg.norm(self.pixelscale @ S2),
            )
        )
    @shape.setter
    def shape(self, shape):
        if shape is None:
            self._pixel_shape = None
            return
        shape = torch.as_tensor(shape, dtype=self.pixelscale.dtype, device=self.pixelscale.device)
        self.pixel_shape = shape / torch.sqrt(torch.sum(self.pixelscale**2, dim=0))
    @property
    def pixel_shape(self):
        return self._pixel_shape
    @pixel_shape.setter
    def pixel_shape(self, shape):
        if shape is None:
            self._pixel_shape = None
            return
        self._pixel_shape = torch.as_tensor(shape, device=AP_config.ap_device)
        self._pixel_shape = torch.round(self.pixel_shape).to(
            dtype=torch.int32, device=AP_config.ap_device
        )
    @property
    def size(self):
        """The number of pixels in the window"""
        return torch.prod(self.pixel_shape)
    @property
    def end(self):
        return self.pixel_to_plane_delta(
            self.pixel_shape.to(dtype=self.pixelscale.dtype, device=self.pixelscale.device)
        )
    @property
    def origin(self):
        return self.pixel_to_plane(-0.5 * torch.ones_like(self.reference_imageij))
    @property
    def center(self):
        return self.origin + self.end / 2
[docs]
    def copy(self, **kwargs):
        copy_kwargs = {"pixel_shape": torch.clone(self.pixel_shape)}
        copy_kwargs.update(kwargs)
        return super().copy(**copy_kwargs) 
[docs]
    def to(self, dtype=None, device=None):
        if dtype is None:
            dtype = AP_config.ap_dtype
        if device is None:
            device = AP_config.ap_device
        super().to(dtype=dtype, device=device)
        self.pixel_shape = self.pixel_shape.to(dtype=dtype, device=device) 
[docs]
    def rescale_pixel(self, scale, **kwargs):
        return self.copy(
            pixelscale=self.pixelscale * scale,
            pixel_shape=self.pixel_shape // scale,
            reference_imageij=(self.reference_imageij + 0.5) / scale - 0.5,
            **kwargs,
        ) 
    @staticmethod
    @torch.no_grad()
    def _get_indices(ref_window, obj_window):
        other_origin_pix = torch.round(ref_window.plane_to_pixel(obj_window.origin) + 0.5).int()
        new_origin_pix = torch.maximum(torch.zeros_like(other_origin_pix), other_origin_pix)
        other_pixel_end = torch.round(
            ref_window.plane_to_pixel(obj_window.origin + obj_window.end) + 0.5
        ).int()
        new_pixel_end = torch.minimum(ref_window.pixel_shape, other_pixel_end)
        return slice(new_origin_pix[1], new_pixel_end[1]), slice(
            new_origin_pix[0], new_pixel_end[0]
        )
[docs]
    def get_self_indices(self, obj):
        """
        Return an index slicing tuple for obj corresponding to this window
        """
        if isinstance(obj, Window):
            return self._get_indices(self, obj)
        return self._get_indices(self, obj.window) 
[docs]
    def get_other_indices(self, obj):
        """
        Return an index slicing tuple for obj corresponding to this window
        """
        if isinstance(obj, Window):
            return self._get_indices(obj, self)
        return self._get_indices(obj.window, self) 
[docs]
    def overlap_frac(self, other):
        overlap = self & other
        overlap_area = torch.prod(overlap.shape)
        full_area = torch.prod(self.shape) + torch.prod(other.shape) - overlap_area
        return overlap_area / full_area 
[docs]
    def shift(self, shift):
        """
        Shift the location of the window by a specified amount in tangent plane coordinates
        """
        self.reference_imagexy = self.reference_imagexy + shift
        return self 
[docs]
    def pixel_shift(self, shift):
        """
        Shift the location of the window by a specified amount in pixel grid coordinates
        """
        self.reference_imageij = self.reference_imageij - shift
        return self 
[docs]
    def get_astropywcs(self, **kwargs):
        wargs = {
            "NAXIS": 2,
            "NAXIS1": self.pixel_shape[0].item(),
            "NAXIS2": self.pixel_shape[1].item(),
            "CTYPE1": "RA---TAN",
            "CTYPE2": "DEC--TAN",
            "CRVAL1": self.pixel_to_world(self.reference_imageij)[0].item(),
            "CRVAL2": self.pixel_to_world(self.reference_imageij)[1].item(),
            "CRPIX1": self.reference_imageij[0].item(),
            "CRPIX2": self.reference_imageij[1].item(),
            "CD1_1": self.pixelscale[0][0].item(),
            "CD1_2": self.pixelscale[0][1].item(),
            "CD2_1": self.pixelscale[1][0].item(),
            "CD2_2": self.pixelscale[1][1].item(),
        }
        wargs.update(kwargs)
        return AstropyWCS(wargs) 
[docs]
    def get_state(self):
        state = super().get_state()
        state["pixel_shape"] = self.pixel_shape.detach().cpu().tolist()
        return state 
[docs]
    def set_state(self, state):
        super().set_state(state)
        self.pixel_shape = torch.tensor(
            state["pixel_shape"], dtype=AP_config.ap_dtype, device=AP_config.ap_device
        ) 
[docs]
    def get_fits_state(self):
        state = super().get_fits_state()
        state["PXL_SHPE"] = str(self.pixel_shape.detach().cpu().tolist())
        return state 
[docs]
    def set_fits_state(self, state):
        super().set_fits_state(state)
        self.pixel_shape = torch.tensor(
            eval(state["PXL_SHPE"]), dtype=AP_config.ap_dtype, device=AP_config.ap_device
        ) 
[docs]
    def crop_pixel(self, pixels):
        """
        [crop all sides] or
        [crop x, crop y] or
        [crop x low, crop y low, crop x high, crop y high]
        """
        if len(pixels) == 1:
            self.pixel_shape = self.pixel_shape - 2 * pixels[0]
            self.reference_imageij = self.reference_imageij - pixels[0]
        elif len(pixels) == 2:
            pix_shift = torch.as_tensor(
                pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device
            )
            self.pixel_shape = self.pixel_shape - 2 * pix_shift
            self.reference_imageij = self.reference_imageij - pix_shift
        elif len(pixels) == 4:  # different crop on all sides
            pixels = torch.as_tensor(pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device)
            self.pixel_shape = self.pixel_shape - pixels[:2] - pixels[2:]
            self.reference_imageij = self.reference_imageij - pixels[:2]
        else:
            raise ValueError(f"Unrecognized pixel crop format: {pixels}")
        return self 
[docs]
    def crop_to_pixel(self, pixels):
        """
        format: [[xmin, xmax],[ymin,ymax]]
        """
        pixels = torch.tensor(pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device)
        self.reference_imageij = self.reference_imageij - pixels[:, 0]
        self.pixel_shape = pixels[:, 1] - pixels[:, 0]
        return self 
[docs]
    def pad_pixel(self, pixels):
        """
        [pad all sides] or
        [pad x, pad y] or
        [pad x low, pad y low, pad x high, pad y high]
        """
        if len(pixels) == 1:
            self.pixel_shape = self.pixel_shape + 2 * pixels[0]
            self.reference_imageij = self.reference_imageij + pixels[0]
        elif len(pixels) == 2:
            pix_shift = torch.as_tensor(
                pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device
            )
            self.pixel_shape = self.pixel_shape + 2 * pix_shift
            self.reference_imageij = self.reference_imageij + pix_shift
        elif len(pixels) == 4:  # different crop on all sides
            pixels = torch.as_tensor(pixels, dtype=AP_config.ap_dtype, device=AP_config.ap_device)
            self.pixel_shape = self.pixel_shape + pixels[:2] + pixels[2:]
            self.reference_imageij = self.reference_imageij + pixels[:2]
        else:
            raise ValueError(f"Unrecognized pixel crop format: {pixels}")
        return self 
[docs]
    @torch.no_grad()
    def get_coordinate_meshgrid(self):
        """Returns a meshgrid with tangent plane coordinates for the center
        of every pixel.
        """
        pix = self.pixel_shape.to(dtype=AP_config.ap_dtype)
        xsteps = torch.arange(pix[0], dtype=AP_config.ap_dtype, device=AP_config.ap_device)
        ysteps = torch.arange(pix[1], dtype=AP_config.ap_dtype, device=AP_config.ap_device)
        meshx, meshy = torch.meshgrid(
            xsteps,
            ysteps,
            indexing="xy",
        )
        Coords = self.pixel_to_plane(meshx, meshy)
        return torch.stack(Coords) 
[docs]
    @torch.no_grad()
    def get_coordinate_corner_meshgrid(self):
        """Returns a meshgrid with tangent plane coordinates for the corners
        of every pixel.
        """
        pix = self.pixel_shape.to(dtype=AP_config.ap_dtype)
        xsteps = (
            torch.arange(pix[0] + 1, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - 0.5
        )
        ysteps = (
            torch.arange(pix[1] + 1, dtype=AP_config.ap_dtype, device=AP_config.ap_device) - 0.5
        )
        meshx, meshy = torch.meshgrid(
            xsteps,
            ysteps,
            indexing="xy",
        )
        Coords = self.pixel_to_plane(meshx, meshy)
        return torch.stack(Coords) 
[docs]
    @torch.no_grad()
    def get_coordinate_simps_meshgrid(self):
        """Returns a meshgrid with tangent plane coordinates for performing
        simpsons method pixel integration (all corners, centers, and
        middle of each edge). This is approximately 4 times more
        points than the standard :meth:`get_coordinate_meshgrid`.
        """
        pix = self.pixel_shape.to(dtype=AP_config.ap_dtype)
        xsteps = (
            0.5
            * torch.arange(
                2 * (pix[0]) + 1,
                dtype=AP_config.ap_dtype,
                device=AP_config.ap_device,
            )
            - 0.5
        )
        ysteps = (
            0.5
            * torch.arange(
                2 * (pix[1]) + 1,
                dtype=AP_config.ap_dtype,
                device=AP_config.ap_device,
            )
            - 0.5
        )
        meshx, meshy = torch.meshgrid(
            xsteps,
            ysteps,
            indexing="xy",
        )
        Coords = self.pixel_to_plane(meshx, meshy)
        return torch.stack(Coords) 
    # Window Comparison operators
    @torch.no_grad()
    def __eq__(self, other):
        return (
            torch.all(self.pixel_shape == other.pixel_shape)
            and torch.all(self.pixelscale == other.pixelscale)
            and (self.projection == other.projection)
            and (
                torch.all(
                    self.pixel_to_plane(torch.zeros_like(self.reference_imageij))
                    == other.pixel_to_plane(torch.zeros_like(other.reference_imageij))
                )
            )
        )  # fixme more checks?
    @torch.no_grad()
    def __ne__(self, other):
        return not self == other
    # Window interaction operators
    @torch.no_grad()
    def __or__(self, other):
        other_origin_pix = self.plane_to_pixel(other.origin)
        new_origin_pix = torch.minimum(-0.5 * torch.ones_like(other_origin_pix), other_origin_pix)
        other_pixel_end = self.plane_to_pixel(other.origin + other.end)
        new_pixel_end = torch.maximum(
            self.pixel_shape.to(dtype=AP_config.ap_dtype), other_pixel_end
        )
        return self.copy(
            origin=self.pixel_to_plane(new_origin_pix),
            pixel_shape=new_pixel_end - new_origin_pix,
        )
    @torch.no_grad()
    def __ior__(self, other):
        other_origin_pix = self.plane_to_pixel(other.origin)
        new_origin_pix = torch.minimum(-0.5 * torch.ones_like(other_origin_pix), other_origin_pix)
        other_pixel_end = self.plane_to_pixel(other.origin + other.end)
        new_pixel_end = torch.maximum(
            self.pixel_shape.to(dtype=AP_config.ap_dtype), other_pixel_end
        )
        self.reference_imageij = self.reference_imageij - (new_origin_pix + 0.5)
        self.pixel_shape = new_pixel_end - new_origin_pix
        return self
    @torch.no_grad()
    def __and__(self, other):
        other_origin_pix = self.plane_to_pixel(other.origin)
        new_origin_pix = torch.maximum(-0.5 * torch.ones_like(other_origin_pix), other_origin_pix)
        other_pixel_end = self.plane_to_pixel(other.origin + other.end)
        new_pixel_end = torch.minimum(
            self.pixel_shape.to(dtype=AP_config.ap_dtype) - 0.5, other_pixel_end
        )
        return self.copy(
            origin=self.pixel_to_plane(new_origin_pix),
            pixel_shape=new_pixel_end - new_origin_pix,
        )
    @torch.no_grad()
    def __iand__(self, other):
        other_origin_pix = self.plane_to_pixel(other.origin)
        new_origin_pix = torch.maximum(-0.5 * torch.ones_like(other_origin_pix), other_origin_pix)
        other_pixel_end = self.plane_to_pixel(other.origin + other.end)
        new_pixel_end = torch.minimum(
            self.pixel_shape.to(dtype=AP_config.ap_dtype), other_pixel_end
        )
        self.reference_imageij = self.reference_imageij - (new_origin_pix + 0.5)
        self.pixel_shape = new_pixel_end - new_origin_pix
        return self
    def __str__(self):
        return f"window origin: {self.origin.detach().cpu().tolist()}, shape: {self.shape.detach().cpu().tolist()}, center: {self.center.detach().cpu().tolist()}, pixelscale: {self.pixelscale.detach().cpu().tolist()}"
    def __repr__(self):
        return (
            f"window pixel_shape: {self.pixel_shape.detach().cpu().tolist()}, shape: {self.shape.detach().cpu().tolist()}\n"
            + super().__repr__()
        ) 
[docs]
class Window_List(Window):
    def __init__(self, window_list=None, state=None):
        if state is not None:
            self.set_state(state)
        else:
            if window_list is None:
                window_list = []
            self.window_list = list(window_list)
        self.check_wcs()
[docs]
    def check_wcs(self):
        """Ensure the WCS systems being used by all the windows in this list
        are consistent with each other. They should all project world
        coordinates onto the same tangent plane.
        """
        windows = tuple(
            W.reference_radec for W in filter(lambda w: w is not None, self.window_list)
        )
        if len(windows) == 0:
            return
        ref = torch.stack(windows)
        if not torch.allclose(ref, ref[0]):
            raise ConflicingWCS(
                "Reference (world) coordinate mismatch! All windows in Window_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means."
            )
        ref = torch.stack(
            tuple(W.reference_planexy for W in filter(lambda w: w is not None, self.window_list))
        )
        if not torch.allclose(ref, ref[0]):
            raise ConflicingWCS(
                "Reference (tangent plane) coordinate mismatch! All windows in Window_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means."
            )
        if len(set(W.projection for W in filter(lambda w: w is not None, self.window_list))) > 1:
            raise ConflicingWCS(
                "Projection mismatch! All windows in Window_List are not on the same tangent plane! Likely serious coordinate mismatch problems. See the coordinates page in the documentation for what this means."
            ) 
    @property
    @torch.no_grad()
    def origin(self):
        return tuple(w.origin for w in self)
    @property
    @torch.no_grad()
    def shape(self):
        return tuple(w.shape for w in self)
    @property
    @torch.no_grad()
    def center(self):
        return tuple(w.center for w in self)
[docs]
    def shift_origin(self, shift):
        raise NotImplementedError("shift origin not implemented for window list") 
[docs]
    def copy(self):
        return self.__class__(list(w.copy() for w in self)) 
[docs]
    def to(self, dtype=None, device=None):
        if dtype is None:
            dtype = AP_config.ap_dtype
        if device is None:
            device = AP_config.ap_device
        for window in self:
            window.to(dtype, device) 
[docs]
    def get_state(self):
        return list(window.get_state() for window in self) 
[docs]
    def set_state(self, state):
        self.window_list = list(Window(state=st) for st in state) 
    # Window interaction operators
    @torch.no_grad()
    def __or__(self, other):
        new_windows = list((sw | ow) for sw, ow in zip(self, other))
        return self.__class__(window_list=new_windows)
    @torch.no_grad()
    def __ior__(self, other):
        for sw, ow in zip(self, other):
            sw |= ow
        return self
    @torch.no_grad()
    def __and__(self, other):
        new_windows = list((sw & ow) for sw, ow in zip(self, other))
        return self.__class__(window_list=new_windows)
    @torch.no_grad()
    def __iand__(self, other):
        for sw, ow in zip(self, other):
            sw &= ow
        return self
    # Window Comparison operators
    @torch.no_grad()
    def __eq__(self, other):
        results = list((sw == ow).view(-1) for sw, ow in zip(self, other))
        return torch.all(torch.cat(results))
    @torch.no_grad()
    def __ne__(self, other):
        return not self == other
    def __len__(self):
        return len(self.window_list)
    def __iter__(self):
        return (win for win in self.window_list)
    def __str__(self):
        return "Window List: \n" + ("\n".join(list(str(window) for window in self)) + "\n")
    def __repr__(self):
        return "Window List: \n" + ("\n".join(list(repr(window) for window in self)) + "\n")