Source code for astrophot.image.model_image

import torch
import numpy as np

from .. import AP_config
from .image_object import Image, Image_List
from .window_object import Window
from ..utils.interpolate import shift_Lanczos_torch
from ..errors import InvalidData, SpecificationConflict, InvalidImage

__all__ = ["Model_Image", "Model_Image_List"]


######################################################################
[docs] class Model_Image(Image): """Image object which represents the sampling of a model at the given coordinates of the image. Extra arithmetic operations are available which can update model values in the image. The whole model can be shifted by less than a pixel to account for sub-pixel accuracy. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.target_identity = kwargs.get("target_identity", None) self.to()
[docs] def clear_image(self): self.data = torch.zeros_like(self.data)
[docs] def shift_origin(self, shift, is_prepadded=True): self.window.shift(shift) pix_shift = self.plane_to_pixel_delta(shift) if torch.any(torch.abs(pix_shift) > 1): raise NotImplementedError( "Shifts larger than 1 pixel are currently not handled" ) self.data = shift_Lanczos_torch( self.data, pix_shift[0], pix_shift[1], min(min(self.data.shape), 10), dtype=AP_config.ap_dtype, device=AP_config.ap_device, img_prepadded=is_prepadded, )
[docs] def get_window(self, window: Window, **kwargs): return super().get_window( window, target_identity=self.target_identity, **kwargs )
[docs] def reduce(self, scale, **kwargs): return super().reduce(scale, target_identity=self.target_identity, **kwargs)
[docs] def replace(self, other, data=None): if isinstance(other, Image): if self.window.overlap_frac(other.window) == 0.0: # fixme control flow return other_indices = self.window.get_other_indices(other) self_indices = other.window.get_other_indices(self) if ( self.data[self_indices].nelement() == 0 or other.data[other_indices].nelement() == 0 ): return self.data[self_indices] = other.data[other_indices] elif isinstance(other, Window): self.data[self.window.get_self_indices(other)] = data else: self.data = other
[docs] def get_state(self): state = super().get_state() state["target_identity"] = self.target_identity return state
[docs] def set_state(self, state): super().set_state(state) self.target_identity = target_identity
[docs] def get_fits_state(self): states = super().get_fits_state() for state in states: if state["HEADER"]["IMAGE"] == "PRIMARY": state["HEADER"]["TRGTID"] = self.target_identity return states
[docs] def set_fits_state(self, states): super().set_fits_state(states) for state in states: if state["HEADER"]["IMAGE"] == "PRIMARY": self.target_identity = state["HEADER"]["TRGTID"]
######################################################################
[docs] class Model_Image_List(Image_List, Model_Image): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) if not all(isinstance(image, Model_Image) for image in self.image_list): raise InvalidImage(f"Model_Image_List can only hold Model_Image objects, not {tuple(type(image) for image in self.image_list)}")
[docs] def clear_image(self): for image in self.image_list: image.clear_image()
[docs] def shift_origin(self, shift): raise NotImplementedError()
[docs] def replace(self, other, data=None): if data is None: for image, oth in zip(self.image_list, other): image.replace(oth) else: for image, oth, dat in zip(self.image_list, other, data): image.replace(oth, dat)
@property def target_identity(self): targets = tuple(image.target_identity for image in self.image_list) if any(tar_id is None for tar_id in targets): return None return targets def __isub__(self, other): if isinstance(other, Model_Image_List): for other_image, zip_self_image in zip(other.image_list, self.image_list): if other_image.target_identity is None or self.target_identity is None: zip_self_image -= other_image continue for self_image in self.image_list: if other_image.target_identity == self_image.target_identity: self_image -= other_image break else: self.image_list.append(other_image) elif isinstance(other, Model_Image): if other.target_identity is None or zip_self_image.target_identity is None: zip_self_image -= other_image else: for self_image in self.image_list: if other.target_identity == self_image.target_identity: self_image -= other break else: self.image_list.append(other) else: for self_image, other_image in zip(self.image_list, other): self_image -= other_image return self def __iadd__(self, other): if isinstance(other, Model_Image_List): for other_image, zip_self_image in zip(other.image_list, self.image_list): if other_image.target_identity is None or self.target_identity is None: zip_self_image += other_image continue for self_image in self.image_list: if other_image.target_identity == self_image.target_identity: self_image += other_image break else: self.image_list.append(other_image) elif isinstance(other, Model_Image): if other.target_identity is None or self.target_identity is None: for self_image in self.image_list: self_image += other else: for self_image in self.image_list: if other.target_identity == self_image.target_identity: self_image += other break else: self.image_list.append(other) else: for self_image, other_image in zip(self.image_list, other): self_image += other_image return self