CPSing Monad Bayes
- Implemented and integrated CPS version of
StateT
monad into monad-bayes:
type State s = StateT s Identity
{-# INLINE state #-}
state :: Monad m => (s -> (a, s)) -> StateT s m a
state f = StateT (\s k -> uncurry k (f s))
{-# INLINE runState #-}
runState :: State s a -> s -> (a, s)
runState mx = runIdentity . runStateT mx
{-# INLINE evalState #-}
evalState :: State s a -> s -> a
evalState mx = runIdentity . evalStateT mx
{-# INLINE execState #-}
execState :: State s a -> s -> s
execState mx = runIdentity . execStateT mx
newtype StateT s m a = StateT {unStateT :: forall r. s -> (a -> s -> m r) -> m r} deriving Functor
{-# INLINE runStateT #-}
runStateT :: Monad m => StateT s m a -> s -> m (a, s)
runStateT (StateT f) s = f s (\x s' -> return (x, s'))
{-# INLINE evalStateT #-}
evalStateT :: Monad m => StateT s m a -> s -> m a
evalStateT (StateT f) s = f s (const . return)
{-# INLINE execStateT #-}
execStateT :: Monad m => StateT s m a -> s -> m s
execStateT (StateT f) s = f s (const return)
{-# INLINE mapStateT #-}
mapStateT :: (Monad m, Monad n) => (m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT f m = StateT (\s k -> do (b, s') <- f $ (runStateT m) s
k b s' )
{-# INLINE modify #-}
modify :: (Monad m) => (s -> s) -> StateT s m ()
modify f = StateT $ \s k -> k () (f s)
{-# INLINE liftListen #-}
liftListen :: (Monad m) => Listen w m (a,s) -> Listen w (StateT s m) a
liftListen listen m = StateT $ \s k -> do
~((a, s'), w) <- listen (runStateT m s)
k (a, w) s'
{-# INLINE liftPass #-}
liftPass :: (Monad m) => Pass w m (a,s) -> Pass w (StateT s m) a
liftPass pass (StateT mx) = StateT $ \ s k ->
pass (mx s (\(x, f) s' -> return ((x, s'), f))) >>= uncurry k
instance Applicative (StateT s m) where
{-# INLINE pure #-}
pure x = StateT (flip ($ x))
{-# INLINE liftA2 #-}
liftA2 f (StateT mx) (StateT my) = StateT (\s k -> mx s (\x s' -> my s' (\y s'' -> k (f x y) s'')))
instance Monad (StateT s m) where
{-# INLINE return #-}
return = pure
{-# INLINE (>>=) #-}
StateT mx >>= f = StateT (\s k -> mx s (\x s' -> unStateT (f x) s' k))
instance MonadFix m => MonadFix (StateT s m) where
{-# INLINE mfix #-}
mfix f = StateT (\s k -> mfix (\ ~(x, _) -> runStateT (f x) s) >>= uncurry k)
instance MonadTrans (StateT s) where
{-# INLINE lift #-}
lift m = StateT (\s k -> m >>= (\x -> k x s))
instance MonadIO m => MonadIO (StateT s m) where liftIO = lift . liftIO
instance MonadFail m => MonadFail (StateT s m) where fail msg = StateT (\_ _ -> Fail.fail msg)
instance Alternative m => Alternative (StateT s m) where
empty = StateT (\_ _ -> empty)
StateT mx <|> StateT my = StateT (\s k -> mx s k <|> my s k)
instance MonadState s (StateT s m) where
get = StateT (\s k -> k s s)
put s = StateT (\_ k -> k () s)
instance MonadWriter w m => MonadWriter w (StateT s m) where
writer = lift . writer
tell = lift . tell
listen = liftListen listen
pass = liftPass pass
This resulted in a big performance gain of roughly 10s:
- Swapped out
ListT
forLogicT
(which is in CPS style):
newtype Population m a = Population (Weighted (LogicT m) a)
deriving (Functor, Applicative, Monad, MonadIO, MonadSample, MonadCond, MonadInfer)
nondetermistically :: Monad m => m [a] -> LogicT m a
nondetermistically mx = lift mx >>= foldr (<|>) empty . map pure
This resulted in a further performance gain of roughly 3.6s:
- Implemented a CPS version of the
Trace
data type ofBayes.Traced.Common
, trying bothTraceCPS
andTraceTCPS
.
-- | Trace (original data type)
data Trace a
= Trace
{ variables :: [Double],
output :: a,
density :: Log Double
}
instance Functor Trace where
fmap f t = t {output = f (output t)}
instance Applicative Trace where
pure x = Trace {variables = [], output = x, density = 1}
tf <*> tx =
Trace
{ variables = variables tf ++ variables tx,
output = output tf (output tx),
density = density tf * density tx
}
instance Monad Trace where
t >>= f =
let t' = f (output t)
in t' {variables = variables t ++ variables t', density = density t * density t'}
-- | TraceCPS
newtype TraceCPS a =
TraceCPS { unTraceCPS :: forall r. ([Double] -> a -> Log Double -> r) -> r }
deriving Functor
runTraceCPS :: TraceCPS a -> Trace a
runTraceCPS m = (unTraceCPS m) Trace
instance Applicative TraceCPS where
pure x = TraceCPS (\k -> k [] x 1)
(TraceCPS mf) <*> (TraceCPS mx) =
TraceCPS (\k ->
mf (\v f d -> mx (\v' x d' -> k (v ++ v') (f x) (d * d') ) ) )
instance Monad TraceCPS where
return = pure
(TraceCPS mx) >>= f =
TraceCPS (\k ->
mx (\v x d -> (unTraceCPS $ f x) (\v' x' d' -> k (v ++ v') x' (d * d') ) ) )
newtype TraceTCPS m a =
TraceTCPS { unTraceTCPS :: forall r. ([Double] -> a -> Log Double -> m r) -> m r}
deriving Functor
-- | TraceTCPS
hoist :: (forall x. m x -> m x) -> TraceTCPS m a -> TraceTCPS m a
hoist f (TraceTCPS mx) = TraceTCPS (\k -> mx (\ds x log -> f (k ds x log)))
runTraceTCPS :: Monad m => TraceTCPS m a -> m (Trace a)
runTraceTCPS (TraceTCPS m) = m (\v o d -> return (Trace v o d))
instance Applicative (TraceTCPS m) where
pure x = TraceTCPS (\k -> k [] x 1)
(TraceTCPS mf) <*> (TraceTCPS mx) =
TraceTCPS (\k ->
mf (\v f d -> mx (\v' x d' -> k (v ++ v') (f x) (d * d') ) ) )
instance Monad (TraceTCPS m) where
return = pure
(TraceTCPS mx) >>= f =
TraceTCPS (\k ->
mx (\v x d -> (unTraceTCPS $ f x) (\v' x' d' -> k (v ++ v') x' (d * d') ) ) )
instance MonadTrans TraceTCPS where
{-# INLINE lift #-}
lift m = TraceTCPS (\k -> m >>= (\x -> k [] x 1))
instance MonadSample m => MonadSample (TraceTCPS m) where
random = lift random
bernoulli = lift . bernoulli
categorical = lift . categorical
instance MonadCond m => MonadCond (TraceTCPS m) where
score = lift . score
instance MonadInfer m => MonadInfer (TraceTCPS m)
This led to the following changes to the Traced
data types of Bayes.Traced.Basic
and Bayes.Traced.Static
(as well as other omitted modifications of related functions):
-- | Bayes.Traced.Basic.Traced (original data type)
data Traced m a
= Traced
{
model :: Weighted (FreeSampler Identity) a,
traceDist :: m (Trace a)
}
-- | Bayes.Traced.Basic.Traced (CPS'd version)
data Traced m a
= Traced
{
model :: Weighted (FreeSampler Identity) a,
traceDist :: TraceTCPS m a
}
-- | Bayes.Traced.Static.Traced (Original data type)
data Traced m a
= Traced
{ model :: Weighted (FreeSampler m) a,
traceDist :: m (Trace a)
}
-- | Bayes.Traced.Static.Traced (CPS'd version)
data Traced m a
= Traced
{ model :: Weighted (FreeSampler m) a,
traceDist :: TraceTCPS m a
}
Overall, these changes to Trace
did not result in any runtime improvement.
- CPS’d the
Monad.Coroutine
library. When rerunning against the most optimal current branchMonadLibCPS
(now yielding 5.26s, which was previously recorded at 4.82s), the CPS’d coroutine version actually resulted in a slightly better run-time (5.17s), but this difference is so modest that it isn’t clear which one is better. Interestingly enough, the profiling report changes completely.
-- | Original version
newtype Coroutine s m a = Coroutine {
resume :: m (Either (s (Coroutine s m a)) a)
}
instance (Functor s, Functor m) => Functor (Coroutine s m) where
fmap f t = Coroutine (fmap (apply f) (resume t))
where apply fc (Right x) = Right (fc x)
apply fc (Left s) = Left (fmap (fmap fc) s)
instance (Functor s, Functor m, Monad m) => Applicative (Coroutine s m) where
pure = return
(<*>) = ap
instance (Functor s, Monad m) => Monad (Coroutine s m) where
return x = Coroutine (return (Right x))
t >>= f = Coroutine (resume t >>= apply f)
where apply fc (Right x) = resume (fc x)
apply fc (Left s) = return (Left (fmap (>>= fc) s))
t >> f = Coroutine (resume t >>= apply f)
where apply fc (Right _) = resume fc
apply fc (Left s) = return (Left (fmap (>> fc) s))
mapMonad :: forall s m m' x. (Functor s, Monad m, Monad m') =>
(forall y. m y -> m' y) -> Coroutine s m x -> Coroutine s m' x
mapMonad f cort = Coroutine {resume= liftM map' (f $ resume cort)}
where map' (Right r) = Right r
map' (Left s) = Left (fmap (mapMonad f) s)
bounce :: (Monad m, Functor s) => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> Coroutine s m x
bounce spring c = lift (resume c) >>= either spring return
pogoStick :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick spring = loop
where loop c = resume c >>= either (loop . spring) return
await :: Monad m => Coroutine (Await x) m x
await = suspend (Await return)
suspend :: (Monad m, Functor s) => s (Coroutine s m x) -> Coroutine s m x
suspend s = Coroutine (return (Left s))
liftBinder :: forall s m. (Functor s, Monad m) => PairBinder m -> PairBinder (Coroutine s m)
liftBinder binder f t1 t2 = Coroutine (binder combine (resume t1) (resume t2)) where
combine (Right x) (Right y) = resume (f x y)
combine (Left s) (Right y) = return $ Left (fmap (flip f y =<<) s)
combine (Right x) (Left s) = return $ Left (fmap (f x =<<) s)
combine (Left s1) (Left s2) = return $ Left (fmap (liftBinder binder f $ suspend s1) s2)
-- | CPS'd version
newtype Coroutine s m a = Coroutine {
resume :: forall r. (s (Coroutine s m a) -> m r) -> (a -> m r) -> m r
}
runCoroutine' :: Monad m => Coroutine s m a -> m (Either (s (Coroutine s m a)) a)
runCoroutine' (Coroutine f) = f (return . Left) (return . Right)
instance (Functor s, Functor m) => Functor (Coroutine s m) where
fmap f (t) = Coroutine (\l r ->
(resume t (l . (\s -> fmap (fmap f) s)) (r . (\x -> f x))))
instance (Functor s, Functor m, Monad m) => Applicative (Coroutine s m) where
pure = return
(<*>) = ap
instance (Functor s, Monad m) => Monad (Coroutine s m) where
return x = Coroutine (\l r -> r x)
{-# INLINABLE (>>=) #-}
t >>= f = Coroutine (\l r ->
resume t (\s -> l (fmap (>>= f) s))
(\x -> resume (f x) l r)
)
{-# INLINABLE (>>) #-}
t >> f = Coroutine (\l r ->
resume t (\s -> l (fmap (>> f) s))
(\x -> (resume f) l r )
)
mapMonad :: forall s m m' x. (Functor s, Monad m, Monad m') =>
(forall y. m y -> m' y) -> Coroutine s m x -> Coroutine s m' x
mapMonad f cort =
Coroutine (\l r -> do x <- f (runCoroutine' cort)
case x of Left s -> l (fmap (mapMonad f) s)
Right a -> r a )
bounce :: (Monad m, Functor s) =>
(s (Coroutine s m a) -> Coroutine s m a) -> Coroutine s m a -> Coroutine s m a
bounce spring c = lift (runCoroutine' c) >>= either spring return
pogoStick :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick spring = loop
where loop c = runCoroutine' c >>= either (loop . spring) return
await :: Monad m => Coroutine (Await x) m x
await = suspend (Await return)
suspend :: (Monad m, Functor s) => s (Coroutine s m x) -> Coroutine s m x
suspend s = Coroutine (\l r -> l s)
-- | Lifting a 'PairBinder' onto a 'Coroutine' monad transformer.
liftBinder :: forall s m. (Functor s, Monad m) => PairBinder m -> PairBinder (Coroutine s m)
liftBinder binder f t1 t2 =
Coroutine (\l r -> do x <- binder combine (runCoroutine' t1) (runCoroutine' t2)
case x of Right rv -> r rv
Left lv -> l lv)
where
combine (Right x) (Right y) = runCoroutine' (f x y)
combine (Left s) (Right y) = return $ Left (fmap (flip f y =<<) s)
combine (Right x) (Left s) = return $ Left (fmap (f x =<<) s)
combine (Left s1) (Left s2) = return $ Left (fmap (liftBinder binder f $ suspend s1) s2)
- Improved the CPS’ing definitions in
Monad.Coroutine
library to use less of therunCoroutine'
function.
-- Previous CPS'd version
bounce :: (Monad m, Functor s) =>
(s (Coroutine s m a) -> Coroutine s m a) -> Coroutine s m a -> Coroutine s m a
bounce spring c = lift (runCoroutine' c) >>= either spring return
pogoStick :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick spring = loop
where loop c = runCoroutine' c >>= either (loop . spring) return
liftBinder :: forall s m. (Functor s, Monad m) => PairBinder m -> PairBinder (Coroutine s m)
liftBinder binder f t1 t2 =
Coroutine (\l r -> do x <- binder combine (runCoroutine' t1) (runCoroutine' t2)
case x of Right rv -> r rv
Left lv -> l lv)
where
combine (Right x) (Right y) = runCoroutine' (f x y)
combine (Left s) (Right y) = return $ Left (fmap (flip f y =<<) s)
combine (Right x) (Left s) = return $ Left (fmap (f x =<<) s)
combine (Left s1) (Left s2) = return $ Left (fmap (liftBinder binder f $ suspend s1) s2)
-- CPS'd version
bounce :: (Monad m, Functor s) =>
(s (Coroutine s m a) -> Coroutine s m a) -> Coroutine s m a -> Coroutine s m a
bounce spring c =
Coroutine (\l r -> resume c (\s -> resume (spring s) l r) r )
pogoStick :: Monad m => (s (Coroutine s m x) -> Coroutine s m x) -> Coroutine s m x -> m x
pogoStick spring = loop
where loop c = resume c (loop . spring) (return)
liftBinder :: forall s m. (Functor s, Monad m) => PairBinder (Coroutine s m)
liftBinder f t1 t2 =
Coroutine
(\l r ->
resume t1 (\s1 -> resume t2 (l . (\s2 -> (fmap (liftBinder f $ suspend s1) s2)))
(l . (\y -> (fmap (flip f y =<<) s1)))
)
(\x -> resume t2 (l . (\s -> (fmap (f x =<<) s)))
(\y -> resume (f x y) l r)
)
)
This gives the following benchmarks and profiling report:
- Put an
INLINABLE
pragma onproper
fromBayes.Population
.
- Added
-O2
,-funfolding-use-threshold=16
, and-fexcess-precision
flags.
- Started using larger arguments to the benchmark program to increase the base run-time, so that optimisations yielding improved run-times are more noticeable. Previously, a data file of 500 data points were used, with 100 metropolis-hastings steps, 100 particle filter time steps, and 100 particles (100, 100, 100). Now, we bench using (100, 100, 300) and (300, 100, 100).
runPmmh :: (Int, Int, Int) -> IO ()
runPmmh (n_mhsteps, n_timesteps, n_particles) = do
particleWeightings <- inferModel n_mhsteps n_timesteps n_particles
print particleWeightings
main = defaultMain [
bgroup "runPmmh" [ bench "(100, 100, 300)" $ whnfIO $ runPmmh (100, 100, 300),
bench "(300, 100, 100)" $ whnfIO $ runPmmh (300, 100, 100)
]
]
(100, 100, 300)
(300, 100, 100)
- Added
-optc-O3
and-optc-ffast-math
flags.
(100, 100, 300)
(300, 100, 100)