{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE TupleSections #-}

module Control.Monad.ConsumableStack
    ( ConsumableStackT
    , ConsumableStack
    , pop
    , runConsumableStackT
    , runConsumableStack
    , ConsumableStackError (..)
    )
where

import Control.Monad ((>=>))
import Control.Monad.Identity (Identity, runIdentity)
import Control.Monad.Trans.Class (MonadTrans (lift))
import Data.Bifunctor (first)
import Data.List (uncons)

-- | A monad transformer with a pop-only stack that is expected to be fully
--   consumed ("popped").
newtype ConsumableStackT s m a = ConsumableStackT ([s] -> m (Maybe (a, [s])))

type ConsumableStack s = ConsumableStackT s Identity

unwrapCST :: ConsumableStackT s m a -> [s] -> m (Maybe (a, [s]))
unwrapCST :: forall s (m :: * -> *) a.
ConsumableStackT s m a -> [s] -> m (Maybe (a, [s]))
unwrapCST (ConsumableStackT [s] -> m (Maybe (a, [s]))
sx) = [s] -> m (Maybe (a, [s]))
sx

instance (Functor m) => Functor (ConsumableStackT s m) where
    fmap :: forall a b.
(a -> b) -> ConsumableStackT s m a -> ConsumableStackT s m b
fmap a -> b
f ConsumableStackT s m a
sx = ([s] -> m (Maybe (b, [s]))) -> ConsumableStackT s m b
forall s (m :: * -> *) a.
([s] -> m (Maybe (a, [s]))) -> ConsumableStackT s m a
ConsumableStackT (([s] -> m (Maybe (b, [s]))) -> ConsumableStackT s m b)
-> ([s] -> m (Maybe (b, [s]))) -> ConsumableStackT s m b
forall a b. (a -> b) -> a -> b
$ (Maybe (a, [s]) -> Maybe (b, [s]))
-> m (Maybe (a, [s])) -> m (Maybe (b, [s]))
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((a, [s]) -> (b, [s])) -> Maybe (a, [s]) -> Maybe (b, [s])
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> b) -> (a, [s]) -> (b, [s])
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first a -> b
f)) (m (Maybe (a, [s])) -> m (Maybe (b, [s])))
-> ([s] -> m (Maybe (a, [s]))) -> [s] -> m (Maybe (b, [s]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConsumableStackT s m a -> [s] -> m (Maybe (a, [s]))
forall s (m :: * -> *) a.
ConsumableStackT s m a -> [s] -> m (Maybe (a, [s]))
unwrapCST ConsumableStackT s m a
sx

-- Note: (Applicative m) seems to be an insufficient constraint; see also the
--   StateT instance.
instance (Monad m) => Applicative (ConsumableStackT s m) where
    pure :: forall a. a -> ConsumableStackT s m a
pure = m a -> ConsumableStackT s m a
forall (m :: * -> *) a. Monad m => m a -> ConsumableStackT s m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> ConsumableStackT s m a)
-> (a -> m a) -> a -> ConsumableStackT s m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ConsumableStackT s m (a -> b)
sf <*> :: forall a b.
ConsumableStackT s m (a -> b)
-> ConsumableStackT s m a -> ConsumableStackT s m b
<*> ConsumableStackT s m a
sx = ConsumableStackT s m (a -> b)
sf ConsumableStackT s m (a -> b)
-> ((a -> b) -> ConsumableStackT s m b) -> ConsumableStackT s m b
forall a b.
ConsumableStackT s m a
-> (a -> ConsumableStackT s m b) -> ConsumableStackT s m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ((a -> b) -> ConsumableStackT s m a -> ConsumableStackT s m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConsumableStackT s m a
sx)

instance (Monad m) => Monad (ConsumableStackT s m) where
    ConsumableStackT s m a
sx >>= :: forall a b.
ConsumableStackT s m a
-> (a -> ConsumableStackT s m b) -> ConsumableStackT s m b
>>= a -> ConsumableStackT s m b
k = ([s] -> m (Maybe (b, [s]))) -> ConsumableStackT s m b
forall s (m :: * -> *) a.
([s] -> m (Maybe (a, [s]))) -> ConsumableStackT s m a
ConsumableStackT (([s] -> m (Maybe (b, [s]))) -> ConsumableStackT s m b)
-> ([s] -> m (Maybe (b, [s]))) -> ConsumableStackT s m b
forall a b. (a -> b) -> a -> b
$ ConsumableStackT s m a -> [s] -> m (Maybe (a, [s]))
forall s (m :: * -> *) a.
ConsumableStackT s m a -> [s] -> m (Maybe (a, [s]))
unwrapCST ConsumableStackT s m a
sx ([s] -> m (Maybe (a, [s])))
-> (Maybe (a, [s]) -> m (Maybe (b, [s])))
-> [s]
-> m (Maybe (b, [s]))
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Maybe (a, [s]) -> m (Maybe (b, [s]))
k'
      where
        k' :: Maybe (a, [s]) -> m (Maybe (b, [s]))
k' Maybe (a, [s])
Nothing = Maybe (b, [s]) -> m (Maybe (b, [s]))
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (b, [s])
forall a. Maybe a
Nothing
        k' (Just (a
x, [s]
s')) = ConsumableStackT s m b -> [s] -> m (Maybe (b, [s]))
forall s (m :: * -> *) a.
ConsumableStackT s m a -> [s] -> m (Maybe (a, [s]))
unwrapCST (a -> ConsumableStackT s m b
k a
x) [s]
s'

instance MonadTrans (ConsumableStackT s) where
    lift :: forall (m :: * -> *) a. Monad m => m a -> ConsumableStackT s m a
lift m a
mx = ([s] -> m (Maybe (a, [s]))) -> ConsumableStackT s m a
forall s (m :: * -> *) a.
([s] -> m (Maybe (a, [s]))) -> ConsumableStackT s m a
ConsumableStackT (([s] -> m (Maybe (a, [s]))) -> ConsumableStackT s m a)
-> ([s] -> m (Maybe (a, [s]))) -> ConsumableStackT s m a
forall a b. (a -> b) -> a -> b
$ \[s]
s -> (a, [s]) -> Maybe (a, [s])
forall a. a -> Maybe a
Just ((a, [s]) -> Maybe (a, [s]))
-> (a -> (a, [s])) -> a -> Maybe (a, [s])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (,[s]
s) (a -> Maybe (a, [s])) -> m a -> m (Maybe (a, [s]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m a
mx

data ConsumableStackError
    = ConsumableStackDepletedEarly
    | ConsumableStackNotFullyConsumed

class (Monad m) => MonadConsumableStack s m | m -> s where
    -- | Pop the head off the stack.
    pop :: m s

instance (Monad m) => MonadConsumableStack s (ConsumableStackT s m) where
    pop :: ConsumableStackT s m s
pop = ([s] -> m (Maybe (s, [s]))) -> ConsumableStackT s m s
forall s (m :: * -> *) a.
([s] -> m (Maybe (a, [s]))) -> ConsumableStackT s m a
ConsumableStackT (([s] -> m (Maybe (s, [s]))) -> ConsumableStackT s m s)
-> ([s] -> m (Maybe (s, [s]))) -> ConsumableStackT s m s
forall a b. (a -> b) -> a -> b
$ Maybe (s, [s]) -> m (Maybe (s, [s]))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (s, [s]) -> m (Maybe (s, [s])))
-> ([s] -> Maybe (s, [s])) -> [s] -> m (Maybe (s, [s]))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [s] -> Maybe (s, [s])
forall a. [a] -> Maybe (a, [a])
uncons

-- | Run a consumable stack monad transformer.
--   Returns a 'ConsumableStackError' iff the stack was depleted early or not
--   fully consumed.
runConsumableStackT
    :: (Functor m)
    => ConsumableStackT s m a
    -> [s]
    -> m (Either ConsumableStackError a)
runConsumableStackT :: forall (m :: * -> *) s a.
Functor m =>
ConsumableStackT s m a -> [s] -> m (Either ConsumableStackError a)
runConsumableStackT ConsumableStackT s m a
sx [s]
stack = Maybe (a, [s]) -> Either ConsumableStackError a
forall {b} {a}. Maybe (b, [a]) -> Either ConsumableStackError b
f (Maybe (a, [s]) -> Either ConsumableStackError a)
-> m (Maybe (a, [s])) -> m (Either ConsumableStackError a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConsumableStackT s m a -> [s] -> m (Maybe (a, [s]))
forall s (m :: * -> *) a.
ConsumableStackT s m a -> [s] -> m (Maybe (a, [s]))
unwrapCST ConsumableStackT s m a
sx [s]
stack
  where
    f :: Maybe (b, [a]) -> Either ConsumableStackError b
f Maybe (b, [a])
Nothing = ConsumableStackError -> Either ConsumableStackError b
forall a b. a -> Either a b
Left ConsumableStackError
ConsumableStackDepletedEarly
    f (Just (b
x, [])) = b -> Either ConsumableStackError b
forall a b. b -> Either a b
Right b
x
    f (Just (b
_, [a]
_)) = ConsumableStackError -> Either ConsumableStackError b
forall a b. a -> Either a b
Left ConsumableStackError
ConsumableStackNotFullyConsumed

runConsumableStack
    :: ConsumableStack s a
    -> [s]
    -> Either ConsumableStackError a
runConsumableStack :: forall s a.
ConsumableStack s a -> [s] -> Either ConsumableStackError a
runConsumableStack ConsumableStack s a
sx = Identity (Either ConsumableStackError a)
-> Either ConsumableStackError a
forall a. Identity a -> a
runIdentity (Identity (Either ConsumableStackError a)
 -> Either ConsumableStackError a)
-> ([s] -> Identity (Either ConsumableStackError a))
-> [s]
-> Either ConsumableStackError a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConsumableStack s a
-> [s] -> Identity (Either ConsumableStackError a)
forall (m :: * -> *) s a.
Functor m =>
ConsumableStackT s m a -> [s] -> m (Either ConsumableStackError a)
runConsumableStackT ConsumableStack s a
sx