Handrolling Monad Transformer Stacks
Unrolling MTL stacks
Stacked monad transformers do not inline well and the MTL library often requires an optimisation pass.
Unrolling means to flatten a stack of transformers into a single hand-unrolled monad.
For example, consider the following MTL monad stack.
-- RWS monad: A monad containing an environment of type r, output of type w, and updatable state of type s.
type RWS r w s = RWST r w s Identity
deriving (MonadState s, MonadWriter w, Monad)
newtype DRM a = DRM { unDRM :: ErrorT Finish (RWS () DNA RNA) a }
deriving (MonadState DNA, MonadWriter RNA, MonadError Finish, Monad)
-- | Inductive case: This tells us that as we know there is a MonadState and MonadWriter instance
-- somewhere in the stack (i.e. RWST) where the state is type DNA and the writer's log is type RNA,
-- then DRM (or equivalently, ErrorT) is also an instance of MonadState and MonadWriter.
-- | Base case: This tells us that as we known there is a MonadError instance somewhere in the stack
-- (i.e. the outermost transformer ErrorT) where the error is type Finish, then trivially DRM
-- (or equivalently ErrorT) is an instance of MonadError.
Unrolling this looks like:
-- | Take all involved monad transformers' associated parameters (the state DNA, the log RNA, and
-- the error Finish) from the monad stack and create a type alias containing these.
type DRM = DRMonad Finish DNA RNA
-- | Redefine monad stack newtype to be parameterised by each of the monads' associated parameters
-- from the stack (e, s, and w).
-- | Running the original monad stack would yield a type "s -> (Either e a, s, w)"
newtype DRMonad e s w a = DRMonad { runDRMonad :: s -> (Either e a, s, w) }
-- | We then create a monad instance for it, where (>>=) is defined to capture all of our
-- monads' effects.
instance (Monoid m, Error e) => Monad (DRMonad e s w) where
return x = DRMonad (\s -> (Right x, s, mempty))
(>>=) = bindDrMonad
{-# INLINE bindDRMonad #-}
-- | This first performs the State monad's bind operation
bindDRMonad :: Monoid m => DRMonad e s w a -> (a -> DRMonad e s w b) -> DRMonad e s w b
bindDRMonad m f = DRMonad $
-- Perform the State monad's effect
-- > Run the DRMonad m on state s to get (Either e a, s, w)
\s -> let (x', s', w) = (runDRMonad m) s
-- Perform the Either monad's effect
-- > In the Left case, perform no further effects and return error in the triple
in case x' of Left e -> (Left e, s', w)
-- > In the Right case, apply the (>>=)'d function f to a to get (DRMonad e s w b)
Right a -> let m' = f a
-- Perform the Writer monad's effect
-- > Run the new DRMonad m on the new state s' to get (Either e b, s, w)
(x'', s'', w') = (runDRMonad m') s'
-- > Append the new message onto the writer's log
in (x'', s'', w `mappend` w')
Q: Do you have to concretely know all the monads in a monad transformer stack, in order to handroll your own? For example, if we don’t know the monad m
in this stack?
newtype Weighted m a = Weighted (StateT (Log Double) m a)
deriving (Functor, Applicative, Monad, MonadIO, MonadTrans, MonadSample)
A: Nope, you can hand roll partially. Just pop a MonadTrans constraint on the m or something if it’s midway through the stack
Unrolling Ordinary Monad Transformer Stacks
To wrap a monad transformer m2T around a monad m1 means that we expect there to exist a function runM2T which takes m2T (m1 …) and returns the monad (m1 …) containing the effectful context of monad m2.
data Maybe a = Just a | Nothing
newtype MaybeT m a = MaybeT { runMaybeT :: m (Maybe a) }
data List a = Cons a (List a) | Nil
newtype ListT m a = ListT { runListT :: m [a] }
data Either e a = Left e | Right a
newtype ExceptT e m a = ExceptT { runExceptT :: m (Either e a)
newtype State s a = State { runState :: s -> (a, s) }
newtype StateT s m a = StateT { runStateT :: s -> m (a, s) }
newtype Writer w a = Writer { runWriter :: (a, w) }
newtype WriterT w m a = WriterT { runWriterT :: m (a, w) }
newtype Reader r a = Reader { runReader :: r -> a }
newtype ReaderT r m a = ReaderT { runReaderT :: r -> m a }
newtype LMS s a = LMS { unLMS :: ListT (MaybeT (State s)) a }
Continuation Passing Style
newtype ContT (r :: k) (m :: k -> *) a :: forall k. k -> (k -> *) -> * -> *
type Cont r = ContT r Identity
newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r }
newtype Cont r a = Cont { runCont :: ((a -> r) -> r) }
instance Monad (Cont r) where
return a = Cont $ \k -> k a
(Cont c) >>= f = Cont $ \k -> c (\a -> runCont (f a) k)
Every monad can be embedded into the continuation passing monad Cont
by setting the “result” type of the Cont
monad to be that monad.
type MCPS a = forall r. Cont (M r) a
runMCPS m = runCont m return
This is essentially just the same as what ContT
achieves if we set m
to M
.
newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r }
An important point is that M
’s (>>=)
operation isn’t used. In many situations, this lets us effectively avoid a layer of interpretation in a lot of the code which uses M
.
As an example, consider the following code which would clearly benefit from using the Maybe
monad.
liftMaybe2 :: (a -> b -> Maybe c) -> Maybe a -> Maybe b -> Maybe c
liftMaybe2 f a b =
case a of Nothing -> Nothing
Just a' -> case b of Nothing -> Nothing
Just b' -> f a' b'
Using the Maybe
monad, it looks like:
liftMaybe2 :: (a -> b -> Maybe c) -> Maybe a -> Maybe b -> Maybe c
liftMaybe2 f a b = do
a' <- a
b' <- b
f a' b'
The original code must constantly check the returned value for a failure, and it is possible that it will need to propagate a Nothing
through an arbitrary large amount of code.
Ideally we want failure free code to run as normal and only deal with failure when it occurs and immediately fail in those cases rather than propagating the failure. This is what continuation passing style gets us.
Explicitly expanding Cont
gets us:
newtype MaybeCPS a = MaybeCPS { unMaybeCPS :: forall r. (a -> Maybe r) -> Maybe r }
runMaybeCPS m = unMaybeCPS m return
-- This is essentially the same as the Cont definitions of return and (>>=)
instance Monad MaybeCPS where
return a = MaybeCPS (\k -> k a)
MaybeCPS m >>= f = MaybeCPS (\k -> m (\a -> (unMaybeCPS (f a)) k))
Note how this code is just normal CPS code and completely independent of the Maybe monad - there are no case analyses, so failure-free code will run as normal CPS code.
This works because we’re basically specializing (>>=)
.
-
Before, we had the following where the monadic bind does a case analysis on how to proceed:
m >>= f = case m of Just a -> f a Nothing -> Nothing
-
In the CPS-representation, we’re specializing bind to the two possible constructors (Just and Nothing), and the case analysis is now “built-in” into return' and mzero':
return' a := return a >>= = (Just a >>=) = \f -> f a mzero' := mzero >>= = (Nothing >>=) = \f -> Nothing
Embedding a monad in Cont
specializes (>>=)
to its primitive operations, for example, mzero
or get
. This is close to specifying the operational semantics of (>>=)
directly and can hence be used to implement the monad in the first place as well.
instance MonadPlus MaybeCPS where
mzero = MaybeCPS (\_ -> Nothing)
m `mplus` n = case runMaybeCPS m of Nothing -> n
Just a -> return a
- The function
mplus
is the only place where we use the case analysis, hence it is the only place that we look at what was returned and to do this we have to actually run the computation. Therefore we only deal with effects when we need to. - The function
mzero
discards its continuation - this is the typical pattern for aborting a CPS computation and leads to an immediate termination of the computation.