from typing import Optional
from types import FunctionType
from copy import deepcopy
from collections import OrderedDict
import torch
import numpy as np
from ..utils.conversions.optimization import (
    boundaries,
    inv_boundaries,
    d_boundaries_dval,
    d_inv_boundaries_dval,
    cyclic_boundaries,
)
from .. import AP_config
from .base import Node
from ..errors import InvalidParameter
__all__ = ["Parameter_Node"]
    
[docs]
class Parameter_Node(Node):
    """A node representing parameters and their relative structure.
    The Parameter_Node object stores all information relevant for the
    parameters of a model. At a high level the Parameter_Node
    accomplishes two tasks. The first task is to store the actual
    parameter values, these are represented as pytorch tensors which
    can have any shape; these are leaf nodes. The second task is to
    store the relationship between parameters in a graph structure;
    these are branch nodes. The two tasks are handled by the same type
    of object since there is some overlap between them where a branch
    node acts like a leaf node in certain contexts.
    There are various quantities that a Parameter_Node tracks which
    can be provided as arguments or updated later.
    Args:
      value: The value of a node represents the tensor which will be used by models to compute their projection into the pixels of an image. These can be quite complex, see further down for more details.
      cyclic (bool): Records if the value of a node is cyclic, meaning that if it is updated outside it's limits it should be wrapped back into the limits.
      limits (Tuple[Tensor or None, Tensor or None]): Tracks if a parameter has constraints on the range of values it can take. The first element is the lower limit, the second element is the upper limit. The two elements should either be None (no limit) or tensors with the same shape as the value.
      units (str): The units of the parameter value.
      uncertainty (Tensor or None): represents the uncertainty of the parameter value. This should be None (no uncertainty) or a Tensor with the same shape as the value.
      prof (Tensor or None): This is a profile of values which has no explicit meaning, but can be used to store information which should be kept alongside the value. For example in a spline model the position of the spline points may be a ``prof`` while the flux at each node is the value to be optimized.
      shape (Tuple or None): Can be used to set the shape of the value (number of elements/dimensions). If not provided then the shape will be set by the first time a value is given. Once a shape has been set, if a value is given which cannot be coerced into that shape, then an error will be thrown.
    The ``value`` of a Parameter_Node is somewhat complicated, there
    are a number of states it can take on. The most straightforward is
    just a Tensor, if a Tensor (or just an iterable like a list or
    numpy.ndarray) is provided then the node is required to be a leaf
    node and it will store the value to be accessed later by other
    parts of AstroPhot. Another option is to set the value as another
    node (they will automatically be linked), in this case the node's
    ``value`` is just a wrapper to call for the ``value`` of the
    linked node. Finally, the value may be a function which allows for
    arbitrarily complex values to be computed from other node's
    values. The function must take as an argument the current
    Parameter_Node instance and return a Tensor. Here are some
    examples of the various ways of interacting with the ``value`` for a hypothetical parameter ``P``::
      P.value = 1. # Will create a tensor with value 1.
      P.value = P2 # calling P.value will actually call P2.value
      def compute_value(param):
        return param["P2"].value**2
      P.value = compute_value # calling P.value will call the function as: compute_value(P) which will return P2.value**2
    
    """
    def __init__(self, name, **kwargs):
        super().__init__(name, **kwargs)
        if "state" in kwargs:
            return
        temp_locked = self.locked
        self.locked = False
        self._value = None
        self.prof = kwargs.get("prof", None)
        self.limits = kwargs.get("limits", [None, None])
        self.cyclic = kwargs.get("cyclic", False)
        self.shape = kwargs.get("shape", None)
        self.value = kwargs.get("value", None)
        self.units = kwargs.get("units", "none")
        self.uncertainty = kwargs.get("uncertainty", None)
        self.to()
        self.locked = temp_locked
    @property
    def value(self):
        """The ``value`` of a Parameter_Node is somewhat complicated, there
        are a number of states it can take on. The most
        straightforward is just a Tensor, if a Tensor (or just an
        iterable like a list or numpy.ndarray) is provided then the
        node is required to be a leaf node and it will store the value
        to be accessed later by other parts of AstroPhot. Another
        option is to set the value as another node (they will
        automatically be linked), in this case the node's ``value`` is
        just a wrapper to call for the ``value`` of the linked
        node. Finally, the value may be a function which allows for
        arbitrarily complex values to be computed from other node's
        values. The function must take as an argument the current
        Parameter_Node instance and return a Tensor. Here are some
        examples of the various ways of interacting with the ``value``
        for a hypothetical parameter ``P``::
          P.value = 1. # Will create a tensor with value 1.
          P.value = P2 # calling P.value will actually call P2.value
          def compute_value(param):
            return param["P2"].value**2
          P.value = compute_value # calling P.value will call the function as: compute_value(P) which will return P2.value**2
        """
        if isinstance(self._value, Parameter_Node):
            return self._value.value
        if isinstance(self._value, FunctionType):
            return self._value(self)
        return self._value
    @property
    def mask(self):
        """The mask tensor is stored internally and it cuts out some values
        from the parameter. This is used by the ``vector`` methods in
        the class to give the parameter DAG a dynamic shape.
        """
        if not self.leaf:
            return self.vector_mask()
        try:
            return self._mask
        except AttributeError:
            return torch.ones(self.shape, dtype = torch.bool, device = AP_config.ap_device)
        
    @property
    def identities(self):
        """This creates a numpy array of strings which uniquely identify
        every element in the parameter vector. For example a
        ``center`` parameter with two components [x,y] would have
        identities be ``np.array(["123456:0", "123456:1"])`` where the
        first part is the unique id for the Parameter_Node object and
        the second number indexes where in the value tensor it refers
        to.
        """
        if self.leaf:
            idstr = str(self.identity)
            return np.array(tuple(f"{idstr}:{i}" for i in range(self.size)))
        flat = self.flat(include_locked = False, include_links = False)
        vec = tuple(node.identities for node in flat.values())
        if len(vec) > 0:
            return np.concatenate(vec)
        return np.array(())
    
    @property
    def names(self):
        """Returns a numpy array of names for all the elements of the
        ``vector`` representation where the name is determined by the
        name of the parameters. Note that this does not create a
        unique name for each element and this should only be used for
        graphical purposes on small parameter DAGs.
        """
        if self.leaf:
            S = self.size
            if S == 1:
                return np.array((self.name,))
            return np.array(tuple(f"{self.name}:{i}" for i in range(self.size)))
        flat = self.flat(include_locked = False, include_links = False)
        vec = tuple(node.names for node in flat.values())
        if len(vec) > 0:
            return np.concatenate(vec)
        return np.array(())
    
