Source code for astrophot.models.psf_model_object

from typing import Optional

import torch

from .core_model import AstroPhot_Model
from ..image import (
    Image,
    Model_Image,
    Window,
    PSF_Image,
    Image_List,
)
from ._shared_methods import select_target
from ..utils.decorators import default_internal, ignore_numpy_warnings
from ..param import Param_Unlock, Param_SoftLimits, Parameter_Node
from ..errors import SpecificationConflict


__all__ = ["PSF_Model"]


[docs] class PSF_Model(AstroPhot_Model): """Prototype point source (typically a star) model, to be subclassed by other point source models which define specific behavior. PSF_Models behave differently than component models. For starters, their target image must be a PSF_Image object instead of a Target_Image object. PSF_Models also don't define a "center" variable since their center is always (0,0) just like a PSF_Image. A PSF_Model will never be convolved with a PSF_Model (that's it's job!), so a lot of the sampling method is simpler. """ # Specifications for the model parameters including units, value, uncertainty, limits, locked, and cyclic parameter_specs = { "center": {"units": "arcsec", "value": (0.,0.), "uncertainty": (0.1, 0.1), "locked": True}, } # Fixed order of parameters for all methods that interact with the list of parameters _parameter_order = ("center", ) model_type = f"psf {AstroPhot_Model.model_type}" useable = False model_integrated = None # The sampled PSF will be normalized to a total flux of 1 within the window normalize_psf = True # Method for initial sampling of model sampling_mode = "midpoint" # midpoint, trapezoid, simpson # Level to which each pixel should be evaluated sampling_tolerance = 1e-2 # Integration scope for model integrate_mode = "threshold" # none, threshold, full* # Maximum recursion depth when performing sub pixel integration integrate_max_depth = 3 # Amount by which to subdivide pixels when doing recursive pixel integration integrate_gridding = 5 # The initial quadrature level for sub pixel integration. Please always choose an odd number 3 or higher integrate_quad_level = 3 # Maximum size of parameter list before jacobian will be broken into smaller chunks, this is helpful for limiting the memory requirements to build a model, lower jacobian_chunksize is slower but uses less memory jacobian_chunksize = 10 image_chunksize = 1000 # Softening length used for numerical stability and/or integration stability to avoid discontinuities (near R=0) softening = 1e-3 # Parameters which are treated specially by the model object and should not be updated directly when initializing special_kwargs = ["parameters", "filename", "model_type"] track_attrs = [ "sampling_mode", "sampling_tolerance", "integrate_mode", "integrate_max_depth", "integrate_gridding", "integrate_quad_level", "jacobian_chunksize", "softening", ] def __init__(self, *, name=None, **kwargs): self._target_identity = None super().__init__(name=name,**kwargs) # Set any user defined attributes for the model for kwarg in kwargs: # fixme move to core model? # Skip parameters with special behaviour if kwarg in self.special_kwargs: continue # Set the model parameter setattr(self, kwarg, kwargs[kwarg]) # If loading from a file, get model configuration then exit __init__ if "filename" in kwargs: self.load(kwargs["filename"], new_name = name) return self.parameter_specs = self.build_parameter_specs( kwargs.get("parameters", None) ) with torch.no_grad(): self.build_parameters() if isinstance(kwargs.get("parameters", None), torch.Tensor): self.parameters.value = kwargs["parameters"] assert torch.allclose(self.window.center, torch.zeros_like(self.window.center)), "PSF models must always be centered at (0,0)" # Initialization functions ######################################################################
[docs] @torch.no_grad() @ignore_numpy_warnings @select_target @default_internal def initialize( self, target: Optional["PSF_Image"] = None, parameters: Optional[Parameter_Node] = None, **kwargs, ): """Determine initial values for the center coordinates. This is done with a local center of mass search which iterates by finding the center of light in a window, then iteratively updates until the iterations move by less than a pixel. Args: target (Optional[Target_Image]): A target image object to use as a reference when setting parameter values """ super().initialize(target=target, parameters=parameters)
# Fit loop functions ######################################################################
[docs] def evaluate_model( self, X: Optional[torch.Tensor] = None, Y: Optional[torch.Tensor] = None, image: Optional[Image] = None, parameters: "Parameter_Node" = None, **kwargs, ): """Evaluate the model on every pixel in the given image. The basemodel object simply returns zeros, this function should be overloaded by subclasses. Args: image (Image): The image defining the set of pixels on which to evaluate the model """ if X is None or Y is None: Coords = image.get_coordinate_meshgrid() X, Y = Coords - parameters["center"].value[..., None, None] return torch.zeros_like(X) # do nothing in base model
[docs] def make_model_image(self, window: Optional[Window] = None): """This is called to create a blank `Model_Image` object of the correct format for this model. This is typically used internally to construct the model image before filling the pixel values with the model. """ if window is None: window = self.window else: window = self.window & window return self.target[window].blank_copy()
[docs] def sample( self, image: Optional[Image] = None, window: Optional[Window] = None, parameters: Optional[Parameter_Node] = None, ): """Evaluate the model on the space covered by an image object. This function properly calls integration methods. This should not be overloaded except in special cases. This function is designed to compute the model on a given image or within a specified window. It takes care of sub-pixel sampling, recursive integration for high curvature regions, and proper alignment of the computed model with the original pixel grid. The final model is then added to the requested image. Args: image (Optional[Image]): An AstroPhot Image object (likely a Model_Image) on which to evaluate the model values. If not provided, a new Model_Image object will be created. window (Optional[Window]): A window within which to evaluate the model. Should only be used if a subset of the full image is needed. If not provided, the entire image will be used. Returns: Image: The image with the computed model values. """ # Image on which to evaluate model if image is None: image = self.make_model_image(window=window) # Window within which to evaluate model if window is None: working_window = image.window.copy() else: working_window = window.copy() # Parameters with which to evaluate the model if parameters is None: parameters = self.parameters # Create an image to store pixel samples working_image = Model_Image( window=working_window ) if self.model_integrated is True: # Evaluate the model on the image Coords = image.get_coordinate_meshgrid() X, Y = Coords - parameters["center"].value[..., None, None] working_image.data = self.evaluate_model( X=X, Y=Y, image=working_image, parameters=parameters ) elif self.model_integrated is False: # Evaluate the model on the image reference, deep = self._sample_init( image=working_image, parameters=parameters, center=parameters["center"].value, ) # Super-resolve and integrate where needed deep = self._sample_integrate( deep, reference, working_image, parameters, center=torch.zeros_like(working_image.center), ) # Add the sampled/integrated pixels to the requested image working_image.data += deep else: raise SpecificationConflict("PSF model 'model_integrated' should be either True or False") # normalize to total flux 1 if self.normalize_psf: working_image.data /= torch.sum(working_image.data) if self.mask is not None: working_image.data = working_image.data * torch.logical_not(self.mask) image += working_image return image
@property def target(self): try: return self._target except AttributeError: return None @target.setter def target(self, tar): assert tar is None or isinstance(tar, PSF_Image) # If a target image list is assigned, pick out the target appropriate for this model if isinstance(tar, Image_List) and self._target_identity is not None: for subtar in tar: if subtar.identity == self._target_identity: usetar = subtar break else: raise KeyError( f"Could not find target in Target_Image_List with matching identity to {self.name}: {self._target_identity}" ) else: usetar = tar self._target = usetar # Remember the target identity to use try: self._target_identity = self._target.identity except AttributeError: pass
[docs] def get_state(self, save_params = True): """Returns a dictionary with a record of the current state of the model. Specifically, the current parameter settings and the window for this model. From this information it is possible for the model to re-build itself lated when loading from disk. Note that the target image is not saved, this must be reset when loading the model. """ state = super().get_state() state["window"] = self.window.get_state() if save_params: state["parameters"] = self.parameters.get_state() state["target_identity"] = self._target_identity for key in self.track_attrs: if getattr(self, key) != getattr(self.__class__, key): state[key] = getattr(self, key) return state
# Extra background methods for the basemodel ###################################################################### from ._model_methods import radius_metric from ._model_methods import angular_metric from ._model_methods import _sample_init from ._model_methods import _sample_integrate from ._model_methods import _integrate_reference from ._model_methods import build_parameter_specs from ._model_methods import build_parameters from ._model_methods import jacobian from ._model_methods import _chunk_jacobian from ._model_methods import _chunk_image_jacobian from ._model_methods import load