MySQL.hs 52 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317
  1. {-# LANGUAGE ExistentialQuantification #-}
  2. {-# LANGUAGE FlexibleContexts #-}
  3. {-# LANGUAGE GADTs #-}
  4. {-# LANGUAGE OverloadedStrings #-}
  5. {-# LANGUAGE PatternSynonyms #-}
  6. {-# LANGUAGE TypeFamilies #-}
  7. -- | A MySQL backend for @persistent@.
  8. module Database.Persist.MySQL
  9. ( withMySQLPool
  10. , withMySQLConn
  11. , createMySQLPool
  12. , module Database.Persist.Sql
  13. , MySQL.ConnectInfo(..)
  14. , MySQLBase.SSLInfo(..)
  15. , MySQL.defaultConnectInfo
  16. , MySQLBase.defaultSSLInfo
  17. , MySQLConf(..)
  18. , mockMigration
  19. -- * @ON DUPLICATE KEY UPDATE@ Functionality
  20. , insertOnDuplicateKeyUpdate
  21. , insertManyOnDuplicateKeyUpdate
  22. , HandleUpdateCollision
  23. , pattern SomeField
  24. , SomeField
  25. , copyField
  26. , copyUnlessNull
  27. , copyUnlessEmpty
  28. , copyUnlessEq
  29. ) where
  30. import qualified Blaze.ByteString.Builder.Char8 as BBB
  31. import qualified Blaze.ByteString.Builder.ByteString as BBS
  32. import Control.Arrow
  33. import Control.Monad
  34. import Control.Monad.IO.Class (MonadIO (..))
  35. import Control.Monad.IO.Unlift (MonadUnliftIO)
  36. import Control.Monad.Logger (MonadLogger, runNoLoggingT)
  37. import Control.Monad.Trans.Class (lift)
  38. import Control.Monad.Trans.Except (runExceptT)
  39. import Control.Monad.Trans.Reader (runReaderT, ReaderT)
  40. import Control.Monad.Trans.Writer (runWriterT)
  41. import Data.Conduit
  42. import qualified Data.Conduit.List as CL
  43. import Data.Acquire (Acquire, mkAcquire, with)
  44. import Data.Aeson
  45. import Data.Aeson.Types (modifyFailure)
  46. import Data.ByteString (ByteString)
  47. import Data.Either (partitionEithers)
  48. import Data.Fixed (Pico)
  49. import Data.Function (on)
  50. import Data.Int (Int64)
  51. import Data.IORef
  52. import Data.List (find, intercalate, sort, groupBy)
  53. import qualified Data.Map as Map
  54. import Data.Monoid ((<>))
  55. import qualified Data.Monoid as Monoid
  56. import Data.Pool (Pool)
  57. import Data.Text (Text, pack)
  58. import qualified Data.Text as T
  59. import qualified Data.Text.Encoding as T
  60. import qualified Data.Text.IO as T
  61. import Text.Read (readMaybe)
  62. import System.Environment (getEnvironment)
  63. import Database.Persist.Sql
  64. import Database.Persist.Sql.Types.Internal (makeIsolationLevelStatement)
  65. import qualified Database.Persist.Sql.Util as Util
  66. import qualified Database.MySQL.Base as MySQLBase
  67. import qualified Database.MySQL.Base.Types as MySQLBase
  68. import qualified Database.MySQL.Simple as MySQL
  69. import qualified Database.MySQL.Simple.Param as MySQL
  70. import qualified Database.MySQL.Simple.Result as MySQL
  71. import qualified Database.MySQL.Simple.Types as MySQL
  72. -- | Create a MySQL connection pool and run the given action.
  73. -- The pool is properly released after the action finishes using
  74. -- it. Note that you should not use the given 'ConnectionPool'
  75. -- outside the action since it may be already been released.
  76. withMySQLPool :: (MonadLogger m, MonadUnliftIO m)
  77. => MySQL.ConnectInfo
  78. -- ^ Connection information.
  79. -> Int
  80. -- ^ Number of connections to be kept open in the pool.
  81. -> (Pool SqlBackend -> m a)
  82. -- ^ Action to be executed that uses the connection pool.
  83. -> m a
  84. withMySQLPool ci = withSqlPool $ open' ci
  85. -- | Create a MySQL connection pool. Note that it's your
  86. -- responsibility to properly close the connection pool when
  87. -- unneeded. Use 'withMySQLPool' for automatic resource control.
  88. createMySQLPool :: (MonadUnliftIO m, MonadLogger m)
  89. => MySQL.ConnectInfo
  90. -- ^ Connection information.
  91. -> Int
  92. -- ^ Number of connections to be kept open in the pool.
  93. -> m (Pool SqlBackend)
  94. createMySQLPool ci = createSqlPool $ open' ci
  95. -- | Same as 'withMySQLPool', but instead of opening a pool
  96. -- of connections, only one connection is opened.
  97. withMySQLConn :: (MonadUnliftIO m, MonadLogger m)
  98. => MySQL.ConnectInfo
  99. -- ^ Connection information.
  100. -> (SqlBackend -> m a)
  101. -- ^ Action to be executed that uses the connection.
  102. -> m a
  103. withMySQLConn = withSqlConn . open'
  104. -- | Internal function that opens a connection to the MySQL
  105. -- server.
  106. open' :: MySQL.ConnectInfo -> LogFunc -> IO SqlBackend
  107. open' ci logFunc = do
  108. conn <- MySQL.connect ci
  109. MySQLBase.autocommit conn False -- disable autocommit!
  110. smap <- newIORef $ Map.empty
  111. return $ SqlBackend
  112. { connPrepare = prepare' conn
  113. , connStmtMap = smap
  114. , connInsertSql = insertSql'
  115. , connInsertManySql = Nothing
  116. , connUpsertSql = Nothing
  117. , connPutManySql = Just putManySql
  118. , connClose = MySQL.close conn
  119. , connMigrateSql = migrate' ci
  120. , connBegin = \_ mIsolation -> do
  121. forM_ mIsolation $ \iso -> MySQL.execute_ conn (makeIsolationLevelStatement iso)
  122. MySQL.execute_ conn "start transaction" >> return ()
  123. , connCommit = const $ MySQL.commit conn
  124. , connRollback = const $ MySQL.rollback conn
  125. , connEscapeName = pack . escapeDBName
  126. , connNoLimit = "LIMIT 18446744073709551615"
  127. -- This noLimit is suggested by MySQL's own docs, see
  128. -- <http://dev.mysql.com/doc/refman/5.5/en/select.html>
  129. , connRDBMS = "mysql"
  130. , connLimitOffset = decorateSQLWithLimitOffset "LIMIT 18446744073709551615"
  131. , connLogFunc = logFunc
  132. , connMaxParams = Nothing
  133. , connRepsertManySql = Just repsertManySql
  134. , connInsertUniqueSql = Nothing
  135. }
  136. -- | Prepare a query. We don't support prepared statements, but
  137. -- we'll do some client-side preprocessing here.
  138. prepare' :: MySQL.Connection -> Text -> IO Statement
  139. prepare' conn sql = do
  140. let query = MySQL.Query (T.encodeUtf8 sql)
  141. return Statement
  142. { stmtFinalize = return ()
  143. , stmtReset = return ()
  144. , stmtExecute = execute' conn query
  145. , stmtQuery = withStmt' conn query
  146. }
  147. -- | SQL code to be executed when inserting an entity.
  148. insertSql' :: EntityDef -> [PersistValue] -> InsertSqlResult
  149. insertSql' ent vals =
  150. let sql = pack $ concat
  151. [ "INSERT INTO "
  152. , escapeDBName $ entityDB ent
  153. , "("
  154. , intercalate "," $ map (escapeDBName . fieldDB) $ entityFields ent
  155. , ") VALUES("
  156. , intercalate "," (map (const "?") $ entityFields ent)
  157. , ")"
  158. ]
  159. in case entityPrimary ent of
  160. Just _ -> ISRManyKeys sql vals
  161. Nothing -> ISRInsertGet sql "SELECT LAST_INSERT_ID()"
  162. -- | Execute an statement that doesn't return any results.
  163. execute' :: MySQL.Connection -> MySQL.Query -> [PersistValue] -> IO Int64
  164. execute' conn query vals = MySQL.execute conn query (map P vals)
  165. -- | Execute an statement that does return results. The results
  166. -- are fetched all at once and stored into memory.
  167. withStmt' :: MonadIO m
  168. => MySQL.Connection
  169. -> MySQL.Query
  170. -> [PersistValue]
  171. -> Acquire (ConduitM () [PersistValue] m ())
  172. withStmt' conn query vals = do
  173. result <- mkAcquire createResult MySQLBase.freeResult
  174. return $ fetchRows result >>= CL.sourceList
  175. where
  176. createResult = do
  177. -- Execute the query
  178. formatted <- MySQL.formatQuery conn query (map P vals)
  179. MySQLBase.query conn formatted
  180. MySQLBase.storeResult conn
  181. fetchRows result = liftIO $ do
  182. -- Find out the type of the columns
  183. fields <- MySQLBase.fetchFields result
  184. let getters = [ maybe PersistNull (getGetter f f . Just) | f <- fields]
  185. convert = use getters
  186. where use (g:gs) (col:cols) =
  187. let v = g col
  188. vs = use gs cols
  189. in v `seq` vs `seq` (v:vs)
  190. use _ _ = []
  191. -- Ready to go!
  192. let go acc = do
  193. row <- MySQLBase.fetchRow result
  194. case row of
  195. [] -> return (acc [])
  196. _ -> let converted = convert row
  197. in converted `seq` go (acc . (converted:))
  198. go id
  199. -- | @newtype@ around 'PersistValue' that supports the
  200. -- 'MySQL.Param' type class.
  201. newtype P = P PersistValue
  202. instance MySQL.Param P where
  203. render (P (PersistText t)) = MySQL.render t
  204. render (P (PersistByteString bs)) = MySQL.render bs
  205. render (P (PersistInt64 i)) = MySQL.render i
  206. render (P (PersistDouble d)) = MySQL.render d
  207. render (P (PersistBool b)) = MySQL.render b
  208. render (P (PersistDay d)) = MySQL.render d
  209. render (P (PersistTimeOfDay t)) = MySQL.render t
  210. render (P (PersistUTCTime t)) = MySQL.render t
  211. render (P PersistNull) = MySQL.render MySQL.Null
  212. render (P (PersistList l)) = MySQL.render $ listToJSON l
  213. render (P (PersistMap m)) = MySQL.render $ mapToJSON m
  214. render (P (PersistRational r)) =
  215. MySQL.Plain $ BBB.fromString $ show (fromRational r :: Pico)
  216. -- FIXME: Too Ambigous, can not select precision without information about field
  217. render (P (PersistDbSpecific s)) = MySQL.Plain $ BBS.fromByteString s
  218. render (P (PersistArray a)) = MySQL.render (P (PersistList a))
  219. render (P (PersistObjectId _)) =
  220. error "Refusing to serialize a PersistObjectId to a MySQL value"
  221. -- | @Getter a@ is a function that converts an incoming value
  222. -- into a data type @a@.
  223. type Getter a = MySQLBase.Field -> Maybe ByteString -> a
  224. -- | Helper to construct 'Getter'@s@ using 'MySQL.Result'.
  225. convertPV :: MySQL.Result a => (a -> b) -> Getter b
  226. convertPV f = (f .) . MySQL.convert
  227. -- | Get the corresponding @'Getter' 'PersistValue'@ depending on
  228. -- the type of the column.
  229. getGetter :: MySQLBase.Field -> Getter PersistValue
  230. getGetter field = go (MySQLBase.fieldType field)
  231. (MySQLBase.fieldLength field)
  232. (MySQLBase.fieldCharSet field)
  233. where
  234. -- Bool
  235. go MySQLBase.Tiny 1 _ = convertPV PersistBool
  236. go MySQLBase.Tiny _ _ = convertPV PersistInt64
  237. -- Int64
  238. go MySQLBase.Int24 _ _ = convertPV PersistInt64
  239. go MySQLBase.Short _ _ = convertPV PersistInt64
  240. go MySQLBase.Long _ _ = convertPV PersistInt64
  241. go MySQLBase.LongLong _ _ = convertPV PersistInt64
  242. -- Double
  243. go MySQLBase.Float _ _ = convertPV PersistDouble
  244. go MySQLBase.Double _ _ = convertPV PersistDouble
  245. go MySQLBase.Decimal _ _ = convertPV PersistDouble
  246. go MySQLBase.NewDecimal _ _ = convertPV PersistDouble
  247. -- ByteString and Text
  248. -- The MySQL C client (and by extension the Haskell mysql package) doesn't distinguish between binary and non-binary string data at the type level.
  249. -- (e.g. both BLOB and TEXT have the MySQLBase.Blob type).
  250. -- Instead, the character set distinguishes them. Binary data uses character set number 63.
  251. -- See https://dev.mysql.com/doc/refman/5.6/en/c-api-data-structures.html (Search for "63")
  252. go MySQLBase.VarChar _ 63 = convertPV PersistByteString
  253. go MySQLBase.VarString _ 63 = convertPV PersistByteString
  254. go MySQLBase.String _ 63 = convertPV PersistByteString
  255. go MySQLBase.VarChar _ _ = convertPV PersistText
  256. go MySQLBase.VarString _ _ = convertPV PersistText
  257. go MySQLBase.String _ _ = convertPV PersistText
  258. go MySQLBase.Blob _ 63 = convertPV PersistByteString
  259. go MySQLBase.TinyBlob _ 63 = convertPV PersistByteString
  260. go MySQLBase.MediumBlob _ 63 = convertPV PersistByteString
  261. go MySQLBase.LongBlob _ 63 = convertPV PersistByteString
  262. go MySQLBase.Blob _ _ = convertPV PersistText
  263. go MySQLBase.TinyBlob _ _ = convertPV PersistText
  264. go MySQLBase.MediumBlob _ _ = convertPV PersistText
  265. go MySQLBase.LongBlob _ _ = convertPV PersistText
  266. -- Time-related
  267. go MySQLBase.Time _ _ = convertPV PersistTimeOfDay
  268. go MySQLBase.DateTime _ _ = convertPV PersistUTCTime
  269. go MySQLBase.Timestamp _ _ = convertPV PersistUTCTime
  270. go MySQLBase.Date _ _ = convertPV PersistDay
  271. go MySQLBase.NewDate _ _ = convertPV PersistDay
  272. go MySQLBase.Year _ _ = convertPV PersistDay
  273. -- Null
  274. go MySQLBase.Null _ _ = \_ _ -> PersistNull
  275. -- Controversial conversions
  276. go MySQLBase.Set _ _ = convertPV PersistText
  277. go MySQLBase.Enum _ _ = convertPV PersistText
  278. -- Conversion using PersistDbSpecific
  279. go MySQLBase.Geometry _ _ = \_ m ->
  280. case m of
  281. Just g -> PersistDbSpecific g
  282. Nothing -> error "Unexpected null in database specific value"
  283. -- Unsupported
  284. go other _ _ = error $ "MySQL.getGetter: type " ++
  285. show other ++ " not supported."
  286. ----------------------------------------------------------------------
  287. -- | Create the migration plan for the given 'PersistEntity'
  288. -- @val@.
  289. migrate' :: MySQL.ConnectInfo
  290. -> [EntityDef]
  291. -> (Text -> IO Statement)
  292. -> EntityDef
  293. -> IO (Either [Text] [(Bool, Text)])
  294. migrate' connectInfo allDefs getter val = do
  295. let name = entityDB val
  296. (idClmn, old) <- getColumns connectInfo getter val
  297. let (newcols, udefs, fdefs) = mkColumns allDefs val
  298. let udspair = map udToPair udefs
  299. case (idClmn, old, partitionEithers old) of
  300. -- Nothing found, create everything
  301. ([], [], _) -> do
  302. let uniques = flip concatMap udspair $ \(uname, ucols) ->
  303. [ AlterTable name $
  304. AddUniqueConstraint uname $
  305. map (findTypeAndMaxLen name) ucols ]
  306. let foreigns = do
  307. Column { cName=cname, cReference=Just (refTblName, _a) } <- newcols
  308. return $ AlterColumn name (refTblName, addReference allDefs (refName name cname) refTblName cname)
  309. let foreignsAlt = map (\fdef -> let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef))
  310. in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignRefTableDBName fdef) (foreignConstraintNameDBName fdef) childfields parentfields)) fdefs
  311. return $ Right $ map showAlterDb $ (addTable newcols val): uniques ++ foreigns ++ foreignsAlt
  312. -- No errors and something found, migrate
  313. (_, _, ([], old')) -> do
  314. let excludeForeignKeys (xs,ys) = (map (\c -> case cReference c of
  315. Just (_,fk) -> case find (\f -> fk == foreignConstraintNameDBName f) fdefs of
  316. Just _ -> c { cReference = Nothing }
  317. Nothing -> c
  318. Nothing -> c) xs,ys)
  319. (acs, ats) = getAlters allDefs name (newcols, udspair) $ excludeForeignKeys $ partitionEithers old'
  320. acs' = map (AlterColumn name) acs
  321. ats' = map (AlterTable name) ats
  322. return $ Right $ map showAlterDb $ acs' ++ ats'
  323. -- Errors
  324. (_, _, (errs, _)) -> return $ Left errs
  325. where
  326. findTypeAndMaxLen tblName col = let (col', ty) = findTypeOfColumn allDefs tblName col
  327. (_, ml) = findMaxLenOfColumn allDefs tblName col
  328. in (col', ty, ml)
  329. addTable :: [Column] -> EntityDef -> AlterDB
  330. addTable cols entity = AddTable $ concat
  331. -- Lower case e: see Database.Persist.Sql.Migration
  332. [ "CREATe TABLE "
  333. , escapeDBName name
  334. , "("
  335. , idtxt
  336. , if null cols then [] else ","
  337. , intercalate "," $ map showColumn cols
  338. , ")"
  339. ]
  340. where
  341. name = entityDB entity
  342. idtxt = case entityPrimary entity of
  343. Just pdef -> concat [" PRIMARY KEY (", intercalate "," $ map (escapeDBName . fieldDB) $ compositeFields pdef, ")"]
  344. Nothing ->
  345. let defText = defaultAttribute $ fieldAttrs $ entityId entity
  346. sType = fieldSqlType $ entityId entity
  347. autoIncrementText = case (sType, defText) of
  348. (SqlInt64, Nothing) -> " AUTO_INCREMENT"
  349. _ -> ""
  350. maxlen = findMaxLenOfField (entityId entity)
  351. in concat
  352. [ escapeDBName $ fieldDB $ entityId entity
  353. , " " <> showSqlType sType maxlen False
  354. , " NOT NULL"
  355. , autoIncrementText
  356. , " PRIMARY KEY"
  357. ]
  358. -- | Find out the type of a column.
  359. findTypeOfColumn :: [EntityDef] -> DBName -> DBName -> (DBName, FieldType)
  360. findTypeOfColumn allDefs name col =
  361. maybe (error $ "Could not find type of column " ++
  362. show col ++ " on table " ++ show name ++
  363. " (allDefs = " ++ show allDefs ++ ")")
  364. ((,) col) $ do
  365. entDef <- find ((== name) . entityDB) allDefs
  366. fieldDef <- find ((== col) . fieldDB) (entityFields entDef)
  367. return (fieldType fieldDef)
  368. -- | Find out the maxlen of a column (default to 200)
  369. findMaxLenOfColumn :: [EntityDef] -> DBName -> DBName -> (DBName, Integer)
  370. findMaxLenOfColumn allDefs name col =
  371. maybe (col, 200)
  372. ((,) col) $ do
  373. entDef <- find ((== name) . entityDB) allDefs
  374. fieldDef <- find ((== col) . fieldDB) (entityFields entDef)
  375. findMaxLenOfField fieldDef
  376. -- | Find out the maxlen of a field
  377. findMaxLenOfField :: FieldDef -> Maybe Integer
  378. findMaxLenOfField fieldDef = do
  379. maxLenAttr <- find ((T.isPrefixOf "maxlen=") . T.toLower) (fieldAttrs fieldDef)
  380. readMaybe . T.unpack . T.drop 7 $ maxLenAttr
  381. -- | Helper for 'AddReference' that finds out the which primary key columns to reference.
  382. addReference :: [EntityDef] -> DBName -> DBName -> DBName -> AlterColumn
  383. addReference allDefs fkeyname reftable cname = AddReference reftable fkeyname [cname] referencedColumns
  384. where
  385. referencedColumns = maybe (error $ "Could not find ID of entity " ++ show reftable
  386. ++ " (allDefs = " ++ show allDefs ++ ")")
  387. id $ do
  388. entDef <- find ((== reftable) . entityDB) allDefs
  389. return $ map fieldDB $ entityKeyFields entDef
  390. data AlterColumn = Change Column
  391. | Add' Column
  392. | Drop
  393. | Default String
  394. | NoDefault
  395. | Update' String
  396. -- | See the definition of the 'showAlter' function to see how these fields are used.
  397. | AddReference
  398. DBName -- Referenced table
  399. DBName -- Foreign key name
  400. [DBName] -- Referencing columns
  401. [DBName] -- Referenced columns
  402. | DropReference DBName
  403. type AlterColumn' = (DBName, AlterColumn)
  404. data AlterTable = AddUniqueConstraint DBName [(DBName, FieldType, Integer)]
  405. | DropUniqueConstraint DBName
  406. data AlterDB = AddTable String
  407. | AlterColumn DBName AlterColumn'
  408. | AlterTable DBName AlterTable
  409. udToPair :: UniqueDef -> (DBName, [DBName])
  410. udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud)
  411. ----------------------------------------------------------------------
  412. -- | Returns all of the 'Column'@s@ in the given table currently
  413. -- in the database.
  414. getColumns :: MySQL.ConnectInfo
  415. -> (Text -> IO Statement)
  416. -> EntityDef
  417. -> IO ( [Either Text (Either Column (DBName, [DBName]))] -- ID column
  418. , [Either Text (Either Column (DBName, [DBName]))] -- everything else
  419. )
  420. getColumns connectInfo getter def = do
  421. -- Find out ID column.
  422. stmtIdClmn <- getter $ T.concat
  423. [ "SELECT COLUMN_NAME, "
  424. , "IS_NULLABLE, "
  425. , "DATA_TYPE, "
  426. , "COLUMN_DEFAULT "
  427. , "FROM INFORMATION_SCHEMA.COLUMNS "
  428. , "WHERE TABLE_SCHEMA = ? "
  429. , "AND TABLE_NAME = ? "
  430. , "AND COLUMN_NAME = ?"
  431. ]
  432. inter1 <- with (stmtQuery stmtIdClmn vals) (\src -> runConduit $ src .| CL.consume)
  433. ids <- runConduitRes $ CL.sourceList inter1 .| helperClmns -- avoid nested queries
  434. -- Find out all columns.
  435. stmtClmns <- getter $ T.concat
  436. [ "SELECT COLUMN_NAME, "
  437. , "IS_NULLABLE, "
  438. , "DATA_TYPE, "
  439. , "COLUMN_TYPE, "
  440. , "CHARACTER_MAXIMUM_LENGTH, "
  441. , "NUMERIC_PRECISION, "
  442. , "NUMERIC_SCALE, "
  443. , "COLUMN_DEFAULT "
  444. , "FROM INFORMATION_SCHEMA.COLUMNS "
  445. , "WHERE TABLE_SCHEMA = ? "
  446. , "AND TABLE_NAME = ? "
  447. , "AND COLUMN_NAME <> ?"
  448. ]
  449. inter2 <- with (stmtQuery stmtClmns vals) (\src -> runConduitRes $ src .| CL.consume)
  450. cs <- runConduitRes $ CL.sourceList inter2 .| helperClmns -- avoid nested queries
  451. -- Find out the constraints.
  452. stmtCntrs <- getter $ T.concat
  453. [ "SELECT CONSTRAINT_NAME, "
  454. , "COLUMN_NAME "
  455. , "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE "
  456. , "WHERE TABLE_SCHEMA = ? "
  457. , "AND TABLE_NAME = ? "
  458. , "AND COLUMN_NAME <> ? "
  459. , "AND CONSTRAINT_NAME <> 'PRIMARY' "
  460. , "AND REFERENCED_TABLE_SCHEMA IS NULL "
  461. , "ORDER BY CONSTRAINT_NAME, "
  462. , "COLUMN_NAME"
  463. ]
  464. us <- with (stmtQuery stmtCntrs vals) (\src -> runConduitRes $ src .| helperCntrs)
  465. -- Return both
  466. return (ids, cs ++ us)
  467. where
  468. vals = [ PersistText $ pack $ MySQL.connectDatabase connectInfo
  469. , PersistText $ unDBName $ entityDB def
  470. , PersistText $ unDBName $ fieldDB $ entityId def ]
  471. helperClmns = CL.mapM getIt .| CL.consume
  472. where
  473. getIt = fmap (either Left (Right . Left)) .
  474. liftIO .
  475. getColumn connectInfo getter (entityDB def)
  476. helperCntrs = do
  477. let check [ PersistText cntrName
  478. , PersistText clmnName] = return ( cntrName, clmnName )
  479. check other = fail $ "helperCntrs: unexpected " ++ show other
  480. rows <- mapM check =<< CL.consume
  481. return $ map (Right . Right . (DBName . fst . head &&& map (DBName . snd)))
  482. $ groupBy ((==) `on` fst) rows
  483. -- | Get the information about a column in a table.
  484. getColumn :: MySQL.ConnectInfo
  485. -> (Text -> IO Statement)
  486. -> DBName
  487. -> [PersistValue]
  488. -> IO (Either Text Column)
  489. getColumn connectInfo getter tname [ PersistText cname
  490. , PersistText null_
  491. , PersistText dataType
  492. , PersistText colType
  493. , colMaxLen
  494. , colPrecision
  495. , colScale
  496. , default'] =
  497. fmap (either (Left . pack) Right) $
  498. runExceptT $ do
  499. -- Default value
  500. default_ <- case default' of
  501. PersistNull -> return Nothing
  502. PersistText t -> return (Just t)
  503. PersistByteString bs ->
  504. case T.decodeUtf8' bs of
  505. Left exc -> fail $ "Invalid default column: " ++
  506. show default' ++ " (error: " ++
  507. show exc ++ ")"
  508. Right t -> return (Just t)
  509. _ -> fail $ "Invalid default column: " ++ show default'
  510. -- Foreign key (if any)
  511. stmt <- lift . getter $ T.concat
  512. [ "SELECT REFERENCED_TABLE_NAME, "
  513. , "CONSTRAINT_NAME, "
  514. , "ORDINAL_POSITION "
  515. , "FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE "
  516. , "WHERE TABLE_SCHEMA = ? "
  517. , "AND TABLE_NAME = ? "
  518. , "AND COLUMN_NAME = ? "
  519. , "AND REFERENCED_TABLE_SCHEMA = ? "
  520. , "ORDER BY CONSTRAINT_NAME, "
  521. , "COLUMN_NAME"
  522. ]
  523. let vars = [ PersistText $ pack $ MySQL.connectDatabase connectInfo
  524. , PersistText $ unDBName $ tname
  525. , PersistText cname
  526. , PersistText $ pack $ MySQL.connectDatabase connectInfo ]
  527. cntrs <- liftIO $ with (stmtQuery stmt vars) (\src -> runConduit $ src .| CL.consume)
  528. ref <- case cntrs of
  529. [] -> return Nothing
  530. [[PersistText tab, PersistText ref, PersistInt64 pos]] ->
  531. return $ if pos == 1 then Just (DBName tab, DBName ref) else Nothing
  532. _ -> fail "MySQL.getColumn/getRef: never here"
  533. let colMaxLen' = case colMaxLen of
  534. PersistInt64 l -> Just (fromIntegral l)
  535. _ -> Nothing
  536. ci = ColumnInfo
  537. { ciColumnType = colType
  538. , ciMaxLength = colMaxLen'
  539. , ciNumericPrecision = colPrecision
  540. , ciNumericScale = colScale
  541. }
  542. (typ, maxLen) <- parseColumnType dataType ci
  543. -- Okay!
  544. return Column
  545. { cName = DBName $ cname
  546. , cNull = null_ == "YES"
  547. , cSqlType = typ
  548. , cDefault = default_
  549. , cDefaultConstraintName = Nothing
  550. , cMaxLen = maxLen
  551. , cReference = ref
  552. }
  553. getColumn _ _ _ x =
  554. return $ Left $ pack $ "Invalid result from INFORMATION_SCHEMA: " ++ show x
  555. -- | Extra column information from MySQL schema
  556. data ColumnInfo = ColumnInfo
  557. { ciColumnType :: Text
  558. , ciMaxLength :: Maybe Integer
  559. , ciNumericPrecision :: PersistValue
  560. , ciNumericScale :: PersistValue
  561. }
  562. -- | Parse the type of column as returned by MySQL's
  563. -- @INFORMATION_SCHEMA@ tables.
  564. parseColumnType :: Monad m => Text -> ColumnInfo -> m (SqlType, Maybe Integer)
  565. -- Ints
  566. parseColumnType "tinyint" ci | ciColumnType ci == "tinyint(1)" = return (SqlBool, Nothing)
  567. parseColumnType "int" ci | ciColumnType ci == "int(11)" = return (SqlInt32, Nothing)
  568. parseColumnType "bigint" ci | ciColumnType ci == "bigint(20)" = return (SqlInt64, Nothing)
  569. -- Double
  570. parseColumnType x@("double") ci | ciColumnType ci == x = return (SqlReal, Nothing)
  571. parseColumnType "decimal" ci =
  572. case (ciNumericPrecision ci, ciNumericScale ci) of
  573. (PersistInt64 p, PersistInt64 s) ->
  574. return (SqlNumeric (fromIntegral p) (fromIntegral s), Nothing)
  575. _ ->
  576. fail "missing DECIMAL precision in DB schema"
  577. -- Text
  578. parseColumnType "varchar" ci = return (SqlString, ciMaxLength ci)
  579. parseColumnType "text" _ = return (SqlString, Nothing)
  580. -- ByteString
  581. parseColumnType "varbinary" ci = return (SqlBlob, ciMaxLength ci)
  582. parseColumnType "blob" _ = return (SqlBlob, Nothing)
  583. -- Time-related
  584. parseColumnType "time" _ = return (SqlTime, Nothing)
  585. parseColumnType "datetime" _ = return (SqlDayTime, Nothing)
  586. parseColumnType "date" _ = return (SqlDay, Nothing)
  587. parseColumnType _ ci = return (SqlOther (ciColumnType ci), Nothing)
  588. ----------------------------------------------------------------------
  589. -- | @getAlters allDefs tblName new old@ finds out what needs to
  590. -- be changed from @old@ to become @new@.
  591. getAlters :: [EntityDef]
  592. -> DBName
  593. -> ([Column], [(DBName, [DBName])])
  594. -> ([Column], [(DBName, [DBName])])
  595. -> ([AlterColumn'], [AlterTable])
  596. getAlters allDefs tblName (c1, u1) (c2, u2) =
  597. (getAltersC c1 c2, getAltersU u1 u2)
  598. where
  599. getAltersC [] old = concatMap dropColumn old
  600. getAltersC (new:news) old =
  601. let (alters, old') = findAlters tblName allDefs new old
  602. in alters ++ getAltersC news old'
  603. dropColumn col =
  604. map ((,) (cName col)) $
  605. [DropReference n | Just (_, n) <- [cReference col]] ++
  606. [Drop]
  607. getAltersU [] old = map (DropUniqueConstraint . fst) old
  608. getAltersU ((name, cols):news) old =
  609. case lookup name old of
  610. Nothing ->
  611. AddUniqueConstraint name (map findTypeAndMaxLen cols) : getAltersU news old
  612. Just ocols ->
  613. let old' = filter (\(x, _) -> x /= name) old
  614. in if sort cols == ocols
  615. then getAltersU news old'
  616. else DropUniqueConstraint name
  617. : AddUniqueConstraint name (map findTypeAndMaxLen cols)
  618. : getAltersU news old'
  619. where
  620. findTypeAndMaxLen col = let (col', ty) = findTypeOfColumn allDefs tblName col
  621. (_, ml) = findMaxLenOfColumn allDefs tblName col
  622. in (col', ty, ml)
  623. -- | @findAlters newColumn oldColumns@ finds out what needs to be
  624. -- changed in the columns @oldColumns@ for @newColumn@ to be
  625. -- supported.
  626. findAlters :: DBName -> [EntityDef] -> Column -> [Column] -> ([AlterColumn'], [Column])
  627. findAlters tblName allDefs col@(Column name isNull type_ def _defConstraintName maxLen ref) cols =
  628. case filter ((name ==) . cName) cols of
  629. -- new fkey that didnt exist before
  630. [] -> case ref of
  631. Nothing -> ([(name, Add' col)],[])
  632. Just (tname, _b) -> let cnstr = [addReference allDefs (refName tblName name) tname name]
  633. in (map ((,) tname) (Add' col : cnstr), cols)
  634. Column _ isNull' type_' def' _defConstraintName' maxLen' ref':_ ->
  635. let -- Foreign key
  636. refDrop = case (ref == ref', ref') of
  637. (False, Just (_, cname)) -> [(name, DropReference cname)]
  638. _ -> []
  639. refAdd = case (ref == ref', ref) of
  640. (False, Just (tname, _cname)) -> [(tname, addReference allDefs (refName tblName name) tname name)]
  641. _ -> []
  642. -- Type and nullability
  643. modType | showSqlType type_ maxLen False `ciEquals` showSqlType type_' maxLen' False && isNull == isNull' = []
  644. | otherwise = [(name, Change col)]
  645. -- Default value
  646. -- Avoid DEFAULT NULL, since it is always unnecessary, and is an error for text/blob fields
  647. modDef | def == def' = []
  648. | otherwise = case def of
  649. Nothing -> [(name, NoDefault)]
  650. Just s -> if T.toUpper s == "NULL" then []
  651. else [(name, Default $ T.unpack s)]
  652. in ( refDrop ++ modType ++ modDef ++ refAdd
  653. , filter ((name /=) . cName) cols )
  654. where
  655. ciEquals x y = T.toCaseFold (T.pack x) == T.toCaseFold (T.pack y)
  656. ----------------------------------------------------------------------
  657. -- | Prints the part of a @CREATE TABLE@ statement about a given
  658. -- column.
  659. showColumn :: Column -> String
  660. showColumn (Column n nu t def _defConstraintName maxLen ref) = concat
  661. [ escapeDBName n
  662. , " "
  663. , showSqlType t maxLen True
  664. , " "
  665. , if nu then "NULL" else "NOT NULL"
  666. , case def of
  667. Nothing -> ""
  668. Just s -> -- Avoid DEFAULT NULL, since it is always unnecessary, and is an error for text/blob fields
  669. if T.toUpper s == "NULL" then ""
  670. else " DEFAULT " ++ T.unpack s
  671. , case ref of
  672. Nothing -> ""
  673. Just (s, _) -> " REFERENCES " ++ escapeDBName s
  674. ]
  675. -- | Renders an 'SqlType' in MySQL's format.
  676. showSqlType :: SqlType
  677. -> Maybe Integer -- ^ @maxlen@
  678. -> Bool -- ^ include character set information?
  679. -> String
  680. showSqlType SqlBlob Nothing _ = "BLOB"
  681. showSqlType SqlBlob (Just i) _ = "VARBINARY(" ++ show i ++ ")"
  682. showSqlType SqlBool _ _ = "TINYINT(1)"
  683. showSqlType SqlDay _ _ = "DATE"
  684. showSqlType SqlDayTime _ _ = "DATETIME"
  685. showSqlType SqlInt32 _ _ = "INT(11)"
  686. showSqlType SqlInt64 _ _ = "BIGINT"
  687. showSqlType SqlReal _ _ = "DOUBLE"
  688. showSqlType (SqlNumeric s prec) _ _ = "NUMERIC(" ++ show s ++ "," ++ show prec ++ ")"
  689. showSqlType SqlString Nothing True = "TEXT CHARACTER SET utf8"
  690. showSqlType SqlString Nothing False = "TEXT"
  691. showSqlType SqlString (Just i) True = "VARCHAR(" ++ show i ++ ") CHARACTER SET utf8"
  692. showSqlType SqlString (Just i) False = "VARCHAR(" ++ show i ++ ")"
  693. showSqlType SqlTime _ _ = "TIME"
  694. showSqlType (SqlOther t) _ _ = T.unpack t
  695. -- | Render an action that must be done on the database.
  696. showAlterDb :: AlterDB -> (Bool, Text)
  697. showAlterDb (AddTable s) = (False, pack s)
  698. showAlterDb (AlterColumn t (c, ac)) =
  699. (isUnsafe ac, pack $ showAlter t (c, ac))
  700. where
  701. isUnsafe Drop = True
  702. isUnsafe _ = False
  703. showAlterDb (AlterTable t at) = (False, pack $ showAlterTable t at)
  704. -- | Render an action that must be done on a table.
  705. showAlterTable :: DBName -> AlterTable -> String
  706. showAlterTable table (AddUniqueConstraint cname cols) = concat
  707. [ "ALTER TABLE "
  708. , escapeDBName table
  709. , " ADD CONSTRAINT "
  710. , escapeDBName cname
  711. , " UNIQUE("
  712. , intercalate "," $ map escapeDBName' cols
  713. , ")"
  714. ]
  715. where
  716. escapeDBName' (name, (FTTypeCon _ "Text" ), maxlen) = escapeDBName name ++ "(" ++ show maxlen ++ ")"
  717. escapeDBName' (name, (FTTypeCon _ "String" ), maxlen) = escapeDBName name ++ "(" ++ show maxlen ++ ")"
  718. escapeDBName' (name, (FTTypeCon _ "ByteString"), maxlen) = escapeDBName name ++ "(" ++ show maxlen ++ ")"
  719. escapeDBName' (name, _ , _) = escapeDBName name
  720. showAlterTable table (DropUniqueConstraint cname) = concat
  721. [ "ALTER TABLE "
  722. , escapeDBName table
  723. , " DROP INDEX "
  724. , escapeDBName cname
  725. ]
  726. -- | Render an action that must be done on a column.
  727. showAlter :: DBName -> AlterColumn' -> String
  728. showAlter table (oldName, Change (Column n nu t def defConstraintName maxLen _ref)) =
  729. concat
  730. [ "ALTER TABLE "
  731. , escapeDBName table
  732. , " CHANGE "
  733. , escapeDBName oldName
  734. , " "
  735. , showColumn (Column n nu t def defConstraintName maxLen Nothing)
  736. ]
  737. showAlter table (_, Add' col) =
  738. concat
  739. [ "ALTER TABLE "
  740. , escapeDBName table
  741. , " ADD COLUMN "
  742. , showColumn col
  743. ]
  744. showAlter table (n, Drop) =
  745. concat
  746. [ "ALTER TABLE "
  747. , escapeDBName table
  748. , " DROP COLUMN "
  749. , escapeDBName n
  750. ]
  751. showAlter table (n, Default s) =
  752. concat
  753. [ "ALTER TABLE "
  754. , escapeDBName table
  755. , " ALTER COLUMN "
  756. , escapeDBName n
  757. , " SET DEFAULT "
  758. , s
  759. ]
  760. showAlter table (n, NoDefault) =
  761. concat
  762. [ "ALTER TABLE "
  763. , escapeDBName table
  764. , " ALTER COLUMN "
  765. , escapeDBName n
  766. , " DROP DEFAULT"
  767. ]
  768. showAlter table (n, Update' s) =
  769. concat
  770. [ "UPDATE "
  771. , escapeDBName table
  772. , " SET "
  773. , escapeDBName n
  774. , "="
  775. , s
  776. , " WHERE "
  777. , escapeDBName n
  778. , " IS NULL"
  779. ]
  780. showAlter table (_, AddReference reftable fkeyname t2 id2) = concat
  781. [ "ALTER TABLE "
  782. , escapeDBName table
  783. , " ADD CONSTRAINT "
  784. , escapeDBName fkeyname
  785. , " FOREIGN KEY("
  786. , intercalate "," $ map escapeDBName t2
  787. , ") REFERENCES "
  788. , escapeDBName reftable
  789. , "("
  790. , intercalate "," $ map escapeDBName id2
  791. , ")"
  792. ]
  793. showAlter table (_, DropReference cname) = concat
  794. [ "ALTER TABLE "
  795. , escapeDBName table
  796. , " DROP FOREIGN KEY "
  797. , escapeDBName cname
  798. ]
  799. refName :: DBName -> DBName -> DBName
  800. refName (DBName table) (DBName column) =
  801. DBName $ T.concat [table, "_", column, "_fkey"]
  802. ----------------------------------------------------------------------
  803. escape :: DBName -> Text
  804. escape = T.pack . escapeDBName
  805. -- | Escape a database name to be included on a query.
  806. escapeDBName :: DBName -> String
  807. escapeDBName (DBName s) = '`' : go (T.unpack s)
  808. where
  809. go ('`':xs) = '`' : '`' : go xs
  810. go ( x :xs) = x : go xs
  811. go "" = "`"
  812. -- | Information required to connect to a MySQL database
  813. -- using @persistent@'s generic facilities. These values are the
  814. -- same that are given to 'withMySQLPool'.
  815. data MySQLConf = MySQLConf
  816. { myConnInfo :: MySQL.ConnectInfo
  817. -- ^ The connection information.
  818. , myPoolSize :: Int
  819. -- ^ How many connections should be held on the connection pool.
  820. } deriving Show
  821. instance FromJSON MySQLConf where
  822. parseJSON v = modifyFailure ("Persistent: error loading MySQL conf: " ++) $
  823. flip (withObject "MySQLConf") v $ \o -> do
  824. database <- o .: "database"
  825. host <- o .: "host"
  826. port <- o .: "port"
  827. path <- o .:? "path"
  828. user <- o .: "user"
  829. password <- o .: "password"
  830. pool <- o .: "poolsize"
  831. let ci = MySQL.defaultConnectInfo
  832. { MySQL.connectHost = host
  833. , MySQL.connectPort = port
  834. , MySQL.connectPath = case path of
  835. Just p -> p
  836. Nothing -> MySQL.connectPath MySQL.defaultConnectInfo
  837. , MySQL.connectUser = user
  838. , MySQL.connectPassword = password
  839. , MySQL.connectDatabase = database
  840. }
  841. return $ MySQLConf ci pool
  842. instance PersistConfig MySQLConf where
  843. type PersistConfigBackend MySQLConf = SqlPersistT
  844. type PersistConfigPool MySQLConf = ConnectionPool
  845. createPoolConfig (MySQLConf cs size) = runNoLoggingT $ createMySQLPool cs size -- FIXME
  846. runPool _ = runSqlPool
  847. loadConfig = parseJSON
  848. applyEnv conf = do
  849. env <- getEnvironment
  850. let maybeEnv old var = maybe old id $ lookup ("MYSQL_" ++ var) env
  851. return conf
  852. { myConnInfo =
  853. case myConnInfo conf of
  854. MySQL.ConnectInfo
  855. { MySQL.connectHost = host
  856. , MySQL.connectPort = port
  857. , MySQL.connectPath = path
  858. , MySQL.connectUser = user
  859. , MySQL.connectPassword = password
  860. , MySQL.connectDatabase = database
  861. } -> (myConnInfo conf)
  862. { MySQL.connectHost = maybeEnv host "HOST"
  863. , MySQL.connectPort = read $ maybeEnv (show port) "PORT"
  864. , MySQL.connectPath = maybeEnv path "PATH"
  865. , MySQL.connectUser = maybeEnv user "USER"
  866. , MySQL.connectPassword = maybeEnv password "PASSWORD"
  867. , MySQL.connectDatabase = maybeEnv database "DATABASE"
  868. }
  869. }
  870. mockMigrate :: MySQL.ConnectInfo
  871. -> [EntityDef]
  872. -> (Text -> IO Statement)
  873. -> EntityDef
  874. -> IO (Either [Text] [(Bool, Text)])
  875. mockMigrate _connectInfo allDefs _getter val = do
  876. let name = entityDB val
  877. let (newcols, udefs, fdefs) = mkColumns allDefs val
  878. let udspair = map udToPair udefs
  879. case () of
  880. -- Nothing found, create everything
  881. () -> do
  882. let uniques = flip concatMap udspair $ \(uname, ucols) ->
  883. [ AlterTable name $
  884. AddUniqueConstraint uname $
  885. map (findTypeAndMaxLen name) ucols ]
  886. let foreigns = do
  887. Column { cName=cname, cReference=Just (refTblName, _a) } <- newcols
  888. return $ AlterColumn name (refTblName, addReference allDefs (refName name cname) refTblName cname)
  889. let foreignsAlt = map (\fdef -> let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef))
  890. in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignRefTableDBName fdef) (foreignConstraintNameDBName fdef) childfields parentfields)) fdefs
  891. return $ Right $ map showAlterDb $ (addTable newcols val): uniques ++ foreigns ++ foreignsAlt
  892. {- FIXME redundant, why is this here? The whole case expression is weird
  893. -- No errors and something found, migrate
  894. (_, _, ([], old')) -> do
  895. let excludeForeignKeys (xs,ys) = (map (\c -> case cReference c of
  896. Just (_,fk) -> case find (\f -> fk == foreignConstraintNameDBName f) fdefs of
  897. Just _ -> c { cReference = Nothing }
  898. Nothing -> c
  899. Nothing -> c) xs,ys)
  900. (acs, ats) = getAlters allDefs name (newcols, udspair) $ excludeForeignKeys $ partitionEithers old'
  901. acs' = map (AlterColumn name) acs
  902. ats' = map (AlterTable name) ats
  903. return $ Right $ map showAlterDb $ acs' ++ ats'
  904. -- Errors
  905. (_, _, (errs, _)) -> return $ Left errs
  906. -}
  907. where
  908. findTypeAndMaxLen tblName col = let (col', ty) = findTypeOfColumn allDefs tblName col
  909. (_, ml) = findMaxLenOfColumn allDefs tblName col
  910. in (col', ty, ml)
  911. -- | Mock a migration even when the database is not present.
  912. -- This function will mock the migration for a database even when
  913. -- the actual database isn't already present in the system.
  914. mockMigration :: Migration -> IO ()
  915. mockMigration mig = do
  916. smap <- newIORef $ Map.empty
  917. let sqlbackend = SqlBackend { connPrepare = \_ -> do
  918. return Statement
  919. { stmtFinalize = return ()
  920. , stmtReset = return ()
  921. , stmtExecute = undefined
  922. , stmtQuery = \_ -> return $ return ()
  923. },
  924. connInsertManySql = Nothing,
  925. connInsertSql = undefined,
  926. connStmtMap = smap,
  927. connClose = undefined,
  928. connMigrateSql = mockMigrate undefined,
  929. connBegin = undefined,
  930. connCommit = undefined,
  931. connRollback = undefined,
  932. connEscapeName = undefined,
  933. connNoLimit = undefined,
  934. connRDBMS = undefined,
  935. connLimitOffset = undefined,
  936. connLogFunc = undefined,
  937. connUpsertSql = undefined,
  938. connPutManySql = undefined,
  939. connMaxParams = Nothing,
  940. connRepsertManySql = Nothing,
  941. connInsertUniqueSql = Nothing
  942. }
  943. result = runReaderT . runWriterT . runWriterT $ mig
  944. resp <- result sqlbackend
  945. mapM_ T.putStrLn $ map snd $ snd resp
  946. -- | MySQL specific 'upsert_'. This will prevent multiple queries, when one will
  947. -- do. The record will be inserted into the database. In the event that the
  948. -- record already exists in the database, the record will have the
  949. -- relevant updates performed.
  950. insertOnDuplicateKeyUpdate
  951. :: ( backend ~ PersistEntityBackend record
  952. , PersistEntity record
  953. , MonadIO m
  954. , PersistStore backend
  955. , BackendCompatible SqlBackend backend
  956. )
  957. => record
  958. -> [Update record]
  959. -> ReaderT backend m ()
  960. insertOnDuplicateKeyUpdate record =
  961. insertManyOnDuplicateKeyUpdate [record] []
  962. -- | This type is used to determine how to update rows using MySQL's
  963. -- @INSERT ... ON DUPLICATE KEY UPDATE@ functionality, exposed via
  964. -- 'insertManyOnDuplicateKeyUpdate' in this library.
  965. --
  966. -- @since 2.8.0
  967. data HandleUpdateCollision record where
  968. -- | Copy the field directly from the record.
  969. CopyField :: EntityField record typ -> HandleUpdateCollision record
  970. -- | Only copy the field if it is not equal to the provided value.
  971. CopyUnlessEq :: PersistField typ => EntityField record typ -> typ -> HandleUpdateCollision record
  972. -- | An alias for 'HandleUpdateCollision'. The type previously was only
  973. -- used to copy a single value, but was expanded to be handle more complex
  974. -- queries.
  975. --
  976. -- @since 2.6.2
  977. type SomeField = HandleUpdateCollision
  978. pattern SomeField :: EntityField record typ -> SomeField record
  979. pattern SomeField x = CopyField x
  980. {-# DEPRECATED SomeField "The type SomeField is deprecated. Use the type HandleUpdateCollision instead, and use the function copyField instead of the data constructor." #-}
  981. -- | Copy the field into the database only if the value in the
  982. -- corresponding record is non-@NULL@.
  983. --
  984. -- @since 2.6.2
  985. copyUnlessNull :: PersistField typ => EntityField record (Maybe typ) -> HandleUpdateCollision record
  986. copyUnlessNull field = CopyUnlessEq field Nothing
  987. -- | Copy the field into the database only if the value in the
  988. -- corresponding record is non-empty, where "empty" means the Monoid
  989. -- definition for 'mempty'. Useful for 'Text', 'String', 'ByteString', etc.
  990. --
  991. -- The resulting 'HandleUpdateCollision' type is useful for the
  992. -- 'insertManyOnDuplicateKeyUpdate' function.
  993. --
  994. -- @since 2.6.2
  995. copyUnlessEmpty :: (Monoid.Monoid typ, PersistField typ) => EntityField record typ -> HandleUpdateCollision record
  996. copyUnlessEmpty field = CopyUnlessEq field Monoid.mempty
  997. -- | Copy the field into the database only if the field is not equal to the
  998. -- provided value. This is useful to avoid copying weird nullary data into
  999. -- the database.
  1000. --
  1001. -- The resulting 'HandleUpdateCollision' type is useful for the
  1002. -- 'insertManyOnDuplicateKeyUpdate' function.
  1003. --
  1004. -- @since 2.6.2
  1005. copyUnlessEq :: PersistField typ => EntityField record typ -> typ -> HandleUpdateCollision record
  1006. copyUnlessEq = CopyUnlessEq
  1007. -- | Copy the field directly from the record.
  1008. --
  1009. -- @since 3.0
  1010. copyField :: PersistField typ => EntityField record typ -> HandleUpdateCollision record
  1011. copyField = CopyField
  1012. -- | Do a bulk insert on the given records in the first parameter. In the event
  1013. -- that a key conflicts with a record currently in the database, the second and
  1014. -- third parameters determine what will happen.
  1015. --
  1016. -- The second parameter is a list of fields to copy from the original value.
  1017. -- This allows you to specify which fields to copy from the record you're trying
  1018. -- to insert into the database to the preexisting row.
  1019. --
  1020. -- The third parameter is a list of updates to perform that are independent of
  1021. -- the value that is provided. You can use this to increment a counter value.
  1022. -- These updates only occur if the original record is present in the database.
  1023. --
  1024. -- === __More details on 'HandleUpdateCollision' usage__
  1025. --
  1026. -- The @['HandleUpdateCollision']@ parameter allows you to specify which fields (and
  1027. -- under which conditions) will be copied from the inserted rows. For
  1028. -- a brief example, consider the following data model and existing data set:
  1029. --
  1030. -- @
  1031. -- Item
  1032. -- name Text
  1033. -- description Text
  1034. -- price Double Maybe
  1035. -- quantity Int Maybe
  1036. --
  1037. -- Primary name
  1038. -- @
  1039. --
  1040. -- > items:
  1041. -- > +------+-------------+-------+----------+
  1042. -- > | name | description | price | quantity |
  1043. -- > +------+-------------+-------+----------+
  1044. -- > | foo | very good | | 3 |
  1045. -- > | bar | | 3.99 | |
  1046. -- > +------+-------------+-------+----------+
  1047. --
  1048. -- This record type has a single natural key on @itemName@. Let's suppose
  1049. -- that we download a CSV of new items to store into the database. Here's
  1050. -- our CSV:
  1051. --
  1052. -- > name,description,price,quantity
  1053. -- > foo,,2.50,6
  1054. -- > bar,even better,,5
  1055. -- > yes,wow,,
  1056. --
  1057. -- We parse that into a list of Haskell records:
  1058. --
  1059. -- @
  1060. -- records =
  1061. -- [ Item { itemName = "foo", itemDescription = ""
  1062. -- , itemPrice = Just 2.50, itemQuantity = Just 6
  1063. -- }
  1064. -- , Item "bar" "even better" Nothing (Just 5)
  1065. -- , Item "yes" "wow" Nothing Nothing
  1066. -- ]
  1067. -- @
  1068. --
  1069. -- The new CSV data is partial. It only includes __updates__ from the
  1070. -- upstream vendor. Our CSV library parses the missing description field as
  1071. -- an empty string. We don't want to override the existing description. So
  1072. -- we can use the 'copyUnlessEmpty' function to say: "Don't update when the
  1073. -- value is empty."
  1074. --
  1075. -- Likewise, the new row for @bar@ includes a quantity, but no price. We do
  1076. -- not want to overwrite the existing price in the database with a @NULL@
  1077. -- value. So we can use 'copyUnlessNull' to only copy the existing values
  1078. -- in.
  1079. --
  1080. -- The final code looks like this:
  1081. -- @
  1082. -- 'insertManyOnDuplicateKeyUpdate' records
  1083. -- [ 'copyUnlessEmpty' ItemDescription
  1084. -- , 'copyUnlessNull' ItemPrice
  1085. -- , 'copyUnlessNull' ItemQuantity
  1086. -- ]
  1087. -- []
  1088. -- @
  1089. --
  1090. -- Once we run that code on the datahase, the new data set looks like this:
  1091. --
  1092. -- > items:
  1093. -- > +------+-------------+-------+----------+
  1094. -- > | name | description | price | quantity |
  1095. -- > +------+-------------+-------+----------+
  1096. -- > | foo | very good | 2.50 | 6 |
  1097. -- > | bar | even better | 3.99 | 5 |
  1098. -- > | yes | wow | | |
  1099. -- > +------+-------------+-------+----------+
  1100. insertManyOnDuplicateKeyUpdate
  1101. :: forall record backend m.
  1102. ( backend ~ PersistEntityBackend record
  1103. , BackendCompatible SqlBackend backend
  1104. , PersistEntity record
  1105. , MonadIO m
  1106. )
  1107. => [record] -- ^ A list of the records you want to insert, or update
  1108. -> [HandleUpdateCollision record] -- ^ A list of the fields you want to copy over.
  1109. -> [Update record] -- ^ A list of the updates to apply that aren't dependent on the record being inserted.
  1110. -> ReaderT backend m ()
  1111. insertManyOnDuplicateKeyUpdate [] _ _ = return ()
  1112. insertManyOnDuplicateKeyUpdate records fieldValues updates =
  1113. uncurry rawExecute
  1114. $ mkBulkInsertQuery records fieldValues updates
  1115. -- | This creates the query for 'bulkInsertOnDuplicateKeyUpdate'. If you
  1116. -- provide an empty list of updates to perform, then it will generate
  1117. -- a dummy/no-op update using the first field of the record. This avoids
  1118. -- duplicate key exceptions.
  1119. mkBulkInsertQuery
  1120. :: PersistEntity record
  1121. => [record] -- ^ A list of the records you want to insert, or update
  1122. -> [HandleUpdateCollision record] -- ^ A list of the fields you want to copy over.
  1123. -> [Update record] -- ^ A list of the updates to apply that aren't dependent on the record being inserted.
  1124. -> (Text, [PersistValue])
  1125. mkBulkInsertQuery records fieldValues updates =
  1126. (q, recordValues <> updsValues <> copyUnlessValues)
  1127. where
  1128. mfieldDef x = case x of
  1129. CopyField rec -> Right (fieldDbToText (persistFieldDef rec))
  1130. CopyUnlessEq rec val -> Left (fieldDbToText (persistFieldDef rec), toPersistValue val)
  1131. (fieldsToMaybeCopy, updateFieldNames) = partitionEithers $ map mfieldDef fieldValues
  1132. fieldDbToText = T.pack . escapeDBName . fieldDB
  1133. entityDef' = entityDef records
  1134. firstField = case entityFieldNames of
  1135. [] -> error "The entity you're trying to insert does not have any fields."
  1136. (field:_) -> field
  1137. entityFieldNames = map fieldDbToText (entityFields entityDef')
  1138. tableName = T.pack . escapeDBName . entityDB $ entityDef'
  1139. copyUnlessValues = map snd fieldsToMaybeCopy
  1140. recordValues = concatMap (map toPersistValue . toPersistFields) records
  1141. recordPlaceholders = Util.commaSeparated $ map (Util.parenWrapped . Util.commaSeparated . map (const "?") . toPersistFields) records
  1142. mkCondFieldSet n _ = T.concat
  1143. [ n
  1144. , "=COALESCE("
  1145. , "NULLIF("
  1146. , "VALUES(", n, "),"
  1147. , "?"
  1148. , "),"
  1149. , n
  1150. , ")"
  1151. ]
  1152. condFieldSets = map (uncurry mkCondFieldSet) fieldsToMaybeCopy
  1153. fieldSets = map (\n -> T.concat [n, "=VALUES(", n, ")"]) updateFieldNames
  1154. upds = map (Util.mkUpdateText' (pack . escapeDBName) id) updates
  1155. updsValues = map (\(Update _ val _) -> toPersistValue val) updates
  1156. updateText = case fieldSets <> upds <> condFieldSets of
  1157. [] -> T.concat [firstField, "=", firstField]
  1158. xs -> Util.commaSeparated xs
  1159. q = T.concat
  1160. [ "INSERT INTO "
  1161. , tableName
  1162. , " ("
  1163. , Util.commaSeparated entityFieldNames
  1164. , ") "
  1165. , " VALUES "
  1166. , recordPlaceholders
  1167. , " ON DUPLICATE KEY UPDATE "
  1168. , updateText
  1169. ]
  1170. putManySql :: EntityDef -> Int -> Text
  1171. putManySql ent n = putManySql' fields ent n
  1172. where
  1173. fields = entityFields ent
  1174. repsertManySql :: EntityDef -> Int -> Text
  1175. repsertManySql ent n = putManySql' fields ent n
  1176. where
  1177. fields = keyAndEntityFields ent
  1178. putManySql' :: [FieldDef] -> EntityDef -> Int -> Text
  1179. putManySql' fields ent n = q
  1180. where
  1181. fieldDbToText = escape . fieldDB
  1182. mkAssignment f = T.concat [f, "=VALUES(", f, ")"]
  1183. table = escape . entityDB $ ent
  1184. columns = Util.commaSeparated $ map fieldDbToText fields
  1185. placeholders = map (const "?") fields
  1186. updates = map (mkAssignment . fieldDbToText) fields
  1187. q = T.concat
  1188. [ "INSERT INTO "
  1189. , table
  1190. , Util.parenWrapped columns
  1191. , " VALUES "
  1192. , Util.commaSeparated . replicate n
  1193. . Util.parenWrapped . Util.commaSeparated $ placeholders
  1194. , " ON DUPLICATE KEY UPDATE "
  1195. , Util.commaSeparated updates
  1196. ]