Source code for astrophot.fit.base

from time import time
from typing import Any, Union, Sequence, Optional

import numpy as np
import torch
from scipy.optimize import minimize
from scipy.special import gammainc

from .. import AP_config


__all__ = ["BaseOptimizer"]


[docs] class BaseOptimizer(object): """ Base optimizer object that other optimizers inherit from. Ensures consistent signature for the classes. Parameters: model: an AstroPhot_Model object that will have its (unlocked) parameters optimized [AstroPhot_Model] initial_state: optional initialization for the parameters as a 1D tensor [tensor] max_iter: maximum allowed number of iterations [int] relative_tolerance: tolerance for counting success steps as: 0 < (Chi2^2 - Chi1^2)/Chi1^2 < tol [float] """ def __init__( self, model: "AstroPhot_Model", initial_state: Sequence = None, relative_tolerance: float = 1e-3, fit_window: Optional["Window"] = None, **kwargs, ) -> None: """ Initializes a new instance of the class. Args: model (object): An object representing the model. initial_state (Optional[Sequence]): The initial state of the model could be any tensor. If `None`, the model's default initial state will be used. relative_tolerance (float): The relative tolerance for the optimization. fit_parameters_identity (Optiona[tuple]): a tuple of parameter identity strings which tell the LM optimizer which parameters of the model to fit. **kwargs (dict): Additional keyword arguments. Attributes: model (object): An object representing the model. verbose (int): The verbosity level. current_state (Tensor): The current state of the model. max_iter (int): The maximum number of iterations. iteration (int): The current iteration number. save_steps (Optional[str]): Save intermediate results to this path. relative_tolerance (float): The relative tolerance for the optimization. lambda_history (List[ndarray]): A list of the optimization steps. loss_history (List[float]): A list of the optimization losses. message (str): An informational message. """ self.model = model self.verbose = kwargs.get("verbose", 0) if fit_window is None: self.fit_window = self.model.window else: self.fit_window = fit_window & self.model.window if initial_state is None: self.model.initialize() initial_state = self.model.parameters.vector_representation() else: initial_state = torch.as_tensor( initial_state, dtype=AP_config.ap_dtype, device=AP_config.ap_device ) self.current_state = torch.as_tensor( initial_state, dtype=AP_config.ap_dtype, device=AP_config.ap_device ) if self.verbose > 1: AP_config.ap_logger.info(f"initial state: {self.current_state}") self.max_iter = kwargs.get("max_iter", 100 * len(initial_state)) self.iteration = 0 self.save_steps = kwargs.get("save_steps", None) self.relative_tolerance = relative_tolerance self.lambda_history = [] self.loss_history = [] self.message = ""
[docs] def fit(self) -> "BaseOptimizer": """ Raises: NotImplementedError: Error is raised if this method is not implemented in a subclass of BaseOptimizer. """ raise NotImplementedError( "Please use a subclass of BaseOptimizer for optimization" )
[docs] def step(self, current_state: torch.Tensor = None) -> None: """Args: current_state (torch.Tensor, optional): Current state of the model parameters. Defaults to None. Raises: NotImplementedError: Error is raised if this method is not implemented in a subclass of BaseOptimizer. """ raise NotImplementedError( "Please use a subclass of BaseOptimizer for optimization" )
[docs] def chi2min(self) -> float: """ Returns the minimum value of chi^2 loss in the loss history. Returns: float: Minimum value of chi^2 loss. """ return np.nanmin(self.loss_history)
[docs] def res(self) -> np.ndarray: """Returns the value of lambda (regularization strength) at which minimum chi^2 loss was achieved. Returns: ndarray which is the Value of lambda at which minimum chi^2 loss was achieved. """ N = np.isfinite(self.loss_history) if np.sum(N) == 0: AP_config.ap_logger.warning( "Getting optimizer res with no real loss history, using current state" ) return self.current_state.detach().cpu().numpy() return np.array(self.lambda_history)[N][ np.argmin(np.array(self.loss_history)[N]) ]
[docs] def res_loss(self): N = np.isfinite(self.loss_history) return np.min(np.array(self.loss_history)[N])
[docs] @staticmethod def chi2contour(n_params: int, confidence: float = 0.682689492137) -> float: """ Calculates the chi^2 contour for the given number of parameters. Args: n_params (int): The number of parameters. confidence (float, optional): The confidence interval (default is 0.682689492137). Returns: float: The calculated chi^2 contour value. Raises: RuntimeError: If unable to compute the Chi^2 contour for the given number of parameters. """ def _f(x: float, nu: int) -> float: """Helper function for calculating chi^2 contour.""" return (gammainc(nu / 2, x / 2) - confidence) ** 2 for method in ["L-BFGS-B", "Powell", "Nelder-Mead"]: res = minimize(_f, x0=n_params, args=(n_params,), method=method, tol=1e-8) if res.success: return res.x[0] raise RuntimeError(f"Unable to compute Chi^2 contour for ndf: {ndf}")