Source code for astrophot.fit.hmc

# Hamiltonian Monte-Carlo
import os
from time import time
from typing import Optional, Sequence
import warnings

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC as pyro_MCMC
from pyro.infer import HMC as pyro_HMC
from pyro.infer.mcmc.adaptation import WarmupAdapter, BlockMassMatrix
from pyro.ops.welford import WelfordCovariance

from .base import BaseOptimizer
from ..models import AstroPhot_Model
from .. import AP_config

__all__ = ["HMC"]

###########################################
# !Overwrite pyro configuration behavior!
# currently this is the only way to provide
# mass matrix manually
###########################################
def new_configure(self, mass_matrix_shape, adapt_mass_matrix=True, options={}):
    """
    Sets up an initial mass matrix.

    :param dict mass_matrix_shape: a dict that maps tuples of site names to the shape of
        the corresponding mass matrix. Each tuple of site names corresponds to a block.
    :param bool adapt_mass_matrix: a flag to decide whether an adaptation scheme will be used.
    :param dict options: tensor options to construct the initial mass matrix.
    """
    inverse_mass_matrix = {}
    for site_names, shape in mass_matrix_shape.items():
        self._mass_matrix_size[site_names] = shape[0]
        diagonal = len(shape) == 1
        inverse_mass_matrix[site_names] = (
            torch.full(shape, self._init_scale, **options)
            if diagonal
            else torch.eye(*shape, **options) * self._init_scale
        )
        if adapt_mass_matrix:
            adapt_scheme = WelfordCovariance(diagonal=diagonal)
            self._adapt_scheme[site_names] = adapt_scheme

    if len(self.inverse_mass_matrix.keys()) == 0:
        self.inverse_mass_matrix = inverse_mass_matrix


BlockMassMatrix.configure = new_configure
############################################


[docs] class HMC(BaseOptimizer): """Hamiltonian Monte-Carlo sampler wrapper for the Pyro package. This MCMC algorithm uses gradients of the Chi^2 to more efficiently explore the probability distribution. Consider using the NUTS sampler instead of HMC, as it is generally better in most aspects. More information on HMC can be found at: https://en.wikipedia.org/wiki/Hamiltonian_Monte_Carlo, https://arxiv.org/abs/1701.02434, and http://www.mcmchandbook.net/HandbookChapter5.pdf Args: model (AstroPhot_Model): The model which will be sampled. initial_state (Optional[Sequence], optional): A 1D array with the values for each parameter in the model. These values should be in the form of "as_representation" in the model. Defaults to None. max_iter (int, optional): The number of sampling steps to perform. Defaults to 1000. epsilon (float, optional): The length of the integration step to perform for each leapfrog iteration. The momentum update will be of order epsilon * score. Defaults to 1e-5. leapfrog_steps (int, optional): Number of steps to perform with leapfrog integrator per sample of the HMC. Defaults to 20. inv_mass (float or array, optional): Inverse Mass matrix (covariance matrix) which can tune the behavior in each dimension to ensure better mixing when sampling. Defaults to the identity. progress_bar (bool, optional): Whether to display a progress bar during sampling. Defaults to True. prior (distribution, optional): Prior distribution for the parameters. Defaults to None. warmup (int, optional): Number of warmup steps before actual sampling begins. Defaults to 100. hmc_kwargs (dict, optional): Additional keyword arguments for the HMC sampler. Defaults to {}. mcmc_kwargs (dict, optional): Additional keyword arguments for the MCMC process. Defaults to {}. """ def __init__( self, model: AstroPhot_Model, initial_state: Optional[Sequence] = None, max_iter: int = 1000, **kwargs ): super().__init__(model, initial_state, max_iter=max_iter, **kwargs) self.inv_mass = kwargs.get("inv_mass", None) self.epsilon = kwargs.get("epsilon", 1e-3) self.leapfrog_steps = kwargs.get("leapfrog_steps", 20) self.progress_bar = kwargs.get("progress_bar", True) self.prior = kwargs.get("prior", None) self.warmup = kwargs.get("warmup", 100) self.hmc_kwargs = kwargs.get("hmc_kwargs", {}) self.mcmc_kwargs = kwargs.get("mcmc_kwargs", {}) self.acceptance = None
[docs] def fit( self, state: Optional[torch.Tensor] = None, ): """Performs MCMC sampling using Hamiltonian Monte-Carlo step. Records the chain for later examination. Args: state (torch.Tensor, optional): Model parameters as a 1D tensor. Returns: HMC: An instance of the HMC class with updated chain. """ def step(model, prior): x = pyro.sample("x", prior) # Log-likelihood function model.parameters.flat_detach() log_likelihood_value = -model.negative_log_likelihood( parameters=x, as_representation=True ) # Observe the log-likelihood pyro.factor("obs", log_likelihood_value) if self.prior is None: self.prior = dist.Normal( self.current_state, torch.ones_like(self.current_state) * 1e2 + torch.abs(self.current_state) * 1e2, ) # Set up the HMC sampler hmc_kwargs = { "jit_compile": False, "ignore_jit_warnings": True, "full_mass": True, "step_size": self.epsilon, "num_steps": self.leapfrog_steps, "adapt_step_size": False, "adapt_mass_matrix": self.inv_mass is None, } hmc_kwargs.update(self.hmc_kwargs) hmc_kernel = pyro_HMC(step, **hmc_kwargs) if self.inv_mass is not None: hmc_kernel.mass_matrix_adapter.inverse_mass_matrix = {("x",): self.inv_mass} # Provide an initial guess for the parameters init_params = {"x": self.model.parameters.vector_representation()} # Run MCMC with the HMC sampler and the initial guess mcmc_kwargs = { "num_samples": self.max_iter, "warmup_steps": self.warmup, "initial_params": init_params, "disable_progbar": not self.progress_bar, } mcmc_kwargs.update(self.mcmc_kwargs) mcmc = pyro_MCMC(hmc_kernel, **mcmc_kwargs) mcmc.run(self.model, self.prior) self.iteration += self.max_iter # Extract posterior samples chain = mcmc.get_samples()["x"] with torch.no_grad(): for i in range(len(chain)): chain[i] = self.model.parameters.vector_transform_rep_to_val(chain[i]) self.chain = chain return self