12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265 |
- {-# LANGUAGE DeriveDataTypeable #-}
- {-# LANGUAGE OverloadedStrings #-}
- {-# LANGUAGE ScopedTypeVariables #-}
- {-# LANGUAGE TypeFamilies #-}
- {-# LANGUAGE ViewPatterns #-}
- -- | A postgresql backend for persistent.
- module Database.Persist.Postgresql
- ( withPostgresqlPool
- , withPostgresqlPoolWithVersion
- , withPostgresqlConn
- , withPostgresqlConnWithVersion
- , createPostgresqlPool
- , createPostgresqlPoolModified
- , createPostgresqlPoolModifiedWithVersion
- , module Database.Persist.Sql
- , ConnectionString
- , PostgresConf (..)
- , openSimpleConn
- , openSimpleConnWithVersion
- , tableName
- , fieldName
- , mockMigration
- , migrateEnableExtension
- ) where
- import qualified Database.PostgreSQL.LibPQ as LibPQ
- import qualified Database.PostgreSQL.Simple as PG
- import qualified Database.PostgreSQL.Simple.Internal as PG
- import qualified Database.PostgreSQL.Simple.FromField as PGFF
- import qualified Database.PostgreSQL.Simple.ToField as PGTF
- import qualified Database.PostgreSQL.Simple.Transaction as PG
- import qualified Database.PostgreSQL.Simple.Types as PG
- import qualified Database.PostgreSQL.Simple.TypeInfo.Static as PS
- import Database.PostgreSQL.Simple.Ok (Ok (..))
- import Control.Arrow
- import Control.Exception (Exception, throw, throwIO)
- import Control.Monad (forM)
- import Control.Monad.IO.Unlift (MonadIO (..), MonadUnliftIO)
- import Control.Monad.Logger (MonadLogger, runNoLoggingT)
- import Control.Monad.Trans.Reader (runReaderT)
- import Control.Monad.Trans.Writer (WriterT(..), runWriterT)
- import qualified Blaze.ByteString.Builder.Char8 as BBB
- import Data.Acquire (Acquire, mkAcquire, with)
- import Data.Aeson
- import Data.Aeson.Types (modifyFailure)
- import Data.ByteString (ByteString)
- import qualified Data.ByteString.Char8 as B8
- import Data.Conduit
- import qualified Data.Conduit.List as CL
- import Data.Data
- import Data.Either (partitionEithers)
- import Data.Fixed (Pico)
- import Data.Function (on)
- import Data.Int (Int64)
- import qualified Data.IntMap as I
- import Data.IORef
- import Data.List (find, sort, groupBy)
- import Data.List.NonEmpty (NonEmpty)
- import qualified Data.List.NonEmpty as NEL
- import qualified Data.Map as Map
- import Data.Maybe
- import Data.Monoid ((<>))
- import Data.Pool (Pool)
- import Data.Text (Text)
- import qualified Data.Text as T
- import qualified Data.Text.Encoding as T
- import qualified Data.Text.IO as T
- import Data.Text.Read (rational)
- import Data.Time (utc, localTimeToUTC)
- import Data.Typeable (Typeable)
- import System.Environment (getEnvironment)
- import Database.Persist.Sql
- import qualified Database.Persist.Sql.Util as Util
- -- | A @libpq@ connection string. A simple example of connection
- -- string would be @\"host=localhost port=5432 user=test
- -- dbname=test password=test\"@. Please read libpq's
- -- documentation at
- -- <https://www.postgresql.org/docs/current/static/libpq-connect.html>
- -- for more details on how to create such strings.
- type ConnectionString = ByteString
- -- | PostgresServerVersionError exception. This is thrown when persistent
- -- is unable to find the version of the postgreSQL server.
- data PostgresServerVersionError = PostgresServerVersionError String deriving Data.Typeable.Typeable
- instance Show PostgresServerVersionError where
- show (PostgresServerVersionError uniqueMsg) =
- "Unexpected PostgreSQL server version, got " <> uniqueMsg
- instance Exception PostgresServerVersionError
- -- | Create a PostgreSQL connection pool and run the given
- -- action. The pool is properly released after the action
- -- finishes using it. Note that you should not use the given
- -- 'ConnectionPool' outside the action since it may already
- -- have been released.
- withPostgresqlPool :: (MonadLogger m, MonadUnliftIO m)
- => ConnectionString
- -- ^ Connection string to the database.
- -> Int
- -- ^ Number of connections to be kept open in
- -- the pool.
- -> (Pool SqlBackend -> m a)
- -- ^ Action to be executed that uses the
- -- connection pool.
- -> m a
- withPostgresqlPool ci = withPostgresqlPoolWithVersion getServerVersion ci
- -- | Same as 'withPostgresPool', but takes a callback for obtaining
- -- the server version (to work around an Amazon Redshift bug).
- --
- -- @since 2.6.2
- withPostgresqlPoolWithVersion :: (MonadUnliftIO m, MonadLogger m)
- => (PG.Connection -> IO (Maybe Double))
- -- ^ Action to perform to get the server version.
- -> ConnectionString
- -- ^ Connection string to the database.
- -> Int
- -- ^ Number of connections to be kept open in
- -- the pool.
- -> (Pool SqlBackend -> m a)
- -- ^ Action to be executed that uses the
- -- connection pool.
- -> m a
- withPostgresqlPoolWithVersion getVer ci = withSqlPool $ open' (const $ return ()) getVer ci
- -- | Create a PostgreSQL connection pool. Note that it's your
- -- responsibility to properly close the connection pool when
- -- unneeded. Use 'withPostgresqlPool' for an automatic resource
- -- control.
- createPostgresqlPool :: (MonadUnliftIO m, MonadLogger m)
- => ConnectionString
- -- ^ Connection string to the database.
- -> Int
- -- ^ Number of connections to be kept open
- -- in the pool.
- -> m (Pool SqlBackend)
- createPostgresqlPool = createPostgresqlPoolModified (const $ return ())
- -- | Same as 'createPostgresqlPool', but additionally takes a callback function
- -- for some connection-specific tweaking to be performed after connection
- -- creation. This could be used, for example, to change the schema. For more
- -- information, see:
- --
- -- <https://groups.google.com/d/msg/yesodweb/qUXrEN_swEo/O0pFwqwQIdcJ>
- --
- -- @since 2.1.3
- createPostgresqlPoolModified
- :: (MonadUnliftIO m, MonadLogger m)
- => (PG.Connection -> IO ()) -- ^ Action to perform after connection is created.
- -> ConnectionString -- ^ Connection string to the database.
- -> Int -- ^ Number of connections to be kept open in the pool.
- -> m (Pool SqlBackend)
- createPostgresqlPoolModified = createPostgresqlPoolModifiedWithVersion getServerVersion
- -- | Same as other similarly-named functions in this module, but takes callbacks for obtaining
- -- the server version (to work around an Amazon Redshift bug) and connection-specific tweaking
- -- (to change the schema).
- --
- -- @since 2.6.2
- createPostgresqlPoolModifiedWithVersion
- :: (MonadUnliftIO m, MonadLogger m)
- => (PG.Connection -> IO (Maybe Double)) -- ^ Action to perform to get the server version.
- -> (PG.Connection -> IO ()) -- ^ Action to perform after connection is created.
- -> ConnectionString -- ^ Connection string to the database.
- -> Int -- ^ Number of connections to be kept open in the pool.
- -> m (Pool SqlBackend)
- createPostgresqlPoolModifiedWithVersion getVer modConn ci =
- createSqlPool $ open' modConn getVer ci
- -- | Same as 'withPostgresqlPool', but instead of opening a pool
- -- of connections, only one connection is opened.
- withPostgresqlConn :: (MonadUnliftIO m, MonadLogger m)
- => ConnectionString -> (SqlBackend -> m a) -> m a
- withPostgresqlConn = withPostgresqlConnWithVersion getServerVersion
- -- | Same as 'withPostgresqlConn', but takes a callback for obtaining
- -- the server version (to work around an Amazon Redshift bug).
- --
- -- @since 2.6.2
- withPostgresqlConnWithVersion :: (MonadUnliftIO m, MonadLogger m)
- => (PG.Connection -> IO (Maybe Double))
- -> ConnectionString
- -> (SqlBackend -> m a)
- -> m a
- withPostgresqlConnWithVersion getVer = withSqlConn . open' (const $ return ()) getVer
- open'
- :: (PG.Connection -> IO ())
- -> (PG.Connection -> IO (Maybe Double))
- -> ConnectionString -> LogFunc -> IO SqlBackend
- open' modConn getVer cstr logFunc = do
- conn <- PG.connectPostgreSQL cstr
- modConn conn
- ver <- getVer conn
- smap <- newIORef $ Map.empty
- return $ createBackend logFunc ver smap conn
- -- | Gets the PostgreSQL server version
- getServerVersion :: PG.Connection -> IO (Maybe Double)
- getServerVersion conn = do
- [PG.Only version] <- PG.query_ conn "show server_version";
- let version' = rational version
- --- λ> rational "9.8.3"
- --- Right (9.8,".3")
- --- λ> rational "9.8.3.5"
- --- Right (9.8,".3.5")
- case version' of
- Right (a,_) -> return $ Just a
- Left err -> throwIO $ PostgresServerVersionError err
- -- | Choose upsert sql generation function based on postgresql version.
- -- PostgreSQL version >= 9.5 supports native upsert feature,
- -- so depending upon that we have to choose how the sql query is generated.
- -- upsertFunction :: Double -> Maybe (EntityDef -> Text -> Text)
- upsertFunction :: a -> Double -> Maybe a
- upsertFunction f version = if (version >= 9.5)
- then Just f
- else Nothing
- -- | Generate a 'SqlBackend' from a 'PG.Connection'.
- openSimpleConn :: LogFunc -> PG.Connection -> IO SqlBackend
- openSimpleConn = openSimpleConnWithVersion getServerVersion
- -- | Generate a 'SqlBackend' from a 'PG.Connection', but takes a callback for
- -- obtaining the server version.
- --
- -- @since 2.9.1
- openSimpleConnWithVersion :: (PG.Connection -> IO (Maybe Double)) -> LogFunc -> PG.Connection -> IO SqlBackend
- openSimpleConnWithVersion getVer logFunc conn = do
- smap <- newIORef $ Map.empty
- serverVersion <- getVer conn
- return $ createBackend logFunc serverVersion smap conn
- -- | Create the backend given a logging function, server version, mutable statement cell,
- -- and connection.
- createBackend :: LogFunc -> Maybe Double
- -> IORef (Map.Map Text Statement) -> PG.Connection -> SqlBackend
- createBackend logFunc serverVersion smap conn = do
- SqlBackend
- { connPrepare = prepare' conn
- , connStmtMap = smap
- , connInsertSql = insertSql'
- , connInsertManySql = Just insertManySql'
- , connUpsertSql = serverVersion >>= upsertFunction upsertSql'
- , connPutManySql = serverVersion >>= upsertFunction putManySql
- , connClose = PG.close conn
- , connMigrateSql = migrate'
- , connBegin = \_ mIsolation -> case mIsolation of
- Nothing -> PG.begin conn
- Just iso -> PG.beginLevel (case iso of
- ReadUncommitted -> PG.ReadCommitted -- PG Upgrades uncommitted reads to committed anyways
- ReadCommitted -> PG.ReadCommitted
- RepeatableRead -> PG.RepeatableRead
- Serializable -> PG.Serializable) conn
- , connCommit = const $ PG.commit conn
- , connRollback = const $ PG.rollback conn
- , connEscapeName = escape
- , connNoLimit = "LIMIT ALL"
- , connRDBMS = "postgresql"
- , connLimitOffset = decorateSQLWithLimitOffset "LIMIT ALL"
- , connLogFunc = logFunc
- , connMaxParams = Nothing
- , connRepsertManySql = serverVersion >>= upsertFunction repsertManySql
- }
- prepare' :: PG.Connection -> Text -> IO Statement
- prepare' conn sql = do
- let query = PG.Query (T.encodeUtf8 sql)
- return Statement
- { stmtFinalize = return ()
- , stmtReset = return ()
- , stmtExecute = execute' conn query
- , stmtQuery = withStmt' conn query
- }
- insertSql' :: EntityDef -> [PersistValue] -> InsertSqlResult
- insertSql' ent vals =
- let sql = T.concat
- [ "INSERT INTO "
- , escape $ entityDB ent
- , if null (entityFields ent)
- then " DEFAULT VALUES"
- else T.concat
- [ "("
- , T.intercalate "," $ map (escape . fieldDB) $ entityFields ent
- , ") VALUES("
- , T.intercalate "," (map (const "?") $ entityFields ent)
- , ")"
- ]
- ]
- in case entityPrimary ent of
- Just _pdef -> ISRManyKeys sql vals
- Nothing -> ISRSingle (sql <> " RETURNING " <> escape (fieldDB (entityId ent)))
- upsertSql' :: EntityDef -> NonEmpty UniqueDef -> Text -> Text
- upsertSql' ent uniqs updateVal = T.concat
- [ "INSERT INTO "
- , escape (entityDB ent)
- , "("
- , T.intercalate "," $ map (escape . fieldDB) $ entityFields ent
- , ") VALUES ("
- , T.intercalate "," $ map (const "?") (entityFields ent)
- , ") ON CONFLICT ("
- , T.intercalate "," $ concat $ map (\x -> map escape (map snd $ uniqueFields x)) (entityUniques ent)
- , ") DO UPDATE SET "
- , updateVal
- , " WHERE "
- , wher
- , " RETURNING ??"
- ]
- where
- wher = T.intercalate " AND " $ map singleCondition $ NEL.toList uniqs
- singleCondition :: UniqueDef -> Text
- singleCondition udef = T.intercalate " AND " (map singleClause $ map snd (uniqueFields udef))
- singleClause :: DBName -> Text
- singleClause field = escape (entityDB ent) <> "." <> (escape field) <> " =?"
- -- | SQL for inserting multiple rows at once and returning their primary keys.
- insertManySql' :: EntityDef -> [[PersistValue]] -> InsertSqlResult
- insertManySql' ent valss =
- let sql = T.concat
- [ "INSERT INTO "
- , escape (entityDB ent)
- , "("
- , T.intercalate "," $ map (escape . fieldDB) $ entityFields ent
- , ") VALUES ("
- , T.intercalate "),(" $ replicate (length valss) $ T.intercalate "," $ map (const "?") (entityFields ent)
- , ") RETURNING "
- , Util.commaSeparated $ Util.dbIdColumnsEsc escape ent
- ]
- in ISRSingle sql
- execute' :: PG.Connection -> PG.Query -> [PersistValue] -> IO Int64
- execute' conn query vals = PG.execute conn query (map P vals)
- withStmt' :: MonadIO m
- => PG.Connection
- -> PG.Query
- -> [PersistValue]
- -> Acquire (ConduitM () [PersistValue] m ())
- withStmt' conn query vals =
- pull `fmap` mkAcquire openS closeS
- where
- openS = do
- -- Construct raw query
- rawquery <- PG.formatQuery conn query (map P vals)
- -- Take raw connection
- (rt, rr, rc, ids) <- PG.withConnection conn $ \rawconn -> do
- -- Execute query
- mret <- LibPQ.exec rawconn rawquery
- case mret of
- Nothing -> do
- merr <- LibPQ.errorMessage rawconn
- fail $ case merr of
- Nothing -> "Postgresql.withStmt': unknown error"
- Just e -> "Postgresql.withStmt': " ++ B8.unpack e
- Just ret -> do
- -- Check result status
- status <- LibPQ.resultStatus ret
- case status of
- LibPQ.TuplesOk -> return ()
- _ -> PG.throwResultError "Postgresql.withStmt': bad result status " ret status
- -- Get number and type of columns
- cols <- LibPQ.nfields ret
- oids <- forM [0..cols-1] $ \col -> fmap ((,) col) (LibPQ.ftype ret col)
- -- Ready to go!
- rowRef <- newIORef (LibPQ.Row 0)
- rowCount <- LibPQ.ntuples ret
- return (ret, rowRef, rowCount, oids)
- let getters
- = map (\(col, oid) -> getGetter conn oid $ PG.Field rt col oid) ids
- return (rt, rr, rc, getters)
- closeS (ret, _, _, _) = LibPQ.unsafeFreeResult ret
- pull x = do
- y <- liftIO $ pullS x
- case y of
- Nothing -> return ()
- Just z -> yield z >> pull x
- pullS (ret, rowRef, rowCount, getters) = do
- row <- atomicModifyIORef rowRef (\r -> (r+1, r))
- if row == rowCount
- then return Nothing
- else fmap Just $ forM (zip getters [0..]) $ \(getter, col) -> do
- mbs <- LibPQ.getvalue' ret row col
- case mbs of
- Nothing ->
- -- getvalue' verified that the value is NULL.
- -- However, that does not mean that there are
- -- no NULL values inside the value (e.g., if
- -- we're dealing with an array of optional values).
- return PersistNull
- Just bs -> do
- ok <- PGFF.runConversion (getter mbs) conn
- bs `seq` case ok of
- Errors (exc:_) -> throw exc
- Errors [] -> error "Got an Errors, but no exceptions"
- Ok v -> return v
- -- | Avoid orphan instances.
- newtype P = P PersistValue
- instance PGTF.ToField P where
- toField (P (PersistText t)) = PGTF.toField t
- toField (P (PersistByteString bs)) = PGTF.toField (PG.Binary bs)
- toField (P (PersistInt64 i)) = PGTF.toField i
- toField (P (PersistDouble d)) = PGTF.toField d
- toField (P (PersistRational r)) = PGTF.Plain $
- BBB.fromString $
- show (fromRational r :: Pico) -- FIXME: Too Ambigous, can not select precision without information about field
- toField (P (PersistBool b)) = PGTF.toField b
- toField (P (PersistDay d)) = PGTF.toField d
- toField (P (PersistTimeOfDay t)) = PGTF.toField t
- toField (P (PersistUTCTime t)) = PGTF.toField t
- toField (P PersistNull) = PGTF.toField PG.Null
- toField (P (PersistList l)) = PGTF.toField $ listToJSON l
- toField (P (PersistMap m)) = PGTF.toField $ mapToJSON m
- toField (P (PersistDbSpecific s)) = PGTF.toField (Unknown s)
- toField (P (PersistArray a)) = PGTF.toField $ PG.PGArray $ P <$> a
- toField (P (PersistObjectId _)) =
- error "Refusing to serialize a PersistObjectId to a PostgreSQL value"
- newtype Unknown = Unknown { unUnknown :: ByteString }
- deriving (Eq, Show, Read, Ord, Typeable)
- instance PGFF.FromField Unknown where
- fromField f mdata =
- case mdata of
- Nothing -> PGFF.returnError PGFF.UnexpectedNull f "Database.Persist.Postgresql/PGFF.FromField Unknown"
- Just dat -> return (Unknown dat)
- instance PGTF.ToField Unknown where
- toField (Unknown a) = PGTF.Escape a
- type Getter a = PGFF.FieldParser a
- convertPV :: PGFF.FromField a => (a -> b) -> Getter b
- convertPV f = (fmap f .) . PGFF.fromField
- builtinGetters :: I.IntMap (Getter PersistValue)
- builtinGetters = I.fromList
- [ (k PS.bool, convertPV PersistBool)
- , (k PS.bytea, convertPV (PersistByteString . unBinary))
- , (k PS.char, convertPV PersistText)
- , (k PS.name, convertPV PersistText)
- , (k PS.int8, convertPV PersistInt64)
- , (k PS.int2, convertPV PersistInt64)
- , (k PS.int4, convertPV PersistInt64)
- , (k PS.text, convertPV PersistText)
- , (k PS.xml, convertPV PersistText)
- , (k PS.float4, convertPV PersistDouble)
- , (k PS.float8, convertPV PersistDouble)
- , (k PS.money, convertPV PersistRational)
- , (k PS.bpchar, convertPV PersistText)
- , (k PS.varchar, convertPV PersistText)
- , (k PS.date, convertPV PersistDay)
- , (k PS.time, convertPV PersistTimeOfDay)
- , (k PS.timestamp, convertPV (PersistUTCTime. localTimeToUTC utc))
- , (k PS.timestamptz, convertPV PersistUTCTime)
- , (k PS.bit, convertPV PersistInt64)
- , (k PS.varbit, convertPV PersistInt64)
- , (k PS.numeric, convertPV PersistRational)
- , (k PS.void, \_ _ -> return PersistNull)
- , (k PS.json, convertPV (PersistByteString . unUnknown))
- , (k PS.jsonb, convertPV (PersistByteString . unUnknown))
- , (k PS.unknown, convertPV (PersistByteString . unUnknown))
- -- Array types: same order as above.
- -- The OIDs were taken from pg_type.
- , (1000, listOf PersistBool)
- , (1001, listOf (PersistByteString . unBinary))
- , (1002, listOf PersistText)
- , (1003, listOf PersistText)
- , (1016, listOf PersistInt64)
- , (1005, listOf PersistInt64)
- , (1007, listOf PersistInt64)
- , (1009, listOf PersistText)
- , (143, listOf PersistText)
- , (1021, listOf PersistDouble)
- , (1022, listOf PersistDouble)
- , (1023, listOf PersistUTCTime)
- , (1024, listOf PersistUTCTime)
- , (791, listOf PersistRational)
- , (1014, listOf PersistText)
- , (1015, listOf PersistText)
- , (1182, listOf PersistDay)
- , (1183, listOf PersistTimeOfDay)
- , (1115, listOf PersistUTCTime)
- , (1185, listOf PersistUTCTime)
- , (1561, listOf PersistInt64)
- , (1563, listOf PersistInt64)
- , (1231, listOf PersistRational)
- -- no array(void) type
- , (2951, listOf (PersistDbSpecific . unUnknown))
- , (199, listOf (PersistByteString . unUnknown))
- , (3807, listOf (PersistByteString . unUnknown))
- -- no array(unknown) either
- ]
- where
- k (PGFF.typoid -> i) = PG.oid2int i
- -- A @listOf f@ will use a @PGArray (Maybe T)@ to convert
- -- the values to Haskell-land. The @Maybe@ is important
- -- because the usual way of checking NULLs
- -- (c.f. withStmt') won't check for NULL inside
- -- arrays---or any other compound structure for that matter.
- listOf f = convertPV (PersistList . map (nullable f) . PG.fromPGArray)
- where nullable = maybe PersistNull
- getGetter :: PG.Connection -> PG.Oid -> Getter PersistValue
- getGetter _conn oid
- = fromMaybe defaultGetter $ I.lookup (PG.oid2int oid) builtinGetters
- where defaultGetter = convertPV (PersistDbSpecific . unUnknown)
- unBinary :: PG.Binary a -> a
- unBinary (PG.Binary x) = x
- doesTableExist :: (Text -> IO Statement)
- -> DBName -- ^ table name
- -> IO Bool
- doesTableExist getter (DBName name) = do
- stmt <- getter sql
- with (stmtQuery stmt vals) (\src -> runConduit $ src .| start)
- where
- sql = "SELECT COUNT(*) FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog'"
- <> " AND schemaname != 'information_schema' AND tablename=?"
- vals = [PersistText name]
- start = await >>= maybe (error "No results when checking doesTableExist") start'
- start' [PersistInt64 0] = finish False
- start' [PersistInt64 1] = finish True
- start' res = error $ "doesTableExist returned unexpected result: " ++ show res
- finish x = await >>= maybe (return x) (error "Too many rows returned in doesTableExist")
- migrate' :: [EntityDef]
- -> (Text -> IO Statement)
- -> EntityDef
- -> IO (Either [Text] [(Bool, Text)])
- migrate' allDefs getter entity = fmap (fmap $ map showAlterDb) $ do
- old <- getColumns getter entity
- case partitionEithers old of
- ([], old'') -> do
- exists <-
- if null old
- then doesTableExist getter name
- else return True
- return $ Right $ migrationText exists old''
- (errs, _) -> return $ Left errs
- where
- name = entityDB entity
- migrationText exists old'' =
- if not exists
- then createText newcols fdefs udspair
- else let (acs, ats) = getAlters allDefs entity (newcols, udspair) old'
- acs' = map (AlterColumn name) acs
- ats' = map (AlterTable name) ats
- in acs' ++ ats'
- where
- old' = partitionEithers old''
- (newcols', udefs, fdefs) = mkColumns allDefs entity
- newcols = filter (not . safeToRemove entity . cName) newcols'
- udspair = map udToPair udefs
- -- Check for table existence if there are no columns, workaround
- -- for https://github.com/yesodweb/persistent/issues/152
- createText newcols fdefs udspair =
- (addTable newcols entity) : uniques ++ references ++ foreignsAlt
- where
- uniques = flip concatMap udspair $ \(uname, ucols) ->
- [AlterTable name $ AddUniqueConstraint uname ucols]
- references = mapMaybe (\c@Column { cName=cname, cReference=Just (refTblName, _) } ->
- getAddReference allDefs name refTblName cname (cReference c))
- $ filter (isJust . cReference) newcols
- foreignsAlt = flip map fdefs (\fdef ->
- let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef))
- in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignConstraintNameDBName fdef) childfields (map escape parentfields)))
- addTable :: [Column] -> EntityDef -> AlterDB
- addTable cols entity = AddTable $ T.concat
- -- Lower case e: see Database.Persist.Sql.Migration
- [ "CREATe TABLE " -- DO NOT FIX THE CAPITALIZATION!
- , escape name
- , "("
- , idtxt
- , if null cols then "" else ","
- , T.intercalate "," $ map showColumn cols
- , ")"
- ]
- where
- name = entityDB entity
- idtxt = case entityPrimary entity of
- Just pdef -> T.concat [" PRIMARY KEY (", T.intercalate "," $ map (escape . fieldDB) $ compositeFields pdef, ")"]
- Nothing ->
- let defText = defaultAttribute $ fieldAttrs $ entityId entity
- sType = fieldSqlType $ entityId entity
- in T.concat
- [ escape $ fieldDB (entityId entity)
- , maySerial sType defText
- , " PRIMARY KEY UNIQUE"
- , mayDefault defText
- ]
- maySerial :: SqlType -> Maybe Text -> Text
- maySerial SqlInt64 Nothing = " SERIAL8 "
- maySerial sType _ = " " <> showSqlType sType
- mayDefault :: Maybe Text -> Text
- mayDefault def = case def of
- Nothing -> ""
- Just d -> " DEFAULT " <> d
- type SafeToRemove = Bool
- data AlterColumn = ChangeType SqlType Text
- | IsNull | NotNull | Add' Column | Drop SafeToRemove
- | Default Text | NoDefault | Update' Text
- | AddReference DBName [DBName] [Text] | DropReference DBName
- type AlterColumn' = (DBName, AlterColumn)
- data AlterTable = AddUniqueConstraint DBName [DBName]
- | DropConstraint DBName
- data AlterDB = AddTable Text
- | AlterColumn DBName AlterColumn'
- | AlterTable DBName AlterTable
- -- | Returns all of the columns in the given table currently in the database.
- getColumns :: (Text -> IO Statement)
- -> EntityDef
- -> IO [Either Text (Either Column (DBName, [DBName]))]
- getColumns getter def = do
- let sqlv=T.concat ["SELECT "
- ,"column_name "
- ,",is_nullable "
- ,",COALESCE(domain_name, udt_name)" -- See DOMAINS below
- ,",column_default "
- ,",numeric_precision "
- ,",numeric_scale "
- ,",character_maximum_length "
- ,"FROM information_schema.columns "
- ,"WHERE table_catalog=current_database() "
- ,"AND table_schema=current_schema() "
- ,"AND table_name=? "
- ,"AND column_name <> ?"]
- -- DOMAINS Postgres supports the concept of domains, which are data types with optional constraints.
- -- An app might make an "email" domain over the varchar type, with a CHECK that the emails are valid
- -- In this case the generated SQL should use the domain name: ALTER TABLE users ALTER COLUMN foo TYPE email
- -- This code exists to use the domain name (email), instead of the underlying type (varchar).
- -- This is tested in EquivalentTypeTest.hs
- stmt <- getter sqlv
- let vals =
- [ PersistText $ unDBName $ entityDB def
- , PersistText $ unDBName $ fieldDB (entityId def)
- ]
- cs <- with (stmtQuery stmt vals) (\src -> runConduit $ src .| helper)
- let sqlc = T.concat ["SELECT "
- ,"c.constraint_name, "
- ,"c.column_name "
- ,"FROM information_schema.key_column_usage c, "
- ,"information_schema.table_constraints k "
- ,"WHERE c.table_catalog=current_database() "
- ,"AND c.table_catalog=k.table_catalog "
- ,"AND c.table_schema=current_schema() "
- ,"AND c.table_schema=k.table_schema "
- ,"AND c.table_name=? "
- ,"AND c.table_name=k.table_name "
- ,"AND c.column_name <> ? "
- ,"AND c.constraint_name=k.constraint_name "
- ,"AND NOT k.constraint_type IN ('PRIMARY KEY', 'FOREIGN KEY') "
- ,"ORDER BY c.constraint_name, c.column_name"]
- stmt' <- getter sqlc
- us <- with (stmtQuery stmt' vals) (\src -> runConduit $ src .| helperU)
- return $ cs ++ us
- where
- getAll front = do
- x <- CL.head
- case x of
- Nothing -> return $ front []
- Just [PersistText con, PersistText col] -> getAll (front . (:) (con, col))
- Just [PersistByteString con, PersistByteString col] -> getAll (front . (:) (T.decodeUtf8 con, T.decodeUtf8 col))
- Just o -> error $ "unexpected datatype returned for postgres o="++show o
- helperU = do
- rows <- getAll id
- return $ map (Right . Right . (DBName . fst . head &&& map (DBName . snd)))
- $ groupBy ((==) `on` fst) rows
- helper = do
- x <- CL.head
- case x of
- Nothing -> return []
- Just x' -> do
- col <- liftIO $ getColumn getter (entityDB def) x'
- let col' = case col of
- Left e -> Left e
- Right c -> Right $ Left c
- cols <- helper
- return $ col' : cols
- -- | Check if a column name is listed as the "safe to remove" in the entity
- -- list.
- safeToRemove :: EntityDef -> DBName -> Bool
- safeToRemove def (DBName colName)
- = any (elem "SafeToRemove" . fieldAttrs)
- $ filter ((== DBName colName) . fieldDB)
- $ entityFields def
- getAlters :: [EntityDef]
- -> EntityDef
- -> ([Column], [(DBName, [DBName])])
- -> ([Column], [(DBName, [DBName])])
- -> ([AlterColumn'], [AlterTable])
- getAlters defs def (c1, u1) (c2, u2) =
- (getAltersC c1 c2, getAltersU u1 u2)
- where
- getAltersC [] old = map (\x -> (cName x, Drop $ safeToRemove def $ cName x)) old
- getAltersC (new:news) old =
- let (alters, old') = findAlters defs (entityDB def) new old
- in alters ++ getAltersC news old'
- getAltersU :: [(DBName, [DBName])]
- -> [(DBName, [DBName])]
- -> [AlterTable]
- getAltersU [] old = map DropConstraint $ filter (not . isManual) $ map fst old
- getAltersU ((name, cols):news) old =
- case lookup name old of
- Nothing -> AddUniqueConstraint name cols : getAltersU news old
- Just ocols ->
- let old' = filter (\(x, _) -> x /= name) old
- in if sort cols == sort ocols
- then getAltersU news old'
- else DropConstraint name
- : AddUniqueConstraint name cols
- : getAltersU news old'
- -- Don't drop constraints which were manually added.
- isManual (DBName x) = "__manual_" `T.isPrefixOf` x
- getColumn :: (Text -> IO Statement)
- -> DBName -> [PersistValue]
- -> IO (Either Text Column)
- getColumn getter tableName' [PersistText columnName, PersistText isNullable, PersistText typeName, defaultValue, numericPrecision, numericScale, maxlen] =
- case d' of
- Left s -> return $ Left s
- Right d'' ->
- let typeStr = case maxlen of
- PersistInt64 n -> T.concat [typeName, "(", T.pack (show n), ")"]
- _ -> typeName
- in case getType typeStr of
- Left s -> return $ Left s
- Right t -> do
- let cname = DBName columnName
- ref <- getRef cname
- return $ Right Column
- { cName = cname
- , cNull = isNullable == "YES"
- , cSqlType = t
- , cDefault = fmap stripSuffixes d''
- , cDefaultConstraintName = Nothing
- , cMaxLen = Nothing
- , cReference = ref
- }
- where
- stripSuffixes t =
- loop'
- [ "::character varying"
- , "::text"
- ]
- where
- loop' [] = t
- loop' (p:ps) =
- case T.stripSuffix p t of
- Nothing -> loop' ps
- Just t' -> t'
- getRef cname = do
- let sql = T.concat
- [ "SELECT COUNT(*) FROM "
- , "information_schema.table_constraints "
- , "WHERE table_catalog=current_database() "
- , "AND table_schema=current_schema() "
- , "AND table_name=? "
- , "AND constraint_type='FOREIGN KEY' "
- , "AND constraint_name=?"
- ]
- let ref = refName tableName' cname
- stmt <- getter sql
- with (stmtQuery stmt
- [ PersistText $ unDBName tableName'
- , PersistText $ unDBName ref
- ]) (\src -> runConduit $ src .| do
- Just [PersistInt64 i] <- CL.head
- return $ if i == 0 then Nothing else Just (DBName "", ref))
- d' = case defaultValue of
- PersistNull -> Right Nothing
- PersistText t -> Right $ Just t
- _ -> Left $ T.pack $ "Invalid default column: " ++ show defaultValue
- getType "int4" = Right SqlInt32
- getType "int8" = Right SqlInt64
- getType "varchar" = Right SqlString
- getType "text" = Right SqlString
- getType "date" = Right SqlDay
- getType "bool" = Right SqlBool
- getType "timestamptz" = Right SqlDayTime
- getType "float4" = Right SqlReal
- getType "float8" = Right SqlReal
- getType "bytea" = Right SqlBlob
- getType "time" = Right SqlTime
- getType "numeric" = getNumeric numericPrecision numericScale
- getType a = Right $ SqlOther a
- getNumeric (PersistInt64 a) (PersistInt64 b) = Right $ SqlNumeric (fromIntegral a) (fromIntegral b)
- getNumeric PersistNull PersistNull = Left $ T.concat
- [ "No precision and scale were specified for the column: "
- , columnName
- , " in table: "
- , unDBName tableName'
- , ". Postgres defaults to a maximum scale of 147,455 and precision of 16383,"
- , " which is probably not what you intended."
- , " Specify the values as numeric(total_digits, digits_after_decimal_place)."
- ]
- getNumeric a b = Left $ T.concat
- [ "Can not get numeric field precision for the column: "
- , columnName
- , " in table: "
- , unDBName tableName'
- , ". Expected an integer for both precision and scale, "
- , "got: "
- , T.pack $ show a
- , " and "
- , T.pack $ show b
- , ", respectively."
- , " Specify the values as numeric(total_digits, digits_after_decimal_place)."
- ]
- getColumn _ _ columnName =
- return $ Left $ T.pack $ "Invalid result from information_schema: " ++ show columnName
- -- | Intelligent comparison of SQL types, to account for SqlInt32 vs SqlOther integer
- sqlTypeEq :: SqlType -> SqlType -> Bool
- sqlTypeEq x y =
- T.toCaseFold (showSqlType x) == T.toCaseFold (showSqlType y)
- findAlters :: [EntityDef] -> DBName -> Column -> [Column] -> ([AlterColumn'], [Column])
- findAlters defs _tablename col@(Column name isNull sqltype def _defConstraintName _maxLen ref) cols =
- case filter (\c -> cName c == name) cols of
- [] -> ([(name, Add' col)], cols)
- Column _ isNull' sqltype' def' _defConstraintName' _maxLen' ref':_ ->
- let refDrop Nothing = []
- refDrop (Just (_, cname)) = [(name, DropReference cname)]
- refAdd Nothing = []
- refAdd (Just (tname, a)) =
- case find ((==tname) . entityDB) defs of
- Just refdef -> [(tname, AddReference a [name] (Util.dbIdColumnsEsc escape refdef))]
- Nothing -> error $ "could not find the entityDef for reftable[" ++ show tname ++ "]"
- modRef =
- if fmap snd ref == fmap snd ref'
- then []
- else refDrop ref' ++ refAdd ref
- modNull = case (isNull, isNull') of
- (True, False) -> [(name, IsNull)]
- (False, True) ->
- let up = case def of
- Nothing -> id
- Just s -> (:) (name, Update' s)
- in up [(name, NotNull)]
- _ -> []
- modType
- | sqlTypeEq sqltype sqltype' = []
- -- When converting from Persistent pre-2.0 databases, we
- -- need to make sure that TIMESTAMP WITHOUT TIME ZONE is
- -- treated as UTC.
- | sqltype == SqlDayTime && sqltype' == SqlOther "timestamp" =
- [(name, ChangeType sqltype $ T.concat
- [ " USING "
- , escape name
- , " AT TIME ZONE 'UTC'"
- ])]
- | otherwise = [(name, ChangeType sqltype "")]
- modDef =
- if def == def'
- then []
- else case def of
- Nothing -> [(name, NoDefault)]
- Just s -> [(name, Default s)]
- in (modRef ++ modDef ++ modNull ++ modType,
- filter (\c -> cName c /= name) cols)
- -- | Get the references to be added to a table for the given column.
- getAddReference :: [EntityDef] -> DBName -> DBName -> DBName -> Maybe (DBName, DBName) -> Maybe AlterDB
- getAddReference allDefs table reftable cname ref =
- case ref of
- Nothing -> Nothing
- Just (s, _) -> Just $ AlterColumn table (s, AddReference (refName table cname) [cname] id_)
- where
- id_ = fromMaybe (error $ "Could not find ID of entity " ++ show reftable)
- $ do
- entDef <- find ((== reftable) . entityDB) allDefs
- return $ Util.dbIdColumnsEsc escape entDef
- showColumn :: Column -> Text
- showColumn (Column n nu sqlType' def _defConstraintName _maxLen _ref) = T.concat
- [ escape n
- , " "
- , showSqlType sqlType'
- , " "
- , if nu then "NULL" else "NOT NULL"
- , case def of
- Nothing -> ""
- Just s -> " DEFAULT " <> s
- ]
- showSqlType :: SqlType -> Text
- showSqlType SqlString = "VARCHAR"
- showSqlType SqlInt32 = "INT4"
- showSqlType SqlInt64 = "INT8"
- showSqlType SqlReal = "DOUBLE PRECISION"
- showSqlType (SqlNumeric s prec) = T.concat [ "NUMERIC(", T.pack (show s), ",", T.pack (show prec), ")" ]
- showSqlType SqlDay = "DATE"
- showSqlType SqlTime = "TIME"
- showSqlType SqlDayTime = "TIMESTAMP WITH TIME ZONE"
- showSqlType SqlBlob = "BYTEA"
- showSqlType SqlBool = "BOOLEAN"
- -- Added for aliasing issues re: https://github.com/yesodweb/yesod/issues/682
- showSqlType (SqlOther (T.toLower -> "integer")) = "INT4"
- showSqlType (SqlOther t) = t
- showAlterDb :: AlterDB -> (Bool, Text)
- showAlterDb (AddTable s) = (False, s)
- showAlterDb (AlterColumn t (c, ac)) =
- (isUnsafe ac, showAlter t (c, ac))
- where
- isUnsafe (Drop safeRemove) = not safeRemove
- isUnsafe _ = False
- showAlterDb (AlterTable t at) = (False, showAlterTable t at)
- showAlterTable :: DBName -> AlterTable -> Text
- showAlterTable table (AddUniqueConstraint cname cols) = T.concat
- [ "ALTER TABLE "
- , escape table
- , " ADD CONSTRAINT "
- , escape cname
- , " UNIQUE("
- , T.intercalate "," $ map escape cols
- , ")"
- ]
- showAlterTable table (DropConstraint cname) = T.concat
- [ "ALTER TABLE "
- , escape table
- , " DROP CONSTRAINT "
- , escape cname
- ]
- showAlter :: DBName -> AlterColumn' -> Text
- showAlter table (n, ChangeType t extra) =
- T.concat
- [ "ALTER TABLE "
- , escape table
- , " ALTER COLUMN "
- , escape n
- , " TYPE "
- , showSqlType t
- , extra
- ]
- showAlter table (n, IsNull) =
- T.concat
- [ "ALTER TABLE "
- , escape table
- , " ALTER COLUMN "
- , escape n
- , " DROP NOT NULL"
- ]
- showAlter table (n, NotNull) =
- T.concat
- [ "ALTER TABLE "
- , escape table
- , " ALTER COLUMN "
- , escape n
- , " SET NOT NULL"
- ]
- showAlter table (_, Add' col) =
- T.concat
- [ "ALTER TABLE "
- , escape table
- , " ADD COLUMN "
- , showColumn col
- ]
- showAlter table (n, Drop _) =
- T.concat
- [ "ALTER TABLE "
- , escape table
- , " DROP COLUMN "
- , escape n
- ]
- showAlter table (n, Default s) =
- T.concat
- [ "ALTER TABLE "
- , escape table
- , " ALTER COLUMN "
- , escape n
- , " SET DEFAULT "
- , s
- ]
- showAlter table (n, NoDefault) = T.concat
- [ "ALTER TABLE "
- , escape table
- , " ALTER COLUMN "
- , escape n
- , " DROP DEFAULT"
- ]
- showAlter table (n, Update' s) = T.concat
- [ "UPDATE "
- , escape table
- , " SET "
- , escape n
- , "="
- , s
- , " WHERE "
- , escape n
- , " IS NULL"
- ]
- showAlter table (reftable, AddReference fkeyname t2 id2) = T.concat
- [ "ALTER TABLE "
- , escape table
- , " ADD CONSTRAINT "
- , escape fkeyname
- , " FOREIGN KEY("
- , T.intercalate "," $ map escape t2
- , ") REFERENCES "
- , escape reftable
- , "("
- , T.intercalate "," id2
- , ")"
- ]
- showAlter table (_, DropReference cname) = T.concat
- [ "ALTER TABLE "
- , escape table
- , " DROP CONSTRAINT "
- , escape cname
- ]
- -- | Get the SQL string for the table that a PeristEntity represents.
- -- Useful for raw SQL queries.
- tableName :: (PersistEntity record) => record -> Text
- tableName = escape . tableDBName
- -- | Get the SQL string for the field that an EntityField represents.
- -- Useful for raw SQL queries.
- fieldName :: (PersistEntity record) => EntityField record typ -> Text
- fieldName = escape . fieldDBName
- escape :: DBName -> Text
- escape (DBName s) =
- T.pack $ '"' : go (T.unpack s) ++ "\""
- where
- go "" = ""
- go ('"':xs) = "\"\"" ++ go xs
- go (x:xs) = x : go xs
- -- | Information required to connect to a PostgreSQL database
- -- using @persistent@'s generic facilities. These values are the
- -- same that are given to 'withPostgresqlPool'.
- data PostgresConf = PostgresConf
- { pgConnStr :: ConnectionString
- -- ^ The connection string.
- , pgPoolSize :: Int
- -- ^ How many connections should be held in the connection pool.
- } deriving (Show, Read, Data, Typeable)
- instance FromJSON PostgresConf where
- parseJSON v = modifyFailure ("Persistent: error loading PostgreSQL conf: " ++) $
- flip (withObject "PostgresConf") v $ \o -> do
- database <- o .: "database"
- host <- o .: "host"
- port <- o .:? "port" .!= 5432
- user <- o .: "user"
- password <- o .: "password"
- pool <- o .: "poolsize"
- let ci = PG.ConnectInfo
- { PG.connectHost = host
- , PG.connectPort = port
- , PG.connectUser = user
- , PG.connectPassword = password
- , PG.connectDatabase = database
- }
- cstr = PG.postgreSQLConnectionString ci
- return $ PostgresConf cstr pool
- instance PersistConfig PostgresConf where
- type PersistConfigBackend PostgresConf = SqlPersistT
- type PersistConfigPool PostgresConf = ConnectionPool
- createPoolConfig (PostgresConf cs size) = runNoLoggingT $ createPostgresqlPool cs size -- FIXME
- runPool _ = runSqlPool
- loadConfig = parseJSON
- applyEnv c0 = do
- env <- getEnvironment
- return $ addUser env
- $ addPass env
- $ addDatabase env
- $ addPort env
- $ addHost env c0
- where
- addParam param val c =
- c { pgConnStr = B8.concat [pgConnStr c, " ", param, "='", pgescape val, "'"] }
- pgescape = B8.pack . go
- where
- go ('\'':rest) = '\\' : '\'' : go rest
- go ('\\':rest) = '\\' : '\\' : go rest
- go ( x :rest) = x : go rest
- go [] = []
- maybeAddParam param envvar env =
- maybe id (addParam param) $
- lookup envvar env
- addHost = maybeAddParam "host" "PGHOST"
- addPort = maybeAddParam "port" "PGPORT"
- addUser = maybeAddParam "user" "PGUSER"
- addPass = maybeAddParam "password" "PGPASS"
- addDatabase = maybeAddParam "dbname" "PGDATABASE"
- refName :: DBName -> DBName -> DBName
- refName (DBName table) (DBName column) =
- DBName $ T.concat [table, "_", column, "_fkey"]
- udToPair :: UniqueDef -> (DBName, [DBName])
- udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud)
- mockMigrate :: [EntityDef]
- -> (Text -> IO Statement)
- -> EntityDef
- -> IO (Either [Text] [(Bool, Text)])
- mockMigrate allDefs _ entity = fmap (fmap $ map showAlterDb) $ do
- case partitionEithers [] of
- ([], old'') -> return $ Right $ migrationText False old''
- (errs, _) -> return $ Left errs
- where
- name = entityDB entity
- migrationText exists old'' =
- if not exists
- then createText newcols fdefs udspair
- else let (acs, ats) = getAlters allDefs entity (newcols, udspair) old'
- acs' = map (AlterColumn name) acs
- ats' = map (AlterTable name) ats
- in acs' ++ ats'
- where
- old' = partitionEithers old''
- (newcols', udefs, fdefs) = mkColumns allDefs entity
- newcols = filter (not . safeToRemove entity . cName) newcols'
- udspair = map udToPair udefs
- -- Check for table existence if there are no columns, workaround
- -- for https://github.com/yesodweb/persistent/issues/152
- createText newcols fdefs udspair =
- (addTable newcols entity) : uniques ++ references ++ foreignsAlt
- where
- uniques = flip concatMap udspair $ \(uname, ucols) ->
- [AlterTable name $ AddUniqueConstraint uname ucols]
- references = mapMaybe (\c@Column { cName=cname, cReference=Just (refTblName, _) } ->
- getAddReference allDefs name refTblName cname (cReference c))
- $ filter (isJust . cReference) newcols
- foreignsAlt = flip map fdefs (\fdef ->
- let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef))
- in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignConstraintNameDBName fdef) childfields (map escape parentfields)))
- -- | Mock a migration even when the database is not present.
- -- This function performs the same functionality of 'printMigration'
- -- with the difference that an actual database is not needed.
- mockMigration :: Migration -> IO ()
- mockMigration mig = do
- smap <- newIORef $ Map.empty
- let sqlbackend = SqlBackend { connPrepare = \_ -> do
- return Statement
- { stmtFinalize = return ()
- , stmtReset = return ()
- , stmtExecute = undefined
- , stmtQuery = \_ -> return $ return ()
- },
- connInsertManySql = Nothing,
- connInsertSql = undefined,
- connUpsertSql = Nothing,
- connPutManySql = Nothing,
- connStmtMap = smap,
- connClose = undefined,
- connMigrateSql = mockMigrate,
- connBegin = undefined,
- connCommit = undefined,
- connRollback = undefined,
- connEscapeName = escape,
- connNoLimit = undefined,
- connRDBMS = undefined,
- connLimitOffset = undefined,
- connLogFunc = undefined,
- connMaxParams = Nothing,
- connRepsertManySql = Nothing
- }
- result = runReaderT $ runWriterT $ runWriterT mig
- resp <- result sqlbackend
- mapM_ T.putStrLn $ map snd $ snd resp
- putManySql :: EntityDef -> Int -> Text
- putManySql ent n = putManySql' conflictColumns fields ent n
- where
- fields = entityFields ent
- conflictColumns = concatMap (map (escape . snd) . uniqueFields) (entityUniques ent)
- repsertManySql :: EntityDef -> Int -> Text
- repsertManySql ent n = putManySql' conflictColumns fields ent n
- where
- fields = keyAndEntityFields ent
- conflictColumns = escape . fieldDB <$> entityKeyFields ent
- putManySql' :: [Text] -> [FieldDef] -> EntityDef -> Int -> Text
- putManySql' conflictColumns fields ent n = q
- where
- fieldDbToText = escape . fieldDB
- mkAssignment f = T.concat [f, "=EXCLUDED.", f]
- table = escape . entityDB $ ent
- columns = Util.commaSeparated $ map fieldDbToText fields
- placeholders = map (const "?") fields
- updates = map (mkAssignment . fieldDbToText) fields
- q = T.concat
- [ "INSERT INTO "
- , table
- , Util.parenWrapped columns
- , " VALUES "
- , Util.commaSeparated . replicate n
- . Util.parenWrapped . Util.commaSeparated $ placeholders
- , " ON CONFLICT "
- , Util.parenWrapped . Util.commaSeparated $ conflictColumns
- , " DO UPDATE SET "
- , Util.commaSeparated updates
- ]
- -- | Enable a Postgres extension. See https://www.postgresql.org/docs/current/static/contrib.html
- -- for a list.
- migrateEnableExtension :: Text -> Migration
- migrateEnableExtension extName = WriterT $ WriterT $ do
- res :: [Single Int] <-
- rawSql "SELECT COUNT(*) FROM pg_catalog.pg_extension WHERE extname = ?" [PersistText extName]
- if res == [Single 0]
- then return (((), []) , [(False, "CREATe EXTENSION \"" <> extName <> "\"")])
- else return (((), []), [])
|