Source code for astrophot.image.jacobian_image

import warnings
from typing import Optional, Union, List

import torch
from torch.nn.functional import pad

from .image_object import Image, Image_List
from .. import AP_config
from ..errors import SpecificationConflict, InvalidImage

__all__ = ["Jacobian_Image", "Jacobian_Image_List"]


######################################################################
[docs] class Jacobian_Image(Image): """Jacobian of a model evaluated in an image. Image object which represents the evaluation of a jacobian on an image. It takes the form of a 3D (Image x Nparameters) tensor. This object can be added other other Jacobian images to build up a full jacobian for a complex model. """ def __init__( self, parameters: List[str], target_identity: str, **kwargs, ): super().__init__(**kwargs) self.target_identity = target_identity self.parameters = list(parameters) if len(self.parameters) != len(set(self.parameters)): raise SpecificationConflict("Every parameter should be unique upon jacobian creation")
[docs] def flatten(self, attribute: str = "data"): return getattr(self, attribute).reshape((-1, len(self.parameters)))
[docs] def copy(self, **kwargs): return super().copy( parameters=self.parameters, target_identity=self.target_identity, **kwargs )
[docs] def get_state(self): state = super().get_state() state["target_identity"] = self.target_identity state["parameters"] = self.parameters return state
[docs] def set_state(self, state): super().set_state(state) self.target_identity = state["target_identity"] self.parameters = state["parameters"]
[docs] def get_fits_state(self): states = super().get_fits_state() for state in states: if state["HEADER"]["IMAGE"] == "PRIMARY": state["HEADER"]["TRGTID"] = self.target_identity state["HEADER"]["PARAMS"] = str(self.parameters) return states
[docs] def set_fits_state(self, states): super().set_fits_state(states) for state in states: if state["HEADER"]["IMAGE"] == "PRIMARY": self.target_identity = state["HEADER"]["TRGTID"] self.parameters = eval(state["HEADER"]["params"])
def __add__(self, other): raise NotImplementedError("Jacobian images cannot add like this, use +=") def __sub__(self, other): raise NotImplementedError("Jacobian images cannot subtract") def __isub__(self, other): raise NotImplementedError("Jacobian images cannot subtract") def __iadd__(self, other): if not isinstance(other, Jacobian_Image): raise InvalidImage("Jacobian images can only add with each other, not: type(other)") # exclude null jacobian images if other.data is None: return self if self.data is None: return other full_window = self.window | other.window self_indices = other.window.get_other_indices(self) other_indices = self.window.get_other_indices(other) for i, other_identity in enumerate(other.parameters): if other_identity in self.parameters: other_loc = self.parameters.index(other_identity) else: self.set_data( torch.cat( ( self.data, torch.zeros( self.data.shape[0], self.data.shape[1], 1, dtype=AP_config.ap_dtype, device=AP_config.ap_device, ), ), dim=2, ), require_shape=False, ) self.parameters.append(other_identity) other_loc = -1 self.data[self_indices[0], self_indices[1], other_loc] += other.data[ other_indices[0], other_indices[1], i ] return self
######################################################################
[docs] class Jacobian_Image_List(Image_List, Jacobian_Image): """For joint modelling, represents Jacobians evaluated on a list of images. Stores jacobians evaluated on a number of image objects. Since jacobian images are aware of the target images they were evaluated on, it is possible to combine this object with other Jacobian_Image_List objects or even Jacobian_Image objects and everything will be sorted into the proper locations of the list, and image. """ def __init__(self, image_list): super().__init__(image_list)
[docs] def flatten(self, attribute="data"): if len(self.image_list) > 1: for image in self.image_list[1:]: if self.image_list[0].parameters != image.parameters: raise SpecificationConflict("Jacobian image list sub-images track different parameters. Please initialize with all parameters that will be used.") return torch.cat(tuple(image.flatten(attribute) for image in self.image_list))
def __add__(self, other): raise NotImplementedError("Jacobian images cannot add like this, use +=") def __sub__(self, other): raise NotImplementedError("Jacobian images cannot subtract") def __isub__(self, other): raise NotImplementedError("Jacobian images cannot subtract") def __iadd__(self, other): if isinstance(other, Jacobian_Image_List): for other_image in other.image_list: for self_image in self.image_list: if other_image.target_identity == self_image.target_identity: self_image += other_image break else: self.image_list.append(other_image) elif isinstance(other, Jacobian_Image): for self_image in self.image_list: if other.target_identity == self_image.target_identity: self_image += other break else: self.image_list.append(other_image) else: for self_image, other_image in zip(self.image_list, other): self_image += other_image return self