Probabilistic Effects. λθ

CPSing Monad Bayes

  • Implemented and integrated CPS version of StateT monad into monad-bayes:
type State s = StateT s Identity

{-# INLINE state #-}
state :: Monad m => (s -> (a, s)) -> StateT s m a
state f = StateT (\s k -> uncurry k (f s))

{-# INLINE runState #-}
runState :: State s a -> s -> (a, s)
runState mx = runIdentity . runStateT mx

{-# INLINE evalState #-}
evalState :: State s a -> s -> a
evalState mx = runIdentity . evalStateT mx

{-# INLINE execState #-}
execState :: State s a -> s -> s
execState mx = runIdentity . execStateT mx

newtype StateT s m a = StateT {unStateT :: forall r. s -> (a -> s -> m r) -> m r} deriving Functor

{-# INLINE runStateT #-}
runStateT :: Monad m => StateT s m a -> s -> m (a, s)
runStateT (StateT f) s = f s (\x s' -> return (x, s'))

{-# INLINE evalStateT #-}
evalStateT :: Monad m => StateT s m a -> s -> m a
evalStateT (StateT f) s = f s (const . return)

{-# INLINE execStateT #-}
execStateT :: Monad m => StateT s m a -> s -> m s
execStateT (StateT f) s = f s (const return)

{-# INLINE mapStateT #-}
mapStateT :: (Monad m, Monad n) => (m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT f m = StateT (\s k -> do (b, s') <- f $ (runStateT m) s
                                   k b s' )

{-# INLINE modify #-}
modify :: (Monad m) => (s -> s) -> StateT s m ()
modify f = StateT $ \s k -> k () (f s)

{-# INLINE liftListen #-}
liftListen :: (Monad m) => Listen w m (a,s) -> Listen w (StateT s m) a
liftListen listen m = StateT $ \s k -> do
    ~((a, s'), w) <- listen (runStateT m s)
    k (a, w) s'

{-# INLINE liftPass #-}
liftPass :: (Monad m) => Pass w m (a,s) -> Pass w (StateT s m) a
liftPass pass (StateT mx) = StateT $ \ s k ->
    pass (mx s (\(x, f) s' -> return ((x, s'), f))) >>= uncurry k

instance Applicative (StateT s m) where
  {-# INLINE pure #-}
  pure x = StateT (flip ($ x))
  {-# INLINE liftA2 #-}
  liftA2 f (StateT mx) (StateT my) = StateT (\s k -> mx s (\x s' -> my s' (\y s'' -> k (f x y) s'')))

instance Monad (StateT s m) where
  {-# INLINE return #-}
  return = pure
  {-# INLINE (>>=) #-}
  StateT mx >>= f = StateT (\s k -> mx s (\x s' -> unStateT (f x) s' k))

instance MonadFix m => MonadFix (StateT s m) where
  {-# INLINE mfix #-}
  mfix f = StateT (\s k -> mfix (\ ~(x, _) -> runStateT (f x) s) >>= uncurry k)

instance MonadTrans (StateT s) where
  {-# INLINE lift #-}
  lift m = StateT (\s k -> m >>= (\x -> k x s))

instance MonadIO m => MonadIO (StateT s m) where liftIO = lift . liftIO
instance MonadFail m => MonadFail (StateT s m) where fail msg = StateT (\_ _ -> Fail.fail msg)

instance Alternative m => Alternative (StateT s m) where
  empty = StateT (\_ _ -> empty)
  StateT mx <|> StateT my = StateT (\s k -> mx s k <|> my s k)

instance MonadState s (StateT s m) where
  get = StateT (\s k -> k s s)
  put s = StateT (\_ k -> k () s)

instance MonadWriter w m => MonadWriter w (StateT s m) where
    writer = lift . writer
    tell   = lift . tell
    listen = liftListen listen
    pass   = liftPass pass

This resulted in a big performance gain of roughly 10s:

  • Swapped out ListT for LogicT (which is in CPS style):
newtype Population m a = Population (Weighted (LogicT m) a)
  deriving (Functor, Applicative, Monad, MonadIO, MonadSample, MonadCond, MonadInfer)

nondetermistically :: Monad m => m [a] -> LogicT m a
nondetermistically mx = lift mx >>= foldr (<|>) empty . map pure

This resulted in a further performance gain of roughly 3.6s:

  • Implemented a CPS version of the Trace data type of Bayes.Traced.Common, trying both TraceCPS and TraceTCPS.
-- | Trace (original data type)

data Trace a
  = Trace
      { variables :: [Double],
        output    :: a,
        density   :: Log Double
      }

instance Functor Trace where
  fmap f t = t {output = f (output t)}

instance Applicative Trace where
  pure x = Trace {variables = [], output = x, density = 1}
  tf <*> tx =
    Trace
      { variables = variables tf ++ variables tx,
        output = output tf (output tx),
        density = density tf * density tx
      }

instance Monad Trace where
  t >>= f =
    let t' = f (output t)
     in t' {variables = variables t ++ variables t', density = density t * density t'}

-- | TraceCPS

newtype TraceCPS a =
  TraceCPS { unTraceCPS :: forall r. ([Double] -> a -> Log Double -> r) -> r }
  deriving Functor

runTraceCPS :: TraceCPS a -> Trace a
runTraceCPS m = (unTraceCPS m) Trace

instance Applicative TraceCPS where
  pure x = TraceCPS (\k -> k [] x 1)
  (TraceCPS mf) <*> (TraceCPS mx) =
    TraceCPS (\k ->
      mf (\v f d -> mx (\v' x d' -> k (v ++ v') (f x) (d * d') ) ) )

instance Monad TraceCPS where
  return = pure
  (TraceCPS mx) >>= f =
    TraceCPS (\k ->
      mx (\v x d -> (unTraceCPS $ f x) (\v' x' d' -> k (v ++ v') x' (d * d') ) ) )

newtype TraceTCPS m a =
  TraceTCPS { unTraceTCPS :: forall r. ([Double] -> a -> Log Double -> m r) -> m r}
  deriving Functor

-- | TraceTCPS

hoist :: (forall x. m x -> m x) -> TraceTCPS m a -> TraceTCPS m a
hoist f (TraceTCPS mx) = TraceTCPS (\k -> mx (\ds x log -> f (k ds x log)))

runTraceTCPS :: Monad m => TraceTCPS m a -> m (Trace a)
runTraceTCPS (TraceTCPS m) = m (\v o d -> return (Trace v o d))

instance Applicative (TraceTCPS m) where
  pure x = TraceTCPS (\k -> k [] x 1)
  (TraceTCPS mf) <*> (TraceTCPS mx) =
    TraceTCPS (\k ->
      mf (\v f d -> mx (\v' x d' -> k (v ++ v') (f x) (d * d') ) ) )

instance Monad (TraceTCPS m) where
  return = pure
  (TraceTCPS mx) >>= f =
    TraceTCPS (\k ->
      mx (\v x d -> (unTraceTCPS $ f x) (\v' x' d' -> k (v ++ v') x' (d * d') ) ) )

instance MonadTrans TraceTCPS where
  {-# INLINE lift #-}
  lift m = TraceTCPS (\k -> m >>= (\x -> k [] x 1))

instance MonadSample m => MonadSample (TraceTCPS m) where
  random = lift random
  bernoulli = lift . bernoulli
  categorical = lift . categorical

instance MonadCond m => MonadCond (TraceTCPS m) where
  score = lift . score

instance MonadInfer m => MonadInfer (TraceTCPS m)

This led to the following changes to the Traced data types of Bayes.Traced.Basic and Bayes.Traced.Static (as well as other omitted modifications of related functions):

-- | Bayes.Traced.Basic.Traced (original data type)
data Traced m a
  = Traced
      {
        model     :: Weighted (FreeSampler Identity) a,
        traceDist :: m (Trace a)
      }

-- | Bayes.Traced.Basic.Traced (CPS'd version)
data Traced m a
  = Traced
      {
        model     :: Weighted (FreeSampler Identity) a,
        traceDist :: TraceTCPS m a
      }

-- | Bayes.Traced.Static.Traced (Original data type)
data Traced m a
  = Traced
      { model     :: Weighted (FreeSampler m) a,
        traceDist :: m (Trace a)
      }

-- | Bayes.Traced.Static.Traced (CPS'd version)
data Traced m a
  = Traced
      { model     :: Weighted (FreeSampler m) a,
        traceDist :: TraceTCPS m a
      }

Overall, these changes to Trace did not result in any runtime improvement.

  • CPS’d the Monad.Coroutine library. When rerunning against the most optimal current branch MonadLibCPS (now yielding 5.26s, which was previously recorded at 4.82s), the CPS’d coroutine version actually resulted in a slightly better run-time (5.17s), but this difference is so modest that it isn’t clear which one is better. Interestingly enough, the profiling report changes completely.
-- | Original version

newtype Coroutine s m a = Coroutine {
   resume :: m (Either (s (Coroutine s m a)) a)
}

instance (Functor s, Functor m) => Functor (Coroutine s m) where
   fmap f t = Coroutine (fmap (apply f) (resume t))
      where apply fc (Right x) = Right (fc x)
            apply fc (Left s) = Left (fmap (fmap fc) s)

instance (Functor s, Functor m, Monad m) => Applicative (Coroutine s m) where
   pure = return
   (<*>) = ap

instance (Functor s, Monad m) => Monad (Coroutine s m) where
   return x = Coroutine (return (Right x))
   t >>= f = Coroutine (resume t >>= apply f)
      where apply fc (Right x) = resume (fc x)
            apply fc (Left s) = return (Left (fmap (>>= fc) s))
   t >> f = Coroutine (resume t >>= apply f)
      where apply fc (Right _) = resume fc
            apply fc (Left s) = return (Left (fmap (>> fc) s))

mapMonad :: forall s m m' x. (Functor s, Monad m, Monad m') =>
            (forall y. m y -> m' y) -> Coroutine s m x -> Coroutine s m' x
mapMonad f cort = Coroutine {resume= liftM map' (f $ resume cort)}
   where map' (Right r) = Right r
         map' (Left s) = Left (fmap (mapMonad f) s)

bounce :: (Monad m, Functor s) => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> Coroutine s m x
bounce spring c = lift (resume c) >>= either spring return

pogoStick :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick spring = loop
   where loop c = resume c >>= either (loop . spring) return

await :: Monad m => Coroutine (Await x) m x
await = suspend (Await return)

suspend :: (Monad m, Functor s) => s (Coroutine s m x) -> Coroutine s m x
suspend s = Coroutine (return (Left s))

liftBinder :: forall s m. (Functor s, Monad m) => PairBinder m -> PairBinder (Coroutine s m)
liftBinder binder f t1 t2 = Coroutine (binder combine (resume t1) (resume t2)) where
   combine (Right x) (Right y) = resume (f x y)
   combine (Left s) (Right y) = return $ Left (fmap (flip f y =<<) s)
   combine (Right x) (Left s) = return $ Left (fmap (f x =<<) s)
   combine (Left s1) (Left s2) = return $ Left (fmap (liftBinder binder f $ suspend s1) s2)

-- | CPS'd version

newtype Coroutine s m a = Coroutine {
   resume :: forall r. (s (Coroutine s m a) -> m r) -> (a -> m r) -> m r
}

runCoroutine' :: Monad m => Coroutine s m a -> m (Either (s (Coroutine s m a)) a)
runCoroutine' (Coroutine f) = f (return . Left) (return . Right)

instance (Functor s, Functor m) => Functor (Coroutine s m) where
   fmap f (t) = Coroutine (\l r ->
                             (resume t (l . (\s -> fmap (fmap f) s)) (r . (\x -> f x))))

instance (Functor s, Functor m, Monad m) => Applicative (Coroutine s m) where
   pure = return
   (<*>) = ap

instance (Functor s, Monad m) => Monad (Coroutine s m) where
   return x = Coroutine (\l r -> r x)
   {-# INLINABLE (>>=) #-}
   t >>= f = Coroutine (\l r -> 
                           resume t (\s -> l (fmap (>>= f) s))
                                    (\x -> resume (f x) l r) 
                        )
   {-# INLINABLE (>>) #-}
   t >> f = Coroutine  (\l r ->
                           resume t (\s -> l (fmap (>> f) s))
                                    (\x -> (resume f) l r )
                        )

mapMonad :: forall s m m' x. (Functor s, Monad m, Monad m') =>
            (forall y. m y -> m' y) -> Coroutine s m x -> Coroutine s m' x
mapMonad f cort = 
   Coroutine (\l r -> do x <- f (runCoroutine' cort)
                         case x of Left s  -> l (fmap (mapMonad f) s)
                                   Right a -> r a )

bounce :: (Monad m, Functor s) =>
          (s (Coroutine s m a) -> Coroutine s m a) -> Coroutine s m a -> Coroutine s m a
bounce spring c = lift (runCoroutine' c) >>= either spring return

pogoStick :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick spring = loop
   where loop c = runCoroutine' c >>= either (loop . spring) return

await :: Monad m => Coroutine (Await x) m x
await = suspend (Await return)

suspend :: (Monad m, Functor s) => s (Coroutine s m x) -> Coroutine s m x
suspend s = Coroutine (\l r -> l s)

-- | Lifting a 'PairBinder' onto a 'Coroutine' monad transformer.
liftBinder :: forall s m. (Functor s, Monad m) => PairBinder m -> PairBinder (Coroutine s m)
liftBinder binder f t1 t2 =
   Coroutine (\l r -> do x <- binder combine (runCoroutine' t1) (runCoroutine' t2)
                         case x of Right rv -> r rv
                                   Left  lv -> l lv)
   where
   combine (Right x) (Right y) = runCoroutine' (f x y)
   combine (Left s) (Right y)  = return $ Left (fmap (flip f y =<<) s)
   combine (Right x) (Left s)  = return $ Left (fmap (f x =<<) s)
   combine (Left s1) (Left s2) = return $ Left (fmap (liftBinder binder f $ suspend s1) s2)

  • Improved the CPS’ing definitions in Monad.Coroutine library to use less of the runCoroutine' function.
-- Previous CPS'd version
bounce :: (Monad m, Functor s) =>
          (s (Coroutine s m a) -> Coroutine s m a) -> Coroutine s m a -> Coroutine s m a
bounce spring c = lift (runCoroutine' c) >>= either spring return

pogoStick :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick spring = loop
   where loop c = runCoroutine' c >>= either (loop . spring) return

liftBinder :: forall s m. (Functor s, Monad m) => PairBinder m -> PairBinder (Coroutine s m)
liftBinder binder f t1 t2 =
   Coroutine (\l r -> do x <- binder combine (runCoroutine' t1) (runCoroutine' t2)
                         case x of Right rv -> r rv
                                   Left  lv -> l lv)
   where
   combine (Right x) (Right y) = runCoroutine' (f x y)
   combine (Left s) (Right y)  = return $ Left (fmap (flip f y =<<) s)
   combine (Right x) (Left s)  = return $ Left (fmap (f x =<<) s)
   combine (Left s1) (Left s2) = return $ Left (fmap (liftBinder binder f $ suspend s1) s2)

-- CPS'd version
bounce :: (Monad m, Functor s) =>
          (s (Coroutine s m a) -> Coroutine s m a) -> Coroutine s m a -> Coroutine s m a
bounce spring c =
      Coroutine (\l r -> resume c (\s -> resume (spring s) l r) r ) 

pogoStick :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick spring = loop
  where loop c = resume c (loop . spring) (return)

liftBinder :: forall s m. (Functor s, Monad m) => PairBinder (Coroutine s m)
liftBinder f t1 t2 =
     Coroutine
      (\l r ->
         resume t1 (\s1 -> resume t2 (l . (\s2 -> (fmap (liftBinder f $ suspend s1) s2)))
                                     (l . (\y -> (fmap (flip f y =<<) s1)))
                   )
                   (\x  -> resume t2 (l . (\s -> (fmap (f x =<<) s)))
                                     (\y -> resume (f x y) l r)
                   )
      )

This gives the following benchmarks and profiling report:

  • Put an INLINABLE pragma on proper from Bayes.Population.

  • Added -O2, -funfolding-use-threshold=16, and -fexcess-precision flags.

  • Started using larger arguments to the benchmark program to increase the base run-time, so that optimisations yielding improved run-times are more noticeable. Previously, a data file of 500 data points were used, with 100 metropolis-hastings steps, 100 particle filter time steps, and 100 particles (100, 100, 100). Now, we bench using (100, 100, 300) and (300, 100, 100).

runPmmh :: (Int, Int, Int) -> IO ()
runPmmh (n_mhsteps, n_timesteps, n_particles) = do
  particleWeightings <- inferModel n_mhsteps n_timesteps n_particles
  print particleWeightings

main = defaultMain [
  bgroup "runPmmh" [ bench "(100, 100, 300)"  $ whnfIO $ runPmmh (100, 100, 300),
                     bench "(300, 100, 100)"  $ whnfIO $ runPmmh (300, 100, 100)
                   ]
  ]

(100, 100, 300)

(300, 100, 100)

  • Added -optc-O3 and -optc-ffast-math flags.

(100, 100, 300)

(300, 100, 100)

Last updated on 13 Nov 2020
Published on 13 Nov 2020