Source code for astrophot.utils.operations

from functools import lru_cache
from typing import Callable, Optional

import torch
import matplotlib.pyplot as plt
import numpy as np
from astropy.convolution import convolve, convolve_fft
from scipy.fft import next_fast_len
from scipy.special import roots_legendre


[docs] def fft_convolve_torch(img, psf, psf_fft=False, img_prepadded=False): # Ensure everything is tensor img = torch.as_tensor(img) psf = torch.as_tensor(psf) if img_prepadded: s = img.size() else: s = tuple( next_fast_len(int(d + (p + 1) / 2), real=True) for d, p in zip(img.size(), psf.size()) ) # list(int(d + (p + 1) / 2) for d, p in zip(img.size(), psf.size())) img_f = torch.fft.rfft2(img, s=s) if not psf_fft: psf_f = torch.fft.rfft2(psf, s=s) else: psf_f = psf conv_f = img_f * psf_f conv = torch.fft.irfft2(conv_f, s=s) # Roll the tensor to correct centering and crop to original image size return torch.roll( conv, shifts=(-int((psf.size()[0] - 1) / 2), -int((psf.size()[1] - 1) / 2)), dims=(0, 1), )[: img.size()[0], : img.size()[1]]
[docs] def fft_convolve_multi_torch( img, kernels, kernel_fft=False, img_prepadded=False, dtype=None, device=None ): # Ensure everything is tensor img = torch.as_tensor(img, dtype=dtype, device=device) for k in range(len(kernels)): kernels[k] = torch.as_tensor(kernels[k], dtype=dtype, device=device) if img_prepadded: s = img.size() else: s = list(int(d + (p + 1) / 2) for d, p in zip(img.size(), kernels[0].size())) img_f = torch.fft.rfft2(img, s=s) if not kernel_fft: kernels_f = list(torch.fft.rfft2(kernel, s=s) for kernel in kernels) else: psf_f = psf conv_f = img_f for kernel_f in kernels_f: conv_f *= kernel_f conv = torch.fft.irfft2(conv_f, s=s) # Roll the tensor to correct centering and crop to original image size return torch.roll( conv, shifts=( -int((sum(kernel.size()[0] for kernel in kernels) - 1) / 2), -int((sum(kernel.size()[1] for kernel in kernels) - 1) / 2), ), dims=(0, 1), )[: img.size()[0], : img.size()[1]]
[docs] def displacement_spacing(N, dtype=torch.float64, device="cpu"): return torch.linspace( -(N - 1) / (2 * N), (N - 1) / (2 * N), N, dtype=dtype, device=device )
[docs] def displacement_grid(Nx, Ny, pixelscale=None, dtype=torch.float64, device="cpu"): px = displacement_spacing(Nx, dtype=dtype, device=device) py = displacement_spacing(Ny, dtype=dtype, device=device) PX, PY = torch.meshgrid(px, py, indexing="xy") return (pixelscale @ torch.stack((PX, PY)).view(2, -1)).reshape((2, *PX.shape))
[docs] @lru_cache(maxsize=32) def quad_table(n, p, dtype, device): """ from: https://pomax.github.io/bezierinfo/legendre-gauss.html """ abscissa, weights = roots_legendre(n) w = torch.tensor(weights, dtype=dtype, device=device) a = torch.tensor(abscissa, dtype=dtype, device=device) X, Y = torch.meshgrid(a, a, indexing="xy") W = torch.outer(w, w) / 4.0 X, Y = p @ (torch.stack((X, Y)).view(2, -1) / 2.0) return X, Y, W.reshape(-1)
[docs] def single_quad_integrate( X, Y, image_header, eval_brightness, eval_parameters, dtype, device, quad_level=3 ): # collect gaussian quadrature weights abscissaX, abscissaY, weight = quad_table( quad_level, image_header.pixelscale, dtype, device ) # Specify coordinates at which to evaluate function Xs = torch.repeat_interleave(X[..., None], quad_level ** 2, -1) + abscissaX Ys = torch.repeat_interleave(Y[..., None], quad_level ** 2, -1) + abscissaY # Evaluate the model at the quadrature points res = eval_brightness( X=Xs, Y=Ys, image=image_header, parameters=eval_parameters, ) # Reference flux for pixel is simply the mean of the evaluations ref = res[..., (quad_level**2) // 2] #res.mean(axis=-1) # # alternative, use midpoint # Apply the weights and reduce to original pixel space res = (res * weight).sum(axis=-1) return res, ref
[docs] def grid_integrate( X, Y, image_header, eval_brightness, eval_parameters, dtype, device, quad_level=3, gridding=5, _current_depth=1, max_depth=2, reference=None, ): """The grid_integrate function performs adaptive quadrature integration over a given pixel grid, offering precision control where it is needed most. Args: X (torch.Tensor): A 2D tensor representing the x-coordinates of the grid on which the function will be integrated. Y (torch.Tensor): A 2D tensor representing the y-coordinates of the grid on which the function will be integrated. image_header (ImageHeader): An object containing meta-information about the image. eval_brightness (callable): A function that evaluates the brightness at each grid point. This function should be compatible with PyTorch tensor operations. eval_parameters (Parameter_Group): An object containing parameters that are passed to the eval_brightness function. dtype (torch.dtype): The data type of the output tensor. The dtype argument should be a valid PyTorch data type. device (torch.device): The device on which to perform the computations. The device argument should be a valid PyTorch device. quad_level (int, optional): The initial level of quadrature used in the integration. Defaults to 3. gridding (int, optional): The factor by which the grid is subdivided when the integration error for a pixel is above the allowed threshold. Defaults to 5. _current_depth (int, optional): The current depth level of the grid subdivision. Used for recursive calls to the function. Defaults to 1. max_depth (int, optional): The maximum depth level of grid subdivision. Once this level is reached, no further subdivision is performed. Defaults to 2. reference (torch.Tensor or None, optional): A scalar value that represents the allowed threshold for the integration error. Returns: torch.Tensor: A tensor of the same shape as X and Y that represents the result of the integration on the grid. This function operates by first performing a quadrature integration over the given pixels. If the maximum depth level has been reached, it simply returns the result. Otherwise, it calculates the integration error for each pixel and selects those that have an error above the allowed threshold. For pixels that have low error, the result is set as computed. For those with high error, it sets up a finer sampling grid and recursively evaluates the quadrature integration on it. Finally, it integrates the results from the finer sampling grid back to the current resolution. """ # perform quadrature integration on the given pixels res, ref = single_quad_integrate( X, Y, image_header, eval_brightness, eval_parameters, dtype, device, quad_level=quad_level, ) # if the max depth is reached, simply return the integrated pixels if _current_depth >= max_depth: return res # Begin integral integral = torch.zeros_like(X) # Select pixels which have errors above the allowed threshold select = torch.abs((res - ref)) > reference # For pixels with low error, set the results as computed integral[torch.logical_not(select)] = res[torch.logical_not(select)] # Set up sub-gridding to super resolve problem pixels stepx, stepy = displacement_grid( gridding, gridding, image_header.pixelscale, dtype, device ) # Write out the coordinates for the super resolved pixels subgridX = torch.repeat_interleave( X[select].unsqueeze(-1), gridding ** 2, -1 ) + stepx.reshape(-1) subgridY = torch.repeat_interleave( Y[select].unsqueeze(-1), gridding ** 2, -1 ) + stepy.reshape(-1) # Recursively evaluate the quadrature integration on the finer sampling grid subgridres = grid_integrate( subgridX, subgridY, image_header.rescale_pixel(1/gridding), eval_brightness, eval_parameters, dtype, device, quad_level=quad_level, gridding=gridding, _current_depth=_current_depth+1, max_depth=max_depth, reference=reference * gridding**2, ) # Integrate the finer sampling grid back to current resolution integral[select] = subgridres.sum(axis=(-1,)) return integral