Probabilistic Effects. λθ

Documentation

1. What Does Monad-Bayes Do?

The general way of inferring a posterior distribution given a likelihood and a prior is an approximate inference technique called MCMC which are a whole class of methods to infer the posterior. A key point to note is that although we are inferring the posterior, we are also inferring an estimate of the likelihood as well using the data. Inference in general is the act of predicting the values of something which we don’t have access to directly, given some data. At the highest level, there is inference over the posterior distribution of the parameters.

Monad Bayes uses a particle filter to infer the likelihood of a given model, and uses metropolis-hastings (or more specifically, the trace metropolis-hastings algorithm) to infer the posterior distribution over any desired given quantity. (A particle filter is only needed when we do not have an exact formulation of the likelihood). Monad Bayes offers the user a very generic version of trace metropolis-hastings, and a particle filter – in general, these two components can be combined to produce different algorithms (such as PMMH or SMC). In theory, one could extend this implementation with one’s own altered version of metropolis-hastings. Particle filters and metropolis-hastings are just two possible base building blocks of a lot of statistical algorithms. In general, the idea of the library monad-bayes is very vague – at its core, it is an effectful, type-class approach to probalistic programming. Whilst monad-bayes has implementations of trace metropolis-hastings and particle filters within it, these are just building blocks of other algorithms and it is possible to use monad-bayes in many ways.


2. Building Blocks of Monad Bayes

2.1 Sampler (Bayes.Sampler)
-- | An 'ST' based random sampler using the @mwc-random@ package.
newtype SamplerST a = SamplerST (forall s. ReaderT (GenST s) (ST s) a)

A sampler is something which draws concrete values for random variables from a prior. This is constructed as the ReaderT monad, where:

  • The read-only environment is a global pseudo-random GenST s where s is the state of the pseudo-random number generator.
  • The monadic context is the state monad ST s that references the state s of the global pseudo-random number generator GenST s. Since the generator is mutable, we use the ST monad. A computation of type ST s a returns a value of type a, and execute in “thread” s. The universality of s keeps objects inside the ST monad from leaking to the outside of the ST monad.

The ReaderT monad then allows us to read from the random number generator (as the environment).

runReaderT :: ReaderT r m a -> r -> m a
-- or in this specific case:
runReaderT :: ReaderT (GenST s) (ST s) a -> GenST s -> ST s a
2.2 Weighted (Bayes.Weighted)
-- | Execute the program using the prior distribution, while accumulating likelihood.
newtype Weighted m a = Weighted (StateT (Log Double) m a)
  -- StateT is more efficient than WriterT
  deriving (Functor, Applicative, Monad, MonadIO, MonadTrans, MonadSample)

A type of Weighted m a represents a weighted sample or particle wrapped within some monadic context m; this is represented by the StateT monad where the state is some weight Log Double. Thereotically, the writer monad transformer can be used on its own (but the state monad transformer was apparently faster). Hence, calling runWeighted will return a tuple containing a sample/particle value a and its associated weight Log Double.

Weighting a representation m equips it with conditioning operation making W m a conditioning representation. If m was a sampling representation, its weighted version is also a sampling representation by lifting the sampling operation.

  • runWeighted - calling runWeighted then runs the StateT monad using 1 as the initial weight. This lets us obtain an m-computation returning pairs (a, Log Double) of the result type a and its explicit, accumulated log-likelihood Log Double.

    -- | Obtain an explicit value of the likelihood for a given value.
    runWeighted :: (Monad m) => Weighted m a -> m (a, Log Double)
    runWeighted (Weighted m) = runStateT m 1
    
  • score - the function score will multiply the current weight held within Weighted m a by a given factor.

    instance Monad m => MonadCond (Weighted m) where
      score w = Weighted (modify (* w))
    
  • hoist - the weighted version of the function hoist lifts inference transformations applicable to m and turns them into an inference transformations applicable to T m. When m is already a conditioning representation, we may use the conditioning available through the weighting transformer, or hoist score r to use the ambient conditioning of m.

    hoist :: (forall x. m x -> n x) -> Weighted m a -> Weighted n a
    hoist f (Weighted m) = Weighted (mapStateT f m)
    

We can construct a simple inference algorithm by interpreting the model in Weighted Sampler a, which when run, lets us obtain a weighted sampler of type Sampler (a, Log Double).

2.3 Population (Bayes.Population)
-- | A collection of weighted particles.
newtype Population m a = Population (Weighted (ListT m) a)

-- | Explicit representation of the weighted sample with weights in the log
-- domain.
runPopulation :: Monad m => Population m a -> m [(a, Log Double)]
runPopulation (Population m) = observeAllT $ runWeighted m

A population is a list of weighted particles, wrapped within some monadic context m. Hence, calling runPopulation will return a list of tuples containing particle values a and their associated weights Log Double. This is used during the particle filter.

  • hoist - The population version of hoist will:

    1. First, run the population to acquire a list of weighted particles m [(a, Log Double)].
    2. Applies the inference transformation f to give n [(a, Log Double)]
    3. And then finally applies fromWeightedList which first converts n [(a, Log Double)] to the list transformer version ListT n (a, Log Double), then applies withWeight to create the type Weighted (ListT n) a, which is a population (a list of weighted particles).
    hoist :: (forall x. m x -> n x) -> Population m a -> Population n a
    hoist f (Population m) = fromWeightedList . f . runPopulation
    
    -- | Initialize 'Population' with a concrete weighted sample.
    fromWeightedList :: Monad n => n [(a, Log Double)] -> Population n a
    fromWeightedList = Population . withWeight . ListT
    
    -- | Embed a random variable with explicitly given likelihood.
    -- > runWeighted . withWeight = id
    withWeight :: (Monad n) => n (a, Log Double) -> Weighted n a
    withWeight m = Weighted $ do
      (x, w) <- lift m
      modify (* w)
      return x
    

