about summary refs log tree commit diff
path: root/users/Profpatsch/my-prelude
diff options
context:
space:
mode:
Diffstat (limited to 'users/Profpatsch/my-prelude')
-rw-r--r--users/Profpatsch/my-prelude/default.nix1
-rw-r--r--users/Profpatsch/my-prelude/my-prelude.cabal5
-rw-r--r--users/Profpatsch/my-prelude/src/MyPrelude.hs288
-rw-r--r--users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs313
4 files changed, 486 insertions, 121 deletions
diff --git a/users/Profpatsch/my-prelude/default.nix b/users/Profpatsch/my-prelude/default.nix
index 7af3a899f30a..1f68cfd16e75 100644
--- a/users/Profpatsch/my-prelude/default.nix
+++ b/users/Profpatsch/my-prelude/default.nix
@@ -28,6 +28,7 @@ pkgs.haskellPackages.mkDerivation {
     pkgs.haskellPackages.pa-pretty
     pkgs.haskellPackages.pa-field-parser
     pkgs.haskellPackages.aeson-better-errors
+    pkgs.haskellPackages.foldl
     pkgs.haskellPackages.resource-pool
     pkgs.haskellPackages.error
     pkgs.haskellPackages.hs-opentelemetry-api
diff --git a/users/Profpatsch/my-prelude/my-prelude.cabal b/users/Profpatsch/my-prelude/my-prelude.cabal
index f9b0e11831cf..49746cd432c7 100644
--- a/users/Profpatsch/my-prelude/my-prelude.cabal
+++ b/users/Profpatsch/my-prelude/my-prelude.cabal
@@ -48,6 +48,8 @@ common common-options
     -- to enable the `type` keyword in import lists (ormolu uses this automatically)
     ExplicitNamespaces
 
+    -- allows defining pattern synonyms, but also the `import Foo (pattern FooPattern)` import syntax
+    PatternSynonyms
   default-language: GHC2021
 
 
@@ -83,6 +85,7 @@ library
      , aeson-better-errors
      , bytestring
      , containers
+     , foldl
      , unordered-containers
      , resource-pool
      , resourcet
@@ -101,9 +104,11 @@ library
      , PyF
      , semigroupoids
      , selective
+     , template-haskell
      , text
      , these
      , unix
      , unliftio
      , validation-selective
      , vector
+     , ghc-boot
diff --git a/users/Profpatsch/my-prelude/src/MyPrelude.hs b/users/Profpatsch/my-prelude/src/MyPrelude.hs
index 7857ace61fdb..cd246d172881 100644
--- a/users/Profpatsch/my-prelude/src/MyPrelude.hs
+++ b/users/Profpatsch/my-prelude/src/MyPrelude.hs
@@ -1,11 +1,7 @@
 {-# LANGUAGE ImplicitParams #-}
 {-# LANGUAGE LambdaCase #-}
 {-# LANGUAGE MagicHash #-}
-{-# LANGUAGE OverloadedStrings #-}
-{-# LANGUAGE PolyKinds #-}
-{-# LANGUAGE RankNTypes #-}
-{-# LANGUAGE ScopedTypeVariables #-}
-{-# OPTIONS_GHC -fexpose-all-unfoldings #-}
+{-# LANGUAGE ViewPatterns #-}
 
 module MyPrelude
   ( -- * Text conversions
@@ -15,6 +11,7 @@ module MyPrelude
     fmt,
     textToString,
     stringToText,
+    stringToBytesUtf8,
     showToText,
     textToBytesUtf8,
     textToBytesUtf8Lazy,
@@ -42,6 +39,7 @@ module MyPrelude
     HasField,
 
     -- * Control flow
+    doAs,
     (&),
     (<&>),
     (<|>),
@@ -91,6 +89,9 @@ module MyPrelude
     failure,
     successes,
     failures,
+    traverseValidate,
+    traverseValidateM,
+    traverseValidateM_,
     eitherToValidation,
     eitherToListValidation,
     validationToEither,
@@ -100,15 +101,28 @@ module MyPrelude
     validationToThese,
     thenThese,
     thenValidate,
+    thenValidateM,
     NonEmpty ((:|)),
+    pattern IsEmpty,
+    pattern IsNonEmpty,
     singleton,
     nonEmpty,
     nonEmptyDef,
+    overNonEmpty,
+    zipNonEmpty,
+    zipWithNonEmpty,
+    zip3NonEmpty,
+    zipWith3NonEmpty,
+    zip4NonEmpty,
     toList,
-    toNonEmptyDefault,
+    lengthNatural,
     maximum1,
     minimum1,
+    maximumBy1,
+    minimumBy1,
+    Vector,
     Generic,
+    Lift,
     Semigroup,
     sconcat,
     Monoid,
@@ -120,6 +134,7 @@ module MyPrelude
     Identity (Identity, runIdentity),
     Natural,
     intToNatural,
+    Scientific,
     Contravariant,
     contramap,
     (>$<),
@@ -132,10 +147,16 @@ module MyPrelude
     Category,
     (>>>),
     (&>>),
+    Any,
 
     -- * Enum definition
     inverseFunction,
     inverseMap,
+    enumerateAll,
+
+    -- * Map helpers
+    mapFromListOn,
+    mapFromListOnMerge,
 
     -- * Error handling
     HasCallStack,
@@ -145,6 +166,7 @@ where
 
 import Control.Applicative ((<|>))
 import Control.Category (Category, (>>>))
+import Control.Foldl.NonEmpty qualified as Foldl1
 import Control.Monad (guard, join, unless, when)
 import Control.Monad.Catch (MonadThrow (throwM))
 import Control.Monad.Except
@@ -164,13 +186,15 @@ import Data.Char qualified
 import Data.Coerce (Coercible, coerce)
 import Data.Data (Proxy (Proxy))
 import Data.Error
-import Data.Foldable (Foldable (foldMap', toList), fold, foldl', for_, traverse_)
+import Data.Foldable (Foldable (foldMap', toList), fold, foldl', for_, sequenceA_, traverse_)
 import Data.Foldable qualified as Foldable
 import Data.Function ((&))
 import Data.Functor ((<&>))
 import Data.Functor.Contravariant (Contravariant (contramap), (>$<))
 import Data.Functor.Identity (Identity (runIdentity))
+import Data.List (zip4)
 import Data.List.NonEmpty (NonEmpty ((:|)), nonEmpty)
+import Data.List.NonEmpty qualified as NonEmpty
 import Data.Map.Strict
   ( Map,
   )
@@ -178,7 +202,8 @@ import Data.Map.Strict qualified as Map
 import Data.Maybe (fromMaybe, mapMaybe)
 import Data.Maybe qualified as Maybe
 import Data.Profunctor (Profunctor, dimap, lmap, rmap)
-import Data.Semigroup (Max (Max, getMax), Min (Min, getMin), sconcat)
+import Data.Scientific (Scientific)
+import Data.Semigroup (sconcat)
 import Data.Semigroup.Foldable (Foldable1 (fold1), foldMap1)
 import Data.Semigroup.Traversable (Traversable1)
 import Data.Semigroupoid (Semigroupoid (o))
@@ -192,14 +217,17 @@ import Data.Text.Lazy qualified
 import Data.Text.Lazy.Encoding qualified
 import Data.These (These (That, These, This))
 import Data.Traversable (for)
+import Data.Vector (Vector)
 import Data.Void (Void, absurd)
 import Data.Word (Word8)
 import GHC.Exception (errorCallWithCallStackException)
-import GHC.Exts (RuntimeRep, TYPE, raise#)
+import GHC.Exts (Any, RuntimeRep, TYPE, raise#)
 import GHC.Generics (Generic)
 import GHC.Natural (Natural)
 import GHC.Records (HasField)
 import GHC.Stack (HasCallStack)
+import GHC.Utils.Encoding qualified as GHC
+import Language.Haskell.TH.Syntax (Lift)
 import PyF (fmt)
 import System.Exit qualified
 import System.IO qualified
@@ -212,6 +240,21 @@ import Validation
     validationToEither,
   )
 
+-- | Mark a `do`-block with the type of the Monad/Applicativ it uses.
+-- Only intended for reading ease and making code easier to understand,
+-- especially do-blocks that use unconventional monads (like Maybe or List).
+--
+-- Example:
+--
+-- @
+-- doAs @Maybe $ do
+--  a <- Just 'a'
+--  b <- Just 'b'
+--  pure (a, b)
+-- @
+doAs :: forall m a. m a -> m a
+doAs = id
+
 -- | Forward-applying 'contramap', like '&'/'$' and '<&>'/'<$>' but for '>$<'.
 (>&<) :: (Contravariant f) => f b -> (a -> b) -> f a
 (>&<) = flip contramap
@@ -222,10 +265,10 @@ infixl 5 >&<
 --
 -- Specialized examples:
 --
--- @@
+-- @
 -- for functions : (a -> b) -> (b -> c) -> (a -> c)
 -- for Folds: Fold a b -> Fold b c -> Fold a c
--- @@
+-- @
 (&>>) :: (Semigroupoid s) => s a b -> s b c -> s a c
 (&>>) = flip Data.Semigroupoid.o
 
@@ -266,26 +309,51 @@ bytesToTextUtf8LenientLazy :: Data.ByteString.Lazy.ByteString -> Data.Text.Lazy.
 bytesToTextUtf8LenientLazy =
   Data.Text.Lazy.Encoding.decodeUtf8With Data.Text.Encoding.Error.lenientDecode
 
--- | Make a lazy text strict
+-- | Make a lazy 'Text' strict.
 toStrict :: Data.Text.Lazy.Text -> Text
 toStrict = Data.Text.Lazy.toStrict
 
--- | Make a strict text lazy
+-- | Make a strict 'Text' lazy.
 toLazy :: Text -> Data.Text.Lazy.Text
 toLazy = Data.Text.Lazy.fromStrict
 
+-- | Make a lazy 'ByteString' strict.
 toStrictBytes :: Data.ByteString.Lazy.ByteString -> ByteString
 toStrictBytes = Data.ByteString.Lazy.toStrict
 
+-- | Make a strict 'ByteString' lazy.
 toLazyBytes :: ByteString -> Data.ByteString.Lazy.ByteString
 toLazyBytes = Data.ByteString.Lazy.fromStrict
 
+-- | Convert a (performant) 'Text' into an (imperformant) list-of-char 'String'.
+--
+-- Some libraries (like @time@ or @network-uri@) still use the `String` as their interface. We only want to convert to string at the edges, otherwise use 'Text'.
+--
+-- ATTN: Don’t use `String` in code if you can avoid it, prefer `Text` instead.
 textToString :: Text -> String
 textToString = Data.Text.unpack
 
+-- | Convert an (imperformant) list-of-char 'String' into a (performant) 'Text' .
+--
+-- Some libraries (like @time@ or @network-uri@) still use the `String` as their interface. We want to convert 'String' to 'Text' as soon as possible and only use 'Text' in our code.
+--
+-- ATTN: Don’t use `String` in code if you can avoid it, prefer `Text` instead.
 stringToText :: String -> Text
 stringToText = Data.Text.pack
 
+-- | Encode a String to an UTF-8 encoded Bytestring
+--
+-- ATTN: Don’t use `String` in code if you can avoid it, prefer `Text` instead.
+stringToBytesUtf8 :: String -> ByteString
+stringToBytesUtf8 = GHC.utf8EncodeString
+
+-- | Like `show`, but generate a 'Text'
+--
+-- ATTN: This goes via `String` and thus is fairly inefficient.
+-- We should add a good display library at one point.
+--
+-- ATTN: unlike `show`, this forces the whole @'a
+-- so only use if you want to display the whole thing.
 showToText :: (Show a) => a -> Text
 showToText = stringToText . show
 
@@ -299,8 +367,20 @@ showToText = stringToText . show
 -- >>> charToWordUnsafe ','
 -- 44
 charToWordUnsafe :: Char -> Word8
-charToWordUnsafe = fromIntegral . Data.Char.ord
 {-# INLINE charToWordUnsafe #-}
+charToWordUnsafe = fromIntegral . Data.Char.ord
+
+pattern IsEmpty :: [a]
+pattern IsEmpty <- (null -> True)
+  where
+    IsEmpty = []
+
+pattern IsNonEmpty :: NonEmpty a -> [a]
+pattern IsNonEmpty n <- (nonEmpty -> Just n)
+  where
+    IsNonEmpty n = toList n
+
+{-# COMPLETE IsEmpty, IsNonEmpty #-}
 
 -- | Single element in a (non-empty) list.
 singleton :: a -> NonEmpty a
@@ -313,19 +393,69 @@ nonEmptyDef def xs =
     Nothing -> def :| []
     Just ne -> ne
 
--- | Construct a non-empty list, given a default value if the ist list was empty.
-toNonEmptyDefault :: a -> [a] -> NonEmpty a
-toNonEmptyDefault def xs = case xs of
-  [] -> def :| []
-  (x : xs') -> x :| xs'
+-- | If the list is not empty, run the given function with a NonEmpty list, otherwise just return []
+overNonEmpty :: (Applicative f) => (NonEmpty a -> f [b]) -> [a] -> f [b]
+overNonEmpty f xs = case xs of
+  IsEmpty -> pure []
+  IsNonEmpty xs' -> f xs'
+
+-- | Zip two non-empty lists.
+zipNonEmpty :: NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
+{-# INLINE zipNonEmpty #-}
+zipNonEmpty ~(a :| as) ~(b :| bs) = (a, b) :| zip as bs
+
+-- | Zip two non-empty lists, combining them with the given function
+zipWithNonEmpty :: (a -> b -> c) -> NonEmpty a -> NonEmpty b -> NonEmpty c
+{-# INLINE zipWithNonEmpty #-}
+zipWithNonEmpty = NonEmpty.zipWith
+
+-- | Zip three non-empty lists.
+zip3NonEmpty :: NonEmpty a -> NonEmpty b -> NonEmpty c -> NonEmpty (a, b, c)
+{-# INLINE zip3NonEmpty #-}
+zip3NonEmpty ~(a :| as) ~(b :| bs) ~(c :| cs) = (a, b, c) :| zip3 as bs cs
 
--- | @O(n)@. Get the maximum element from a non-empty structure.
+-- | Zip three non-empty lists, combining them with the given function
+zipWith3NonEmpty :: (a -> b -> c -> d) -> NonEmpty a -> NonEmpty b -> NonEmpty c -> NonEmpty d
+{-# INLINE zipWith3NonEmpty #-}
+zipWith3NonEmpty f ~(x :| xs) ~(y :| ys) ~(z :| zs) = f x y z :| zipWith3 f xs ys zs
+
+-- | Zip four non-empty lists
+zip4NonEmpty :: NonEmpty a -> NonEmpty b -> NonEmpty c -> NonEmpty d -> NonEmpty (a, b, c, d)
+{-# INLINE zip4NonEmpty #-}
+zip4NonEmpty ~(a :| as) ~(b :| bs) ~(c :| cs) ~(d :| ds) = (a, b, c, d) :| zip4 as bs cs ds
+
+-- | We don’t want to use Foldable’s `length`, because it is too polymorphic and can lead to bugs.
+-- Only list-y things should have a length.
+class (Foldable f) => Lengthy f
+
+instance Lengthy []
+
+instance Lengthy NonEmpty
+
+instance Lengthy Vector
+
+lengthNatural :: (Lengthy f) => f a -> Natural
+lengthNatural xs =
+  xs
+    & Foldable.length
+    -- length can never be negative or something went really, really wrong
+    & fromIntegral @Int @Natural
+
+-- | @O(n)@. Get the maximum element from a non-empty structure (strict).
 maximum1 :: (Foldable1 f, Ord a) => f a -> a
-maximum1 xs = xs & foldMap1 Max & getMax
+maximum1 = Foldl1.fold1 Foldl1.maximum
 
--- | @O(n)@. Get the minimum element from a non-empty structure.
+-- | @O(n)@. Get the maximum element from a non-empty structure, using the given comparator (strict).
+maximumBy1 :: (Foldable1 f) => (a -> a -> Ordering) -> f a -> a
+maximumBy1 f = Foldl1.fold1 (Foldl1.maximumBy f)
+
+-- | @O(n)@. Get the minimum element from a non-empty structure (strict).
 minimum1 :: (Foldable1 f, Ord a) => f a -> a
-minimum1 xs = xs & foldMap1 Min & getMin
+minimum1 = Foldl1.fold1 Foldl1.minimum
+
+-- | @O(n)@. Get the minimum element from a non-empty structure, using the given comparator (strict).
+minimumBy1 :: (Foldable1 f) => (a -> a -> Ordering) -> f a -> a
+minimumBy1 f = Foldl1.fold1 (Foldl1.minimumBy f)
 
 -- | Annotate a 'Maybe' with an error message and turn it into an 'Either'.
 annotate :: err -> Maybe a -> Either err a
@@ -355,8 +485,48 @@ findMaybe mPred list =
         Just a -> mPred a
         Nothing -> Nothing
 
+-- | 'traverse' with a function returning 'Either' and collect all errors that happen, if they happen.
+--
+-- Does not shortcut on error, so will always traverse the whole list/'Traversable' structure.
+--
+-- This is a useful error handling function in many circumstances,
+-- because it won’t only return the first error that happens, but rather all of them.
+traverseValidate :: forall t a err b. (Traversable t) => (a -> Either err b) -> t a -> Either (NonEmpty err) (t b)
+traverseValidate f as =
+  as
+    & traverse @t @(Validation _) (eitherToListValidation . f)
+    & validationToEither
+
+-- | 'traverse' with a function returning 'm Either' and collect all errors that happen, if they happen.
+--
+-- Does not shortcut on error, so will always traverse the whole list/'Traversable' structure.
+--
+-- This is a useful error handling function in many circumstances,
+-- because it won’t only return the first error that happens, but rather all of them.
+traverseValidateM :: forall t m a err b. (Traversable t, Applicative m) => (a -> m (Either err b)) -> t a -> m (Either (NonEmpty err) (t b))
+traverseValidateM f as =
+  as
+    & traverse @t @m (\a -> a & f <&> eitherToListValidation)
+    <&> sequenceA @t @(Validation _)
+    <&> validationToEither
+
+-- | 'traverse_' with a function returning 'm Either' and collect all errors that happen, if they happen.
+--
+-- Does not shortcut on error, so will always traverse the whole list/'Traversable' structure.
+--
+-- This is a useful error handling function in many circumstances,
+-- because it won’t only return the first error that happens, but rather all of them.
+traverseValidateM_ :: forall t m a err. (Traversable t, Applicative m) => (a -> m (Either err ())) -> t a -> m (Either (NonEmpty err) ())
+traverseValidateM_ f as =
+  as
+    & traverse @t @m (\a -> a & f <&> eitherToListValidation)
+    <&> sequenceA_ @t @(Validation _)
+    <&> validationToEither
+
 -- | Like 'eitherToValidation', but puts the Error side into a NonEmpty list
 -- to make it combine with other validations.
+--
+-- See also 'validateEithers', if you have a list of Either and want to collect all errors.
 eitherToListValidation :: Either a c -> Validation (NonEmpty a) c
 eitherToListValidation = first singleton . eitherToValidation
 
@@ -388,15 +558,26 @@ thenThese f x = do
   th <- x
   join <$> traverse f th
 
--- | Nested validating bind-like combinator inside some other @m@.
+-- | Nested validating bind-like combinator.
 --
 -- Use if you want to collect errors, and want to chain multiple functions returning 'Validation'.
 thenValidate ::
+  (a -> Validation err b) ->
+  Validation err a ->
+  Validation err b
+thenValidate f = \case
+  Success a -> f a
+  Failure err -> Failure err
+
+-- | Nested validating bind-like combinator inside some other @m@.
+--
+-- Use if you want to collect errors, and want to chain multiple functions returning 'Validation'.
+thenValidateM ::
   (Monad m) =>
   (a -> m (Validation err b)) ->
   m (Validation err a) ->
   m (Validation err b)
-thenValidate f x =
+thenValidateM f x =
   eitherToValidation <$> do
     x' <- validationToEither <$> x
     case x' of
@@ -429,23 +610,23 @@ exitWithMessage msg = do
 --
 -- … since @(Semigroup err => Validation err a)@ is a @Semigroup@/@Monoid@ itself.
 traverseFold :: (Applicative ap, Traversable t, Monoid m) => (a -> ap m) -> t a -> ap m
+{-# INLINE traverseFold #-}
 traverseFold f xs =
   -- note: could be weakened to (Foldable t) via `getAp . foldMap (Ap . f)`
   fold <$> traverse f xs
-{-# INLINE traverseFold #-}
 
 -- | Like 'traverseFold', but fold over a semigroup instead of a Monoid, by providing a starting element.
 traverseFoldDefault :: (Applicative ap, Traversable t, Semigroup m) => m -> (a -> ap m) -> t a -> ap m
+{-# INLINE traverseFoldDefault #-}
 traverseFoldDefault def f xs = foldDef def <$> traverse f xs
   where
     foldDef = foldr (<>)
-{-# INLINE traverseFoldDefault #-}
 
 -- | Same as 'traverseFold', but with a 'Semigroup' and 'Traversable1' restriction.
 traverseFold1 :: (Applicative ap, Traversable1 t, Semigroup s) => (a -> ap s) -> t a -> ap s
+{-# INLINE traverseFold1 #-}
 -- note: cannot be weakened to (Foldable1 t) because there is no `Ap` for Semigroup (No `Apply` typeclass)
 traverseFold1 f xs = fold1 <$> traverse f xs
-{-# INLINE traverseFold1 #-}
 
 -- | Use this in places where the code is still to be implemented.
 --
@@ -527,18 +708,31 @@ inverseFunction f k = Map.lookup k $ inverseMap f
 -- it returns a mapping from all possible outputs to their possible inputs.
 --
 -- This has the same restrictions of 'inverseFunction'.
-inverseMap ::
-  forall a k.
-  (Bounded a, Enum a, Ord k) =>
-  (a -> k) ->
-  Map k a
-inverseMap f =
-  universe
-    <&> (\a -> (f a, a))
-    & Map.fromList
-  where
-    universe :: [a]
-    universe = [minBound .. maxBound]
+inverseMap :: forall a k. (Bounded a, Enum a, Ord k) => (a -> k) -> Map k a
+inverseMap f = enumerateAll <&> (\a -> (f a, a)) & Map.fromList
+
+-- | All possible values in this enum.
+enumerateAll :: (Enum a, Bounded a) => [a]
+enumerateAll = [minBound .. maxBound]
+
+-- | Create a 'Map' from a list of values, extracting the map key from each value. Like 'Map.fromList'.
+--
+-- Attention: if the key is not unique, the earliest value with the key will be in the map.
+mapFromListOn :: (Ord key) => (a -> key) -> [a] -> Map key a
+mapFromListOn f xs = xs <&> (\x -> (f x, x)) & Map.fromList
+
+-- | Create a 'Map' from a list of values, merging multiple values at the same key with '<>' (left-to-right)
+--
+-- `f` has to extract the key and value. Value must be mergable.
+--
+-- Attention: if the key is not unique, the earliest value with the key will be in the map.
+mapFromListOnMerge :: (Ord key, Semigroup s) => (a -> (key, s)) -> [a] -> Map key s
+mapFromListOnMerge f xs =
+  xs
+    <&> (\x -> f x)
+    & Map.fromListWith
+      -- we have to flip (`<>`) because `Map.fromListWith` merges its values “the other way around”
+      (flip (<>))
 
 -- | If the predicate is true, return the @m@, else 'mempty'.
 --
@@ -570,12 +764,18 @@ ifTrue pred' m = if pred' then m else mempty
 -- >>> import Data.Monoid (Sum(..))
 --
 -- >>> :{ mconcat [
--- unknown command '{'
+--   ifExists (Just [1]),
+--   [2, 3, 4],
+--   ifExists Nothing,
+-- ]
+-- :}
+-- [1,2,3,4]
 --
 -- Or any other Monoid:
 --
--- >>> mconcat [ Sum 1, ifExists Sum (Just 2), Sum 3 ]
+-- >>> mconcat [ Sum 1, ifExists (Just (Sum 2)), Sum 3 ]
+
 -- Sum {getSum = 6}
 
-ifExists :: (Monoid m) => (a -> m) -> Maybe a -> m
-ifExists = foldMap
+ifExists :: (Monoid m) => Maybe m -> m
+ifExists = fold
diff --git a/users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs b/users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs
index 78e3897ef5f3..bd8ddd12f775 100644
--- a/users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs
+++ b/users/Profpatsch/my-prelude/src/Postgres/MonadPostgres.hs
@@ -1,14 +1,15 @@
+{-# LANGUAGE AllowAmbiguousTypes #-}
 {-# LANGUAGE DeriveAnyClass #-}
 {-# LANGUAGE QuasiQuotes #-}
-{-# LANGUAGE TemplateHaskell #-}
 {-# OPTIONS_GHC -Wno-orphans #-}
 
 module Postgres.MonadPostgres where
 
 import AtLeast (AtLeast)
 import Control.Exception
+import Control.Foldl qualified as Fold
 import Control.Monad.Except
-import Control.Monad.Logger (MonadLogger, logDebug, logWarn)
+import Control.Monad.Logger.CallStack (MonadLogger, logDebug, logWarn)
 import Control.Monad.Reader (MonadReader (ask), ReaderT (..))
 import Control.Monad.Trans.Resource
 import Data.Aeson (FromJSON)
@@ -28,7 +29,7 @@ import Database.PostgreSQL.Simple.FromRow qualified as PG
 import Database.PostgreSQL.Simple.ToField (ToField)
 import Database.PostgreSQL.Simple.ToRow (ToRow (toRow))
 import Database.PostgreSQL.Simple.Types (Query (..))
-import GHC.Records (HasField (..))
+import GHC.Records (getField)
 import Label
 import OpenTelemetry.Trace.Core qualified as Otel hiding (inSpan, inSpan')
 import OpenTelemetry.Trace.Monad qualified as Otel
@@ -42,7 +43,7 @@ import Tool
 import UnliftIO (MonadUnliftIO (withRunInIO))
 import UnliftIO.Process qualified as Process
 import UnliftIO.Resource qualified as Resource
-import Prelude hiding (span)
+import Prelude hiding (init, span)
 
 -- | Postgres queries/commands that can be executed within a running transaction.
 --
@@ -52,28 +53,46 @@ class (Monad m) => MonadPostgres (m :: Type -> Type) where
   -- | Execute an INSERT, UPDATE, or other SQL query that is not expected to return results.
 
   -- Returns the number of rows affected.
-  execute :: (ToRow params, Typeable params) => Query -> params -> Transaction m (Label "numberOfRowsAffected" Natural)
-
-  -- | Execute an INSERT, UPDATE, or other SQL query that is not expected to return results. Does not take parameters.
-
-  -- Returns the number of rows affected.
-  execute_ :: Query -> Transaction m (Label "numberOfRowsAffected" Natural)
+  execute ::
+    (ToRow params, Typeable params) =>
+    Query ->
+    params ->
+    Transaction m (Label "numberOfRowsAffected" Natural)
 
   -- | Execute a multi-row INSERT, UPDATE, or other SQL query that is not expected to return results.
   --
-  -- Returns the number of rows affected. If the list of parameters is empty, this function will simply return 0 without issuing the query to the backend. If this is not desired, consider using the 'PG.Values' constructor instead.
-  executeMany :: (ToRow params, Typeable params) => Query -> [params] -> Transaction m (Label "numberOfRowsAffected" Natural)
-
-  -- | Execute INSERT ... RETURNING, UPDATE ... RETURNING, or other SQL query that accepts multi-row input and is expected to return results. Note that it is possible to write query conn "INSERT ... RETURNING ..." ... in cases where you are only inserting a single row, and do not need functionality analogous to 'executeMany'.
+  -- Returns the number of rows affected. If the list of parameters is empty,
+  -- this function will simply return 0 without issuing the query to the backend.
+  -- If this is not desired, consider using the 'PG.Values' constructor instead.
+  executeMany ::
+    (ToRow params, Typeable params) =>
+    Query ->
+    NonEmpty params ->
+    Transaction m (Label "numberOfRowsAffected" Natural)
+
+  -- | Execute INSERT ... RETURNING, UPDATE ... RETURNING,
+  -- or other SQL query that accepts multi-row input and is expected to return results.
+  -- Note that it is possible to write query conn "INSERT ... RETURNING ..." ...
+  -- in cases where you are only inserting a single row,
+  -- and do not need functionality analogous to 'executeMany'.
   --
   -- If the list of parameters is empty, this function will simply return [] without issuing the query to the backend. If this is not desired, consider using the 'PG.Values' constructor instead.
-  executeManyReturningWith :: (ToRow q) => Query -> [q] -> Decoder r -> Transaction m [r]
+  executeManyReturningWith :: (ToRow q) => Query -> NonEmpty q -> Decoder r -> Transaction m [r]
 
   -- | Run a query, passing parameters and result row parser.
-  queryWith :: (PG.ToRow params, Typeable params, Typeable r) => PG.Query -> params -> Decoder r -> Transaction m [r]
+  queryWith ::
+    (PG.ToRow params, Typeable params, Typeable r) =>
+    PG.Query ->
+    params ->
+    Decoder r ->
+    Transaction m [r]
 
   -- | Run a query without any parameters and result row parser.
-  queryWith_ :: (Typeable r) => PG.Query -> Decoder r -> Transaction m [r]
+  queryWith_ ::
+    (Typeable r) =>
+    PG.Query ->
+    Decoder r ->
+    Transaction m [r]
 
   -- | Run a query, passing parameters, and fold over the resulting rows.
   --
@@ -82,13 +101,15 @@ class (Monad m) => MonadPostgres (m :: Type -> Type) where
   --
   -- When dealing with small results, it may be simpler (and perhaps faster) to use query instead.
   --
-  -- This fold is _not_ strict. The stream consumer is responsible for forcing the evaluation of its result to avoid space leaks.
+  -- This fold is _not_ strict. The stream consumer is responsible
+  -- for forcing the evaluation of its result to avoid space leaks.
   --
   -- If you can, prefer aggregating in the database itself.
-  foldRows ::
-    (FromRow row, ToRow params, Typeable row, Typeable params) =>
+  foldRowsWithAcc ::
+    (ToRow params, Typeable row, Typeable params) =>
     Query ->
     params ->
+    Decoder row ->
     a ->
     (a -> row -> Transaction m a) ->
     Transaction m a
@@ -109,12 +130,23 @@ class (Monad m) => MonadPostgres (m :: Type -> Type) where
   -- Only handlers should run transactions.
   runTransaction :: Transaction m a -> m a
 
--- | Run a query, passing parameters.
-query :: forall m params r. (PG.ToRow params, PG.FromRow r, Typeable params, Typeable r, MonadPostgres m) => PG.Query -> params -> Transaction m [r]
+-- | Run a query, passing parameters. Prefer 'queryWith' if possible.
+query ::
+  forall m params r.
+  (PG.ToRow params, PG.FromRow r, Typeable params, Typeable r, MonadPostgres m) =>
+  PG.Query ->
+  params ->
+  Transaction m [r]
 query qry params = queryWith qry params (Decoder PG.fromRow)
 
--- | Run a query without any parameters.
-query_ :: forall m r. (Typeable r, PG.FromRow r, MonadPostgres m) => PG.Query -> Transaction m [r]
+-- | Run a query without any parameters. Prefer 'queryWith' if possible.
+--
+-- TODO: I think(?) this can always be replaced by passing @()@ to 'query', remove?
+query_ ::
+  forall m r.
+  (Typeable r, PG.FromRow r, MonadPostgres m) =>
+  PG.Query ->
+  Transaction m [r]
 query_ qry = queryWith_ qry (Decoder PG.fromRow)
 
 -- TODO: implement via fold, so that the result doesn’t have to be realized in memory
@@ -153,7 +185,10 @@ querySingleRowMaybe qry params = do
     -- that a database function can error out, should probably handled by the instances.
     more -> throwM $ SingleRowError {numberOfRowsReturned = (List.length more)}
 
-ensureSingleRow :: (MonadThrow m) => [a] -> m a
+ensureSingleRow ::
+  (MonadThrow m) =>
+  [a] ->
+  m a
 ensureSingleRow = \case
   -- TODO: Should we MonadThrow this here? It’s really an implementation detail of MonadPostgres
   -- that a database function can error out, should probably handled by the instances.
@@ -167,6 +202,52 @@ ensureSingleRow = \case
             List.length more
         }
 
+ensureNoneOrSingleRow ::
+  (MonadThrow m) =>
+  [a] ->
+  m (Maybe a)
+ensureNoneOrSingleRow = \case
+  -- TODO: Should we MonadThrow this here? It’s really an implementation detail of MonadPostgres
+  -- that a database function can error out, should probably handled by the instances.
+  [] -> pure Nothing
+  [one] -> pure $ Just one
+  more ->
+    throwM $
+      SingleRowError
+        { numberOfRowsReturned =
+            -- TODO: this is VERY bad, because it requires to parse the full database output, even if there’s 10000000000 elements
+            List.length more
+        }
+
+-- | Run a query, passing parameters, and fold over the resulting rows.
+--
+-- This doesn’t have to realize the full list of results in memory,
+-- rather results are streamed incrementally from the database.
+--
+-- When dealing with small results, it may be simpler (and perhaps faster) to use query instead.
+--
+-- The results are folded strictly by the 'Fold.Fold' that is passed.
+--
+-- If you can, prefer aggregating in the database itself.
+foldRowsWith ::
+  forall row params m b.
+  ( MonadPostgres m,
+    PG.ToRow params,
+    Typeable row,
+    Typeable params
+  ) =>
+  PG.Query ->
+  params ->
+  Decoder row ->
+  Fold.Fold row b ->
+  Transaction m b
+foldRowsWith qry params decoder = Fold.purely f
+  where
+    f :: forall x. (x -> row -> x) -> x -> (x -> b) -> Transaction m b
+    f acc init extract = do
+      x <- foldRowsWithAcc qry params decoder init (\a r -> pure $ acc a r)
+      pure $ extract x
+
 newtype Transaction m a = Transaction {unTransaction :: (ReaderT Connection m a)}
   deriving newtype
     ( Functor,
@@ -180,9 +261,6 @@ newtype Transaction m a = Transaction {unTransaction :: (ReaderT Connection m a)
       Otel.MonadTracer
     )
 
-runTransaction' :: Connection -> Transaction m a -> m a
-runTransaction' conn transaction = runReaderT transaction.unTransaction conn
-
 -- | [Resource Pool](http://hackage.haskell.org/package/resource-pool-0.2.3.2/docs/Data-Pool.html) configuration.
 data PoolingInfo = PoolingInfo
   { -- | Minimal amount of resources that are
@@ -237,17 +315,41 @@ initMonadPostgres logInfoFn connectInfo poolingInfo = do
       IO ()
     destroyPGConnPool p = Pool.destroyAllResources p
 
+-- | Improve a possible error message, by adding some context to it.
+--
+-- The given Exception type is caught, 'show'n and pretty-printed.
+--
+-- In case we get an `IOError`, we display it in a reasonable fashion.
+addErrorInformation ::
+  forall exc a.
+  (Exception exc) =>
+  Text.Text ->
+  IO a ->
+  IO a
+addErrorInformation msg io =
+  io
+    & try @exc
+    <&> first (showPretty >>> newError >>> errorContext msg)
+    & try @IOError
+    <&> first (showToError >>> errorContext "IOError" >>> errorContext msg)
+    <&> join @(Either Error)
+    >>= unwrapIOError
+
 -- | Catch any Postgres exception that gets thrown,
 -- print the query that was run and the query parameters,
 -- then rethrow inside an 'Error'.
 handlePGException ::
   forall a params tools m.
-  (ToRow params, MonadUnliftIO m, MonadLogger m, HasField "pgFormat" tools Tool) =>
+  ( ToRow params,
+    MonadUnliftIO m,
+    MonadLogger m,
+    HasField "pgFormat" tools Tool
+  ) =>
   tools ->
   Text ->
   Query ->
   -- | Depending on whether we used `format` or `formatMany`.
-  Either params [params] ->
+  Either params (NonEmpty params) ->
   IO a ->
   Transaction m a
 handlePGException tools queryType query' params io = do
@@ -289,7 +391,11 @@ withPGTransaction connPool f =
     connPool
     (\conn -> Postgres.withTransaction conn (f conn))
 
-runPGTransactionImpl :: (MonadUnliftIO m) => m (Pool Postgres.Connection) -> Transaction m a -> m a
+runPGTransactionImpl ::
+  (MonadUnliftIO m) =>
+  m (Pool Postgres.Connection) ->
+  Transaction m a ->
+  m a
 {-# INLINE runPGTransactionImpl #-}
 runPGTransactionImpl zoom (Transaction transaction) = do
   pool <- zoom
@@ -337,7 +443,7 @@ executeManyImpl ::
   m tools ->
   m DebugLogDatabaseQueries ->
   Query ->
-  [params] ->
+  NonEmpty params ->
   Transaction m (Label "numberOfRowsAffected" Natural)
 executeManyImpl zoomTools zoomDebugLogDatabaseQueries qry params =
   Otel.inSpan' "Postgres Query (execute)" Otel.defaultSpanArguments $ \span -> do
@@ -345,7 +451,7 @@ executeManyImpl zoomTools zoomDebugLogDatabaseQueries qry params =
     logDatabaseQueries <- lift @Transaction zoomDebugLogDatabaseQueries
     traceQueryIfEnabled tools span logDatabaseQueries qry (HasMultiParams params)
     conn <- Transaction ask
-    PG.executeMany conn qry params
+    PG.executeMany conn qry (params & toList)
       & handlePGException tools "executeMany" qry (Right params)
       >>= toNumberOfRowsAffected "executeManyImpl"
 
@@ -364,7 +470,7 @@ executeManyReturningWithImpl ::
   m tools ->
   m DebugLogDatabaseQueries ->
   Query ->
-  [params] ->
+  NonEmpty params ->
   Decoder r ->
   Transaction m [r]
 {-# INLINE executeManyReturningWithImpl #-}
@@ -374,33 +480,45 @@ executeManyReturningWithImpl zoomTools zoomDebugLogDatabaseQueries qry params (D
     logDatabaseQueries <- lift @Transaction zoomDebugLogDatabaseQueries
     traceQueryIfEnabled tools span logDatabaseQueries qry (HasMultiParams params)
     conn <- Transaction ask
-    PG.returningWith fromRow conn qry params
+    PG.returningWith fromRow conn qry (params & toList)
       & handlePGException tools "executeManyReturning" qry (Right params)
 
-foldRowsImpl ::
-  (FromRow row, ToRow params, MonadUnliftIO m, MonadLogger m, HasField "pgFormat" tools Tool) =>
+foldRowsWithAccImpl ::
+  ( ToRow params,
+    MonadUnliftIO m,
+    MonadLogger m,
+    HasField "pgFormat" tools Tool,
+    Otel.MonadTracer m
+  ) =>
   m tools ->
+  m DebugLogDatabaseQueries ->
   Query ->
   params ->
+  Decoder row ->
   a ->
   (a -> row -> Transaction m a) ->
   Transaction m a
-{-# INLINE foldRowsImpl #-}
-foldRowsImpl zoomTools qry params accumulator f = do
-  conn <- Transaction ask
-  tools <- lift @Transaction zoomTools
-  withRunInIO
-    ( \runInIO ->
-        do
-          PG.fold
-            conn
-            qry
-            params
-            accumulator
-            (\acc row -> runInIO $ f acc row)
-            & handlePGException tools "fold" qry (Left params)
-            & runInIO
-    )
+{-# INLINE foldRowsWithAccImpl #-}
+foldRowsWithAccImpl zoomTools zoomDebugLogDatabaseQueries qry params (Decoder rowParser) accumulator f = do
+  Otel.inSpan' "Postgres Query (foldRowsWithAcc)" Otel.defaultSpanArguments $ \span -> do
+    tools <- lift @Transaction zoomTools
+    logDatabaseQueries <- lift @Transaction zoomDebugLogDatabaseQueries
+    traceQueryIfEnabled tools span logDatabaseQueries qry (HasSingleParam params)
+    conn <- Transaction ask
+    withRunInIO
+      ( \runInIO ->
+          do
+            PG.foldWithOptionsAndParser
+              PG.defaultFoldOptions
+              rowParser
+              conn
+              qry
+              params
+              accumulator
+              (\acc row -> runInIO $ f acc row)
+              & handlePGException tools "fold" qry (Left params)
+              & runInIO
+      )
 
 pgFormatQueryNoParams' ::
   (MonadIO m, MonadLogger m, HasField "pgFormat" tools Tool) =>
@@ -410,18 +528,38 @@ pgFormatQueryNoParams' ::
 pgFormatQueryNoParams' tools q =
   lift $ pgFormatQueryByteString tools q.fromQuery
 
-pgFormatQuery :: (ToRow params, MonadIO m) => Query -> params -> Transaction m ByteString
+pgFormatQuery ::
+  (ToRow params, MonadIO m) =>
+  Query ->
+  params ->
+  Transaction m ByteString
 pgFormatQuery qry params = Transaction $ do
   conn <- ask
   liftIO $ PG.formatQuery conn qry params
 
-pgFormatQueryMany :: (MonadIO m, ToRow params) => Query -> [params] -> Transaction m ByteString
+pgFormatQueryMany ::
+  (MonadIO m, ToRow params) =>
+  Query ->
+  NonEmpty params ->
+  Transaction m ByteString
 pgFormatQueryMany qry params = Transaction $ do
   conn <- ask
-  liftIO $ PG.formatMany conn qry params
+  liftIO $
+    PG.formatMany
+      conn
+      qry
+      ( params
+          -- upstream is partial on empty list, see https://github.com/haskellari/postgresql-simple/issues/129
+          & toList
+      )
 
 queryWithImpl ::
-  (ToRow params, MonadUnliftIO m, MonadLogger m, HasField "pgFormat" tools Tool, Otel.MonadTracer m) =>
+  ( ToRow params,
+    MonadUnliftIO m,
+    MonadLogger m,
+    HasField "pgFormat" tools Tool,
+    Otel.MonadTracer m
+  ) =>
   m tools ->
   m DebugLogDatabaseQueries ->
   Query ->
@@ -438,7 +576,15 @@ queryWithImpl zoomTools zoomDebugLogDatabaseQueries qry params (Decoder fromRow)
     PG.queryWith fromRow conn qry params
       & handlePGException tools "query" qry (Left params)
 
-queryWithImpl_ :: (MonadUnliftIO m, MonadLogger m, HasField "pgFormat" tools Tool) => m tools -> Query -> Decoder r -> Transaction m [r]
+queryWithImpl_ ::
+  ( MonadUnliftIO m,
+    MonadLogger m,
+    HasField "pgFormat" tools Tool
+  ) =>
+  m tools ->
+  Query ->
+  Decoder r ->
+  Transaction m [r]
 {-# INLINE queryWithImpl_ #-}
 queryWithImpl_ zoomTools qry (Decoder fromRow) = do
   tools <- lift @Transaction zoomTools
@@ -446,18 +592,6 @@ queryWithImpl_ zoomTools qry (Decoder fromRow) = do
   liftIO (PG.queryWith_ fromRow conn qry)
     & handlePGException tools "query" qry (Left ())
 
-pgQuery :: (ToRow params, FromRow r, MonadUnliftIO m, MonadLogger m, HasField "pgFormat" tools Tool) => tools -> Query -> params -> Transaction m [r]
-pgQuery tools qry params = do
-  conn <- Transaction ask
-  PG.query conn qry params
-    & handlePGException tools "query" qry (Left params)
-
-pgQuery_ :: (FromRow r, MonadUnliftIO m, MonadLogger m, HasField "pgFormat" tools Tool) => tools -> Query -> Transaction m [r]
-pgQuery_ tools qry = do
-  conn <- Transaction ask
-  PG.query_ conn qry
-    & handlePGException tools "query_" qry (Left ())
-
 data SingleRowError = SingleRowError
   { -- | How many columns were actually returned by the query
     numberOfRowsReturned :: Int
@@ -467,12 +601,30 @@ data SingleRowError = SingleRowError
 instance Exception SingleRowError where
   displayException (SingleRowError {..}) = [fmt|Single row expected from SQL query result, {numberOfRowsReturned} rows were returned instead."|]
 
-pgFormatQuery' :: (MonadIO m, ToRow params, MonadLogger m, HasField "pgFormat" tools Tool) => tools -> Query -> params -> Transaction m Text
+pgFormatQuery' ::
+  ( MonadIO m,
+    ToRow params,
+    MonadLogger m,
+    HasField "pgFormat" tools Tool
+  ) =>
+  tools ->
+  Query ->
+  params ->
+  Transaction m Text
 pgFormatQuery' tools q p =
   pgFormatQuery q p
     >>= lift . pgFormatQueryByteString tools
 
-pgFormatQueryMany' :: (MonadIO m, ToRow params, MonadLogger m, HasField "pgFormat" tools Tool) => tools -> Query -> [params] -> Transaction m Text
+pgFormatQueryMany' ::
+  ( MonadIO m,
+    ToRow params,
+    MonadLogger m,
+    HasField "pgFormat" tools Tool
+  ) =>
+  tools ->
+  Query ->
+  NonEmpty params ->
+  Transaction m Text
 pgFormatQueryMany' tools q p =
   pgFormatQueryMany q p
     >>= lift . pgFormatQueryByteString tools
@@ -481,7 +633,14 @@ pgFormatQueryMany' tools q p =
 postgresToolsParser :: ToolParserT IO (Label "pgFormat" Tool)
 postgresToolsParser = label @"pgFormat" <$> readTool "pg_format"
 
-pgFormatQueryByteString :: (MonadIO m, MonadLogger m, HasField "pgFormat" tools Tool) => tools -> ByteString -> m Text
+pgFormatQueryByteString ::
+  ( MonadIO m,
+    MonadLogger m,
+    HasField "pgFormat" tools Tool
+  ) =>
+  tools ->
+  ByteString ->
+  m Text
 pgFormatQueryByteString tools queryBytes = do
   do
     (exitCode, stdout, stderr) <-
@@ -492,8 +651,8 @@ pgFormatQueryByteString tools queryBytes = do
     case exitCode of
       ExitSuccess -> pure (stdout & stringToText)
       ExitFailure status -> do
-        $logWarn [fmt|pg_format failed with status {status} while formatting the query, using original query string. Is there a syntax error?|]
-        $logDebug
+        logWarn [fmt|pg_format failed with status {status} while formatting the query, using original query string. Is there a syntax error?|]
+        logDebug
           ( prettyErrorTree
               ( nestedMultiError
                   "pg_format output"
@@ -502,7 +661,7 @@ pgFormatQueryByteString tools queryBytes = do
                   )
               )
           )
-        $logDebug [fmt|pg_format stdout: stderr|]
+        logDebug [fmt|pg_format stdout: stderr|]
         pure (queryBytes & bytesToTextUtf8Lenient)
 
 data DebugLogDatabaseQueries
@@ -517,7 +676,7 @@ data DebugLogDatabaseQueries
 data HasQueryParams param
   = HasNoParams
   | HasSingleParam param
-  | HasMultiParams [param]
+  | HasMultiParams (NonEmpty param)
 
 -- | Log the postgres query depending on the given setting
 traceQueryIfEnabled ::