From 46251292e34cad18d5c1e67a21f9dacc726883d7 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 14:30:58 -0400 Subject: [PATCH 01/30] added mock structured --- .../Database/Persist/Postgresql.hs | 32 ++++--------------- .../Database/Persist/Postgresql/Internal.hs | 30 ++++++++++++++++- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index f40196c12..5c91ab533 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -762,31 +762,6 @@ mayDefault def = case def of Nothing -> "" Just d -> " DEFAULT " <> d -type SafeToRemove = Bool - -data AlterColumn - = ChangeType Column SqlType Text - | IsNull Column - | NotNull Column - | Add' Column - | Drop Column SafeToRemove - | Default Column Text - | NoDefault Column - | Update' Column Text - | AddReference EntityNameDB ConstraintNameDB [FieldNameDB] [Text] FieldCascade - | DropReference ConstraintNameDB - deriving Show - -data AlterTable - = AddUniqueConstraint ConstraintNameDB [FieldNameDB] - | DropConstraint ConstraintNameDB - deriving Show - -data AlterDB = AddTable Text - | AlterColumn EntityNameDB AlterColumn - | AlterTable EntityNameDB AlterTable - deriving Show - -- | Returns all of the columns in the given table currently in the database. getColumns :: (Text -> IO Statement) -> EntityDef -> [Column] @@ -1557,7 +1532,12 @@ mockMigrate :: [EntityDef] -> (Text -> IO Statement) -> EntityDef -> IO (Either [Text] [(Bool, Text)]) -mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ do +mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ mockMigrateStructured allDefs entity + +mockMigrateStructured :: [EntityDef] + -> EntityDef + -> IO (Either [Text] [AlterDB]) +mockMigrateStructured allDefs entity = do case partitionEithers [] of ([], old'') -> return $ Right $ migrationText False old'' (errs, _) -> return $ Left errs diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 0fbdcb771..c19c566d3 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -5,6 +5,10 @@ module Database.Persist.Postgresql.Internal ( P(..) , PgInterval(..) , getGetter + , AlterDB (..) + , AlterTable (..) + , AlterColumn (..) + , SafeToRemove ) where import qualified Database.PostgreSQL.Simple as PG @@ -14,6 +18,7 @@ import qualified Database.PostgreSQL.Simple.ToField as PGTF import qualified Database.PostgreSQL.Simple.TypeInfo.Static as PS import qualified Database.PostgreSQL.Simple.Types as PG +import Database.Persist.Sql import qualified Blaze.ByteString.Builder.Char8 as BBB import qualified Data.Attoparsec.ByteString.Char8 as P import Data.Bits ((.&.)) @@ -281,4 +286,27 @@ instance PersistField PgInterval where instance PersistFieldSql PgInterval where sqlType _ = SqlOther "interval" - +type SafeToRemove = Bool + +data AlterColumn + = ChangeType Column SqlType Text + | IsNull Column + | NotNull Column + | Add' Column + | Drop Column SafeToRemove + | Default Column Text + | NoDefault Column + | Update' Column Text + | AddReference EntityNameDB ConstraintNameDB [FieldNameDB] [Text] FieldCascade + | DropReference ConstraintNameDB + deriving Show + +data AlterTable + = AddUniqueConstraint ConstraintNameDB [FieldNameDB] + | DropConstraint ConstraintNameDB + deriving Show + +data AlterDB = AddTable Text + | AlterColumn EntityNameDB AlterColumn + | AlterTable EntityNameDB AlterTable + deriving Show \ No newline at end of file From 206db47f22a10673493a53f7f0097c97bc6fb247 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 15:40:09 -0400 Subject: [PATCH 02/30] builds --- .../Database/Persist/Postgresql.hs | 342 -------------- .../Database/Persist/Postgresql/Internal.hs | 417 +++++++++++++++++- 2 files changed, 412 insertions(+), 347 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index 5c91ab533..702895cba 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -709,59 +709,6 @@ mkForeignAlt entity fdef = pure $ AlterColumn tableName_ addReference escapedParentFields = map escapeF parentfields -addTable :: [Column] -> EntityDef -> AlterDB -addTable cols entity = - AddTable $ T.concat - -- Lower case e: see Database.Persist.Sql.Migration - [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! - , escapeE name - , "(" - , idtxt - , if null nonIdCols then "" else "," - , T.intercalate "," $ map showColumn nonIdCols - , ")" - ] - where - nonIdCols = - case entityPrimary entity of - Just _ -> - cols - _ -> - filter keepField cols - where - keepField c = - Just (cName c) /= fmap fieldDB (getEntityIdField entity) - && not (safeToRemove entity (cName c)) - - name = - getEntityDBName entity - idtxt = - case getEntityId entity of - EntityIdNaturalKey pdef -> - T.concat - [ " PRIMARY KEY (" - , T.intercalate "," $ map (escapeF . fieldDB) $ NEL.toList $ compositeFields pdef - , ")" - ] - EntityIdField field -> - let defText = defaultAttribute $ fieldAttrs field - sType = fieldSqlType field - in T.concat - [ escapeF $ fieldDB field - , maySerial sType defText - , " PRIMARY KEY UNIQUE" - , mayDefault defText - ] - -maySerial :: SqlType -> Maybe Text -> Text -maySerial SqlInt64 Nothing = " SERIAL8 " -maySerial sType _ = " " <> showSqlType sType - -mayDefault :: Maybe Text -> Text -mayDefault def = case def of - Nothing -> "" - Just d -> " DEFAULT " <> d - -- | Returns all of the columns in the given table currently in the database. getColumns :: (Text -> IO Statement) -> EntityDef -> [Column] @@ -845,55 +792,6 @@ getColumns getter def cols = do Left e -> Left e Right c -> Right $ Left c --- | Check if a column name is listed as the "safe to remove" in the entity --- list. -safeToRemove :: EntityDef -> FieldNameDB -> Bool -safeToRemove def (FieldNameDB colName) - = any (elem FieldAttrSafeToRemove . fieldAttrs) - $ filter ((== FieldNameDB colName) . fieldDB) - $ allEntityFields - where - allEntityFields = - getEntityFieldsDatabase def <> case getEntityId def of - EntityIdField fdef -> - [fdef] - _ -> - [] - -getAlters :: [EntityDef] - -> EntityDef - -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) - -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) - -> ([AlterColumn], [AlterTable]) -getAlters defs def (c1, u1) (c2, u2) = - (getAltersC c1 c2, getAltersU u1 u2) - where - getAltersC [] old = - map (\x -> Drop x $ safeToRemove def $ cName x) old - getAltersC (new:news) old = - let (alters, old') = findAlters defs def new old - in alters ++ getAltersC news old' - - getAltersU - :: [(ConstraintNameDB, [FieldNameDB])] - -> [(ConstraintNameDB, [FieldNameDB])] - -> [AlterTable] - getAltersU [] old = - map DropConstraint $ filter (not . isManual) $ map fst old - getAltersU ((name, cols):news) old = - case lookup name old of - Nothing -> - AddUniqueConstraint name cols : getAltersU news old - Just ocols -> - let old' = filter (\(x, _) -> x /= name) old - in if sort cols == sort ocols - then getAltersU news old' - else DropConstraint name - : AddUniqueConstraint name cols - : getAltersU news old' - - -- Don't drop constraints which were manually added. - isManual (ConstraintNameDB x) = "__manual_" `T.isPrefixOf` x getColumn :: (Text -> IO Statement) @@ -1076,169 +974,6 @@ getColumn getter tableName' [ PersistText columnName getColumn _ _ columnName _ = return $ Left $ T.pack $ "Invalid result from information_schema: " ++ show columnName --- | Intelligent comparison of SQL types, to account for SqlInt32 vs SqlOther integer -sqlTypeEq :: SqlType -> SqlType -> Bool -sqlTypeEq x y = - let - -- Non exhaustive helper to map postgres aliases to the same name. Based on - -- https://www.postgresql.org/docs/9.5/datatype.html. - -- This prevents needless `ALTER TYPE`s when the type is the same. - normalize "int8" = "bigint" - normalize "serial8" = "bigserial" - normalize v = v - in - normalize (T.toCaseFold (showSqlType x)) == normalize (T.toCaseFold (showSqlType y)) - -findAlters - :: [EntityDef] - -- ^ The list of all entity definitions that persistent is aware of. - -> EntityDef - -- ^ The entity definition for the entity that we're working on. - -> Column - -- ^ The column that we're searching for potential alterations for. - -> [Column] - -> ([AlterColumn], [Column]) -findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName _maxLen ref) cols = - case List.find (\c -> cName c == name) cols of - Nothing -> - ([Add' col], cols) - Just (Column _oldName isNull' sqltype' def' _gen' _defConstraintName' _maxLen' ref') -> - let refDrop Nothing = [] - refDrop (Just ColumnReference {crConstraintName=cname}) = - [DropReference cname] - - refAdd Nothing = [] - refAdd (Just colRef) = - case find ((== crTableName colRef) . getEntityDBName) defs of - Just refdef - | Just _oldName /= fmap fieldDB (getEntityIdField edef) - -> - [AddReference - (crTableName colRef) - (crConstraintName colRef) - [name] - (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef) - (crFieldCascade colRef) - ] - Just _ -> [] - Nothing -> - error $ "could not find the entityDef for reftable[" - ++ show (crTableName colRef) ++ "]" - modRef = - if equivalentRef ref ref' - then [] - else refDrop ref' ++ refAdd ref - modNull = case (isNull, isNull') of - (True, False) -> do - guard $ Just name /= fmap fieldDB (getEntityIdField edef) - pure (IsNull col) - (False, True) -> - let up = case def of - Nothing -> id - Just s -> (:) (Update' col s) - in up [NotNull col] - _ -> [] - modType - | sqlTypeEq sqltype sqltype' = [] - -- When converting from Persistent pre-2.0 databases, we - -- need to make sure that TIMESTAMP WITHOUT TIME ZONE is - -- treated as UTC. - | sqltype == SqlDayTime && sqltype' == SqlOther "timestamp" = - [ChangeType col sqltype $ T.concat - [ " USING " - , escapeF name - , " AT TIME ZONE 'UTC'" - ]] - | otherwise = [ChangeType col sqltype ""] - modDef = - if def == def' - || isJust (T.stripPrefix "nextval" =<< def') - then [] - else - case def of - Nothing -> [NoDefault col] - Just s -> [Default col s] - dropSafe = - if safeToRemove edef name - then error "wtf" [Drop col True] - else [] - in - ( modRef ++ modDef ++ modNull ++ modType ++ dropSafe - , filter (\c -> cName c /= name) cols - ) - --- We check if we should alter a foreign key. This is almost an equality check, --- except we consider 'Nothing' and 'Just Restrict' equivalent. -equivalentRef :: Maybe ColumnReference -> Maybe ColumnReference -> Bool -equivalentRef Nothing Nothing = True -equivalentRef (Just cr1) (Just cr2) = - crTableName cr1 == crTableName cr2 - && crConstraintName cr1 == crConstraintName cr2 - && eqCascade (fcOnUpdate $ crFieldCascade cr1) (fcOnUpdate $ crFieldCascade cr2) - && eqCascade (fcOnDelete $ crFieldCascade cr1) (fcOnDelete $ crFieldCascade cr2) - where - eqCascade :: Maybe CascadeAction -> Maybe CascadeAction -> Bool - eqCascade Nothing Nothing = True - eqCascade Nothing (Just Restrict) = True - eqCascade (Just Restrict) Nothing = True - eqCascade (Just cs1) (Just cs2) = cs1 == cs2 - eqCascade _ _ = False -equivalentRef _ _ = False - --- | Get the references to be added to a table for the given column. -getAddReference - :: [EntityDef] - -> EntityDef - -> FieldNameDB - -> ColumnReference - -> Maybe AlterDB -getAddReference allDefs entity cname cr@ColumnReference {crTableName = s, crConstraintName=constraintName} = do - guard $ Just cname /= fmap fieldDB (getEntityIdField entity) - pure $ AlterColumn - table - (AddReference s constraintName [cname] id_ (crFieldCascade cr) - ) - where - table = getEntityDBName entity - id_ = - fromMaybe - (error $ "Could not find ID of entity " ++ show s) - $ do - entDef <- find ((== s) . getEntityDBName) allDefs - return $ NEL.toList $ Util.dbIdColumnsEsc escapeF entDef - -showColumn :: Column -> Text -showColumn (Column n nu sqlType' def gen _defConstraintName _maxLen _ref) = T.concat - [ escapeF n - , " " - , showSqlType sqlType' - , " " - , if nu then "NULL" else "NOT NULL" - , case def of - Nothing -> "" - Just s -> " DEFAULT " <> s - , case gen of - Nothing -> "" - Just s -> " GENERATED ALWAYS AS (" <> s <> ") STORED" - ] - -showSqlType :: SqlType -> Text -showSqlType SqlString = "VARCHAR" -showSqlType SqlInt32 = "INT4" -showSqlType SqlInt64 = "INT8" -showSqlType SqlReal = "DOUBLE PRECISION" -showSqlType (SqlNumeric s prec) = T.concat [ "NUMERIC(", T.pack (show s), ",", T.pack (show prec), ")" ] -showSqlType SqlDay = "DATE" -showSqlType SqlTime = "TIME" -showSqlType SqlDayTime = "TIMESTAMP WITH TIME ZONE" -showSqlType SqlBlob = "BYTEA" -showSqlType SqlBool = "BOOLEAN" - --- Added for aliasing issues re: https://github.com/yesodweb/yesod/issues/682 -showSqlType (SqlOther (T.toLower -> "integer")) = "INT4" - -showSqlType (SqlOther t) = t - showAlterDb :: AlterDB -> (Bool, Text) showAlterDb (AddTable s) = (False, s) showAlterDb (AlterColumn t ac) = @@ -1493,84 +1228,12 @@ defaultPostgresConfHooks = PostgresConfHooks } -refName :: EntityNameDB -> FieldNameDB -> ConstraintNameDB -refName (EntityNameDB table) (FieldNameDB column) = - let overhead = T.length $ T.concat ["_", "_fkey"] - (fromTable, fromColumn) = shortenNames overhead (T.length table, T.length column) - in ConstraintNameDB $ T.concat [T.take fromTable table, "_", T.take fromColumn column, "_fkey"] - - where - - -- Postgres automatically truncates too long foreign keys to a combination of - -- truncatedTableName + "_" + truncatedColumnName + "_fkey" - -- This works fine for normal use cases, but it creates an issue for Persistent - -- Because after running the migrations, Persistent sees the truncated foreign key constraint - -- doesn't have the expected name, and suggests that you migrate again - -- To workaround this, we copy the Postgres truncation approach before sending foreign key constraints to it. - -- - -- I believe this will also be an issue for extremely long table names, - -- but it's just much more likely to exist with foreign key constraints because they're usually tablename * 2 in length - - -- Approximation of the algorithm Postgres uses to truncate identifiers - -- See makeObjectName https://github.com/postgres/postgres/blob/5406513e997f5ee9de79d4076ae91c04af0c52f6/src/backend/commands/indexcmds.c#L2074-L2080 - shortenNames :: Int -> (Int, Int) -> (Int, Int) - shortenNames overhead (x, y) - | x + y + overhead <= maximumIdentifierLength = (x, y) - | x > y = shortenNames overhead (x - 1, y) - | otherwise = shortenNames overhead (x, y - 1) - --- | Postgres' default maximum identifier length in bytes --- (You can re-compile Postgres with a new limit, but I'm assuming that virtually noone does this). --- See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -maximumIdentifierLength :: Int -maximumIdentifierLength = 63 - -udToPair :: UniqueDef -> (ConstraintNameDB, [FieldNameDB]) -udToPair ud = (uniqueDBName ud, map snd $ NEL.toList $ uniqueFields ud) - mockMigrate :: [EntityDef] -> (Text -> IO Statement) -> EntityDef -> IO (Either [Text] [(Bool, Text)]) mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ mockMigrateStructured allDefs entity -mockMigrateStructured :: [EntityDef] - -> EntityDef - -> IO (Either [Text] [AlterDB]) -mockMigrateStructured allDefs entity = do - case partitionEithers [] of - ([], old'') -> return $ Right $ migrationText False old'' - (errs, _) -> return $ Left errs - where - name = getEntityDBName entity - migrationText exists' old'' = - if not exists' - then createText newcols fdefs udspair - else let (acs, ats) = getAlters allDefs entity (newcols, udspair) old' - acs' = map (AlterColumn name) acs - ats' = map (AlterTable name) ats - in acs' ++ ats' - where - old' = partitionEithers old'' - (newcols', udefs, fdefs) = postgresMkColumns allDefs entity - newcols = filter (not . safeToRemove entity . cName) newcols' - udspair = map udToPair udefs - -- Check for table existence if there are no columns, workaround - -- for https://github.com/yesodweb/persistent/issues/152 - - createText newcols fdefs udspair = - (addTable newcols entity) : uniques ++ references ++ foreignsAlt - where - uniques = flip concatMap udspair $ \(uname, ucols) -> - [AlterTable name $ AddUniqueConstraint uname ucols] - references = - mapMaybe - (\Column { cName, cReference } -> - getAddReference allDefs entity cName =<< cReference - ) - newcols - foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs - -- | Mock a migration even when the database is not present. -- This function performs the same functionality of 'printMigration' -- with the difference that an actual database is not needed. @@ -1948,11 +1611,6 @@ migrateEnableExtension extName = WriterT $ WriterT $ do then return (((), []) , [(False, "CREATe EXTENSION \"" <> extName <> "\"")]) else return (((), []), []) -postgresMkColumns :: [EntityDef] -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -postgresMkColumns allDefs t = - mkColumns allDefs t - $ setBackendSpecificForeignKeyName refName emptyBackendSpecificOverrides - -- | Wrapper for persistent SqlBackends that carry the corresponding -- `Postgresql.Connection`. -- diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index c19c566d3..f744a111c 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ViewPatterns #-} @@ -9,6 +10,18 @@ module Database.Persist.Postgresql.Internal , AlterTable (..) , AlterColumn (..) , SafeToRemove + , mockMigrateStructured + , addTable + , findAlters + , maySerial + , mayDefault + , showSqlType + , showColumn + , getAddReference + , udToPair + , safeToRemove + , postgresMkColumns + , getAlters ) where import qualified Database.PostgreSQL.Simple as PG @@ -19,26 +32,28 @@ import qualified Database.PostgreSQL.Simple.TypeInfo.Static as PS import qualified Database.PostgreSQL.Simple.Types as PG import Database.Persist.Sql +import qualified Data.List.NonEmpty as NEL import qualified Blaze.ByteString.Builder.Char8 as BBB import qualified Data.Attoparsec.ByteString.Char8 as P import Data.Bits ((.&.)) import Data.ByteString (ByteString) import qualified Data.ByteString.Builder as BB +import Data.List as List (find, sort) import qualified Data.ByteString.Char8 as B8 import Data.Char (ord) import Data.Data (Typeable) import Data.Fixed (Fixed(..), Pico) import Data.Int (Int64) +import Data.Either (partitionEithers) +import Control.Monad import qualified Data.IntMap as I -import Data.Maybe (fromMaybe) +import Data.Maybe +import qualified Database.Persist.Sql.Util as Util import Data.String.Conversions.Monomorphic (toStrictByteString) import Data.Text (Text) import qualified Data.Text as T -import qualified Data.Text.Encoding as T import Data.Time (NominalDiffTime, localTimeToUTC, utc) -import Database.Persist.Sql - -- | Newtype used to avoid orphan instances for @postgresql-simple@ classes. -- -- @since 2.13.2.0 @@ -309,4 +324,396 @@ data AlterTable data AlterDB = AddTable Text | AlterColumn EntityNameDB AlterColumn | AlterTable EntityNameDB AlterTable - deriving Show \ No newline at end of file + deriving Show + +mockMigrateStructured :: [EntityDef] + -> EntityDef + -> IO (Either [Text] [AlterDB]) +mockMigrateStructured allDefs entity = do + case partitionEithers [] of + ([], old'') -> return $ Right $ migrationText False old'' + (errs, _) -> return $ Left errs + where + name = getEntityDBName entity + migrationText exists' old'' = + if not exists' + then createText newcols fdefs udspair + else let (acs, ats) = getAlters allDefs entity (newcols, udspair) old' + acs' = map (AlterColumn name) acs + ats' = map (AlterTable name) ats + in acs' ++ ats' + where + old' = partitionEithers old'' + (newcols', udefs, fdefs) = postgresMkColumns allDefs entity + newcols = filter (not . safeToRemove entity . cName) newcols' + udspair = map udToPair udefs + -- Check for table existence if there are no columns, workaround + -- for https://github.com/yesodweb/persistent/issues/152 + + createText newcols fdefs udspair = + (addTable newcols entity) : uniques ++ references ++ foreignsAlt + where + uniques = flip concatMap udspair $ \(uname, ucols) -> + [AlterTable name $ AddUniqueConstraint uname ucols] + references = + mapMaybe + (\Column { cName, cReference } -> + getAddReference allDefs entity cName =<< cReference + ) + newcols + foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs + + + +addTable :: [Column] -> EntityDef -> AlterDB +addTable cols entity = + AddTable $ T.concat + -- Lower case e: see Database.Persist.Sql.Migration + [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! + , escapeE name + , "(" + , idtxt + , if null nonIdCols then "" else "," + , T.intercalate "," $ map showColumn nonIdCols + , ")" + ] + where + nonIdCols = + case entityPrimary entity of + Just _ -> + cols + _ -> + filter keepField cols + where + keepField c = + Just (cName c) /= fmap fieldDB (getEntityIdField entity) + && not (safeToRemove entity (cName c)) + + name = + getEntityDBName entity + idtxt = + case getEntityId entity of + EntityIdNaturalKey pdef -> + T.concat + [ " PRIMARY KEY (" + , T.intercalate "," $ map (escapeF . fieldDB) $ NEL.toList $ compositeFields pdef + , ")" + ] + EntityIdField field -> + let defText = defaultAttribute $ fieldAttrs field + sType = fieldSqlType field + in T.concat + [ escapeF $ fieldDB field + , maySerial sType defText + , " PRIMARY KEY UNIQUE" + , mayDefault defText + ] + +maySerial :: SqlType -> Maybe Text -> Text +maySerial SqlInt64 Nothing = " SERIAL8 " +maySerial sType _ = " " <> showSqlType sType + +mayDefault :: Maybe Text -> Text +mayDefault def = case def of + Nothing -> "" + Just d -> " DEFAULT " <> d + + +getAlters :: [EntityDef] + -> EntityDef + -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) + -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) + -> ([AlterColumn], [AlterTable]) +getAlters defs def (c1, u1) (c2, u2) = + (getAltersC c1 c2, getAltersU u1 u2) + where + getAltersC [] old = + map (\x -> Drop x $ safeToRemove def $ cName x) old + getAltersC (new:news) old = + let (alters, old') = findAlters defs def new old + in alters ++ getAltersC news old' + + getAltersU + :: [(ConstraintNameDB, [FieldNameDB])] + -> [(ConstraintNameDB, [FieldNameDB])] + -> [AlterTable] + getAltersU [] old = + map DropConstraint $ filter (not . isManual) $ map fst old + getAltersU ((name, cols):news) old = + case lookup name old of + Nothing -> + AddUniqueConstraint name cols : getAltersU news old + Just ocols -> + let old' = filter (\(x, _) -> x /= name) old + in if sort cols == sort ocols + then getAltersU news old' + else DropConstraint name + : AddUniqueConstraint name cols + : getAltersU news old' + + -- Don't drop constraints which were manually added. + isManual (ConstraintNameDB x) = "__manual_" `T.isPrefixOf` x + + +-- | Postgres' default maximum identifier length in bytes +-- (You can re-compile Postgres with a new limit, but I'm assuming that virtually noone does this). +-- See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +maximumIdentifierLength :: Int +maximumIdentifierLength = 63 + +-- | Intelligent comparison of SQL types, to account for SqlInt32 vs SqlOther integer +sqlTypeEq :: SqlType -> SqlType -> Bool +sqlTypeEq x y = + let + -- Non exhaustive helper to map postgres aliases to the same name. Based on + -- https://www.postgresql.org/docs/9.5/datatype.html. + -- This prevents needless `ALTER TYPE`s when the type is the same. + normalize "int8" = "bigint" + normalize "serial8" = "bigserial" + normalize v = v + in + normalize (T.toCaseFold (showSqlType x)) == normalize (T.toCaseFold (showSqlType y)) + +-- We check if we should alter a foreign key. This is almost an equality check, +-- except we consider 'Nothing' and 'Just Restrict' equivalent. +equivalentRef :: Maybe ColumnReference -> Maybe ColumnReference -> Bool +equivalentRef Nothing Nothing = True +equivalentRef (Just cr1) (Just cr2) = + crTableName cr1 == crTableName cr2 + && crConstraintName cr1 == crConstraintName cr2 + && eqCascade (fcOnUpdate $ crFieldCascade cr1) (fcOnUpdate $ crFieldCascade cr2) + && eqCascade (fcOnDelete $ crFieldCascade cr1) (fcOnDelete $ crFieldCascade cr2) + where + eqCascade :: Maybe CascadeAction -> Maybe CascadeAction -> Bool + eqCascade Nothing Nothing = True + eqCascade Nothing (Just Restrict) = True + eqCascade (Just Restrict) Nothing = True + eqCascade (Just cs1) (Just cs2) = cs1 == cs2 + eqCascade _ _ = False +equivalentRef _ _ = False + +refName :: EntityNameDB -> FieldNameDB -> ConstraintNameDB +refName (EntityNameDB table) (FieldNameDB column) = + let overhead = T.length $ T.concat ["_", "_fkey"] + (fromTable, fromColumn) = shortenNames overhead (T.length table, T.length column) + in ConstraintNameDB $ T.concat [T.take fromTable table, "_", T.take fromColumn column, "_fkey"] + + where + + -- Postgres automatically truncates too long foreign keys to a combination of + -- truncatedTableName + "_" + truncatedColumnName + "_fkey" + -- This works fine for normal use cases, but it creates an issue for Persistent + -- Because after running the migrations, Persistent sees the truncated foreign key constraint + -- doesn't have the expected name, and suggests that you migrate again + -- To workaround this, we copy the Postgres truncation approach before sending foreign key constraints to it. + -- + -- I believe this will also be an issue for extremely long table names, + -- but it's just much more likely to exist with foreign key constraints because they're usually tablename * 2 in length + + -- Approximation of the algorithm Postgres uses to truncate identifiers + -- See makeObjectName https://github.com/postgres/postgres/blob/5406513e997f5ee9de79d4076ae91c04af0c52f6/src/backend/commands/indexcmds.c#L2074-L2080 + shortenNames :: Int -> (Int, Int) -> (Int, Int) + shortenNames overhead (x, y) + | x + y + overhead <= maximumIdentifierLength = (x, y) + | x > y = shortenNames overhead (x - 1, y) + | otherwise = shortenNames overhead (x, y - 1) + +postgresMkColumns :: [EntityDef] -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) +postgresMkColumns allDefs t = + mkColumns allDefs t + $ setBackendSpecificForeignKeyName refName emptyBackendSpecificOverrides + + +-- | Check if a column name is listed as the "safe to remove" in the entity +-- list. +safeToRemove :: EntityDef -> FieldNameDB -> Bool +safeToRemove def (FieldNameDB colName) + = any (elem FieldAttrSafeToRemove . fieldAttrs) + $ filter ((== FieldNameDB colName) . fieldDB) + $ allEntityFields + where + allEntityFields = + getEntityFieldsDatabase def <> case getEntityId def of + EntityIdField fdef -> + [fdef] + _ -> + [] + + +udToPair :: UniqueDef -> (ConstraintNameDB, [FieldNameDB]) +udToPair ud = (uniqueDBName ud, map snd $ NEL.toList $ uniqueFields ud) + + +-- | Get the references to be added to a table for the given column. +getAddReference + :: [EntityDef] + -> EntityDef + -> FieldNameDB + -> ColumnReference + -> Maybe AlterDB +getAddReference allDefs entity cname cr@ColumnReference {crTableName = s, crConstraintName=constraintName} = do + guard $ Just cname /= fmap fieldDB (getEntityIdField entity) + pure $ AlterColumn + table + (AddReference s constraintName [cname] id_ (crFieldCascade cr) + ) + where + table = getEntityDBName entity + id_ = + fromMaybe + (error $ "Could not find ID of entity " ++ show s) + $ do + entDef <- find ((== s) . getEntityDBName) allDefs + return $ NEL.toList $ Util.dbIdColumnsEsc escapeF entDef + +mkForeignAlt + :: EntityDef + -> ForeignDef + -> Maybe AlterDB +mkForeignAlt entity fdef = pure $ AlterColumn tableName_ addReference + where + tableName_ = getEntityDBName entity + addReference = + AddReference + (foreignRefTableDBName fdef) + constraintName + childfields + escapedParentFields + (foreignFieldCascade fdef) + constraintName = + foreignConstraintNameDBName fdef + (childfields, parentfields) = + unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef)) + escapedParentFields = + map escapeF parentfields + +escapeC :: ConstraintNameDB -> Text +escapeC = escapeWith escape + +escapeE :: EntityNameDB -> Text +escapeE = escapeWith escape + +escapeF :: FieldNameDB -> Text +escapeF = escapeWith escape + + +escape :: Text -> Text +escape s = + T.pack $ '"' : go (T.unpack s) ++ "\"" + where + go "" = "" + go ('"':xs) = "\"\"" ++ go xs + go (x:xs) = x : go xs + +showColumn :: Column -> Text +showColumn (Column n nu sqlType' def gen _defConstraintName _maxLen _ref) = T.concat + [ escapeF n + , " " + , showSqlType sqlType' + , " " + , if nu then "NULL" else "NOT NULL" + , case def of + Nothing -> "" + Just s -> " DEFAULT " <> s + , case gen of + Nothing -> "" + Just s -> " GENERATED ALWAYS AS (" <> s <> ") STORED" + ] + + +showSqlType :: SqlType -> Text +showSqlType SqlString = "VARCHAR" +showSqlType SqlInt32 = "INT4" +showSqlType SqlInt64 = "INT8" +showSqlType SqlReal = "DOUBLE PRECISION" +showSqlType (SqlNumeric s prec) = T.concat [ "NUMERIC(", T.pack (show s), ",", T.pack (show prec), ")" ] +showSqlType SqlDay = "DATE" +showSqlType SqlTime = "TIME" +showSqlType SqlDayTime = "TIMESTAMP WITH TIME ZONE" +showSqlType SqlBlob = "BYTEA" +showSqlType SqlBool = "BOOLEAN" + +-- Added for aliasing issues re: https://github.com/yesodweb/yesod/issues/682 +showSqlType (SqlOther (T.toLower -> "integer")) = "INT4" + +showSqlType (SqlOther t) = t + + + +findAlters + :: [EntityDef] + -- ^ The list of all entity definitions that persistent is aware of. + -> EntityDef + -- ^ The entity definition for the entity that we're working on. + -> Column + -- ^ The column that we're searching for potential alterations for. + -> [Column] + -> ([AlterColumn], [Column]) +findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName _maxLen ref) cols = + case List.find (\c -> cName c == name) cols of + Nothing -> + ([Add' col], cols) + Just (Column _oldName isNull' sqltype' def' _gen' _defConstraintName' _maxLen' ref') -> + let refDrop Nothing = [] + refDrop (Just ColumnReference {crConstraintName=cname}) = + [DropReference cname] + + refAdd Nothing = [] + refAdd (Just colRef) = + case find ((== crTableName colRef) . getEntityDBName) defs of + Just refdef + | Just _oldName /= fmap fieldDB (getEntityIdField edef) + -> + [AddReference + (crTableName colRef) + (crConstraintName colRef) + [name] + (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef) + (crFieldCascade colRef) + ] + Just _ -> [] + Nothing -> + error $ "could not find the entityDef for reftable[" + ++ show (crTableName colRef) ++ "]" + modRef = + if equivalentRef ref ref' + then [] + else refDrop ref' ++ refAdd ref + modNull = case (isNull, isNull') of + (True, False) -> do + guard $ Just name /= fmap fieldDB (getEntityIdField edef) + pure (IsNull col) + (False, True) -> + let up = case def of + Nothing -> id + Just s -> (:) (Update' col s) + in up [NotNull col] + _ -> [] + modType + | sqlTypeEq sqltype sqltype' = [] + -- When converting from Persistent pre-2.0 databases, we + -- need to make sure that TIMESTAMP WITHOUT TIME ZONE is + -- treated as UTC. + | sqltype == SqlDayTime && sqltype' == SqlOther "timestamp" = + [ChangeType col sqltype $ T.concat + [ " USING " + , escapeF name + , " AT TIME ZONE 'UTC'" + ]] + | otherwise = [ChangeType col sqltype ""] + modDef = + if def == def' + || isJust (T.stripPrefix "nextval" =<< def') + then [] + else + case def of + Nothing -> [NoDefault col] + Just s -> [Default col s] + dropSafe = + if safeToRemove edef name + then error "wtf" [Drop col True] + else [] + in + ( modRef ++ modDef ++ modNull ++ modType ++ dropSafe + , filter (\c -> cName c /= name) cols + ) \ No newline at end of file From 7cfed0f66043ace1bf4ee564415cda31044f63b6 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 15:42:52 -0400 Subject: [PATCH 03/30] re-export escape --- .../Database/Persist/Postgresql.hs | 19 ------------------- .../Database/Persist/Postgresql/Internal.hs | 4 ++++ 2 files changed, 4 insertions(+), 19 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index 702895cba..67ee9c251 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -1098,24 +1098,6 @@ tableName = escapeE . tableDBName fieldName :: (PersistEntity record) => EntityField record typ -> Text fieldName = escapeF . fieldDBName -escapeC :: ConstraintNameDB -> Text -escapeC = escapeWith escape - -escapeE :: EntityNameDB -> Text -escapeE = escapeWith escape - -escapeF :: FieldNameDB -> Text -escapeF = escapeWith escape - - -escape :: Text -> Text -escape s = - T.pack $ '"' : go (T.unpack s) ++ "\"" - where - go "" = "" - go ('"':xs) = "\"\"" ++ go xs - go (x:xs) = x : go xs - -- | Information required to connect to a PostgreSQL database -- using @persistent@'s generic facilities. These values are the -- same that are given to 'withPostgresqlPool'. @@ -1227,7 +1209,6 @@ defaultPostgresConfHooks = PostgresConfHooks , pgConfHooksAfterCreate = const $ pure () } - mockMigrate :: [EntityDef] -> (Text -> IO Statement) -> EntityDef diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index f744a111c..b590b581d 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -22,6 +22,10 @@ module Database.Persist.Postgresql.Internal , safeToRemove , postgresMkColumns , getAlters + , escapeC + , escapeE + , escapeF + , escape ) where import qualified Database.PostgreSQL.Simple as PG From 8cf506af5b1954ed4dc6306e4d57d0877930659b Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 15:45:49 -0400 Subject: [PATCH 04/30] move show --- .../Database/Persist/Postgresql.hs | 114 ----------------- .../Database/Persist/Postgresql/Internal.hs | 118 +++++++++++++++++- 2 files changed, 117 insertions(+), 115 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index 67ee9c251..e6281c763 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -974,120 +974,6 @@ getColumn getter tableName' [ PersistText columnName getColumn _ _ columnName _ = return $ Left $ T.pack $ "Invalid result from information_schema: " ++ show columnName -showAlterDb :: AlterDB -> (Bool, Text) -showAlterDb (AddTable s) = (False, s) -showAlterDb (AlterColumn t ac) = - (isUnsafe ac, showAlter t ac) - where - isUnsafe (Drop _ safeRemove) = not safeRemove - isUnsafe _ = False -showAlterDb (AlterTable t at) = (False, showAlterTable t at) - -showAlterTable :: EntityNameDB -> AlterTable -> Text -showAlterTable table (AddUniqueConstraint cname cols) = T.concat - [ "ALTER TABLE " - , escapeE table - , " ADD CONSTRAINT " - , escapeC cname - , " UNIQUE(" - , T.intercalate "," $ map escapeF cols - , ")" - ] -showAlterTable table (DropConstraint cname) = T.concat - [ "ALTER TABLE " - , escapeE table - , " DROP CONSTRAINT " - , escapeC cname - ] - -showAlter :: EntityNameDB -> AlterColumn -> Text -showAlter table (ChangeType c t extra) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ALTER COLUMN " - , escapeF (cName c) - , " TYPE " - , showSqlType t - , extra - ] -showAlter table (IsNull c) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ALTER COLUMN " - , escapeF (cName c) - , " DROP NOT NULL" - ] -showAlter table (NotNull c) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ALTER COLUMN " - , escapeF (cName c) - , " SET NOT NULL" - ] -showAlter table (Add' col) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ADD COLUMN " - , showColumn col - ] -showAlter table (Drop c _) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " DROP COLUMN " - , escapeF (cName c) - ] -showAlter table (Default c s) = - T.concat - [ "ALTER TABLE " - , escapeE table - , " ALTER COLUMN " - , escapeF (cName c) - , " SET DEFAULT " - , s - ] -showAlter table (NoDefault c) = T.concat - [ "ALTER TABLE " - , escapeE table - , " ALTER COLUMN " - , escapeF (cName c) - , " DROP DEFAULT" - ] -showAlter table (Update' c s) = T.concat - [ "UPDATE " - , escapeE table - , " SET " - , escapeF (cName c) - , "=" - , s - , " WHERE " - , escapeF (cName c) - , " IS NULL" - ] -showAlter table (AddReference reftable fkeyname t2 id2 cascade) = T.concat - [ "ALTER TABLE " - , escapeE table - , " ADD CONSTRAINT " - , escapeC fkeyname - , " FOREIGN KEY(" - , T.intercalate "," $ map escapeF t2 - , ") REFERENCES " - , escapeE reftable - , "(" - , T.intercalate "," id2 - , ")" - ] <> renderFieldCascade cascade -showAlter table (DropReference cname) = T.concat - [ "ALTER TABLE " - , escapeE table - , " DROP CONSTRAINT " - , escapeC cname - ] - -- | Get the SQL string for the table that a PersistEntity represents. -- Useful for raw SQL queries. tableName :: (PersistEntity record) => record -> Text diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index b590b581d..9cee5e878 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -17,12 +17,14 @@ module Database.Persist.Postgresql.Internal , mayDefault , showSqlType , showColumn + , showAlter + , showAlterDb + , showAlterTable , getAddReference , udToPair , safeToRemove , postgresMkColumns , getAlters - , escapeC , escapeE , escapeF , escape @@ -609,6 +611,120 @@ escape s = go ('"':xs) = "\"\"" ++ go xs go (x:xs) = x : go xs +showAlterDb :: AlterDB -> (Bool, Text) +showAlterDb (AddTable s) = (False, s) +showAlterDb (AlterColumn t ac) = + (isUnsafe ac, showAlter t ac) + where + isUnsafe (Drop _ safeRemove) = not safeRemove + isUnsafe _ = False +showAlterDb (AlterTable t at) = (False, showAlterTable t at) + +showAlterTable :: EntityNameDB -> AlterTable -> Text +showAlterTable table (AddUniqueConstraint cname cols) = T.concat + [ "ALTER TABLE " + , escapeE table + , " ADD CONSTRAINT " + , escapeC cname + , " UNIQUE(" + , T.intercalate "," $ map escapeF cols + , ")" + ] +showAlterTable table (DropConstraint cname) = T.concat + [ "ALTER TABLE " + , escapeE table + , " DROP CONSTRAINT " + , escapeC cname + ] + +showAlter :: EntityNameDB -> AlterColumn -> Text +showAlter table (ChangeType c t extra) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " TYPE " + , showSqlType t + , extra + ] +showAlter table (IsNull c) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " DROP NOT NULL" + ] +showAlter table (NotNull c) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " SET NOT NULL" + ] +showAlter table (Add' col) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ADD COLUMN " + , showColumn col + ] +showAlter table (Drop c _) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " DROP COLUMN " + , escapeF (cName c) + ] +showAlter table (Default c s) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " SET DEFAULT " + , s + ] +showAlter table (NoDefault c) = T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " DROP DEFAULT" + ] +showAlter table (Update' c s) = T.concat + [ "UPDATE " + , escapeE table + , " SET " + , escapeF (cName c) + , "=" + , s + , " WHERE " + , escapeF (cName c) + , " IS NULL" + ] +showAlter table (AddReference reftable fkeyname t2 id2 cascade) = T.concat + [ "ALTER TABLE " + , escapeE table + , " ADD CONSTRAINT " + , escapeC fkeyname + , " FOREIGN KEY(" + , T.intercalate "," $ map escapeF t2 + , ") REFERENCES " + , escapeE reftable + , "(" + , T.intercalate "," id2 + , ")" + ] <> renderFieldCascade cascade +showAlter table (DropReference cname) = T.concat + [ "ALTER TABLE " + , escapeE table + , " DROP CONSTRAINT " + , escapeC cname + ] + showColumn :: Column -> Text showColumn (Column n nu sqlType' def gen _defConstraintName _maxLen _ref) = T.concat [ escapeF n From 5a3c8afa393cc6c2599b4a1a3818f1b07efe9f81 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 16:21:58 -0400 Subject: [PATCH 05/30] started adding structured representation to AddTable --- .../Database/Persist/Postgresql/Internal.hs | 51 ++++++++++++++----- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 9cee5e878..8c1b8c59d 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -307,8 +307,14 @@ instance PersistField PgInterval where instance PersistFieldSql PgInterval where sqlType _ = SqlOther "interval" +-- | Indicates whether a Postgres Column is safe to drop. +-- +-- @since 2.17.0.0 type SafeToRemove = Bool +-- | Represents a change to a Postgres column in a DB statement. +-- +-- @since 2.17.0.0 data AlterColumn = ChangeType Column SqlType Text | IsNull Column @@ -322,16 +328,28 @@ data AlterColumn | DropReference ConstraintNameDB deriving Show +-- | Represents a change to a Postgres table in a DB statement. +-- +-- @since 2.17.0.0 data AlterTable = AddUniqueConstraint ConstraintNameDB [FieldNameDB] | DropConstraint ConstraintNameDB deriving Show -data AlterDB = AddTable Text +-- | Represents a change to a Postgres DB in a statement. +-- +-- @since 2.17.0.0 +data AlterDB = AddTable Text EntityNameDB EntityIdDef | AlterColumn EntityNameDB AlterColumn | AlterTable EntityNameDB AlterTable deriving Show +-- | Returns a structured representation of all of the +-- DB changes required to migrate the Entity from its +-- current state in the database to the state described in +-- Haskell. +-- +-- @since 2.17.0.0 mockMigrateStructured :: [EntityDef] -> EntityDef -> IO (Either [Text] [AlterDB]) @@ -369,20 +387,14 @@ mockMigrateStructured allDefs entity = do newcols foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs - - +-- | Returns a structured representation of all of the +-- DB changes required to migrate the Entity from its current state +-- in the database to the state described in Haskell. +-- +-- @since 2.17.0.0 addTable :: [Column] -> EntityDef -> AlterDB addTable cols entity = - AddTable $ T.concat - -- Lower case e: see Database.Persist.Sql.Migration - [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! - , escapeE name - , "(" - , idtxt - , if null nonIdCols then "" else "," - , T.intercalate "," $ map showColumn nonIdCols - , ")" - ] + AddTable rawText name where nonIdCols = case entityPrimary entity of @@ -414,6 +426,16 @@ addTable cols entity = , " PRIMARY KEY UNIQUE" , mayDefault defText ] + rawText = T.concat + -- Lower case e: see Database.Persist.Sql.Migration + [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! + , escapeE name + , "(" + , idtxt + , if null nonIdCols then "" else "," + , T.intercalate "," $ map showColumn nonIdCols + , ")" + ] maySerial :: SqlType -> Maybe Text -> Text maySerial SqlInt64 Nothing = " SERIAL8 " @@ -612,7 +634,8 @@ escape s = go (x:xs) = x : go xs showAlterDb :: AlterDB -> (Bool, Text) -showAlterDb (AddTable s) = (False, s) +showAlterDb (AddTable s _) = (False, s) + showAlterDb (AlterColumn t ac) = (isUnsafe ac, showAlter t ac) where From df42045f58630d4ca8918f5b1822de6a5af8b786 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 16:35:47 -0400 Subject: [PATCH 06/30] make AddTable structured --- .../Database/Persist/Postgresql/Internal.hs | 66 +++++++++---------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 8c1b8c59d..77bbedc6d 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -339,7 +339,7 @@ data AlterTable -- | Represents a change to a Postgres DB in a statement. -- -- @since 2.17.0.0 -data AlterDB = AddTable Text EntityNameDB EntityIdDef +data AlterDB = AddTable EntityNameDB EntityIdDef [Column] | AlterColumn EntityNameDB AlterColumn | AlterTable EntityNameDB AlterTable deriving Show @@ -394,7 +394,7 @@ mockMigrateStructured allDefs entity = do -- @since 2.17.0.0 addTable :: [Column] -> EntityDef -> AlterDB addTable cols entity = - AddTable rawText name + AddTable name entityId nonIdCols where nonIdCols = case entityPrimary entity of @@ -406,36 +406,8 @@ addTable cols entity = keepField c = Just (cName c) /= fmap fieldDB (getEntityIdField entity) && not (safeToRemove entity (cName c)) - - name = - getEntityDBName entity - idtxt = - case getEntityId entity of - EntityIdNaturalKey pdef -> - T.concat - [ " PRIMARY KEY (" - , T.intercalate "," $ map (escapeF . fieldDB) $ NEL.toList $ compositeFields pdef - , ")" - ] - EntityIdField field -> - let defText = defaultAttribute $ fieldAttrs field - sType = fieldSqlType field - in T.concat - [ escapeF $ fieldDB field - , maySerial sType defText - , " PRIMARY KEY UNIQUE" - , mayDefault defText - ] - rawText = T.concat - -- Lower case e: see Database.Persist.Sql.Migration - [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! - , escapeE name - , "(" - , idtxt - , if null nonIdCols then "" else "," - , T.intercalate "," $ map showColumn nonIdCols - , ")" - ] + entityId = getEntityId entity + name = getEntityDBName entity maySerial :: SqlType -> Maybe Text -> Text maySerial SqlInt64 Nothing = " SERIAL8 " @@ -634,7 +606,35 @@ escape s = go (x:xs) = x : go xs showAlterDb :: AlterDB -> (Bool, Text) -showAlterDb (AddTable s _) = (False, s) +showAlterDb (AddTable name entityId nonIdCols) = (False, rawText) + where + idtxt = + case entityId of + EntityIdNaturalKey pdef -> + T.concat + [ " PRIMARY KEY (" + , T.intercalate "," $ map (escapeF . fieldDB) $ NEL.toList $ compositeFields pdef + , ")" + ] + EntityIdField field -> + let defText = defaultAttribute $ fieldAttrs field + sType = fieldSqlType field + in T.concat + [ escapeF $ fieldDB field + , maySerial sType defText + , " PRIMARY KEY UNIQUE" + , mayDefault defText + ] + rawText = T.concat + -- Lower case e: see Database.Persist.Sql.Migration + [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! + , escapeE name + , "(" + , idtxt + , if null nonIdCols then "" else "," + , T.intercalate "," $ map showColumn nonIdCols + , ")" + ] showAlterDb (AlterColumn t ac) = (isUnsafe ac, showAlter t ac) From 43afe362f1921f53e60cf08e5809f317adc5ae19 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 17:03:16 -0400 Subject: [PATCH 07/30] remove unneeded lines --- persistent-postgresql/Database/Persist/Postgresql/Internal.hs | 2 -- 1 file changed, 2 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 77bbedc6d..12a8152d8 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -781,8 +781,6 @@ showSqlType (SqlOther (T.toLower -> "integer")) = "INT4" showSqlType (SqlOther t) = t - - findAlters :: [EntityDef] -- ^ The list of all entity definitions that persistent is aware of. From c12707c2aafe0a405f7c5f22e5a7ad6f66b2a608 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 17:07:09 -0400 Subject: [PATCH 08/30] ran fourmolu --- .../Database/Persist/Postgresql.hs | 1557 +++++++++-------- .../Database/Persist/Postgresql/Internal.hs | 800 +++++---- 2 files changed, 1283 insertions(+), 1074 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index e6281c763..33034d11c 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -28,13 +28,10 @@ module Database.Persist.Postgresql ( withPostgresqlPool , withPostgresqlPoolWithVersion , withPostgresqlPoolWithConf - , withPostgresqlPoolModified , withPostgresqlPoolModifiedWithVersion - , withPostgresqlConn , withPostgresqlConnWithVersion - , createPostgresqlPool , createPostgresqlPoolModified , createPostgresqlPoolModifiedWithVersion @@ -60,10 +57,9 @@ module Database.Persist.Postgresql , fieldName , mockMigration , migrateEnableExtension - , PostgresConfHooks(..) + , PostgresConfHooks (..) , defaultPostgresConfHooks - - , RawPostgresql(..) + , RawPostgresql (..) , createRawPostgresqlPool , createRawPostgresqlPoolModified , createRawPostgresqlPoolModifiedWithVersion @@ -76,7 +72,7 @@ import qualified Database.PostgreSQL.LibPQ as LibPQ import qualified Database.PostgreSQL.Simple as PG import qualified Database.PostgreSQL.Simple.FromField as PGFF import qualified Database.PostgreSQL.Simple.Internal as PG -import Database.PostgreSQL.Simple.Ok (Ok(..)) +import Database.PostgreSQL.Simple.Ok (Ok (..)) import qualified Database.PostgreSQL.Simple.Transaction as PG import qualified Database.PostgreSQL.Simple.Types as PG @@ -84,16 +80,16 @@ import Control.Arrow import Control.Exception (Exception, throw, throwIO) import Control.Monad import Control.Monad.Except -import Control.Monad.IO.Unlift (MonadIO(..), MonadUnliftIO) +import Control.Monad.IO.Unlift (MonadIO (..), MonadUnliftIO) import Control.Monad.Logger (MonadLoggerIO, runNoLoggingT) import Control.Monad.Trans.Class (lift) -import Control.Monad.Trans.Reader (ReaderT(..), asks, runReaderT) +import Control.Monad.Trans.Reader (ReaderT (..), asks, runReaderT) #if !MIN_VERSION_base(4,12,0) import Control.Monad.Trans.Reader (withReaderT) #endif -import Control.Monad.Trans.Writer (WriterT(..), runWriterT) +import Control.Monad.Trans.Writer (WriterT (..), runWriterT) import qualified Data.List.NonEmpty as NEL -import Data.Proxy (Proxy(..)) +import Data.Proxy (Proxy (..)) import Data.Acquire (Acquire, mkAcquire, with) import Data.Aeson @@ -106,8 +102,8 @@ import qualified Data.Conduit.List as CL import Data.Data (Data) import Data.Either (partitionEithers) import Data.Function (on) -import Data.Int (Int64) import Data.IORef +import Data.Int (Int64) import Data.List as List (find, foldl', groupBy, sort) import qualified Data.List as List import Data.List.NonEmpty (NonEmpty) @@ -132,7 +128,10 @@ import Database.Persist.Sql import qualified Database.Persist.Sql.Util as Util import Database.Persist.SqlBackend import Database.Persist.SqlBackend.StatementCache - (StatementCache, mkSimpleStatementCache, mkStatementCache) + ( StatementCache + , mkSimpleStatementCache + , mkStatementCache + ) import System.IO.Unsafe (unsafePerformIO) -- | A @libpq@ connection string. A simple example of connection @@ -149,7 +148,7 @@ data PostgresServerVersionError = PostgresServerVersionError String instance Show PostgresServerVersionError where show (PostgresServerVersionError uniqueMsg) = - "Unexpected PostgreSQL server version, got " <> uniqueMsg + "Unexpected PostgreSQL server version, got " <> uniqueMsg instance Exception PostgresServerVersionError -- | Create a PostgreSQL connection pool and run the given action. The pool is @@ -158,53 +157,61 @@ instance Exception PostgresServerVersionError -- have been released. -- The provided action should use 'runSqlConn' and *not* 'runReaderT' because -- the former brackets the database action with transaction begin/commit. -withPostgresqlPool :: (MonadLoggerIO m, MonadUnliftIO m) - => ConnectionString - -- ^ Connection string to the database. - -> Int - -- ^ Number of connections to be kept open in - -- the pool. - -> (Pool SqlBackend -> m a) - -- ^ Action to be executed that uses the - -- connection pool. - -> m a +withPostgresqlPool + :: (MonadLoggerIO m, MonadUnliftIO m) + => ConnectionString + -- ^ Connection string to the database. + -> Int + -- ^ Number of connections to be kept open in + -- the pool. + -> (Pool SqlBackend -> m a) + -- ^ Action to be executed that uses the + -- connection pool. + -> m a withPostgresqlPool ci = withPostgresqlPoolWithVersion getServerVersion ci -- | Same as 'withPostgresPool', but takes a callback for obtaining -- the server version (to work around an Amazon Redshift bug). -- -- @since 2.6.2 -withPostgresqlPoolWithVersion :: (MonadUnliftIO m, MonadLoggerIO m) - => (PG.Connection -> IO (Maybe Double)) - -- ^ Action to perform to get the server version. - -> ConnectionString - -- ^ Connection string to the database. - -> Int - -- ^ Number of connections to be kept open in - -- the pool. - -> (Pool SqlBackend -> m a) - -- ^ Action to be executed that uses the - -- connection pool. - -> m a +withPostgresqlPoolWithVersion + :: (MonadUnliftIO m, MonadLoggerIO m) + => (PG.Connection -> IO (Maybe Double)) + -- ^ Action to perform to get the server version. + -> ConnectionString + -- ^ Connection string to the database. + -> Int + -- ^ Number of connections to be kept open in + -- the pool. + -> (Pool SqlBackend -> m a) + -- ^ Action to be executed that uses the + -- connection pool. + -> m a withPostgresqlPoolWithVersion getVerDouble ci = do - let getVer = oldGetVersionToNew getVerDouble - withSqlPool $ open' (const $ return ()) getVer id ci + let + getVer = oldGetVersionToNew getVerDouble + withSqlPool $ open' (const $ return ()) getVer id ci -- | Same as 'withPostgresqlPool', but can be configured with 'PostgresConf' and 'PostgresConfHooks'. -- -- @since 2.11.0.0 -withPostgresqlPoolWithConf :: (MonadUnliftIO m, MonadLoggerIO m) - => PostgresConf -- ^ Configuration for connecting to Postgres - -> PostgresConfHooks -- ^ Record of callback functions - -> (Pool SqlBackend -> m a) - -- ^ Action to be executed that uses the - -- connection pool. - -> m a +withPostgresqlPoolWithConf + :: (MonadUnliftIO m, MonadLoggerIO m) + => PostgresConf + -- ^ Configuration for connecting to Postgres + -> PostgresConfHooks + -- ^ Record of callback functions + -> (Pool SqlBackend -> m a) + -- ^ Action to be executed that uses the + -- connection pool. + -> m a withPostgresqlPoolWithConf conf hooks = do - let getVer = pgConfHooksGetServerVersion hooks - modConn = pgConfHooksAfterCreate hooks - let logFuncToBackend = open' modConn getVer id (pgConnStr conf) - withSqlPoolWithConfig logFuncToBackend (postgresConfToConnectionPoolConfig conf) + let + getVer = pgConfHooksGetServerVersion hooks + modConn = pgConfHooksAfterCreate hooks + let + logFuncToBackend = open' modConn getVer id (pgConnStr conf) + withSqlPoolWithConfig logFuncToBackend (postgresConfToConnectionPoolConfig conf) -- | Same as 'withPostgresqlPool', but with the 'createPostgresqlPoolModified' -- feature. @@ -212,9 +219,12 @@ withPostgresqlPoolWithConf conf hooks = do -- @since 2.13.5.0 withPostgresqlPoolModified :: (MonadUnliftIO m, MonadLoggerIO m) - => (PG.Connection -> IO ()) -- ^ Action to perform after connection is created. - -> ConnectionString -- ^ Connection string to the database. - -> Int -- ^ Number of connections to be kept open in the pool. + => (PG.Connection -> IO ()) + -- ^ Action to perform after connection is created. + -> ConnectionString + -- ^ Connection string to the database. + -> Int + -- ^ Number of connections to be kept open in the pool. -> (Pool SqlBackend -> m t) -> m t withPostgresqlPoolModified = withPostgresqlPoolModifiedWithVersion getServerVersion @@ -225,26 +235,31 @@ withPostgresqlPoolModified = withPostgresqlPoolModifiedWithVersion getServerVers -- @since 2.13.5.0 withPostgresqlPoolModifiedWithVersion :: (MonadUnliftIO m, MonadLoggerIO m) - => (PG.Connection -> IO (Maybe Double)) -- ^ Action to perform to get the server version. - -> (PG.Connection -> IO ()) -- ^ Action to perform after connection is created. - -> ConnectionString -- ^ Connection string to the database. - -> Int -- ^ Number of connections to be kept open in the pool. + => (PG.Connection -> IO (Maybe Double)) + -- ^ Action to perform to get the server version. + -> (PG.Connection -> IO ()) + -- ^ Action to perform after connection is created. + -> ConnectionString + -- ^ Connection string to the database. + -> Int + -- ^ Number of connections to be kept open in the pool. -> (Pool SqlBackend -> m t) -> m t withPostgresqlPoolModifiedWithVersion getVerDouble modConn ci = do - withSqlPool (open' modConn (oldGetVersionToNew getVerDouble) id ci) + withSqlPool (open' modConn (oldGetVersionToNew getVerDouble) id ci) -- | Create a PostgreSQL connection pool. Note that it's your -- responsibility to properly close the connection pool when -- unneeded. Use 'withPostgresqlPool' for an automatic resource -- control. -createPostgresqlPool :: (MonadUnliftIO m, MonadLoggerIO m) - => ConnectionString - -- ^ Connection string to the database. - -> Int - -- ^ Number of connections to be kept open - -- in the pool. - -> m (Pool SqlBackend) +createPostgresqlPool + :: (MonadUnliftIO m, MonadLoggerIO m) + => ConnectionString + -- ^ Connection string to the database. + -> Int + -- ^ Number of connections to be kept open + -- in the pool. + -> m (Pool SqlBackend) createPostgresqlPool = createPostgresqlPoolModified (const $ return ()) -- | Same as 'createPostgresqlPool', but additionally takes a callback function @@ -257,9 +272,12 @@ createPostgresqlPool = createPostgresqlPoolModified (const $ return ()) -- @since 2.1.3 createPostgresqlPoolModified :: (MonadUnliftIO m, MonadLoggerIO m) - => (PG.Connection -> IO ()) -- ^ Action to perform after connection is created. - -> ConnectionString -- ^ Connection string to the database. - -> Int -- ^ Number of connections to be kept open in the pool. + => (PG.Connection -> IO ()) + -- ^ Action to perform after connection is created. + -> ConnectionString + -- ^ Connection string to the database. + -> Int + -- ^ Number of connections to be kept open in the pool. -> m (Pool SqlBackend) createPostgresqlPoolModified = createPostgresqlPoolModifiedWithVersion getServerVersion @@ -270,10 +288,14 @@ createPostgresqlPoolModified = createPostgresqlPoolModifiedWithVersion getServer -- @since 2.6.2 createPostgresqlPoolModifiedWithVersion :: (MonadUnliftIO m, MonadLoggerIO m) - => (PG.Connection -> IO (Maybe Double)) -- ^ Action to perform to get the server version. - -> (PG.Connection -> IO ()) -- ^ Action to perform after connection is created. - -> ConnectionString -- ^ Connection string to the database. - -> Int -- ^ Number of connections to be kept open in the pool. + => (PG.Connection -> IO (Maybe Double)) + -- ^ Action to perform to get the server version. + -> (PG.Connection -> IO ()) + -- ^ Action to perform after connection is created. + -> ConnectionString + -- ^ Connection string to the database. + -> Int + -- ^ Number of connections to be kept open in the pool. -> m (Pool SqlBackend) createPostgresqlPoolModifiedWithVersion = createPostgresqlPoolTailored open' @@ -286,62 +308,79 @@ createPostgresqlPoolModifiedWithVersion = createPostgresqlPoolTailored open' -- @since 2.13.6 createPostgresqlPoolTailored :: (MonadUnliftIO m, MonadLoggerIO m) - => - ( (PG.Connection -> IO ()) - -> (PG.Connection -> IO (NonEmpty Word)) - -> ((PG.Connection -> SqlBackend) -> PG.Connection -> SqlBackend) - -> ConnectionString -> LogFunc -> IO SqlBackend - ) -- ^ Action that creates a postgresql connection (please see documentation on the un-exported @open'@ function in this same module. - -> (PG.Connection -> IO (Maybe Double)) -- ^ Action to perform to get the server version. - -> (PG.Connection -> IO ()) -- ^ Action to perform after connection is created. - -> ConnectionString -- ^ Connection string to the database. - -> Int -- ^ Number of connections to be kept open in the pool. + => ( (PG.Connection -> IO ()) + -> (PG.Connection -> IO (NonEmpty Word)) + -> ((PG.Connection -> SqlBackend) -> PG.Connection -> SqlBackend) + -> ConnectionString + -> LogFunc + -> IO SqlBackend + ) + -- ^ Action that creates a postgresql connection (please see documentation on the un-exported @open'@ function in this same module. + -> (PG.Connection -> IO (Maybe Double)) + -- ^ Action to perform to get the server version. + -> (PG.Connection -> IO ()) + -- ^ Action to perform after connection is created. + -> ConnectionString + -- ^ Connection string to the database. + -> Int + -- ^ Number of connections to be kept open in the pool. -> m (Pool SqlBackend) createPostgresqlPoolTailored createConnection getVerDouble modConn ci = do - let getVer = oldGetVersionToNew getVerDouble - createSqlPool $ createConnection modConn getVer id ci + let + getVer = oldGetVersionToNew getVerDouble + createSqlPool $ createConnection modConn getVer id ci -- | Same as 'createPostgresqlPool', but can be configured with 'PostgresConf' and 'PostgresConfHooks'. -- -- @since 2.11.0.0 createPostgresqlPoolWithConf :: (MonadUnliftIO m, MonadLoggerIO m) - => PostgresConf -- ^ Configuration for connecting to Postgres - -> PostgresConfHooks -- ^ Record of callback functions + => PostgresConf + -- ^ Configuration for connecting to Postgres + -> PostgresConfHooks + -- ^ Record of callback functions -> m (Pool SqlBackend) createPostgresqlPoolWithConf conf hooks = do - let getVer = pgConfHooksGetServerVersion hooks - modConn = pgConfHooksAfterCreate hooks - createSqlPoolWithConfig (open' modConn getVer id (pgConnStr conf)) (postgresConfToConnectionPoolConfig conf) + let + getVer = pgConfHooksGetServerVersion hooks + modConn = pgConfHooksAfterCreate hooks + createSqlPoolWithConfig + (open' modConn getVer id (pgConnStr conf)) + (postgresConfToConnectionPoolConfig conf) postgresConfToConnectionPoolConfig :: PostgresConf -> ConnectionPoolConfig postgresConfToConnectionPoolConfig conf = - ConnectionPoolConfig - { connectionPoolConfigStripes = pgPoolStripes conf - , connectionPoolConfigIdleTimeout = fromInteger $ pgPoolIdleTimeout conf - , connectionPoolConfigSize = pgPoolSize conf - } + ConnectionPoolConfig + { connectionPoolConfigStripes = pgPoolStripes conf + , connectionPoolConfigIdleTimeout = fromInteger $ pgPoolIdleTimeout conf + , connectionPoolConfigSize = pgPoolSize conf + } -- | Same as 'withPostgresqlPool', but instead of opening a pool -- of connections, only one connection is opened. -- The provided action should use 'runSqlConn' and *not* 'runReaderT' because -- the former brackets the database action with transaction begin/commit. -withPostgresqlConn :: (MonadUnliftIO m, MonadLoggerIO m) - => ConnectionString -> (SqlBackend -> m a) -> m a +withPostgresqlConn + :: (MonadUnliftIO m, MonadLoggerIO m) + => ConnectionString + -> (SqlBackend -> m a) + -> m a withPostgresqlConn = withPostgresqlConnWithVersion getServerVersion -- | Same as 'withPostgresqlConn', but takes a callback for obtaining -- the server version (to work around an Amazon Redshift bug). -- -- @since 2.6.2 -withPostgresqlConnWithVersion :: (MonadUnliftIO m, MonadLoggerIO m) - => (PG.Connection -> IO (Maybe Double)) - -> ConnectionString - -> (SqlBackend -> m a) - -> m a +withPostgresqlConnWithVersion + :: (MonadUnliftIO m, MonadLoggerIO m) + => (PG.Connection -> IO (Maybe Double)) + -> ConnectionString + -> (SqlBackend -> m a) + -> m a withPostgresqlConnWithVersion getVerDouble = do - let getVer = oldGetVersionToNew getVerDouble - withSqlConn . open' (const $ return ()) getVer id + let + getVer = oldGetVersionToNew getVerDouble + withSqlConn . open' (const $ return ()) getVer id open' :: (PG.Connection -> IO ()) @@ -351,7 +390,9 @@ open' -- this is just 'id', since the desired backend type is 'SqlBackend'. -- But some callers want a @'RawPostgresql' 'SqlBackend'@, and will -- pass in 'withRawConnection'. - -> ConnectionString -> LogFunc -> IO backend + -> ConnectionString + -> LogFunc + -> IO backend open' modConn getVer constructor cstr logFunc = do conn <- PG.connectPostgreSQL cstr modConn conn @@ -364,25 +405,31 @@ open' modConn getVer constructor cstr logFunc = do -- @since 2.13.6 getServerVersion :: PG.Connection -> IO (Maybe Double) getServerVersion conn = do - [PG.Only version] <- PG.query_ conn "show server_version"; - let version' = rational version - --- λ> rational "9.8.3" - --- Right (9.8,".3") - --- λ> rational "9.8.3.5" - --- Right (9.8,".3.5") - case version' of - Right (a,_) -> return $ Just a - Left err -> throwIO $ PostgresServerVersionError err + [PG.Only version] <- PG.query_ conn "show server_version" + let + version' = rational version + --- λ> rational "9.8.3" + --- Right (9.8,".3") + --- λ> rational "9.8.3.5" + --- Right (9.8,".3.5") + case version' of + Right (a, _) -> return $ Just a + Left err -> throwIO $ PostgresServerVersionError err getServerVersionNonEmpty :: PG.Connection -> IO (NonEmpty Word) getServerVersionNonEmpty conn = do - [PG.Only version] <- PG.query_ conn "show server_version"; - case AT.parseOnly parseVersion (T.pack version) of - Left err -> throwIO $ PostgresServerVersionError $ "Parse failure on: " <> version <> ". Error: " <> err - Right versionComponents -> case NEL.nonEmpty versionComponents of - Nothing -> throwIO $ PostgresServerVersionError $ "Empty Postgres version string: " <> version - Just neVersion -> pure neVersion - + [PG.Only version] <- PG.query_ conn "show server_version" + case AT.parseOnly parseVersion (T.pack version) of + Left err -> + throwIO $ + PostgresServerVersionError $ + "Parse failure on: " <> version <> ". Error: " <> err + Right versionComponents -> case NEL.nonEmpty versionComponents of + Nothing -> + throwIO $ + PostgresServerVersionError $ + "Empty Postgres version string: " <> version + Just neVersion -> pure neVersion where -- Partially copied from the `versions` package -- Typically server_version gives e.g. 12.3 @@ -394,9 +441,10 @@ getServerVersionNonEmpty conn = do -- so depending upon that we have to choose how the sql query is generated. -- upsertFunction :: Double -> Maybe (EntityDef -> Text -> Text) upsertFunction :: a -> NonEmpty Word -> Maybe a -upsertFunction f version = if (version >= postgres9dot5) - then Just f - else Nothing +upsertFunction f version = + if (version >= postgres9dot5) + then Just f + else Nothing where postgres9dot5 :: NonEmpty Word @@ -408,14 +456,16 @@ postgres9dot5 = 9 NEL.:| [5] minimumPostgresVersion :: NonEmpty Word minimumPostgresVersion = 9 NEL.:| [4] -oldGetVersionToNew :: (PG.Connection -> IO (Maybe Double)) -> (PG.Connection -> IO (NonEmpty Word)) +oldGetVersionToNew + :: (PG.Connection -> IO (Maybe Double)) -> (PG.Connection -> IO (NonEmpty Word)) oldGetVersionToNew oldFn = \conn -> do - mDouble <- oldFn conn - case mDouble of - Nothing -> pure minimumPostgresVersion - Just double -> do - let (major, minor) = properFraction double - pure $ major NEL.:| [floor minor] + mDouble <- oldFn conn + case mDouble of + Nothing -> pure minimumPostgresVersion + Just double -> do + let + (major, minor) = properFraction double + pure $ major NEL.:| [floor minor] -- | Generate a 'SqlBackend' from a 'PG.Connection'. openSimpleConn :: LogFunc -> PG.Connection -> IO SqlBackend @@ -425,7 +475,11 @@ openSimpleConn = openSimpleConnWithVersion getServerVersion -- obtaining the server version. -- -- @since 2.9.1 -openSimpleConnWithVersion :: (PG.Connection -> IO (Maybe Double)) -> LogFunc -> PG.Connection -> IO SqlBackend +openSimpleConnWithVersion + :: (PG.Connection -> IO (Maybe Double)) + -> LogFunc + -> PG.Connection + -> IO SqlBackend openSimpleConnWithVersion getVerDouble logFunc conn = do smap <- newIORef mempty serverVersion <- oldGetVersionToNew getVerDouble conn @@ -439,53 +493,66 @@ underlyingConnectionKey = unsafePerformIO Vault.newKey -- provided isn't backed by postgresql-simple. -- -- @since 2.13.0 -getSimpleConn :: (BackendCompatible SqlBackend backend) => backend -> Maybe PG.Connection +getSimpleConn + :: (BackendCompatible SqlBackend backend) => backend -> Maybe PG.Connection getSimpleConn = Vault.lookup underlyingConnectionKey <$> getConnVault -- | Create the backend given a logging function, server version, mutable statement cell, -- and connection. -- -- @since 2.13.6 -createBackend :: LogFunc -> NonEmpty Word - -> IORef (Map.Map Text Statement) -> PG.Connection -> SqlBackend +createBackend + :: LogFunc + -> NonEmpty Word + -> IORef (Map.Map Text Statement) + -> PG.Connection + -> SqlBackend createBackend logFunc serverVersion smap conn = maybe id setConnPutManySql (upsertFunction putManySql serverVersion) $ - maybe id setConnUpsertSql (upsertFunction upsertSql' serverVersion) $ - setConnInsertManySql insertManySql' $ - maybe id setConnRepsertManySql (upsertFunction repsertManySql serverVersion) $ - modifyConnVault (Vault.insert underlyingConnectionKey conn) $ mkSqlBackend MkSqlBackendArgs - { connPrepare = prepare' conn - , connStmtMap = smap - , connInsertSql = insertSql' - , connClose = PG.close conn - , connMigrateSql = migrate' - , connBegin = \_ mIsolation -> case mIsolation of - Nothing -> PG.begin conn - Just iso -> PG.beginLevel (case iso of - ReadUncommitted -> PG.ReadCommitted -- PG Upgrades uncommitted reads to committed anyways - ReadCommitted -> PG.ReadCommitted - RepeatableRead -> PG.RepeatableRead - Serializable -> PG.Serializable) conn - , connCommit = const $ PG.commit conn - , connRollback = const $ PG.rollback conn - , connEscapeFieldName = escapeF - , connEscapeTableName = escapeE . getEntityDBName - , connEscapeRawName = escape - , connNoLimit = "LIMIT ALL" - , connRDBMS = "postgresql" - , connLimitOffset = decorateSQLWithLimitOffset "LIMIT ALL" - , connLogFunc = logFunc - } + maybe id setConnUpsertSql (upsertFunction upsertSql' serverVersion) $ + setConnInsertManySql insertManySql' $ + maybe id setConnRepsertManySql (upsertFunction repsertManySql serverVersion) $ + modifyConnVault (Vault.insert underlyingConnectionKey conn) $ + mkSqlBackend + MkSqlBackendArgs + { connPrepare = prepare' conn + , connStmtMap = smap + , connInsertSql = insertSql' + , connClose = PG.close conn + , connMigrateSql = migrate' + , connBegin = \_ mIsolation -> case mIsolation of + Nothing -> PG.begin conn + Just iso -> + PG.beginLevel + ( case iso of + ReadUncommitted -> PG.ReadCommitted -- PG Upgrades uncommitted reads to committed anyways + ReadCommitted -> PG.ReadCommitted + RepeatableRead -> PG.RepeatableRead + Serializable -> PG.Serializable + ) + conn + , connCommit = const $ PG.commit conn + , connRollback = const $ PG.rollback conn + , connEscapeFieldName = escapeF + , connEscapeTableName = escapeE . getEntityDBName + , connEscapeRawName = escape + , connNoLimit = "LIMIT ALL" + , connRDBMS = "postgresql" + , connLimitOffset = decorateSQLWithLimitOffset "LIMIT ALL" + , connLogFunc = logFunc + } prepare' :: PG.Connection -> Text -> IO Statement prepare' conn sql = do - let query = PG.Query (T.encodeUtf8 sql) - return Statement - { stmtFinalize = return () - , stmtReset = return () - , stmtExecute = execute' conn query - , stmtQuery = withStmt' conn query - } + let + query = PG.Query (T.encodeUtf8 sql) + return + Statement + { stmtFinalize = return () + , stmtReset = return () + , stmtExecute = execute' conn query + , stmtQuery = withStmt' conn query + } insertSql' :: EntityDef -> [PersistValue] -> InsertSqlResult insertSql' ent vals = @@ -496,19 +563,21 @@ insertSql' ent vals = ISRSingle (sql <> " RETURNING " <> escapeF (fieldDB field)) where (fieldNames, placeholders) = unzip (Util.mkInsertPlaceholders ent escapeF) - sql = T.concat - [ "INSERT INTO " - , escapeE $ getEntityDBName ent - , if null (getEntityFields ent) - then " DEFAULT VALUES" - else T.concat - [ "(" - , T.intercalate "," fieldNames - , ") VALUES(" - , T.intercalate "," placeholders - , ")" - ] - ] + sql = + T.concat + [ "INSERT INTO " + , escapeE $ getEntityDBName ent + , if null (getEntityFields ent) + then " DEFAULT VALUES" + else + T.concat + [ "(" + , T.intercalate "," fieldNames + , ") VALUES(" + , T.intercalate "," placeholders + , ")" + ] + ] upsertSql' :: EntityDef -> NonEmpty (FieldNameHS, FieldNameDB) -> Text -> Text upsertSql' ent uniqs updateVal = @@ -540,61 +609,63 @@ insertManySql' :: EntityDef -> [[PersistValue]] -> InsertSqlResult insertManySql' ent valss = ISRSingle sql where - (fieldNames, placeholders)= unzip (Util.mkInsertPlaceholders ent escapeF) - sql = T.concat - [ "INSERT INTO " - , escapeE (getEntityDBName ent) - , "(" - , T.intercalate "," fieldNames - , ") VALUES (" - , T.intercalate "),(" $ replicate (length valss) $ T.intercalate "," placeholders - , ") RETURNING " - , Util.commaSeparated $ NEL.toList $ Util.dbIdColumnsEsc escapeF ent - ] - + (fieldNames, placeholders) = unzip (Util.mkInsertPlaceholders ent escapeF) + sql = + T.concat + [ "INSERT INTO " + , escapeE (getEntityDBName ent) + , "(" + , T.intercalate "," fieldNames + , ") VALUES (" + , T.intercalate "),(" $ replicate (length valss) $ T.intercalate "," placeholders + , ") RETURNING " + , Util.commaSeparated $ NEL.toList $ Util.dbIdColumnsEsc escapeF ent + ] execute' :: PG.Connection -> PG.Query -> [PersistValue] -> IO Int64 execute' conn query vals = PG.execute conn query (map P vals) -withStmt' :: MonadIO m - => PG.Connection - -> PG.Query - -> [PersistValue] - -> Acquire (ConduitM () [PersistValue] m ()) +withStmt' + :: (MonadIO m) + => PG.Connection + -> PG.Query + -> [PersistValue] + -> Acquire (ConduitM () [PersistValue] m ()) withStmt' conn query vals = pull `fmap` mkAcquire openS closeS where openS = do - -- Construct raw query - rawquery <- PG.formatQuery conn query (map P vals) + -- Construct raw query + rawquery <- PG.formatQuery conn query (map P vals) - -- Take raw connection - (rt, rr, rc, ids) <- PG.withConnection conn $ \rawconn -> do + -- Take raw connection + (rt, rr, rc, ids) <- PG.withConnection conn $ \rawconn -> do -- Execute query mret <- LibPQ.exec rawconn rawquery case mret of - Nothing -> do - merr <- LibPQ.errorMessage rawconn - fail $ case merr of - Nothing -> "Postgresql.withStmt': unknown error" - Just e -> "Postgresql.withStmt': " ++ B8.unpack e - Just ret -> do - -- Check result status - status <- LibPQ.resultStatus ret - case status of - LibPQ.TuplesOk -> return () - _ -> PG.throwResultError "Postgresql.withStmt': bad result status " ret status - - -- Get number and type of columns - cols <- LibPQ.nfields ret - oids <- forM [0..cols-1] $ \col -> fmap ((,) col) (LibPQ.ftype ret col) - -- Ready to go! - rowRef <- newIORef (LibPQ.Row 0) - rowCount <- LibPQ.ntuples ret - return (ret, rowRef, rowCount, oids) - let getters - = map (\(col, oid) -> getGetter oid $ PG.Field rt col oid) ids - return (rt, rr, rc, getters) + Nothing -> do + merr <- LibPQ.errorMessage rawconn + fail $ case merr of + Nothing -> "Postgresql.withStmt': unknown error" + Just e -> "Postgresql.withStmt': " ++ B8.unpack e + Just ret -> do + -- Check result status + status <- LibPQ.resultStatus ret + case status of + LibPQ.TuplesOk -> return () + _ -> PG.throwResultError "Postgresql.withStmt': bad result status " ret status + + -- Get number and type of columns + cols <- LibPQ.nfields ret + oids <- forM [0 .. cols - 1] $ \col -> fmap ((,) col) (LibPQ.ftype ret col) + -- Ready to go! + rowRef <- newIORef (LibPQ.Row 0) + rowCount <- LibPQ.ntuples ret + return (ret, rowRef, rowCount, oids) + let + getters = + map (\(col, oid) -> getGetter oid $ PG.Field rt col oid) ids + return (rt, rr, rc, getters) closeS (ret, _, _, _) = LibPQ.unsafeFreeResult ret @@ -605,34 +676,36 @@ withStmt' conn query vals = Just z -> yield z >> pull x pullS (ret, rowRef, rowCount, getters) = do - row <- atomicModifyIORef rowRef (\r -> (r+1, r)) + row <- atomicModifyIORef rowRef (\r -> (r + 1, r)) if row == rowCount - then return Nothing - else fmap Just $ forM (zip getters [0..]) $ \(getter, col) -> do - mbs <- LibPQ.getvalue' ret row col - case mbs of - Nothing -> - -- getvalue' verified that the value is NULL. - -- However, that does not mean that there are - -- no NULL values inside the value (e.g., if - -- we're dealing with an array of optional values). - return PersistNull - Just bs -> do - ok <- PGFF.runConversion (getter mbs) conn - bs `seq` case ok of - Errors (exc:_) -> throw exc - Errors [] -> error "Got an Errors, but no exceptions" - Ok v -> return v - -doesTableExist :: (Text -> IO Statement) - -> EntityNameDB - -> IO Bool + then return Nothing + else fmap Just $ forM (zip getters [0 ..]) $ \(getter, col) -> do + mbs <- LibPQ.getvalue' ret row col + case mbs of + Nothing -> + -- getvalue' verified that the value is NULL. + -- However, that does not mean that there are + -- no NULL values inside the value (e.g., if + -- we're dealing with an array of optional values). + return PersistNull + Just bs -> do + ok <- PGFF.runConversion (getter mbs) conn + bs `seq` case ok of + Errors (exc : _) -> throw exc + Errors [] -> error "Got an Errors, but no exceptions" + Ok v -> return v + +doesTableExist + :: (Text -> IO Statement) + -> EntityNameDB + -> IO Bool doesTableExist getter (EntityNameDB name) = do stmt <- getter sql with (stmtQuery stmt vals) (\src -> runConduit $ src .| start) where - sql = "SELECT COUNT(*) FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog'" - <> " AND schemaname != 'information_schema' AND tablename=?" + sql = + "SELECT COUNT(*) FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog'" + <> " AND schemaname != 'information_schema' AND tablename=?" vals = [PersistText name] start = await >>= maybe (error "No results when checking doesTableExist") start' @@ -641,10 +714,11 @@ doesTableExist getter (EntityNameDB name) = do start' res = error $ "doesTableExist returned unexpected result: " ++ show res finish x = await >>= maybe (return x) (error "Too many rows returned in doesTableExist") -migrate' :: [EntityDef] - -> (Text -> IO Statement) - -> EntityDef - -> IO (Either [Text] CautiousMigration) +migrate' + :: [EntityDef] + -> (Text -> IO Statement) + -> EntityDef + -> IO (Either [Text] CautiousMigration) migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do old <- getColumns getter entity newcols' case partitionEithers old of @@ -662,27 +736,28 @@ migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do | not exists' = createText newcols fdefs udspair | otherwise = - let (acs, ats) = + let + (acs, ats) = getAlters allDefs entity (newcols, udspair) old' acs' = map (AlterColumn name) acs ats' = map (AlterTable name) ats - in + in acs' ++ ats' - where - old' = partitionEithers old'' - newcols = filter (not . safeToRemove entity . cName) newcols' - udspair = map udToPair udefs - -- Check for table existence if there are no columns, workaround - -- for https://github.com/yesodweb/persistent/issues/152 + where + old' = partitionEithers old'' + newcols = filter (not . safeToRemove entity . cName) newcols' + udspair = map udToPair udefs + -- Check for table existence if there are no columns, workaround + -- for https://github.com/yesodweb/persistent/issues/152 createText newcols fdefs_ udspair = (addTable newcols entity) : uniques ++ references ++ foreignsAlt where uniques = flip concatMap udspair $ \(uname, ucols) -> - [AlterTable name $ AddUniqueConstraint uname ucols] + [AlterTable name $ AddUniqueConstraint uname ucols] references = mapMaybe - (\Column { cName, cReference } -> + ( \Column{cName, cReference} -> getAddReference allDefs entity cName =<< cReference ) newcols @@ -705,60 +780,70 @@ mkForeignAlt entity fdef = pure $ AlterColumn tableName_ addReference constraintName = foreignConstraintNameDBName fdef (childfields, parentfields) = - unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef)) + unzip (map (\((_, b), (_, d)) -> (b, d)) (foreignFields fdef)) escapedParentFields = map escapeF parentfields -- | Returns all of the columns in the given table currently in the database. -getColumns :: (Text -> IO Statement) - -> EntityDef -> [Column] - -> IO [Either Text (Either Column (ConstraintNameDB, [FieldNameDB]))] +getColumns + :: (Text -> IO Statement) + -> EntityDef + -> [Column] + -> IO [Either Text (Either Column (ConstraintNameDB, [FieldNameDB]))] getColumns getter def cols = do - let sqlv = T.concat - [ "SELECT " - , "column_name " - , ",is_nullable " - , ",COALESCE(domain_name, udt_name)" -- See DOMAINS below - , ",column_default " - , ",generation_expression " - , ",numeric_precision " - , ",numeric_scale " - , ",character_maximum_length " - , "FROM information_schema.columns " - , "WHERE table_catalog=current_database() " - , "AND table_schema=current_schema() " - , "AND table_name=? " - ] + let + sqlv = + T.concat + [ "SELECT " + , "column_name " + , ",is_nullable " + , ",COALESCE(domain_name, udt_name)" -- See DOMAINS below + , ",column_default " + , ",generation_expression " + , ",numeric_precision " + , ",numeric_scale " + , ",character_maximum_length " + , "FROM information_schema.columns " + , "WHERE table_catalog=current_database() " + , "AND table_schema=current_schema() " + , "AND table_name=? " + ] --- DOMAINS Postgres supports the concept of domains, which are data types --- with optional constraints. An app might make an "email" domain over the --- varchar type, with a CHECK that the emails are valid In this case the --- generated SQL should use the domain name: ALTER TABLE users ALTER COLUMN --- foo TYPE email This code exists to use the domain name (email), instead --- of the underlying type (varchar). This is tested in --- EquivalentTypeTest.hs + -- DOMAINS Postgres supports the concept of domains, which are data types + -- with optional constraints. An app might make an "email" domain over the + -- varchar type, with a CHECK that the emails are valid In this case the + -- generated SQL should use the domain name: ALTER TABLE users ALTER COLUMN + -- foo TYPE email This code exists to use the domain name (email), instead + -- of the underlying type (varchar). This is tested in + -- EquivalentTypeTest.hs stmt <- getter sqlv - let vals = + let + vals = [ PersistText $ unEntityNameDB $ getEntityDBName def ] - columns <- with (stmtQuery stmt vals) (\src -> runConduit $ src .| processColumns .| CL.consume) - let sqlc = T.concat - [ "SELECT " - , "c.constraint_name, " - , "c.column_name " - , "FROM information_schema.key_column_usage AS c, " - , "information_schema.table_constraints AS k " - , "WHERE c.table_catalog=current_database() " - , "AND c.table_catalog=k.table_catalog " - , "AND c.table_schema=current_schema() " - , "AND c.table_schema=k.table_schema " - , "AND c.table_name=? " - , "AND c.table_name=k.table_name " - , "AND c.constraint_name=k.constraint_name " - , "AND NOT k.constraint_type IN ('PRIMARY KEY', 'FOREIGN KEY') " - , "ORDER BY c.constraint_name, c.column_name" - ] + columns <- + with + (stmtQuery stmt vals) + (\src -> runConduit $ src .| processColumns .| CL.consume) + let + sqlc = + T.concat + [ "SELECT " + , "c.constraint_name, " + , "c.column_name " + , "FROM information_schema.key_column_usage AS c, " + , "information_schema.table_constraints AS k " + , "WHERE c.table_catalog=current_database() " + , "AND c.table_catalog=k.table_catalog " + , "AND c.table_schema=current_schema() " + , "AND c.table_schema=k.table_schema " + , "AND c.table_name=? " + , "AND c.table_name=k.table_name " + , "AND c.constraint_name=k.constraint_name " + , "AND NOT k.constraint_type IN ('PRIMARY KEY', 'FOREIGN KEY') " + , "ORDER BY c.constraint_name, c.column_name" + ] stmt' <- getter sqlc @@ -766,9 +851,9 @@ getColumns getter def cols = do return $ columns ++ us where refMap = - fmap (\cr -> (crTableName cr, crConstraintName cr)) - $ Map.fromList - $ List.foldl' ref [] cols + fmap (\cr -> (crTableName cr, crConstraintName cr)) $ + Map.fromList $ + List.foldl' ref [] cols where ref rs c = maybe rs (\r -> (unFieldNameDB $ cName c, r) : rs) (cReference c) @@ -780,199 +865,220 @@ getColumns getter def cols = do [PersistByteString con, PersistByteString col] -> (T.decodeUtf8 con, T.decodeUtf8 col) o -> - error $ "unexpected datatype returned for postgres o="++show o + error $ "unexpected datatype returned for postgres o=" ++ show o helperU = do rows <- getAll .| CL.consume - return $ map (Right . Right . (ConstraintNameDB . fst . head &&& map (FieldNameDB . snd))) - $ groupBy ((==) `on` fst) rows + return $ + map + (Right . Right . (ConstraintNameDB . fst . head &&& map (FieldNameDB . snd))) $ + groupBy ((==) `on` fst) rows processColumns = CL.mapM $ \x'@((PersistText cname) : _) -> do - col <- liftIO $ getColumn getter (getEntityDBName def) x' (Map.lookup cname refMap) + col <- + liftIO $ getColumn getter (getEntityDBName def) x' (Map.lookup cname refMap) pure $ case col of Left e -> Left e Right c -> Right $ Left c - getColumn :: (Text -> IO Statement) -> EntityNameDB -> [PersistValue] -> Maybe (EntityNameDB, ConstraintNameDB) -> IO (Either Text Column) -getColumn getter tableName' [ PersistText columnName - , PersistText isNullable - , PersistText typeName - , defaultValue - , generationExpression - , numericPrecision - , numericScale - , maxlen - ] refName_ = runExceptT $ do - defaultValue' <- - case defaultValue of - PersistNull -> - pure Nothing - PersistText t -> - pure $ Just t - _ -> - throwError $ T.pack $ "Invalid default column: " ++ show defaultValue - - generationExpression' <- - case generationExpression of - PersistNull -> - pure Nothing - PersistText t -> - pure $ Just t - _ -> - throwError $ T.pack $ "Invalid generated column: " ++ show generationExpression - - let typeStr = - case maxlen of - PersistInt64 n -> - T.concat [typeName, "(", T.pack (show n), ")"] +getColumn + getter + tableName' + [ PersistText columnName + , PersistText isNullable + , PersistText typeName + , defaultValue + , generationExpression + , numericPrecision + , numericScale + , maxlen + ] + refName_ = runExceptT $ do + defaultValue' <- + case defaultValue of + PersistNull -> + pure Nothing + PersistText t -> + pure $ Just t _ -> - typeName - - t <- getType typeStr - - let cname = FieldNameDB columnName - - ref <- lift $ fmap join $ traverse (getRef cname) refName_ - - return Column - { cName = cname - , cNull = isNullable == "YES" - , cSqlType = t - , cDefault = fmap stripSuffixes defaultValue' - , cGenerated = fmap stripSuffixes generationExpression' - , cDefaultConstraintName = Nothing - , cMaxLen = Nothing - , cReference = fmap (\(a,b,c,d) -> ColumnReference a b (mkCascade c d)) ref - } - - where + throwError $ T.pack $ "Invalid default column: " ++ show defaultValue + + generationExpression' <- + case generationExpression of + PersistNull -> + pure Nothing + PersistText t -> + pure $ Just t + _ -> + throwError $ T.pack $ "Invalid generated column: " ++ show generationExpression + + let + typeStr = + case maxlen of + PersistInt64 n -> + T.concat [typeName, "(", T.pack (show n), ")"] + _ -> + typeName + + t <- getType typeStr + + let + cname = FieldNameDB columnName + + ref <- lift $ fmap join $ traverse (getRef cname) refName_ + + return + Column + { cName = cname + , cNull = isNullable == "YES" + , cSqlType = t + , cDefault = fmap stripSuffixes defaultValue' + , cGenerated = fmap stripSuffixes generationExpression' + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = fmap (\(a, b, c, d) -> ColumnReference a b (mkCascade c d)) ref + } + where + mkCascade updText delText = + FieldCascade + { fcOnUpdate = parseCascade updText + , fcOnDelete = parseCascade delText + } - mkCascade updText delText = - FieldCascade - { fcOnUpdate = parseCascade updText - , fcOnDelete = parseCascade delText - } + parseCascade txt = + case txt of + "NO ACTION" -> + Nothing + "CASCADE" -> + Just Cascade + "SET NULL" -> + Just SetNull + "SET DEFAULT" -> + Just SetDefault + "RESTRICT" -> + Just Restrict + _ -> + error $ "Unexpected value in parseCascade: " <> show txt - parseCascade txt = - case txt of - "NO ACTION" -> - Nothing - "CASCADE" -> - Just Cascade - "SET NULL" -> - Just SetNull - "SET DEFAULT" -> - Just SetDefault - "RESTRICT" -> - Just Restrict - _ -> - error $ "Unexpected value in parseCascade: " <> show txt - - stripSuffixes t = - loop' - [ "::character varying" - , "::text" - ] - where - loop' [] = t - loop' (p:ps) = - case T.stripSuffix p t of - Nothing -> loop' ps - Just t' -> t' - - getRef cname (_, refName') = do - let sql = T.concat - [ "SELECT DISTINCT " - , "ccu.table_name, " - , "tc.constraint_name, " - , "rc.update_rule, " - , "rc.delete_rule " - , "FROM information_schema.constraint_column_usage ccu " - , "INNER JOIN information_schema.key_column_usage kcu " - , " ON ccu.constraint_name = kcu.constraint_name " - , "INNER JOIN information_schema.table_constraints tc " - , " ON tc.constraint_name = kcu.constraint_name " - , "LEFT JOIN information_schema.referential_constraints AS rc" - , " ON rc.constraint_name = ccu.constraint_name " - , "WHERE tc.constraint_type='FOREIGN KEY' " - , "AND kcu.ordinal_position=1 " - , "AND kcu.table_name=? " - , "AND kcu.column_name=? " - , "AND tc.constraint_name=?" + stripSuffixes t = + loop' + [ "::character varying" + , "::text" ] - stmt <- getter sql - cntrs <- - with - (stmtQuery stmt - [ PersistText $ unEntityNameDB tableName' - , PersistText $ unFieldNameDB cname - , PersistText $ unConstraintNameDB refName' + where + loop' [] = t + loop' (p : ps) = + case T.stripSuffix p t of + Nothing -> loop' ps + Just t' -> t' + + getRef cname (_, refName') = do + let + sql = + T.concat + [ "SELECT DISTINCT " + , "ccu.table_name, " + , "tc.constraint_name, " + , "rc.update_rule, " + , "rc.delete_rule " + , "FROM information_schema.constraint_column_usage ccu " + , "INNER JOIN information_schema.key_column_usage kcu " + , " ON ccu.constraint_name = kcu.constraint_name " + , "INNER JOIN information_schema.table_constraints tc " + , " ON tc.constraint_name = kcu.constraint_name " + , "LEFT JOIN information_schema.referential_constraints AS rc" + , " ON rc.constraint_name = ccu.constraint_name " + , "WHERE tc.constraint_type='FOREIGN KEY' " + , "AND kcu.ordinal_position=1 " + , "AND kcu.table_name=? " + , "AND kcu.column_name=? " + , "AND tc.constraint_name=?" + ] + stmt <- getter sql + cntrs <- + with + ( stmtQuery + stmt + [ PersistText $ unEntityNameDB tableName' + , PersistText $ unFieldNameDB cname + , PersistText $ unConstraintNameDB refName' + ] + ) + (\src -> runConduit $ src .| CL.consume) + case cntrs of + [] -> + return Nothing + [ [ PersistText table + , PersistText constraint + , PersistText updRule + , PersistText delRule + ] + ] -> + return $ + Just (EntityNameDB table, ConstraintNameDB constraint, updRule, delRule) + xs -> + error $ + mconcat + [ "Postgresql.getColumn: error fetching constraints. Expected a single result for foreign key query for table: " + , T.unpack (unEntityNameDB tableName') + , " and column: " + , T.unpack (unFieldNameDB cname) + , " but got: " + , show xs + ] + + getType "int4" = pure SqlInt32 + getType "int8" = pure SqlInt64 + getType "varchar" = pure SqlString + getType "text" = pure SqlString + getType "date" = pure SqlDay + getType "bool" = pure SqlBool + getType "timestamptz" = pure SqlDayTime + getType "float4" = pure SqlReal + getType "float8" = pure SqlReal + getType "bytea" = pure SqlBlob + getType "time" = pure SqlTime + getType "numeric" = getNumeric numericPrecision numericScale + getType a = pure $ SqlOther a + + getNumeric (PersistInt64 a) (PersistInt64 b) = + pure $ SqlNumeric (fromIntegral a) (fromIntegral b) + getNumeric PersistNull PersistNull = + throwError $ + T.concat + [ "No precision and scale were specified for the column: " + , columnName + , " in table: " + , unEntityNameDB tableName' + , ". Postgres defaults to a maximum scale of 147,455 and precision of 16383," + , " which is probably not what you intended." + , " Specify the values as numeric(total_digits, digits_after_decimal_place)." + ] + getNumeric a b = + throwError $ + T.concat + [ "Can not get numeric field precision for the column: " + , columnName + , " in table: " + , unEntityNameDB tableName' + , ". Expected an integer for both precision and scale, " + , "got: " + , T.pack $ show a + , " and " + , T.pack $ show b + , ", respectively." + , " Specify the values as numeric(total_digits, digits_after_decimal_place)." ] - ) - (\src -> runConduit $ src .| CL.consume) - case cntrs of - [] -> - return Nothing - [[PersistText table, PersistText constraint, PersistText updRule, PersistText delRule]] -> - return $ Just (EntityNameDB table, ConstraintNameDB constraint, updRule, delRule) - xs -> - error $ mconcat - [ "Postgresql.getColumn: error fetching constraints. Expected a single result for foreign key query for table: " - , T.unpack (unEntityNameDB tableName') - , " and column: " - , T.unpack (unFieldNameDB cname) - , " but got: " - , show xs - ] - - getType "int4" = pure SqlInt32 - getType "int8" = pure SqlInt64 - getType "varchar" = pure SqlString - getType "text" = pure SqlString - getType "date" = pure SqlDay - getType "bool" = pure SqlBool - getType "timestamptz" = pure SqlDayTime - getType "float4" = pure SqlReal - getType "float8" = pure SqlReal - getType "bytea" = pure SqlBlob - getType "time" = pure SqlTime - getType "numeric" = getNumeric numericPrecision numericScale - getType a = pure $ SqlOther a - - getNumeric (PersistInt64 a) (PersistInt64 b) = - pure $ SqlNumeric (fromIntegral a) (fromIntegral b) - - getNumeric PersistNull PersistNull = throwError $ T.concat - [ "No precision and scale were specified for the column: " - , columnName - , " in table: " - , unEntityNameDB tableName' - , ". Postgres defaults to a maximum scale of 147,455 and precision of 16383," - , " which is probably not what you intended." - , " Specify the values as numeric(total_digits, digits_after_decimal_place)." - ] - - getNumeric a b = throwError $ T.concat - [ "Can not get numeric field precision for the column: " - , columnName - , " in table: " - , unEntityNameDB tableName' - , ". Expected an integer for both precision and scale, " - , "got: " - , T.pack $ show a - , " and " - , T.pack $ show b - , ", respectively." - , " Specify the values as numeric(total_digits, digits_after_decimal_place)." - ] - getColumn _ _ columnName _ = - return $ Left $ T.pack $ "Invalid result from information_schema: " ++ show columnName + return $ + Left $ + T.pack $ + "Invalid result from information_schema: " ++ show columnName -- | Get the SQL string for the table that a PersistEntity represents. -- Useful for raw SQL queries. @@ -988,47 +1094,55 @@ fieldName = escapeF . fieldDBName -- using @persistent@'s generic facilities. These values are the -- same that are given to 'withPostgresqlPool'. data PostgresConf = PostgresConf - { pgConnStr :: ConnectionString - -- ^ The connection string. - - -- TODO: Currently stripes, idle timeout, and pool size are all separate fields - -- When Persistent next does a large breaking release (3.0?), we should consider making these just a single ConnectionPoolConfig value - -- - -- Currently there the idle timeout is an Integer, rather than resource-pool's NominalDiffTime type. - -- This is because the time package only recently added the Read instance for NominalDiffTime. - -- Future TODO: Consider removing the Read instance, and/or making the idle timeout a NominalDiffTime. + { pgConnStr :: ConnectionString + -- ^ The connection string. + , -- TODO: Currently stripes, idle timeout, and pool size are all separate fields + -- When Persistent next does a large breaking release (3.0?), we should consider making these just a single ConnectionPoolConfig value + -- + -- Currently there the idle timeout is an Integer, rather than resource-pool's NominalDiffTime type. + -- This is because the time package only recently added the Read instance for NominalDiffTime. + -- Future TODO: Consider removing the Read instance, and/or making the idle timeout a NominalDiffTime. - , pgPoolStripes :: Int + pgPoolStripes :: Int -- ^ How many stripes to divide the pool into. See "Data.Pool" for details. -- @since 2.11.0.0 , pgPoolIdleTimeout :: Integer -- Ideally this would be a NominalDiffTime, but that type lacks a Read instance https://github.com/haskell/time/issues/130 + -- ^ How long connections can remain idle before being disposed of, in seconds. -- @since 2.11.0.0 , pgPoolSize :: Int - -- ^ How many connections should be held in the connection pool. - } deriving (Show, Read, Data) + -- ^ How many connections should be held in the connection pool. + } + deriving (Show, Read, Data) instance FromJSON PostgresConf where parseJSON v = modifyFailure ("Persistent: error loading PostgreSQL conf: " ++) $ - flip (withObject "PostgresConf") v $ \o -> do - let defaultPoolConfig = defaultConnectionPoolConfig - database <- o .: "database" - host <- o .: "host" - port <- o .:? "port" .!= 5432 - user <- o .: "user" - password <- o .: "password" - poolSize <- o .:? "poolsize" .!= (connectionPoolConfigSize defaultPoolConfig) - poolStripes <- o .:? "stripes" .!= (connectionPoolConfigStripes defaultPoolConfig) - poolIdleTimeout <- o .:? "idleTimeout" .!= (floor $ connectionPoolConfigIdleTimeout defaultPoolConfig) - let ci = PG.ConnectInfo - { PG.connectHost = host - , PG.connectPort = port - , PG.connectUser = user - , PG.connectPassword = password - , PG.connectDatabase = database - } - cstr = PG.postgreSQLConnectionString ci - return $ PostgresConf cstr poolStripes poolIdleTimeout poolSize + flip (withObject "PostgresConf") v $ \o -> do + let + defaultPoolConfig = defaultConnectionPoolConfig + database <- o .: "database" + host <- o .: "host" + port <- o .:? "port" .!= 5432 + user <- o .: "user" + password <- o .: "password" + poolSize <- o .:? "poolsize" .!= (connectionPoolConfigSize defaultPoolConfig) + poolStripes <- + o .:? "stripes" .!= (connectionPoolConfigStripes defaultPoolConfig) + poolIdleTimeout <- + o + .:? "idleTimeout" + .!= (floor $ connectionPoolConfigIdleTimeout defaultPoolConfig) + let + ci = + PG.ConnectInfo + { PG.connectHost = host + , PG.connectPort = port + , PG.connectUser = user + , PG.connectPassword = password + , PG.connectDatabase = database + } + cstr = PG.postgreSQLConnectionString ci + return $ PostgresConf cstr poolStripes poolIdleTimeout poolSize instance PersistConfig PostgresConf where type PersistConfigBackend PostgresConf = SqlPersistT type PersistConfigPool PostgresConf = ConnectionPool @@ -1038,67 +1152,70 @@ instance PersistConfig PostgresConf where applyEnv c0 = do env <- getEnvironment - return $ addUser env - $ addPass env - $ addDatabase env - $ addPort env - $ addHost env c0 + return $ + addUser env $ + addPass env $ + addDatabase env $ + addPort env $ + addHost env c0 where addParam param val c = - c { pgConnStr = B8.concat [pgConnStr c, " ", param, "='", pgescape val, "'"] } + c{pgConnStr = B8.concat [pgConnStr c, " ", param, "='", pgescape val, "'"]} pgescape = B8.pack . go - where - go ('\'':rest) = '\\' : '\'' : go rest - go ('\\':rest) = '\\' : '\\' : go rest - go ( x :rest) = x : go rest - go [] = [] + where + go ('\'' : rest) = '\\' : '\'' : go rest + go ('\\' : rest) = '\\' : '\\' : go rest + go (x : rest) = x : go rest + go [] = [] maybeAddParam param envvar env = maybe id (addParam param) $ - lookup envvar env + lookup envvar env - addHost = maybeAddParam "host" "PGHOST" - addPort = maybeAddParam "port" "PGPORT" - addUser = maybeAddParam "user" "PGUSER" - addPass = maybeAddParam "password" "PGPASS" - addDatabase = maybeAddParam "dbname" "PGDATABASE" + addHost = maybeAddParam "host" "PGHOST" + addPort = maybeAddParam "port" "PGPORT" + addUser = maybeAddParam "user" "PGUSER" + addPass = maybeAddParam "password" "PGPASS" + addDatabase = maybeAddParam "dbname" "PGDATABASE" -- | Hooks for configuring the Persistent/its connection to Postgres -- -- @since 2.11.0 data PostgresConfHooks = PostgresConfHooks - { pgConfHooksGetServerVersion :: PG.Connection -> IO (NonEmpty Word) - -- ^ Function to get the version of Postgres - -- - -- The default implementation queries the server with "show server_version". - -- Some variants of Postgres, such as Redshift, don't support showing the version. - -- It's recommended you return a hardcoded version in those cases. - -- - -- @since 2.11.0 - , pgConfHooksAfterCreate :: PG.Connection -> IO () - -- ^ Action to perform after a connection is created. - -- - -- Typical uses of this are modifying the connection (e.g. to set the schema) or logging a connection being created. - -- - -- The default implementation does nothing. - -- - -- @since 2.11.0 - } + { pgConfHooksGetServerVersion :: PG.Connection -> IO (NonEmpty Word) + -- ^ Function to get the version of Postgres + -- + -- The default implementation queries the server with "show server_version". + -- Some variants of Postgres, such as Redshift, don't support showing the version. + -- It's recommended you return a hardcoded version in those cases. + -- + -- @since 2.11.0 + , pgConfHooksAfterCreate :: PG.Connection -> IO () + -- ^ Action to perform after a connection is created. + -- + -- Typical uses of this are modifying the connection (e.g. to set the schema) or logging a connection being created. + -- + -- The default implementation does nothing. + -- + -- @since 2.11.0 + } -- | Default settings for 'PostgresConfHooks'. See the individual fields of 'PostgresConfHooks' for the default values. -- -- @since 2.11.0 defaultPostgresConfHooks :: PostgresConfHooks -defaultPostgresConfHooks = PostgresConfHooks - { pgConfHooksGetServerVersion = getServerVersionNonEmpty - , pgConfHooksAfterCreate = const $ pure () - } - -mockMigrate :: [EntityDef] - -> (Text -> IO Statement) - -> EntityDef - -> IO (Either [Text] [(Bool, Text)]) +defaultPostgresConfHooks = + PostgresConfHooks + { pgConfHooksGetServerVersion = getServerVersionNonEmpty + , pgConfHooksAfterCreate = const $ pure () + } + +mockMigrate + :: [EntityDef] + -> (Text -> IO Statement) + -> EntityDef + -> IO (Either [Text] [(Bool, Text)]) mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ mockMigrateStructured allDefs entity -- | Mock a migration even when the database is not present. @@ -1107,30 +1224,33 @@ mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ mockMigrateStruct mockMigration :: Migration -> IO () mockMigration mig = do smap <- newIORef mempty - let sqlbackend = - mkSqlBackend MkSqlBackendArgs - { connPrepare = \_ -> do - return Statement - { stmtFinalize = return () - , stmtReset = return () - , stmtExecute = undefined - , stmtQuery = \_ -> return $ return () - } - , connInsertSql = undefined - , connStmtMap = smap - , connClose = undefined - , connMigrateSql = mockMigrate - , connBegin = undefined - , connCommit = undefined - , connRollback = undefined - , connEscapeFieldName = escapeF - , connEscapeTableName = escapeE . getEntityDBName - , connEscapeRawName = escape - , connNoLimit = undefined - , connRDBMS = undefined - , connLimitOffset = undefined - , connLogFunc = undefined - } + let + sqlbackend = + mkSqlBackend + MkSqlBackendArgs + { connPrepare = \_ -> do + return + Statement + { stmtFinalize = return () + , stmtReset = return () + , stmtExecute = undefined + , stmtQuery = \_ -> return $ return () + } + , connInsertSql = undefined + , connStmtMap = smap + , connClose = undefined + , connMigrateSql = mockMigrate + , connBegin = undefined + , connCommit = undefined + , connRollback = undefined + , connEscapeFieldName = escapeF + , connEscapeTableName = escapeE . getEntityDBName + , connEscapeRawName = escape + , connNoLimit = undefined + , connRDBMS = undefined + , connLimitOffset = undefined + , connLogFunc = undefined + } result = runReaderT $ runWriterT $ runWriterT mig resp <- result sqlbackend mapM_ T.putStrLn $ map snd $ snd resp @@ -1139,7 +1259,10 @@ putManySql :: EntityDef -> Int -> Text putManySql ent n = putManySql' conflictColumns fields ent n where fields = getEntityFields ent - conflictColumns = concatMap (map (escapeF . snd) . NEL.toList . uniqueFields) (getEntityUniques ent) + conflictColumns = + concatMap + (map (escapeF . snd) . NEL.toList . uniqueFields) + (getEntityUniques ent) repsertManySql :: EntityDef -> Int -> Text repsertManySql ent n = putManySql' conflictColumns fields ent n @@ -1153,16 +1276,23 @@ repsertManySql ent n = putManySql' conflictColumns fields ent n -- -- @since 2.12.1.0 data HandleUpdateCollision record where - -- | Copy the field directly from the record. - CopyField :: EntityField record typ -> HandleUpdateCollision record - -- | Only copy the field if it is not equal to the provided value. - CopyUnlessEq :: PersistField typ => EntityField record typ -> typ -> HandleUpdateCollision record + -- | Copy the field directly from the record. + CopyField :: EntityField record typ -> HandleUpdateCollision record + -- | Only copy the field if it is not equal to the provided value. + CopyUnlessEq + :: (PersistField typ) + => EntityField record typ + -> typ + -> HandleUpdateCollision record -- | Copy the field into the database only if the value in the -- corresponding record is non-@NULL@. -- -- @since 2.12.1.0 -copyUnlessNull :: PersistField typ => EntityField record (Maybe typ) -> HandleUpdateCollision record +copyUnlessNull + :: (PersistField typ) + => EntityField record (Maybe typ) + -> HandleUpdateCollision record copyUnlessNull field = CopyUnlessEq field Nothing -- | Copy the field into the database only if the value in the @@ -1173,7 +1303,10 @@ copyUnlessNull field = CopyUnlessEq field Nothing -- 'upsertManyWhere' function. -- -- @since 2.12.1.0 -copyUnlessEmpty :: (Monoid.Monoid typ, PersistField typ) => EntityField record typ -> HandleUpdateCollision record +copyUnlessEmpty + :: (Monoid.Monoid typ, PersistField typ) + => EntityField record typ + -> HandleUpdateCollision record copyUnlessEmpty field = CopyUnlessEq field Monoid.mempty -- | Copy the field into the database only if the field is not equal to the @@ -1184,13 +1317,18 @@ copyUnlessEmpty field = CopyUnlessEq field Monoid.mempty -- 'upsertMany' function. -- -- @since 2.12.1.0 -copyUnlessEq :: PersistField typ => EntityField record typ -> typ -> HandleUpdateCollision record +copyUnlessEq + :: (PersistField typ) + => EntityField record typ + -> typ + -> HandleUpdateCollision record copyUnlessEq = CopyUnlessEq -- | Copy the field directly from the record. -- -- @since 2.12.1.0 -copyField :: PersistField typ => EntityField record typ -> HandleUpdateCollision record +copyField + :: (PersistField typ) => EntityField record typ -> HandleUpdateCollision record copyField = CopyField -- | Postgres specific 'upsertWhere'. This method does the following: @@ -1208,20 +1346,20 @@ copyField = CopyField -- -- @since 2.12.1.0 upsertWhere - :: ( backend ~ PersistEntityBackend record - , PersistEntity record - , PersistEntityBackend record ~ SqlBackend - , MonadIO m - , PersistStore backend - , BackendCompatible SqlBackend backend - , OnlyOneUniqueKey record - ) - => record - -> [Update record] - -> [Filter record] - -> ReaderT backend m () + :: ( backend ~ PersistEntityBackend record + , PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , MonadIO m + , PersistStore backend + , BackendCompatible SqlBackend backend + , OnlyOneUniqueKey record + ) + => record + -> [Update record] + -> [Filter record] + -> ReaderT backend m () upsertWhere record updates filts = - upsertManyWhere [record] [] updates filts + upsertManyWhere [record] [] updates filts -- | Postgres specific 'upsertManyWhere'. -- @@ -1255,20 +1393,20 @@ upsertWhere record updates filts = -- -- @ -- upsertManyWhere --- [record] -- (1) --- [ copyField recordField1 -- (2) --- , copyUnlessEq recordField2 -- (3) --- , copyUnlessNull recordField3 -- (4) +-- [record] -- (1) +-- [ copyField recordField1 -- (2) +-- , copyUnlessEq recordField2 -- (3) +-- , copyUnlessNull recordField3 -- (4) -- ] --- [recordField4 =. arbitraryValue] -- (5) --- [recordField4 !=. anotherValue] -- (6) +-- [recordField4 =. arbitraryValue] -- (5) +-- [recordField4 !=. anotherValue] -- (6) -- @ --- +-- -- 1. new records to insert if there's no conflicts -- 2. for each conflicting existing row, replace the value of recordField1 with the one present in the conflicting new record -- 3. only replace the existing value if it's different from the one present in the conflicting new record -- 4. only replace the existing value if the new value is non-NULL (i.e. don't replace existing values with NULLs.) --- +-- -- 5. update recordField4 with an arbitrary new value -- 6. only apply the above updates for conflicting rows that meet this condition -- @@ -1298,14 +1436,14 @@ upsertWhere record updates filts = -- -- @since 2.12.1.0 upsertManyWhere - :: forall record backend m. - ( backend ~ PersistEntityBackend record - , BackendCompatible SqlBackend backend - , PersistEntityBackend record ~ SqlBackend - , PersistEntity record - , OnlyOneUniqueKey record - , MonadIO m - ) + :: forall record backend m + . ( backend ~ PersistEntityBackend record + , BackendCompatible SqlBackend backend + , PersistEntityBackend record ~ SqlBackend + , PersistEntity record + , OnlyOneUniqueKey record + , MonadIO m + ) => [record] -- ^ A list of the records you want to insert, or update -> [HandleUpdateCollision record] @@ -1319,7 +1457,8 @@ upsertManyWhere upsertManyWhere [] _ _ _ = return () upsertManyWhere records fieldValues updates filters = do conn <- asks projectBackend - let uniqDef = onlyOneUniqueDef (Proxy :: Proxy record) + let + uniqDef = onlyOneUniqueDef (Proxy :: Proxy record) uncurry rawExecute $ mkBulkUpsertQuery records conn fieldValues updates filters uniqDef @@ -1348,16 +1487,19 @@ excludeNotEqualToOriginal field = } where bsForExcludedField = - T.encodeUtf8 - $ "EXCLUDED." - <> fieldName field + T.encodeUtf8 $ + "EXCLUDED." + <> fieldName field -- | This creates the query for 'upsertManyWhere'. If you -- provide an empty list of updates to perform, then it will generate -- a dummy/no-op update using the first field of the record. This avoids -- duplicate key exceptions. mkBulkUpsertQuery - :: (PersistEntity record, PersistEntityBackend record ~ SqlBackend, OnlyOneUniqueKey record) + :: ( PersistEntity record + , PersistEntityBackend record ~ SqlBackend + , OnlyOneUniqueKey record + ) => [record] -- ^ A list of the records you want to insert, or update -> SqlBackend @@ -1373,7 +1515,7 @@ mkBulkUpsertQuery -- a catch-all. How frustrating! -> (Text, [PersistValue]) mkBulkUpsertQuery records conn fieldValues updates filters uniqDef = - (q, recordValues <> updsValues <> copyUnlessValues <> whereVals) + (q, recordValues <> updsValues <> copyUnlessValues <> whereVals) where mfieldDef x = case x of CopyField rec -> Right (fieldDbToText (persistFieldDef rec)) @@ -1385,39 +1527,43 @@ mkBulkUpsertQuery records conn fieldValues updates filters uniqDef = map (escapeF . snd) $ NEL.toList $ uniqueFields uniqDef firstField = case entityFieldNames of [] -> error "The entity you're trying to insert does not have any fields." - (field:_) -> field + (field : _) -> field entityFieldNames = map fieldDbToText (getEntityFields entityDef') nameOfTable = escapeE . getEntityDBName $ entityDef' copyUnlessValues = map snd fieldsToMaybeCopy recordValues = concatMap (map toPersistValue . toPersistFields) records recordPlaceholders = - Util.commaSeparated - $ map (Util.parenWrapped . Util.commaSeparated . map (const "?") . toPersistFields) - $ records + Util.commaSeparated $ + map + (Util.parenWrapped . Util.commaSeparated . map (const "?") . toPersistFields) $ + records mkCondFieldSet n _ = T.concat [ n , "=COALESCE(" - , "NULLIF(" - , "EXCLUDED." - , n - , "," - , "?" - , ")" - , "," - , nameOfTable - , "." - , n - ,")" + , "NULLIF(" + , "EXCLUDED." + , n + , "," + , "?" + , ")" + , "," + , nameOfTable + , "." + , n + , ")" ] condFieldSets = map (uncurry mkCondFieldSet) fieldsToMaybeCopy fieldSets = map (\n -> T.concat [n, "=EXCLUDED.", n, ""]) updateFieldNames - upds = map (Util.mkUpdateText' (escapeF) (\n -> T.concat [nameOfTable, ".", n])) updates + upds = + map + (Util.mkUpdateText' (escapeF) (\n -> T.concat [nameOfTable, ".", n])) + updates updsValues = map (\(Update _ val _) -> toPersistValue val) updates (wher, whereVals) = if null filters - then ("", []) - else (filterClauseWithVals (Just PrefixTableName) conn filters) + then ("", []) + else (filterClauseWithVals (Just PrefixTableName) conn filters) updateText = case fieldSets <> upds <> condFieldSets of [] -> @@ -1430,18 +1576,19 @@ mkBulkUpsertQuery records conn fieldValues updates filters uniqDef = T.concat [firstField, "=", nameOfTable, ".", firstField] xs -> Util.commaSeparated xs - q = T.concat - [ "INSERT INTO " - , nameOfTable - , Util.parenWrapped . Util.commaSeparated $ entityFieldNames - , " VALUES " - , recordPlaceholders - , " ON CONFLICT " - , Util.parenWrapped $ Util.commaSeparated $ conflictColumns - , " DO UPDATE SET " - , updateText - , wher - ] + q = + T.concat + [ "INSERT INTO " + , nameOfTable + , Util.parenWrapped . Util.commaSeparated $ entityFieldNames + , " VALUES " + , recordPlaceholders + , " ON CONFLICT " + , Util.parenWrapped $ Util.commaSeparated $ conflictColumns + , " DO UPDATE SET " + , updateText + , wher + ] putManySql' :: [Text] -> [FieldDef] -> EntityDef -> Int -> Text putManySql' conflictColumns (filter isFieldNotGenerated -> fields) ent n = q @@ -1454,29 +1601,34 @@ putManySql' conflictColumns (filter isFieldNotGenerated -> fields) ent n = q placeholders = map (const "?") fields updates = map (mkAssignment . fieldDbToText) fields - q = T.concat - [ "INSERT INTO " - , table - , Util.parenWrapped columns - , " VALUES " - , Util.commaSeparated . replicate n - . Util.parenWrapped . Util.commaSeparated $ placeholders - , " ON CONFLICT " - , Util.parenWrapped . Util.commaSeparated $ conflictColumns - , " DO UPDATE SET " - , Util.commaSeparated updates - ] - + q = + T.concat + [ "INSERT INTO " + , table + , Util.parenWrapped columns + , " VALUES " + , Util.commaSeparated + . replicate n + . Util.parenWrapped + . Util.commaSeparated + $ placeholders + , " ON CONFLICT " + , Util.parenWrapped . Util.commaSeparated $ conflictColumns + , " DO UPDATE SET " + , Util.commaSeparated updates + ] -- | Enable a Postgres extension. See https://www.postgresql.org/docs/current/static/contrib.html -- for a list. migrateEnableExtension :: Text -> Migration migrateEnableExtension extName = WriterT $ WriterT $ do - res :: [Single Int] <- - rawSql "SELECT COUNT(*) FROM pg_catalog.pg_extension WHERE extname = ?" [PersistText extName] - if res == [Single 0] - then return (((), []) , [(False, "CREATe EXTENSION \"" <> extName <> "\"")]) - else return (((), []), []) + res :: [Single Int] <- + rawSql + "SELECT COUNT(*) FROM pg_catalog.pg_extension WHERE extname = ?" + [PersistText extName] + if res == [Single 0] + then return (((), []), [(False, "CREATe EXTENSION \"" <> extName <> "\"")]) + else return (((), []), []) -- | Wrapper for persistent SqlBackends that carry the corresponding -- `Postgresql.Connection`. @@ -1503,22 +1655,24 @@ withRawConnection :: (PG.Connection -> SqlBackend) -> PG.Connection -> RawPostgresql SqlBackend -withRawConnection f conn = RawPostgresql - { persistentBackend = f conn - , rawPostgresqlConnection = conn - } +withRawConnection f conn = + RawPostgresql + { persistentBackend = f conn + , rawPostgresqlConnection = conn + } -- | Create a PostgreSQL connection pool which also exposes the -- raw connection. The raw counterpart to 'createPostgresqlPool'. -- -- @since 2.13.1.0 -createRawPostgresqlPool :: (MonadUnliftIO m, MonadLoggerIO m) - => ConnectionString - -- ^ Connection string to the database. - -> Int - -- ^ Number of connections to be kept open - -- in the pool. - -> m (Pool (RawPostgresql SqlBackend)) +createRawPostgresqlPool + :: (MonadUnliftIO m, MonadLoggerIO m) + => ConnectionString + -- ^ Connection string to the database. + -> Int + -- ^ Number of connections to be kept open + -- in the pool. + -> m (Pool (RawPostgresql SqlBackend)) createRawPostgresqlPool = createRawPostgresqlPoolModified (const $ return ()) -- | The raw counterpart to 'createPostgresqlPoolModified'. @@ -1526,9 +1680,12 @@ createRawPostgresqlPool = createRawPostgresqlPoolModified (const $ return ()) -- @since 2.13.1.0 createRawPostgresqlPoolModified :: (MonadUnliftIO m, MonadLoggerIO m) - => (PG.Connection -> IO ()) -- ^ Action to perform after connection is created. - -> ConnectionString -- ^ Connection string to the database. - -> Int -- ^ Number of connections to be kept open in the pool. + => (PG.Connection -> IO ()) + -- ^ Action to perform after connection is created. + -> ConnectionString + -- ^ Connection string to the database. + -> Int + -- ^ Number of connections to be kept open in the pool. -> m (Pool (RawPostgresql SqlBackend)) createRawPostgresqlPoolModified = createRawPostgresqlPoolModifiedWithVersion getServerVersion @@ -1537,27 +1694,37 @@ createRawPostgresqlPoolModified = createRawPostgresqlPoolModifiedWithVersion get -- @since 2.13.1.0 createRawPostgresqlPoolModifiedWithVersion :: (MonadUnliftIO m, MonadLoggerIO m) - => (PG.Connection -> IO (Maybe Double)) -- ^ Action to perform to get the server version. - -> (PG.Connection -> IO ()) -- ^ Action to perform after connection is created. - -> ConnectionString -- ^ Connection string to the database. - -> Int -- ^ Number of connections to be kept open in the pool. + => (PG.Connection -> IO (Maybe Double)) + -- ^ Action to perform to get the server version. + -> (PG.Connection -> IO ()) + -- ^ Action to perform after connection is created. + -> ConnectionString + -- ^ Connection string to the database. + -> Int + -- ^ Number of connections to be kept open in the pool. -> m (Pool (RawPostgresql SqlBackend)) createRawPostgresqlPoolModifiedWithVersion getVerDouble modConn ci = do - let getVer = oldGetVersionToNew getVerDouble - createSqlPool $ open' modConn getVer withRawConnection ci + let + getVer = oldGetVersionToNew getVerDouble + createSqlPool $ open' modConn getVer withRawConnection ci -- | The raw counterpart to 'createPostgresqlPoolWithConf'. -- -- @since 2.13.1.0 createRawPostgresqlPoolWithConf :: (MonadUnliftIO m, MonadLoggerIO m) - => PostgresConf -- ^ Configuration for connecting to Postgres - -> PostgresConfHooks -- ^ Record of callback functions + => PostgresConf + -- ^ Configuration for connecting to Postgres + -> PostgresConfHooks + -- ^ Record of callback functions -> m (Pool (RawPostgresql SqlBackend)) createRawPostgresqlPoolWithConf conf hooks = do - let getVer = pgConfHooksGetServerVersion hooks - modConn = pgConfHooksAfterCreate hooks - createSqlPoolWithConfig (open' modConn getVer withRawConnection (pgConnStr conf)) (postgresConfToConnectionPoolConfig conf) + let + getVer = pgConfHooksGetServerVersion hooks + modConn = pgConfHooksAfterCreate hooks + createSqlPoolWithConfig + (open' modConn getVer withRawConnection (pgConnStr conf)) + (postgresConfToConnectionPoolConfig conf) #if MIN_VERSION_base(4,12,0) instance (PersistCore b) => PersistCore (RawPostgresql b) where @@ -1583,7 +1750,6 @@ deriving instance (ToJSON (BackendKey b)) => ToJSON (BackendKey (RawPostgresql b deriving instance (FromJSON (BackendKey b)) => FromJSON (BackendKey (RawPostgresql b)) #endif - #if MIN_VERSION_base(4,12,0) $(pure []) @@ -1632,4 +1798,3 @@ instance (PersistUniqueWrite b) => PersistUniqueWrite (RawPostgresql b) where upsertBy uniq rec = withReaderT persistentBackend . upsertBy uniq rec putMany = withReaderT persistentBackend . putMany #endif - diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 12a8152d8..1434894a1 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -3,8 +3,8 @@ {-# LANGUAGE ViewPatterns #-} module Database.Persist.Postgresql.Internal - ( P(..) - , PgInterval(..) + ( P (..) + , PgInterval (..) , getGetter , AlterDB (..) , AlterTable (..) @@ -16,7 +16,7 @@ module Database.Persist.Postgresql.Internal , maySerial , mayDefault , showSqlType - , showColumn + , showColumn , showAlter , showAlterDb , showAlterTable @@ -37,162 +37,172 @@ import qualified Database.PostgreSQL.Simple.ToField as PGTF import qualified Database.PostgreSQL.Simple.TypeInfo.Static as PS import qualified Database.PostgreSQL.Simple.Types as PG -import Database.Persist.Sql -import qualified Data.List.NonEmpty as NEL import qualified Blaze.ByteString.Builder.Char8 as BBB +import Control.Monad import qualified Data.Attoparsec.ByteString.Char8 as P import Data.Bits ((.&.)) import Data.ByteString (ByteString) import qualified Data.ByteString.Builder as BB -import Data.List as List (find, sort) import qualified Data.ByteString.Char8 as B8 import Data.Char (ord) import Data.Data (Typeable) -import Data.Fixed (Fixed(..), Pico) -import Data.Int (Int64) import Data.Either (partitionEithers) -import Control.Monad +import Data.Fixed (Fixed (..), Pico) +import Data.Int (Int64) import qualified Data.IntMap as I +import Data.List as List (find, sort) +import qualified Data.List.NonEmpty as NEL import Data.Maybe -import qualified Database.Persist.Sql.Util as Util import Data.String.Conversions.Monomorphic (toStrictByteString) import Data.Text (Text) import qualified Data.Text as T import Data.Time (NominalDiffTime, localTimeToUTC, utc) +import Database.Persist.Sql +import qualified Database.Persist.Sql.Util as Util -- | Newtype used to avoid orphan instances for @postgresql-simple@ classes. -- -- @since 2.13.2.0 -newtype P = P { unP :: PersistValue } +newtype P = P {unP :: PersistValue} instance PGTF.ToField P where - toField (P (PersistText t)) = PGTF.toField t + toField (P (PersistText t)) = PGTF.toField t toField (P (PersistByteString bs)) = PGTF.toField (PG.Binary bs) - toField (P (PersistInt64 i)) = PGTF.toField i - toField (P (PersistDouble d)) = PGTF.toField d - toField (P (PersistRational r)) = PGTF.Plain $ - BBB.fromString $ - show (fromRational r :: Pico) -- FIXME: Too Ambigous, can not select precision without information about field - toField (P (PersistBool b)) = PGTF.toField b - toField (P (PersistDay d)) = PGTF.toField d - toField (P (PersistTimeOfDay t)) = PGTF.toField t - toField (P (PersistUTCTime t)) = PGTF.toField t - toField (P PersistNull) = PGTF.toField PG.Null - toField (P (PersistList l)) = PGTF.toField $ listToJSON l - toField (P (PersistMap m)) = PGTF.toField $ mapToJSON m - toField (P (PersistLiteral_ DbSpecific s)) = PGTF.toField (Unknown s) - toField (P (PersistLiteral_ Unescaped l)) = PGTF.toField (UnknownLiteral l) + toField (P (PersistInt64 i)) = PGTF.toField i + toField (P (PersistDouble d)) = PGTF.toField d + toField (P (PersistRational r)) = + PGTF.Plain $ + BBB.fromString $ + show (fromRational r :: Pico) -- FIXME: Too Ambigous, can not select precision without information about field + toField (P (PersistBool b)) = PGTF.toField b + toField (P (PersistDay d)) = PGTF.toField d + toField (P (PersistTimeOfDay t)) = PGTF.toField t + toField (P (PersistUTCTime t)) = PGTF.toField t + toField (P PersistNull) = PGTF.toField PG.Null + toField (P (PersistList l)) = PGTF.toField $ listToJSON l + toField (P (PersistMap m)) = PGTF.toField $ mapToJSON m + toField (P (PersistLiteral_ DbSpecific s)) = PGTF.toField (Unknown s) + toField (P (PersistLiteral_ Unescaped l)) = PGTF.toField (UnknownLiteral l) toField (P (PersistLiteral_ Escaped e)) = PGTF.toField (Unknown e) - toField (P (PersistArray a)) = PGTF.toField $ PG.PGArray $ P <$> a - toField (P (PersistObjectId _)) = + toField (P (PersistArray a)) = PGTF.toField $ PG.PGArray $ P <$> a + toField (P (PersistObjectId _)) = error "Refusing to serialize a PersistObjectId to a PostgreSQL value" instance PGFF.FromField P where fromField field mdata = fmap P $ case mdata of - -- If we try to simply decode based on oid, we will hit unexpected null - -- errors. - Nothing -> pure PersistNull - data' -> getGetter (PGFF.typeOid field) field data' + -- If we try to simply decode based on oid, we will hit unexpected null + -- errors. + Nothing -> pure PersistNull + data' -> getGetter (PGFF.typeOid field) field data' -newtype Unknown = Unknown { unUnknown :: ByteString } - deriving (Eq, Show, Read, Ord) +newtype Unknown = Unknown {unUnknown :: ByteString} + deriving (Eq, Show, Read, Ord) instance PGFF.FromField Unknown where fromField f mdata = - case mdata of - Nothing -> PGFF.returnError PGFF.UnexpectedNull f "Database.Persist.Postgresql/PGFF.FromField Unknown" - Just dat -> return (Unknown dat) + case mdata of + Nothing -> + PGFF.returnError + PGFF.UnexpectedNull + f + "Database.Persist.Postgresql/PGFF.FromField Unknown" + Just dat -> return (Unknown dat) instance PGTF.ToField Unknown where toField (Unknown a) = PGTF.Escape a -newtype UnknownLiteral = UnknownLiteral { unUnknownLiteral :: ByteString } - deriving (Eq, Show, Read, Ord, Typeable) +newtype UnknownLiteral = UnknownLiteral {unUnknownLiteral :: ByteString} + deriving (Eq, Show, Read, Ord, Typeable) instance PGFF.FromField UnknownLiteral where fromField f mdata = - case mdata of - Nothing -> PGFF.returnError PGFF.UnexpectedNull f "Database.Persist.Postgresql/PGFF.FromField UnknownLiteral" - Just dat -> return (UnknownLiteral dat) + case mdata of + Nothing -> + PGFF.returnError + PGFF.UnexpectedNull + f + "Database.Persist.Postgresql/PGFF.FromField UnknownLiteral" + Just dat -> return (UnknownLiteral dat) instance PGTF.ToField UnknownLiteral where toField (UnknownLiteral a) = PGTF.Plain $ BB.byteString a type Getter a = PGFF.FieldParser a -convertPV :: PGFF.FromField a => (a -> b) -> Getter b +convertPV :: (PGFF.FromField a) => (a -> b) -> Getter b convertPV f = (fmap f .) . PGFF.fromField builtinGetters :: I.IntMap (Getter PersistValue) -builtinGetters = I.fromList - [ (k PS.bool, convertPV PersistBool) - , (k PS.bytea, convertPV (PersistByteString . unBinary)) - , (k PS.char, convertPV PersistText) - , (k PS.name, convertPV PersistText) - , (k PS.int8, convertPV PersistInt64) - , (k PS.int2, convertPV PersistInt64) - , (k PS.int4, convertPV PersistInt64) - , (k PS.text, convertPV PersistText) - , (k PS.xml, convertPV (PersistByteString . unUnknown)) - , (k PS.float4, convertPV PersistDouble) - , (k PS.float8, convertPV PersistDouble) - , (k PS.money, convertPV PersistRational) - , (k PS.bpchar, convertPV PersistText) - , (k PS.varchar, convertPV PersistText) - , (k PS.date, convertPV PersistDay) - , (k PS.time, convertPV PersistTimeOfDay) - , (k PS.timestamp, convertPV (PersistUTCTime. localTimeToUTC utc)) - , (k PS.timestamptz, convertPV PersistUTCTime) - , (k PS.interval, convertPV (PersistLiteralEscaped . pgIntervalToBs)) - , (k PS.bit, convertPV PersistInt64) - , (k PS.varbit, convertPV PersistInt64) - , (k PS.numeric, convertPV PersistRational) - , (k PS.void, \_ _ -> return PersistNull) - , (k PS.json, convertPV (PersistByteString . unUnknown)) - , (k PS.jsonb, convertPV (PersistByteString . unUnknown)) - , (k PS.unknown, convertPV (PersistByteString . unUnknown)) - - -- Array types: same order as above. - -- The OIDs were taken from pg_type. - , (1000, listOf PersistBool) - , (1001, listOf (PersistByteString . unBinary)) - , (1002, listOf PersistText) - , (1003, listOf PersistText) - , (1016, listOf PersistInt64) - , (1005, listOf PersistInt64) - , (1007, listOf PersistInt64) - , (1009, listOf PersistText) - , (143, listOf (PersistByteString . unUnknown)) - , (1021, listOf PersistDouble) - , (1022, listOf PersistDouble) - , (1023, listOf PersistUTCTime) - , (1024, listOf PersistUTCTime) - , (791, listOf PersistRational) - , (1014, listOf PersistText) - , (1015, listOf PersistText) - , (1182, listOf PersistDay) - , (1183, listOf PersistTimeOfDay) - , (1115, listOf PersistUTCTime) - , (1185, listOf PersistUTCTime) - , (1187, listOf (PersistLiteralEscaped . pgIntervalToBs)) - , (1561, listOf PersistInt64) - , (1563, listOf PersistInt64) - , (1231, listOf PersistRational) - -- no array(void) type - , (2951, listOf (PersistLiteralEscaped . unUnknown)) - , (199, listOf (PersistByteString . unUnknown)) - , (3807, listOf (PersistByteString . unUnknown)) - -- no array(unknown) either - ] - where - k (PGFF.typoid -> i) = PG.oid2int i - -- A @listOf f@ will use a @PGArray (Maybe T)@ to convert - -- the values to Haskell-land. The @Maybe@ is important - -- because the usual way of checking NULLs - -- (c.f. withStmt') won't check for NULL inside - -- arrays---or any other compound structure for that matter. - listOf f = convertPV (PersistList . map (nullable f) . PG.fromPGArray) - where nullable = maybe PersistNull +builtinGetters = + I.fromList + [ (k PS.bool, convertPV PersistBool) + , (k PS.bytea, convertPV (PersistByteString . unBinary)) + , (k PS.char, convertPV PersistText) + , (k PS.name, convertPV PersistText) + , (k PS.int8, convertPV PersistInt64) + , (k PS.int2, convertPV PersistInt64) + , (k PS.int4, convertPV PersistInt64) + , (k PS.text, convertPV PersistText) + , (k PS.xml, convertPV (PersistByteString . unUnknown)) + , (k PS.float4, convertPV PersistDouble) + , (k PS.float8, convertPV PersistDouble) + , (k PS.money, convertPV PersistRational) + , (k PS.bpchar, convertPV PersistText) + , (k PS.varchar, convertPV PersistText) + , (k PS.date, convertPV PersistDay) + , (k PS.time, convertPV PersistTimeOfDay) + , (k PS.timestamp, convertPV (PersistUTCTime . localTimeToUTC utc)) + , (k PS.timestamptz, convertPV PersistUTCTime) + , (k PS.interval, convertPV (PersistLiteralEscaped . pgIntervalToBs)) + , (k PS.bit, convertPV PersistInt64) + , (k PS.varbit, convertPV PersistInt64) + , (k PS.numeric, convertPV PersistRational) + , (k PS.void, \_ _ -> return PersistNull) + , (k PS.json, convertPV (PersistByteString . unUnknown)) + , (k PS.jsonb, convertPV (PersistByteString . unUnknown)) + , (k PS.unknown, convertPV (PersistByteString . unUnknown)) + , -- Array types: same order as above. + -- The OIDs were taken from pg_type. + (1000, listOf PersistBool) + , (1001, listOf (PersistByteString . unBinary)) + , (1002, listOf PersistText) + , (1003, listOf PersistText) + , (1016, listOf PersistInt64) + , (1005, listOf PersistInt64) + , (1007, listOf PersistInt64) + , (1009, listOf PersistText) + , (143, listOf (PersistByteString . unUnknown)) + , (1021, listOf PersistDouble) + , (1022, listOf PersistDouble) + , (1023, listOf PersistUTCTime) + , (1024, listOf PersistUTCTime) + , (791, listOf PersistRational) + , (1014, listOf PersistText) + , (1015, listOf PersistText) + , (1182, listOf PersistDay) + , (1183, listOf PersistTimeOfDay) + , (1115, listOf PersistUTCTime) + , (1185, listOf PersistUTCTime) + , (1187, listOf (PersistLiteralEscaped . pgIntervalToBs)) + , (1561, listOf PersistInt64) + , (1563, listOf PersistInt64) + , (1231, listOf PersistRational) + , -- no array(void) type + (2951, listOf (PersistLiteralEscaped . unUnknown)) + , (199, listOf (PersistByteString . unUnknown)) + , (3807, listOf (PersistByteString . unUnknown)) + -- no array(unknown) either + ] + where + k (PGFF.typoid -> i) = PG.oid2int i + -- A @listOf f@ will use a @PGArray (Maybe T)@ to convert + -- the values to Haskell-land. The @Maybe@ is important + -- because the usual way of checking NULLs + -- (c.f. withStmt') won't check for NULL inside + -- arrays---or any other compound structure for that matter. + listOf f = convertPV (PersistList . map (nullable f) . PG.fromPGArray) + where + nullable = maybe PersistNull -- | Get the field parser corresponding to the given 'PG.Oid'. -- @@ -201,9 +211,10 @@ builtinGetters = I.fromList -- -- @since 2.13.2.0 getGetter :: PG.Oid -> Getter PersistValue -getGetter oid - = fromMaybe defaultGetter $ I.lookup (PG.oid2int oid) builtinGetters - where defaultGetter = convertPV (PersistLiteralEscaped . unUnknown) +getGetter oid = + fromMaybe defaultGetter $ I.lookup (PG.oid2int oid) builtinGetters + where + defaultGetter = convertPV (PersistLiteralEscaped . unUnknown) unBinary :: PG.Binary a -> a unBinary (PG.Binary x) = x @@ -211,8 +222,8 @@ unBinary (PG.Binary x) = x -- | Represent Postgres interval using NominalDiffTime -- -- @since 2.11.0.0 -newtype PgInterval = PgInterval { getPgInterval :: NominalDiffTime } - deriving (Eq, Show) +newtype PgInterval = PgInterval {getPgInterval :: NominalDiffTime} + deriving (Eq, Show) pgIntervalToBs :: PgInterval -> ByteString pgIntervalToBs = toStrictByteString . show . getPgInterval @@ -222,14 +233,13 @@ instance PGTF.ToField PgInterval where instance PGFF.FromField PgInterval where fromField f mdata = - if PGFF.typeOid f /= PS.typoid PS.interval - then PGFF.returnError PGFF.Incompatible f "" - else case mdata of - Nothing -> PGFF.returnError PGFF.UnexpectedNull f "" - Just dat -> case P.parseOnly (nominalDiffTime <* P.endOfInput) dat of - Left msg -> PGFF.returnError PGFF.ConversionFailed f msg - Right t -> return $ PgInterval t - + if PGFF.typeOid f /= PS.typoid PS.interval + then PGFF.returnError PGFF.Incompatible f "" + else case mdata of + Nothing -> PGFF.returnError PGFF.UnexpectedNull f "" + Just dat -> case P.parseOnly (nominalDiffTime <* P.endOfInput) dat of + Left msg -> PGFF.returnError PGFF.ConversionFailed f msg + Right t -> return $ PgInterval t where toPico :: Integer -> Pico toPico = MkFixed @@ -237,27 +247,32 @@ instance PGFF.FromField PgInterval where -- Taken from Database.PostgreSQL.Simple.Time.Internal.Parser twoDigits :: P.Parser Int twoDigits = do - a <- P.digit - b <- P.digit - let c2d c = ord c .&. 15 - return $! c2d a * 10 + c2d b + a <- P.digit + b <- P.digit + let + c2d c = ord c .&. 15 + return $! c2d a * 10 + c2d b -- Taken from Database.PostgreSQL.Simple.Time.Internal.Parser seconds :: P.Parser Pico seconds = do - real <- twoDigits - mc <- P.peekChar - case mc of - Just '.' -> do - t <- P.anyChar *> P.takeWhile1 P.isDigit - return $! parsePicos (fromIntegral real) t - _ -> return $! fromIntegral real - where - parsePicos :: Int64 -> B8.ByteString -> Pico - parsePicos a0 t = toPico (fromIntegral (t' * 10^n)) - where n = max 0 (12 - B8.length t) - t' = B8.foldl' (\a c -> 10 * a + fromIntegral (ord c .&. 15)) a0 - (B8.take 12 t) + real <- twoDigits + mc <- P.peekChar + case mc of + Just '.' -> do + t <- P.anyChar *> P.takeWhile1 P.isDigit + return $! parsePicos (fromIntegral real) t + _ -> return $! fromIntegral real + where + parsePicos :: Int64 -> B8.ByteString -> Pico + parsePicos a0 t = toPico (fromIntegral (t' * 10 ^ n)) + where + n = max 0 (12 - B8.length t) + t' = + B8.foldl' + (\a c -> 10 * a + fromIntegral (ord c .&. 15)) + a0 + (B8.take 12 t) parseSign :: P.Parser Bool parseSign = P.choice [P.char '-' >> return True, return False] @@ -266,9 +281,9 @@ instance PGFF.FromField PgInterval where -- For example, nominalDay is stored as 24:00:00 interval :: P.Parser (Bool, Int, Int, Pico) interval = do - s <- parseSign - h <- P.decimal <* P.char ':' - m <- twoDigits <* P.char ':' + s <- parseSign + h <- P.decimal <* P.char ':' + m <- twoDigits <* P.char ':' ss <- seconds if m < 60 && ss <= 60 then return (s, h, m, ss) @@ -276,36 +291,43 @@ instance PGFF.FromField PgInterval where nominalDiffTime :: P.Parser NominalDiffTime nominalDiffTime = do - (s, h, m, ss) <- interval - let pico = ss + 60 * (fromIntegral m) + 60 * 60 * (fromIntegral (abs h)) - return . fromRational . toRational $ if s then (-pico) else pico - -fromPersistValueError :: Text -- ^ Haskell type, should match Haskell name exactly, e.g. "Int64" - -> Text -- ^ Database type(s), should appear different from Haskell name, e.g. "integer" or "INT", not "Int". - -> PersistValue -- ^ Incorrect value - -> Text -- ^ Error message -fromPersistValueError haskellType databaseType received = T.concat - [ "Failed to parse Haskell type `" - , haskellType - , "`; expected " - , databaseType - , " from database, but received: " - , T.pack (show received) - , ". Potential solution: Check that your database schema matches your Persistent model definitions." - ] + (s, h, m, ss) <- interval + let + pico = ss + 60 * (fromIntegral m) + 60 * 60 * (fromIntegral (abs h)) + return . fromRational . toRational $ if s then (-pico) else pico + +fromPersistValueError + :: Text + -- ^ Haskell type, should match Haskell name exactly, e.g. "Int64" + -> Text + -- ^ Database type(s), should appear different from Haskell name, e.g. "integer" or "INT", not "Int". + -> PersistValue + -- ^ Incorrect value + -> Text + -- ^ Error message +fromPersistValueError haskellType databaseType received = + T.concat + [ "Failed to parse Haskell type `" + , haskellType + , "`; expected " + , databaseType + , " from database, but received: " + , T.pack (show received) + , ". Potential solution: Check that your database schema matches your Persistent model definitions." + ] instance PersistField PgInterval where toPersistValue = PersistLiteralEscaped . pgIntervalToBs fromPersistValue (PersistLiteral_ DbSpecific bs) = fromPersistValue (PersistLiteralEscaped bs) fromPersistValue x@(PersistLiteral_ Escaped bs) = - case P.parseOnly (P.signed P.rational <* P.char 's' <* P.endOfInput) bs of - Left _ -> Left $ fromPersistValueError "PgInterval" "Interval" x - Right i -> Right $ PgInterval i + case P.parseOnly (P.signed P.rational <* P.char 's' <* P.endOfInput) bs of + Left _ -> Left $ fromPersistValueError "PgInterval" "Interval" x + Right i -> Right $ PgInterval i fromPersistValue x = Left $ fromPersistValueError "PgInterval" "Interval" x instance PersistFieldSql PgInterval where - sqlType _ = SqlOther "interval" + sqlType _ = SqlOther "interval" -- | Indicates whether a Postgres Column is safe to drop. -- @@ -326,7 +348,7 @@ data AlterColumn | Update' Column Text | AddReference EntityNameDB ConstraintNameDB [FieldNameDB] [Text] FieldCascade | DropReference ConstraintNameDB - deriving Show + deriving (Show) -- | Represents a change to a Postgres table in a DB statement. -- @@ -334,25 +356,27 @@ data AlterColumn data AlterTable = AddUniqueConstraint ConstraintNameDB [FieldNameDB] | DropConstraint ConstraintNameDB - deriving Show + deriving (Show) -- | Represents a change to a Postgres DB in a statement. -- -- @since 2.17.0.0 -data AlterDB = AddTable EntityNameDB EntityIdDef [Column] - | AlterColumn EntityNameDB AlterColumn - | AlterTable EntityNameDB AlterTable - deriving Show - --- | Returns a structured representation of all of the --- DB changes required to migrate the Entity from its --- current state in the database to the state described in +data AlterDB + = AddTable EntityNameDB EntityIdDef [Column] + | AlterColumn EntityNameDB AlterColumn + | AlterTable EntityNameDB AlterTable + deriving (Show) + +-- | Returns a structured representation of all of the +-- DB changes required to migrate the Entity from its +-- current state in the database to the state described in -- Haskell. -- -- @since 2.17.0.0 -mockMigrateStructured :: [EntityDef] - -> EntityDef - -> IO (Either [Text] [AlterDB]) +mockMigrateStructured + :: [EntityDef] + -> EntityDef + -> IO (Either [Text] [AlterDB]) mockMigrateStructured allDefs entity = do case partitionEithers [] of ([], old'') -> return $ Right $ migrationText False old'' @@ -362,32 +386,35 @@ mockMigrateStructured allDefs entity = do migrationText exists' old'' = if not exists' then createText newcols fdefs udspair - else let (acs, ats) = getAlters allDefs entity (newcols, udspair) old' - acs' = map (AlterColumn name) acs - ats' = map (AlterTable name) ats - in acs' ++ ats' - where - old' = partitionEithers old'' - (newcols', udefs, fdefs) = postgresMkColumns allDefs entity - newcols = filter (not . safeToRemove entity . cName) newcols' - udspair = map udToPair udefs - -- Check for table existence if there are no columns, workaround - -- for https://github.com/yesodweb/persistent/issues/152 + else + let + (acs, ats) = getAlters allDefs entity (newcols, udspair) old' + acs' = map (AlterColumn name) acs + ats' = map (AlterTable name) ats + in + acs' ++ ats' + where + old' = partitionEithers old'' + (newcols', udefs, fdefs) = postgresMkColumns allDefs entity + newcols = filter (not . safeToRemove entity . cName) newcols' + udspair = map udToPair udefs + -- Check for table existence if there are no columns, workaround + -- for https://github.com/yesodweb/persistent/issues/152 createText newcols fdefs udspair = (addTable newcols entity) : uniques ++ references ++ foreignsAlt where uniques = flip concatMap udspair $ \(uname, ucols) -> - [AlterTable name $ AddUniqueConstraint uname ucols] + [AlterTable name $ AddUniqueConstraint uname ucols] references = mapMaybe - (\Column { cName, cReference } -> + ( \Column{cName, cReference} -> getAddReference allDefs entity cName =<< cReference ) newcols foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs --- | Returns a structured representation of all of the +-- | Returns a structured representation of all of the -- DB changes required to migrate the Entity from its current state -- in the database to the state described in Haskell. -- @@ -405,7 +432,7 @@ addTable cols entity = where keepField c = Just (cName c) /= fmap fieldDB (getEntityIdField entity) - && not (safeToRemove entity (cName c)) + && not (safeToRemove entity (cName c)) entityId = getEntityId entity name = getEntityDBName entity @@ -418,20 +445,22 @@ mayDefault def = case def of Nothing -> "" Just d -> " DEFAULT " <> d - -getAlters :: [EntityDef] - -> EntityDef - -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) - -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) - -> ([AlterColumn], [AlterTable]) +getAlters + :: [EntityDef] + -> EntityDef + -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) + -> ([Column], [(ConstraintNameDB, [FieldNameDB])]) + -> ([AlterColumn], [AlterTable]) getAlters defs def (c1, u1) (c2, u2) = (getAltersC c1 c2, getAltersU u1 u2) where getAltersC [] old = map (\x -> Drop x $ safeToRemove def $ cName x) old - getAltersC (new:news) old = - let (alters, old') = findAlters defs def new old - in alters ++ getAltersC news old' + getAltersC (new : news) old = + let + (alters, old') = findAlters defs def new old + in + alters ++ getAltersC news old' getAltersU :: [(ConstraintNameDB, [FieldNameDB])] @@ -439,22 +468,24 @@ getAlters defs def (c1, u1) (c2, u2) = -> [AlterTable] getAltersU [] old = map DropConstraint $ filter (not . isManual) $ map fst old - getAltersU ((name, cols):news) old = + getAltersU ((name, cols) : news) old = case lookup name old of Nothing -> AddUniqueConstraint name cols : getAltersU news old Just ocols -> - let old' = filter (\(x, _) -> x /= name) old - in if sort cols == sort ocols + let + old' = filter (\(x, _) -> x /= name) old + in + if sort cols == sort ocols then getAltersU news old' - else DropConstraint name - : AddUniqueConstraint name cols - : getAltersU news old' + else + DropConstraint name + : AddUniqueConstraint name cols + : getAltersU news old' -- Don't drop constraints which were manually added. isManual (ConstraintNameDB x) = "__manual_" `T.isPrefixOf` x - -- | Postgres' default maximum identifier length in bytes -- (You can re-compile Postgres with a new limit, but I'm assuming that virtually noone does this). -- See https://www.postgresql.org/docs/11/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS @@ -468,69 +499,71 @@ sqlTypeEq x y = -- Non exhaustive helper to map postgres aliases to the same name. Based on -- https://www.postgresql.org/docs/9.5/datatype.html. -- This prevents needless `ALTER TYPE`s when the type is the same. - normalize "int8" = "bigint" + normalize "int8" = "bigint" normalize "serial8" = "bigserial" - normalize v = v - in - normalize (T.toCaseFold (showSqlType x)) == normalize (T.toCaseFold (showSqlType y)) + normalize v = v + in + normalize (T.toCaseFold (showSqlType x)) + == normalize (T.toCaseFold (showSqlType y)) -- We check if we should alter a foreign key. This is almost an equality check, -- except we consider 'Nothing' and 'Just Restrict' equivalent. equivalentRef :: Maybe ColumnReference -> Maybe ColumnReference -> Bool equivalentRef Nothing Nothing = True equivalentRef (Just cr1) (Just cr2) = - crTableName cr1 == crTableName cr2 - && crConstraintName cr1 == crConstraintName cr2 - && eqCascade (fcOnUpdate $ crFieldCascade cr1) (fcOnUpdate $ crFieldCascade cr2) - && eqCascade (fcOnDelete $ crFieldCascade cr1) (fcOnDelete $ crFieldCascade cr2) + crTableName cr1 == crTableName cr2 + && crConstraintName cr1 == crConstraintName cr2 + && eqCascade (fcOnUpdate $ crFieldCascade cr1) (fcOnUpdate $ crFieldCascade cr2) + && eqCascade (fcOnDelete $ crFieldCascade cr1) (fcOnDelete $ crFieldCascade cr2) where eqCascade :: Maybe CascadeAction -> Maybe CascadeAction -> Bool - eqCascade Nothing Nothing = True + eqCascade Nothing Nothing = True eqCascade Nothing (Just Restrict) = True eqCascade (Just Restrict) Nothing = True - eqCascade (Just cs1) (Just cs2) = cs1 == cs2 - eqCascade _ _ = False + eqCascade (Just cs1) (Just cs2) = cs1 == cs2 + eqCascade _ _ = False equivalentRef _ _ = False refName :: EntityNameDB -> FieldNameDB -> ConstraintNameDB refName (EntityNameDB table) (FieldNameDB column) = - let overhead = T.length $ T.concat ["_", "_fkey"] + let + overhead = T.length $ T.concat ["_", "_fkey"] (fromTable, fromColumn) = shortenNames overhead (T.length table, T.length column) - in ConstraintNameDB $ T.concat [T.take fromTable table, "_", T.take fromColumn column, "_fkey"] - - where - - -- Postgres automatically truncates too long foreign keys to a combination of - -- truncatedTableName + "_" + truncatedColumnName + "_fkey" - -- This works fine for normal use cases, but it creates an issue for Persistent - -- Because after running the migrations, Persistent sees the truncated foreign key constraint - -- doesn't have the expected name, and suggests that you migrate again - -- To workaround this, we copy the Postgres truncation approach before sending foreign key constraints to it. - -- - -- I believe this will also be an issue for extremely long table names, - -- but it's just much more likely to exist with foreign key constraints because they're usually tablename * 2 in length - - -- Approximation of the algorithm Postgres uses to truncate identifiers - -- See makeObjectName https://github.com/postgres/postgres/blob/5406513e997f5ee9de79d4076ae91c04af0c52f6/src/backend/commands/indexcmds.c#L2074-L2080 - shortenNames :: Int -> (Int, Int) -> (Int, Int) - shortenNames overhead (x, y) - | x + y + overhead <= maximumIdentifierLength = (x, y) - | x > y = shortenNames overhead (x - 1, y) - | otherwise = shortenNames overhead (x, y - 1) - -postgresMkColumns :: [EntityDef] -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) + in + ConstraintNameDB $ + T.concat [T.take fromTable table, "_", T.take fromColumn column, "_fkey"] + where + -- Postgres automatically truncates too long foreign keys to a combination of + -- truncatedTableName + "_" + truncatedColumnName + "_fkey" + -- This works fine for normal use cases, but it creates an issue for Persistent + -- Because after running the migrations, Persistent sees the truncated foreign key constraint + -- doesn't have the expected name, and suggests that you migrate again + -- To workaround this, we copy the Postgres truncation approach before sending foreign key constraints to it. + -- + -- I believe this will also be an issue for extremely long table names, + -- but it's just much more likely to exist with foreign key constraints because they're usually tablename * 2 in length + + -- Approximation of the algorithm Postgres uses to truncate identifiers + -- See makeObjectName https://github.com/postgres/postgres/blob/5406513e997f5ee9de79d4076ae91c04af0c52f6/src/backend/commands/indexcmds.c#L2074-L2080 + shortenNames :: Int -> (Int, Int) -> (Int, Int) + shortenNames overhead (x, y) + | x + y + overhead <= maximumIdentifierLength = (x, y) + | x > y = shortenNames overhead (x - 1, y) + | otherwise = shortenNames overhead (x, y - 1) + +postgresMkColumns + :: [EntityDef] -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) postgresMkColumns allDefs t = - mkColumns allDefs t - $ setBackendSpecificForeignKeyName refName emptyBackendSpecificOverrides - + mkColumns allDefs t $ + setBackendSpecificForeignKeyName refName emptyBackendSpecificOverrides -- | Check if a column name is listed as the "safe to remove" in the entity -- list. safeToRemove :: EntityDef -> FieldNameDB -> Bool -safeToRemove def (FieldNameDB colName) - = any (elem FieldAttrSafeToRemove . fieldAttrs) - $ filter ((== FieldNameDB colName) . fieldDB) - $ allEntityFields +safeToRemove def (FieldNameDB colName) = + any (elem FieldAttrSafeToRemove . fieldAttrs) $ + filter ((== FieldNameDB colName) . fieldDB) $ + allEntityFields where allEntityFields = getEntityFieldsDatabase def <> case getEntityId def of @@ -539,11 +572,9 @@ safeToRemove def (FieldNameDB colName) _ -> [] - udToPair :: UniqueDef -> (ConstraintNameDB, [FieldNameDB]) udToPair ud = (uniqueDBName ud, map snd $ NEL.toList $ uniqueFields ud) - -- | Get the references to be added to a table for the given column. getAddReference :: [EntityDef] @@ -551,12 +582,13 @@ getAddReference -> FieldNameDB -> ColumnReference -> Maybe AlterDB -getAddReference allDefs entity cname cr@ColumnReference {crTableName = s, crConstraintName=constraintName} = do +getAddReference allDefs entity cname cr@ColumnReference{crTableName = s, crConstraintName = constraintName} = do guard $ Just cname /= fmap fieldDB (getEntityIdField entity) - pure $ AlterColumn - table - (AddReference s constraintName [cname] id_ (crFieldCascade cr) - ) + pure $ + AlterColumn + table + ( AddReference s constraintName [cname] id_ (crFieldCascade cr) + ) where table = getEntityDBName entity id_ = @@ -583,7 +615,7 @@ mkForeignAlt entity fdef = pure $ AlterColumn tableName_ addReference constraintName = foreignConstraintNameDBName fdef (childfields, parentfields) = - unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef)) + unzip (map (\((_, b), (_, d)) -> (b, d)) (foreignFields fdef)) escapedParentFields = map escapeF parentfields @@ -596,14 +628,13 @@ escapeE = escapeWith escape escapeF :: FieldNameDB -> Text escapeF = escapeWith escape - escape :: Text -> Text escape s = T.pack $ '"' : go (T.unpack s) ++ "\"" where go "" = "" - go ('"':xs) = "\"\"" ++ go xs - go (x:xs) = x : go xs + go ('"' : xs) = "\"\"" ++ go xs + go (x : xs) = x : go xs showAlterDb :: AlterDB -> (Bool, Text) showAlterDb (AddTable name entityId nonIdCols) = (False, rawText) @@ -617,25 +648,27 @@ showAlterDb (AddTable name entityId nonIdCols) = (False, rawText) , ")" ] EntityIdField field -> - let defText = defaultAttribute $ fieldAttrs field + let + defText = defaultAttribute $ fieldAttrs field sType = fieldSqlType field - in T.concat + in + T.concat [ escapeF $ fieldDB field , maySerial sType defText , " PRIMARY KEY UNIQUE" , mayDefault defText ] - rawText = T.concat - -- Lower case e: see Database.Persist.Sql.Migration - [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! - , escapeE name - , "(" - , idtxt - , if null nonIdCols then "" else "," - , T.intercalate "," $ map showColumn nonIdCols - , ")" - ] - + rawText = + T.concat + -- Lower case e: see Database.Persist.Sql.Migration + [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION! + , escapeE name + , "(" + , idtxt + , if null nonIdCols then "" else "," + , T.intercalate "," $ map showColumn nonIdCols + , ")" + ] showAlterDb (AlterColumn t ac) = (isUnsafe ac, showAlter t ac) where @@ -644,21 +677,23 @@ showAlterDb (AlterColumn t ac) = showAlterDb (AlterTable t at) = (False, showAlterTable t at) showAlterTable :: EntityNameDB -> AlterTable -> Text -showAlterTable table (AddUniqueConstraint cname cols) = T.concat - [ "ALTER TABLE " - , escapeE table - , " ADD CONSTRAINT " - , escapeC cname - , " UNIQUE(" - , T.intercalate "," $ map escapeF cols - , ")" - ] -showAlterTable table (DropConstraint cname) = T.concat - [ "ALTER TABLE " - , escapeE table - , " DROP CONSTRAINT " - , escapeC cname - ] +showAlterTable table (AddUniqueConstraint cname cols) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ADD CONSTRAINT " + , escapeC cname + , " UNIQUE(" + , T.intercalate "," $ map escapeF cols + , ")" + ] +showAlterTable table (DropConstraint cname) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " DROP CONSTRAINT " + , escapeC cname + ] showAlter :: EntityNameDB -> AlterColumn -> Text showAlter table (ChangeType c t extra) = @@ -710,75 +745,78 @@ showAlter table (Default c s) = , " SET DEFAULT " , s ] -showAlter table (NoDefault c) = T.concat - [ "ALTER TABLE " - , escapeE table - , " ALTER COLUMN " - , escapeF (cName c) - , " DROP DEFAULT" - ] -showAlter table (Update' c s) = T.concat - [ "UPDATE " - , escapeE table - , " SET " - , escapeF (cName c) - , "=" - , s - , " WHERE " - , escapeF (cName c) - , " IS NULL" - ] -showAlter table (AddReference reftable fkeyname t2 id2 cascade) = T.concat - [ "ALTER TABLE " - , escapeE table - , " ADD CONSTRAINT " - , escapeC fkeyname - , " FOREIGN KEY(" - , T.intercalate "," $ map escapeF t2 - , ") REFERENCES " - , escapeE reftable - , "(" - , T.intercalate "," id2 - , ")" - ] <> renderFieldCascade cascade -showAlter table (DropReference cname) = T.concat - [ "ALTER TABLE " - , escapeE table - , " DROP CONSTRAINT " - , escapeC cname - ] +showAlter table (NoDefault c) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ALTER COLUMN " + , escapeF (cName c) + , " DROP DEFAULT" + ] +showAlter table (Update' c s) = + T.concat + [ "UPDATE " + , escapeE table + , " SET " + , escapeF (cName c) + , "=" + , s + , " WHERE " + , escapeF (cName c) + , " IS NULL" + ] +showAlter table (AddReference reftable fkeyname t2 id2 cascade) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " ADD CONSTRAINT " + , escapeC fkeyname + , " FOREIGN KEY(" + , T.intercalate "," $ map escapeF t2 + , ") REFERENCES " + , escapeE reftable + , "(" + , T.intercalate "," id2 + , ")" + ] + <> renderFieldCascade cascade +showAlter table (DropReference cname) = + T.concat + [ "ALTER TABLE " + , escapeE table + , " DROP CONSTRAINT " + , escapeC cname + ] showColumn :: Column -> Text -showColumn (Column n nu sqlType' def gen _defConstraintName _maxLen _ref) = T.concat - [ escapeF n - , " " - , showSqlType sqlType' - , " " - , if nu then "NULL" else "NOT NULL" - , case def of - Nothing -> "" - Just s -> " DEFAULT " <> s - , case gen of - Nothing -> "" - Just s -> " GENERATED ALWAYS AS (" <> s <> ") STORED" - ] - +showColumn (Column n nu sqlType' def gen _defConstraintName _maxLen _ref) = + T.concat + [ escapeF n + , " " + , showSqlType sqlType' + , " " + , if nu then "NULL" else "NOT NULL" + , case def of + Nothing -> "" + Just s -> " DEFAULT " <> s + , case gen of + Nothing -> "" + Just s -> " GENERATED ALWAYS AS (" <> s <> ") STORED" + ] showSqlType :: SqlType -> Text showSqlType SqlString = "VARCHAR" showSqlType SqlInt32 = "INT4" showSqlType SqlInt64 = "INT8" showSqlType SqlReal = "DOUBLE PRECISION" -showSqlType (SqlNumeric s prec) = T.concat [ "NUMERIC(", T.pack (show s), ",", T.pack (show prec), ")" ] +showSqlType (SqlNumeric s prec) = T.concat ["NUMERIC(", T.pack (show s), ",", T.pack (show prec), ")"] showSqlType SqlDay = "DATE" showSqlType SqlTime = "TIME" showSqlType SqlDayTime = "TIMESTAMP WITH TIME ZONE" showSqlType SqlBlob = "BYTEA" showSqlType SqlBool = "BOOLEAN" - -- Added for aliasing issues re: https://github.com/yesodweb/yesod/issues/682 showSqlType (SqlOther (T.toLower -> "integer")) = "INT4" - showSqlType (SqlOther t) = t findAlters @@ -794,62 +832,68 @@ findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName case List.find (\c -> cName c == name) cols of Nothing -> ([Add' col], cols) - Just (Column _oldName isNull' sqltype' def' _gen' _defConstraintName' _maxLen' ref') -> - let refDrop Nothing = [] - refDrop (Just ColumnReference {crConstraintName=cname}) = + Just + (Column _oldName isNull' sqltype' def' _gen' _defConstraintName' _maxLen' ref') -> + let + refDrop Nothing = [] + refDrop (Just ColumnReference{crConstraintName = cname}) = [DropReference cname] refAdd Nothing = [] refAdd (Just colRef) = case find ((== crTableName colRef) . getEntityDBName) defs of Just refdef - | Just _oldName /= fmap fieldDB (getEntityIdField edef) - -> - [AddReference - (crTableName colRef) - (crConstraintName colRef) - [name] - (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef) - (crFieldCascade colRef) - ] + | Just _oldName /= fmap fieldDB (getEntityIdField edef) -> + [ AddReference + (crTableName colRef) + (crConstraintName colRef) + [name] + (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef) + (crFieldCascade colRef) + ] Just _ -> [] Nothing -> - error $ "could not find the entityDef for reftable[" - ++ show (crTableName colRef) ++ "]" + error $ + "could not find the entityDef for reftable[" + ++ show (crTableName colRef) + ++ "]" modRef = if equivalentRef ref ref' then [] else refDrop ref' ++ refAdd ref modNull = case (isNull, isNull') of - (True, False) -> do - guard $ Just name /= fmap fieldDB (getEntityIdField edef) - pure (IsNull col) - (False, True) -> - let up = case def of - Nothing -> id - Just s -> (:) (Update' col s) - in up [NotNull col] - _ -> [] + (True, False) -> do + guard $ Just name /= fmap fieldDB (getEntityIdField edef) + pure (IsNull col) + (False, True) -> + let + up = case def of + Nothing -> id + Just s -> (:) (Update' col s) + in + up [NotNull col] + _ -> [] modType | sqlTypeEq sqltype sqltype' = [] -- When converting from Persistent pre-2.0 databases, we -- need to make sure that TIMESTAMP WITHOUT TIME ZONE is -- treated as UTC. | sqltype == SqlDayTime && sqltype' == SqlOther "timestamp" = - [ChangeType col sqltype $ T.concat - [ " USING " - , escapeF name - , " AT TIME ZONE 'UTC'" - ]] + [ ChangeType col sqltype $ + T.concat + [ " USING " + , escapeF name + , " AT TIME ZONE 'UTC'" + ] + ] | otherwise = [ChangeType col sqltype ""] modDef = if def == def' || isJust (T.stripPrefix "nextval" =<< def') then [] - else - case def of - Nothing -> [NoDefault col] - Just s -> [Default col s] + else case def of + Nothing -> [NoDefault col] + Just s -> [Default col s] dropSafe = if safeToRemove edef name then error "wtf" [Drop col True] @@ -857,4 +901,4 @@ findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName in ( modRef ++ modDef ++ modNull ++ modType ++ dropSafe , filter (\c -> cName c /= name) cols - ) \ No newline at end of file + ) From 3f4594bcd38c9ecd9eeda9e89903bdbc8c0ae00f Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 17:22:37 -0400 Subject: [PATCH 09/30] added changelog --- persistent/ChangeLog.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/persistent/ChangeLog.md b/persistent/ChangeLog.md index de47d1f4d..927628bda 100644 --- a/persistent/ChangeLog.md +++ b/persistent/ChangeLog.md @@ -1,5 +1,12 @@ # Changelog for persistent +# 2.17.1.0 + +* [#1600](https://github.com/yesodweb/persistent/pull/1600) + * Add `mockMigrateStructured` to `Database.Persist.Postgresql.Internal`. + This allows you to access a structured representation of the mock migrations + for use in yor application. + # 2.17.0.0 * [#1595](https://github.com/yesodweb/persistent/pull/1595) From f5521518646b1f2dba08ccacb569c88f86f22c10 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 17:56:13 -0400 Subject: [PATCH 10/30] Ran fourmolu again --- .../Database/Persist/Postgresql.hs | 20 +-- .../Database/Persist/Postgresql/Internal.hs | 136 +++++++++--------- 2 files changed, 78 insertions(+), 78 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index 33034d11c..d51cdf093 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -868,10 +868,10 @@ getColumns getter def cols = do error $ "unexpected datatype returned for postgres o=" ++ show o helperU = do rows <- getAll .| CL.consume - return $ - map - (Right . Right . (ConstraintNameDB . fst . head &&& map (FieldNameDB . snd))) $ - groupBy ((==) `on` fst) rows + return + $ map + (Right . Right . (ConstraintNameDB . fst . head &&& map (FieldNameDB . snd))) + $ groupBy ((==) `on` fst) rows processColumns = CL.mapM $ \x'@((PersistText cname) : _) -> do col <- @@ -1019,8 +1019,8 @@ getColumn , PersistText delRule ] ] -> - return $ - Just (EntityNameDB table, ConstraintNameDB constraint, updRule, delRule) + return $ + Just (EntityNameDB table, ConstraintNameDB constraint, updRule, delRule) xs -> error $ mconcat @@ -1533,10 +1533,10 @@ mkBulkUpsertQuery records conn fieldValues updates filters uniqDef = copyUnlessValues = map snd fieldsToMaybeCopy recordValues = concatMap (map toPersistValue . toPersistFields) records recordPlaceholders = - Util.commaSeparated $ - map - (Util.parenWrapped . Util.commaSeparated . map (const "?") . toPersistFields) $ - records + Util.commaSeparated + $ map + (Util.parenWrapped . Util.commaSeparated . map (const "?") . toPersistFields) + $ records mkCondFieldSet n _ = T.concat [ n diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 1434894a1..87243ee7b 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -834,71 +834,71 @@ findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName ([Add' col], cols) Just (Column _oldName isNull' sqltype' def' _gen' _defConstraintName' _maxLen' ref') -> - let - refDrop Nothing = [] - refDrop (Just ColumnReference{crConstraintName = cname}) = - [DropReference cname] - - refAdd Nothing = [] - refAdd (Just colRef) = - case find ((== crTableName colRef) . getEntityDBName) defs of - Just refdef - | Just _oldName /= fmap fieldDB (getEntityIdField edef) -> - [ AddReference - (crTableName colRef) - (crConstraintName colRef) - [name] - (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef) - (crFieldCascade colRef) - ] - Just _ -> [] - Nothing -> - error $ - "could not find the entityDef for reftable[" - ++ show (crTableName colRef) - ++ "]" - modRef = - if equivalentRef ref ref' - then [] - else refDrop ref' ++ refAdd ref - modNull = case (isNull, isNull') of - (True, False) -> do - guard $ Just name /= fmap fieldDB (getEntityIdField edef) - pure (IsNull col) - (False, True) -> - let - up = case def of - Nothing -> id - Just s -> (:) (Update' col s) - in - up [NotNull col] - _ -> [] - modType - | sqlTypeEq sqltype sqltype' = [] - -- When converting from Persistent pre-2.0 databases, we - -- need to make sure that TIMESTAMP WITHOUT TIME ZONE is - -- treated as UTC. - | sqltype == SqlDayTime && sqltype' == SqlOther "timestamp" = - [ ChangeType col sqltype $ - T.concat - [ " USING " - , escapeF name - , " AT TIME ZONE 'UTC'" - ] - ] - | otherwise = [ChangeType col sqltype ""] - modDef = - if def == def' - || isJust (T.stripPrefix "nextval" =<< def') - then [] - else case def of - Nothing -> [NoDefault col] - Just s -> [Default col s] - dropSafe = - if safeToRemove edef name - then error "wtf" [Drop col True] - else [] - in - ( modRef ++ modDef ++ modNull ++ modType ++ dropSafe - , filter (\c -> cName c /= name) cols - ) + let + refDrop Nothing = [] + refDrop (Just ColumnReference{crConstraintName = cname}) = + [DropReference cname] + + refAdd Nothing = [] + refAdd (Just colRef) = + case find ((== crTableName colRef) . getEntityDBName) defs of + Just refdef + | Just _oldName /= fmap fieldDB (getEntityIdField edef) -> + [ AddReference + (crTableName colRef) + (crConstraintName colRef) + [name] + (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef) + (crFieldCascade colRef) + ] + Just _ -> [] + Nothing -> + error $ + "could not find the entityDef for reftable[" + ++ show (crTableName colRef) + ++ "]" + modRef = + if equivalentRef ref ref' + then [] + else refDrop ref' ++ refAdd ref + modNull = case (isNull, isNull') of + (True, False) -> do + guard $ Just name /= fmap fieldDB (getEntityIdField edef) + pure (IsNull col) + (False, True) -> + let + up = case def of + Nothing -> id + Just s -> (:) (Update' col s) + in + up [NotNull col] + _ -> [] + modType + | sqlTypeEq sqltype sqltype' = [] + -- When converting from Persistent pre-2.0 databases, we + -- need to make sure that TIMESTAMP WITHOUT TIME ZONE is + -- treated as UTC. + | sqltype == SqlDayTime && sqltype' == SqlOther "timestamp" = + [ ChangeType col sqltype $ + T.concat + [ " USING " + , escapeF name + , " AT TIME ZONE 'UTC'" + ] + ] + | otherwise = [ChangeType col sqltype ""] + modDef = + if def == def' + || isJust (T.stripPrefix "nextval" =<< def') + then [] + else case def of + Nothing -> [NoDefault col] + Just s -> [Default col s] + dropSafe = + if safeToRemove edef name + then error "wtf" [Drop col True] + else [] + in + ( modRef ++ modDef ++ modNull ++ modType ++ dropSafe + , filter (\c -> cName c /= name) cols + ) From 8692779c43911c2ce8441cab21827cdbcfa3de63 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 18:03:05 -0400 Subject: [PATCH 11/30] manually add restyle --- persistent-postgresql/Database/Persist/Postgresql/Internal.hs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 87243ee7b..5deaede17 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -587,8 +587,7 @@ getAddReference allDefs entity cname cr@ColumnReference{crTableName = s, crConst pure $ AlterColumn table - ( AddReference s constraintName [cname] id_ (crFieldCascade cr) - ) + (AddReference s constraintName [cname] id_ (crFieldCascade cr)) where table = getEntityDBName entity id_ = From a117b2af41f16fce0a6778cd409408715bd5b114 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 18:10:39 -0400 Subject: [PATCH 12/30] fix version @since, add newtype to SafeToRemove --- .../Database/Persist/Postgresql/Internal.hs | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 5deaede17..272ceeb73 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -331,12 +331,13 @@ instance PersistFieldSql PgInterval where -- | Indicates whether a Postgres Column is safe to drop. -- --- @since 2.17.0.0 -type SafeToRemove = Bool +-- @since 2.17.1.0 +newtype SafeToRemove = SafeToRemove Bool + deriving (Show) -- | Represents a change to a Postgres column in a DB statement. -- --- @since 2.17.0.0 +-- @since 2.17.1.0 data AlterColumn = ChangeType Column SqlType Text | IsNull Column @@ -352,7 +353,7 @@ data AlterColumn -- | Represents a change to a Postgres table in a DB statement. -- --- @since 2.17.0.0 +-- @since 2.17.1.0 data AlterTable = AddUniqueConstraint ConstraintNameDB [FieldNameDB] | DropConstraint ConstraintNameDB @@ -360,7 +361,7 @@ data AlterTable -- | Represents a change to a Postgres DB in a statement. -- --- @since 2.17.0.0 +-- @since 2.17.1.0 data AlterDB = AddTable EntityNameDB EntityIdDef [Column] | AlterColumn EntityNameDB AlterColumn @@ -372,7 +373,7 @@ data AlterDB -- current state in the database to the state described in -- Haskell. -- --- @since 2.17.0.0 +-- @since 2.17.1.0 mockMigrateStructured :: [EntityDef] -> EntityDef @@ -418,7 +419,7 @@ mockMigrateStructured allDefs entity = do -- DB changes required to migrate the Entity from its current state -- in the database to the state described in Haskell. -- --- @since 2.17.0.0 +-- @since 2.17.1.0 addTable :: [Column] -> EntityDef -> AlterDB addTable cols entity = AddTable name entityId nonIdCols @@ -455,7 +456,7 @@ getAlters defs def (c1, u1) (c2, u2) = (getAltersC c1 c2, getAltersU u1 u2) where getAltersC [] old = - map (\x -> Drop x $ safeToRemove def $ cName x) old + map (\x -> Drop x $ SafeToRemove $ safeToRemove def $ cName x) old getAltersC (new : news) old = let (alters, old') = findAlters defs def new old @@ -671,7 +672,7 @@ showAlterDb (AddTable name entityId nonIdCols) = (False, rawText) showAlterDb (AlterColumn t ac) = (isUnsafe ac, showAlter t ac) where - isUnsafe (Drop _ safeRemove) = not safeRemove + isUnsafe (Drop _ (SafeToRemove safeRemove)) = not safeRemove isUnsafe _ = False showAlterDb (AlterTable t at) = (False, showAlterTable t at) @@ -895,7 +896,7 @@ findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName Just s -> [Default col s] dropSafe = if safeToRemove edef name - then error "wtf" [Drop col True] + then error "wtf" [Drop col (SafeToRemove True)] else [] in ( modRef ++ modDef ++ modNull ++ modType ++ dropSafe From d7170d4957dcf967ad4e6823d3c51d6006c7ac8d Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 18:12:59 -0400 Subject: [PATCH 13/30] updated constructors to better names --- .../Database/Persist/Postgresql/Internal.hs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 272ceeb73..a198f647b 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -342,11 +342,11 @@ data AlterColumn = ChangeType Column SqlType Text | IsNull Column | NotNull Column - | Add' Column + | AddColumn Column | Drop Column SafeToRemove | Default Column Text | NoDefault Column - | Update' Column Text + | UpdateNullToValue Column Text | AddReference EntityNameDB ConstraintNameDB [FieldNameDB] [Text] FieldCascade | DropReference ConstraintNameDB deriving (Show) @@ -722,7 +722,7 @@ showAlter table (NotNull c) = , escapeF (cName c) , " SET NOT NULL" ] -showAlter table (Add' col) = +showAlter table (AddColumn col) = T.concat [ "ALTER TABLE " , escapeE table @@ -753,7 +753,7 @@ showAlter table (NoDefault c) = , escapeF (cName c) , " DROP DEFAULT" ] -showAlter table (Update' c s) = +showAlter table (UpdateNullToValue c s) = T.concat [ "UPDATE " , escapeE table @@ -831,7 +831,7 @@ findAlters findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName _maxLen ref) cols = case List.find (\c -> cName c == name) cols of Nothing -> - ([Add' col], cols) + ([AddColumn col], cols) Just (Column _oldName isNull' sqltype' def' _gen' _defConstraintName' _maxLen' ref') -> let @@ -869,7 +869,7 @@ findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName let up = case def of Nothing -> id - Just s -> (:) (Update' col s) + Just s -> (:) (UpdateNullToValue col s) in up [NotNull col] _ -> [] From e982aa03c8d9ce9d63fa5caae222baccd5b4d02e Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 18:27:02 -0400 Subject: [PATCH 14/30] remove unneeded partition --- .../Database/Persist/Postgresql/Internal.hs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index a198f647b..9bdc369aa 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -378,10 +378,8 @@ mockMigrateStructured :: [EntityDef] -> EntityDef -> IO (Either [Text] [AlterDB]) -mockMigrateStructured allDefs entity = do - case partitionEithers [] of - ([], old'') -> return $ Right $ migrationText False old'' - (errs, _) -> return $ Left errs +mockMigrateStructured allDefs entity = + return $ Right $ migrationText False [] where name = getEntityDBName entity migrationText exists' old'' = From a88fdfb8314a7a0658eae2fa66bb0ef01c676784 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 18:34:33 -0400 Subject: [PATCH 15/30] move to NEL for reference --- .../Database/Persist/Postgresql.hs | 18 +++++++------ .../Database/Persist/Postgresql/Internal.hs | 27 ++++++++++--------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index d51cdf093..33e66cb09 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -767,16 +767,18 @@ mkForeignAlt :: EntityDef -> ForeignDef -> Maybe AlterDB -mkForeignAlt entity fdef = pure $ AlterColumn tableName_ addReference +mkForeignAlt entity fdef = case childfields of + [] -> Nothing + f : r -> Just $ AlterColumn tableName_ (addReference) + where + addReference = AddReference + (foreignRefTableDBName fdef) + constraintName + (f NEL.:| r) + escapedParentFields + (foreignFieldCascade fdef) where tableName_ = getEntityDBName entity - addReference = - AddReference - (foreignRefTableDBName fdef) - constraintName - childfields - escapedParentFields - (foreignFieldCascade fdef) constraintName = foreignConstraintNameDBName fdef (childfields, parentfields) = diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 9bdc369aa..fa0addb14 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -347,7 +347,7 @@ data AlterColumn | Default Column Text | NoDefault Column | UpdateNullToValue Column Text - | AddReference EntityNameDB ConstraintNameDB [FieldNameDB] [Text] FieldCascade + | AddReference EntityNameDB ConstraintNameDB (NEL.NonEmpty FieldNameDB) [Text] FieldCascade | DropReference ConstraintNameDB deriving (Show) @@ -586,7 +586,7 @@ getAddReference allDefs entity cname cr@ColumnReference{crTableName = s, crConst pure $ AlterColumn table - (AddReference s constraintName [cname] id_ (crFieldCascade cr)) + (AddReference s constraintName (cname NEL.:| []) id_ (crFieldCascade cr)) where table = getEntityDBName entity id_ = @@ -600,16 +600,19 @@ mkForeignAlt :: EntityDef -> ForeignDef -> Maybe AlterDB -mkForeignAlt entity fdef = pure $ AlterColumn tableName_ addReference +mkForeignAlt entity fdef = case NEL.nonEmpty childfields of + Nothing -> Nothing + Just childfields' -> Just $ AlterColumn tableName_ addReference + where + addReference = + AddReference + (foreignRefTableDBName fdef) + constraintName + childfields' + escapedParentFields + (foreignFieldCascade fdef) where tableName_ = getEntityDBName entity - addReference = - AddReference - (foreignRefTableDBName fdef) - constraintName - childfields - escapedParentFields - (foreignFieldCascade fdef) constraintName = foreignConstraintNameDBName fdef (childfields, parentfields) = @@ -770,7 +773,7 @@ showAlter table (AddReference reftable fkeyname t2 id2 cascade) = , " ADD CONSTRAINT " , escapeC fkeyname , " FOREIGN KEY(" - , T.intercalate "," $ map escapeF t2 + , T.intercalate "," $ map escapeF $ NEL.toList t2 , ") REFERENCES " , escapeE reftable , "(" @@ -845,7 +848,7 @@ findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName [ AddReference (crTableName colRef) (crConstraintName colRef) - [name] + (name NEL.:| []) (NEL.toList $ Util.dbIdColumnsEsc escapeF refdef) (crFieldCascade colRef) ] From f0b001e366a40fadc7f1713fb042689d8ac5a7dd Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Tue, 1 Jul 2025 18:34:50 -0400 Subject: [PATCH 16/30] fourmolu --- .../Database/Persist/Postgresql.hs | 21 ++++++------ .../Database/Persist/Postgresql/Internal.hs | 33 +++++++++++-------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index 33e66cb09..d7ed1937e 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -767,16 +767,17 @@ mkForeignAlt :: EntityDef -> ForeignDef -> Maybe AlterDB -mkForeignAlt entity fdef = case childfields of - [] -> Nothing - f : r -> Just $ AlterColumn tableName_ (addReference) - where - addReference = AddReference - (foreignRefTableDBName fdef) - constraintName - (f NEL.:| r) - escapedParentFields - (foreignFieldCascade fdef) +mkForeignAlt entity fdef = case childfields of + [] -> Nothing + f : r -> Just $ AlterColumn tableName_ (addReference) + where + addReference = + AddReference + (foreignRefTableDBName fdef) + constraintName + (f NEL.:| r) + escapedParentFields + (foreignFieldCascade fdef) where tableName_ = getEntityDBName entity constraintName = diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index fa0addb14..0cb9010f0 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -347,7 +347,12 @@ data AlterColumn | Default Column Text | NoDefault Column | UpdateNullToValue Column Text - | AddReference EntityNameDB ConstraintNameDB (NEL.NonEmpty FieldNameDB) [Text] FieldCascade + | AddReference + EntityNameDB + ConstraintNameDB + (NEL.NonEmpty FieldNameDB) + [Text] + FieldCascade | DropReference ConstraintNameDB deriving (Show) @@ -378,7 +383,7 @@ mockMigrateStructured :: [EntityDef] -> EntityDef -> IO (Either [Text] [AlterDB]) -mockMigrateStructured allDefs entity = +mockMigrateStructured allDefs entity = return $ Right $ migrationText False [] where name = getEntityDBName entity @@ -600,17 +605,17 @@ mkForeignAlt :: EntityDef -> ForeignDef -> Maybe AlterDB -mkForeignAlt entity fdef = case NEL.nonEmpty childfields of - Nothing -> Nothing - Just childfields' -> Just $ AlterColumn tableName_ addReference - where - addReference = - AddReference - (foreignRefTableDBName fdef) - constraintName - childfields' - escapedParentFields - (foreignFieldCascade fdef) +mkForeignAlt entity fdef = case NEL.nonEmpty childfields of + Nothing -> Nothing + Just childfields' -> Just $ AlterColumn tableName_ addReference + where + addReference = + AddReference + (foreignRefTableDBName fdef) + constraintName + childfields' + escapedParentFields + (foreignFieldCascade fdef) where tableName_ = getEntityDBName entity constraintName = @@ -773,7 +778,7 @@ showAlter table (AddReference reftable fkeyname t2 id2 cascade) = , " ADD CONSTRAINT " , escapeC fkeyname , " FOREIGN KEY(" - , T.intercalate "," $ map escapeF $ NEL.toList t2 + , T.intercalate "," $ map escapeF $ NEL.toList t2 , ") REFERENCES " , escapeE reftable , "(" From 49f19024c65f3e94b64a99a0a9f8213f86736d7b Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Wed, 2 Jul 2025 09:30:29 -0400 Subject: [PATCH 17/30] move structured migration to internal --- .../Database/Persist/Postgresql.hs | 358 +--------------- .../Database/Persist/Postgresql/Internal.hs | 382 ++++++++++++++++++ 2 files changed, 383 insertions(+), 357 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index d7ed1937e..5006e5696 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -695,73 +695,12 @@ withStmt' conn query vals = Errors [] -> error "Got an Errors, but no exceptions" Ok v -> return v -doesTableExist - :: (Text -> IO Statement) - -> EntityNameDB - -> IO Bool -doesTableExist getter (EntityNameDB name) = do - stmt <- getter sql - with (stmtQuery stmt vals) (\src -> runConduit $ src .| start) - where - sql = - "SELECT COUNT(*) FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog'" - <> " AND schemaname != 'information_schema' AND tablename=?" - vals = [PersistText name] - - start = await >>= maybe (error "No results when checking doesTableExist") start' - start' [PersistInt64 0] = finish False - start' [PersistInt64 1] = finish True - start' res = error $ "doesTableExist returned unexpected result: " ++ show res - finish x = await >>= maybe (return x) (error "Too many rows returned in doesTableExist") - migrate' :: [EntityDef] -> (Text -> IO Statement) -> EntityDef -> IO (Either [Text] CautiousMigration) -migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do - old <- getColumns getter entity newcols' - case partitionEithers old of - ([], old'') -> do - exists' <- - if null old - then doesTableExist getter name - else return True - return $ Right $ migrationText exists' old'' - (errs, _) -> return $ Left errs - where - name = getEntityDBName entity - (newcols', udefs, fdefs) = postgresMkColumns allDefs entity - migrationText exists' old'' - | not exists' = - createText newcols fdefs udspair - | otherwise = - let - (acs, ats) = - getAlters allDefs entity (newcols, udspair) old' - acs' = map (AlterColumn name) acs - ats' = map (AlterTable name) ats - in - acs' ++ ats' - where - old' = partitionEithers old'' - newcols = filter (not . safeToRemove entity . cName) newcols' - udspair = map udToPair udefs - -- Check for table existence if there are no columns, workaround - -- for https://github.com/yesodweb/persistent/issues/152 - - createText newcols fdefs_ udspair = - (addTable newcols entity) : uniques ++ references ++ foreignsAlt - where - uniques = flip concatMap udspair $ \(uname, ucols) -> - [AlterTable name $ AddUniqueConstraint uname ucols] - references = - mapMaybe - ( \Column{cName, cReference} -> - getAddReference allDefs entity cName =<< cReference - ) - newcols - foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs_ +migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ migrateStructured allDefs getter entity mkForeignAlt :: EntityDef @@ -787,301 +726,6 @@ mkForeignAlt entity fdef = case childfields of escapedParentFields = map escapeF parentfields --- | Returns all of the columns in the given table currently in the database. -getColumns - :: (Text -> IO Statement) - -> EntityDef - -> [Column] - -> IO [Either Text (Either Column (ConstraintNameDB, [FieldNameDB]))] -getColumns getter def cols = do - let - sqlv = - T.concat - [ "SELECT " - , "column_name " - , ",is_nullable " - , ",COALESCE(domain_name, udt_name)" -- See DOMAINS below - , ",column_default " - , ",generation_expression " - , ",numeric_precision " - , ",numeric_scale " - , ",character_maximum_length " - , "FROM information_schema.columns " - , "WHERE table_catalog=current_database() " - , "AND table_schema=current_schema() " - , "AND table_name=? " - ] - - -- DOMAINS Postgres supports the concept of domains, which are data types - -- with optional constraints. An app might make an "email" domain over the - -- varchar type, with a CHECK that the emails are valid In this case the - -- generated SQL should use the domain name: ALTER TABLE users ALTER COLUMN - -- foo TYPE email This code exists to use the domain name (email), instead - -- of the underlying type (varchar). This is tested in - -- EquivalentTypeTest.hs - - stmt <- getter sqlv - let - vals = - [ PersistText $ unEntityNameDB $ getEntityDBName def - ] - columns <- - with - (stmtQuery stmt vals) - (\src -> runConduit $ src .| processColumns .| CL.consume) - let - sqlc = - T.concat - [ "SELECT " - , "c.constraint_name, " - , "c.column_name " - , "FROM information_schema.key_column_usage AS c, " - , "information_schema.table_constraints AS k " - , "WHERE c.table_catalog=current_database() " - , "AND c.table_catalog=k.table_catalog " - , "AND c.table_schema=current_schema() " - , "AND c.table_schema=k.table_schema " - , "AND c.table_name=? " - , "AND c.table_name=k.table_name " - , "AND c.constraint_name=k.constraint_name " - , "AND NOT k.constraint_type IN ('PRIMARY KEY', 'FOREIGN KEY') " - , "ORDER BY c.constraint_name, c.column_name" - ] - - stmt' <- getter sqlc - - us <- with (stmtQuery stmt' vals) (\src -> runConduit $ src .| helperU) - return $ columns ++ us - where - refMap = - fmap (\cr -> (crTableName cr, crConstraintName cr)) $ - Map.fromList $ - List.foldl' ref [] cols - where - ref rs c = - maybe rs (\r -> (unFieldNameDB $ cName c, r) : rs) (cReference c) - getAll = - CL.mapM $ \x -> - pure $ case x of - [PersistText con, PersistText col] -> - (con, col) - [PersistByteString con, PersistByteString col] -> - (T.decodeUtf8 con, T.decodeUtf8 col) - o -> - error $ "unexpected datatype returned for postgres o=" ++ show o - helperU = do - rows <- getAll .| CL.consume - return - $ map - (Right . Right . (ConstraintNameDB . fst . head &&& map (FieldNameDB . snd))) - $ groupBy ((==) `on` fst) rows - processColumns = - CL.mapM $ \x'@((PersistText cname) : _) -> do - col <- - liftIO $ getColumn getter (getEntityDBName def) x' (Map.lookup cname refMap) - pure $ case col of - Left e -> Left e - Right c -> Right $ Left c - -getColumn - :: (Text -> IO Statement) - -> EntityNameDB - -> [PersistValue] - -> Maybe (EntityNameDB, ConstraintNameDB) - -> IO (Either Text Column) -getColumn - getter - tableName' - [ PersistText columnName - , PersistText isNullable - , PersistText typeName - , defaultValue - , generationExpression - , numericPrecision - , numericScale - , maxlen - ] - refName_ = runExceptT $ do - defaultValue' <- - case defaultValue of - PersistNull -> - pure Nothing - PersistText t -> - pure $ Just t - _ -> - throwError $ T.pack $ "Invalid default column: " ++ show defaultValue - - generationExpression' <- - case generationExpression of - PersistNull -> - pure Nothing - PersistText t -> - pure $ Just t - _ -> - throwError $ T.pack $ "Invalid generated column: " ++ show generationExpression - - let - typeStr = - case maxlen of - PersistInt64 n -> - T.concat [typeName, "(", T.pack (show n), ")"] - _ -> - typeName - - t <- getType typeStr - - let - cname = FieldNameDB columnName - - ref <- lift $ fmap join $ traverse (getRef cname) refName_ - - return - Column - { cName = cname - , cNull = isNullable == "YES" - , cSqlType = t - , cDefault = fmap stripSuffixes defaultValue' - , cGenerated = fmap stripSuffixes generationExpression' - , cDefaultConstraintName = Nothing - , cMaxLen = Nothing - , cReference = fmap (\(a, b, c, d) -> ColumnReference a b (mkCascade c d)) ref - } - where - mkCascade updText delText = - FieldCascade - { fcOnUpdate = parseCascade updText - , fcOnDelete = parseCascade delText - } - - parseCascade txt = - case txt of - "NO ACTION" -> - Nothing - "CASCADE" -> - Just Cascade - "SET NULL" -> - Just SetNull - "SET DEFAULT" -> - Just SetDefault - "RESTRICT" -> - Just Restrict - _ -> - error $ "Unexpected value in parseCascade: " <> show txt - - stripSuffixes t = - loop' - [ "::character varying" - , "::text" - ] - where - loop' [] = t - loop' (p : ps) = - case T.stripSuffix p t of - Nothing -> loop' ps - Just t' -> t' - - getRef cname (_, refName') = do - let - sql = - T.concat - [ "SELECT DISTINCT " - , "ccu.table_name, " - , "tc.constraint_name, " - , "rc.update_rule, " - , "rc.delete_rule " - , "FROM information_schema.constraint_column_usage ccu " - , "INNER JOIN information_schema.key_column_usage kcu " - , " ON ccu.constraint_name = kcu.constraint_name " - , "INNER JOIN information_schema.table_constraints tc " - , " ON tc.constraint_name = kcu.constraint_name " - , "LEFT JOIN information_schema.referential_constraints AS rc" - , " ON rc.constraint_name = ccu.constraint_name " - , "WHERE tc.constraint_type='FOREIGN KEY' " - , "AND kcu.ordinal_position=1 " - , "AND kcu.table_name=? " - , "AND kcu.column_name=? " - , "AND tc.constraint_name=?" - ] - stmt <- getter sql - cntrs <- - with - ( stmtQuery - stmt - [ PersistText $ unEntityNameDB tableName' - , PersistText $ unFieldNameDB cname - , PersistText $ unConstraintNameDB refName' - ] - ) - (\src -> runConduit $ src .| CL.consume) - case cntrs of - [] -> - return Nothing - [ [ PersistText table - , PersistText constraint - , PersistText updRule - , PersistText delRule - ] - ] -> - return $ - Just (EntityNameDB table, ConstraintNameDB constraint, updRule, delRule) - xs -> - error $ - mconcat - [ "Postgresql.getColumn: error fetching constraints. Expected a single result for foreign key query for table: " - , T.unpack (unEntityNameDB tableName') - , " and column: " - , T.unpack (unFieldNameDB cname) - , " but got: " - , show xs - ] - - getType "int4" = pure SqlInt32 - getType "int8" = pure SqlInt64 - getType "varchar" = pure SqlString - getType "text" = pure SqlString - getType "date" = pure SqlDay - getType "bool" = pure SqlBool - getType "timestamptz" = pure SqlDayTime - getType "float4" = pure SqlReal - getType "float8" = pure SqlReal - getType "bytea" = pure SqlBlob - getType "time" = pure SqlTime - getType "numeric" = getNumeric numericPrecision numericScale - getType a = pure $ SqlOther a - - getNumeric (PersistInt64 a) (PersistInt64 b) = - pure $ SqlNumeric (fromIntegral a) (fromIntegral b) - getNumeric PersistNull PersistNull = - throwError $ - T.concat - [ "No precision and scale were specified for the column: " - , columnName - , " in table: " - , unEntityNameDB tableName' - , ". Postgres defaults to a maximum scale of 147,455 and precision of 16383," - , " which is probably not what you intended." - , " Specify the values as numeric(total_digits, digits_after_decimal_place)." - ] - getNumeric a b = - throwError $ - T.concat - [ "Can not get numeric field precision for the column: " - , columnName - , " in table: " - , unEntityNameDB tableName' - , ". Expected an integer for both precision and scale, " - , "got: " - , T.pack $ show a - , " and " - , T.pack $ show b - , ", respectively." - , " Specify the values as numeric(total_digits, digits_after_decimal_place)." - ] -getColumn _ _ columnName _ = - return $ - Left $ - T.pack $ - "Invalid result from information_schema: " ++ show columnName -- | Get the SQL string for the table that a PersistEntity represents. -- Useful for raw SQL queries. diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 0cb9010f0..74536e439 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -1,4 +1,5 @@ {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ViewPatterns #-} @@ -10,6 +11,7 @@ module Database.Persist.Postgresql.Internal , AlterTable (..) , AlterColumn (..) , SafeToRemove + , migrateStructured , mockMigrateStructured , addTable , findAlters @@ -59,6 +61,15 @@ import qualified Data.Text as T import Data.Time (NominalDiffTime, localTimeToUTC, utc) import Database.Persist.Sql import qualified Database.Persist.Sql.Util as Util +import Control.Arrow +import Control.Monad.Except +import Data.Acquire (with) +import Data.Conduit +import qualified Data.Conduit.List as CL +import Data.Function (on) +import Data.List as List ( foldl', groupBy) +import qualified Data.Map as Map +import qualified Data.Text.Encoding as T -- | Newtype used to avoid orphan instances for @postgresql-simple@ classes. -- @@ -373,6 +384,61 @@ data AlterDB | AlterTable EntityNameDB AlterTable deriving (Show) +-- | Returns a structured representation of all of the +-- DB changes required to migrate the Entity from its +-- current state in the database to the state described in +-- Haskell. +-- +-- @since 2.17.1.0 +migrateStructured + :: [EntityDef] + -> (Text -> IO Statement) + -> EntityDef + -> IO (Either [Text] [AlterDB]) +migrateStructured allDefs getter entity = do + old <- getColumns getter entity newcols' + case partitionEithers old of + ([], old'') -> do + exists' <- + if null old + then doesTableExist getter name + else return True + return $ Right $ migrationText exists' old'' + (errs, _) -> return $ Left errs + where + name = getEntityDBName entity + (newcols', udefs, fdefs) = postgresMkColumns allDefs entity + migrationText exists' old'' + | not exists' = + createText newcols fdefs udspair + | otherwise = + let + (acs, ats) = + getAlters allDefs entity (newcols, udspair) old' + acs' = map (AlterColumn name) acs + ats' = map (AlterTable name) ats + in + acs' ++ ats' + where + old' = partitionEithers old'' + newcols = filter (not . safeToRemove entity . cName) newcols' + udspair = map udToPair udefs + -- Check for table existence if there are no columns, workaround + -- for https://github.com/yesodweb/persistent/issues/152 + + createText newcols fdefs_ udspair = + (addTable newcols entity) : uniques ++ references ++ foreignsAlt + where + uniques = flip concatMap udspair $ \(uname, ucols) -> + [AlterTable name $ AddUniqueConstraint uname ucols] + references = + mapMaybe + ( \Column{cName, cReference} -> + getAddReference allDefs entity cName =<< cReference + ) + newcols + foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs_ + -- | Returns a structured representation of all of the -- DB changes required to migrate the Entity from its -- current state in the database to the state described in @@ -908,3 +974,319 @@ findAlters defs edef col@(Column name isNull sqltype def _gen _defConstraintName ( modRef ++ modDef ++ modNull ++ modType ++ dropSafe , filter (\c -> cName c /= name) cols ) + +-- | Returns all of the columns in the given table currently in the database. +getColumns + :: (Text -> IO Statement) + -> EntityDef + -> [Column] + -> IO [Either Text (Either Column (ConstraintNameDB, [FieldNameDB]))] +getColumns getter def cols = do + let + sqlv = + T.concat + [ "SELECT " + , "column_name " + , ",is_nullable " + , ",COALESCE(domain_name, udt_name)" -- See DOMAINS below + , ",column_default " + , ",generation_expression " + , ",numeric_precision " + , ",numeric_scale " + , ",character_maximum_length " + , "FROM information_schema.columns " + , "WHERE table_catalog=current_database() " + , "AND table_schema=current_schema() " + , "AND table_name=? " + ] + + -- DOMAINS Postgres supports the concept of domains, which are data types + -- with optional constraints. An app might make an "email" domain over the + -- varchar type, with a CHECK that the emails are valid In this case the + -- generated SQL should use the domain name: ALTER TABLE users ALTER COLUMN + -- foo TYPE email This code exists to use the domain name (email), instead + -- of the underlying type (varchar). This is tested in + -- EquivalentTypeTest.hs + + stmt <- getter sqlv + let + vals = + [ PersistText $ unEntityNameDB $ getEntityDBName def + ] + columns <- + with + (stmtQuery stmt vals) + (\src -> runConduit $ src .| processColumns .| CL.consume) + let + sqlc = + T.concat + [ "SELECT " + , "c.constraint_name, " + , "c.column_name " + , "FROM information_schema.key_column_usage AS c, " + , "information_schema.table_constraints AS k " + , "WHERE c.table_catalog=current_database() " + , "AND c.table_catalog=k.table_catalog " + , "AND c.table_schema=current_schema() " + , "AND c.table_schema=k.table_schema " + , "AND c.table_name=? " + , "AND c.table_name=k.table_name " + , "AND c.constraint_name=k.constraint_name " + , "AND NOT k.constraint_type IN ('PRIMARY KEY', 'FOREIGN KEY') " + , "ORDER BY c.constraint_name, c.column_name" + ] + + stmt' <- getter sqlc + + us <- with (stmtQuery stmt' vals) (\src -> runConduit $ src .| helperU) + return $ columns ++ us + where + refMap = + fmap (\cr -> (crTableName cr, crConstraintName cr)) $ + Map.fromList $ + List.foldl' ref [] cols + where + ref rs c = + maybe rs (\r -> (unFieldNameDB $ cName c, r) : rs) (cReference c) + getAll = + CL.mapM $ \x -> + pure $ case x of + [PersistText con, PersistText col] -> + (con, col) + [PersistByteString con, PersistByteString col] -> + (T.decodeUtf8 con, T.decodeUtf8 col) + o -> + error $ "unexpected datatype returned for postgres o=" ++ show o + helperU = do + rows <- getAll .| CL.consume + return + $ map + (Right . Right . (ConstraintNameDB . fst . head &&& map (FieldNameDB . snd))) + $ groupBy ((==) `on` fst) rows + processColumns = + CL.mapM $ \x'@((PersistText cname) : _) -> do + col <- + liftIO $ getColumn getter (getEntityDBName def) x' (Map.lookup cname refMap) + pure $ case col of + Left e -> Left e + Right c -> Right $ Left c + + +getColumn + :: (Text -> IO Statement) + -> EntityNameDB + -> [PersistValue] + -> Maybe (EntityNameDB, ConstraintNameDB) + -> IO (Either Text Column) +getColumn + getter + tableName' + [ PersistText columnName + , PersistText isNullable + , PersistText typeName + , defaultValue + , generationExpression + , numericPrecision + , numericScale + , maxlen + ] + refName_ = runExceptT $ do + defaultValue' <- + case defaultValue of + PersistNull -> + pure Nothing + PersistText t -> + pure $ Just t + _ -> + throwError $ T.pack $ "Invalid default column: " ++ show defaultValue + + generationExpression' <- + case generationExpression of + PersistNull -> + pure Nothing + PersistText t -> + pure $ Just t + _ -> + throwError $ T.pack $ "Invalid generated column: " ++ show generationExpression + + let + typeStr = + case maxlen of + PersistInt64 n -> + T.concat [typeName, "(", T.pack (show n), ")"] + _ -> + typeName + + t <- getType typeStr + + let + cname = FieldNameDB columnName + + ref <- lift $ fmap join $ traverse (getRef cname) refName_ + + return + Column + { cName = cname + , cNull = isNullable == "YES" + , cSqlType = t + , cDefault = fmap stripSuffixes defaultValue' + , cGenerated = fmap stripSuffixes generationExpression' + , cDefaultConstraintName = Nothing + , cMaxLen = Nothing + , cReference = fmap (\(a, b, c, d) -> ColumnReference a b (mkCascade c d)) ref + } + where + mkCascade updText delText = + FieldCascade + { fcOnUpdate = parseCascade updText + , fcOnDelete = parseCascade delText + } + + parseCascade txt = + case txt of + "NO ACTION" -> + Nothing + "CASCADE" -> + Just Cascade + "SET NULL" -> + Just SetNull + "SET DEFAULT" -> + Just SetDefault + "RESTRICT" -> + Just Restrict + _ -> + error $ "Unexpected value in parseCascade: " <> show txt + + stripSuffixes t = + loop' + [ "::character varying" + , "::text" + ] + where + loop' [] = t + loop' (p : ps) = + case T.stripSuffix p t of + Nothing -> loop' ps + Just t' -> t' + + getRef cname (_, refName') = do + let + sql = + T.concat + [ "SELECT DISTINCT " + , "ccu.table_name, " + , "tc.constraint_name, " + , "rc.update_rule, " + , "rc.delete_rule " + , "FROM information_schema.constraint_column_usage ccu " + , "INNER JOIN information_schema.key_column_usage kcu " + , " ON ccu.constraint_name = kcu.constraint_name " + , "INNER JOIN information_schema.table_constraints tc " + , " ON tc.constraint_name = kcu.constraint_name " + , "LEFT JOIN information_schema.referential_constraints AS rc" + , " ON rc.constraint_name = ccu.constraint_name " + , "WHERE tc.constraint_type='FOREIGN KEY' " + , "AND kcu.ordinal_position=1 " + , "AND kcu.table_name=? " + , "AND kcu.column_name=? " + , "AND tc.constraint_name=?" + ] + stmt <- getter sql + cntrs <- + with + ( stmtQuery + stmt + [ PersistText $ unEntityNameDB tableName' + , PersistText $ unFieldNameDB cname + , PersistText $ unConstraintNameDB refName' + ] + ) + (\src -> runConduit $ src .| CL.consume) + case cntrs of + [] -> + return Nothing + [ [ PersistText table + , PersistText constraint + , PersistText updRule + , PersistText delRule + ] + ] -> + return $ + Just (EntityNameDB table, ConstraintNameDB constraint, updRule, delRule) + xs -> + error $ + mconcat + [ "Postgresql.getColumn: error fetching constraints. Expected a single result for foreign key query for table: " + , T.unpack (unEntityNameDB tableName') + , " and column: " + , T.unpack (unFieldNameDB cname) + , " but got: " + , show xs + ] + + getType "int4" = pure SqlInt32 + getType "int8" = pure SqlInt64 + getType "varchar" = pure SqlString + getType "text" = pure SqlString + getType "date" = pure SqlDay + getType "bool" = pure SqlBool + getType "timestamptz" = pure SqlDayTime + getType "float4" = pure SqlReal + getType "float8" = pure SqlReal + getType "bytea" = pure SqlBlob + getType "time" = pure SqlTime + getType "numeric" = getNumeric numericPrecision numericScale + getType a = pure $ SqlOther a + + getNumeric (PersistInt64 a) (PersistInt64 b) = + pure $ SqlNumeric (fromIntegral a) (fromIntegral b) + getNumeric PersistNull PersistNull = + throwError $ + T.concat + [ "No precision and scale were specified for the column: " + , columnName + , " in table: " + , unEntityNameDB tableName' + , ". Postgres defaults to a maximum scale of 147,455 and precision of 16383," + , " which is probably not what you intended." + , " Specify the values as numeric(total_digits, digits_after_decimal_place)." + ] + getNumeric a b = + throwError $ + T.concat + [ "Can not get numeric field precision for the column: " + , columnName + , " in table: " + , unEntityNameDB tableName' + , ". Expected an integer for both precision and scale, " + , "got: " + , T.pack $ show a + , " and " + , T.pack $ show b + , ", respectively." + , " Specify the values as numeric(total_digits, digits_after_decimal_place)." + ] +getColumn _ _ columnName _ = + return $ + Left $ + T.pack $ + "Invalid result from information_schema: " ++ show columnName + +doesTableExist + :: (Text -> IO Statement) + -> EntityNameDB + -> IO Bool +doesTableExist getter (EntityNameDB name) = do + stmt <- getter sql + with (stmtQuery stmt vals) (\src -> runConduit $ src .| start) + where + sql = + "SELECT COUNT(*) FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog'" + <> " AND schemaname != 'information_schema' AND tablename=?" + vals = [PersistText name] + + start = await >>= maybe (error "No results when checking doesTableExist") start' + start' [PersistInt64 0] = finish False + start' [PersistInt64 1] = finish True + start' res = error $ "doesTableExist returned unexpected result: " ++ show res + finish x = await >>= maybe (return x) (error "Too many rows returned in doesTableExist") From 05d54d7da2a4eb6318ad2bd2b295e31f52d0165f Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Wed, 2 Jul 2025 09:30:46 -0400 Subject: [PATCH 18/30] fourmolu --- .../Database/Persist/Postgresql.hs | 1 - .../Database/Persist/Postgresql/Internal.hs | 24 +++++++++---------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index 5006e5696..729cc19bf 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -726,7 +726,6 @@ mkForeignAlt entity fdef = case childfields of escapedParentFields = map escapeF parentfields - -- | Get the SQL string for the table that a PersistEntity represents. -- Useful for raw SQL queries. tableName :: (PersistEntity record) => record -> Text diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 74536e439..5ce5cf346 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -1,5 +1,5 @@ -{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ViewPatterns #-} @@ -40,36 +40,35 @@ import qualified Database.PostgreSQL.Simple.TypeInfo.Static as PS import qualified Database.PostgreSQL.Simple.Types as PG import qualified Blaze.ByteString.Builder.Char8 as BBB +import Control.Arrow import Control.Monad +import Control.Monad.Except +import Data.Acquire (with) import qualified Data.Attoparsec.ByteString.Char8 as P import Data.Bits ((.&.)) import Data.ByteString (ByteString) import qualified Data.ByteString.Builder as BB import qualified Data.ByteString.Char8 as B8 import Data.Char (ord) +import Data.Conduit +import qualified Data.Conduit.List as CL import Data.Data (Typeable) import Data.Either (partitionEithers) import Data.Fixed (Fixed (..), Pico) +import Data.Function (on) import Data.Int (Int64) import qualified Data.IntMap as I -import Data.List as List (find, sort) +import Data.List as List (find, foldl', groupBy, sort) import qualified Data.List.NonEmpty as NEL +import qualified Data.Map as Map import Data.Maybe import Data.String.Conversions.Monomorphic (toStrictByteString) import Data.Text (Text) import qualified Data.Text as T +import qualified Data.Text.Encoding as T import Data.Time (NominalDiffTime, localTimeToUTC, utc) import Database.Persist.Sql import qualified Database.Persist.Sql.Util as Util -import Control.Arrow -import Control.Monad.Except -import Data.Acquire (with) -import Data.Conduit -import qualified Data.Conduit.List as CL -import Data.Function (on) -import Data.List as List ( foldl', groupBy) -import qualified Data.Map as Map -import qualified Data.Text.Encoding as T -- | Newtype used to avoid orphan instances for @postgresql-simple@ classes. -- @@ -395,7 +394,7 @@ migrateStructured -> (Text -> IO Statement) -> EntityDef -> IO (Either [Text] [AlterDB]) -migrateStructured allDefs getter entity = do +migrateStructured allDefs getter entity = do old <- getColumns getter entity newcols' case partitionEithers old of ([], old'') -> do @@ -1071,7 +1070,6 @@ getColumns getter def cols = do Left e -> Left e Right c -> Right $ Left c - getColumn :: (Text -> IO Statement) -> EntityNameDB From 7439e38b19c160775dc261116df6fd03ad7b0776 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Wed, 2 Jul 2025 09:32:41 -0400 Subject: [PATCH 19/30] remove unused --- .../Database/Persist/Postgresql.hs | 41 +------------------ 1 file changed, 2 insertions(+), 39 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index 729cc19bf..c7b270b20 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -76,13 +76,11 @@ import Database.PostgreSQL.Simple.Ok (Ok (..)) import qualified Database.PostgreSQL.Simple.Transaction as PG import qualified Database.PostgreSQL.Simple.Types as PG -import Control.Arrow import Control.Exception (Exception, throw, throwIO) import Control.Monad import Control.Monad.Except -import Control.Monad.IO.Unlift (MonadIO (..), MonadUnliftIO) +import Control.Monad.IO.Unlift (MonadUnliftIO) import Control.Monad.Logger (MonadLoggerIO, runNoLoggingT) -import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Reader (ReaderT (..), asks, runReaderT) #if !MIN_VERSION_base(4,12,0) import Control.Monad.Trans.Reader (withReaderT) @@ -91,25 +89,19 @@ import Control.Monad.Trans.Writer (WriterT (..), runWriterT) import qualified Data.List.NonEmpty as NEL import Data.Proxy (Proxy (..)) -import Data.Acquire (Acquire, mkAcquire, with) +import Data.Acquire (Acquire, mkAcquire) import Data.Aeson import Data.Aeson.Types (modifyFailure) import qualified Data.Attoparsec.Text as AT import Data.ByteString (ByteString) import qualified Data.ByteString.Char8 as B8 import Data.Conduit -import qualified Data.Conduit.List as CL import Data.Data (Data) import Data.Either (partitionEithers) -import Data.Function (on) import Data.IORef import Data.Int (Int64) -import Data.List as List (find, foldl', groupBy, sort) -import qualified Data.List as List import Data.List.NonEmpty (NonEmpty) import qualified Data.Map as Map -import Data.Maybe -import Data.Monoid ((<>)) import qualified Data.Monoid as Monoid import Data.Pool (Pool) import Data.Text (Text) @@ -127,11 +119,6 @@ import Database.Persist.Postgresql.Internal import Database.Persist.Sql import qualified Database.Persist.Sql.Util as Util import Database.Persist.SqlBackend -import Database.Persist.SqlBackend.StatementCache - ( StatementCache - , mkSimpleStatementCache - , mkStatementCache - ) import System.IO.Unsafe (unsafePerformIO) -- | A @libpq@ connection string. A simple example of connection @@ -702,30 +689,6 @@ migrate' -> IO (Either [Text] CautiousMigration) migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ migrateStructured allDefs getter entity -mkForeignAlt - :: EntityDef - -> ForeignDef - -> Maybe AlterDB -mkForeignAlt entity fdef = case childfields of - [] -> Nothing - f : r -> Just $ AlterColumn tableName_ (addReference) - where - addReference = - AddReference - (foreignRefTableDBName fdef) - constraintName - (f NEL.:| r) - escapedParentFields - (foreignFieldCascade fdef) - where - tableName_ = getEntityDBName entity - constraintName = - foreignConstraintNameDBName fdef - (childfields, parentfields) = - unzip (map (\((_, b), (_, d)) -> (b, d)) (foreignFields fdef)) - escapedParentFields = - map escapeF parentfields - -- | Get the SQL string for the table that a PersistEntity represents. -- Useful for raw SQL queries. tableName :: (PersistEntity record) => record -> Text From 77efee2508cc5796d2b90d18e4dd6f18af1cf7dc Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Wed, 2 Jul 2025 09:39:11 -0400 Subject: [PATCH 20/30] Removed redundant code from mock migration --- .../Database/Persist/Postgresql/Internal.hs | 21 +++++-------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 5ce5cf346..4ac9246e4 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -439,31 +439,20 @@ migrateStructured allDefs getter entity = do foreignsAlt = mapMaybe (mkForeignAlt entity) fdefs_ -- | Returns a structured representation of all of the --- DB changes required to migrate the Entity from its --- current state in the database to the state described in --- Haskell. +-- DB changes required to migrate the Entity to the state +-- described in Haskell, assuming it currently does not +-- exist in the database. -- -- @since 2.17.1.0 mockMigrateStructured :: [EntityDef] -> EntityDef -> IO (Either [Text] [AlterDB]) -mockMigrateStructured allDefs entity = - return $ Right $ migrationText False [] +mockMigrateStructured allDefs entity = return $ Right migrationText where name = getEntityDBName entity - migrationText exists' old'' = - if not exists' - then createText newcols fdefs udspair - else - let - (acs, ats) = getAlters allDefs entity (newcols, udspair) old' - acs' = map (AlterColumn name) acs - ats' = map (AlterTable name) ats - in - acs' ++ ats' + migrationText = createText newcols fdefs udspair where - old' = partitionEithers old'' (newcols', udefs, fdefs) = postgresMkColumns allDefs entity newcols = filter (not . safeToRemove entity . cName) newcols' udspair = map udToPair udefs From f5b0912dd1a85162bd390a5de3c76d0287ff1cdd Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Wed, 2 Jul 2025 14:13:15 -0400 Subject: [PATCH 21/30] tests compile --- persistent-postgresql/Database/Persist/Postgresql.hs | 2 +- persistent-postgresql/Database/Persist/Postgresql/Internal.hs | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index c7b270b20..f1dafa4c0 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -79,7 +79,7 @@ import qualified Database.PostgreSQL.Simple.Types as PG import Control.Exception (Exception, throw, throwIO) import Control.Monad import Control.Monad.Except -import Control.Monad.IO.Unlift (MonadUnliftIO) +import Control.Monad.IO.Unlift (MonadIO (..),MonadUnliftIO) import Control.Monad.Logger (MonadLoggerIO, runNoLoggingT) import Control.Monad.Trans.Reader (ReaderT (..), asks, runReaderT) #if !MIN_VERSION_base(4,12,0) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 4ac9246e4..f47390e9c 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -43,6 +43,8 @@ import qualified Blaze.ByteString.Builder.Char8 as BBB import Control.Arrow import Control.Monad import Control.Monad.Except +import Control.Monad.Trans.Class (lift) +import Control.Monad.IO.Unlift (MonadIO(..)) import Data.Acquire (with) import qualified Data.Attoparsec.ByteString.Char8 as P import Data.Bits ((.&.)) From 5a69456405e63ad63dd4ad32758078b130864653 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Wed, 2 Jul 2025 14:13:35 -0400 Subject: [PATCH 22/30] Fourmolu --- persistent-postgresql/Database/Persist/Postgresql.hs | 2 +- persistent-postgresql/Database/Persist/Postgresql/Internal.hs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index f1dafa4c0..461f23e89 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -79,7 +79,7 @@ import qualified Database.PostgreSQL.Simple.Types as PG import Control.Exception (Exception, throw, throwIO) import Control.Monad import Control.Monad.Except -import Control.Monad.IO.Unlift (MonadIO (..),MonadUnliftIO) +import Control.Monad.IO.Unlift (MonadIO (..), MonadUnliftIO) import Control.Monad.Logger (MonadLoggerIO, runNoLoggingT) import Control.Monad.Trans.Reader (ReaderT (..), asks, runReaderT) #if !MIN_VERSION_base(4,12,0) diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index f47390e9c..203154a5a 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -43,8 +43,8 @@ import qualified Blaze.ByteString.Builder.Char8 as BBB import Control.Arrow import Control.Monad import Control.Monad.Except +import Control.Monad.IO.Unlift (MonadIO (..)) import Control.Monad.Trans.Class (lift) -import Control.Monad.IO.Unlift (MonadIO(..)) import Data.Acquire (with) import qualified Data.Attoparsec.ByteString.Char8 as P import Data.Bits ((.&.)) From 32a6d686126cc9efc89e9e6c22ee79832b715c94 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Wed, 2 Jul 2025 14:36:10 -0400 Subject: [PATCH 23/30] fixed typo --- persistent/ChangeLog.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/persistent/ChangeLog.md b/persistent/ChangeLog.md index 927628bda..c52a3f7b1 100644 --- a/persistent/ChangeLog.md +++ b/persistent/ChangeLog.md @@ -3,9 +3,9 @@ # 2.17.1.0 * [#1600](https://github.com/yesodweb/persistent/pull/1600) - * Add `mockMigrateStructured` to `Database.Persist.Postgresql.Internal`. - This allows you to access a structured representation of the mock migrations - for use in yor application. + * Add `migrateStructured` to `Database.Persist.Postgresql.Internal`. + This allows you to access a structured representation of the proposed migrations + for use in your application. # 2.17.0.0 From 85a22aeab12fa41a78858069f594ea0cc79efd51 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Thu, 3 Jul 2025 08:01:00 -0400 Subject: [PATCH 24/30] update type --- persistent-postgresql/Database/Persist/Postgresql.hs | 2 +- persistent-postgresql/Database/Persist/Postgresql/Internal.hs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index 461f23e89..ad51fdf59 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -825,7 +825,7 @@ mockMigrate -> (Text -> IO Statement) -> EntityDef -> IO (Either [Text] [(Bool, Text)]) -mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ mockMigrateStructured allDefs entity +mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ return $ Right $ mockMigrateStructured allDefs entity -- | Mock a migration even when the database is not present. -- This function performs the same functionality of 'printMigration' diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs index 203154a5a..7058145e5 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal.hs @@ -449,8 +449,8 @@ migrateStructured allDefs getter entity = do mockMigrateStructured :: [EntityDef] -> EntityDef - -> IO (Either [Text] [AlterDB]) -mockMigrateStructured allDefs entity = return $ Right migrationText + -> [AlterDB] +mockMigrateStructured allDefs entity = migrationText where name = getEntityDBName entity migrationText = createText newcols fdefs udspair From 9929a469a9797851523736cc2e9035aa8e0bc55a Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Thu, 3 Jul 2025 08:01:30 -0400 Subject: [PATCH 25/30] fourmolu --- persistent-postgresql/Database/Persist/Postgresql.hs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index ad51fdf59..c253b870b 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -825,7 +825,11 @@ mockMigrate -> (Text -> IO Statement) -> EntityDef -> IO (Either [Text] [(Bool, Text)]) -mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ return $ Right $ mockMigrateStructured allDefs entity +mockMigrate allDefs _ entity = + fmap (fmap $ map showAlterDb) $ + return $ + Right $ + mockMigrateStructured allDefs entity -- | Mock a migration even when the database is not present. -- This function performs the same functionality of 'printMigration' From f0f82f9b51601add3a6d2130c689b5602e0ae066 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Mon, 7 Jul 2025 13:13:40 -0400 Subject: [PATCH 26/30] update changelog --- persistent-postgresql/ChangeLog.md | 6 +++++- persistent/ChangeLog.md | 7 ------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/persistent-postgresql/ChangeLog.md b/persistent-postgresql/ChangeLog.md index 14cc261c3..b91d977df 100644 --- a/persistent-postgresql/ChangeLog.md +++ b/persistent-postgresql/ChangeLog.md @@ -1,7 +1,11 @@ # Changelog for persistent-postgresql -## Unreleased +## 2.13.7.0 +* [#1600](https://github.com/yesodweb/persistent/pull/1600) + * Add `migrateStructured` to `Database.Persist.Postgresql.Internal`. + This allows you to access a structured representation of the proposed migrations + for use in your application. * [#1547](https://github.com/yesodweb/persistent/pull/1547) * Bump `libpq` bounds diff --git a/persistent/ChangeLog.md b/persistent/ChangeLog.md index c52a3f7b1..de47d1f4d 100644 --- a/persistent/ChangeLog.md +++ b/persistent/ChangeLog.md @@ -1,12 +1,5 @@ # Changelog for persistent -# 2.17.1.0 - -* [#1600](https://github.com/yesodweb/persistent/pull/1600) - * Add `migrateStructured` to `Database.Persist.Postgresql.Internal`. - This allows you to access a structured representation of the proposed migrations - for use in your application. - # 2.17.0.0 * [#1595](https://github.com/yesodweb/persistent/pull/1595) From a11ccc21eb5876e851dda2fdf1a2e5b29338c94f Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Mon, 12 Jan 2026 14:31:33 -0500 Subject: [PATCH 27/30] thread default through --- .../Database/Persist/Postgresql.hs | 18 +++--- .../Persist/Postgresql/Internal/Migration.hs | 44 +++++++++------ persistent-postgresql/test/MigrationSpec.hs | 56 +++++++++++++++++-- persistent/Database/Persist/Sql.hs | 4 +- persistent/Database/Persist/Sql/Internal.hs | 34 ++++++++++- 5 files changed, 122 insertions(+), 34 deletions(-) diff --git a/persistent-postgresql/Database/Persist/Postgresql.hs b/persistent-postgresql/Database/Persist/Postgresql.hs index c253b870b..226f55e73 100644 --- a/persistent-postgresql/Database/Persist/Postgresql.hs +++ b/persistent-postgresql/Database/Persist/Postgresql.hs @@ -506,7 +506,7 @@ createBackend logFunc serverVersion smap conn = , connStmtMap = smap , connInsertSql = insertSql' , connClose = PG.close conn - , connMigrateSql = migrate' + , connMigrateSql = migrate' emptyBackendSpecificOverrides , connBegin = \_ mIsolation -> case mIsolation of Nothing -> PG.begin conn Just iso -> @@ -683,11 +683,14 @@ withStmt' conn query vals = Ok v -> return v migrate' - :: [EntityDef] + :: BackendSpecificOverrides + -> [EntityDef] -> (Text -> IO Statement) -> EntityDef -> IO (Either [Text] CautiousMigration) -migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ migrateStructured allDefs getter entity +migrate' overrides allDefs getter entity = + fmap (fmap $ map showAlterDb) $ + migrateStructured overrides allDefs getter entity -- | Get the SQL string for the table that a PersistEntity represents. -- Useful for raw SQL queries. @@ -821,15 +824,16 @@ defaultPostgresConfHooks = } mockMigrate - :: [EntityDef] + :: BackendSpecificOverrides + -> [EntityDef] -> (Text -> IO Statement) -> EntityDef -> IO (Either [Text] [(Bool, Text)]) -mockMigrate allDefs _ entity = +mockMigrate overrides allDefs _ entity = fmap (fmap $ map showAlterDb) $ return $ Right $ - mockMigrateStructured allDefs entity + mockMigrateStructured overrides allDefs entity -- | Mock a migration even when the database is not present. -- This function performs the same functionality of 'printMigration' @@ -852,7 +856,7 @@ mockMigration mig = do , connInsertSql = undefined , connStmtMap = smap , connClose = undefined - , connMigrateSql = mockMigrate + , connMigrateSql = mockMigrate emptyBackendSpecificOverrides , connBegin = undefined , connCommit = undefined , connRollback = undefined diff --git a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs index 321a41445..051244e16 100644 --- a/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs +++ b/persistent-postgresql/Database/Persist/Postgresql/Internal/Migration.hs @@ -39,12 +39,13 @@ import qualified Database.Persist.Sql.Util as Util -- -- @since 2.17.1.0 migrateStructured - :: [EntityDef] + :: BackendSpecificOverrides + -> [EntityDef] -> (Text -> IO Statement) -> EntityDef -> IO (Either [Text] [AlterDB]) -migrateStructured allDefs getter entity = - migrateEntitiesStructured getter allDefs [entity] +migrateStructured overrides allDefs getter entity = + migrateEntitiesStructured overrides getter allDefs [entity] -- | Returns a structured representation of all of the DB changes required to -- migrate the listed entities from their current state in the database to the @@ -54,15 +55,16 @@ migrateStructured allDefs getter entity = -- -- @since 2.14.1.0 migrateEntitiesStructured - :: (Text -> IO Statement) + :: BackendSpecificOverrides + -> (Text -> IO Statement) -> [EntityDef] -> [EntityDef] -> IO (Either [Text] [AlterDB]) -migrateEntitiesStructured getStmt allDefs defsToMigrate = do +migrateEntitiesStructured overrides getStmt allDefs defsToMigrate = do r <- collectSchemaState getStmt (map getEntityDBName defsToMigrate) pure $ case r of Right schemaState -> - migrateEntitiesFromSchemaState schemaState allDefs defsToMigrate + migrateEntitiesFromSchemaState overrides schemaState allDefs defsToMigrate Left err -> Left [err] @@ -73,11 +75,12 @@ migrateEntitiesStructured getStmt allDefs defsToMigrate = do -- -- @since 2.17.1.0 mockMigrateStructured - :: [EntityDef] + :: BackendSpecificOverrides + -> [EntityDef] -> EntityDef -> [AlterDB] -mockMigrateStructured allDefs entity = - migrateEntityFromSchemaState EntityDoesNotExist allDefs entity +mockMigrateStructured overrides allDefs entity = + migrateEntityFromSchemaState overrides EntityDoesNotExist allDefs entity -- | In order to ensure that generating migrations is fast and avoids N+1 -- queries, we split it into two phases. The first phase involves querying the @@ -532,11 +535,12 @@ mapLeft _ (Right x) = Right x mapLeft f (Left x) = Left (f x) migrateEntitiesFromSchemaState - :: SchemaState + :: BackendSpecificOverrides + -> SchemaState -> [EntityDef] -> [EntityDef] -> Either [Text] [AlterDB] -migrateEntitiesFromSchemaState (SchemaState schemaStateMap) allDefs defsToMigrate = +migrateEntitiesFromSchemaState overrides (SchemaState schemaStateMap) allDefs defsToMigrate = let go :: EntityDef -> Either Text [AlterDB] go entity = do @@ -544,7 +548,7 @@ migrateEntitiesFromSchemaState (SchemaState schemaStateMap) allDefs defsToMigrat name = getEntityDBName entity case Map.lookup name schemaStateMap of Just entityState -> - Right $ migrateEntityFromSchemaState entityState allDefs entity + Right $ migrateEntityFromSchemaState overrides entityState allDefs entity Nothing -> Left $ T.pack $ "No entry for entity in schemaState: " <> show name in @@ -553,11 +557,12 @@ migrateEntitiesFromSchemaState (SchemaState schemaStateMap) allDefs defsToMigrat (errs, _) -> Left errs migrateEntityFromSchemaState - :: EntitySchemaState + :: BackendSpecificOverrides + -> EntitySchemaState -> [EntityDef] -> EntityDef -> [AlterDB] -migrateEntityFromSchemaState schemaState allDefs entity = +migrateEntityFromSchemaState overrides schemaState allDefs entity = case schemaState of EntityDoesNotExist -> (addTable newcols entity) : uniques ++ references ++ foreignsAlt @@ -577,7 +582,7 @@ migrateEntityFromSchemaState schemaState allDefs entity = acs' ++ ats' where name = getEntityDBName entity - (newcols', udefs, fdefs) = postgresMkColumns allDefs entity + (newcols', udefs, fdefs) = postgresMkColumns overrides allDefs entity newcols = filter (not . safeToRemove entity . cName) newcols' udspair = map udToPair udefs @@ -822,10 +827,13 @@ refName (EntityNameDB table) (FieldNameDB column) = | otherwise = shortenNames overhead (x, y - 1) postgresMkColumns - :: [EntityDef] -> EntityDef -> ([Column], [UniqueDef], [ForeignDef]) -postgresMkColumns allDefs t = + :: BackendSpecificOverrides + -> [EntityDef] + -> EntityDef + -> ([Column], [UniqueDef], [ForeignDef]) +postgresMkColumns overrides allDefs t = mkColumns allDefs t $ - setBackendSpecificForeignKeyName refName emptyBackendSpecificOverrides + setBackendSpecificForeignKeyName refName overrides -- | Check if a column name is listed as the "safe to remove" in the entity -- list. diff --git a/persistent-postgresql/test/MigrationSpec.hs b/persistent-postgresql/test/MigrationSpec.hs index 7600d06cc..b256561b6 100644 --- a/persistent-postgresql/test/MigrationSpec.hs +++ b/persistent-postgresql/test/MigrationSpec.hs @@ -67,7 +67,7 @@ FKChildV1 sql=migration_fk_child -- Simulate creating a new FK field on an existing table FKChildV2 sql=migration_fk_child - parentId FKParentId + parentId FKParentId OnUpdateNoAction ExplicitPrimaryKey sql=explicit_primary_key Id Text @@ -585,7 +585,12 @@ spec = describe "MigrationSpec" $ do getter <- getStmtGetter result <- - liftIO $ migrateEntitiesStructured getter allEntityDefs allEntityDefs + liftIO $ + migrateEntitiesStructured + emptyBackendSpecificOverrides + getter + allEntityDefs + allEntityDefs cleanDB @@ -602,7 +607,12 @@ spec = describe "MigrationSpec" $ do getter <- getStmtGetter result <- - liftIO $ migrateEntitiesStructured getter allEntityDefs allEntityDefs + liftIO $ + migrateEntitiesStructured + emptyBackendSpecificOverrides + getter + allEntityDefs + allEntityDefs cleanDB @@ -614,7 +624,12 @@ spec = describe "MigrationSpec" $ do Right alters -> do traverse_ (flip rawExecute [] . snd . showAlterDb) alters result2 <- - liftIO $ migrateEntitiesStructured getter allEntityDefs allEntityDefs + liftIO $ + migrateEntitiesStructured + emptyBackendSpecificOverrides + getter + allEntityDefs + allEntityDefs result2 `shouldBe` Right [] it "suggests FK constraints for new fields first time" $ runConnAssert $ do @@ -624,6 +639,37 @@ spec = describe "MigrationSpec" $ do result <- liftIO $ migrateEntitiesStructured + emptyBackendSpecificOverrides + getter + (fkChildV2EntityDef : allEntityDefs) + [fkChildV2EntityDef] + + cleanDB + + case result of + Right [] -> + pure () + Left err -> + expectationFailure $ show err + Right alters -> + map (snd . showAlterDb) alters + `shouldBe` [ "ALTER TABLE \"migration_fk_child\" ADD COLUMN \"parent_id\" INT8 NOT NULL" + , "ALTER TABLE \"migration_fk_child\" ADD CONSTRAINT \"migration_fk_child_parent_id_fkey\" FOREIGN KEY(\"parent_id\") REFERENCES \"migration_fk_parent\"(\"id\") ON DELETE RESTRICT ON UPDATE NO ACTION" + ] + + it "Uses overrides for empty cascade action" $ runConnAssert $ do + migrateManually + + getter <- getStmtGetter + + let + overrideWithDefault = + ( setBackendSpecificForeignKeyCascadeDefault Cascade emptyBackendSpecificOverrides + ) + result <- + liftIO $ + migrateEntitiesStructured + overrideWithDefault getter (fkChildV2EntityDef : allEntityDefs) [fkChildV2EntityDef] @@ -638,5 +684,5 @@ spec = describe "MigrationSpec" $ do Right alters -> map (snd . showAlterDb) alters `shouldBe` [ "ALTER TABLE \"migration_fk_child\" ADD COLUMN \"parent_id\" INT8 NOT NULL" - , "ALTER TABLE \"migration_fk_child\" ADD CONSTRAINT \"migration_fk_child_parent_id_fkey\" FOREIGN KEY(\"parent_id\") REFERENCES \"migration_fk_parent\"(\"id\") ON DELETE RESTRICT ON UPDATE RESTRICT" + , "ALTER TABLE \"migration_fk_child\" ADD CONSTRAINT \"migration_fk_child_parent_id_fkey\" FOREIGN KEY(\"parent_id\") REFERENCES \"migration_fk_parent\"(\"id\") ON DELETE CASCADE ON UPDATE NO ACTION" ] diff --git a/persistent/Database/Persist/Sql.hs b/persistent/Database/Persist/Sql.hs index e450cf9e5..7dd6404f6 100644 --- a/persistent/Database/Persist/Sql.hs +++ b/persistent/Database/Persist/Sql.hs @@ -9,7 +9,7 @@ -- Then, you'll use the operations module Database.Persist.Sql ( -- * 'RawSql' and 'PersistFieldSql' - module Database.Persist.Sql.Class + module Database.Persist.Sql.Class -- * Running actions @@ -61,6 +61,8 @@ module Database.Persist.Sql , emptyBackendSpecificOverrides , getBackendSpecificForeignKeyName , setBackendSpecificForeignKeyName + , getBackendSpecificForeignKeyCascadeDefault + , setBackendSpecificForeignKeyCascadeDefault , defaultAttribute -- * Internal diff --git a/persistent/Database/Persist/Sql/Internal.hs b/persistent/Database/Persist/Sql/Internal.hs index 8aa6adfa0..04d5e390b 100644 --- a/persistent/Database/Persist/Sql/Internal.hs +++ b/persistent/Database/Persist/Sql/Internal.hs @@ -10,6 +10,8 @@ module Database.Persist.Sql.Internal , BackendSpecificOverrides (..) , getBackendSpecificForeignKeyName , setBackendSpecificForeignKeyName + , getBackendSpecificForeignKeyCascadeDefault + , setBackendSpecificForeignKeyCascadeDefault , emptyBackendSpecificOverrides ) where @@ -36,6 +38,7 @@ import Database.Persist.Types data BackendSpecificOverrides = BackendSpecificOverrides { backendSpecificForeignKeyName :: Maybe (EntityNameDB -> FieldNameDB -> ConstraintNameDB) + , backendSpecificForeignKeyCascadeDefault :: CascadeAction } -- | If the override is defined, then this returns a function that accepts an @@ -61,6 +64,29 @@ setBackendSpecificForeignKeyName setBackendSpecificForeignKeyName func bso = bso{backendSpecificForeignKeyName = Just func} +-- | If the override is defined, then this returns a function that accepts an +-- entity name and field name and provides the 'ConstraintNameDB' for the +-- foreign key constraint. +-- +-- An abstract accessor for the 'BackendSpecificOverrides' +-- +-- @since 2.13.0.0 +getBackendSpecificForeignKeyCascadeDefault + :: BackendSpecificOverrides + -> CascadeAction +getBackendSpecificForeignKeyCascadeDefault = + backendSpecificForeignKeyCascadeDefault + +-- | Set the backend's foreign key generation function to this value. +-- +-- @since 2.13.0.0 +setBackendSpecificForeignKeyCascadeDefault + :: CascadeAction + -> BackendSpecificOverrides + -> BackendSpecificOverrides +setBackendSpecificForeignKeyCascadeDefault action bso = + bso{backendSpecificForeignKeyCascadeDefault = action} + findMaybe :: (a -> Maybe b) -> [a] -> Maybe b findMaybe p = listToMaybe . mapMaybe p @@ -68,7 +94,7 @@ findMaybe p = listToMaybe . mapMaybe p -- -- @since 2.11 emptyBackendSpecificOverrides :: BackendSpecificOverrides -emptyBackendSpecificOverrides = BackendSpecificOverrides Nothing +emptyBackendSpecificOverrides = BackendSpecificOverrides Nothing Restrict defaultAttribute :: [FieldAttr] -> Maybe Text defaultAttribute = findMaybe $ \case @@ -171,9 +197,11 @@ mkColumns allDefs t overrides = -- explicitly makes migrations run smoother. overrideNothings (FieldCascade{fcOnUpdate = upd, fcOnDelete = del}) = FieldCascade - { fcOnUpdate = upd <|> Just Restrict - , fcOnDelete = del <|> Just Restrict + { fcOnUpdate = upd <|> Just defaultAction + , fcOnDelete = del <|> Just defaultAction } + where + defaultAction = (backendSpecificForeignKeyCascadeDefault overrides) ref :: FieldNameDB From 147fb8790559b1333a1026bc28319f11aa87c93b Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Mon, 12 Jan 2026 14:47:46 -0500 Subject: [PATCH 28/30] update test --- persistent-postgresql/test/MigrationSpec.hs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/persistent-postgresql/test/MigrationSpec.hs b/persistent-postgresql/test/MigrationSpec.hs index b256561b6..071ef0733 100644 --- a/persistent-postgresql/test/MigrationSpec.hs +++ b/persistent-postgresql/test/MigrationSpec.hs @@ -67,7 +67,7 @@ FKChildV1 sql=migration_fk_child -- Simulate creating a new FK field on an existing table FKChildV2 sql=migration_fk_child - parentId FKParentId OnUpdateNoAction + parentId FKParentId ExplicitPrimaryKey sql=explicit_primary_key Id Text @@ -654,7 +654,7 @@ spec = describe "MigrationSpec" $ do Right alters -> map (snd . showAlterDb) alters `shouldBe` [ "ALTER TABLE \"migration_fk_child\" ADD COLUMN \"parent_id\" INT8 NOT NULL" - , "ALTER TABLE \"migration_fk_child\" ADD CONSTRAINT \"migration_fk_child_parent_id_fkey\" FOREIGN KEY(\"parent_id\") REFERENCES \"migration_fk_parent\"(\"id\") ON DELETE RESTRICT ON UPDATE NO ACTION" + , "ALTER TABLE \"migration_fk_child\" ADD CONSTRAINT \"migration_fk_child_parent_id_fkey\" FOREIGN KEY(\"parent_id\") REFERENCES \"migration_fk_parent\"(\"id\") ON DELETE RESTRICT ON UPDATE RESTRICT" ] it "Uses overrides for empty cascade action" $ runConnAssert $ do @@ -684,5 +684,5 @@ spec = describe "MigrationSpec" $ do Right alters -> map (snd . showAlterDb) alters `shouldBe` [ "ALTER TABLE \"migration_fk_child\" ADD COLUMN \"parent_id\" INT8 NOT NULL" - , "ALTER TABLE \"migration_fk_child\" ADD CONSTRAINT \"migration_fk_child_parent_id_fkey\" FOREIGN KEY(\"parent_id\") REFERENCES \"migration_fk_parent\"(\"id\") ON DELETE CASCADE ON UPDATE NO ACTION" + , "ALTER TABLE \"migration_fk_child\" ADD CONSTRAINT \"migration_fk_child_parent_id_fkey\" FOREIGN KEY(\"parent_id\") REFERENCES \"migration_fk_parent\"(\"id\") ON DELETE CASCADE ON UPDATE CASCADE" ] From fd8399fc8c73d23bb797854b9a6ff026d93dd7df Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Mon, 12 Jan 2026 14:53:21 -0500 Subject: [PATCH 29/30] update changelog, comments --- persistent-postgresql/ChangeLog.md | 5 +++++ persistent-postgresql/persistent-postgresql.cabal | 4 ++-- persistent/ChangeLog.md | 4 +++- persistent/Database/Persist/Sql/Internal.hs | 13 +++++-------- persistent/persistent.cabal | 2 +- 5 files changed, 16 insertions(+), 12 deletions(-) diff --git a/persistent-postgresql/ChangeLog.md b/persistent-postgresql/ChangeLog.md index 181def641..8d6ec1613 100644 --- a/persistent-postgresql/ChangeLog.md +++ b/persistent-postgresql/ChangeLog.md @@ -1,5 +1,10 @@ # Changelog for persistent-postgresql +# 2.14.3.0 + +* [#1616](https://github.com/yesodweb/persistent/pull/1616) + * Allow overriding the default cascade option for foreign keys. + # 2.14.2.0 * [#1614](https://github.com/yesodweb/persistent/pull/1614) diff --git a/persistent-postgresql/persistent-postgresql.cabal b/persistent-postgresql/persistent-postgresql.cabal index 9b2c9ee85..8314c528d 100644 --- a/persistent-postgresql/persistent-postgresql.cabal +++ b/persistent-postgresql/persistent-postgresql.cabal @@ -1,5 +1,5 @@ name: persistent-postgresql -version: 2.14.2.0 +version: 2.14.3.0 license: MIT license-file: LICENSE author: Felipe Lessa, Michael Snoyman @@ -28,7 +28,7 @@ library , file-embed >=0.0.16 , monad-logger >=0.3.25 , mtl - , persistent >=2.18 && <3 + , persistent >=2.18.1 && <3 , postgresql-libpq >=0.9.4.2 && <0.12 , postgresql-simple >=0.6.1 && <0.8 , postgresql-simple-interval >=1 && <1.1 diff --git a/persistent/ChangeLog.md b/persistent/ChangeLog.md index e0c208d03..a7703759e 100644 --- a/persistent/ChangeLog.md +++ b/persistent/ChangeLog.md @@ -1,6 +1,8 @@ # Changelog for persistent -# Unreleased +# 2.18.1.0 +* [#1616](https://github.com/yesodweb/persistent/pull/1616) + * Allow overriding the default cascade option for foreign keys. * [#1608](https://github.com/yesodweb/persistent/pull/1608) * Improves documentation on getBy with nullable fields * Updates the warning text present when you try to make a Unique field that is nullable diff --git a/persistent/Database/Persist/Sql/Internal.hs b/persistent/Database/Persist/Sql/Internal.hs index 04d5e390b..3b3f7beae 100644 --- a/persistent/Database/Persist/Sql/Internal.hs +++ b/persistent/Database/Persist/Sql/Internal.hs @@ -64,22 +64,19 @@ setBackendSpecificForeignKeyName setBackendSpecificForeignKeyName func bso = bso{backendSpecificForeignKeyName = Just func} --- | If the override is defined, then this returns a function that accepts an --- entity name and field name and provides the 'ConstraintNameDB' for the --- foreign key constraint. --- --- An abstract accessor for the 'BackendSpecificOverrides' +-- | If the override is defined, then this specifies what cascade action +-- should be used if there is none defined for the column. -- --- @since 2.13.0.0 +-- @since 2.18.1.0 getBackendSpecificForeignKeyCascadeDefault :: BackendSpecificOverrides -> CascadeAction getBackendSpecificForeignKeyCascadeDefault = backendSpecificForeignKeyCascadeDefault --- | Set the backend's foreign key generation function to this value. +-- | Set the backend's default cascade action. -- --- @since 2.13.0.0 +-- @since 2.18.1.0 setBackendSpecificForeignKeyCascadeDefault :: CascadeAction -> BackendSpecificOverrides diff --git a/persistent/persistent.cabal b/persistent/persistent.cabal index 8db185165..d30976467 100644 --- a/persistent/persistent.cabal +++ b/persistent/persistent.cabal @@ -1,5 +1,5 @@ name: persistent -version: 2.18.0.0 +version: 2.18.1.0 license: MIT license-file: LICENSE author: Michael Snoyman From acadebd5eeb803f4b50e68d24274230f678adee8 Mon Sep 17 00:00:00 2001 From: Kuba Karpierz Date: Mon, 12 Jan 2026 14:58:19 -0500 Subject: [PATCH 30/30] styling --- persistent-postgresql/test/MigrationSpec.hs | 3 +-- persistent/Database/Persist/Sql.hs | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/persistent-postgresql/test/MigrationSpec.hs b/persistent-postgresql/test/MigrationSpec.hs index 071ef0733..e4fa141dc 100644 --- a/persistent-postgresql/test/MigrationSpec.hs +++ b/persistent-postgresql/test/MigrationSpec.hs @@ -664,8 +664,7 @@ spec = describe "MigrationSpec" $ do let overrideWithDefault = - ( setBackendSpecificForeignKeyCascadeDefault Cascade emptyBackendSpecificOverrides - ) + setBackendSpecificForeignKeyCascadeDefault Cascade emptyBackendSpecificOverrides result <- liftIO $ migrateEntitiesStructured diff --git a/persistent/Database/Persist/Sql.hs b/persistent/Database/Persist/Sql.hs index 7dd6404f6..76de6661b 100644 --- a/persistent/Database/Persist/Sql.hs +++ b/persistent/Database/Persist/Sql.hs @@ -9,7 +9,7 @@ -- Then, you'll use the operations module Database.Persist.Sql ( -- * 'RawSql' and 'PersistFieldSql' - module Database.Persist.Sql.Class + module Database.Persist.Sql.Class -- * Running actions