Source code for astrophot.models.model_object

from functools import partial
from typing import Optional, Union
import io

import numpy as np
import torch

from .core_model import AstroPhot_Model
from ..image import (
    Image,
    Model_Image,
    Window,
    PSF_Image,
    Target_Image,
    Target_Image_List,
    Image,
)
from ..param import Parameter_Node, Param_Unlock, Param_SoftLimits
from ..utils.initialize import center_of_mass
from ..utils.decorators import ignore_numpy_warnings, default_internal
from ._shared_methods import select_target
from .. import AP_config
from ..errors import InvalidTarget

__all__ = ["Component_Model"]


[docs] class Component_Model(AstroPhot_Model): """Component_Model(name, target, window, locked, **kwargs) Component_Model is a base class for models that represent single objects or parametric forms. It provides the basis for subclassing models and requires the definition of parameters, initialization, and model evaluation functions. This class also handles integration, PSF convolution, and computing the Jacobian matrix. Attributes: parameter_specs (dict): Specifications for the model parameters. _parameter_order (tuple): Fixed order of parameters. psf_mode (str): Technique and scope for PSF convolution. sampling_mode (str): Method for initial sampling of model. Can be one of midpoint, trapezoid, simpson. Default: midpoint sampling_tolerance (float): accuracy to which each pixel should be evaluated. Default: 1e-2 integrate_mode (str): Integration scope for the model. One of none, threshold, full where threshold will select which pixels to integrate while full (in development) will integrate all pixels. Default: threshold integrate_max_depth (int): Maximum recursion depth when performing sub pixel integration. integrate_gridding (int): Amount by which to subdivide pixels when doing recursive pixel integration. integrate_quad_level (int): The initial quadrature level for sub pixel integration. Please always choose an odd number 3 or higher. softening (float): Softening length used for numerical stability and integration stability to avoid discontinuities (near R=0). Effectively has units of arcsec. Default: 1e-5 jacobian_chunksize (int): Maximum size of parameter list before jacobian will be broken into smaller chunks. special_kwargs (list): Parameters which are treated specially by the model object and should not be updated directly. useable (bool): Indicates if the model is useable. Methods: initialize: Determine initial values for the center coordinates. sample: Evaluate the model on the space covered by an image object. jacobian: Compute the Jacobian matrix for this model. """ # Specifications for the model parameters including units, value, uncertainty, limits, locked, and cyclic parameter_specs = { "center": {"units": "arcsec", "uncertainty": [0.1, 0.1]}, } # Fixed order of parameters for all methods that interact with the list of parameters _parameter_order = ("center",) # Scope for PSF convolution psf_mode = "none" # none, full # Technique for PSF convolution psf_convolve_mode = "fft" # fft, direct # Method to use when performing subpixel shifts. bilinear set by default for stability around pixel edges, though lanczos:3 is also fairly stable, and all are stable when away from pixel edges psf_subpixel_shift = "bilinear" # bilinear, lanczos:2, lanczos:3, lanczos:5, none # Method for initial sampling of model sampling_mode = ( "midpoint" # midpoint, trapezoid, simpsons, quad:x (where x is a positive integer) ) # Level to which each pixel should be evaluated sampling_tolerance = 1e-2 # Integration scope for model integrate_mode = "threshold" # none, threshold # Maximum recursion depth when performing sub pixel integration integrate_max_depth = 3 # Amount by which to subdivide pixels when doing recursive pixel integration integrate_gridding = 5 # The initial quadrature level for sub pixel integration. Please always choose an odd number 3 or higher integrate_quad_level = 3 # Maximum size of parameter list before jacobian will be broken into smaller chunks, this is helpful for limiting the memory requirements to build a model, lower jacobian_chunksize is slower but uses less memory jacobian_chunksize = 10 image_chunksize = 1000 # Softening length used for numerical stability and/or integration stability to avoid discontinuities (near R=0) softening = 1e-3 # Parameters which are treated specially by the model object and should not be updated directly when initializing special_kwargs = ["parameters", "filename", "model_type"] track_attrs = [ "psf_mode", "psf_convolve_mode", "psf_subpixel_shift", "sampling_mode", "sampling_tolerance", "integrate_mode", "integrate_max_depth", "integrate_gridding", "integrate_quad_level", "jacobian_chunksize", "image_chunksize", "softening", ] useable = False def __init__(self, *, name=None, **kwargs): self._target_identity = None super().__init__(name=name, **kwargs) self.psf = None self.psf_aux_image = None # Set any user defined attributes for the model for kwarg in kwargs: # fixme move to core model? # Skip parameters with special behaviour if kwarg in self.special_kwargs: continue # Set the model parameter setattr(self, kwarg, kwargs[kwarg]) # If loading from a file, get model configuration then exit __init__ if "filename" in kwargs: self.load(kwargs["filename"], new_name=name) return self.parameter_specs = self.build_parameter_specs(kwargs.get("parameters", None)) with torch.no_grad(): self.build_parameters() if isinstance(kwargs.get("parameters", None), torch.Tensor): self.parameters.value = kwargs["parameters"]
[docs] def set_aux_psf(self, aux_psf, add_parameters=True): """Set the PSF for this model as an auxiliary psf model. This psf model will be resampled as part of the model sampling step to track changes made during fitting. Args: aux_psf: The auxiliary psf model add_parameters: if true, the parameters of the auxiliary psf model will become model parameters for this model as well. """ self._psf = aux_psf if add_parameters: self.parameters.link(aux_psf.parameters)
@property def psf(self): if self._psf is None: try: return self.target.psf except AttributeError: return None return self._psf @psf.setter def psf(self, val): if val is None: self._psf = None elif isinstance(val, PSF_Image): self._psf = val elif isinstance(val, AstroPhot_Model): self.set_aux_psf(val) else: self._psf = PSF_Image(data=val, pixelscale=self.target.pixelscale) AP_config.ap_logger.warning( "Setting PSF with pixel matrix, assuming target pixelscale is the same as " "PSF pixelscale. To remove this warning, set PSFs as an ap.image.PSF_Image " "or ap.models.AstroPhot_Model object instead." ) # Initialization functions ######################################################################
[docs] @torch.no_grad() @ignore_numpy_warnings @select_target @default_internal def initialize( self, target: Optional["Target_Image"] = None, parameters: Optional[Parameter_Node] = None, **kwargs, ): """Determine initial values for the center coordinates. This is done with a local center of mass search which iterates by finding the center of light in a window, then iteratively updates until the iterations move by less than a pixel. Args: target (Optional[Target_Image]): A target image object to use as a reference when setting parameter values """ super().initialize(target=target, parameters=parameters) # Get the sub-image area corresponding to the model image target_area = target[self.window] # Use center of window if a center hasn't been set yet if parameters["center"].value is None: with ( Param_Unlock(parameters["center"]), Param_SoftLimits(parameters["center"]), ): parameters["center"].value = self.window.center else: return if parameters["center"].locked: return # Convert center coordinates to target area array indices init_icenter = target_area.plane_to_pixel(parameters["center"].value) # Compute center of mass in window COM = center_of_mass( ( init_icenter[1].detach().cpu().item(), init_icenter[0].detach().cpu().item(), ), target_area.data.detach().cpu().numpy(), ) if np.any(np.array(COM) < 0) or np.any(np.array(COM) >= np.array(target_area.data.shape)): AP_config.ap_logger.warning("center of mass failed, using center of window") return COM = (COM[1], COM[0]) # Convert center of mass indices to coordinates COM_center = target_area.pixel_to_plane( torch.tensor(COM, dtype=AP_config.ap_dtype, device=AP_config.ap_device) ) # Set the new coordinates as the model center parameters["center"].value = COM_center
# Fit loop functions ######################################################################
[docs] def evaluate_model( self, X: Optional[torch.Tensor] = None, Y: Optional[torch.Tensor] = None, image: Optional[Image] = None, parameters: Parameter_Node = None, **kwargs, ): """Evaluate the model on every pixel in the given image. The basemodel object simply returns zeros, this function should be overloaded by subclasses. Args: image (Image): The image defining the set of pixels on which to evaluate the model """ if X is None or Y is None: Coords = image.get_coordinate_meshgrid() X, Y = Coords - parameters["center"].value[..., None, None] return torch.zeros_like(X) # do nothing in base model
[docs] def sample( self, image: Optional[Image] = None, window: Optional[Window] = None, parameters: Optional[Parameter_Node] = None, ): """Evaluate the model on the space covered by an image object. This function properly calls integration methods and PSF convolution. This should not be overloaded except in special cases. This function is designed to compute the model on a given image or within a specified window. It takes care of sub-pixel sampling, recursive integration for high curvature regions, PSF convolution, and proper alignment of the computed model with the original pixel grid. The final model is then added to the requested image. Args: image (Optional[Image]): An AstroPhot Image object (likely a Model_Image) on which to evaluate the model values. If not provided, a new Model_Image object will be created. window (Optional[Window]): A window within which to evaluate the model. Should only be used if a subset of the full image is needed. If not provided, the entire image will be used. Returns: Image: The image with the computed model values. """ # Image on which to evaluate model if image is None: image = self.make_model_image(window=window) # Window within which to evaluate model if window is None: working_window = image.window.copy() else: working_window = window.copy() # Parameters with which to evaluate the model if parameters is None: parameters = self.parameters if "window" in self.psf_mode: raise NotImplementedError("PSF convolution in sub-window not available yet") if "full" in self.psf_mode: if isinstance(self.psf, AstroPhot_Model): psf = self.psf( parameters=parameters[self.psf.name], ) else: psf = self.psf psf_upscale = torch.round(image.pixel_length / psf.pixel_length).int() # Add border for psf convolution edge effects, will be cropped out later working_window.pad_pixel(psf.psf_border_int) # Make the image object to which the samples will be tracked working_image = Model_Image(window=working_window) # Sub pixel shift to align the model with the center of a pixel if self.psf_subpixel_shift != "none": pixel_center = working_image.plane_to_pixel(parameters["center"].value) center_shift = pixel_center - torch.round(pixel_center) working_image.header.pixel_shift(center_shift) else: center_shift = None # Evaluate the model at the current resolution reference, deep = self._sample_init( image=working_image, parameters=parameters, center=parameters["center"].value, ) # If needed, super-resolve the image in areas of high curvature so pixels are properly sampled deep = self._sample_integrate( deep, reference, working_image, parameters, parameters["center"].value ) # update the image with the integrated pixels working_image.data += deep # Convolve the PSF self._sample_convolve(working_image, center_shift, psf, self.psf_subpixel_shift) # Shift image back to align with original pixel grid if self.psf_subpixel_shift != "none": working_image.header.pixel_shift(-center_shift) # Add the sampled/integrated/convolved pixels to the requested image working_image = working_image.reduce(psf_upscale).crop(psf.psf_border_int) else: # Create an image to store pixel samples working_image = Model_Image(pixelscale=image.pixelscale, window=working_window) # Evaluate the model on the image reference, deep = self._sample_init( image=working_image, parameters=parameters, center=parameters["center"].value, ) # Super-resolve and integrate where needed deep = self._sample_integrate( deep, reference, working_image, parameters, center=parameters["center"].value, ) # Add the sampled/integrated pixels to the requested image working_image.data += deep if self.mask is not None: working_image.data = working_image.data * torch.logical_not(self.mask) image += working_image return image
@property def target(self): return self._target @target.setter def target(self, tar): if not (tar is None or isinstance(tar, Target_Image)): raise InvalidTarget("AstroPhot_Model target must be a Target_Image instance.") # If a target image list is assigned, pick out the target appropriate for this model if isinstance(tar, Target_Image_List) and self._target_identity is not None: for subtar in tar: if subtar.identity == self._target_identity: usetar = subtar break else: raise InvalidTarget( f"Could not find target in Target_Image_List with matching identity " f"to {self.name}: {self._target_identity}" ) else: usetar = tar self._target = usetar # Remember the target identity to use try: self._target_identity = self._target.identity except AttributeError: pass
[docs] def get_state(self, save_params=True): """Returns a dictionary with a record of the current state of the model. Specifically, the current parameter settings and the window for this model. From this information it is possible for the model to re-build itself lated when loading from disk. Note that the target image is not saved, this must be reset when loading the model. """ state = super().get_state() state["window"] = self.window.get_state() if save_params: state["parameters"] = self.parameters.get_state() state["target_identity"] = self._target_identity if isinstance(self._psf, PSF_Image) or isinstance(self._psf, AstroPhot_Model): state["psf"] = self._psf.get_state() for key in self.track_attrs: if getattr(self, key) != getattr(self.__class__, key): state[key] = getattr(self, key) return state
# Extra background methods for the basemodel ###################################################################### from ._model_methods import radius_metric from ._model_methods import angular_metric from ._model_methods import _sample_init from ._model_methods import _sample_integrate from ._model_methods import _sample_convolve from ._model_methods import _integrate_reference from ._model_methods import _shift_psf from ._model_methods import build_parameter_specs from ._model_methods import build_parameters from ._model_methods import jacobian from ._model_methods import _chunk_jacobian from ._model_methods import _chunk_image_jacobian from ._model_methods import load