[docs]
    def vector_values(self):
        """The vector representation is for values which correspond to
        fundamental inputs to the parameter DAG. Since the DAG may
        have linked nodes, or functions which produce values derived
        from other node values, the collection of all "values" is not
        necessarily of use for some methods such as fitting
        algorithms. The vector representation is useful for optimizers
        as it gives a fundamental representation of the parameter
        DAG. The vector_values function returns a vector of the
        ``value`` for each leaf node.
        """
        if self.leaf:
            return self.value[self.mask].flatten()
        
        flat = self.flat(include_locked = False, include_links = False)
        vec = tuple(node.vector_values() for node in flat.values())
        if len(vec) > 0:
            return torch.cat(vec)
        return torch.tensor((), dtype = AP_config.ap_dtype, device = AP_config.ap_device) 
    
[docs]
    def vector_uncertainty(self):
        """This returns a vector (see vector_values) with the uncertainty for
        each leaf node.
        """
        if self.leaf:
            if self._uncertainty is None:
                self.uncertainty = torch.ones_like(self.value)
            return self.uncertainty[self.mask].flatten()
        
        flat = self.flat(include_locked = False, include_links = False)
        vec = tuple(node.vector_uncertainty() for node in flat.values())
        if len(vec) > 0:
            return torch.cat(vec)
        return torch.tensor((), dtype = AP_config.ap_dtype, device = AP_config.ap_device) 
[docs]
    def vector_representation(self):
        """This returns a vector (see vector_values) with the representation
        for each leaf node. The representation is an alternative view
        of each value which is mapped into the (-inf, inf) range where
        optimization is more stable.
        """
        return self.vector_transform_val_to_rep(self.vector_values()) 
