Inference Transformers
Inference Representation Type Classes
type R = Double
class Monad m => MonadSample m where
random :: m R
bernoulli :: R -> m Bool
bernoulli p = fmap ( < p ) random
-- and other default distributions :
-- normal , gamma , beta , geometric ,
-- poisson , dirichlet
class Monad m ⇒ MonadCond m where
score :: Log R -> m ()
class ( MonadSample m , MonadCond m ) => MonadInfer m
Sampler
This inference transformer is a sampler that draws concrete values for random variables from the prior. Such a sampler type can be constructed as a state monad that references a global pseudo-random number generator.
newtype Sampler a = Sampler ( forall s . ReaderT ( GenST s ) ( ST s ) a )
instance MonadSample Sampler where
random = Sampler $ do
gen ← ask
lift ( MWC . uniform gen )
ReaderT
newtype ReaderT r m a = ReaderT { runReaderT :: r -> m a }
instance (Monad m) => Monad (ReaderT r m) where
return = lift . return
m >>= k = ReaderT $ \ r -> do
a <- runReaderT m r
runReaderT (k a) r
- r = resource from the environment
- m = the resulting monad
- a = value returned in the monad
ask :: Monad m => ReaderT r m r
ask
- to retrieve the supplied resource
Weighting
This is the weighting inference transformer W
. (Theoretically it is the writer monad transformer for the monoid structure on Log R
, but the state monad transformer was apparently found to be much faster.)
newtype W m a = W ( StateT ( Log R ) m a )
runW :: Monad m ⇒ W m a → m (a , Log R )
runW ( W x ) = runStateT x 1
instance MonadSample m ⇒ MonadSample ( W m ) where
random = lift . random
instance Monad m ⇒ MonadCond ( W m ) where
score w = W ( modify (* w ))
hoistW :: ( forall x . m x → n x ) → W m a → W n a
hoistW f ( W m ) = W ( mapStateT f m )
W m a
is an m
-computation returning pairs (a, Log R)
, which are the result type and the accumulated log-likelihood. Weighting a representation m
equips it with conditioning operation making W m a
conditioning representation by allowing us to use the function score
.
The function hoist
lifts inference transformations applicable to m
and turns them into inference transformations applicable to T m
.
StateT
newtype StateT s m a = StateT { runStateT :: s -> m (a,s) }
instance (Monad m) => Monad (StateT s m) where
return a = StateT $ \ s -> return (a, s)
m >>= k = StateT $ \ s -> do
~(a, s') <- runStateT m s
runStateT (k a) s'
- s = the state
- m = the inner monad
- a = value returned in the monad
Population
The Pop
inference transformer turns a single sample into a collection of weighted samples, called the population. It is the weighted list transformer, i.e. the composition of W
with the ListT
transformer.
newtype Pop m a = Pop ( W ( ListT m ) a )
deriving ( Monad , MonadSample , MonadCond , MonadInfer )
runPopulation :: Monad m ⇒ Pop m a → m [a , Log R ]
runPopulation ( Pop p ) = runListT ( runW p )
hoistP :: ( forall x . m x → n x ) → Pop m a → Pop n a
hoistP f ( Pop m ) = Pop ( hoistW ( mapListT f ) m )
There are three inference transformations associated with the Pop
inference transformer.
Calling spawn n >>
increases the population size n
times, adjusting the weights accordingly.
spawn :: Monad m => Int -> Pop m ()
Calling resample
draws a new population with uniform weights from the current population.
resample :: MonadSample m => Pop m a -> Pop m a
Calling pushEvidence
normalises the weights in the population, while simultaneously incorporating the sum of the weights as a score
in m
.
pushEvidence :: MonadInfer m => Pop m a -> Pop m a
ListT
newtype ListT m a = ListT { next :: m (Step m a) }
deriving (Foldable, Traversable)
data Step m a = Cons a (ListT m a) | Nil
instance Monad m => Monad (ListT m) where
return = pure
ListT m >>= k = ListT (do
s <- m
case s of
Nil -> return Nil
Cons x l' -> next (k x <|> (l' >>= k)) )
Sequential
Many models exhibit a sequential structure where observations are interleaved with sampling. In those models a possible inference strategy is to consider a program up to a certain point, do inference on the partial posterior it defines, then run the program a little more, do more inference, and so on.
The sequential transformer Seq
represents a computation that can be suspended at certain points, by introducing suspensions after each score
in the program. Seq
is the standard coroutine transformer.
newtype Seq m a = Seq {runSeq :: Coroutine (Await ()) m a}
deriving (Functor, Applicative, Monad, MonadTrans, MonadIO)
-- | A point where the computation is paused.
suspend :: Monad m => Sequential m ()
suspend = Sequential await
instance MonadSample m => MonadSample (Sequential m) where
random = lift random
bernoulli = lift . bernoulli
categorical = lift . categorical
-- | Execution is 'suspend'ed after each 'score'.
instance MonadCond m => MonadCond (Sequential m) where
score w = lift (score w) >> suspend
instance MonadInfer m => MonadInfer (Sequential m)
The hoistS
function applies the inference transformation only to the part of the program executed so far (i.e. it transforms the inner monad). The transformation is applied recursively through all the suspension points.
hoistS :: (Monad m, Monad n) => (forall x. m x -> n x) -> Sequential m a -> Sequential n a
hoistS f = Sequential . mapMonad f . runSequential
We have two inference transformations associated with Seq
:
The advance
transformation runs the program to the next suspension point.
advance :: Monad m => Sequential m a -> Sequential m a
advance = Sequential . bounce extract . runSequential
The finish transformation runs the program to the end.
-- | Remove the remaining suspension points.
finish :: Monad m => Sequential m a -> m a
finish = pogoStick extract . runSequential
Coroutine
-- Suspending, resumable monadic computations.
newtype Coroutine s m r = Coroutine {
-- | Run the next step of a `Coroutine` computation. The result of the
-- | step execution will be either a suspension or the final coroutine result.
resume :: m (Either (s (Coroutine s m r)) r)
}
Traced
The traced transformer supports a class of algorithms called Trace MCMC - this is a variant of Metropolis Hastings that aims to make it more automatic in some sense. In normal Metropolis Hastings you have to create a proposal whereas trace MCMC gets around the need for a proposal.