Documentation
1. What Does Monad-Bayes Do?
The general way of inferring a posterior distribution given a likelihood and a prior is an approximate inference technique called MCMC which are a whole class of methods to infer the posterior. A key point to note is that although we are inferring the posterior, we are also inferring an estimate of the likelihood as well using the data. Inference in general is the act of predicting the values of something which we don’t have access to directly, given some data. At the highest level, there is inference over the posterior distribution of the parameters.
Monad Bayes uses a particle filter to infer the likelihood of a given model, and uses metropolis-hastings (or more specifically, the trace metropolis-hastings algorithm) to infer the posterior distribution over any desired given quantity. (A particle filter is only needed when we do not have an exact formulation of the likelihood). Monad Bayes offers the user a very generic version of trace metropolis-hastings, and a particle filter – in general, these two components can be combined to produce different algorithms (such as PMMH or SMC). In theory, one could extend this implementation with one’s own altered version of metropolis-hastings. Particle filters and metropolis-hastings are just two possible base building blocks of a lot of statistical algorithms. In general, the idea of the library monad-bayes is very vague – at its core, it is an effectful, type-class approach to probalistic programming. Whilst monad-bayes has implementations of trace metropolis-hastings and particle filters within it, these are just building blocks of other algorithms and it is possible to use monad-bayes in many ways.
2. Building Blocks of Monad Bayes
2.1 Sampler (Bayes.Sampler)
-- | An 'ST' based random sampler using the @mwc-random@ package.
newtype SamplerST a = SamplerST (forall s. ReaderT (GenST s) (ST s) a)
A sampler is something which draws concrete values for random variables from a prior. This is constructed as the ReaderT
monad, where:
- The read-only environment is a global pseudo-random
GenST s
wheres
is the state of the pseudo-random number generator. - The monadic context is the state monad
ST s
that references the states
of the global pseudo-random number generatorGenST s
. Since the generator is mutable, we use theST
monad. A computation of typeST s a
returns a value of typea
, and execute in “thread”s
. The universality ofs
keeps objects inside theST
monad from leaking to the outside of theST
monad.
The ReaderT
monad then allows us to read from the random number generator (as the environment).
runReaderT :: ReaderT r m a -> r -> m a
-- or in this specific case:
runReaderT :: ReaderT (GenST s) (ST s) a -> GenST s -> ST s a
2.2 Weighted (Bayes.Weighted)
-- | Execute the program using the prior distribution, while accumulating likelihood.
newtype Weighted m a = Weighted (StateT (Log Double) m a)
-- StateT is more efficient than WriterT
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans, MonadSample)
A type of Weighted m a
represents a weighted sample or particle wrapped within some monadic context m
; this is represented by the StateT
monad where the state is some weight Log Double
. Thereotically, the writer monad transformer can be used on its own (but the state monad transformer was apparently faster). Hence, calling runWeighted
will return a tuple containing a sample/particle value a
and its associated weight Log Double
.
Weighting a representation m
equips it with conditioning operation making W m a
conditioning representation. If m
was a sampling representation, its weighted version is also a sampling representation by lifting the sampling operation.
-
runWeighted
- callingrunWeighted
then runs theStateT
monad using 1 as the initial weight. This lets us obtain anm
-computation returning pairs(a, Log Double)
of the result typea
and its explicit, accumulated log-likelihoodLog Double
.-- | Obtain an explicit value of the likelihood for a given value. runWeighted :: (Monad m) => Weighted m a -> m (a, Log Double) runWeighted (Weighted m) = runStateT m 1
-
score
- the functionscore
will multiply the current weight held withinWeighted m a
by a given factor.instance Monad m => MonadCond (Weighted m) where score w = Weighted (modify (* w))
-
hoist
- the weighted version of the functionhoist
lifts inference transformations applicable tom
and turns them into an inference transformations applicable toT m
. Whenm
is already a conditioning representation, we may use the conditioning available through the weighting transformer, or hoistscore r
to use the ambient conditioning ofm
.hoist :: (forall x. m x -> n x) -> Weighted m a -> Weighted n a hoist f (Weighted m) = Weighted (mapStateT f m)
We can construct a simple inference algorithm by interpreting the model in Weighted Sampler a
, which when run, lets us obtain a weighted sampler of type Sampler (a, Log Double)
.
2.3 Population (Bayes.Population)
-- | A collection of weighted particles.
newtype Population m a = Population (Weighted (ListT m) a)
-- | Explicit representation of the weighted sample with weights in the log
-- domain.
runPopulation :: Monad m => Population m a -> m [(a, Log Double)]
runPopulation (Population m) = observeAllT $ runWeighted m
A population is a list of weighted particles, wrapped within some monadic context m
. Hence, calling runPopulation
will return a list of tuples containing particle values a
and their associated weights Log Double
. This is used during the particle filter.
-
hoist
- The population version ofhoist
will:- First, run the population to acquire a list of weighted particles
m [(a, Log Double)]
. - Applies the inference transformation
f
to given [(a, Log Double)]
- And then finally applies
fromWeightedList
which first convertsn [(a, Log Double)]
to the list transformer versionListT n (a, Log Double)
, then applieswithWeight
to create the typeWeighted (ListT n) a
, which is a population (a list of weighted particles).
hoist :: (forall x. m x -> n x) -> Population m a -> Population n a hoist f (Population m) = fromWeightedList . f . runPopulation -- | Initialize 'Population' with a concrete weighted sample. fromWeightedList :: Monad n => n [(a, Log Double)] -> Population n a fromWeightedList = Population . withWeight . ListT -- | Embed a random variable with explicitly given likelihood. -- > runWeighted . withWeight = id withWeight :: (Monad n) => n (a, Log Double) -> Weighted n a withWeight m = Weighted $ do (x, w) <- lift m modify (* w) return x
- First, run the population to acquire a list of weighted particles
We use three inference transformations associated with Population
:
-
spawn
spawn :: Monad m => Int -> Population m () spawn n = fromWeightedList $ pure $ replicate n ((), 1 / fromIntegral n)
This increases the sample size (number of particles in the population) by a given factor. The weights of the particles are adjusted such that their sum is preserved. It is therefore safe to use ‘spawn’ in arbitrary places in the program without introducing bias.
-
resample
resampleSystematic :: MonadSample m => Population m a -> Population m a
This resamples the population using the underlying monad
m
and a systematic resampling scheme - i.e. it draws a new population with uniform weights from the current population (the total weight of the set of particles is preserved).Resampling’s purpose is to remedy situations when a single sample has a large weight compared to the other particles in the population and dominates the result making the other particles irrelevant.
-
pushEvidence
pushEvidence :: MonadCond m => Population m a -> Population m a pushEvidence = (hoist applyWeight) . extractEvidence
This normalizes the weights in the population of particles, while at the same time incorporating the sum of the weights as a
score
inm
, which is some monadic computation that can be conditioned.The first function
extractEvidence
:-
Takes a population (a list of weighted particles)
-
Sums up the weights of all the particles in
z = sum ps
(summing these weights is an estimate of the likelihood of the sequence of random variables sampled during the program’s execution giving rise to the program’s output). -
Normalizes the weights of the particles, giving
ws
. -
Explicitly separates the sum of the weights of the population (the likelihood) into the
Weighted
transformer, by callingfactor z
(which I assume will always just multiply the initial weight 1 by z). -
Returns the set of particles and their normalized weights, where the likelihood
z
is separated and stored in theWeighted m
.-- | Separate the sum of weights into the 'Weighted' transformer. -- Weights are normalized after this operation. extractEvidence :: Monad m => Population m a -> Population (Weighted m) a extractEvidence m = fromWeightedList $ do pop <- lift $ runPopulation m let (xs, ps) = unzip pop let z = sum ps let ws = map (if z > 0 then (/ z) else const (1 / fromIntegral (length ps))) ps factor z return $ zip xs ws
This results in a return type
Population (Weighted m) a
, which when run, will unwrap to becomem (Weighted (List (Weighted a)))
- i.e. the list of weighted particles and the likelihood.
The second function
applyWeight
:-
When hoisted to the type
Population (Weighted m) a -> Population m a
, will incorporate the sum of the weights of the particles in the population (the likelihood stored in theWeighted m
) as a score in the monadm
. This then discards the likelihood, returning justm a
.-- | Use the weight as a factor in the transformed monad. applyWeight :: MonadCond m => Weighted m a -> m a applyWeight m = do (x, w) <- runWeighted m factor w return x
-
2.4 Sequential (Bayes.Sequential)
-- | Represents a computation that can be suspended at certain points.
-- The intermediate monadic effects can be extracted, which is particularly
-- useful for implementation of particle filter related methods.
-- All the probabilistic effects are lifted from the transformed monad, but
-- also `suspend` is inserted after each `factor`/`score`.
newtype Sequential m a = Sequential {runSequential :: Coroutine (Await ()) m a}
deriving (Functor, Applicative, Monad, MonadTrans, MonadIO)
Many models exhibit a sequential structure where observations are interleaved with sampling. In those models a possible inference strategy is to consider a program up to a certain point, do inference on the partial posterior it defines, then run the program a little more, do more inference, and so on. To implement such algorithms we introduce the sequential transformer Sequential
which introduces suspensions after each score in the program.
There are two hoisting functions associated with Sequential
:
-
The function
hoistFirst
applies the inference transformationf
only to the part of the program executed so far.-- | Transform the inner monad. -- This operation only applies to computation up to the first suspension. hoistFirst :: (forall x. m x -> m x) -> Sequential m a -> Sequential m a hoistFirst f = Sequential . Coroutine . f . resume . runSequential
-
The function
hoist
will apply the inference transformationf
recursively to all the suspension points of the program.-- | Transform the inner monad. -- The transformation is applied recursively through all the suspension points. hoist :: (Monad m, Monad n) => (forall x. m x -> n x) -> Sequential m a -> Sequential n a hoist f = Sequential . mapMonad f . runSequential
There are two inference transformations associated with Sequential
:
-
advance
- theadvance
transformation runs the program to the next suspension point.-- | Execute to the next suspension point. -- If the computation is finished, do nothing. -- > finish = finish . advance advance :: Monad m => Sequential m a -> Sequential m a advance (Sequential m) = Sequential (m >>= either (return . Left) runSequential)
-
finish
- thefinish
transformation runs the program to the end.-- | Remove the remaining suspension points. finish :: Monad m => Sequential m a -> m a finish (Sequential m) = Sequential (m >>= either return finish)
3. Particle Filter
By combining the inference transformers Sequential
and Population
, we obtain a particle filter (SMC). Within the context of a particle filter, a sample in a population is called a particle. The particle filter is implemented as the function sir
, which is a template for particle filtering that takes a custom resampling method.
-- | Sequential importance resampling - an SMC template that takes a custom resampler.
sir :: Monad m =>
-- | resampler
(forall x. Population m x -> Population m x) ->
-- | number of timesteps
Int ->
-- | population size
Int ->
-- | model
Sequential (Population m) a ->
Population m a
sir resampler k n = finish . composeCopies k (advance . hoistFirst resampler) . Seq.hoistFirst (spawn n >>)
-- | Apply a function a given number of times.
composeCopies :: Int -> (a -> a) -> (a -> a)
composeCopies k f = foldr (.) id (replicate k f)
-
The algorithm starts by creating a monadic computation which initialises a population of size
n
and awaits its next population related computation. This monadic computation is then wrapped within theSequential
transformer, by hoisting the computation up to the part of the program executed so far (i.e. the first suspension in the currentSequential
monad).Seq.hoistFirst (spawn n >>)
When this is passed a model of type
Sequential (Population m) a
, this then represents a computation of where n particles are spawned (creatingSequential (Population m) ()
and then the model is run.
-
The function
composeCompies :: Int -> (a -> a) -> (a -> a)
will apply a given functionk
times. In this case, the function is(advance . hoistFirst resampler)
– when provided a modelSequential (Population m) a
, it will resample the population and then execute the model one time-step (i.e. executes the model to the next suspension point). ApplyingcomposeCopies k
will hence resample and then progress the model after every time step fork
time-steps.composeCopies k (advance . hoistFirst resampler)
- The function
finish
will then run the rest ofSequential m a
program to the end, removing the remaining suspension points, which leaves us with the typem a
.
4. The Free Monad Transformer
4.1 The Free Monad
FreeF
is the fixed point version of the free monad Free
.
data FreeF f a x = Pure a | Free (f x)
By wrapping a functor within the free monad, this allows us to treat the functor (or technically Free f
) as a monad, i.e. an effectful computation. By constructing values of type Free f
, we are simply constructing a data type in the form of a syntax tree. The important thing to note is that this tree on its own has no computation associated with it – it exists as a data type. How we evaluate a free monad tree is determined by what interpreter functions we choose to define. This means that we have decoupled the syntax of our program from its semantics.
When binding over a free monad, (Free f a) -> (a -> Free f b) -> Free f b
, this says: execute the effectful computation Free f a
, extract the value a
from this, and execute another effectful computation Free f b
.
4.2 The Free Monad Transformer
Using normal free monads, we can build abstract syntax trees which let us abstract away the interpreter, however, sometimes we can’t specify the syntax tree all at once. Often we want to interleave the syntax tree with some other monad to generate streaming or interactive computations. The free monad transformer FreeT
solves this problem by allowing us to mix building steps of the abstract syntax tree with calling actions in some base monad.
data FreeF f a x = Pure a | Free (f x)
newtype FreeT f m a = FreeT {
runFreeT :: m (FreeF f a (FreeT f m a)) }
We can see that FreeT
gives us something very similar to the normal free monad, but with a monad m
on the outside, where m
is the extra effect. This creates a tree where we go through it and execute some effects, but these effects are executed with respect to some other effect; so every step of computation (the nodes of the tree) is wrapped in an effect itself.
4.3 An Example Use of Free Monad Transformers
For example, let’s say we want to write our own Python-style generator. (Python generators are essentially coroutines. They are functions which when passed arguments, will run until encountering the yield
keyword. Upon the yield
keyword, it will return whatever value is “yielded”, and also save the state of the function at that point (under the name of the original function). Then, upon recalling the function, it will execute from immediately after the yield.)
type Generator a m r = FreeT ((,) a) m r
The type Generator
is the free monad transformer where the functor f
is ((,) a)
i.e. a partially applied tuple with the type a
as its first argument, the monad is some arbitrary monad m
as an effect, and it produces a result of arbitrary type r
. Our functor being ((,) a)
means that the first element of the tuple of type a
is the value that can be yielded, and the second element of the tuple represents the rest of the computation to be run (i.e. the continuation). The free monad transformer tree is a recursive structure that sequentially composes zero or more operations of the given functor ((,) a)
embedded within the context of monad m
. The leaves of this tree are given by Pure
of type a
, and the nodes are operations Free
that are shaped by ((,) a)
where the first element can be a yielded value and the second element contains the rest of the computation which can be run. The result of return x
is a leaf, and (>>=)
grows the tree of operations at its leaves.
Let’s now look at an implementation of yield
:
yield :: a -> Generator a m ()
yield a = liftF (a, ())
The function yield
takes a value of type a
that we want to yield, and lifts (b, ())
into the free monad transformer type FreeT ((,) a) m ()
– when we monadically extract the value from this, we get a tuple containing the yielded value a
and a unit ()
representing the rest (or in this case, the end) of the computation.
Let’s now look at an implementation of prompt
:
prompt :: Generator String IO ()
prompt = do
lift $ putStrLn "Enter a string:"
str <- lift getLine
yield str
The function prompt
takes a free monad transformer tree where the nodes are shaped by the functor ((,) String
meaning each node contains a string and some arbitrary type r
– however in this case, r
will be ()
. It also sets the monadic context of our tree to be IO
, allowing us to attach IO
actions to the operations/nodes of the tree. We choose to attach the IO
effect of printing "Enter a string:"
to the terminal and then prompting a terminal input from the user – this input is then yielded.
Next, we look at how the free monad transformer tree prompt
can be run:
putStrLnAllTheThings :: Show r => FreeT ((,) String) IO r -> IO r
putStrLnAllTheThings gen = do
x <- runFreeT gen
case x of
Pure r
-> do putStrLn $ "Result: " ++ show r
return r
Free (str, gen' :: FreeT ((,) String) IO r)
-> do putStrLn $ "User entered: " ++ str
putStrLnAllTheThings gen'
main = putStrLnAllTheThings prompt
The function putStrLnAllTheThings
takes a free monad transformer tree as an argument in the same form as prompt
except allowing the result type r
to be arbitrary (rather than fixed as ()
). Here is an example run of main
:
Enter a string:
hello
User entered: hello
Result: ()
If we wanted to define prompt
to run forever, we could write the following, where forever
executes the operation it is passed, ignores its result, and then recurses on itself:
prompt :: Generator String IO r
prompt = forever $ do
lift $ putStrLn "Enter a string:"
str <- lift getLine
yield str
putStrLnAllTheThings :: FreeT ((,) String) IO r -> IO r
putStrLnAllTheThings gen = do
x <- runFreeT gen
case x of
Pure r
-> do return r
Free (str, gen' :: FreeT ((,) String) IO r)
-> do putStrLn $ "User entered: " ++ str
putStrLnAllTheThings gen'
main = putStrLnAllTheThings prompt
Summary: Free Monad Transformer
The free monad transformer means that every step of the computation within the tree can be automatically embedded in an arbitrary monadic effect. This lets us be both abstract in the interpreter as well as the monadic effect we want to attach to the interpreter.
5. Traced (Bayes.Traced.Static)
The following monad Traced
is a tracing monad where only a subset of random choices are traced.
Bayes.Traced.Static
data Traced m a
= Traced
{ model :: Weighted (FreeSampler m) a,
traceDist :: m (Trace a)
}
The notion here is that we have some model, such as an SIR model, which makes random decisions at certain points of the program. We want to be able to decouple what it means to make random choices (the actual implementation of randomness) from the syntax. The syntax is the description of the model but where the model’s stochastic elements (i.e calls to random functions such as distributions) are just names. By supplying the syntax with randomness, we use the description of the stochastic model in a stochatic manner. This distinction can be described as the model as syntax vs the model as a realisation of a stochastic proces.
We want the syntax to describe doing some action, but we want what those random choices are to be chosen elsewhere. For example, consider if we have a program that needs to make a random effect, but we want that randomness to come from different places i.e. we want to give it different ways of generating random numbers at each point. To elaborate, the program using the same deterministic random number at every point of random decision, versus the program using -user-rand
at every point of random decision, are two distinctly different ways to do randomness.
The reason this is important with respect to tracing in probabilistic programming, is that the way one performs inference over an execution trace requires one to modify the trace of the program, the trace being all the random decisions made. Hence, we don’t make any changes to the syntax of the program, we only work with the random decisions performed during the program (by making random changes to these random decisions). The first thing we need is a way to decouple these two things, which is what the Traced
datatype achieves by having two distinct fields:
-
The
model
field, which is a description of the suspended model as a syntax, separated from its random decisions. The model is the abstract interpreter which defines our program, augmented in this free monad transformer with where the random decisions to be made would be. One could imagine this as a tree of random decisions waiting for their source of randomness. This tree is generated from the model specification. The implementation of the model itself will most likely be associated with drawing from various distributions, which are themselves random decisions, but for such random actions to be performed, we first need to provide a source of randomness as a “seed”. -
The
traceDist
field, which provides the model with a source of randomness in order to realise the model as a stochastic process.
We will now elaborate on what these two fields of Traced
are:
5.1 FreeSampler (Bayes.Free)
Let’s first inspect the model
field of Traced
.
Bayes.Traced.Static
data Traced m a
= Traced
{ model :: Weighted (FreeSampler m) a,
... }
This involves us looking at what the type FreeSampler
is.
Bayes.Free
-- | Random sampling functor.
newtype SamF a = Random (Double -> a)
instance Functor SamF where
fmap f (Random k) = Random (f . k)
-- | Free monad transformer over random sampling.
newtype FreeSampler m a = FreeSampler {runFreeSampler :: FT SamF m a}
deriving (Functor, Applicative, Monad, MonadTrans)
The type FreeSampler
is essentially the free monad transformer FT
where the functor (shape of the tree’s nodes) is specified as the type SamF
, and the monadic context we wrap our nodes in is left abstract.
The type SamF
is the random sampling functor, which represents the idea of a random effect. It captures the notion that we provide a (random) Double
and it returns some value of type a
, i.e. the notion that we make a random decision at some point.
The type FT SamF m a
is hence the free monad tree of random decisions. In some sense, this is the composition of a bunch of functions which take Double
s and return us a
s. Imagine a tree where every step of the tree we provide it a random number as a Double
which is then mapped to a value as an a
– we can then consider this a
to be a random value. The type a
represents whatever we want samples of. We can think of (Double -> a)
as a specification for a random number generator - a random number generator takes some seed as a double and produces a random value in type a
. Specifically in this context, this function Double -> a
will specialise to StateT [Double] (WriterT [Double] m) a
. (This is talked about later, in the function withPartialRandomness
).
“So what does the free sampler FT SamF m a
allow us to do?” Let’s take a look at the function withRandomness
: this executes a computation with supplied values for random choices.
Bayes.Free
-- | Execute computation with supplied values for random choices.
withRandomness :: Monad m => [Double] -> FreeSampler m a -> m a
withRandomness randomness (FreeSampler m) = evalStateT (iterTM f m) randomness
where
f (Random k) = do
xs <- get
case xs of
[] -> error "FreeSampler: the list of randomness was too short"
y : ys -> put ys >> k y
It takes a list of random doubles and a free monad tree of random decisions waiting to happen. We can imagine that m
is the computation we want to execute, and embedded in this is some source of randomness given by the SamF
functor, or more specifically, the Random k
where k :: Double -> a
– however, we need to provide a random Double
to k
each time in order to achieve this randomness. How this list of random Double
s is generated in the first place, is up to the user.
What trace mcmc will do in order to do inference, is mess around with/modify this list of doubles supplied and see what the consequent program output will be. The approach given in Lightweight Implementations of Probabilistic Programming Languages for probabilistic programming via trace mcmc, treats trace mcmc as a meta-program. It takes the source code of the program, and it takes the source of randomness, and it does some meta-level programming to rewrite the source code, naming each of the original random functions f_k
, and then replacing them with deterministic functions f'_k
. When these deterministic functions are encountered in the execution trace, they deterministically use their name to look up a current value x_k
in a database and return it to be used as the random double. For each full run of the program, the doubles stored in the database are manipulated in order to perform inference.
However it turns out that if we use the free monad transformer, we are producing the abstract interpreter of the program within the language itself. In other words, the free monad transformer lets us internalize the approach described in the paper within the language, without any meta-programming. In monad-bayes, rather than naming each random function f_k
to associate it with a random double stored in a database, we treat the “names” as the nodes/steps in the free monad transformer tree, each of which corresponds to an index in a list of random doubles – so the k’th step in the tree can be mapped to the k’th index in the list of random doubles. So instead of doing things by “name”, we create associations between random functions and their random doubles by order in a list. The list’s elements matches up with the sequential nature of the program.
5.2 Trace (Bayes.Traced.Common)
Let’s now inspect the traceDist
field of Traced
.
Bayes.Traced.Static
data Traced m a
= Traced
{ ... ,
traceDist :: m (Trace a)
}
This involves looking at how the trace datatype Trace
is defined:
Bayes.Traced.Common
-- | Collection of random variables sampled during the program's execution.
data Trace a
= Trace
{ -- | Sequence of particular realisations of random variables sampled during the program's execution.
variables :: !([Double]),
-- |
output :: !a,
-- | The probability of observing this particular sequence.
density :: {-# UNPACK #-} !(Log Double)
}
This contains three things:
- The
variables
are the list of random doubles that we supply to the model/program during execution, where each double corresponds to a computational step of randomness in the program. This is the list that we must modify values of in order to observe new results and then perform inference over. - The
output
is the result of running a given model using the current variables (the list of random doubles). - The
density
is the probability density of the sequence of doubles,variables
, giving rise to the outputb
, given some external model.
6. The Trace Metropolis-Hastings Function
6.1 Trace Metropolis-Hastings
Let’s now have a look at how the Traced
data type is used in the implementation of trace metropolis-hastings (a method of inference of the posterior), by inspecting the function mh
.
Bayes.Traced.Static
-- | Full run of the Trace Metropolis-Hastings algorithm with a specified
-- number of steps.
mh :: MonadSample m => Int -> Traced m a -> m [a]
mh n (Traced m d) = fmap (map output) (f n)
where
f 0 = fmap (: []) d
f k = do
~(x : xs) <- f (k -1)
y <- mhTrans m x
return (y : x : xs)
The function mh
represents the traced metropolis-hastings algorithm, (where metropolis-hastings is a markov chain monte carlo method for obtaining a sequence of random samples from a probability distribution).
For a specified number of steps, n
, it runs mhTrans
on the model m
(given by the field model :: Weighted (FreeSampler m a)
of Traced
) and the current trace x
(given by extracting the result from the field traceDist :: m (Trace a)
of Traced
).
The function f
then returns the history of all the different values of the Trace
data type used during each call to mhTrans
. By calling fmap (map output)
on this list of Trace
values, we extract all the output
s of the model produced during each metropolis-hastings step. Note that although the variables
and the density
of the current Trace
are needed for each step of metropolis-hastings (see mhTrans
) in order to “optimise” the Trace
as we progress during inference, at the end of the entire metropolis-hastings algorithm, we are only interested in recording all the output
s which are accepted. The history of all the output
s are what form the markov chain, and in theory, these output
s are all the samples of which a histogram of will form our posterior distribution, i.e. our end goal.
6.2 A Single Step of Trace Metropolis-Hastings
The function mhTrans
performs a single step of trace metropolis-hastings. Let’s now take a look at it:
Bayes.Traced.Common
-- | A single Metropolis-corrected transition of single-site Trace MCMC.
mhTrans :: MonadSample m => Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
mhTrans m t@Trace {variables = us, density = p} = do
let n = length us
us' <- do -- :: m [Log Double]
i <- discrete $ discreteUniformAB 0 (n - 1)
u' <- random
let (xs, _ : ys) = splitAt i us
return $ xs ++ (u' : ys)
((b, q), vs) <- runWriterT $ runWeighted $ Weighted.hoist (WriterT . withPartialRandomness us') m
let ratio = (exp . ln) $ min 1 (q * fromIntegral n / (p * fromIntegral (length vs)))
accept <- bernoulli ratio
return $ if accept then Trace vs b q else t
This function takes as inputs the model we’re performing inference over (i.e. the free tree containing computational steps of random decisions), and the current trace (of which we are only interested in the variables and the density).
- The block of do-notation following
us' <- do
describes the act of reassigning the value of a double in the list of random doubles. The way we do this is we first randomly select an index in this list, given byi <- discrete $ discreteUniformAB 0 (n - 1)
. Next, we generate a new random value, given byu' <- random
. Finally, we insert the new random value in the list, given bylet (xs, _ : ys) = splitAt i us; return $ xs ++ (u' : ys)
.
-
Given this updated list, we re-run the model with the new source of randomness, seen in
((b, q), vs) <- runWriterT $ runWeighted $ Weighted.hoist (WriterT . withPartialRandomness us') m
, whereb
is the model’s observed output,vs
is the list of random doubles used to run the model, andq
is the probability density of the sequence of doublesvs
used giving rise to the outputb
. (The density is the joint probability of getting all the realisations of the random variables).Note that
us'
andvs
will be the same only if the list of randomnessus'
is the same length as the number of steps of random decisions made during the program (i.e. the same length as the number of nodes in the free monad transformer tree). If only a subset of random doubles is provided (i.e. the list is shorter than the tree), then the functionwithPartialRandomness
will create new random values for us and append them to the listus'
. If the listus'
is too long, then the returned listvs
will contain only the list of random doubles ofus'
used during execution (hencevs
will always be the same length as the free tree).
- Then we determine whether the resulting probability density
q
of observing the current sequence of doubles used,vs'
, is better or worse than the probability densityp
of observing the previous sequence of doubles used,us
. We do this by evaluating the “ratio” given bylet ratio = (exp . ln) $ min 1 (q * fromIntegral n / (p * fromIntegral (length vs)))
, and then passing the ratio to a bernoulli function, which returns us a boolean telling us whether to accept or reject the new proposed sequence of doubles. If we accept, then we return the new trace data type, otherwise, we retain the old one.
Summary: Trace Metropolis-Hastings
The function mh
hence is an iterative process of randomly updating the list of random doubles as a new proposal (stored inside the Trace
data type), running the model to see whether the output has a better probability density of yielding the provided list of random doubles, and updating the Trace
data type whenever the probability density satisfies the acceptance ratio. In a sense, we are keeping the execution trace which is best – because the execution trace fully defines the program, we know what the parameters are. Note that we never make changes to the model itself.
The reason trace mcmc can be viewed as “better” than normal mcmc, is that it is more general. In normal metropolis hastings, one has to specify what the parameters of the model are. In trace mcmc, the parameters automatically become the random doubles associated with each random computational step in the program – what we are doing is performing metropolis-hastings on the execution trace of the program itself. The trace gives us all the information about all the random decisions performed in the program, i.e. everything about the probabilistic nature of the program. Hence, the algorithm will work for all programming languages. The downside to performing inference over all the random decisions made in the program, is that not all random decisions may be related to the parameters of the model we are interested in.
For example, the following block of code inside the function mhTrans
which randomly selects an index in the list of doubles to change, is a fairly naive, brute-force approach to inference.
Bayes.Traced.Common
us' <- do -- :: m [Log Double]
i <- discrete $ discreteUniformAB 0 (n - 1)
u' <- random
let (xs, _ : ys) = splitAt i us
return $ xs ++ (u' : ys)
So how does one recognise which random decisions in the program are actually important to the parameters of our model? This is a question that currently needs more looking in to.
7. PMMH
Given a prior and a likelihood function, normal metropolis-hastings is an algorithm for producing samples from the posterior distribution, and hence allows us to approximate the posterior without having the analytic formulation of the normalising constant.
PMMH or PMCMC methods are normal MH/MCMC algorithms where we don’t have the analytic formulation of the likelihood function, so it is instead replaced by a particle filter as an approximation to the likelihood function (we can run a particle filter, and at the end, we can extract an estimate of the likelihood function from this). This means that at the highest level, metropolis-hastings is run to draw samples from the posterior distribution. Each step of the metropolis-hastings algorithm produces a proposal sample – in order to correctly decide whether a proposal sample is “representative enough” of the posterior distribution, each proposed sample has its likelihood calculated (by using the particle filter). This likelihood is what determines whether we accept or reject proposals.
Particle metropolis-hastings is implemented in the function pmmh
:
Bayes.Inference.PMMH
pmmh ::
MonadInfer m =>
-- | number of Metropolis-Hastings steps
Int ->
-- | number of time steps
Int ->
-- | number of particles
Int ->
-- | model parameters prior
Traced m b ->
-- | model
(b -> Sequential (Population m) a) ->
m [[(a, Log Double)]]
pmmh t k n param model =
mh t (param >>= runPopulation . pushEvidence . Pop.hoist Traced.Static.lift .
smcSystematic k n . model)
We can imagine that PMMH as a function is something which runs metropolis-hastings (mh
) given a prior (params
) and a likelihood function (runPopulation . pushEvidence . Pop.hoist lift . smcSystematic k n . model
).
Below, we elaborate on the different components of this function’s definition.
-
model
- our model of typeb -> Sequential (Population m) a
whereb
represents the initial parameters given to our model,a
will represent the type of sample we draw from our posterior distribution (which will most likely be the parameters of the model we’re trying to infer), andSequential (Population m) a
represents the particle filter computation. -
smcSystematic
- the particle filter (Sequential Monte Carlo) which uses systematic resampling at each timestep as a resampling method. Hence composingsmcSystematic
withmodel
is running the particle filter on the model we give it.sir resampler k n = finish . composeCopies k (advance . hoistFirst resampler) . Seq.hoistFirst (spawn n >>) -- | Sequential Monte Carlo with systematic resampling at each timestep. -- Weights are not normalized. smcSystematic :: MonadSample m => -- | number of timesteps Int -> -- | number of particles Int -> -- | model Sequential (Population m) a -> Population m a smcSystematic = sir resampleSystematic
What we are doing with a particle filter is simulating a set of particles with their associated weights, and it turns out that summing these weights is an estimate of the likelihood (of the sequence of random values sampled during the program’s execution giving yield to the program’s output).
-
Pop.hoist Bayes.Traced.Static.lift
has typePopulation m a -> Population n a
– specifically in this case, it will have typePopulation m a -> Population (Traced m) a
, i think, becauselift
will be from theMonadTrans
instance ofTraced
.hoist :: (Monad m, Monad n) => (forall x. m x -> n x) -> Population m a -> Population n a hoist f = fromWeightedList . f . runPopulation
-
pushEvidence
-- | Push the evidence estimator as a score to the transformed monad. -- Weights are normalized after this operation. pushEvidence :: MonadCond m => Population m a -> Population m a pushEvidence = (hoist applyWeight) . extractEvidence
This will:
-
Use
extractEvidence
to get the likelihood from the particle filter (as the sum of the weights of all the particles) and makes this likelihood explicit in the typeWeighted m
ofPopulation (Weighted m) a
. Recall that aPopulation
is a list of weighted samples, but these weights are implicit in thePopulation
transformer.extractEvidence :: Monad m => Population m a -> Population (Weighted m) a
-
Use
hoist applyWeight
to factor this likelihood as a score in the transformed monadm
, which is possible due tom
being an instance ofMonadCond
.hoist applyWeight :: MonadCond m => Population (Weighted m a) -> Population m a
-
-
runPopulation
runPopulation :: Population m a -> m [(a, Log Double)] runPopulation (Population m) = runListT $ runWeighted m
Calling
runPopulation
will then extract an explicit representation of the weighted particles.
-
param >>= ... . model
In this code snippet, the prior distribution of our parameters,
param
, is represented in the formTraced m b
where theb
represents the prior parameters – thisb
is monadically extracted and then passed to ourmodel :: b -> Sequential (Population m) a
.
-
mh t ...
The metropolis hastings function
mh
is run fort
amount of metropolis-hastings steps, using(param >>= runPopulation . pushEvidence . Pop.hoist Traced.Static.lift . smcSystematic k n . model)
as a likelihood function to draw samples from the posterior distribution which are representative enough of the posterior.
8. Some brainstorming on how the Free Sampler is used
Let’s start off in pmmh
-- | Particle Marginal Metropolis-Hastings sampling.
pmmh ::
MonadInfer m =>
-- | number of Metropolis-Hastings steps
Int ->
-- | number of time steps
Int ->
-- | number of particles
Int ->
-- | model parameters prior
Traced m b ->
-- | model
(b -> Sequential (Population m) a) ->
m [[(a, Log Double)]]
pmmh t k n param model =
mh t (param >>= runPopulation . pushEvidence . Pop.hoist Bayes.Traced.Static.lift .
smcSystematic k n . model)
-
Our
model
has typeb -> Sequential (Population m) a
. -
The partially applied function
smcSystematic k n
(the particle filter) will have typeSequential (Population m a) -> Population m a
. -
Pop.hoist Bayes.Traced.Static.lift
will have typePopulation m a -> Population (Traced m) a
.
When we called Bayes.Traced.Static.lift
on m
, lift
has the following definition:
data Traced m a
= Traced
{ -- "Weighted (FreeSampler m a)" is "StateT (Log Double) (FT SamF m) a"
model :: Weighted (FreeSampler m) a,
traceDist :: m (Trace a)
}
instance MonadTrans Traced where
lift m = Traced (lift $ lift m) (fmap pure m)
-
The
model
parameter is achieved by:-
First lifting
m a
intoFreeSampler m a
(which is a newtype forFT SamF m a
). Hence, thislift
is fromFT.lift
.Using the non-church encoding version of the free monad transformer, this looks like the following, where
Pure :: a -> FreeF f a b
, soliftM Pure :: m a -> m (FreeF f a b
. This essentially creates a monadm (FreeF f a b)
where theFreeF
monad is a singlePure
node.data FreeF f a b = Pure a | Free (f b) newtype FreeT f m a = FreeT { runFreeT :: m (FreeF f a (FreeT f m a)) } instance MonadTrans (FreeT f) where lift = FreeT . liftM Pure
Using the church encoding version (which monad-bayes uses), this is in continuation passing style - we can see that it is analogous to the version above, as this constructs an
FT
containing a function which takes the continuation for thePure
constructor and binds this with the monad we’re lifting.newtype FT f m a = FT { runFT :: forall r. (a -> m r) -> (forall x. (x -> m r) -> f x -> m r) -> m r } instance MonadTrans (FT f) where lift m = FT (\mf _ -> m >>= mf)
This means that we construct a
FT
tree with aPure
node -
Then lifting this into
Weighted (FreeSampler m) a
(which is a newtype forStateT (Log Double) (FT SamF m) a
).
-
-
The
traceDist
parameter is achieved by mapping thepure
function for theTrace
data type overm a
to returnm (Trace a)
.data Trace a = Trace { variables :: [Double], output :: a, density :: Log Double } instance Applicative (TraceTCPS m) where pure x = TraceTCPS (\k -> k [] x 1)
When running mh
, we take the model :: StateT (Log Double) (FT SamF m) a
of the Traced
data type, and pass it to mhTrans
along with the most recent Trace
data type computed during the previous step of metropolis-hastings.
mh :: MonadSample m => Int -> Traced m a -> m [a]
mh n (Traced m d) = fmap (map output) (f n)
where
f 0 = fmap (: []) d
f k = do
~(x : xs) <- f (k -1)
y <- mhTrans m x
return (y : x : xs)
Inside mhTrans
on the line ((b, q), vs) <- ...
, this is where we use our model :: StateT (Log Double) (FT SamF m) a
.
mhTrans :: MonadSample m => Weighted (FreeSampler m) a -> Trace a -> m (Trace a)
mhTrans m t@Trace {variables = us, density = p} = do
let n = length us
us' <- do -- :: m [Log Double]
i <- discrete $ discreteUniformAB 0 (n -1)
u' <- random
let (xs, _ : ys) = splitAt i us
return $ xs ++ (u' : ys)
-- Weighted.hoist (WriterT . withPartialRandomness us')
-- :: Weighted (FreeSampler m) a -> Weighted (WriterT [Double] m) a
((b, q), vs) <- runWriterT $ runWeighted $ Weighted.hoist (WriterT . withPartialRandomness us') m
let ratio = (exp . ln) $ min 1 (q * fromIntegral n / (p * fromIntegral (length vs)))
accept <- bernoulli ratio
return $ if accept then Trace vs b q else t
- The type of
WriterT . withPartialRandomness us'
isFreeSampler m a -> WriterT [Double] m a
(or in other words,FT SamF m a -> WriterT [Double] m a
). - Hence, the type of
Weighted.hoist (WriterT . withPartialRandomness us')
has typeWeighted (FreeSampler m) a -> Weighted (WriterT [Double] m) a
(or in other words,StateT (Log Double) (FT SamF m) a -> StateT (Log Double) (WriterT [Double] m) a
).
The function withPartialRandomness
is given below. It takes the list of randomness and a free sampler m :: FT SamF m a
.
newtype SamF a = Random (Double -> a)
withPartialRandomness :: MonadSample m => [Double] -> FreeSampler m a -> m (a, [Double])
-- newtype FreeSampler m a = FreeSampler {runFreeSampler :: FT SamF m a}
withPartialRandomness randomness (FreeSampler m) =
-- f :: (SamF (StateT [Double] (WriterT [Double] m) a)
-- -> StateT [Double] (WriterT [Double] m) a)
-- iterTM :: (f (t m a) -> t m a) -> FT f m a -> t m a
-- iterTM f :: FT SamF (WriterT [Double] m) a -> StateT [Double] (WriterT [Double] m) a
-- hoistFT lift m :: FT SamF (WriterT [Double] m) a
-- m :: FT SamF m a
runWriterT $ evalStateT (iterTM f $ hoistFT lift m) randomness
where
-- k :: Double -> StateT [Double] (WriterT [Double] m) a
f (Random k) = do
-- This block runs in StateT [Double] (WriterT [Double] m) a.
-- StateT propagates consumed randomness while WriterT records
-- randomness used, whether old or new.
-- Get the state, i.e. the list of randomness "xs"
xs <- get
x <- case xs of
-- If the list of randomness is empty, then let x be a random double
[] -> random
-- If the list of randomness is non-empty, then update the state to be the
-- tail of the list, and return the first random double value from the list.
y : ys -> put ys >> return y
-- Append the random value we just retrieved to the Writer log.
tell [x]
-- Apply the function "k" to the random double, to yield the type
-- "StateT [Double] (WriterT [Double] m) a",
k x
-
hoistFT lift m
lifts the monadm
insideFT SamF m a
toWriterT
, yielding the typeFT SamF (WriterT [Double] m) a
.The type
FT SamF m a
unravels tom (FreeF SamF a (FreeT SamF m a))
. SohoistFT lift m :: FT SamF (WriterT [Double] m) a
unravels to(WriterT [Double] m) (FreeF SamF a (FreeT SamF (WriterT [Double] m) a))
.data FreeF f a x = Pure a | Free (f x) newtype FreeT f m a = FreeT { runFreeT :: m (FreeF f a (FreeT f m a)) }
-
f
has typeSamF (StateT [Double] (WriterT [Double] m) a) -> StateT [Double] (WriterT [Double] m) a
. In other words, it takes aSamF
node containing a functionk :: Double -> StateT [Double] (WriterT [Double] m) a
. The description of whatf
does is given in comments in the code above. -
iterTM
has type(f (t m a) -> t m a) -> FT f m a -> t m a
-
iterTM f
applied tohoistFT lift m
tears down the free monad transformer using iteration over a transformer. Iterating over the tree withf
constructs aStateT [Double] (WriterT [Double] m) a
computation. -
evalStateT
then usesrandomness
as the initial state to run the stateful computation whichiterTM f $ hoistFT lift m
constructs.
The reason the random function k :: Double -> a
specialises to the type Double -> StateT [Double] (WriterT [Double] m) a
is because we define f
to work in the context of StateT [Double] (WriterT [Double] m) a
.