Source code for astrophot.fit.iterative

# Apply a different optimizer iteratively
from typing import Dict, Any, Sequence, Union
import os
from time import time
from copy import copy
import random

import numpy as np
import torch
from scipy.optimize import minimize
from scipy.special import gammainc

from .base import BaseOptimizer
from ..models import AstroPhot_Model
from .lm import LM
from ..param import Param_Mask
from .. import AP_config

__all__ = ["Iter", "Iter_LM"]


[docs] class Iter(BaseOptimizer): """Optimizer wrapper that performs optimization iteratively. This optimizer applies a different optimizer to a group model iteratively. It can be used for complex fits or when the number of models to fit is too large to fit in memory. Args: model: An `AstroPhot_Model` object to perform optimization on. method: The optimizer class to apply at each iteration step. initial_state: Optional initial state for optimization, defaults to None. max_iter: Maximum number of iterations, defaults to 100. method_kwargs: Keyword arguments to pass to `method`. **kwargs: Additional keyword arguments. Attributes: ndf: Degrees of freedom of the data. method: The optimizer class to apply at each iteration step. Default: Levenberg-Marquardt method_kwargs: Keyword arguments to pass to `method`. iteration: The number of iterations performed. lambda_history: A list of the states at each iteration step. loss_history: A list of the losses at each iteration step """ def __init__( self, model: AstroPhot_Model, method: BaseOptimizer = LM, initial_state: np.ndarray = None, max_iter: int = 100, method_kwargs: Dict[str, Any] = {}, **kwargs: Dict[str, Any], ) -> None: super().__init__(model, initial_state, max_iter=max_iter, **kwargs) self.method = method self.method_kwargs = method_kwargs if "relative_tolerance" not in method_kwargs and isinstance(method, LM): # Lower tolerance since it's not worth fine tuning a model when its neighbors will be shifting soon anyway self.method_kwargs["relative_tolerance"] = 1e-3 self.method_kwargs["max_iter"] = 15 # # pixels # parameters self.ndf = self.model.target[self.model.window].flatten("data").size(0) - len( self.current_state ) if self.model.target.has_mask: # subtract masked pixels from degrees of freedom self.ndf -= torch.sum( self.model.target[self.model.window].flatten("mask") ).item()
[docs] def sub_step(self, model: "AstroPhot_Model") -> None: """ Perform optimization for a single model. Args: model: The model to perform optimization on. """ self.Y -= model() initial_target = model.target model.target = model.target[model.window] - self.Y[model.window] res = self.method(model, **self.method_kwargs).fit() self.Y += model() if self.verbose > 1: AP_config.ap_logger.info(res.message) model.target = initial_target
[docs] def step(self) -> None: """ Perform a single iteration of optimization. """ if self.verbose > 0: AP_config.ap_logger.info("--------iter-------") # Fit each model individually for model in self.model.models.values(): if self.verbose > 0: AP_config.ap_logger.info(model.name) self.sub_step(model) # Update the current state self.current_state = self.model.parameters.vector_representation() # Update the loss value with torch.no_grad(): if self.verbose > 0: AP_config.ap_logger.info("Update Chi^2 with new parameters") self.Y = self.model( parameters=self.current_state, as_representation=True, ) D = self.model.target[self.model.window].flatten("data") V = ( self.model.target[self.model.window].flatten("variance") if self.model.target.has_variance else 1.0 ) if self.model.target.has_mask: M = self.model.target[self.model.window].flatten("mask") loss = ( torch.sum( (((D - self.Y.flatten("data")) ** 2) / V)[torch.logical_not(M)] ) / self.ndf ) else: loss = torch.sum(((D - self.Y.flatten("data")) ** 2 / V)) / self.ndf if self.verbose > 0: AP_config.ap_logger.info(f"Loss: {loss.item()}") self.lambda_history.append(np.copy((self.current_state).detach().cpu().numpy())) self.loss_history.append(loss.item()) # Test for convergence if self.iteration >= 2 and ( (-self.relative_tolerance * 1e-3) < ((self.loss_history[-2] - self.loss_history[-1]) / self.loss_history[-1]) < (self.relative_tolerance / 10) ): self._count_finish += 1 else: self._count_finish = 0 self.iteration += 1
[docs] def fit(self) -> "BaseOptimizer": """ Fit the models to the target. """ self.iteration = 0 self.Y = self.model(parameters=self.current_state, as_representation=True) start_fit = time() try: while True: self.step() if self.save_steps is not None: self.model.save( os.path.join( self.save_steps, f"{self.model.name}_Iteration_{self.iteration:03d}.yaml", ) ) if self.iteration > 2 and self._count_finish >= 2: self.message = self.message + "success" break elif self.iteration >= self.max_iter: self.message = ( self.message + f"fail max iterations reached: {self.iteration}" ) break except KeyboardInterrupt: self.message = self.message + "fail interrupted" self.model.parameters.vector_set_representation(self.res()) if self.verbose > 1: AP_config.ap_logger.info( f"Iter Fitting complete in {time() - start_fit} sec with message: {self.message}" ) return self
[docs] class Iter_LM(BaseOptimizer): """Optimization wrapper that call LM optimizer on subsets of variables. Iter_LM takes the full set of parameters for a model and breaks them down into chunks as specified by the user. It then calls Levenberg-Marquardt optimization on the subset of parameters, and iterates through all subsets until every parameter has been optimized. It cycles through these chunks until convergence. This method is very powerful in situations where the full optimization problem cannot fit in memory, or where the optimization problem is too complex to tackle as a single large problem. In full LM optimization a single problematic parameter can ripple into issues with every other parameter, so breaking the problem down can sometimes make an otherwise intractable problem easier. For small problems with only a few models, it is likely better to optimize the full problem with LM as, when it works, LM is faster than the Iter_LM method. Args: chunks (Union[int, tuple]): Specify how to break down the model parameters. If an integer, at each iteration the algorithm will break the parameters into groups of that size. If a tuple, should be a tuple of tuples of strings which give an explicit pairing of parameters to optimize, note that it is allowed to have variable size chunks this way. Default: 50 method (str): How to iterate through the chunks. Should be one of: random, sequential. Default: random """ def __init__( self, model: "AstroPhot_Model", initial_state: Sequence = None, chunks: Union[int, tuple] = 50, max_iter: int = 100, method: str = "random", LM_kwargs: dict = {}, **kwargs: Dict[str, Any], ) -> None: super().__init__(model, initial_state, max_iter=max_iter, **kwargs) self.chunks = chunks self.method = method self.LM_kwargs = LM_kwargs # # pixels # parameters self.ndf = self.model.target[self.model.window].flatten("data").numel() - len( self.current_state ) if self.model.target.has_mask: # subtract masked pixels from degrees of freedom self.ndf -= torch.sum( self.model.target[self.model.window].flatten("mask") ).item()
[docs] def step(self): # These store the chunking information depending on which chunk mode is selected param_ids = list(self.model.parameters.vector_identities()) init_param_ids = list(self.model.parameters.vector_identities()) _chunk_index = 0 _chunk_choices = None res = None if self.verbose > 0: AP_config.ap_logger.info("--------iter-------") # Loop through all the chunks while True: chunk = torch.zeros(len(init_param_ids), dtype = torch.bool, device = AP_config.ap_device) if isinstance(self.chunks, int): if len(param_ids) == 0: break if self.method == "random": # Draw a random chunk of ids for pid in random.sample(param_ids, min(len(param_ids), self.chunks)): chunk[init_param_ids.index(pid)] = True else: # Draw the next chunk of ids for pid in param_ids[: self.chunks]: chunk[init_param_ids.index(pid)] = True # Remove the selected ids from the list for p in np.array(init_param_ids)[chunk.detach().cpu().numpy()]: param_ids.pop(param_ids.index(p)) elif isinstance(self.chunks, (tuple, list)): if _chunk_choices is None: # Make a list of the chunks as given explicitly _chunk_choices = list(range(len(self.chunks))) if self.method == "random": if len(_chunk_choices) == 0: break # Select a random chunk from the given groups sub_index = random.choice(_chunk_choices) _chunk_choices.pop(_chunk_choices.index(sub_index)) for pid in self.chunks[sub_index]: chunk[param_ids.index(pid)] = True else: if _chunk_index >= len(self.chunks): break # Select the next chunk in order for pid in self.chunks[_chunk_index]: chunk[param_ids.index(pid)] = True _chunk_index += 1 else: raise ValueError( "Unrecognized chunks value, should be one of int, tuple. not: {type(self.chunks)}" ) if self.verbose > 1: AP_config.ap_logger.info(str(chunk)) del res with Param_Mask(self.model.parameters, chunk): res = LM( self.model, ndf=self.ndf, **self.LM_kwargs, ).fit() if self.verbose > 0: AP_config.ap_logger.info(f"chunk loss: {res.res_loss()}") if self.verbose > 1: AP_config.ap_logger.info(f"chunk message: {res.message}") self.loss_history.append(res.res_loss()) self.lambda_history.append( self.model.parameters.vector_representation() .detach() .cpu() .numpy() ) if self.verbose > 0: AP_config.ap_logger.info(f"Loss: {self.loss_history[-1]}") # test for convergence if self.iteration >= 2 and ( (-self.relative_tolerance * 1e-3) < ((self.loss_history[-2] - self.loss_history[-1]) / self.loss_history[-1]) < (self.relative_tolerance / 10) ): self._count_finish += 1 else: self._count_finish = 0 self.iteration += 1
[docs] def fit(self): self.iteration = 0 start_fit = time() try: while True: self.step() if self.save_steps is not None: self.model.save( os.path.join( self.save_steps, f"{self.model.name}_Iteration_{self.iteration:03d}.yaml", ) ) if self.iteration > 2 and self._count_finish >= 2: self.message = self.message + "success" break elif self.iteration >= self.max_iter: self.message = ( self.message + f"fail max iterations reached: {self.iteration}" ) break except KeyboardInterrupt: self.message = self.message + "fail interrupted" self.model.parameters.vector_set_representation(self.res()) if self.verbose > 1: AP_config.ap_logger.info( f"Iter Fitting complete in {time() - start_fit} sec with message: {self.message}" ) return self