{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Server.Handlers.PasswordResetHandlers
    ( PasswordResetAPI
    , passwordResetServer
    , requestPasswordResetHandler
    , confirmPasswordResetHandler
    ) where

import Control.Monad (unless, when)
import Control.Monad.Except (runExceptT)
import Control.Monad.IO.Class (liftIO)
import Data.Functor ((<&>))
import Data.Maybe (isJust)
import Data.Password.Argon2
    ( PasswordHash (unPasswordHash)
    , hashPassword
    , mkPassword
    )
import qualified Data.Text as Text
import Data.Time (getCurrentTime)
import qualified Hasql.Session as Session
import Servant
import Server.Auth.PasswordReset
import Server.Auth.PasswordResetUtil (getTokenExpirationTime)
import qualified Server.Auth.PasswordResetUtil as Util
import Server.HandlerUtil
import System.Environment (getEnv)
import qualified UserManagement.Sessions as Sessions
import qualified UserManagement.User as User

-- | Server implementation for password reset API
passwordResetServer :: Server PasswordResetAPI
passwordResetServer :: Server PasswordResetAPI
passwordResetServer = PasswordResetRequest -> Handler NoContent
requestPasswordResetHandler (PasswordResetRequest -> Handler NoContent)
-> (PasswordResetConfirm -> Handler NoContent)
-> (PasswordResetRequest -> Handler NoContent)
   :<|> (PasswordResetConfirm -> Handler NoContent)
forall a b. a -> b -> a :<|> b
:<|> PasswordResetConfirm -> Handler NoContent
confirmPasswordResetHandler

-- | Handler for password reset requests
requestPasswordResetHandler :: PasswordResetRequest -> Handler NoContent
requestPasswordResetHandler :: PasswordResetRequest -> Handler NoContent
requestPasswordResetHandler PasswordResetRequest {Text
resetRequestEmail :: Text
resetRequestEmail :: PasswordResetRequest -> Text
..} = do
    Connection
conn <- Handler Connection
tryGetDBConnection

    -- Check if user exists
    Either SessionError (Maybe User)
eUser <- IO (Either SessionError (Maybe User))
-> Handler (Either SessionError (Maybe User))
forall a. IO a -> Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SessionError (Maybe User))
 -> Handler (Either SessionError (Maybe User)))
-> IO (Either SessionError (Maybe User))
-> Handler (Either SessionError (Maybe User))
forall a b. (a -> b) -> a -> b
$ Session (Maybe User)
-> Connection -> IO (Either SessionError (Maybe User))
forall a. Session a -> Connection -> IO (Either SessionError a)
Session.run (Text -> Session (Maybe User)
Sessions.getUserByEmail Text
resetRequestEmail) Connection
conn
    case Either SessionError (Maybe User)
eUser of
        Right (Just User
user) -> do
            -- Generate reset token
            Text
token <- IO Text -> Handler Text
forall a. IO a -> Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Text
Util.generateResetToken
            let tokenHash :: Text
tokenHash = Text -> Text
Util.hashToken Text
token

            UTCTime
expiresAt <- IO UTCTime -> Handler UTCTime
forall a. IO a -> Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getTokenExpirationTime

            -- Store token in database
            Either SessionError UUID
eTokenId <-
                IO (Either SessionError UUID) -> Handler (Either SessionError UUID)
forall a. IO a -> Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SessionError UUID)
 -> Handler (Either SessionError UUID))
-> IO (Either SessionError UUID)
-> Handler (Either SessionError UUID)
forall a b. (a -> b) -> a -> b
$
                    Session UUID -> Connection -> IO (Either SessionError UUID)
forall a. Session a -> Connection -> IO (Either SessionError a)
Session.run
                        (UUID -> Text -> UTCTime -> Session UUID
Sessions.createPasswordResetToken (User -> UUID
User.userID User
user) Text
tokenHash UTCTime
expiresAt)
                        Connection
conn

            case Either SessionError UUID
eTokenId of
                Right UUID
_ -> do
                    Text
baseUrl <- IO Text -> Handler Text
forall a. IO a -> Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Text -> Handler Text) -> IO Text -> Handler Text
forall a b. (a -> b) -> a -> b
$ String -> IO String
getEnv String
"SERVER_HOST" IO String -> (String -> Text) -> IO Text
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> String -> Text
Text.pack
                    let resetUrl :: Text
resetUrl = Text -> Text -> Text
Util.createResetUrl Text
baseUrl Text
token

                    -- Send email
                    IO () -> Handler ()
forall a. IO a -> Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> Handler ()) -> IO () -> Handler ()
forall a b. (a -> b) -> a -> b
$
                        Text -> Text -> Text -> IO ()