We use three inference transformations associated with Population:

  • spawn

    spawn :: Monad m => Int -> Population m ()
    spawn n = fromWeightedList $ pure $ replicate n ((), 1 / fromIntegral n)
    

    This increases the sample size (number of particles in the population) by a given factor. The weights of the particles are adjusted such that their sum is preserved. It is therefore safe to use ‘spawn’ in arbitrary places in the program without introducing bias.

  • resample

    resampleSystematic :: MonadSample m => Population m a -> Population m a
    

    This resamples the population using the underlying monad m and a systematic resampling scheme - i.e. it draws a new population with uniform weights from the current population (the total weight of the set of particles is preserved).Resampling’s purpose is to remedy situations when a single sample has a large weight compared to the other particles in the population and dominates the result making the other particles irrelevant.

  • pushEvidence

    pushEvidence :: MonadCond m => Population m a -> Population m a
    pushEvidence = (hoist applyWeight) . extractEvidence
    

    This normalizes the weights in the population of particles, while at the same time incorporating the sum of the weights as a score in m, which is some monadic computation that can be conditioned.

    The first function extractEvidence:

    1. Takes a population (a list of weighted particles)

    2. Sums up the weights of all the particles in z = sum ps (summing these weights is an estimate of the likelihood of the sequence of random variables sampled during the program’s execution giving rise to the program’s output).

    3. Normalizes the weights of the particles, giving ws.

    4. Explicitly separates the sum of the weights of the population (the likelihood) into the Weighted transformer, by calling factor z (which I assume will always just multiply the initial weight 1 by z).

    5. Returns the set of particles and their normalized weights, where the likelihood z is separated and stored in the Weighted m.

      -- | Separate the sum of weights into the 'Weighted' transformer.
      -- Weights are normalized after this operation.
      extractEvidence :: Monad m => Population m a -> Population (Weighted m) a
      extractEvidence m = fromWeightedList $ do
        pop <- lift $ runPopulation m
        let (xs, ps) = unzip pop
        let z = sum ps
        let ws = map (if z > 0 then (/ z) else const (1 / fromIntegral (length ps))) ps
        factor z
        return $ zip xs ws
      

      This results in a return type Population (Weighted m) a, which when run, will unwrap to become m (Weighted (List (Weighted a))) - i.e. the list of weighted particles and the likelihood.

    The second function applyWeight:

    • When hoisted to the type Population (Weighted m) a -> Population m a, will incorporate the sum of the weights of the particles in the population (the likelihood stored in the Weighted m) as a score in the monad m. This then discards the likelihood, returning just m a.

      -- | Use the weight as a factor in the transformed monad.
      applyWeight :: MonadCond m => Weighted m a -> m a
      applyWeight m = do
        (x, w) <- runWeighted m
        factor w
        return x
      
2.4 Sequential (Bayes.Sequential)
-- | Represents a computation that can be suspended at certain points.
-- The intermediate monadic effects can be extracted, which is particularly
-- useful for implementation of particle filter related methods.
-- All the probabilistic effects are lifted from the transformed monad, but
-- also `suspend` is inserted after each `factor`/`score`.
newtype Sequential m a = Sequential {runSequential :: Coroutine (Await ()) m a}
  deriving (Functor, Applicative, Monad, MonadTrans, MonadIO)

Many models exhibit a sequential structure where observations are interleaved with sampling. In those models a possible inference strategy is to consider a program up to a certain point, do inference on the partial posterior it defines, then run the program a little more, do more inference, and so on. To implement such algorithms we introduce the sequential transformer Sequential which introduces suspensions after each score in the program.

There are two hoisting functions associated with Sequential:

  • The function hoistFirst applies the inference transformation f only to the part of the program executed so far.

    -- | Transform the inner monad.
    -- This operation only applies to computation up to the first suspension.
    hoistFirst :: (forall x. m x -> m x) -> Sequential m a -> Sequential m a
    hoistFirst f = Sequential . Coroutine . f . resume . runSequential
    
  • The function hoist will apply the inference transformation f recursively to all the suspension points of the program.

    -- | Transform the inner monad.
    -- The transformation is applied recursively through all the suspension points.
    hoist :: (Monad m, Monad n) => (forall x. m x -> n x) -> Sequential m a -> Sequential n a
    hoist f = Sequential . mapMonad f . runSequential
    

There are two inference transformations associated with Sequential:

  • advance - the advance transformation runs the program to the next suspension point.

    -- | Execute to the next suspension point.
    -- If the computation is finished, do nothing.
    -- > finish = finish . advance
    advance :: Monad m => Sequential m a -> Sequential m a
    advance (Sequential m) = Sequential (m >>= either (return . Left) runSequential)
    
  • finish - the finish transformation runs the program to the end.

    -- | Remove the remaining suspension points.
    finish :: Monad m => Sequential m a -> m a
    finish (Sequential m) = Sequential (m >>= either return finish)
    

3. Particle Filter

