Probabilistic Effects. λθ

Inference Transformers

Inference Representation Type Classes

type R = Double

class Monad m => MonadSample m where
   random :: m R
   bernoulli :: R -> m Bool
   bernoulli p = fmap ( < p ) random
   -- and other default distributions :
   -- normal , gamma , beta , geometric ,
   -- poisson , dirichlet

class Monad m  MonadCond m where
   score :: Log R -> m ()

class ( MonadSample m , MonadCond m ) => MonadInfer m

Sampler

This inference transformer is a sampler that draws concrete values for random variables from the prior. Such a sampler type can be constructed as a state monad that references a global pseudo-random number generator.

newtype Sampler a = Sampler ( forall s . ReaderT ( GenST s ) ( ST s ) a )

instance MonadSample Sampler where
   random = Sampler $ do
   gen  ask
   lift ( MWC . uniform gen )

ReaderT

newtype ReaderT r m a = ReaderT { runReaderT :: r -> m a }

instance (Monad m) => Monad (ReaderT r m) where
    return   = lift . return
    m >>= k  = ReaderT $ \ r -> do
        a <- runReaderT m r
        runReaderT (k a) r
  • r = resource from the environment
  • m = the resulting monad
  • a = value returned in the monad
ask :: Monad m => ReaderT r m r

ask - to retrieve the supplied resource


Weighting

This is the weighting inference transformer W. (Theoretically it is the writer monad transformer for the monoid structure on Log R, but the state monad transformer was apparently found to be much faster.)

newtype W m a = W ( StateT ( Log R ) m a )

runW :: Monad m  W m a  m (a , Log R )
runW ( W x ) = runStateT x 1

instance MonadSample m  MonadSample ( W m ) where
   random = lift . random
instance Monad m  MonadCond ( W m ) where
   score w = W ( modify (* w ))

hoistW :: ( forall x . m x  n x )  W m a  W n a
hoistW f ( W m ) = W ( mapStateT f m )

W m a is an m-computation returning pairs (a, Log R), which are the result type and the accumulated log-likelihood. Weighting a representation m equips it with conditioning operation making W m a conditioning representation by allowing us to use the function score.

The function hoist lifts inference transformations applicable to m and turns them into inference transformations applicable to T m.

StateT

newtype StateT s m a = StateT { runStateT :: s -> m (a,s) }

instance (Monad m) => Monad (StateT s m) where
    return a = StateT $ \ s -> return (a, s)
    m >>= k  = StateT $ \ s -> do
        ~(a, s') <- runStateT m s
        runStateT (k a) s'
  • s = the state
  • m = the inner monad
  • a = value returned in the monad

Population

The Pop inference transformer turns a single sample into a collection of weighted samples, called the population. It is the weighted list transformer, i.e. the composition of W with the ListT transformer.

newtype Pop m a = Pop ( W ( ListT m ) a )
   deriving ( Monad , MonadSample , MonadCond , MonadInfer )

runPopulation :: Monad m  Pop m a  m [a , Log R ]
runPopulation ( Pop p ) = runListT ( runW p )

hoistP :: ( forall x . m x  n x )  Pop m a  Pop n a
hoistP f ( Pop m ) = Pop ( hoistW ( mapListT f ) m )

There are three inference transformations associated with the Pop inference transformer.

Calling spawn n >> increases the population size n times, adjusting the weights accordingly.

spawn :: Monad m => Int -> Pop m ()

Calling resample draws a new population with uniform weights from the current population.

resample :: MonadSample m => Pop m a -> Pop m a

Calling pushEvidence normalises the weights in the population, while simultaneously incorporating the sum of the weights as a score in m.

pushEvidence :: MonadInfer m => Pop m a -> Pop m a

ListT

newtype ListT m a = ListT { next :: m (Step m a) }
    deriving (Foldable, Traversable)

data Step m a = Cons a (ListT m a) | Nil

instance Monad m => Monad (ListT m) where
    return = pure
    ListT m >>= k = ListT (do
        s <- m
        case s of
            Nil       -> return Nil
            Cons x l' -> next (k x <|> (l' >>= k)) )

Sequential

Many models exhibit a sequential structure where observations are interleaved with sampling. In those models a possible inference strategy is to consider a program up to a certain point, do inference on the partial posterior it defines, then run the program a little more, do more inference, and so on.

The sequential transformer Seq represents a computation that can be suspended at certain points, by introducing suspensions after each score in the program. Seq is the standard coroutine transformer.

newtype Seq m a = Seq {runSeq :: Coroutine (Await ()) m a}
   deriving (Functor, Applicative, Monad, MonadTrans, MonadIO)

-- | A point where the computation is paused.
suspend :: Monad m => Sequential m ()
suspend = Sequential await

instance MonadSample m => MonadSample (Sequential m) where
  random = lift random
  bernoulli = lift . bernoulli
  categorical = lift . categorical

-- | Execution is 'suspend'ed after each 'score'.
instance MonadCond m => MonadCond (Sequential m) where
  score w = lift (score w) >> suspend

instance MonadInfer m => MonadInfer (Sequential m)

The hoistS function applies the inference transformation only to the part of the program executed so far (i.e. it transforms the inner monad). The transformation is applied recursively through all the suspension points.

hoistS :: (Monad m, Monad n) => (forall x. m x -> n x) -> Sequential m a -> Sequential n a
hoistS f = Sequential . mapMonad f . runSequential

We have two inference transformations associated with Seq:

The advance transformation runs the program to the next suspension point.

advance :: Monad m => Sequential m a -> Sequential m a
advance = Sequential . bounce extract . runSequential

The finish transformation runs the program to the end.

-- | Remove the remaining suspension points.
finish :: Monad m => Sequential m a -> m a
finish = pogoStick extract . runSequential

Coroutine

-- Suspending, resumable monadic computations.
newtype Coroutine s m r = Coroutine {
   -- | Run the next step of a `Coroutine` computation. The result of the
   -- | step execution will be either a suspension or the final coroutine result.
   resume :: m (Either (s (Coroutine s m r)) r)
   }

Traced

The traced transformer supports a class of algorithms called Trace MCMC - this is a variant of Metropolis Hastings that aims to make it more automatic in some sense. In normal Metropolis Hastings you have to create a proposal whereas trace MCMC gets around the need for a proposal.

Last updated on 13 Nov 2020
Published on 13 Nov 2020