[docs]
    def vector_mask(self):
        """This returns a vector (see vector_values) with the mask for each
        leaf node. Note however that the mask is not itself masked,
        this vector is always the full size of the unmasked parameter
        DAG.
        """
        if self.leaf:
            return self.mask.flatten()
        
        flat = self.flat(include_locked = False, include_links = False)
        vec = tuple(node.vector_mask() for node in flat.values())
        if len(vec) > 0:
            return torch.cat(vec)
        return torch.tensor((), dtype = AP_config.ap_dtype, device = AP_config.ap_device) 
[docs]
    def vector_identities(self):
        """This returns a vector (see vector_values) with the identities for
        each leaf node.
        """
        if self.leaf:
            return self.identities[self.vector_mask().detach().cpu().numpy()].flatten()
        flat = self.flat(include_locked = False, include_links = False)
        vec = tuple(node.vector_identities() for node in flat.values())
        if len(vec) > 0:
            return np.concatenate(vec)
        return np.array(()) 
[docs]
    def vector_names(self):
        """This returns a vector (see vector_values) with the names for each
        leaf node.
        """
        if self.leaf:
            return self.names[self.vector_mask().detach().cpu().numpy()].flatten()
        flat = self.flat(include_locked = False, include_links = False)
        vec = tuple(node.vector_names() for node in flat.values())
        if len(vec) > 0:
            return np.concatenate(vec)
        return np.array(()) 
        
[docs]
    def vector_set_values(self, values):
        """This function allows one to update the full vector of values in a
        single call by providing a tensor of the appropriate size. The
        input will be separated so that the correct elements are
        passed to the correct leaf nodes.
        """
        values = torch.as_tensor(values, dtype = AP_config.ap_dtype, device = AP_config.ap_device).flatten()
        if self.leaf:
            self._value[self.mask] = values
            return
        mask = self.vector_mask()
        flat = self.flat(include_locked = False, include_links = False)
        loc = 0
        for node in flat.values():
            node.vector_set_values(values[mask[:loc].sum().int():mask[:loc+node.size].sum().int()])
            loc += node.size 
            
[docs]
    def vector_set_uncertainty(self, uncertainty):
        """Update the uncertainty vector for this parameter DAG (see
        vector_set_values).
        """
        uncertainty = torch.as_tensor(uncertainty, dtype = AP_config.ap_dtype, device = AP_config.ap_device)        
        if self.leaf:
            if self._uncertainty is None:
                self._uncertainty = torch.ones_like(self.value)
            self._uncertainty[self.mask] = uncertainty
            return
        mask = self.vector_mask()
        flat = self.flat(include_locked = False, include_links = False)
        loc = 0
        for node in flat.values():
            node.vector_set_uncertainty(uncertainty[mask[:loc].sum().int():mask[:loc+node.size].sum().int()])
            loc += node.size 
[docs]
    def vector_set_mask(self, mask):
        """Update the mask vector for this parameter DAG (see
        vector_set_values). Note again that the mask vector is always
        the full size of the DAG.
        """
        mask = torch.as_tensor(mask, dtype = torch.bool, device = AP_config.ap_device)
        if self.leaf:
            self._mask = mask.reshape(self.shape)
            return
        flat = self.flat(include_locked = False, include_links = False)
        loc = 0
        for node in flat.values():
            node.vector_set_mask(mask[loc:loc+node.size])
            loc += node.size 
        
