diff options
Diffstat (limited to 'users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs')
-rw-r--r-- | users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs | 379 |
1 files changed, 379 insertions, 0 deletions
diff --git a/users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs b/users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs new file mode 100644 index 000000000000..e602ee287fa2 --- /dev/null +++ b/users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs @@ -0,0 +1,379 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE FunctionalDependencies #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE TypeFamilyDependencies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# OPTIONS_GHC -Wno-orphans #-} + +module Postgres.MonadPostgres where + +import Control.Exception +import Control.Monad.Except +import Control.Monad.Logger.CallStack +import Control.Monad.Reader (MonadReader (ask), ReaderT (..)) +import Data.Error.Tree +import Data.Int (Int64) +import Data.Kind (Type) +import Data.List qualified as List +import Data.Typeable (Typeable) +import Database.PostgreSQL.Simple (Connection, FormatError, FromRow, Query, QueryError, ResultError, SqlError, ToRow) +import Database.PostgreSQL.Simple qualified as PG +import Database.PostgreSQL.Simple.FromRow qualified as PG +import Database.PostgreSQL.Simple.ToField (ToField) +import Database.PostgreSQL.Simple.ToRow (ToRow (toRow)) +import Database.PostgreSQL.Simple.Types (fromQuery) +import GHC.Records (HasField (..)) +import Label +import PossehlAnalyticsPrelude +import Postgres.Decoder +import Pretty (showPretty) +import System.Exit (ExitCode (..)) +import Tool +import UnliftIO (MonadUnliftIO (withRunInIO)) +import UnliftIO.Process qualified as Process + +-- | Postgres queries/commands that can be executed within a running transaction. +-- +-- These are implemented with the @postgresql-simple@ primitives of the same name +-- and will behave the same unless othewise documented. +class Monad m => MonadPostgres (m :: Type -> Type) where + -- | Execute an INSERT, UPDATE, or other SQL query that is not expected to return results. + -- + -- Returns the number of rows affected. + execute :: (ToRow params, Typeable params) => Query -> params -> Transaction m (Label "numberOfRowsAffected" Natural) + + -- | Execute an INSERT, UPDATE, or other SQL query that is not expected to return results. Does not perform parameter substitution. + -- + -- Returns the number of rows affected. + execute_ :: Query -> Transaction m (Label "numberOfRowsAffected" Natural) + + -- | Execute a multi-row INSERT, UPDATE, or other SQL query that is not expected to return results. + -- + -- Returns the number of rows affected. If the list of parameters is empty, this function will simply return 0 without issuing the query to the backend. If this is not desired, consider using the 'PG.Values' constructor instead. + executeMany :: (ToRow params, Typeable params) => Query -> [params] -> Transaction m (Label "numberOfRowsAffected" Natural) + + -- | Execute INSERT ... RETURNING, UPDATE ... RETURNING, or other SQL query that accepts multi-row input and is expected to return results. Note that it is possible to write query conn "INSERT ... RETURNING ..." ... in cases where you are only inserting a single row, and do not need functionality analogous to 'executeMany'. + -- + -- If the list of parameters is empty, this function will simply return [] without issuing the query to the backend. If this is not desired, consider using the 'PG.Values' constructor instead. + executeManyReturningWith :: (ToRow q) => Query -> [q] -> Decoder r -> Transaction m [r] + + -- | Run a query, passing parameters and result row parser. + queryWith :: (PG.ToRow params, Typeable params, Typeable r) => PG.Query -> params -> Decoder r -> Transaction m [r] + + -- | Run a query without any parameters and result row parser. + queryWith_ :: (Typeable r) => PG.Query -> Decoder r -> Transaction m [r] + + -- | Run a query, passing parameters, and fold over the resulting rows. + -- + -- This doesn’t have to realize the full list of results in memory, + -- rather results are streamed incrementally from the database. + -- + -- When dealing with small results, it may be simpler (and perhaps faster) to use query instead. + -- + -- This fold is _not_ strict. The stream consumer is responsible for forcing the evaluation of its result to avoid space leaks. + -- + -- If you can, prefer aggregating in the database itself. + foldRows :: + (FromRow row, ToRow params, Typeable row, Typeable params) => + Query -> + params -> + a -> + (a -> row -> Transaction m a) -> + Transaction m a + + -- | Run a given transaction in a transaction block, rolling back the transaction + -- if any exception (postgres or Haskell Exception) is thrown during execution. + -- + -- Re-throws the exception. + -- + -- Don’t do any long-running things on the Haskell side during a transaction, + -- because it will block a database connection and potentially also lock + -- database tables from being written or read by other clients. + -- + -- Nonetheless, try to push transactions as far out to the handlers as possible, + -- don’t do something like @runTransaction $ query …@, because it will lead people + -- to accidentally start nested transactions (the inner transaction is run on a new connections, + -- thus can’t see any changes done by the outer transaction). + -- Only handlers should run transactions. + runTransaction :: Transaction m a -> m a + +-- | Run a query, passing parameters. +query :: forall m params r. (PG.ToRow params, PG.FromRow r, Typeable params, Typeable r, MonadPostgres m) => PG.Query -> params -> Transaction m [r] +query qry params = queryWith qry params (Decoder PG.fromRow) + +-- | Run a query without any parameters. +query_ :: forall m r. (Typeable r, PG.FromRow r, MonadPostgres m) => PG.Query -> Transaction m [r] +query_ qry = queryWith_ qry (Decoder PG.fromRow) + +-- TODO: implement via fold, so that the result doesn’t have to be realized in memory +querySingleRow :: + ( MonadPostgres m, + ToRow qParams, + Typeable qParams, + FromRow a, + Typeable a, + MonadThrow m + ) => + Query -> + qParams -> + Transaction m a +querySingleRow qry params = do + query qry params >>= ensureSingleRow + +-- TODO: implement via fold, so that the result doesn’t have to be realized in memory +querySingleRowMaybe :: + ( MonadPostgres m, + ToRow qParams, + Typeable qParams, + FromRow a, + Typeable a, + MonadThrow m + ) => + Query -> + qParams -> + Transaction m (Maybe a) +querySingleRowMaybe qry params = do + rows <- query qry params + case rows of + [] -> pure Nothing + [one] -> pure (Just one) + -- TODO: Should we MonadThrow this here? It’s really an implementation detail of MonadPostgres + -- that a database function can error out, should probably handled by the instances. + more -> throwM $ SingleRowError {numberOfRowsReturned = (List.length more)} + +ensureSingleRow :: MonadThrow m => [a] -> m a +ensureSingleRow = \case + -- TODO: Should we MonadThrow this here? It’s really an implementation detail of MonadPostgres + -- that a database function can error out, should probably handled by the instances. + [] -> throwM (SingleRowError {numberOfRowsReturned = 0}) + [one] -> pure one + more -> + throwM $ + SingleRowError + { numberOfRowsReturned = + -- TODO: this is VERY bad, because it requires to parse the full database output, even if there’s 10000000000 elements + List.length more + } + +newtype Transaction m a = Transaction {unTransaction :: (ReaderT Connection m a)} + deriving newtype + ( Functor, + Applicative, + Monad, + MonadThrow, + MonadLogger, + MonadIO, + MonadUnliftIO, + MonadTrans + ) + +runTransaction' :: Connection -> Transaction m a -> m a +runTransaction' conn transaction = runReaderT transaction.unTransaction conn + +-- | Catch any Postgres exception that gets thrown, +-- print the query that was run and the query parameters, +-- then rethrow inside an 'Error'. +handlePGException :: + forall a params m. + (ToRow params, MonadUnliftIO m, MonadLogger m, MonadTools m) => + Text -> + Query -> + -- | Depending on whether we used `format` or `formatMany`. + Either params [params] -> + IO a -> + Transaction m a +handlePGException queryType query' params io = do + withRunInIO $ \unliftIO -> + io + `catches` [ Handler $ unliftIO . logQueryException @SqlError, + Handler $ unliftIO . logQueryException @QueryError, + Handler $ unliftIO . logQueryException @ResultError, + Handler $ unliftIO . logFormatException + ] + where + -- TODO: use throwInternalError here (after pulling it into the MonadPostgres class) + throwAsError = unwrapIOError . Left . newError + throwErr err = liftIO $ throwAsError $ prettyErrorTree $ nestedMultiError "A Postgres query failed" err + logQueryException :: Exception e => e -> Transaction m a + logQueryException exc = do + formattedQuery <- case params of + Left one -> pgFormatQuery' query' one + Right many -> pgFormatQueryMany' query' many + throwErr + ( singleError [fmt|Query Type: {queryType}|] + :| [ nestedError "Exception" (exc & showPretty & newError & singleError), + nestedError "Query" (formattedQuery & newError & singleError) + ] + ) + logFormatException :: FormatError -> Transaction m a + logFormatException fe = throwErr (fe & showPretty & newError & singleError & singleton) + +pgExecute :: (ToRow params, MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> params -> Transaction m (Label "numberOfRowsAffected" Natural) +pgExecute qry params = do + conn <- Transaction ask + PG.execute conn qry params + & handlePGException "execute" qry (Left params) + >>= toNumberOfRowsAffected "pgExecute" + +pgExecute_ :: (MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> Transaction m (Label "numberOfRowsAffected" Natural) +pgExecute_ qry = do + conn <- Transaction ask + PG.execute_ conn qry + & handlePGException "execute_" qry (Left ()) + >>= toNumberOfRowsAffected "pgExecute_" + +pgExecuteMany :: (ToRow params, MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> [params] -> Transaction m (Label "numberOfRowsAffected" Natural) +pgExecuteMany qry params = + do + conn <- Transaction ask + PG.executeMany conn qry params + & handlePGException "executeMany" qry (Right params) + >>= toNumberOfRowsAffected "pgExecuteMany" + +toNumberOfRowsAffected :: MonadIO m => Text -> Int64 -> m (Label "numberOfRowsAffected" Natural) +toNumberOfRowsAffected functionName i64 = + i64 + & intToNatural + & annotate [fmt|{functionName}: postgres returned a negative number of rows affected: {i64}|] + -- we throw this directly in IO here, because we don’t want to e.g. have to propagate MonadThrow through user code (it’s an assertion) + & unwrapIOError + & liftIO + <&> label @"numberOfRowsAffected" + +pgExecuteManyReturningWith :: (ToRow params, MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> [params] -> Decoder r -> Transaction m [r] +pgExecuteManyReturningWith qry params (Decoder fromRow) = + do + conn <- Transaction ask + PG.returningWith fromRow conn qry params + & handlePGException "executeManyReturning" qry (Right params) + +pgFold :: + (FromRow row, ToRow params, MonadUnliftIO m, MonadLogger m, MonadTools m) => + Query -> + params -> + a -> + (a -> row -> Transaction m a) -> + Transaction m a +pgFold qry params accumulator f = do + conn <- Transaction ask + + withRunInIO + ( \runInIO -> + do + PG.fold + conn + qry + params + accumulator + (\acc row -> runInIO $ f acc row) + & handlePGException "fold" qry (Left params) + & runInIO + ) + +pgFormatQuery :: (ToRow params, MonadIO m) => Query -> params -> Transaction m ByteString +pgFormatQuery qry params = Transaction $ do + conn <- ask + liftIO $ PG.formatQuery conn qry params + +pgFormatQueryMany :: (MonadIO m, ToRow params) => Query -> [params] -> Transaction m ByteString +pgFormatQueryMany qry params = Transaction $ do + conn <- ask + liftIO $ PG.formatMany conn qry params + +pgQueryWith :: (ToRow params, MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> params -> Decoder r -> Transaction m [r] +pgQueryWith qry params (Decoder fromRow) = do + conn <- Transaction ask + PG.queryWith fromRow conn qry params + & handlePGException "query" qry (Left params) + +pgQueryWith_ :: (MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> Decoder r -> Transaction m [r] +pgQueryWith_ qry (Decoder fromRow) = do + conn <- Transaction ask + liftIO (PG.queryWith_ fromRow conn qry) + & handlePGException "query" qry (Left ()) + +pgQuery :: (ToRow params, FromRow r, MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> params -> Transaction m [r] +pgQuery qry params = do + conn <- Transaction ask + PG.query conn qry params + & handlePGException "query" qry (Left params) + +pgQuery_ :: (FromRow r, MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> Transaction m [r] +pgQuery_ qry = do + conn <- Transaction ask + PG.query_ conn qry + & handlePGException "query_" qry (Left ()) + +data SingleRowError = SingleRowError + { -- | How many columns were actually returned by the query + numberOfRowsReturned :: Int + } + deriving stock (Show) + +instance Exception SingleRowError where + displayException (SingleRowError {..}) = [fmt|Single row expected from SQL query result, {numberOfRowsReturned} rows were returned instead."|] + +pgFormatQueryNoParams' :: (MonadIO m, MonadLogger m, MonadTools m) => Query -> Transaction m Text +pgFormatQueryNoParams' q = + lift $ pgFormatQueryByteString q.fromQuery + +pgFormatQuery' :: (MonadIO m, ToRow params, MonadLogger m, MonadTools m) => Query -> params -> Transaction m Text +pgFormatQuery' q p = + pgFormatQuery q p + >>= lift . pgFormatQueryByteString + +pgFormatQueryMany' :: (MonadIO m, ToRow params, MonadLogger m, MonadTools m) => Query -> [params] -> Transaction m Text +pgFormatQueryMany' q p = + pgFormatQueryMany q p + >>= lift . pgFormatQueryByteString + +-- | Tools required at runtime +data Tools = Tools + { pgFormat :: Tool + } + deriving stock (Show) + +class Monad m => MonadTools m where + getTools :: m Tools + +initMonadTools :: Label "envvar" Text -> IO Tools +initMonadTools var = + Tool.readTools (label @"toolsEnvVar" var.envvar) toolParser + where + toolParser = do + pgFormat <- readTool "pg_format" + pure $ Tools {..} + +pgFormatQueryByteString :: (MonadIO m, MonadLogger m, MonadTools m) => ByteString -> m Text +pgFormatQueryByteString queryBytes = do + do + tools <- getTools + (exitCode, stdout, stderr) <- + Process.readProcessWithExitCode + tools.pgFormat.toolPath + ["-"] + (queryBytes & bytesToTextUtf8Lenient & textToString) + case exitCode of + ExitSuccess -> pure (stdout & stringToText) + ExitFailure status -> do + logWarn [fmt|pg_format failed with status {status} while formatting the query, using original query string. Is there a syntax error?|] + logDebug + ( prettyErrorTree + ( nestedMultiError + "pg_format output" + ( nestedError "stdout" (singleError (stdout & stringToText & newError)) + :| [(nestedError "stderr" (singleError (stderr & stringToText & newError)))] + ) + ) + ) + logDebug [fmt|pg_format stdout: stderr|] + pure (queryBytes & bytesToTextUtf8Lenient) + +instance (ToField t1) => ToRow (Label l1 t1) where + toRow t2 = toRow $ PG.Only $ getField @l1 t2 + +instance (ToField t1, ToField t2) => ToRow (T2 l1 t1 l2 t2) where + toRow t2 = toRow (getField @l1 t2, getField @l2 t2) + +instance (ToField t1, ToField t2, ToField t3) => ToRow (T3 l1 t1 l2 t2 l3 t3) where + toRow t3 = toRow (getField @l1 t3, getField @l2 t3, getField @l3 t3) |