Util.sendPasswordResetEmail
                            (User -> Text
User.userEmail User
user)
                            (User -> Text
User.userName User
user)
                            Text
resetUrl

                    NoContent -> Handler NoContent
forall a. a -> Handler a
forall (m :: * -> *) a. Monad m => a -> m a
return NoContent
NoContent
                Left SessionError
_ -> ServerError -> Handler NoContent
forall a. ServerError -> Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ServerError
errDatabaseAccessFailed
        Right Maybe User
Nothing ->
            -- For security, don't reveal whether email exists
            -- Just return success but don't send email
            NoContent -> Handler NoContent
forall a. a -> Handler a
forall (m :: * -> *) a. Monad m => a -> m a
return NoContent
NoContent
        Left SessionError
_ -> ServerError -> Handler NoContent
forall a. ServerError -> Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ServerError
errDatabaseAccessFailed

confirmPasswordResetHandler :: PasswordResetConfirm -> Handler NoContent
confirmPasswordResetHandler :: PasswordResetConfirm -> Handler NoContent
confirmPasswordResetHandler PasswordResetConfirm {Text
resetConfirmToken :: Text
resetConfirmNewPassword :: Text
resetConfirmNewPassword :: PasswordResetConfirm -> Text
resetConfirmToken :: PasswordResetConfirm -> Text
..} =
    do
        Bool -> Handler () -> Handler ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Text -> Bool
Util.validateTokenFormat Text
resetConfirmToken) (Handler () -> Handler ()) -> Handler () -> Handler ()
forall a b. (a -> b) -> a -> b
$
            ServerError -> Handler ()
forall a. ServerError -> Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> Handler ()) -> ServerError -> Handler ()
forall a b. (a -> b) -> a -> b
$
                ServerError
err400 {errBody = "Invalid token format"}

        Connection
conn <- Handler Connection
tryGetDBConnection
        let tokenHash :: Text
tokenHash = Text -> Text
Util.hashToken Text
resetConfirmToken

        ExceptT ServerError Handler NoContent
-> Handler (Either ServerError NoContent)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (ExceptT ServerError Handler NoContent
 -> Handler (Either ServerError NoContent))
-> ExceptT ServerError Handler NoContent
-> Handler (Either ServerError NoContent)
forall a b. (a -> b) -> a -> b
$
            do
                -- Look up token
                Either
  SessionError
  (Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime))
eToken <- IO
  (Either
     SessionError
     (Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)))
-> ExceptT
     ServerError
     Handler
     (Either
        SessionError
        (Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)))
forall a. IO a -> ExceptT ServerError Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO
   (Either
      SessionError
      (Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)))
 -> ExceptT
      ServerError
      Handler
      (Either
         SessionError
         (Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime))))
-> IO
     (Either
        SessionError
        (Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)))
-> ExceptT
     ServerError
     Handler
     (Either
        SessionError
        (Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)))
forall a b. (a -> b) -> a -> b
$ Session (Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime))
-> Connection
-> IO
     (Either
        SessionError
        (Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)))
forall a. Session a -> Connection -> IO (Either SessionError a)
Session.run (Text
-> Session
     (Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime))
Sessions.getPasswordResetToken Text
tokenHash) Connection
conn
                (UUID
_, UUID
userId, Text
_, UTCTime
expiresAt, UTCTime
_, Maybe UTCTime
usedAt) <- case Either
  SessionError
  (Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime))
