from copy import copy
from time import time
import io
from typing import Optional
from functools import partial
import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml
from ..utils.conversions.optimization import cyclic_difference_np
from ..utils.conversions.dict_to_hdf5 import dict_to_hdf5, hdf5_to_dict
from ..utils.optimization import reduced_chi_squared
from ..utils.decorators import ignore_numpy_warnings, default_internal
from ..image import Model_Image, Window, Target_Image, Target_Image_List
from ..param import Parameter_Node
from ._shared_methods import select_target, select_sample
from .. import AP_config
from ..errors import NameNotAllowed, InvalidTarget, UnrecognizedModel, InvalidWindow
__all__ = ("AstroPhot_Model",)
def all_subclasses(cls):
return set(cls.__subclasses__()).union(
[s for c in cls.__subclasses__() for s in all_subclasses(c)]
)
######################################################################
[docs]
class AstroPhot_Model(object):
"""Core class for all AstroPhot models and model like objects. This
class defines the signatures to interact with AstroPhot models
both for users and internal functions.
Basic usage:
.. code-block:: python
import astrophot as ap
# Create a model object
model = ap.models.AstroPhot_Model(
name = "unique name",
model_type = <choose a model type>,
target = <Target_Image object>,
window = [[a,b],[c,d]], <widnow pixel coordinates>,
parameters = <dict of parameter specifications if desired>,
)
# Initialize parameters that weren't set on creation
model.initialize()
# Fit model to target
result = ap.fit.lm(model, verbose=1).fit()
# Plot the model
fig, ax = plt.subplots()
ap.plots.model_image(fig, ax, model)
plt.show()
# Sample the model
img = model()
pixels = img.data
AstroPhot models are one of the main ways that one interacts with
the code, either by setting model parameters or passing models to
other objects, one can perform a huge variety of fitting
tasks. The subclass `Component_Model` should be thought of as the
basic unit when constructing a model of an image while a
`Group_Model` is a composite structure that may represent a
complex object, a region of an image, or even a model spanning
many images. Constructing the `Component_Model`s is where most
work goes, these store the actual parameters that will be
optimized. It is important to remmeber that a `Component_Model`
only ever applies to a single image and a single component (star,
galaxy, or even sub-component of one of those) in that image.
A complex representation is made by stacking many
`Component_Model`s together, in total this may result in a very
large number of parameters. Trying to find starting values for all
of these parameters can be tedious and error prone, so instead all
built-in AstroPhot models can self initialize and find reasonable
starting parameters for most situations. Even still one may find
that for extremely complex fits, it is more stable to first run an
iterative fitter before global optimization to start the models in
better initial positions.
Args:
name (Optional[str]): every AstroPhot model should have a unique name
model_type (str): a model type string can determine which kind of AstroPhot model is instantiated.
target (Optional[Target_Image]): A Target_Image object which stores information about the image which the model is trying to fit.
filename (Optional[str]): name of a file to load AstroPhot parameters, window, and name. The model will still need to be told its target, device, and other information
"""
model_type = "model"
default_uncertainty = 1e-2 # During initialization, uncertainty will be assumed 1% of initial value if no uncertainty is given
useable = False
model_names = []
def __new__(cls, *, filename=None, model_type=None, **kwargs):
if filename is not None:
state = AstroPhot_Model.load(filename)
MODELS = AstroPhot_Model.List_Models()
for M in MODELS:
if M.model_type == state["model_type"]:
return super(AstroPhot_Model, cls).__new__(M)
else:
raise UnrecognizedModel(
f"Unknown AstroPhot model type: {state['model_type']}"
)
elif model_type is not None:
MODELS = AstroPhot_Model.List_Models() # all_subclasses(AstroPhot_Model)
for M in MODELS:
if M.model_type == model_type:
return super(AstroPhot_Model, cls).__new__(M)
else:
raise UnrecognizedModel(f"Unknown AstroPhot model type: {model_type}")
return super().__new__(cls)
def __init__(self, *, name=None, target=None, window=None, locked=False, **kwargs):
if not hasattr(self, "_window"):
self._window = None
if not hasattr(self, "_target"):
self._target = None
self.name = name
AP_config.ap_logger.debug("Creating model named: {self.name}")
self.parameters = Parameter_Node(self.name)
self.target = target
self.window = window
self._locked = locked
self.mask = kwargs.get("mask", None)
@property
def name(self):
"""The name for this model as a string. The name should be unique
though this is not enforced here. The name should not contain
the `|` or `:` characters as these are reserved for internal
use. If one tries to set the name of a model as `None` (for
example by not providing a name for the model) then a new
unique name will be generated. The unique name is just the
model type for this model with an extra unique id appended to
the end in the format of `[#]` where `#` is a number that
increases until a unique name is found.
"""
return self._name
@name.setter
def name(self, name):
try:
if name == self.name:
return
except AttributeError:
pass
if name is None:
i = 0
while True:
proposed_name = f"{self.model_type} [{i}]"
if proposed_name in AstroPhot_Model.model_names:
i += 1
else:
name = proposed_name
break
if ":" in name or "|" in name:
raise NameNotAllowed("characters '|' and ':' are reserved for internal model operations please do not include these in a model name")
self._name = name
AstroPhot_Model.model_names.append(name)
[docs]
@torch.no_grad()
@ignore_numpy_warnings
@select_target
@default_internal
def initialize(self, target=None, parameters=None, **kwargs):
"""When this function finishes, all parameters should have numerical
values (non None) that are reasonable estimates of the final
values.
"""
pass
[docs]
def make_model_image(self, window: Optional[Window] = None):
"""This is called to create a blank `Model_Image` object of the
correct format for this model. This is typically used
internally to construct the model image before filling the
pixel values with the model.
"""
if window is None:
window = self.window
else:
window = self.window & window
return self.target[window].model_image()
[docs]
def sample(self, image=None, window=None, parameters=None, *args, **kwargs):
"""Calling this function should fill the given image with values
sampled from the given model.
"""
pass
[docs]
def negative_log_likelihood(
self,
parameters=None,
as_representation=False,
):
"""
Compute the negative log likelihood of the model wrt the target image in the appropriate window.
"""
if parameters is not None:
if as_representation:
self.parameters.vector_set_representation(parameters)
else:
self.parameters.vector_set_values(parameters)
model = self.sample()
data = self.target[self.window]
weight = data.weight
if self.target.has_mask:
if isinstance(data, Target_Image_List):
mask = tuple(torch.logical_not(submask) for submask in data.mask)
chi2 = sum(torch.sum(((mo - da).data ** 2 * wgt)[ma]) / 2.0 for mo, da, wgt, ma in zip(model, data, weight, mask))
else:
mask = torch.logical_not(data.mask)
chi2 = torch.sum(((model - data).data ** 2 * weight)[mask]) / 2.0
else:
if isinstance(data, Target_Image_List):
chi2 = sum(torch.sum(((mo - da).data ** 2 * wgt)) / 2.0 for mo, da, wgt in zip(model, data, weight))
else:
chi2 = torch.sum(((model - data).data ** 2 * weight)) / 2.0
return chi2
[docs]
def jacobian(
self,
parameters=None,
**kwargs,
):
raise NotImplementedError("please use a subclass of AstroPhot_Model")
[docs]
@default_internal
def total_flux(self, parameters=None, window=None, image=None):
F = self(parameters = parameters, window=None, image=None)
return torch.sum(F.data)
@property
def window(self):
"""The window defines a region on the sky in which this model will be
optimized and typically evaluated. Two models with
non-overlapping windows are in effect independent of each
other. If there is another model with a window that spans both
of them, then they are tenuously conected.
If not provided, the model will assume a window equal to the
target it is fitting. Note that in this case the window is not
explicitly set to the target window, so if the model is moved
to another target then the fitting window will also change.
"""
if self._window is None:
if self.target is None:
raise ValueError(
"This model has no target or window, these must be provided by the user"
)
return self.target.window.copy()
return self._window
[docs]
def set_window(self, window):
if window is None:
# If no window given, set to none
self._window = None
elif isinstance(window, Window):
# If window object given, use that
self._window = window
elif len(window) == 2:
# If window given in pixels, use relative to target
self._window = self.target.window.copy().crop_to_pixel(window)
else:
raise InvalidWindow(f"Unrecognized window format: {str(window)}")
@window.setter
def window(self, window):
self.set_window(window)
@property
def target(self):
return self._target
@target.setter
def target(self, tar):
if not (tar is None or isinstance(tar, Target_Image)):
raise InvalidTarget("AstroPhot_Model target must be a Target_Image instance.")
self._target = tar
@property
def locked(self):
"""Set when the model should remain fixed going forward. This model
will be bypassed when fitting parameters, however it will
still be sampled for generating the model image.
Warning:
This feature is not yet fully functional and should be avoided for now. It is included here for the sake of testing.
"""
return self._locked
@locked.setter
def locked(self, val):
self._locked = val
@property
def parameter_order(self):
"""Returns the model parameters in the order they are kept for
flattening, such as when evaluating the model with a tensor of
parameter values.
"""
return tuple(P.name for P in self.parameters)
def __str__(self):
"""String representation for the model."""
return self.parameters.__str__()
def __repr__(self):
"""Detailed string representation for the model."""
return yaml.dump(self.get_state(), indent=2)
[docs]
def get_state(self, *args, **kwargs):
"""Returns a dictionary of the state of the model with its name,
type, parameters, and other important infomration. This
dictionary is what gets saved when a model saves to disk.
"""
state = {
"name": self.name,
"model_type": self.model_type,
}
return state
[docs]
def save(self, filename="AstroPhot.yaml"):
"""Saves a model object to disk. By default the file type should be
yaml, this is the only file type which gets tested, though
other file types such as json and hdf5 should work.
"""
if filename.endswith(".yaml"):
state = self.get_state()
with open(filename, "w") as f:
yaml.dump(state, f, indent=2)
elif filename.endswith(".json"):
import json
state = self.get_state()
with open(filename, "w") as f:
json.dump(state, f, indent=2)
elif filename.endswith(".hdf5"):
import h5py
state = self.get_state()
with h5py.File(filename, "w") as F:
dict_to_hdf5(F, state)
else:
if isinstance(filename, str) and "." in filename:
raise ValueError(
f"Unrecognized filename format: {filename[filename.find('.'):]}, must be one of: .json, .yaml, .hdf5"
)
else:
raise ValueError(
f"Unrecognized filename format: {str(filename)}, must be one of: .json, .yaml, .hdf5"
)
[docs]
@classmethod
def load(cls, filename="AstroPhot.yaml"):
"""
Loads a saved model object.
"""
if isinstance(filename, dict):
state = filename
elif isinstance(filename, io.TextIOBase):
state = yaml.load(filename, Loader=yaml.FullLoader)
elif filename.endswith(".yaml"):
with open(filename, "r") as f:
state = yaml.load(f, Loader=yaml.FullLoader)
elif filename.endswith(".json"):
import json
with open(filename, "r") as f:
state = json.load(f)
elif filename.endswith(".hdf5"):
import h5py
with h5py.File(filename, "r") as F:
state = hdf5_to_dict(F)
else:
if isinstance(filename, str) and "." in filename:
raise ValueError(
f"Unrecognized filename format: {filename[filename.find('.'):]}, must be one of: .json, .yaml, .hdf5"
)
else:
raise ValueError(
f"Unrecognized filename format: {str(filename)}, must be one of: .json, .yaml, .hdf5 or python dictionary."
)
return state
[docs]
@classmethod
def List_Models(cls, useable=None):
MODELS = all_subclasses(cls)
if useable is not None:
for model in list(MODELS):
if model.useable is not useable:
MODELS.remove(model)
return MODELS
[docs]
@classmethod
def List_Model_Names(cls, useable=None):
MODELS = cls.List_Models(useable=useable)
names = []
for model in MODELS:
names.append(model.model_type)
return list(sorted(names, key=lambda n: n[::-1]))
def __eq__(self, other):
return self is other
def __getitem__(self, key):
return self.parameters[key]
def __contains__(self, key):
return self.parameters.__contains__(key)
def __del__(self):
try:
i = AstroPhot_Model.model_names.index(self.name)
AstroPhot_Model.model_names.pop(i)
except:
pass
@select_sample
def __call__(
self,
image=None,
parameters=None,
window=None,
as_representation=False,
**kwargs,
):
if parameters is None:
parameters = self.parameters
elif isinstance(parameters, torch.Tensor):
if as_representation:
self.parameters.vector_set_representation(parameters)
else:
self.parameters.vector_set_values(parameters)
parameters = self.parameters
return self.sample(image=image, window=window, parameters=parameters, **kwargs)