By combining the inference transformers Sequential and Population, we obtain a particle filter (SMC). Within the context of a particle filter, a sample in a population is called a particle. The particle filter is implemented as the function sir, which is a template for particle filtering that takes a custom resampling method.

-- | Sequential importance resampling - an SMC template that takes a custom resampler.
sir :: Monad m =>
  -- | resampler
  (forall x. Population m x -> Population m x) ->
  -- | number of timesteps
  Int ->
  -- | population size
  Int ->
  -- | model
  Sequential (Population m) a ->
  Population m a
sir resampler k n = finish . composeCopies k (advance . hoistFirst resampler) . Seq.hoistFirst (spawn n >>)

-- | Apply a function a given number of times.
composeCopies :: Int -> (a -> a) -> (a -> a)
composeCopies k f = foldr (.) id (replicate k f)
  1. The algorithm starts by creating a monadic computation which initialises a population of size n and awaits its next population related computation. This monadic computation is then wrapped within the Sequential transformer, by hoisting the computation up to the part of the program executed so far (i.e. the first suspension in the current Sequential monad).

    Seq.hoistFirst (spawn n >>)
    

    When this is passed a model of type Sequential (Population m) a, this then represents a computation of where n particles are spawned (creating Sequential (Population m) () and then the model is run.

  1. The function composeCompies :: Int -> (a -> a) -> (a -> a) will apply a given function k times. In this case, the function is (advance . hoistFirst resampler) – when provided a model Sequential (Population m) a, it will resample the population and then execute the model one time-step (i.e. executes the model to the next suspension point). Applying composeCopies k will hence resample and then progress the model after every time step for k time-steps.

    composeCopies k (advance . hoistFirst resampler)
    
  1. The function finish will then run the rest of Sequential m a program to the end, removing the remaining suspension points, which leaves us with the type m a.

4. The Free Monad Transformer

4.1 The Free Monad

FreeF is the fixed point version of the free monad Free.

data FreeF f a x = Pure a | Free (f x)

By wrapping a functor within the free monad, this allows us to treat the functor (or technically Free f) as a monad, i.e. an effectful computation. By constructing values of type Free f, we are simply constructing a data type in the form of a syntax tree. The important thing to note is that this tree on its own has no computation associated with it – it exists as a data type. How we evaluate a free monad tree is determined by what interpreter functions we choose to define. This means that we have decoupled the syntax of our program from its semantics.

When binding over a free monad, (Free f a) -> (a -> Free f b) -> Free f b, this says: execute the effectful computation Free f a, extract the value a from this, and execute another effectful computation Free f b.

4.2 The Free Monad Transformer

Using normal free monads, we can build abstract syntax trees which let us abstract away the interpreter, however, sometimes we can’t specify the syntax tree all at once. Often we want to interleave the syntax tree with some other monad to generate streaming or interactive computations. The free monad transformer FreeT solves this problem by allowing us to mix building steps of the abstract syntax tree with calling actions in some base monad.

data FreeF f a x = Pure a | Free (f x)

newtype FreeT f m a = FreeT {
    runFreeT :: m (FreeF f a (FreeT f m a)) }

We can see that FreeT gives us something very similar to the normal free monad, but with a monad m on the outside, where m is the extra effect. This creates a tree where we go through it and execute some effects, but these effects are executed with respect to some other effect; so every step of computation (the nodes of the tree) is wrapped in an effect itself.

4.3 An Example Use of Free Monad Transformers

For example, let’s say we want to write our own Python-style generator. (Python generators are essentially coroutines. They are functions which when passed arguments, will run until encountering the yield keyword. Upon the yield keyword, it will return whatever value is “yielded”, and also save the state of the function at that point (under the name of the original function). Then, upon recalling the function, it will execute from immediately after the yield.)

type Generator a m r = FreeT ((,) a) m r

The type Generator is the free monad transformer where the functor f is ((,) a) i.e. a partially applied tuple with the type a as its first argument, the monad is some arbitrary monad m as an effect, and it produces a result of arbitrary type r. Our functor being ((,) a) means that the first element of the tuple of type a is the value that can be yielded, and the second element of the tuple represents the rest of the computation to be run (i.e. the continuation). The free monad transformer tree is a recursive structure that sequentially composes zero or more operations of the given functor ((,) a) embedded within the context of monad m. The leaves of this tree are given by Pure of type a, and the nodes are operations Free that are shaped by ((,) a) where the first element can be a yielded value and the second element contains the rest of the computation which can be run. The result of return x is a leaf, and (>>=) grows the tree of operations at its leaves.

Let’s now look at an implementation of yield:

yield :: a -> Generator a m ()
yield a = liftF (a, ())

The function yield takes a value of type a that we want to yield, and lifts (b, ()) into the free monad transformer type FreeT ((,) a) m () – when we monadically extract the value from this, we get a tuple containing the yielded value a and a unit () representing the rest (or in this case, the end) of the computation.

Let’s now look at an implementation of prompt:

prompt :: Generator String IO ()
prompt = do
    lift $ putStrLn "Enter a string:"
    str <- lift getLine
    yield str

The function prompt takes a free monad transformer tree where the nodes are shaped by the functor ((,) String meaning each node contains a string and some arbitrary type r – however in this case, r will be (). It also sets the monadic context of our tree to be IO, allowing us to attach IO actions to the operations/nodes of the tree. We choose to attach the IO effect of printing "Enter a string:" to the terminal and then prompting a terminal input from the user – this input is then yielded.

Next, we look at how the free monad transformer tree prompt can be run:

putStrLnAllTheThings :: Show r => FreeT ((,) String) IO r -> IO r
putStrLnAllTheThings gen = do
    x <- runFreeT gen
    case x of
        Pure r   
          -> do putStrLn $ "Result: " ++ show r
                return r
        Free (str, gen' :: FreeT ((,) String) IO r)
          -> do putStrLn $ "User entered: " ++ str
                putStrLnAllTheThings gen'

main = putStrLnAllTheThings prompt

The function putStrLnAllTheThings takes a free monad transformer tree as an argument in the same form as prompt except allowing the result type r to be arbitrary (rather than fixed as ()). Here is an example run of main:

Enter a string:
hello
User entered: hello
Result: ()

If we wanted to define prompt to run forever, we could write the following, where forever executes the operation it is passed, ignores its result, and then recurses on itself:

prompt :: Generator String IO r
prompt = forever $ do
    lift $ putStrLn "Enter a string:"
    str <- lift getLine
    yield str

putStrLnAllTheThings :: FreeT ((,) String) IO r -> IO r
putStrLnAllTheThings gen = do
    x <- runFreeT gen
    case x of
        Pure r   
          -> do return r
        Free (str, gen' :: FreeT ((,) String) IO r)
          -> do putStrLn $ "User entered: " ++ str
                putStrLnAllTheThings gen'

main = putStrLnAllTheThings prompt

Summary: Free Monad Transformer

The free monad transformer means that every step of the computation within the tree can be automatically embedded in an arbitrary monadic effect. This lets us be both abstract in the interpreter as well as the monadic effect we want to attach to the interpreter.

5. Traced (Bayes.Traced.Static)

The following monad Traced is a tracing monad where only a subset of random choices are traced.

Bayes.Traced.Static

data Traced m a
  = Traced
      { model     :: Weighted (FreeSampler m) a,
        traceDist :: m (Trace a)
      }

The notion here is that we have some model, such as an SIR model, which makes random decisions at certain points of the program. We want to be able to decouple what it means to make random choices (the actual implementation of randomness) from the syntax. The syntax is the description of the model but where the model’s stochastic elements (i.e calls to random functions such as distributions) are just names. By supplying the syntax with randomness, we use the description of the stochastic model in a stochatic manner. This distinction can be described as the model as syntax vs the model as a realisation of a stochastic proces.

We want the syntax to describe doing some action, but we want what those random choices are to be chosen elsewhere. For example, consider if we have a program that needs to make a random effect, but we want that randomness to come from different places i.e. we want to give it different ways of generating random numbers at each point. To elaborate, the program using the same deterministic random number at every point of random decision, versus the program using -user-rand at every point of random decision, are two distinctly different ways to do randomness.

The reason this is important with respect to tracing in probabilistic programming, is that the way one performs inference over an execution trace requires one to modify the trace of the program, the trace being all the random decisions made. Hence, we don’t make any changes to the syntax of the program, we only work with the random decisions performed during the program (by making random changes to these random decisions). The first thing we need is a way to decouple these two things, which is what the Traced datatype achieves by having two distinct fields:

  • The model field, which is a description of the suspended model as a syntax, separated from its random decisions. The model is the abstract interpreter which defines our program, augmented in this free monad transformer with where the random decisions to be made would be. One could imagine this as a tree of random decisions waiting for their source of randomness. This tree is generated from the model specification. The implementation of the model itself will most likely be associated with drawing from various distributions, which are themselves random decisions, but for such random actions to be performed, we first need to provide a source of randomness as a “seed”.

  • The traceDist field, which provides the model with a source of randomness in order to realise the model as a stochastic process.

We will now elaborate on what these two fields of Traced are:

5.1 FreeSampler (Bayes.Free)

Let’s first inspect the model field of Traced.

Bayes.Traced.Static

data Traced m a
  = Traced
      { model     :: Weighted (FreeSampler m) a,
        ... }

This involves us looking at what the type FreeSampler is.

Bayes.Free

-- | Random sampling functor.
newtype SamF a = Random (Double -> a)

instance Functor SamF where
  fmap f (Random k) = Random (f . k)

-- | Free monad transformer over random sampling.
newtype FreeSampler m a = FreeSampler {runFreeSampler :: FT SamF m a}
  deriving (Functor, Applicative, Monad, MonadTrans)

The type FreeSampler is essentially the free monad transformer FT where the functor (shape of the tree’s nodes) is specified as the type SamF, and the monadic context we wrap our nodes in is left abstract.

The type SamF is the random sampling functor, which represents the idea of a random effect. It captures the notion that we provide a (random) Double and it returns some value of type a, i.e. the notion that we make a random decision at some point.

The type FT SamF m a is hence the free monad tree of random decisions. In some sense, this is the composition of a bunch of functions which take Doubles and return us as. Imagine a tree where every step of the tree we provide it a random number as a Double which is then mapped to a value as an a – we can then consider this a to be a random value. The type a represents whatever we want samples of. We can think of (Double -> a) as a specification for a random number generator - a random number generator takes some seed as a double and produces a random value in type a. Specifically in this context, this function Double -> a will specialise to StateT [Double] (WriterT [Double] m) a. (This is talked about later, in the function withPartialRandomness).

“So what does the free sampler FT SamF m a allow us to do?” Let’s take a look at the function withRandomness: this executes a computation with supplied values for random choices.

Bayes.Free

-- | Execute computation with supplied values for random choices.
withRandomness :: Monad m => [Double] -> FreeSampler m a -> m a
withRandomness randomness (FreeSampler m) = evalStateT (iterTM f m) randomness
  where
    f (Random k) = do
      xs <- get
      case xs of
        [] -> error "FreeSampler: the list of randomness was too short"
        y : ys -> put ys >> k y

It takes a list of random doubles and a free monad tree of random decisions waiting to happen. We can imagine that m is the computation we want to execute, and embedded in this is some source of randomness given by the SamF functor, or more specifically, the Random k where k :: Double -> a – however, we need to provide a random Double to k each time in order to achieve this randomness. How this list of random Doubles is generated in the first place, is up to the user.

What trace mcmc will do in order to do inference, is mess around with/modify this list of doubles supplied and see what the consequent program output will be. The approach given in Lightweight Implementations of Probabilistic Programming Languages for probabilistic programming via trace mcmc, treats trace mcmc as a meta-program. It takes the source code of the program, and it takes the source of randomness, and it does some meta-level programming to rewrite the source code, naming each of the original random functions f_k, and then replacing them with deterministic functions f'_k. When these deterministic functions are encountered in the execution trace, they deterministically use their name to look up a current value x_k in a database and return it to be used as the random double. For each full run of the program, the doubles stored in the database are manipulated in order to perform inference.

However it turns out that if we use the free monad transformer, we are producing the abstract interpreter of the program within the language itself. In other words, the free monad transformer lets us internalize the approach described in the paper within the language, without any meta-programming. In monad-bayes, rather than naming each random function f_k to associate it with a random double stored in a database, we treat the “names” as the nodes/steps in the free monad transformer tree, each of which corresponds to an index in a list of random doubles – so the k’th step in the tree can be mapped to the k’th index in the list of random doubles. So instead of doing things by “name”, we create associations between random functions and their random doubles by order in a list. The list’s elements matches up with the sequential nature of the program.

5.2 Trace (Bayes.Traced.Common)

Let’s now inspect the traceDist field of Traced.

Bayes.Traced.Static

data Traced m a
  = Traced
      { ... ,
        traceDist :: m (Trace a)
      }

This involves looking at how the trace datatype Trace is defined:

Bayes.Traced.Common

-- | Collection of random variables sampled during the program's execution.
data Trace a
  = Trace
      { -- | Sequence of particular realisations of random variables sampled during the program's execution.
        variables :: !([Double]),
        -- |
        output :: !a,
        -- | The probability of observing this particular sequence.
        density :: {-# UNPACK #-} !(Log Double)
      }

This contains three things:

  • The variables are the list of random doubles that we supply to the model/program during execution, where each double corresponds to a computational step of randomness in the program. This is the list that we must modify values of in order to observe new results and then perform inference over.
  • The output is the result of running a given model using the current variables (the list of random doubles).
  • The density is the probability density of the sequence of doubles, variables, giving rise to the output b, given some external model.

6. The Trace Metropolis-Hastings Function

6.1 Trace Metropolis-Hastings

Let’s now have a look at how the Traced data type is used in the implementation of trace metropolis-hastings (a method of inference of the posterior), by inspecting the function mh.

Bayes.Traced.Static

-- | Full run of the Trace Metropolis-Hastings algorithm with a specified
--   number of steps.
mh :: MonadSample m => Int -> Traced m a -> m [a]
mh n (Traced m d) = fmap (map output) (f n)
  where
    f 0 = fmap (: []) d
    f k = do
      ~(x : xs) <- f (k -1)
      y <- mhTrans m x
      return (y : x : xs)

The function mh represents the traced metropolis-hastings algorithm, (where metropolis-hastings is a markov chain monte carlo method for obtaining a sequence of random samples from a probability distribution).

For a specified number of steps, n, it runs mhTrans on the model m (given by the field model :: Weighted (FreeSampler m a) of Traced) and the current trace x (given by extracting the result from the field traceDist :: m (Trace a) of Traced).

The function f then returns the history of all the different values of the Trace data type used during each call to mhTrans. By calling fmap (map output) on this list of Trace values, we extract all the outputs of the model produced during each metropolis-hastings step. Note that although the variables and the density of the current Trace are needed for each step of metropolis-hastings (see mhTrans) in order to “optimise” the Trace as we progress during inference, at the end of the entire metropolis-hastings algorithm, we are only interested in recording all the outputs which are accepted. The history of all the outputs are what form the markov chain, and in theory, these outputs are all the samples of which a histogram of will form our posterior distribution, i.e. our end goal.

6.2 A Single Step of Trace Metropolis-Hastings

The function mhTrans performs a single step of trace metropolis-hastings. Let’s now take a look at it:

Bayes.Traced.Common

-- | A single Metropolis-corrected transition of single-site Trace MCMC.
mhTrans :: MonadSample m => Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
mhTrans m t@Trace {variables = us, density = p} = do
  let n = length us
  us' <- do -- :: m [Log Double]
    i <- discrete $ discreteUniformAB 0 (n - 1)
    u' <- random
    let (xs, _ : ys) = splitAt i us
    return $ xs ++ (u' : ys)
  ((b, q), vs) <- runWriterT $ runWeighted $ Weighted.hoist (WriterT . withPartialRandomness us') m
  let ratio = (exp . ln) $ min 1 (q * fromIntegral n / (p * fromIntegral (length vs)))
  accept <- bernoulli ratio
  return $ if accept then Trace vs b q else t

This function takes as inputs the model we’re performing inference over (i.e. the free tree containing computational steps of random decisions), and the current trace (of which we are only interested in the variables and the density).

  1. The block of do-notation following us' <- do describes the act of reassigning the value of a double in the list of random doubles. The way we do this is we first randomly select an index in this list, given by i <- discrete $ discreteUniformAB 0 (n - 1). Next, we generate a new random value, given by u' <- random. Finally, we insert the new random value in the list, given by let (xs, _ : ys) = splitAt i us; return $ xs ++ (u' : ys).
  1. Given this updated list, we re-run the model with the new source of randomness, seen in ((b, q), vs) <- runWriterT $ runWeighted $ Weighted.hoist (WriterT . withPartialRandomness us') m, where b is the model’s observed output, vs is the list of random doubles used to run the model, and q is the probability density of the sequence of doubles vs used giving rise to the output b. (The density is the joint probability of getting all the realisations of the random variables).

    Note that us' and vs will be the same only if the list of randomness us' is the same length as the number of steps of random decisions made during the program (i.e. the same length as the number of nodes in the free monad transformer tree). If only a subset of random doubles is provided (i.e. the list is shorter than the tree), then the function withPartialRandomness will create new random values for us and append them to the list us'. If the list us' is too long, then the returned list vs will contain only the list of random doubles of us' used during execution (hence vs will always be the same length as the free tree).

  1. Then we determine whether the resulting probability density q of observing the current sequence of doubles used, vs', is better or worse than the probability density p of observing the previous sequence of doubles used, us. We do this by evaluating the “ratio” given by let ratio = (exp . ln) $ min 1 (q * fromIntegral n / (p * fromIntegral (length vs))), and then passing the ratio to a bernoulli function, which returns us a boolean telling us whether to accept or reject the new proposed sequence of doubles. If we accept, then we return the new trace data type, otherwise, we retain the old one.

Summary: Trace Metropolis-Hastings

The function mh hence is an iterative process of randomly updating the list of random doubles as a new proposal (stored inside the Trace data type), running the model to see whether the output has a better probability density of yielding the provided list of random doubles, and updating the Trace data type whenever the probability density satisfies the acceptance ratio. In a sense, we are keeping the execution trace which is best – because the execution trace fully defines the program, we know what the parameters are. Note that we never make changes to the model itself.

The reason trace mcmc can be viewed as “better” than normal mcmc, is that it is more general. In normal metropolis hastings, one has to specify what the parameters of the model are. In trace mcmc, the parameters automatically become the random doubles associated with each random computational step in the program – what we are doing is performing metropolis-hastings on the execution trace of the program itself. The trace gives us all the information about all the random decisions performed in the program, i.e. everything about the probabilistic nature of the program. Hence, the algorithm will work for all programming languages. The downside to performing inference over all the random decisions made in the program, is that not all random decisions may be related to the parameters of the model we are interested in.

For example, the following block of code inside the function mhTrans which randomly selects an index in the list of doubles to change, is a fairly naive, brute-force approach to inference.

Bayes.Traced.Common

  us' <- do -- :: m [Log Double]
    i <- discrete $ discreteUniformAB 0 (n - 1)
    u' <- random
    let (xs, _ : ys) = splitAt i us
    return $ xs ++ (u' : ys)

So how does one recognise which random decisions in the program are actually important to the parameters of our model? This is a question that currently needs more looking in to.

7. PMMH

Given a prior and a likelihood function, normal metropolis-hastings is an algorithm for producing samples from the posterior distribution, and hence allows us to approximate the posterior without having the analytic formulation of the normalising constant.

PMMH or PMCMC methods are normal MH/MCMC algorithms where we don’t have the analytic formulation of the likelihood function, so it is instead replaced by a particle filter as an approximation to the likelihood function (we can run a particle filter, and at the end, we can extract an estimate of the likelihood function from this). This means that at the highest level, metropolis-hastings is run to draw samples from the posterior distribution. Each step of the metropolis-hastings algorithm produces a proposal sample – in order to correctly decide whether a proposal sample is “representative enough” of the posterior distribution, each proposed sample has its likelihood calculated (by using the particle filter). This likelihood is what determines whether we accept or reject proposals.

Particle metropolis-hastings is implemented in the function pmmh:

Bayes.Inference.PMMH

pmmh ::
  MonadInfer m =>
  -- | number of Metropolis-Hastings steps
  Int ->
  -- | number of time steps
  Int ->
  -- | number of particles
  Int ->
  -- | model parameters prior
  Traced m b ->
  -- | model
  (b -> Sequential (Population m) a) ->
  m [[(a, Log Double)]]
pmmh t k n param model =
  mh t (param >>= runPopulation . pushEvidence . Pop.hoist Traced.Static.lift .
        smcSystematic k n . model)

We can imagine that PMMH as a function is something which runs metropolis-hastings (mh) given a prior (params) and a likelihood function (runPopulation . pushEvidence . Pop.hoist lift . smcSystematic k n . model).

Below, we elaborate on the different components of this function’s definition.

  • model - our model of type b -> Sequential (Population m) a where b represents the initial parameters given to our model, a will represent the type of sample we draw from our posterior distribution (which will most likely be the parameters of the model we’re trying to infer), and Sequential (Population m) a represents the particle filter computation.

  • smcSystematic - the particle filter (Sequential Monte Carlo) which uses systematic resampling at each timestep as a resampling method. Hence composing smcSystematic with model is running the particle filter on the model we give it.

    sir resampler k n = finish . composeCopies k (advance . hoistFirst resampler) . Seq.hoistFirst (spawn n >>)
    
    -- | Sequential Monte Carlo with systematic resampling at each timestep.
    -- Weights are not normalized.
    smcSystematic ::
      MonadSample m =>
      -- | number of timesteps
      Int ->
      -- | number of particles
      Int ->
      -- | model
      Sequential (Population m) a ->
      Population m a
    smcSystematic = sir resampleSystematic
    

    What we are doing with a particle filter is simulating a set of particles with their associated weights, and it turns out that summing these weights is an estimate of the likelihood (of the sequence of random values sampled during the program’s execution giving yield to the program’s output).

  • Pop.hoist Bayes.Traced.Static.lift has type Population m a -> Population n a – specifically in this case, it will have type Population m a -> Population (Traced m) a, i think, because lift will be from the MonadTrans instance of Traced.

    hoist ::
      (Monad m, Monad n) =>
      (forall x. m x -> n x) ->
      Population m a ->
      Population n a
    hoist f = fromWeightedList . f . runPopulation
    
  • pushEvidence

      -- | Push the evidence estimator as a score to the transformed monad.
    -- Weights are normalized after this operation.
    pushEvidence ::
      MonadCond m =>
      Population m a ->
      Population m a
    pushEvidence = (hoist applyWeight) . extractEvidence
    

    This will:

    1. Use extractEvidence to get the likelihood from the particle filter (as the sum of the weights of all the particles) and makes this likelihood explicit in the type Weighted m of Population (Weighted m) a. Recall that a Population is a list of weighted samples, but these weights are implicit in the Population transformer.

      extractEvidence   :: Monad m => Population m a -> Population (Weighted m) a
      
    2. Use hoist applyWeight to factor this likelihood as a score in the transformed monad m, which is possible due to m being an instance of MonadCond.

      hoist applyWeight :: MonadCond m => Population (Weighted m a) -> Population m a 
      
  • runPopulation

    runPopulation :: Population m a -> m [(a, Log Double)]
    runPopulation (Population m) = runListT $ runWeighted m
    

    Calling runPopulation will then extract an explicit representation of the weighted particles.

  • param >>= ... . model

    In this code snippet, the prior distribution of our parameters, param, is represented in the form Traced m b where the b represents the prior parameters – this b is monadically extracted and then passed to our model :: b -> Sequential (Population m) a .

  • mh t ...

    The metropolis hastings function mh is run for t amount of metropolis-hastings steps, using (param >>= runPopulation . pushEvidence . Pop.hoist Traced.Static.lift . smcSystematic k n . model) as a likelihood function to draw samples from the posterior distribution which are representative enough of the posterior.

8. Some brainstorming on how the Free Sampler is used

Let’s start off in pmmh

-- | Particle Marginal Metropolis-Hastings sampling.
pmmh ::
  MonadInfer m =>
  -- | number of Metropolis-Hastings steps
  Int ->
  -- | number of time steps
  Int ->
  -- | number of particles
  Int ->
  -- | model parameters prior
  Traced m b ->
  -- | model
  (b -> Sequential (Population m) a) ->
  m [[(a, Log Double)]]
pmmh t k n param model =
  mh t (param >>= runPopulation . pushEvidence . Pop.hoist Bayes.Traced.Static.lift .
        smcSystematic k n . model)
  • Our model has type b -> Sequential (Population m) a.

  • The partially applied function smcSystematic k n (the particle filter) will have type Sequential (Population m a) -> Population m a.

  • Pop.hoist Bayes.Traced.Static.lift will have type Population m a -> Population (Traced m) a.

When we called Bayes.Traced.Static.lift on m, lift has the following definition:

data Traced m a
  = Traced
      { -- "Weighted (FreeSampler m a)" is "StateT (Log Double) (FT SamF m) a"
        model     :: Weighted (FreeSampler m) a,
        traceDist :: m (Trace a)
      }

instance MonadTrans Traced where
  lift m = Traced (lift $ lift m) (fmap pure m)
  • The model parameter is achieved by:

    • First lifting m a into FreeSampler m a (which is a newtype for FT SamF m a). Hence, this lift is from FT.lift.

      Using the non-church encoding version of the free monad transformer, this looks like the following, where Pure :: a -> FreeF f a b, so liftM Pure :: m a -> m (FreeF f a b. This essentially creates a monad m (FreeF f a b) where the FreeF monad is a single Pure node.

      data FreeF f a b = Pure a | Free (f b)
      
      newtype FreeT f m a = FreeT { runFreeT :: m (FreeF f a (FreeT f m a)) }
      
      instance MonadTrans (FreeT f) where
        lift = FreeT . liftM Pure
      

      Using the church encoding version (which monad-bayes uses), this is in continuation passing style - we can see that it is analogous to the version above, as this constructs an FT containing a function which takes the continuation for the Pure constructor and binds this with the monad we’re lifting.

      newtype FT f m a = FT { runFT :: forall r. (a -> m r) -> (forall x. (x -> m r) -> f x -> m r) -> m r }
      
      instance MonadTrans (FT f) where
        lift m = FT (\mf _ -> m >>= mf)
      

      This means that we construct a FT tree with a Pure node

    • Then lifting this into Weighted (FreeSampler m) a (which is a newtype for StateT (Log Double) (FT SamF m) a).

  • The traceDist parameter is achieved by mapping the pure function for the Trace data type over m a to return m (Trace a).

    data Trace a
      = Trace
          { 
            variables :: [Double],
            output :: a,
            density :: Log Double
          }
    
    instance Applicative (TraceTCPS m) where
      pure x = TraceTCPS (\k -> k [] x 1)
    

When running mh, we take the model :: StateT (Log Double) (FT SamF m) a of the Traced data type, and pass it to mhTrans along with the most recent Trace data type computed during the previous step of metropolis-hastings.

mh :: MonadSample m => Int -> Traced m a -> m [a]
mh n (Traced m d) = fmap (map output) (f n)
  where
    f 0 = fmap (: []) d
    f k = do
      ~(x : xs) <- f (k -1)
      y <- mhTrans m x
      return (y : x : xs)

Inside mhTrans on the line ((b, q), vs) <- ..., this is where we use our model :: StateT (Log Double) (FT SamF m) a.

mhTrans :: MonadSample m => Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
mhTrans m t@Trace {variables = us, density = p} = do
  let n = length us
  us' <- do -- :: m [Log Double]
    i <- discrete $ discreteUniformAB 0 (n -1)
    u' <- random
    let (xs, _ : ys) = splitAt i us
    return $ xs ++ (u' : ys)
    -- Weighted.hoist (WriterT . withPartialRandomness us') 
    --   :: Weighted (FreeSampler m) a -> Weighted (WriterT [Double] m) a
  ((b, q), vs) <- runWriterT $ runWeighted $ Weighted.hoist (WriterT . withPartialRandomness us') m
  let ratio = (exp . ln) $ min 1 (q * fromIntegral n / (p * fromIntegral (length vs)))
  accept <- bernoulli ratio
  return $ if accept then Trace vs b q else t
  • The type of WriterT . withPartialRandomness us' is FreeSampler m a -> WriterT [Double] m a (or in other words, FT SamF m a -> WriterT [Double] m a).
  • Hence, the type of Weighted.hoist (WriterT . withPartialRandomness us') has type Weighted (FreeSampler m) a -> Weighted (WriterT [Double] m) a (or in other words, StateT (Log Double) (FT SamF m) a -> StateT (Log Double) (WriterT [Double] m) a).

The function withPartialRandomness is given below. It takes the list of randomness and a free sampler m :: FT SamF m a.

newtype SamF a = Random (Double -> a)

withPartialRandomness :: MonadSample m => [Double] -> FreeSampler m a -> m (a, [Double])
-- newtype FreeSampler m a = FreeSampler {runFreeSampler :: FT SamF m a}
withPartialRandomness randomness (FreeSampler m) =
-- f              :: (SamF (StateT [Double] (WriterT [Double] m) a) 
--                -> StateT [Double] (WriterT [Double] m) a)
-- iterTM         :: (f (t m a) -> t m a) -> FT f m a -> t m a
-- iterTM f       :: FT SamF (WriterT [Double] m) a -> StateT [Double] (WriterT [Double] m) a
-- hoistFT lift m :: FT SamF (WriterT [Double] m) a 
-- m              :: FT SamF m a
  runWriterT $ evalStateT (iterTM f $ hoistFT lift m) randomness
  where
  -- k  :: Double -> StateT [Double] (WriterT [Double] m) a
    f (Random k) = do
      -- This block runs in StateT [Double] (WriterT [Double] m) a.
      -- StateT propagates consumed randomness while WriterT records
      -- randomness used, whether old or new.

      -- Get the state, i.e. the list of randomness "xs"
      xs <- get
      x <- case xs of
        -- If the list of randomness is empty, then let x be a random double
        [] -> random
        -- If the list of randomness is non-empty, then update the state to be the
        -- tail of the list, and return the first random double value from the list.
        y : ys -> put ys >> return y
      -- Append the random value we just retrieved to the Writer log.
      tell [x]
      -- Apply the function "k" to the random double, to yield the type
      -- "StateT [Double] (WriterT [Double] m) a", 
      k x
  • hoistFT lift m lifts the monad m inside FT SamF m a to WriterT, yielding the type FT SamF (WriterT [Double] m) a.

    The type FT SamF m a unravels to m (FreeF SamF a (FreeT SamF m a)). So hoistFT lift m :: FT SamF (WriterT [Double] m) a unravels to (WriterT [Double] m) (FreeF SamF a (FreeT SamF (WriterT [Double] m) a)).

    data FreeF f a x = Pure a | Free (f x)
    
    newtype FreeT f m a = FreeT { runFreeT :: m (FreeF f a (FreeT f m a)) }
    
  • f has type SamF (StateT [Double] (WriterT [Double] m) a) -> StateT [Double] (WriterT [Double] m) a. In other words, it takes a SamF node containing a function k :: Double -> StateT [Double] (WriterT [Double] m) a. The description of what f does is given in comments in the code above.

  • iterTM has type (f (t m a) -> t m a) -> FT f m a -> t m a

  • iterTM f applied to hoistFT lift m tears down the free monad transformer using iteration over a transformer. Iterating over the tree with f constructs a StateT [Double] (WriterT [Double] m) a computation.

  • evalStateT then uses randomness as the initial state to run the stateful computation which iterTM f $ hoistFT lift m constructs.

The reason the random function k :: Double -> a specialises to the type Double -> StateT [Double] (WriterT [Double] m) a is because we define f to work in the context of StateT [Double] (WriterT [Double] m) a.

Last updated on 13 Nov 2020
Published on 13 Nov 2020