eToken of
                    Right (Just (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
t) -> (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
-> ExceptT
     ServerError
     Handler
     (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
forall a. a -> ExceptT ServerError Handler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
t
                    Right Maybe (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
Nothing -> ServerError
-> ExceptT
     ServerError
     Handler
     (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
forall a. ServerError -> ExceptT ServerError Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError
 -> ExceptT
      ServerError
      Handler
      (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime))
-> ServerError
-> ExceptT
     ServerError
     Handler
     (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
forall a b. (a -> b) -> a -> b
$ ServerError
err400 {errBody = "Invalid or expired token"}
                    Left SessionError
_ -> ServerError
-> ExceptT
     ServerError
     Handler
     (UUID, UUID, Text, UTCTime, UTCTime, Maybe UTCTime)
forall a. ServerError -> ExceptT ServerError Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ServerError
errDatabaseAccessFailed

                -- Check if already used
                Bool
-> ExceptT ServerError Handler () -> ExceptT ServerError Handler ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe UTCTime -> Bool
forall a. Maybe a -> Bool
isJust Maybe UTCTime
usedAt) (ExceptT ServerError Handler () -> ExceptT ServerError Handler ())
-> ExceptT ServerError Handler () -> ExceptT ServerError Handler ()
forall a b. (a -> b) -> a -> b
$
                    ServerError -> ExceptT ServerError Handler ()
forall a. ServerError -> ExceptT ServerError Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> ExceptT ServerError Handler ())
-> ServerError -> ExceptT ServerError Handler ()
forall a b. (a -> b) -> a -> b
$
                        ServerError
err400 {errBody = "Token has already been used"}

                -- Check expiration
                UTCTime
now <- IO UTCTime -> ExceptT ServerError Handler UTCTime
forall a. IO a -> ExceptT ServerError Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
                Bool
-> ExceptT ServerError Handler () -> ExceptT ServerError Handler ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (UTCTime
now UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
> UTCTime
expiresAt) (ExceptT ServerError Handler () -> ExceptT ServerError Handler ())
-> ExceptT ServerError Handler () -> ExceptT ServerError Handler ()
forall a b. (a -> b) -> a -> b
$
                    ServerError -> ExceptT ServerError Handler ()
forall a. ServerError -> ExceptT ServerError Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> ExceptT ServerError Handler ())
-> ServerError -> ExceptT ServerError Handler ()
forall a b. (a -> b) -> a -> b
$
                        ServerError
err400 {errBody = "Token has expired"}

                -- Hash new password
                PasswordHash Argon2
hashedPassword <- IO (PasswordHash Argon2)
-> ExceptT ServerError Handler (PasswordHash Argon2)
forall a. IO a -> ExceptT ServerError Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (PasswordHash Argon2)
 -> ExceptT ServerError Handler (PasswordHash Argon2))
-> IO (PasswordHash Argon2)
-> ExceptT ServerError Handler (PasswordHash Argon2)
forall a b. (a -> b) -> a -> b
$ Password -> IO (PasswordHash Argon2)
forall (m :: * -> *).
MonadIO m =>
Password -> m (PasswordHash Argon2)
hashPassword (Text -> Password
mkPassword Text
resetConfirmNewPassword)

                -- Update password
                Either SessionError ()
eUpdate <-
                    IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ())
forall a. IO a -> ExceptT ServerError Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SessionError ())
 -> ExceptT ServerError Handler (Either SessionError ()))
-> IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ())
forall a b. (a -> b) -> a -> b
$
                        Session () -> Connection -> IO (Either SessionError ())
forall a. Session a -> Connection -> IO (Either SessionError a)
Session.run
                            (UUID -> Text -> Session ()
Sessions.updateUserPWHash UUID
userId (PasswordHash Argon2 -> Text
forall a. PasswordHash a -> Text
unPasswordHash PasswordHash Argon2
hashedPassword))
                            Connection
conn
                case Either SessionError ()
eUpdate of
                    Right ()
_ -> () -> ExceptT ServerError Handler ()
forall a. a -> ExceptT ServerError Handler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
                    Left SessionError
_ -> ServerError -> ExceptT ServerError Handler ()
forall a. ServerError -> ExceptT ServerError Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError ServerError
errDatabaseAccessFailed

                -- Mark token as used (ignore failures here)
                Either SessionError ()
_ <- IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ())
forall a. IO a -> ExceptT ServerError Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SessionError ())
 -> ExceptT ServerError Handler (Either SessionError ()))
-> IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ())
forall a b. (a -> b) -> a -> b
$ Session () -> Connection -> IO (Either SessionError ())
forall a. Session a -> Connection -> IO (Either SessionError a)
Session.run (Text -> Session ()
Sessions.markPasswordResetTokenUsed Text
tokenHash) Connection
conn
                -- Cleanup expired tokens (best effort)
                Either SessionError ()
_ <- IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ())
forall a. IO a -> ExceptT ServerError Handler a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either SessionError ())
 -> ExceptT ServerError Handler (Either SessionError ()))
-> IO (Either SessionError ())
-> ExceptT ServerError Handler (Either SessionError ())
forall a b. (a -> b) -> a -> b
$ Session () -> Connection -> IO (Either SessionError ())
forall a. Session a -> Connection -> IO (Either SessionError a)
Session.run Session ()
Sessions.cleanupExpiredTokens Connection
conn

                NoContent -> ExceptT ServerError Handler NoContent
forall a. a -> ExceptT ServerError Handler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoContent
NoContent
        Handler (Either ServerError NoContent)
-> (Either ServerError NoContent -> Handler NoContent)
-> Handler NoContent
forall a b. Handler a -> (a -> Handler b) -> Handler b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (ServerError -> Handler NoContent)
-> (NoContent -> Handler NoContent)
-> Either ServerError NoContent
-> Handler NoContent
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either ServerError -> Handler NoContent
forall a. ServerError -> Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError NoContent -> Handler NoContent
forall a. a -> Handler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure