import torch
import numpy as np
from .. import AP_config
from .image_object import Image, Image_List
from .window_object import Window
from ..utils.interpolate import shift_Lanczos_torch
from ..errors import InvalidData, SpecificationConflict, InvalidImage
__all__ = ["Model_Image", "Model_Image_List"]
######################################################################
[docs]
class Model_Image(Image):
"""Image object which represents the sampling of a model at the given
coordinates of the image. Extra arithmetic operations are
available which can update model values in the image. The whole
model can be shifted by less than a pixel to account for sub-pixel
accuracy.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.target_identity = kwargs.get("target_identity", None)
self.to()
[docs]
def clear_image(self):
self.data = torch.zeros_like(self.data)
[docs]
def shift_origin(self, shift, is_prepadded=True):
self.window.shift(shift)
pix_shift = self.plane_to_pixel_delta(shift)
if torch.any(torch.abs(pix_shift) > 1):
raise NotImplementedError(
"Shifts larger than 1 pixel are currently not handled"
)
self.data = shift_Lanczos_torch(
self.data,
pix_shift[0],
pix_shift[1],
min(min(self.data.shape), 10),
dtype=AP_config.ap_dtype,
device=AP_config.ap_device,
img_prepadded=is_prepadded,
)
[docs]
def get_window(self, window: Window, **kwargs):
return super().get_window(
window, target_identity=self.target_identity, **kwargs
)
[docs]
def reduce(self, scale, **kwargs):
return super().reduce(scale, target_identity=self.target_identity, **kwargs)
[docs]
def replace(self, other, data=None):
if isinstance(other, Image):
if self.window.overlap_frac(other.window) == 0.0: # fixme control flow
return
other_indices = self.window.get_other_indices(other)
self_indices = other.window.get_other_indices(self)
if (
self.data[self_indices].nelement() == 0
or other.data[other_indices].nelement() == 0
):
return
self.data[self_indices] = other.data[other_indices]
elif isinstance(other, Window):
self.data[self.window.get_self_indices(other)] = data
else:
self.data = other
[docs]
def get_state(self):
state = super().get_state()
state["target_identity"] = self.target_identity
return state
[docs]
def set_state(self, state):
super().set_state(state)
self.target_identity = target_identity
[docs]
def get_fits_state(self):
states = super().get_fits_state()
for state in states:
if state["HEADER"]["IMAGE"] == "PRIMARY":
state["HEADER"]["TRGTID"] = self.target_identity
return states
[docs]
def set_fits_state(self, states):
super().set_fits_state(states)
for state in states:
if state["HEADER"]["IMAGE"] == "PRIMARY":
self.target_identity = state["HEADER"]["TRGTID"]
######################################################################
[docs]
class Model_Image_List(Image_List, Model_Image):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if not all(isinstance(image, Model_Image) for image in self.image_list):
raise InvalidImage(f"Model_Image_List can only hold Model_Image objects, not {tuple(type(image) for image in self.image_list)}")
[docs]
def clear_image(self):
for image in self.image_list:
image.clear_image()
[docs]
def shift_origin(self, shift):
raise NotImplementedError()
[docs]
def replace(self, other, data=None):
if data is None:
for image, oth in zip(self.image_list, other):
image.replace(oth)
else:
for image, oth, dat in zip(self.image_list, other, data):
image.replace(oth, dat)
@property
def target_identity(self):
targets = tuple(image.target_identity for image in self.image_list)
if any(tar_id is None for tar_id in targets):
return None
return targets
def __isub__(self, other):
if isinstance(other, Model_Image_List):
for other_image, zip_self_image in zip(other.image_list, self.image_list):
if other_image.target_identity is None or self.target_identity is None:
zip_self_image -= other_image
continue
for self_image in self.image_list:
if other_image.target_identity == self_image.target_identity:
self_image -= other_image
break
else:
self.image_list.append(other_image)
elif isinstance(other, Model_Image):
if other.target_identity is None or zip_self_image.target_identity is None:
zip_self_image -= other_image
else:
for self_image in self.image_list:
if other.target_identity == self_image.target_identity:
self_image -= other
break
else:
self.image_list.append(other)
else:
for self_image, other_image in zip(self.image_list, other):
self_image -= other_image
return self
def __iadd__(self, other):
if isinstance(other, Model_Image_List):
for other_image, zip_self_image in zip(other.image_list, self.image_list):
if other_image.target_identity is None or self.target_identity is None:
zip_self_image += other_image
continue
for self_image in self.image_list:
if other_image.target_identity == self_image.target_identity:
self_image += other_image
break
else:
self.image_list.append(other_image)
elif isinstance(other, Model_Image):
if other.target_identity is None or self.target_identity is None:
for self_image in self.image_list:
self_image += other
else:
for self_image in self.image_list:
if other.target_identity == self_image.target_identity:
self_image += other
break
else:
self.image_list.append(other)
else:
for self_image, other_image in zip(self.image_list, other):
self_image += other_image
return self