# Levenberg-Marquardt algorithm
import os
from time import time
from typing import List, Callable, Optional, Union, Sequence, Any
import torch
from torch.autograd.functional import jacobian
import numpy as np
import matplotlib.pyplot as plt
from .base import BaseOptimizer
from .. import AP_config
__all__ = ["oldLM", "LM_Constraint"]
@torch.no_grad()
@torch.jit.script
def Broyden_step(J, h, Yp, Yph):
delta = torch.matmul(J, h)
# avoid constructing a second giant jacobian matrix, instead go one row at a time
for j in range(J.shape[1]):
J[:, j] += (Yph - Yp - delta) * h[j] / torch.linalg.norm(h)
return J
[docs]
class oldLM(BaseOptimizer):
"""based heavily on:
@article{gavin2019levenberg,
title={The Levenberg-Marquardt algorithm for nonlinear least squares curve-fitting problems},
author={Gavin, Henri P},
journal={Department of Civil and Environmental Engineering, Duke University},
volume={19},
year={2019}
}
The Levenberg-Marquardt algorithm bridges the gap between a
gradient descent optimizer and a Newton's Method optimizer. The
Hessian for the Newton's Method update is too complex to evaluate
with automatic differentiation (memory scales roughly as
parameters^2 * pixels^2) and so an approximation is made using the
Jacobian of the image pixels wrt to the parameters of the
model. Automatic differentiation provides an exact Jacobian as
opposed to a finite differences approximation.
Once a Hessian H and gradient G have been determined, the update
step is defined as h which is the solution to the linear equation:
(H + L*I)h = G
where L is the Levenberg-Marquardt damping parameter and I is the
identity matrix. For small L this is just the Newton's method, for
large L this is just a small gradient descent step (approximately
h = grad/L). The method implimented is modified from Gavin 2019.
Args:
model (AstroPhot_Model): object with which to perform optimization
initial_state (Optional[Sequence]): an initial state for optimization
epsilon4 (Optional[float]): approximation accuracy requirement, for any rho < epsilon4 the step will be rejected. Default 0.1
epsilon5 (Optional[float]): numerical stability factor, added to the diagonal of the Hessian. Default 1e-8
constraints (Optional[Union[LM_Constraint,tuple[LM_Constraint]]]): Constraint objects which control the fitting process.
L0 (Optional[float]): initial value for L factor in (H +L*I)h = G. Default 1.
Lup (Optional[float]): amount to increase L when rejecting an update step. Default 11.
Ldn (Optional[float]): amount to decrease L when accetping an update step. Default 9.
"""
def __init__(
self,
model: "AstroPhot_Model",
initial_state: Sequence = None,
max_iter: int = 100,
fit_parameters_identity: Optional[tuple] = None,
**kwargs,
):
super().__init__(
model,
initial_state,
max_iter=max_iter,
fit_parameters_identity=fit_parameters_identity,
**kwargs,
)
# Set optimizer parameters
self.epsilon4 = kwargs.get("epsilon4", 0.1)
self.epsilon5 = kwargs.get("epsilon5", 1e-8)
self.Lup = kwargs.get("Lup", 11.0)
self.Ldn = kwargs.get("Ldn", 9.0)
self.L = kwargs.get("L0", 1e-3)
self.use_broyden = kwargs.get("use_broyden", False)
# Initialize optimizer atributes
self.Y = self.model.target[self.fit_window].flatten("data")
# 1 / sigma^2
self.W = (
1.0 / self.model.target[self.fit_window].flatten("variance")
if model.target.has_variance
else 1.0
)
# # pixels # parameters
self.ndf = len(self.Y) - len(self.current_state)
self.J = None
self.full_jac = False
self.current_Y = None
self.prev_Y = [None, None]
if self.model.target.has_mask:
self.mask = self.model.target[self.fit_window].flatten("mask")
# subtract masked pixels from degrees of freedom
self.ndf -= torch.sum(self.mask)
self.L_history = []
self.decision_history = []
self.rho_history = []
self._count_converged = 0
self.ndf = kwargs.get("ndf", self.ndf)
self._covariance_matrix = None
# update attributes with constraints
self.constraints = kwargs.get("constraints", None)
if self.constraints is not None and isinstance(self.constraints, LM_Constraint):
self.constraints = (self.constraints,)
if self.constraints is not None:
for con in self.constraints:
self.Y = torch.cat((self.Y, con.reference_value))
self.W = torch.cat((self.W, 1 / con.weight))
self.ndf -= con.reduce_ndf
if self.model.target.has_mask:
self.mask = torch.cat(
(
self.mask,
torch.zeros_like(con.reference_value, dtype=torch.bool),
)
)
[docs]
def L_up(self, Lup=None):
if Lup is None:
Lup = self.Lup
self.L = min(1e9, self.L * Lup)
[docs]
def L_dn(self, Ldn=None):
if Ldn is None:
Ldn = self.Ldn
self.L = max(1e-9, self.L / Ldn)
[docs]
def step(self, current_state=None) -> None:
"""
Levenberg-Marquardt update step
"""
if current_state is not None:
self.current_state = current_state
if self.iteration > 0:
if self.verbose > 0:
AP_config.ap_logger.info("---------iter---------")
else:
if self.verbose > 0:
AP_config.ap_logger.info("---------init---------")
h = self.update_h()
if self.verbose > 1:
AP_config.ap_logger.info(f"h: {h.detach().cpu().numpy()}")
self.update_Yp(h)
loss = self.update_chi2()
if self.verbose > 0:
AP_config.ap_logger.info(f"LM loss: {loss.item()}")
if self.iteration == 0:
self.prev_Y[1] = self.current_Y
self.loss_history.append(loss.detach().cpu().item())
self.L_history.append(self.L)
self.lambda_history.append(
np.copy((self.current_state + h).detach().cpu().numpy())
)
if self.iteration > 0 and not torch.isfinite(loss):
if self.verbose > 0:
AP_config.ap_logger.warning("nan loss")
self.decision_history.append("nan")
self.rho_history.append(None)
self._count_reject += 1
self.iteration += 1
self.L_up()
return
elif self.iteration > 0:
lossmin = np.nanmin(self.loss_history[:-1])
rho = self.rho(lossmin, loss, h)
if self.verbose > 1:
AP_config.ap_logger.debug(
f"LM loss: {loss.item()}, best loss: {np.nanmin(self.loss_history[:-1])}, loss diff: {np.nanmin(self.loss_history[:-1]) - loss.item()}, L: {self.L}"
)
self.rho_history.append(rho)
if self.verbose > 1:
AP_config.ap_logger.debug(f"rho: {rho.item()}")
if rho > self.epsilon4:
if self.verbose > 0:
AP_config.ap_logger.info("accept")
self.decision_history.append("accept")
self.prev_Y[0] = self.prev_Y[1]
self.prev_Y[1] = torch.clone(self.current_Y)
self.current_state += h
self.L_dn()
self._count_reject = 0
if 0 < ((lossmin - loss) / loss) < self.relative_tolerance:
self._count_finish += 1
else:
self._count_finish = 0
else:
if self.verbose > 0:
AP_config.ap_logger.info("reject")
self.decision_history.append("reject")
self.L_up()
self._count_reject += 1
return
else:
self.decision_history.append("init")
self.rho_history.append(None)
if (
(not self.use_broyden)
or self.J is None
or self.iteration < 2
or "reset" in self.decision_history[-2:]
or rho < self.epsilon4
or self._count_reject > 0
or self.iteration >= (2 * len(self.current_state))
or self.decision_history[-1] == "nan"
):
if self.verbose > 1:
AP_config.ap_logger.debug("full jac")
self.update_J_AD()
else:
if self.verbose > 1:
AP_config.ap_logger.debug("Broyden jac")
self.update_J_Broyden(h, self.prev_Y[0], self.current_Y)
self.update_hess()
self.update_grad(self.prev_Y[1])
self.iteration += 1
[docs]
def fit(self):
self.iteration = 0
self._count_reject = 0
self._count_finish = 0
self.grad_only = False
start_fit = time()
try:
while True:
if self.verbose > 0:
AP_config.ap_logger.info(f"L: {self.L}")
# take LM step
self.step()
# Save the state of the model
if (
self.save_steps is not None
and self.decision_history[-1] == "accept"
):
self.model.save(
os.path.join(
self.save_steps,
f"{self.model.name}_Iteration_{self.iteration:03d}.yaml",
)
)
lam, L, loss = self.progress_history()
# Check for convergence
if (
self.decision_history.count("accept") > 2
and self.decision_history[-1] == "accept"
and L[-1] < 0.1
and ((loss[-2] - loss[-1]) / loss[-1])
< (self.relative_tolerance / 10)
):
self._count_converged += 1
elif self.iteration >= self.max_iter:
self.message = (
self.message + f"fail max iterations reached: {self.iteration}"
)
break
elif not torch.all(torch.isfinite(self.current_state)):
self.message = self.message + "fail non-finite step taken"
break
elif (
self.L >= (1e9 - 1)
and self._count_reject >= 8
and not self.take_low_rho_step()
):
self.message = (
self.message
+ "fail by immobility, unable to find improvement or even small bad step"
)
break
if self._count_converged >= 3:
self.message = self.message + "success"
break
lam, L, loss = self.accept_history()
if len(loss) >= 10:
loss10 = np.array(loss[-10:])
if (
np.all(
np.abs((loss10[0] - loss10[-1]) / loss10[-1])
< self.relative_tolerance
)
and L[-1] < 0.1
):
self.message = self.message + "success"
break
if (
np.all(
np.abs((loss10[0] - loss10[-1]) / loss10[-1])
< self.relative_tolerance
)
and L[-1] >= 0.1
):
self.message = (
self.message
+ "fail by immobility, possible bad area of parameter space."
)
break
except KeyboardInterrupt:
self.message = self.message + "fail interrupted"
if self.message.startswith("fail") and self._count_finish > 0:
self.message = (
self.message
+ ". possibly converged to numerical precision and could not make a better step."
)
self.model.parameters.set_values(
self.res(),
as_representation=True,
parameters_identity=self.fit_parameters_identity,
)
if self.verbose > 1:
AP_config.ap_logger.info(
f"LM Fitting complete in {time() - start_fit} sec with message: {self.message}"
)
return self
[docs]
def update_uncertainty(self):
# set the uncertainty for each parameter
cov = self.covariance_matrix
if torch.all(torch.isfinite(cov)):
try:
self.model.parameters.set_uncertainty(
torch.sqrt(
torch.abs(torch.diag(cov))
),
as_representation=False,
parameters_identity=self.fit_parameters_identity,
)
except RuntimeError as e:
AP_config.ap_logger.warning(f"Unable to update uncertainty due to: {e}")
[docs]
@torch.no_grad()
def undo_step(self) -> None:
AP_config.ap_logger.info("undoing step, trying to recover")
assert (
self.decision_history.count("accept") >= 2
), "cannot undo with not enough accepted steps, retry with new parameters"
assert len(self.decision_history) == len(self.lambda_history)
assert len(self.decision_history) == len(self.L_history)
found_accept = False
for i in reversed(range(len(self.decision_history))):
if not found_accept and self.decision_history[i] == "accept":
found_accept = True
continue
if self.decision_history[i] != "accept":
continue
self.current_state = torch.tensor(
self.lambda_history[i],
dtype=AP_config.ap_dtype,
device=AP_config.ap_device,
)
self.L = self.L_history[i] * self.Lup
[docs]
def take_low_rho_step(self) -> bool:
for i in reversed(range(len(self.decision_history))):
if "accept" in self.decision_history[i]:
return False
if self.rho_history[i] is not None and self.rho_history[i] > 0:
if self.verbose > 0:
AP_config.ap_logger.info(
f"taking a low rho step for some progress: {self.rho_history[i]}"
)
self.current_state = torch.tensor(
self.lambda_history[i],
dtype=AP_config.ap_dtype,
device=AP_config.ap_device,
)
self.L = self.L_history[i]
self.loss_history.append(self.loss_history[i])
self.L_history.append(self.L)
self.lambda_history.append(
np.copy((self.current_state).detach().cpu().numpy())
)
self.decision_history.append("low rho accept")
self.rho_history.append(self.rho_history[i])
with torch.no_grad():
self.update_Yp(torch.zeros_like(self.current_state))
self.prev_Y[0] = self.prev_Y[1]
self.prev_Y[1] = self.current_Y
self.update_J_AD()
self.update_hess()
self.update_grad(self.prev_Y[1])
self.iteration += 1
self.count_reject = 0
return True
[docs]
@torch.no_grad()
def update_h(self) -> torch.Tensor:
"""Solves the LM update linear equation (H + L*I)h = G to determine
the proposal for how to adjust the parameters to decrease the
chi2.
"""
h = torch.zeros_like(self.current_state)
if self.iteration == 0:
return h
h = torch.linalg.solve(
(
self.hess
+ self.L**2
* torch.eye(
len(self.grad), dtype=AP_config.ap_dtype, device=AP_config.ap_device
)
)
* (
1
+ self.L**2
* torch.eye(
len(self.grad), dtype=AP_config.ap_dtype, device=AP_config.ap_device
)
) ** 2
/ (1 + self.L**2),
self.grad,
)
return h
[docs]
@torch.no_grad()
def update_Yp(self, h):
"""
Updates the current model values for each pixel
"""
# Sample model at proposed state
self.current_Y = self.model(
parameters=self.current_state + h,
as_representation=True,
parameters_identity=self.fit_parameters_identity,
window=self.fit_window,
).flatten("data")
# Add constraint evaluations
if self.constraints is not None:
for con in self.constraints:
self.current_Y = torch.cat((self.current_Y, con(self.model)))
[docs]
@torch.no_grad()
def update_chi2(self):
"""
Updates the chi squared / ndf value
"""
# Apply mask if needed
if self.model.target.has_mask:
loss = (
torch.sum(
((self.Y - self.current_Y) ** 2 * self.W)[
torch.logical_not(self.mask)
]
)
/ self.ndf
)
else:
loss = torch.sum((self.Y - self.current_Y) ** 2 * self.W) / self.ndf
return loss
[docs]
def update_J_AD(self) -> None:
"""
Update the jacobian using automatic differentiation, produces an accurate jacobian at the current state.
"""
# Free up memory
del self.J
if "cpu" not in AP_config.ap_device:
torch.cuda.empty_cache()
# Compute jacobian on image
self.J = self.model.jacobian(
torch.clone(self.current_state).detach(),
as_representation=True,
parameters_identity=self.fit_parameters_identity,
window=self.fit_window,
).flatten("data")
# compute the constraint jacobian if needed
if self.constraints is not None:
for con in self.constraints:
self.J = torch.cat((self.J, con.jacobian(self.model)))
# Apply mask if needed
if self.model.target.has_mask:
self.J[self.mask] = 0.0
# Note that the most recent jacobian was a full autograd jacobian
self.full_jac = True
[docs]
def update_J_natural(self) -> None:
"""
Update the jacobian using automatic differentiation, produces an accurate jacobian at the current state. Use this method to get the jacobian in the parameter space instead of representation space.
"""
# Free up memory
del self.J
if "cpu" not in AP_config.ap_device:
torch.cuda.empty_cache()
# Compute jacobian on image
self.J = self.model.jacobian(
torch.clone(
self.model.parameters.transform(
self.current_state,
to_representation=False,
parameters_identity=self.fit_parameters_identity,
)
).detach(),
as_representation=False,
parameters_identity=self.fit_parameters_identity,
window=self.fit_window,
).flatten("data")
# compute the constraint jacobian if needed
if self.constraints is not None:
for con in self.constraints:
self.J = torch.cat((self.J, con.jacobian(self.model)))
# Apply mask if needed
if self.model.target.has_mask:
self.J[self.mask] = 0.0
# Note that the most recent jacobian was a full autograd jacobian
self.full_jac = False
[docs]
@torch.no_grad()
def update_J_Broyden(self, h, Yp, Yph) -> None:
"""
Use the Broyden update to approximate the new Jacobian tensor at the current state. Less accurate, but far faster.
"""
# Update the Jacobian
self.J = Broyden_step(self.J, h, Yp, Yph)
# Apply mask if needed
if self.model.target.has_mask:
self.J[self.mask] = 0.0
# compute the constraint jacobian if needed
if self.constraints is not None:
for con in self.constraints:
self.J = torch.cat((self.J, con.jacobian(self.model)))
# Note that the most recent jacobian update was with Broyden step
self.full_jac = False
[docs]
@torch.no_grad()
def update_hess(self) -> None:
"""
Update the Hessian using the jacobian most recently computed on the image.
"""
if isinstance(self.W, float):
self.hess = torch.matmul(self.J.T, self.J)
else:
self.hess = torch.matmul(self.J.T, self.W.view(len(self.W), -1) * self.J)
self.hess += self.epsilon5 * torch.eye(
len(self.current_state),
dtype=AP_config.ap_dtype,
device=AP_config.ap_device,
)
@property
@torch.no_grad()
def covariance_matrix(self) -> torch.Tensor:
if self._covariance_matrix is not None:
return self._covariance_matrix
self.update_J_natural()
self.update_hess()
try:
self._covariance_matrix = 2*torch.linalg.inv(self.hess)
except:
AP_config.ap_logger.warning(
"WARNING: Hessian is singular, likely at least one model is non-physical. Will massage Hessian to continue but results should be inspected."
)
self.hess += torch.eye(
len(self.grad), dtype=AP_config.ap_dtype, device=AP_config.ap_device
) * (torch.diag(self.hess) == 0)
self._covariance_matrix = 2*torch.linalg.inv(self.hess)
return self._covariance_matrix
[docs]
@torch.no_grad()
def update_grad(self, Yph) -> None:
"""
Update the gradient using the model evaluation on all pixels
"""
self.grad = torch.matmul(self.J.T, self.W * (self.Y - Yph))
[docs]
@torch.no_grad()
def rho(self, Xp, Xph, h) -> torch.Tensor:
return (
self.ndf
* (Xp - Xph)
/ abs(
torch.dot(
h,
self.L**2 * (torch.abs(torch.diag(self.hess) - self.epsilon5) * h)
+ self.grad,
)
)
)
[docs]
def accept_history(self) -> (List[np.ndarray], List[np.ndarray], List[float]):
lambdas = []
Ls = []
losses = []
for l in range(len(self.decision_history)):
if "accept" in self.decision_history[l] and np.isfinite(
self.loss_history[l]
):
lambdas.append(self.lambda_history[l])
Ls.append(self.L_history[l])
losses.append(self.loss_history[l])
return lambdas, Ls, losses
[docs]
def progress_history(self) -> (List[np.ndarray], List[np.ndarray], List[float]):
lambdas = []
Ls = []
losses = []
for l in range(len(self.decision_history)):
if self.decision_history[l] == "accept":
lambdas.append(self.lambda_history[l])
Ls.append(self.L_history[l])
losses.append(self.loss_history[l])
return lambdas, Ls, losses
[docs]
class LM_Constraint:
"""Add an arbitrary constraint to the LM optimization algorithm.
Expresses a constraint between parameters in the LM optimization
routine. Constraints may be used to bias parameters to have
certain behaviour, for example you may require the radius of one
model to be larger than that of another, or may require two models
to have the same position on the sky. The constraints defined in
this object are fuzzy constraints and so can be broken to some
degree, the amount of constraint breaking is determined my how
informative the data is and how strong the constraint weight is
set. To create a constraint, first construct a function which
takes as argument a 1D tensor of the model parameters and gives as
output a real number (or 1D tensor of real numbers) which is zero
when the constraint is satisfied and non-zero increasing based on
how much the constraint is violated. For example:
def example_constraint(P):
return (P[1] - P[0]) * (P[1] > P[0]).int()
which enforces that parameter 1 is less than parameter 0. Note
that we do not use any control flow "if" statements and instead
incorporate the condition through multiplication, this is
important as it allows pytorch to compute derivatives through the
expression and performs far faster on GPU since no communication
is needed back and forth to handle the if-statement. Keep this in
mind while constructing your constraint function. Also, make sure
that any math operations are performed by pytorch so it can
construct a computational graph. Bayond the requirement that the
constraint be differentiable, there is no limitation on what
constraints can be built with this system.
Args:
constraint_func (Callable[torch.Tensor, torch.Tensor]): python function which takes in a 1D tensor of parameters and generates real values in a tensor.
constraint_args (Optional[tuple]): An optional tuple of arguments for the constraint function that will be unpacked when calling the function.
weight (torch.Tensor): The weight of this constraint in the range (0,inf). Smaller values mean a stronger constraint, larger values mean a weaker constraint. Default 1.
representation_parameters (bool): if the constraint_func expects the parameters in the form of their representation or their standard value. Default False
out_len (int): the length of the output tensor by constraint_func. Default 1
reference_value (torch.Tensor): The value at which the constraint is satisfied. Default 0.
reduce_ndf (float): Amount by which to reduce the degrees of freedom. Default 0.
"""
def __init__(
self,
constraint_func: Callable[[torch.Tensor, Any], torch.Tensor],
constraint_args: tuple = (),
representation_parameters: bool = False,
out_len: int = 1,
reduce_ndf: float = 0.0,
weight: Optional[torch.Tensor] = None,
reference_value: Optional[torch.Tensor] = None,
**kwargs,
):
self.constraint_func = constraint_func
self.constraint_args = constraint_args
self.representation_parameters = representation_parameters
self.out_len = out_len
self.reduce_ndf = reduce_ndf
self.reference_value = torch.as_tensor(
reference_value if reference_value is not None else torch.zeros(out_len),
dtype=AP_config.ap_dtype,
device=AP_config.ap_device,
)
self.weight = torch.as_tensor(
weight if weight is not None else torch.ones(out_len),
dtype=AP_config.ap_dtype,
device=AP_config.ap_device,
)
[docs]
def jacobian(self, model: "AstroPhot_Model"):
jac = jacobian(
lambda P: self.constraint_func(P, *self.constraint_args),
model.parameters.get_vector(
as_representation=self.representation_parameters
),
strategy="forward-mode",
vectorize=True,
create_graph=False,
)
return jac.reshape(-1, np.sum(model.parameters.vector_len()))
def __call__(self, model: "AstroPhot_Model"):
return self.constraint_func(
model.parameters.get_vector(
as_representation=self.representation_parameters
),
*self.constraint_args,
)