[docs]
    def vector_set_representation(self, rep):
        """Update the representation vector for this parameter DAG (see
        vector_set_values).
        """        
        self.vector_set_values(self.vector_transform_rep_to_val(rep)) 
        
    
        
    def _set_val_self(self, val):
        """Handles the setting of the value for a leaf node. Ensures the
        value is a Tensor and that it has the right shape. Will also
        check the limits of the value which has different behaviour
        depending on if it is cyclic, one sided, or two sided.
        """
        val = torch.as_tensor(
            val, dtype=AP_config.ap_dtype, device=AP_config.ap_device
        )
        if self.shape is not None:
            self._value = val.reshape(self.shape)
        else:
            self._value = val
            self.shape = self._value.shape
                
        if self.cyclic:
            self._value = self.limits[0] + ((self._value - self.limits[0]) % (self.limits[1] - self.limits[0]))
            return
        if self.limits[0] is not None:
            if not torch.all(self._value > self.limits[0]):
                raise InvalidParameter(f"{self.name} has lower limit {self.limits[0].detach().cpu().tolist()}")
        if self.limits[1] is not None:
            if not torch.all(self._value < self.limits[1]):
                raise InvalidParameter(f"{self.name} has upper limit {self.limits[1].detach().cpu().tolist()}")
            
    def _soft_set_val_self(self, val):
        """The same as ``_set_val_self`` except that it doesn't raise an
        error when the values are set outside their range, instead it
        will push the values into the range defined by the limits.
        """
        val = torch.as_tensor(
            val, dtype=AP_config.ap_dtype, device=AP_config.ap_device
        )
        if self.shape is not None:
            self._value = val.reshape(self.shape)
        else:
            self._value = val
            self.shape = self._value.shape
            
        if self.cyclic:
            self._value = self.limits[0] + ((self._value - self.limits[0]) % (self.limits[1] - self.limits[0]))
            return
        if self.limits[0] is not None:
            self._value = torch.maximum(self._value, self.limits[0] + torch.ones_like(self._value) * 1e-3)
        if self.limits[1] is not None:
            self._value = torch.minimum(self._value, self.limits[1] - torch.ones_like(self._value) * 1e-3)
            
        
    @value.setter
    def value(self, val):
        if self.locked and not Node.global_unlock:
            return
        if val is None:
            self._value = None
            self.shape = None
            return
        if isinstance(val, str):
            self._value = val
            return
        if isinstance(val, Parameter_Node):
            self._value = val
            self.shape = None
            # Link only to the pointed node
            self.dump()
            self.link(val)
            return
        if isinstance(val, FunctionType):
            self._value = val
            self.shape = None
            return
        if len(self.nodes) > 0:
            self.vector_set_values(val)
            self.shape = None
            return
        self._set_val_self(val)
        self.dump()
    @property
    def shape(self):
        try:
            if isinstance(self._value, Parameter_Node):
                return self._value.shape
            if isinstance(self._value, FunctionType):
                return self.value.shape
            if self.leaf:
                return self._shape
        except AttributeError:
            pass
        return None
    @shape.setter
    def shape(self, shape):
        self._shape = shape
        
    @property
    def prof(self):
        return self._prof
    @prof.setter
    def prof(self, prof):
        if self.locked and not Node.global_unlock:
            return
        if prof is None:
            self._prof = None
            return
        self._prof = torch.as_tensor(
            prof, dtype=AP_config.ap_dtype, device=AP_config.ap_device
        )
    @property
    def uncertainty(self):
        return self._uncertainty
    @uncertainty.setter
    def uncertainty(self, unc):
        if self.locked and not Node.global_unlock:
            return
        if unc is None:
            self._uncertainty = None
            return
        
        self._uncertainty = torch.as_tensor(
            unc, dtype=AP_config.ap_dtype, device=AP_config.ap_device
        )
        # Ensure that the uncertainty tensor has the same shape as the data
        if self.shape is not None:
            if self._uncertainty.shape != self.shape:
                self._uncertainty = self._uncertainty * torch.ones(self.shape, dtype = AP_config.ap_dtype, device = AP_config.ap_device)
    @property
    def limits(self):
        return self._limits
    @limits.setter
    def limits(self, limits):
        if self.locked and not Node.global_unlock:
            return
        if limits[0] is None:
            low = None
        else:
            low = torch.as_tensor(
                limits[0], dtype=AP_config.ap_dtype, device=AP_config.ap_device
            )
        if limits[1] is None:
            high = None
        else:
            high = torch.as_tensor(
                limits[1], dtype=AP_config.ap_dtype, device=AP_config.ap_device
            )
        self._limits = (low, high)
[docs]
    def to(self, dtype=None, device=None):
        """
        updates the datatype or device of this parameter
        """
        if dtype is not None:
            dtype = AP_config.ap_dtype
        if device is not None:
            device = AP_config.ap_device
            
        if isinstance(self._value, torch.Tensor):
            self._value = self._value.to(dtype=dtype, device=device)
        elif len(self.nodes) > 0:
            for node in self.nodes.values():
                node.to(dtype, device)
        if isinstance(self._uncertainty, torch.Tensor):
            self._uncertainty = self._uncertainty.to(dtype=dtype, device=device)
        if isinstance(self.prof, torch.Tensor):
            self.prof = self.prof.to(dtype=dtype, device=device)
        return self 
