from copy import deepcopy
from typing import Optional, Sequence
from collections import OrderedDict
import torch
import numpy as np
import matplotlib.pyplot as plt
from .core_model import AstroPhot_Model
from ..image import (
Image,
Model_Image,
Model_Image_List,
Target_Image,
Image_List,
Window,
Window_List,
Jacobian_Image,
)
from ..utils.decorators import ignore_numpy_warnings, default_internal
from ._shared_methods import select_target
from ..param import Parameter_Node
from ..errors import InvalidTarget
from .. import AP_config
__all__ = ["Group_Model"]
[docs]
class Group_Model(AstroPhot_Model):
"""Model object which represents a list of other models. For each
general AstroPhot model method, this calls all the appropriate
models from its list and combines their output into a single
summed model. This class shoould be used when describing any
system more comlex than makes sense to represent with a single
light distribution.
Args:
name (str): unique name for the full group model
target (Target_Image): the target image that this group model is trying to fit to
models (Optional[Sequence[AstroPhot_Model]]): list of AstroPhot_Model objects which will combine for the group model
locked (bool): if the whole group of models should be locked
"""
model_type = f"group {AstroPhot_Model.model_type}"
useable = True
def __init__(
self,
*,
name: Optional[str] = None,
models: Optional[Sequence[AstroPhot_Model]] = None,
**kwargs,
):
super().__init__(name=name, models=models, **kwargs)
self._param_tuple = None
self.models = OrderedDict()
if models is not None:
self.add_model(models)
self._psf_mode = "none"
self.update_window()
if "filename" in kwargs:
self.load(kwargs["filename"], new_name=name)
[docs]
def add_model(self, model):
"""Adds a new model to the group model list. Ensures that the same
model isn't added a second time.
Parameters:
model: a model object to add to the model list.
"""
if isinstance(model, (tuple, list)):
for mod in model:
self.add_model(mod)
return
if model.name in self.models and model is not self.models[model.name]:
raise KeyError(
f"{self.name} already has model with name {model.name}, every model must have a unique name."
)
self.models[model.name] = model
self.parameters.link(model.parameters)
self.update_window()
[docs]
def update_window(self, include_locked: bool = False):
"""Makes a new window object which encloses all the windows of the
sub models in this group model object.
"""
if isinstance(
self.target, Image_List
): # Window_List if target is a Target_Image_List
new_window = [None] * len(self.target.image_list)
for model in self.models.values():
if model.locked and not include_locked:
continue
if isinstance(model.target, Image_List):
for target, window in zip(model.target, model.window):
index = self.target.index(target)
if new_window[index] is None:
new_window[index] = window.copy()
else:
new_window[index] |= window
elif isinstance(model.target, Target_Image):
index = self.target.index(model.target)
if new_window[index] is None:
new_window[index] = model.window.copy()
else:
new_window[index] |= model.window
else:
raise NotImplementedError(
f"Group_Model cannot construct a window for itself using {type(model.target)} object. Must be a Target_Image"
)
new_window = Window_List(new_window)
else:
new_window = None
for model in self.models.values():
if model.locked and not include_locked:
continue
if new_window is None:
new_window = model.window.copy()
else:
new_window |= model.window
self.window = new_window
[docs]
@torch.no_grad()
@ignore_numpy_warnings
@select_target
@default_internal
def initialize(
self, target: Optional[Image] = None, parameters=None, **kwargs
):
"""
Initialize each model in this group. Does this by iteratively initializing a model then subtracting it from a copy of the target.
Args:
target (Optional["Target_Image"]): A Target_Image instance to use as the source for initializing the model parameters on this image.
"""
self._param_tuple = None
super().initialize(target=target, parameters=parameters)
target_copy = target.copy()
for model in self.models.values():
model.initialize(
target=target_copy, parameters=parameters[model.name]
)
target_copy -= model(parameters=parameters[model.name])
[docs]
def sample(
self,
image: Optional[Image] = None,
window: Optional[Window] = None,
parameters: Optional["Parameter_Node"] = None,
):
"""Sample the group model on an image. Produces the flux values for
each pixel associated with the models in this group. Each
model is called individually and the results are added
together in one larger image.
Args:
image (Optional["Model_Image"]): Image to sample on, overrides the windows for each sub model, they will all be evaluated over this entire image. If left as none then each sub model will be evaluated in its window.
"""
self._param_tuple = None
if image is None:
sample_window = True
image = self.make_model_image(window=window)
else:
sample_window = False
if parameters is None:
parameters = self.parameters
for model in self.models.values():
if window is not None and isinstance(window, Window_List):
indices = self.target.match_indices(model.target)
if isinstance(indices, (tuple, list)):
use_window = Window_List(
window_list=list(window.window_list[ind] for ind in indices)
)
else:
use_window = window.window_list[indices]
else:
use_window = window
if sample_window:
# Will sample the model fit window then add to the image
image += model(
window=use_window, parameters=parameters[model.name]
)
else:
# Will sample the entire image
model(
image, window=use_window, parameters=parameters[model.name]
)
return image
[docs]
@torch.no_grad()
def jacobian(
self,
parameters: Optional[torch.Tensor] = None,
as_representation: bool = False,
pass_jacobian: Optional[Jacobian_Image] = None,
window: Optional[Window] = None,
**kwargs,
):
"""Compute the jacobian for this model. Done by first constructing a
full jacobian (Npixels * Nparameters) of zeros then call the
jacobian method of each sub model and add it in to the total.
Args:
parameters (Optional[torch.Tensor]): 1D parameter vector to overwrite current values
as_representation (bool): Indiates if the "parameters" argument is in the form of the real values, or as representations in the (-inf,inf) range. Default False
pass_jacobian (Optional["Jacobian_Image"]): A Jacobian image pre-constructed to be passed along instead of constructing new Jacobians
"""
if window is None:
window = self.window
self._param_tuple = None
if parameters is not None:
if as_representation:
self.parameters.vector_set_representation(parameters)
else:
self.parameters.vector_set_values(parameters)
if pass_jacobian is None:
jac_img = self.target[window].jacobian_image(
parameters=self.parameters.vector_identities()
)
else:
jac_img = pass_jacobian
for model in self.models.values():
if isinstance(model, Group_Model):
model.jacobian(
as_representation=as_representation,
pass_jacobian=jac_img,
window=window,
)
else: # fixme, maybe make pass_jacobian be filled internally to each model
jac_img += model.jacobian(
as_representation=as_representation,
pass_jacobian=jac_img,
window=window,
)
return jac_img
def __iter__(self):
return (mod for mod in self.models.values())
@property
def psf_mode(self):
return self._psf_mode
@psf_mode.setter
def psf_mode(self, value):
self._psf_mode = value
for model in self.models.values():
model.psf_mode = value
@property
def target(self):
try:
return self._target
except AttributeError:
return None
@target.setter
def target(self, tar):
if not (tar is None or isinstance(tar, Target_Image)):
raise InvalidTarget("Group_Model target must be a Target_Image instance.")
self._target = tar
if hasattr(self, "models"):
for model in self.models.values():
model.target = tar
[docs]
def get_state(self, save_params = True):
"""Returns a dictionary with information about the state of the model
and its parameters.
"""
state = super().get_state(save_params = save_params)
if save_params:
state["parameters"] = self.parameters.get_state()
if "models" not in state:
state["models"] = {}
for model in self.models.values():
state["models"][model.name] = model.get_state(save_params = False)
return state
[docs]
def load(self, filename="AstroPhot.yaml", new_name = None):
"""Loads an AstroPhot state file and updates this model with the
loaded parameters.
"""
state = AstroPhot_Model.load(filename)
if new_name is None:
new_name = state["name"]
self.name = new_name
if isinstance(state["parameters"], Parameter_Node):
self.parameters = state["parameters"]
else:
self.parameters = Parameter_Node(self.name, state = state["parameters"])
for model in state["models"]:
state["models"][model]["parameters"] = self.parameters[model]
for own_model in self.models.values():
if model == own_model.name:
own_model.load(state["models"][model])
break
else:
self.add_model(
AstroPhot_Model(
name=model, filename=state["models"][model], target=self.target
)
)
self.update_window()