about summary refs log tree commit diff
path: root/users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs
diff options
context:
space:
mode:
Diffstat (limited to 'users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs')
-rw-r--r--users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs379
1 files changed, 379 insertions, 0 deletions
diff --git a/users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs b/users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs
new file mode 100644
index 0000000000..e602ee287f
--- /dev/null
+++ b/users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs
@@ -0,0 +1,379 @@
+{-# LANGUAGE AllowAmbiguousTypes #-}
+{-# LANGUAGE FunctionalDependencies #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE TypeFamilyDependencies #-}
+{-# LANGUAGE UndecidableInstances #-}
+{-# OPTIONS_GHC -Wno-orphans #-}
+
+module Postgres.MonadPostgres where
+
+import Control.Exception
+import Control.Monad.Except
+import Control.Monad.Logger.CallStack
+import Control.Monad.Reader (MonadReader (ask), ReaderT (..))
+import Data.Error.Tree
+import Data.Int (Int64)
+import Data.Kind (Type)
+import Data.List qualified as List
+import Data.Typeable (Typeable)
+import Database.PostgreSQL.Simple (Connection, FormatError, FromRow, Query, QueryError, ResultError, SqlError, ToRow)
+import Database.PostgreSQL.Simple qualified as PG
+import Database.PostgreSQL.Simple.FromRow qualified as PG
+import Database.PostgreSQL.Simple.ToField (ToField)
+import Database.PostgreSQL.Simple.ToRow (ToRow (toRow))
+import Database.PostgreSQL.Simple.Types (fromQuery)
+import GHC.Records (HasField (..))
+import Label
+import PossehlAnalyticsPrelude
+import Postgres.Decoder
+import Pretty (showPretty)
+import System.Exit (ExitCode (..))
+import Tool
+import UnliftIO (MonadUnliftIO (withRunInIO))
+import UnliftIO.Process qualified as Process
+
+-- | Postgres queries/commands that can be executed within a running transaction.
+--
+-- These are implemented with the @postgresql-simple@ primitives of the same name
+-- and will behave the same unless othewise documented.
+class Monad m => MonadPostgres (m :: Type -> Type) where
+  -- | Execute an INSERT, UPDATE, or other SQL query that is not expected to return results.
+  --
+  -- Returns the number of rows affected.
+  execute :: (ToRow params, Typeable params) => Query -> params -> Transaction m (Label "numberOfRowsAffected" Natural)
+
+  -- | Execute an INSERT, UPDATE, or other SQL query that is not expected to return results. Does not perform parameter substitution.
+  --
+  -- Returns the number of rows affected.
+  execute_ :: Query -> Transaction m (Label "numberOfRowsAffected" Natural)
+
+  -- | Execute a multi-row INSERT, UPDATE, or other SQL query that is not expected to return results.
+  --
+  -- Returns the number of rows affected. If the list of parameters is empty, this function will simply return 0 without issuing the query to the backend. If this is not desired, consider using the 'PG.Values' constructor instead.
+  executeMany :: (ToRow params, Typeable params) => Query -> [params] -> Transaction m (Label "numberOfRowsAffected" Natural)
+
+  -- | Execute INSERT ... RETURNING, UPDATE ... RETURNING, or other SQL query that accepts multi-row input and is expected to return results. Note that it is possible to write query conn "INSERT ... RETURNING ..." ... in cases where you are only inserting a single row, and do not need functionality analogous to 'executeMany'.
+  --
+  -- If the list of parameters is empty, this function will simply return [] without issuing the query to the backend. If this is not desired, consider using the 'PG.Values' constructor instead.
+  executeManyReturningWith :: (ToRow q) => Query -> [q] -> Decoder r -> Transaction m [r]
+
+  -- | Run a query, passing parameters and result row parser.
+  queryWith :: (PG.ToRow params, Typeable params, Typeable r) => PG.Query -> params -> Decoder r -> Transaction m [r]
+
+  -- | Run a query without any parameters and result row parser.
+  queryWith_ :: (Typeable r) => PG.Query -> Decoder r -> Transaction m [r]
+
+  -- | Run a query, passing parameters, and fold over the resulting rows.
+  --
+  -- This doesn’t have to realize the full list of results in memory,
+  -- rather results are streamed incrementally from the database.
+  --
+  -- When dealing with small results, it may be simpler (and perhaps faster) to use query instead.
+  --
+  -- This fold is _not_ strict. The stream consumer is responsible for forcing the evaluation of its result to avoid space leaks.
+  --
+  -- If you can, prefer aggregating in the database itself.
+  foldRows ::
+    (FromRow row, ToRow params, Typeable row, Typeable params) =>
+    Query ->
+    params ->
+    a ->
+    (a -> row -> Transaction m a) ->
+    Transaction m a
+
+  -- | Run a given transaction in a transaction block, rolling back the transaction
+  -- if any exception (postgres or Haskell Exception) is thrown during execution.
+  --
+  -- Re-throws the exception.
+  --
+  -- Don’t do any long-running things on the Haskell side during a transaction,
+  -- because it will block a database connection and potentially also lock
+  -- database tables from being written or read by other clients.
+  --
+  -- Nonetheless, try to push transactions as far out to the handlers as possible,
+  -- don’t do something like @runTransaction $ query …@, because it will lead people
+  -- to accidentally start nested transactions (the inner transaction is run on a new connections,
+  -- thus can’t see any changes done by the outer transaction).
+  -- Only handlers should run transactions.
+  runTransaction :: Transaction m a -> m a
+
+-- | Run a query, passing parameters.
+query :: forall m params r. (PG.ToRow params, PG.FromRow r, Typeable params, Typeable r, MonadPostgres m) => PG.Query -> params -> Transaction m [r]
+query qry params = queryWith qry params (Decoder PG.fromRow)
+
+-- | Run a query without any parameters.
+query_ :: forall m r. (Typeable r, PG.FromRow r, MonadPostgres m) => PG.Query -> Transaction m [r]
+query_ qry = queryWith_ qry (Decoder PG.fromRow)
+
+-- TODO: implement via fold, so that the result doesn’t have to be realized in memory
+querySingleRow ::
+  ( MonadPostgres m,
+    ToRow qParams,
+    Typeable qParams,
+    FromRow a,
+    Typeable a,
+    MonadThrow m
+  ) =>
+  Query ->
+  qParams ->
+  Transaction m a
+querySingleRow qry params = do
+  query qry params >>= ensureSingleRow
+
+-- TODO: implement via fold, so that the result doesn’t have to be realized in memory
+querySingleRowMaybe ::
+  ( MonadPostgres m,
+    ToRow qParams,
+    Typeable qParams,
+    FromRow a,
+    Typeable a,
+    MonadThrow m
+  ) =>
+  Query ->
+  qParams ->
+  Transaction m (Maybe a)
+querySingleRowMaybe qry params = do
+  rows <- query qry params
+  case rows of
+    [] -> pure Nothing
+    [one] -> pure (Just one)
+    -- TODO: Should we MonadThrow this here? It’s really an implementation detail of MonadPostgres
+    -- that a database function can error out, should probably handled by the instances.
+    more -> throwM $ SingleRowError {numberOfRowsReturned = (List.length more)}
+
+ensureSingleRow :: MonadThrow m => [a] -> m a
+ensureSingleRow = \case
+  -- TODO: Should we MonadThrow this here? It’s really an implementation detail of MonadPostgres
+  -- that a database function can error out, should probably handled by the instances.
+  [] -> throwM (SingleRowError {numberOfRowsReturned = 0})
+  [one] -> pure one
+  more ->
+    throwM $
+      SingleRowError
+        { numberOfRowsReturned =
+            -- TODO: this is VERY bad, because it requires to parse the full database output, even if there’s 10000000000 elements
+            List.length more
+        }
+
+newtype Transaction m a = Transaction {unTransaction :: (ReaderT Connection m a)}
+  deriving newtype
+    ( Functor,
+      Applicative,
+      Monad,
+      MonadThrow,
+      MonadLogger,
+      MonadIO,
+      MonadUnliftIO,
+      MonadTrans
+    )
+
+runTransaction' :: Connection -> Transaction m a -> m a
+runTransaction' conn transaction = runReaderT transaction.unTransaction conn
+
+-- | Catch any Postgres exception that gets thrown,
+-- print the query that was run and the query parameters,
+-- then rethrow inside an 'Error'.
+handlePGException ::
+  forall a params m.
+  (ToRow params, MonadUnliftIO m, MonadLogger m, MonadTools m) =>
+  Text ->
+  Query ->
+  -- | Depending on whether we used `format` or `formatMany`.
+  Either params [params] ->
+  IO a ->
+  Transaction m a
+handlePGException queryType query' params io = do
+  withRunInIO $ \unliftIO ->
+    io
+      `catches` [ Handler $ unliftIO . logQueryException @SqlError,
+                  Handler $ unliftIO . logQueryException @QueryError,
+                  Handler $ unliftIO . logQueryException @ResultError,
+                  Handler $ unliftIO . logFormatException
+                ]
+  where
+    -- TODO: use throwInternalError here (after pulling it into the MonadPostgres class)
+    throwAsError = unwrapIOError . Left . newError
+    throwErr err = liftIO $ throwAsError $ prettyErrorTree $ nestedMultiError "A Postgres query failed" err
+    logQueryException :: Exception e => e -> Transaction m a
+    logQueryException exc = do
+      formattedQuery <- case params of
+        Left one -> pgFormatQuery' query' one
+        Right many -> pgFormatQueryMany' query' many
+      throwErr
+        ( singleError [fmt|Query Type: {queryType}|]
+            :| [ nestedError "Exception" (exc & showPretty & newError & singleError),
+                 nestedError "Query" (formattedQuery & newError & singleError)
+               ]
+        )
+    logFormatException :: FormatError -> Transaction m a
+    logFormatException fe = throwErr (fe & showPretty & newError & singleError & singleton)
+
+pgExecute :: (ToRow params, MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> params -> Transaction m (Label "numberOfRowsAffected" Natural)
+pgExecute qry params = do
+  conn <- Transaction ask
+  PG.execute conn qry params
+    & handlePGException "execute" qry (Left params)
+    >>= toNumberOfRowsAffected "pgExecute"
+
+pgExecute_ :: (MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> Transaction m (Label "numberOfRowsAffected" Natural)
+pgExecute_ qry = do
+  conn <- Transaction ask
+  PG.execute_ conn qry
+    & handlePGException "execute_" qry (Left ())
+    >>= toNumberOfRowsAffected "pgExecute_"
+
+pgExecuteMany :: (ToRow params, MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> [params] -> Transaction m (Label "numberOfRowsAffected" Natural)
+pgExecuteMany qry params =
+  do
+    conn <- Transaction ask
+    PG.executeMany conn qry params
+      & handlePGException "executeMany" qry (Right params)
+      >>= toNumberOfRowsAffected "pgExecuteMany"
+
+toNumberOfRowsAffected :: MonadIO m => Text -> Int64 -> m (Label "numberOfRowsAffected" Natural)
+toNumberOfRowsAffected functionName i64 =
+  i64
+    & intToNatural
+    & annotate [fmt|{functionName}: postgres returned a negative number of rows affected: {i64}|]
+    -- we throw this directly in IO here, because we don’t want to e.g. have to propagate MonadThrow through user code (it’s an assertion)
+    & unwrapIOError
+    & liftIO
+    <&> label @"numberOfRowsAffected"
+
+pgExecuteManyReturningWith :: (ToRow params, MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> [params] -> Decoder r -> Transaction m [r]
+pgExecuteManyReturningWith qry params (Decoder fromRow) =
+  do
+    conn <- Transaction ask
+    PG.returningWith fromRow conn qry params
+      & handlePGException "executeManyReturning" qry (Right params)
+
+pgFold ::
+  (FromRow row, ToRow params, MonadUnliftIO m, MonadLogger m, MonadTools m) =>
+  Query ->
+  params ->
+  a ->
+  (a -> row -> Transaction m a) ->
+  Transaction m a
+pgFold qry params accumulator f = do
+  conn <- Transaction ask
+
+  withRunInIO
+    ( \runInIO ->
+        do
+          PG.fold
+            conn
+            qry
+            params
+            accumulator
+            (\acc row -> runInIO $ f acc row)
+            & handlePGException "fold" qry (Left params)
+            & runInIO
+    )
+
+pgFormatQuery :: (ToRow params, MonadIO m) => Query -> params -> Transaction m ByteString
+pgFormatQuery qry params = Transaction $ do
+  conn <- ask
+  liftIO $ PG.formatQuery conn qry params
+
+pgFormatQueryMany :: (MonadIO m, ToRow params) => Query -> [params] -> Transaction m ByteString
+pgFormatQueryMany qry params = Transaction $ do
+  conn <- ask
+  liftIO $ PG.formatMany conn qry params
+
+pgQueryWith :: (ToRow params, MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> params -> Decoder r -> Transaction m [r]
+pgQueryWith qry params (Decoder fromRow) = do
+  conn <- Transaction ask
+  PG.queryWith fromRow conn qry params
+    & handlePGException "query" qry (Left params)
+
+pgQueryWith_ :: (MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> Decoder r -> Transaction m [r]
+pgQueryWith_ qry (Decoder fromRow) = do
+  conn <- Transaction ask
+  liftIO (PG.queryWith_ fromRow conn qry)
+    & handlePGException "query" qry (Left ())
+
+pgQuery :: (ToRow params, FromRow r, MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> params -> Transaction m [r]
+pgQuery qry params = do
+  conn <- Transaction ask
+  PG.query conn qry params
+    & handlePGException "query" qry (Left params)
+
+pgQuery_ :: (FromRow r, MonadUnliftIO m, MonadLogger m, MonadTools m) => Query -> Transaction m [r]
+pgQuery_ qry = do
+  conn <- Transaction ask
+  PG.query_ conn qry
+    & handlePGException "query_" qry (Left ())
+
+data SingleRowError = SingleRowError
+  { -- | How many columns were actually returned by the query
+    numberOfRowsReturned :: Int
+  }
+  deriving stock (Show)
+
+instance Exception SingleRowError where
+  displayException (SingleRowError {..}) = [fmt|Single row expected from SQL query result, {numberOfRowsReturned} rows were returned instead."|]
+
+pgFormatQueryNoParams' :: (MonadIO m, MonadLogger m, MonadTools m) => Query -> Transaction m Text
+pgFormatQueryNoParams' q =
+  lift $ pgFormatQueryByteString q.fromQuery
+
+pgFormatQuery' :: (MonadIO m, ToRow params, MonadLogger m, MonadTools m) => Query -> params -> Transaction m Text
+pgFormatQuery' q p =
+  pgFormatQuery q p
+    >>= lift . pgFormatQueryByteString
+
+pgFormatQueryMany' :: (MonadIO m, ToRow params, MonadLogger m, MonadTools m) => Query -> [params] -> Transaction m Text
+pgFormatQueryMany' q p =
+  pgFormatQueryMany q p
+    >>= lift . pgFormatQueryByteString
+
+-- | Tools required at runtime
+data Tools = Tools
+  { pgFormat :: Tool
+  }
+  deriving stock (Show)
+
+class Monad m => MonadTools m where
+  getTools :: m Tools
+
+initMonadTools :: Label "envvar" Text -> IO Tools
+initMonadTools var =
+  Tool.readTools (label @"toolsEnvVar" var.envvar) toolParser
+  where
+    toolParser = do
+      pgFormat <- readTool "pg_format"
+      pure $ Tools {..}
+
+pgFormatQueryByteString :: (MonadIO m, MonadLogger m, MonadTools m) => ByteString -> m Text
+pgFormatQueryByteString queryBytes = do
+  do
+    tools <- getTools
+    (exitCode, stdout, stderr) <-
+      Process.readProcessWithExitCode
+        tools.pgFormat.toolPath
+        ["-"]
+        (queryBytes & bytesToTextUtf8Lenient & textToString)
+    case exitCode of
+      ExitSuccess -> pure (stdout & stringToText)
+      ExitFailure status -> do
+        logWarn [fmt|pg_format failed with status {status} while formatting the query, using original query string. Is there a syntax error?|]
+        logDebug
+          ( prettyErrorTree
+              ( nestedMultiError
+                  "pg_format output"
+                  ( nestedError "stdout" (singleError (stdout & stringToText & newError))
+                      :| [(nestedError "stderr" (singleError (stderr & stringToText & newError)))]
+                  )
+              )
+          )
+        logDebug [fmt|pg_format stdout: stderr|]
+        pure (queryBytes & bytesToTextUtf8Lenient)
+
+instance (ToField t1) => ToRow (Label l1 t1) where
+  toRow t2 = toRow $ PG.Only $ getField @l1 t2
+
+instance (ToField t1, ToField t2) => ToRow (T2 l1 t1 l2 t2) where
+  toRow t2 = toRow (getField @l1 t2, getField @l2 t2)
+
+instance (ToField t1, ToField t2, ToField t3) => ToRow (T3 l1 t1 l2 t2 l3 t3) where
+  toRow t3 = toRow (getField @l1 t3, getField @l2 t3, getField @l3 t3)