[docs]
    def get_state(self):
        """Return the values representing the current state of the parameter,
        this can be used to re-load the state later from memory.
        """
        state = super().get_state()
        if self.value is not None:
            if isinstance(self._value, Node):
                state["value"] = "NODE:" + str(self._value.identity)
            elif isinstance(self._value, FunctionType):
                state["value"] = "FUNCTION:" + self._value.__name__
            else:
                state["value"] = self.value.detach().cpu().numpy().tolist()
        if self.shape is not None:
            state["shape"] = list(self.shape)
        if self.units is not None:
            state["units"] = self.units
        if self.uncertainty is not None:
            state["uncertainty"] = self.uncertainty.detach().cpu().numpy().tolist()
        if not (self.limits[0] is None and self.limits[1] is None):
            save_lim = []
            for i in [0, 1]:
                if self.limits[i] is None:
                    save_lim.append(None)
                else:
                    save_lim.append(self.limits[i].detach().cpu().tolist())
            state["limits"] = save_lim
        if self.cyclic:
            state["cyclic"] = self.cyclic
        if self.prof is not None:
            state["prof"] = self.prof.detach().cpu().tolist()
        return state 
    
[docs]
    def set_state(self, state):
        """Update the state of the parameter given a state variable which
        holds all information about a variable.
        """
        
        super().set_state(state)
        save_locked = self.locked
        self.locked = False
        self.units = state.get("units", None)
        self.limits = state.get("limits", (None,None))
        self.cyclic = state.get("cyclic", False)
        self.value = state.get("value", None)
        self.uncertainty = state.get("uncertainty", None)
        self.prof = state.get("prof", None)
        self.locked = save_locked 
[docs]
    def flat_detach(self):
        """Due to the system used to track and update values in the DAG, some
        parts of the computational graph used to determine gradients
        may linger after calling .backward on a model using the
        parameters. This function essentially resets all the leaf
        values so that the full computational graph is freed.
        """
        for P in self.flat().values():
            P.value = P.value.detach()
            if P.uncertainty is not None:
                P.uncertainty = P.uncertainty.detach()
            if P.prof is not None:
                P.prof = P.prof.detach() 
    @property
    def size(self):
        if self.leaf:
            return self.value.numel()
        return self.vector_values().numel()
        
    def __len__(self):
        """The number of elements required to fully describe the DAG. This is
        the number of elements in the vector_values tensor.
        """        
        return self.size
[docs]
    def print_params(self, include_locked=True, include_prof=True, include_id=True):
        if self.leaf:
            return f"{self.name}" + (f" (id-{self.identity})" if include_id else "") + f": {self.value.detach().cpu().tolist()}" + ("" if self.uncertainty is None else f" +- {self.uncertainty.detach().cpu().tolist()}") + f" [{self.units}]" + ("" if self.limits[0] is None and self.limits[1] is None else f", limits: ({None if self.limits[0] is None else self.limits[0].detach().cpu().tolist()}, {None if self.limits[1] is None else self.limits[1].detach().cpu().tolist()})") + (", cyclic" if self.cyclic else "") + (", locked" if self.locked else "") + (f", prof: {self.prof.detach().cpu().tolist()}" if include_prof and self.prof is not None else "")
        elif isinstance(self._value, Parameter_Node):
            return self.name + (f" (id-{self.identity})" if include_id else "") + " points to: " + self._value.print_params(include_locked=include_locked, include_prof=include_prof, include_id=include_id)
        return self.name + (f" (id-{self.identity}, {('function node, '+self._value.__name__) if isinstance(self._value, FunctionType) else 'branch node'})" if include_id else "") + ":\n" 
        
    def __str__(self):
        reply = self.print_params(include_locked=True, include_prof=False, include_id=False)
        if self.leaf or isinstance(self._value, Parameter_Node):
            return reply
        reply += "\n".join(node.print_params(include_locked=True, include_prof=False, include_id=False) for node in self.flat(include_locked=True, include_links=False).values())
        return reply
    
    def __repr__(self, level = 0, indent = '  '):
        reply = indent*level + self.print_params(include_locked=True, include_prof=False, include_id=True)
        if self.leaf or isinstance(self._value, Parameter_Node):
            return reply
        reply += "\n".join(node.__repr__(level = level+1, indent=indent) for node in self.nodes.values())
        return reply