from functools import lru_cache
import torch
import numpy as np
from scipy.special import binom
from ..utils.decorators import ignore_numpy_warnings, default_internal
from ._shared_methods import select_target
from .psf_model_object import PSF_Model
from ..param import Param_Unlock, Param_SoftLimits
from .. import AP_config
from ..errors import SpecificationConflict
__all__ = ("Zernike_PSF",)
[docs]
class Zernike_PSF(PSF_Model):
    model_type = f"zernike {PSF_Model.model_type}"
    parameter_specs = {
        "Anm": {"units": "flux/arcsec^2"},
    }
    _parameter_order = PSF_Model._parameter_order + ("Anm",)
    useable = True
    model_integrated = False
    
    def __init__(self, *, name=None, order_n=2, r_scale=None, **kwargs):
        super().__init__(name=name, **kwargs)
        self.order_n = int(order_n)
        self.r_scale = r_scale
        self.nm_list = self.iter_nm(self.order_n)
[docs]
    @torch.no_grad()
    @ignore_numpy_warnings
    @select_target
    @default_internal
    def initialize(self, target=None, parameters=None, **kwargs):
        super().initialize(target=target, parameters=parameters)
        # List the coefficients to use
        self.nm_list = self.iter_nm(self.order_n)
        # Set the scale radius for the Zernike area
        if self.r_scale is None:
            self.r_scale = torch.max(self.window.shape) / 2
        # Check if user has already set the coefficients
        if parameters["Anm"].value is not None:
            if len(self.nm_list) != len(
                parameters["Anm"].value
            ):
                raise SpecificationConflict(f"nm_list length ({len(self.nm_list)}) must match coefficients ({len(parameters['Anm'].value)})")
            return
        # Set the default coefficients to zeros
        with Param_Unlock(parameters["Anm"]), Param_SoftLimits(parameters["Anm"]):
            parameters["Anm"].value = torch.zeros(len(self.nm_list))
            if parameters["Anm"].uncertainty is None:
                parameters["Anm"].uncertainty = self.default_uncertainty * torch.ones_like(parameters["Anm"].value)
            # Set the zero order zernike polynomial to the average in the image
            if self.nm_list[0] == (0, 0):
                parameters["Anm"].value[0] = (
                    torch.median(target[self.window].data) / target.pixel_area
                ) 
[docs]
    def iter_nm(self, n):
        nm = []
        for n_i in range(n + 1):
            for m_i in range(-n_i, n_i + 1, 2):
                nm.append((n_i, m_i))
        return nm 
[docs]
    @staticmethod
    @lru_cache(maxsize=1024)
    def coefficients(n, m):
        C = []
        for k in range(int((n - abs(m)) / 2) + 1):
            C.append(
                (
                    k,
                    (-1) ** k
                    * binom(n - k, k)
                    * binom(n - 2 * k, (n - abs(m)) / 2 - k),
                )
            )
        return C 
[docs]
    def Z_n_m(self, rho, phi, n, m, efficient=True):
        Z = torch.zeros_like(rho)
        if efficient:
            T_cache = {0: None}
            R_cache = {}
        for k, c in self.coefficients(n, m):
            if efficient:
                if (n - 2 * k) not in R_cache:
                    R_cache[n - 2 * k] = rho ** (n - 2 * k)
                R = R_cache[n - 2 * k]
                if m not in T_cache:
                    if m < 0:
                        T_cache[m] = torch.sin(abs(m) * phi)
                    elif m > 0:
                        T_cache[m] = torch.cos(m * phi)
                T = T_cache[m]
            else:
                R = rho ** (n - 2 * k)
                if m < 0:
                    T = torch.sin(abs(m) * phi)
                elif m > 0:
                    T = torch.cos(m * phi)
            if m == 0:
                Z += c * R
            elif m < 0:
                Z += c * R * T
            else:
                Z += c * R * T
        return Z 
[docs]
    @default_internal
    def evaluate_model(self, X=None, Y=None, image=None, parameters=None):
        if X is None:
            Coords = image.get_coordinate_meshgrid()
            X, Y = Coords - parameters["center"].value[..., None, None]
        phi = self.angular_metric(X, Y, image, parameters)
        r = self.radius_metric(X, Y, image, parameters)
        r = r / self.r_scale
        G = torch.zeros_like(X)
        i = 0
        A = image.pixel_area * parameters["Anm"].value
        for n, m in self.nm_list:
            G += A[i] * self.Z_n_m(r, phi, n, m)
            i += 1
        G[r > 1] = 0.0
        return G