Source code for astrophot.models.group_model_object

from copy import deepcopy
from typing import Optional, Sequence
from collections import OrderedDict

import torch
import numpy as np
import matplotlib.pyplot as plt

from .core_model import AstroPhot_Model
from ..image import (
    Image,
    Model_Image,
    Model_Image_List,
    Target_Image,
    Image_List,
    Window,
    Window_List,
    Jacobian_Image,
)
from ..utils.decorators import ignore_numpy_warnings, default_internal
from ._shared_methods import select_target
from ..param import Parameter_Node
from ..errors import InvalidTarget
from .. import AP_config

__all__ = ["Group_Model"]

[docs] class Group_Model(AstroPhot_Model): """Model object which represents a list of other models. For each general AstroPhot model method, this calls all the appropriate models from its list and combines their output into a single summed model. This class shoould be used when describing any system more comlex than makes sense to represent with a single light distribution. Args: name (str): unique name for the full group model target (Target_Image): the target image that this group model is trying to fit to models (Optional[Sequence[AstroPhot_Model]]): list of AstroPhot_Model objects which will combine for the group model locked (bool): if the whole group of models should be locked """ model_type = f"group {AstroPhot_Model.model_type}" useable = True def __init__( self, *, name: Optional[str] = None, models: Optional[Sequence[AstroPhot_Model]] = None, **kwargs, ): super().__init__(name=name, models=models, **kwargs) self._param_tuple = None self.models = OrderedDict() if models is not None: self.add_model(models) self._psf_mode = "none" self.update_window() if "filename" in kwargs: self.load(kwargs["filename"], new_name=name)
[docs] def add_model(self, model): """Adds a new model to the group model list. Ensures that the same model isn't added a second time. Parameters: model: a model object to add to the model list. """ if isinstance(model, (tuple, list)): for mod in model: self.add_model(mod) return if model.name in self.models and model is not self.models[model.name]: raise KeyError( f"{self.name} already has model with name {model.name}, every model must have a unique name." ) self.models[model.name] = model self.parameters.link(model.parameters) self.update_window()
[docs] def update_window(self, include_locked: bool = False): """Makes a new window object which encloses all the windows of the sub models in this group model object. """ if isinstance( self.target, Image_List ): # Window_List if target is a Target_Image_List new_window = [None] * len(self.target.image_list) for model in self.models.values(): if model.locked and not include_locked: continue if isinstance(model.target, Image_List): for target, window in zip(model.target, model.window): index = self.target.index(target) if new_window[index] is None: new_window[index] = window.copy() else: new_window[index] |= window elif isinstance(model.target, Target_Image): index = self.target.index(model.target) if new_window[index] is None: new_window[index] = model.window.copy() else: new_window[index] |= model.window else: raise NotImplementedError( f"Group_Model cannot construct a window for itself using {type(model.target)} object. Must be a Target_Image" ) new_window = Window_List(new_window) else: new_window = None for model in self.models.values(): if model.locked and not include_locked: continue if new_window is None: new_window = model.window.copy() else: new_window |= model.window self.window = new_window
[docs] @torch.no_grad() @ignore_numpy_warnings @select_target @default_internal def initialize( self, target: Optional[Image] = None, parameters=None, **kwargs ): """ Initialize each model in this group. Does this by iteratively initializing a model then subtracting it from a copy of the target. Args: target (Optional["Target_Image"]): A Target_Image instance to use as the source for initializing the model parameters on this image. """ self._param_tuple = None super().initialize(target=target, parameters=parameters) target_copy = target.copy() for model in self.models.values(): model.initialize( target=target_copy, parameters=parameters[model.name] ) target_copy -= model(parameters=parameters[model.name])
[docs] def sample( self, image: Optional[Image] = None, window: Optional[Window] = None, parameters: Optional["Parameter_Node"] = None, ): """Sample the group model on an image. Produces the flux values for each pixel associated with the models in this group. Each model is called individually and the results are added together in one larger image. Args: image (Optional["Model_Image"]): Image to sample on, overrides the windows for each sub model, they will all be evaluated over this entire image. If left as none then each sub model will be evaluated in its window. """ self._param_tuple = None if image is None: sample_window = True image = self.make_model_image(window=window) else: sample_window = False if parameters is None: parameters = self.parameters for model in self.models.values(): if window is not None and isinstance(window, Window_List): indices = self.target.match_indices(model.target) if isinstance(indices, (tuple, list)): use_window = Window_List( window_list=list(window.window_list[ind] for ind in indices) ) else: use_window = window.window_list[indices] else: use_window = window if sample_window: # Will sample the model fit window then add to the image image += model( window=use_window, parameters=parameters[model.name] ) else: # Will sample the entire image model( image, window=use_window, parameters=parameters[model.name] ) return image
[docs] @torch.no_grad() def jacobian( self, parameters: Optional[torch.Tensor] = None, as_representation: bool = False, pass_jacobian: Optional[Jacobian_Image] = None, window: Optional[Window] = None, **kwargs, ): """Compute the jacobian for this model. Done by first constructing a full jacobian (Npixels * Nparameters) of zeros then call the jacobian method of each sub model and add it in to the total. Args: parameters (Optional[torch.Tensor]): 1D parameter vector to overwrite current values as_representation (bool): Indiates if the "parameters" argument is in the form of the real values, or as representations in the (-inf,inf) range. Default False pass_jacobian (Optional["Jacobian_Image"]): A Jacobian image pre-constructed to be passed along instead of constructing new Jacobians """ if window is None: window = self.window self._param_tuple = None if parameters is not None: if as_representation: self.parameters.vector_set_representation(parameters) else: self.parameters.vector_set_values(parameters) if pass_jacobian is None: jac_img = self.target[window].jacobian_image( parameters=self.parameters.vector_identities() ) else: jac_img = pass_jacobian for model in self.models.values(): if isinstance(model, Group_Model): model.jacobian( as_representation=as_representation, pass_jacobian=jac_img, window=window, ) else: # fixme, maybe make pass_jacobian be filled internally to each model jac_img += model.jacobian( as_representation=as_representation, pass_jacobian=jac_img, window=window, ) return jac_img
def __iter__(self): return (mod for mod in self.models.values()) @property def psf_mode(self): return self._psf_mode @psf_mode.setter def psf_mode(self, value): self._psf_mode = value for model in self.models.values(): model.psf_mode = value @property def target(self): try: return self._target except AttributeError: return None @target.setter def target(self, tar): if not (tar is None or isinstance(tar, Target_Image)): raise InvalidTarget("Group_Model target must be a Target_Image instance.") self._target = tar if hasattr(self, "models"): for model in self.models.values(): model.target = tar
[docs] def get_state(self, save_params = True): """Returns a dictionary with information about the state of the model and its parameters. """ state = super().get_state(save_params = save_params) if save_params: state["parameters"] = self.parameters.get_state() if "models" not in state: state["models"] = {} for model in self.models.values(): state["models"][model.name] = model.get_state(save_params = False) return state
[docs] def load(self, filename="AstroPhot.yaml", new_name = None): """Loads an AstroPhot state file and updates this model with the loaded parameters. """ state = AstroPhot_Model.load(filename) if new_name is None: new_name = state["name"] self.name = new_name if isinstance(state["parameters"], Parameter_Node): self.parameters = state["parameters"] else: self.parameters = Parameter_Node(self.name, state = state["parameters"]) for model in state["models"]: state["models"][model]["parameters"] = self.parameters[model] for own_model in self.models.values(): if model == own_model.name: own_model.load(state["models"][model]) break else: self.add_model( AstroPhot_Model( name=model, filename=state["models"][model], target=self.target ) ) self.update_window()