{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Server.Handlers.PasswordResetHandlers
( PasswordResetAPI
, passwordResetServer
, requestPasswordResetHandler
, confirmPasswordResetHandler
) where
import Control.Monad (unless, when)
import Control.Monad.Except (runExceptT)
import Control.Monad.IO.Class (liftIO)
import Data.Functor ((<&>))
import Data.Maybe (isJust)
import Data.Password.Argon2
( PasswordHash (unPasswordHash)
, hashPassword
, mkPassword
)
import qualified Data.Text as Text
import Data.Time (getCurrentTime)
import qualified Hasql.Session as Session
import Servant
import Server.Auth.PasswordReset
import Server.Auth.PasswordResetUtil (getTokenExpirationTime)
import qualified Server.Auth.PasswordResetUtil as Util
import Server.HandlerUtil
import System.Environment (getEnv)
import qualified UserManagement.Sessions as Sessions
import qualified UserManagement.User as User
passwordResetServer :: Server PasswordResetAPI
passwordResetServer :: Server PasswordResetAPI
passwordResetServer = PasswordResetRequest -> Handler NoContent
requestPasswordResetHandler (PasswordResetRequest -> Handler NoContent)
-> (PasswordResetConfirm -> Handler NoContent)
-> (PasswordResetRequest -> Handler NoContent)
:<|> (PasswordResetConfirm -> Handler NoContent)
forall a b. a -> b -> a :<|> b
:<|> PasswordResetConfirm -> Handler NoContent
confirmPasswordResetHandler
requestPasswordResetHandler :: PasswordResetRequest -> Handler NoContent
requestPasswordResetHandler :: PasswordResetRequest -> Handler NoContent
requestPasswordResetHandler PasswordResetRequest {Text
resetRequestEmail :: Text
resetRequestEmail :: PasswordResetRequest -> Text
..} = do
Connection
conn <- Handler Connection
tryGetDBConnection
Either SessionError (Maybe User)
eUser <- IO (Either SessionError (Maybe User))
-> Handler (Either SessionError (Maybe User))
forall a. IO a -> Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SessionError (Maybe User))
-> Handler (Either SessionError (Maybe User)))
-> IO (Either SessionError (Maybe User))
-> Handler (Either SessionError (Maybe User))
forall a b. (a -> b) -> a -> b
$ Session (Maybe User)
-> Connection -> IO (Either SessionError (Maybe User))
forall a. Session a -> Connection -> IO (Either SessionError a)
Session.run (Text -> Session (Maybe User)
Sessions.getUserByEmail Text
resetRequestEmail) Connection
conn
case Either SessionError (Maybe User)
eUser of
Right (Just User
user) -> do
Text
token <- IO Text -> Handler Text
forall a. IO a -> Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Text
Util.generateResetToken
let tokenHash :: Text
tokenHash = Text -> Text
Util.hashToken Text
token
UTCTime
expiresAt <- IO UTCTime -> Handler UTCTime
forall a. IO a -> Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getTokenExpirationTime
Either SessionError UUID
eTokenId <-
IO (Either SessionError UUID) -> Handler (Either SessionError UUID)
forall a. IO a -> Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SessionError UUID)
-> Handler (Either SessionError UUID))
-> IO (Either SessionError UUID)
-> Handler (Either SessionError UUID)
forall a b. (a -> b) -> a -> b
$
Session UUID -> Connection -> IO (Either SessionError UUID)
forall a. Session a -> Connection -> IO (Either SessionError a)
Session.run
(UUID -> Text -> UTCTime -> Session UUID
Sessions.createPasswordResetToken (User -> UUID
User.userID User
user) Text
tokenHash UTCTime
expiresAt)
Connection
conn
case Either SessionError UUID
eTokenId of
Right UUID
_ -> do
Text
baseUrl <- IO Text -> Handler Text
forall a. IO a -> Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Text -> Handler Text) -> IO Text -> Handler Text
forall a b. (a -> b) -> a -> b
$ String -> IO String
getEnv String
"SERVER_HOST" IO String -> (String -> Text) -> IO Text
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> String -> Text
Text.pack
let resetUrl :: Text
resetUrl = Text -> Text -> Text
Util.createResetUrl Text
baseUrl Text
token
IO () -> Handler ()
forall a. IO a -> Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Handler ()) -> IO () -> Handler ()
forall a b. (a -> b) -> a -> b
$
Text -> Text -> Text -> IO ()
Util.sendPasswordResetEmail
(User -> Text
User.userEmail User
user)
(User -> Text
User.userName User
user)
Text
resetUrl
NoContent -> Handler NoContent
forall a. a -> Handler a
forall (m :: * -> *) a. Monad m => a -> m a
return NoContent
NoContent
Left SessionError
_ -> ServerError -> Handler NoContent
forall a. ServerError -> Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ServerError
errDatabaseAccessFailed
Right Maybe User
Nothing ->
NoContent -> Handler NoContent
forall a. a -> Handler a
forall (m :: * -> *) a. Monad m => a -> m a
return NoContent
NoContent
Left SessionError
_ -> ServerError -> Handler NoContent
forall a. ServerError -> Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ServerError
errDatabaseAccessFailed
confirmPasswordResetHandler :: PasswordResetConfirm -> Handler NoContent
confirmPasswordResetHandler :: PasswordResetConfirm -> Handler NoContent
confirmPasswordResetHandler PasswordResetConfirm {Text
resetConfirmToken :: Text
resetConfirmNewPassword :: Text
resetConfirmNewPassword :: PasswordResetConfirm -> Text
resetConfirmToken :: PasswordResetConfirm -> Text
..} =
do
Bool -> Handler () -> Handler ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Text -> Bool
Util.validateTokenFormat Text
resetConfirmToken) (Handler () -> Handler ()) -> Handler () -> Handler ()
forall a b. (a -> b) -> a -> b
$
ServerError -> Handler ()
forall a. ServerError -> Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> Handler ()) -> ServerError -> Handler ()
forall a b. (a -> b) -> a -> b
$
ServerError
err400 {errBody = "Invalid token format"}
Connection
conn <- Handler Connection
tryGetDBConnection
let tokenHash :: Text
tokenHash = Text -> Text
Util.hashToken Text
resetConfirmToken
ExceptT ServerError Handler NoContent
-> Handler (Either ServerError NoContent)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT ServerError Handler NoContent
-> Handler (Either ServerError NoContent))
-> ExceptT ServerError Handler NoContent
-> Handler (Either ServerError NoContent)
forall a b. (a -> b) -> a -> b
$
do
Either
SessionError
(Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime))
eToken <- IO
(Either
SessionError
(Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)))
-> ExceptT
ServerError
Handler
(Either
SessionError
(Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)))
forall a. IO a -> ExceptT ServerError Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO
(Either
SessionError
(Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)))
-> ExceptT
ServerError
Handler
(Either
SessionError
(Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime))))
-> IO
(Either
SessionError
(Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)))
-> ExceptT
ServerError
Handler
(Either
SessionError
(Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)))
forall a b. (a -> b) -> a -> b
$ Session (Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime))
-> Connection
-> IO
(Either
SessionError
(Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)))
forall a. Session a -> Connection -> IO (Either SessionError a)
Session.run (Text
-> Session
(Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime))
Sessions.getPasswordResetToken Text
tokenHash) Connection
conn
(UUID
_, UUID
userId, Text
_, UTCTime
expiresAt, UTCTime
_, Maybe UTCTime
usedAt) <- case Either
SessionError
(Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime))
eToken of
Right (Just (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
t) -> (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
-> ExceptT
ServerError
Handler
(UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
forall a. a -> ExceptT ServerError Handler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
t
Right Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
Nothing -> ServerError
-> ExceptT
ServerError
Handler
(UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
forall a. ServerError -> ExceptT ServerError Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError
-> ExceptT
ServerError
Handler
(UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime))
-> ServerError
-> ExceptT
ServerError
Handler
(UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
forall a b. (a -> b) -> a -> b
$ ServerError
err400 {errBody = "Invalid or expired token"}
Left SessionError
_ -> ServerError
-> ExceptT
ServerError
Handler
(UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
forall a. ServerError -> ExceptT ServerError Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ServerError
errDatabaseAccessFailed
Bool
-> ExceptT ServerError Handler () -> ExceptT ServerError Handler ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe UTCTime -> Bool
forall a. Maybe a -> Bool
isJust Maybe UTCTime
usedAt) (ExceptT ServerError Handler () -> ExceptT ServerError Handler ())
-> ExceptT ServerError Handler () -> ExceptT ServerError Handler ()
forall a b. (a -> b) -> a -> b
$
ServerError -> ExceptT ServerError Handler ()
forall a. ServerError -> ExceptT ServerError Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> ExceptT ServerError Handler ())
-> ServerError -> ExceptT ServerError Handler ()
forall a b. (a -> b) -> a -> b
$
ServerError
err400 {errBody = "Token has already been used"}
UTCTime
now <- IO UTCTime -> ExceptT ServerError Handler UTCTime
forall a. IO a -> ExceptT ServerError Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
Bool
-> ExceptT ServerError Handler () -> ExceptT ServerError Handler ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (UTCTime
now UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
> UTCTime
expiresAt) (ExceptT ServerError Handler () -> ExceptT ServerError Handler ())
-> ExceptT ServerError Handler () -> ExceptT ServerError Handler ()
forall a b. (a -> b) -> a -> b
$
ServerError -> ExceptT ServerError Handler ()
forall a. ServerError -> ExceptT ServerError Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> ExceptT ServerError Handler ())
-> ServerError -> ExceptT ServerError Handler ()
forall a b. (a -> b) -> a -> b
$
ServerError
err400 {errBody = "Token has expired"}
PasswordHash Argon2
hashedPassword <- IO (PasswordHash Argon2)
-> ExceptT ServerError Handler (PasswordHash Argon2)
forall a. IO a -> ExceptT ServerError Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PasswordHash Argon2)
-> ExceptT ServerError Handler (PasswordHash Argon2))
-> IO (PasswordHash Argon2)
-> ExceptT ServerError Handler (PasswordHash Argon2)
forall a b. (a -> b) -> a -> b
$ Password -> IO (PasswordHash Argon2)
forall (m :: * -> *).
MonadIO m =>
Password -> m (PasswordHash Argon2)
hashPassword (Text -> Password
mkPassword Text
resetConfirmNewPassword)
Either SessionError ()
eUpdate <-
IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ())
forall a. IO a -> ExceptT ServerError Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ()))
-> IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ())
forall a b. (a -> b) -> a -> b
$
Session () -> Connection -> IO (Either SessionError ())
forall a. Session a -> Connection -> IO (Either SessionError a)
Session.run
(UUID -> Text -> Session ()
Sessions.updateUserPWHash UUID
userId (PasswordHash Argon2 -> Text
forall a. PasswordHash a -> Text
unPasswordHash PasswordHash Argon2
hashedPassword))
Connection
conn
case Either SessionError ()
eUpdate of
Right ()
_ -> () -> ExceptT ServerError Handler ()
forall a. a -> ExceptT ServerError Handler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Left SessionError
_ -> ServerError -> ExceptT ServerError Handler ()
forall a. ServerError -> ExceptT ServerError Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ServerError
errDatabaseAccessFailed
Either SessionError ()
_ <- IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ())
forall a. IO a -> ExceptT ServerError Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ()))
-> IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ())
forall a b. (a -> b) -> a -> b
$ Session () -> Connection -> IO (Either SessionError ())
forall a. Session a -> Connection -> IO (Either SessionError a)
Session.run (Text -> Session ()
Sessions.markPasswordResetTokenUsed Text
tokenHash) Connection
conn
Either SessionError ()
_ <- IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ())
forall a. IO a -> ExceptT ServerError Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ()))
-> IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ())
forall a b. (a -> b) -> a -> b
$ Session () -> Connection -> IO (Either SessionError ())
forall a. Session a -> Connection -> IO (Either SessionError a)
Session.run Session ()
Sessions.cleanupExpiredTokens Connection
conn
NoContent -> ExceptT ServerError Handler NoContent
forall a. a -> ExceptT ServerError Handler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoContent
NoContent
Handler (Either ServerError NoContent)
-> (Either ServerError NoContent -> Handler NoContent)
-> Handler NoContent
forall a b. Handler a -> (a -> Handler b) -> Handler b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ServerError -> Handler NoContent)
-> (NoContent -> Handler NoContent)
-> Either ServerError NoContent
-> Handler NoContent
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either ServerError -> Handler NoContent
forall a. ServerError -> Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError NoContent -> Handler NoContent
forall a. a -> Handler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure