Source code for astrophot.fit.mhmcmc

# Metropolis-Hasting Markov-Chain Monte-Carlo
import os
from time import time
from typing import Optional, Sequence
import torch
from tqdm import tqdm
import numpy as np
from .base import BaseOptimizer
from .. import AP_config

__all__ = ["MHMCMC"]


[docs] class MHMCMC(BaseOptimizer): """Metropolis-Hastings Markov-Chain Monte-Carlo sampler, based on: https://en.wikipedia.org/wiki/Metropolis-Hastings_algorithm . This is a naive implimentation of a standard MCMC, it is far from optimal and should not be used for anything but the most basic scenarios. Args: model (AstroPhot_Model): The model which will be sampled. initial_state (Optional[Sequence]): A 1D array with the values for each parameter in the model. Note that these values should be in the form of "as_representation" in the model. max_iter (int): The number of sampling steps to perform. Default 1000 epsilon (float or array): The random step length to take at each iteration. This is the standard deviation for the normal distribution sampling. Default 1e-2 """ def __init__( self, model: "AstroPhot_Model", initial_state: Optional[Sequence] = None, max_iter: int = 1000, **kwargs, ): super().__init__(model, initial_state, max_iter=max_iter, **kwargs) self.epsilon = kwargs.get("epsilon", 1e-2) self.progress_bar = kwargs.get("progress_bar", True) self.report_after = kwargs.get("report_after", int(self.max_iter / 10)) self.chain = [] self._accepted = 0 self._sampled = 0
[docs] def fit( self, state: Optional[torch.Tensor] = None, nsamples: Optional[int] = None, restart_chain: bool = True, ): """ Performs the MCMC sampling using a Metropolis Hastings acceptance step and records the chain for later examination. """ if nsamples is None: nsamples = self.max_iter if state is None: state = self.current_state chi2 = self.sample(state) if restart_chain: self.chain = [] else: self.chain = list(self.chain) iterator = tqdm(range(nsamples)) if self.progress_bar else range(nsamples) for i in iterator: state, chi2 = self.step(state, chi2) self.append_chain(state) if i % self.report_after == 0 and i > 0 and self.verbose > 0: AP_config.ap_logger.info(f"Acceptance: {self.acceptance}") if self.verbose > 0: AP_config.ap_logger.info(f"Acceptance: {self.acceptance}") self.current_state = state self.chain = np.stack(self.chain) return self
[docs] def append_chain(self, state: torch.Tensor): """ Add a state vector to the MCMC chain """ self.chain.append( self.model.parameters.vector_transform_rep_to_val(state) .detach() .cpu() .clone() .numpy() )
[docs] @staticmethod def accept(log_alpha): """ Evaluates randomly if a given proposal is accepted. This is done in log space which is more natural for the evaluation in the step. """ return torch.log(torch.rand(log_alpha.shape)) < log_alpha
[docs] @torch.no_grad() def sample(self, state: torch.Tensor): """ Samples the model at the proposed state vector values """ return self.model.negative_log_likelihood( parameters=state, as_representation=True )
[docs] @torch.no_grad() def step(self, state: torch.Tensor, chi2: torch.Tensor) -> torch.Tensor: """ Takes one step of the HMC sampler by integrating along a path initiated with a random momentum. """ proposal_state = torch.normal(mean=state, std=self.epsilon) proposal_chi2 = self.sample(proposal_state) log_alpha = chi2 - proposal_chi2 accept = self.accept(log_alpha) self._accepted += accept self._sampled += 1 return proposal_state if accept else state, proposal_chi2 if accept else chi2
@property def acceptance(self): """ Returns the ratio of accepted states to total states sampled. """ return self._accepted / self._sampled