conn.go 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113
  1. package pq
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/md5"
  7. "crypto/sha256"
  8. "database/sql"
  9. "database/sql/driver"
  10. "encoding/binary"
  11. "errors"
  12. "fmt"
  13. "io"
  14. "net"
  15. "os"
  16. "os/user"
  17. "path"
  18. "path/filepath"
  19. "strconv"
  20. "strings"
  21. "sync"
  22. "time"
  23. "unicode"
  24. "github.com/lib/pq/oid"
  25. "github.com/lib/pq/scram"
  26. )
  27. // Common error types
  28. var (
  29. ErrNotSupported = errors.New("pq: Unsupported command")
  30. ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
  31. ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
  32. ErrSSLKeyUnknownOwnership = errors.New("pq: Could not get owner information for private key, may not be properly protected")
  33. ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key has world access. Permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less")
  34. ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
  35. errUnexpectedReady = errors.New("unexpected ReadyForQuery")
  36. errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
  37. errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
  38. )
  39. // Compile time validation that our types implement the expected interfaces
  40. var (
  41. _ driver.Driver = Driver{}
  42. )
  43. // Driver is the Postgres database driver.
  44. type Driver struct{}
  45. // Open opens a new connection to the database. name is a connection string.
  46. // Most users should only use it through database/sql package from the standard
  47. // library.
  48. func (d Driver) Open(name string) (driver.Conn, error) {
  49. return Open(name)
  50. }
  51. func init() {
  52. sql.Register("postgres", &Driver{})
  53. }
  54. type parameterStatus struct {
  55. // server version in the same format as server_version_num, or 0 if
  56. // unavailable
  57. serverVersion int
  58. // the current location based on the TimeZone value of the session, if
  59. // available
  60. currentLocation *time.Location
  61. }
  62. type transactionStatus byte
  63. const (
  64. txnStatusIdle transactionStatus = 'I'
  65. txnStatusIdleInTransaction transactionStatus = 'T'
  66. txnStatusInFailedTransaction transactionStatus = 'E'
  67. )
  68. func (s transactionStatus) String() string {
  69. switch s {
  70. case txnStatusIdle:
  71. return "idle"
  72. case txnStatusIdleInTransaction:
  73. return "idle in transaction"
  74. case txnStatusInFailedTransaction:
  75. return "in a failed transaction"
  76. default:
  77. errorf("unknown transactionStatus %d", s)
  78. }
  79. panic("not reached")
  80. }
  81. // Dialer is the dialer interface. It can be used to obtain more control over
  82. // how pq creates network connections.
  83. type Dialer interface {
  84. Dial(network, address string) (net.Conn, error)
  85. DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
  86. }
  87. // DialerContext is the context-aware dialer interface.
  88. type DialerContext interface {
  89. DialContext(ctx context.Context, network, address string) (net.Conn, error)
  90. }
  91. type defaultDialer struct {
  92. d net.Dialer
  93. }
  94. func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
  95. return d.d.Dial(network, address)
  96. }
  97. func (d defaultDialer) DialTimeout(
  98. network, address string, timeout time.Duration,
  99. ) (net.Conn, error) {
  100. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  101. defer cancel()
  102. return d.DialContext(ctx, network, address)
  103. }
  104. func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  105. return d.d.DialContext(ctx, network, address)
  106. }
  107. type conn struct {
  108. c net.Conn
  109. buf *bufio.Reader
  110. namei int
  111. scratch [512]byte
  112. txnStatus transactionStatus
  113. txnFinish func()
  114. // Save connection arguments to use during CancelRequest.
  115. dialer Dialer
  116. opts values
  117. // Cancellation key data for use with CancelRequest messages.
  118. processID int
  119. secretKey int
  120. parameterStatus parameterStatus
  121. saveMessageType byte
  122. saveMessageBuffer []byte
  123. // If an error is set, this connection is bad and all public-facing
  124. // functions should return the appropriate error by calling get()
  125. // (ErrBadConn) or getForNext().
  126. err syncErr
  127. // If set, this connection should never use the binary format when
  128. // receiving query results from prepared statements. Only provided for
  129. // debugging.
  130. disablePreparedBinaryResult bool
  131. // Whether to always send []byte parameters over as binary. Enables single
  132. // round-trip mode for non-prepared Query calls.
  133. binaryParameters bool
  134. // If true this connection is in the middle of a COPY
  135. inCopy bool
  136. // If not nil, notices will be synchronously sent here
  137. noticeHandler func(*Error)
  138. // If not nil, notifications will be synchronously sent here
  139. notificationHandler func(*Notification)
  140. // GSSAPI context
  141. gss GSS
  142. }
  143. type syncErr struct {
  144. err error
  145. sync.Mutex
  146. }
  147. // Return ErrBadConn if connection is bad.
  148. func (e *syncErr) get() error {
  149. e.Lock()
  150. defer e.Unlock()
  151. if e.err != nil {
  152. return driver.ErrBadConn
  153. }
  154. return nil
  155. }
  156. // Return the error set on the connection. Currently only used by rows.Next.
  157. func (e *syncErr) getForNext() error {
  158. e.Lock()
  159. defer e.Unlock()
  160. return e.err
  161. }
  162. // Set error, only if it isn't set yet.
  163. func (e *syncErr) set(err error) {
  164. if err == nil {
  165. panic("attempt to set nil err")
  166. }
  167. e.Lock()
  168. defer e.Unlock()
  169. if e.err == nil {
  170. e.err = err
  171. }
  172. }
  173. // Handle driver-side settings in parsed connection string.
  174. func (cn *conn) handleDriverSettings(o values) (err error) {
  175. boolSetting := func(key string, val *bool) error {
  176. if value, ok := o[key]; ok {
  177. if value == "yes" {
  178. *val = true
  179. } else if value == "no" {
  180. *val = false
  181. } else {
  182. return fmt.Errorf("unrecognized value %q for %s", value, key)
  183. }
  184. }
  185. return nil
  186. }
  187. err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
  188. if err != nil {
  189. return err
  190. }
  191. return boolSetting("binary_parameters", &cn.binaryParameters)
  192. }
  193. func (cn *conn) handlePgpass(o values) {
  194. // if a password was supplied, do not process .pgpass
  195. if _, ok := o["password"]; ok {
  196. return
  197. }
  198. filename := os.Getenv("PGPASSFILE")
  199. if filename == "" {
  200. // XXX this code doesn't work on Windows where the default filename is
  201. // XXX %APPDATA%\postgresql\pgpass.conf
  202. // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
  203. userHome := os.Getenv("HOME")
  204. if userHome == "" {
  205. user, err := user.Current()
  206. if err != nil {
  207. return
  208. }
  209. userHome = user.HomeDir
  210. }
  211. filename = filepath.Join(userHome, ".pgpass")
  212. }
  213. fileinfo, err := os.Stat(filename)
  214. if err != nil {
  215. return
  216. }
  217. mode := fileinfo.Mode()
  218. if mode&(0x77) != 0 {
  219. // XXX should warn about incorrect .pgpass permissions as psql does
  220. return
  221. }
  222. file, err := os.Open(filename)
  223. if err != nil {
  224. return
  225. }
  226. defer file.Close()
  227. scanner := bufio.NewScanner(io.Reader(file))
  228. // From: https://github.com/tg/pgpass/blob/master/reader.go
  229. for scanner.Scan() {
  230. if scanText(scanner.Text(), o) {
  231. break
  232. }
  233. }
  234. }
  235. // GetFields is a helper function for scanText.
  236. func getFields(s string) []string {
  237. fs := make([]string, 0, 5)
  238. f := make([]rune, 0, len(s))
  239. var esc bool
  240. for _, c := range s {
  241. switch {
  242. case esc:
  243. f = append(f, c)
  244. esc = false
  245. case c == '\\':
  246. esc = true
  247. case c == ':':
  248. fs = append(fs, string(f))
  249. f = f[:0]
  250. default:
  251. f = append(f, c)
  252. }
  253. }
  254. return append(fs, string(f))
  255. }
  256. // ScanText assists HandlePgpass in it's objective.
  257. func scanText(line string, o values) bool {
  258. hostname := o["host"]
  259. ntw, _ := network(o)
  260. port := o["port"]
  261. db := o["dbname"]
  262. username := o["user"]
  263. if len(line) == 0 || line[0] == '#' {
  264. return false
  265. }
  266. split := getFields(line)
  267. if len(split) != 5 {
  268. return false
  269. }
  270. if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
  271. o["password"] = split[4]
  272. return true
  273. }
  274. return false
  275. }
  276. func (cn *conn) writeBuf(b byte) *writeBuf {
  277. cn.scratch[0] = b
  278. return &writeBuf{
  279. buf: cn.scratch[:5],
  280. pos: 1,
  281. }
  282. }
  283. // Open opens a new connection to the database. dsn is a connection string.
  284. // Most users should only use it through database/sql package from the standard
  285. // library.
  286. func Open(dsn string) (_ driver.Conn, err error) {
  287. return DialOpen(defaultDialer{}, dsn)
  288. }
  289. // DialOpen opens a new connection to the database using a dialer.
  290. func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
  291. c, err := NewConnector(dsn)
  292. if err != nil {
  293. return nil, err
  294. }
  295. c.Dialer(d)
  296. return c.open(context.Background())
  297. }
  298. func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
  299. // Handle any panics during connection initialization. Note that we
  300. // specifically do *not* want to use errRecover(), as that would turn any
  301. // connection errors into ErrBadConns, hiding the real error message from
  302. // the user.
  303. defer errRecoverNoErrBadConn(&err)
  304. // Create a new values map (copy). This makes it so maps in different
  305. // connections do not reference the same underlying data structure, so it
  306. // is safe for multiple connections to concurrently write to their opts.
  307. o := make(values)
  308. for k, v := range c.opts {
  309. o[k] = v
  310. }
  311. cn = &conn{
  312. opts: o,
  313. dialer: c.dialer,
  314. }
  315. err = cn.handleDriverSettings(o)
  316. if err != nil {
  317. return nil, err
  318. }
  319. cn.handlePgpass(o)
  320. cn.c, err = dial(ctx, c.dialer, o)
  321. if err != nil {
  322. return nil, err
  323. }
  324. err = cn.ssl(o)
  325. if err != nil {
  326. if cn.c != nil {
  327. cn.c.Close()
  328. }
  329. return nil, err
  330. }
  331. // cn.startup panics on error. Make sure we don't leak cn.c.
  332. panicking := true
  333. defer func() {
  334. if panicking {
  335. cn.c.Close()
  336. }
  337. }()
  338. cn.buf = bufio.NewReader(cn.c)
  339. cn.startup(o)
  340. // reset the deadline, in case one was set (see dial)
  341. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  342. err = cn.c.SetDeadline(time.Time{})
  343. }
  344. panicking = false
  345. return cn, err
  346. }
  347. func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
  348. network, address := network(o)
  349. // Zero or not specified means wait indefinitely.
  350. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  351. seconds, err := strconv.ParseInt(timeout, 10, 0)
  352. if err != nil {
  353. return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
  354. }
  355. duration := time.Duration(seconds) * time.Second
  356. // connect_timeout should apply to the entire connection establishment
  357. // procedure, so we both use a timeout for the TCP connection
  358. // establishment and set a deadline for doing the initial handshake.
  359. // The deadline is then reset after startup() is done.
  360. deadline := time.Now().Add(duration)
  361. var conn net.Conn
  362. if dctx, ok := d.(DialerContext); ok {
  363. ctx, cancel := context.WithTimeout(ctx, duration)
  364. defer cancel()
  365. conn, err = dctx.DialContext(ctx, network, address)
  366. } else {
  367. conn, err = d.DialTimeout(network, address, duration)
  368. }
  369. if err != nil {
  370. return nil, err
  371. }
  372. err = conn.SetDeadline(deadline)
  373. return conn, err
  374. }
  375. if dctx, ok := d.(DialerContext); ok {
  376. return dctx.DialContext(ctx, network, address)
  377. }
  378. return d.Dial(network, address)
  379. }
  380. func network(o values) (string, string) {
  381. host := o["host"]
  382. if strings.HasPrefix(host, "/") {
  383. sockPath := path.Join(host, ".s.PGSQL."+o["port"])
  384. return "unix", sockPath
  385. }
  386. return "tcp", net.JoinHostPort(host, o["port"])
  387. }
  388. type values map[string]string
  389. // scanner implements a tokenizer for libpq-style option strings.
  390. type scanner struct {
  391. s []rune
  392. i int
  393. }
  394. // newScanner returns a new scanner initialized with the option string s.
  395. func newScanner(s string) *scanner {
  396. return &scanner{[]rune(s), 0}
  397. }
  398. // Next returns the next rune.
  399. // It returns 0, false if the end of the text has been reached.
  400. func (s *scanner) Next() (rune, bool) {
  401. if s.i >= len(s.s) {
  402. return 0, false
  403. }
  404. r := s.s[s.i]
  405. s.i++
  406. return r, true
  407. }
  408. // SkipSpaces returns the next non-whitespace rune.
  409. // It returns 0, false if the end of the text has been reached.
  410. func (s *scanner) SkipSpaces() (rune, bool) {
  411. r, ok := s.Next()
  412. for unicode.IsSpace(r) && ok {
  413. r, ok = s.Next()
  414. }
  415. return r, ok
  416. }
  417. // parseOpts parses the options from name and adds them to the values.
  418. //
  419. // The parsing code is based on conninfo_parse from libpq's fe-connect.c
  420. func parseOpts(name string, o values) error {
  421. s := newScanner(name)
  422. for {
  423. var (
  424. keyRunes, valRunes []rune
  425. r rune
  426. ok bool
  427. )
  428. if r, ok = s.SkipSpaces(); !ok {
  429. break
  430. }
  431. // Scan the key
  432. for !unicode.IsSpace(r) && r != '=' {
  433. keyRunes = append(keyRunes, r)
  434. if r, ok = s.Next(); !ok {
  435. break
  436. }
  437. }
  438. // Skip any whitespace if we're not at the = yet
  439. if r != '=' {
  440. r, ok = s.SkipSpaces()
  441. }
  442. // The current character should be =
  443. if r != '=' || !ok {
  444. return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
  445. }
  446. // Skip any whitespace after the =
  447. if r, ok = s.SkipSpaces(); !ok {
  448. // If we reach the end here, the last value is just an empty string as per libpq.
  449. o[string(keyRunes)] = ""
  450. break
  451. }
  452. if r != '\'' {
  453. for !unicode.IsSpace(r) {
  454. if r == '\\' {
  455. if r, ok = s.Next(); !ok {
  456. return fmt.Errorf(`missing character after backslash`)
  457. }
  458. }
  459. valRunes = append(valRunes, r)
  460. if r, ok = s.Next(); !ok {
  461. break
  462. }
  463. }
  464. } else {
  465. quote:
  466. for {
  467. if r, ok = s.Next(); !ok {
  468. return fmt.Errorf(`unterminated quoted string literal in connection string`)
  469. }
  470. switch r {
  471. case '\'':
  472. break quote
  473. case '\\':
  474. r, _ = s.Next()
  475. fallthrough
  476. default:
  477. valRunes = append(valRunes, r)
  478. }
  479. }
  480. }
  481. o[string(keyRunes)] = string(valRunes)
  482. }
  483. return nil
  484. }
  485. func (cn *conn) isInTransaction() bool {
  486. return cn.txnStatus == txnStatusIdleInTransaction ||
  487. cn.txnStatus == txnStatusInFailedTransaction
  488. }
  489. func (cn *conn) checkIsInTransaction(intxn bool) {
  490. if cn.isInTransaction() != intxn {
  491. cn.err.set(driver.ErrBadConn)
  492. errorf("unexpected transaction status %v", cn.txnStatus)
  493. }
  494. }
  495. func (cn *conn) Begin() (_ driver.Tx, err error) {
  496. return cn.begin("")
  497. }
  498. func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
  499. if err := cn.err.get(); err != nil {
  500. return nil, err
  501. }
  502. defer cn.errRecover(&err)
  503. cn.checkIsInTransaction(false)
  504. _, commandTag, err := cn.simpleExec("BEGIN" + mode)
  505. if err != nil {
  506. return nil, err
  507. }
  508. if commandTag != "BEGIN" {
  509. cn.err.set(driver.ErrBadConn)
  510. return nil, fmt.Errorf("unexpected command tag %s", commandTag)
  511. }
  512. if cn.txnStatus != txnStatusIdleInTransaction {
  513. cn.err.set(driver.ErrBadConn)
  514. return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
  515. }
  516. return cn, nil
  517. }
  518. func (cn *conn) closeTxn() {
  519. if finish := cn.txnFinish; finish != nil {
  520. finish()
  521. }
  522. }
  523. func (cn *conn) Commit() (err error) {
  524. defer cn.closeTxn()
  525. if err := cn.err.get(); err != nil {
  526. return err
  527. }
  528. defer cn.errRecover(&err)
  529. cn.checkIsInTransaction(true)
  530. // We don't want the client to think that everything is okay if it tries
  531. // to commit a failed transaction. However, no matter what we return,
  532. // database/sql will release this connection back into the free connection
  533. // pool so we have to abort the current transaction here. Note that you
  534. // would get the same behaviour if you issued a COMMIT in a failed
  535. // transaction, so it's also the least surprising thing to do here.
  536. if cn.txnStatus == txnStatusInFailedTransaction {
  537. if err := cn.rollback(); err != nil {
  538. return err
  539. }
  540. return ErrInFailedTransaction
  541. }
  542. _, commandTag, err := cn.simpleExec("COMMIT")
  543. if err != nil {
  544. if cn.isInTransaction() {
  545. cn.err.set(driver.ErrBadConn)
  546. }
  547. return err
  548. }
  549. if commandTag != "COMMIT" {
  550. cn.err.set(driver.ErrBadConn)
  551. return fmt.Errorf("unexpected command tag %s", commandTag)
  552. }
  553. cn.checkIsInTransaction(false)
  554. return nil
  555. }
  556. func (cn *conn) Rollback() (err error) {
  557. defer cn.closeTxn()
  558. if err := cn.err.get(); err != nil {
  559. return err
  560. }
  561. defer cn.errRecover(&err)
  562. return cn.rollback()
  563. }
  564. func (cn *conn) rollback() (err error) {
  565. cn.checkIsInTransaction(true)
  566. _, commandTag, err := cn.simpleExec("ROLLBACK")
  567. if err != nil {
  568. if cn.isInTransaction() {
  569. cn.err.set(driver.ErrBadConn)
  570. }
  571. return err
  572. }
  573. if commandTag != "ROLLBACK" {
  574. return fmt.Errorf("unexpected command tag %s", commandTag)
  575. }
  576. cn.checkIsInTransaction(false)
  577. return nil
  578. }
  579. func (cn *conn) gname() string {
  580. cn.namei++
  581. return strconv.FormatInt(int64(cn.namei), 10)
  582. }
  583. func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
  584. b := cn.writeBuf('Q')
  585. b.string(q)
  586. cn.send(b)
  587. for {
  588. t, r := cn.recv1()
  589. switch t {
  590. case 'C':
  591. res, commandTag = cn.parseComplete(r.string())
  592. case 'Z':
  593. cn.processReadyForQuery(r)
  594. if res == nil && err == nil {
  595. err = errUnexpectedReady
  596. }
  597. // done
  598. return
  599. case 'E':
  600. err = parseError(r)
  601. case 'I':
  602. res = emptyRows
  603. case 'T', 'D':
  604. // ignore any results
  605. default:
  606. cn.err.set(driver.ErrBadConn)
  607. errorf("unknown response for simple query: %q", t)
  608. }
  609. }
  610. }
  611. func (cn *conn) simpleQuery(q string) (res *rows, err error) {
  612. defer cn.errRecover(&err)
  613. b := cn.writeBuf('Q')
  614. b.string(q)
  615. cn.send(b)
  616. for {
  617. t, r := cn.recv1()
  618. switch t {
  619. case 'C', 'I':
  620. // We allow queries which don't return any results through Query as
  621. // well as Exec. We still have to give database/sql a rows object
  622. // the user can close, though, to avoid connections from being
  623. // leaked. A "rows" with done=true works fine for that purpose.
  624. if err != nil {
  625. cn.err.set(driver.ErrBadConn)
  626. errorf("unexpected message %q in simple query execution", t)
  627. }
  628. if res == nil {
  629. res = &rows{
  630. cn: cn,
  631. }
  632. }
  633. // Set the result and tag to the last command complete if there wasn't a
  634. // query already run. Although queries usually return from here and cede
  635. // control to Next, a query with zero results does not.
  636. if t == 'C' {
  637. res.result, res.tag = cn.parseComplete(r.string())
  638. if res.colNames != nil {
  639. return
  640. }
  641. }
  642. res.done = true
  643. case 'Z':
  644. cn.processReadyForQuery(r)
  645. // done
  646. return
  647. case 'E':
  648. res = nil
  649. err = parseError(r)
  650. case 'D':
  651. if res == nil {
  652. cn.err.set(driver.ErrBadConn)
  653. errorf("unexpected DataRow in simple query execution")
  654. }
  655. // the query didn't fail; kick off to Next
  656. cn.saveMessage(t, r)
  657. return
  658. case 'T':
  659. // res might be non-nil here if we received a previous
  660. // CommandComplete, but that's fine; just overwrite it
  661. res = &rows{cn: cn}
  662. res.rowsHeader = parsePortalRowDescribe(r)
  663. // To work around a bug in QueryRow in Go 1.2 and earlier, wait
  664. // until the first DataRow has been received.
  665. default:
  666. cn.err.set(driver.ErrBadConn)
  667. errorf("unknown response for simple query: %q", t)
  668. }
  669. }
  670. }
  671. type noRows struct{}
  672. var emptyRows noRows
  673. var _ driver.Result = noRows{}
  674. func (noRows) LastInsertId() (int64, error) {
  675. return 0, errNoLastInsertID
  676. }
  677. func (noRows) RowsAffected() (int64, error) {
  678. return 0, errNoRowsAffected
  679. }
  680. // Decides which column formats to use for a prepared statement. The input is
  681. // an array of type oids, one element per result column.
  682. func decideColumnFormats(
  683. colTyps []fieldDesc, forceText bool,
  684. ) (colFmts []format, colFmtData []byte) {
  685. if len(colTyps) == 0 {
  686. return nil, colFmtDataAllText
  687. }
  688. colFmts = make([]format, len(colTyps))
  689. if forceText {
  690. return colFmts, colFmtDataAllText
  691. }
  692. allBinary := true
  693. allText := true
  694. for i, t := range colTyps {
  695. switch t.OID {
  696. // This is the list of types to use binary mode for when receiving them
  697. // through a prepared statement. If a type appears in this list, it
  698. // must also be implemented in binaryDecode in encode.go.
  699. case oid.T_bytea:
  700. fallthrough
  701. case oid.T_int8:
  702. fallthrough
  703. case oid.T_int4:
  704. fallthrough
  705. case oid.T_int2:
  706. fallthrough
  707. case oid.T_uuid:
  708. colFmts[i] = formatBinary
  709. allText = false
  710. default:
  711. allBinary = false
  712. }
  713. }
  714. if allBinary {
  715. return colFmts, colFmtDataAllBinary
  716. } else if allText {
  717. return colFmts, colFmtDataAllText
  718. } else {
  719. colFmtData = make([]byte, 2+len(colFmts)*2)
  720. binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
  721. for i, v := range colFmts {
  722. binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
  723. }
  724. return colFmts, colFmtData
  725. }
  726. }
  727. func (cn *conn) prepareTo(q, stmtName string) *stmt {
  728. st := &stmt{cn: cn, name: stmtName}
  729. b := cn.writeBuf('P')
  730. b.string(st.name)
  731. b.string(q)
  732. b.int16(0)
  733. b.next('D')
  734. b.byte('S')
  735. b.string(st.name)
  736. b.next('S')
  737. cn.send(b)
  738. cn.readParseResponse()
  739. st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
  740. st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
  741. cn.readReadyForQuery()
  742. return st
  743. }
  744. func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
  745. if err := cn.err.get(); err != nil {
  746. return nil, err
  747. }
  748. defer cn.errRecover(&err)
  749. if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
  750. s, err := cn.prepareCopyIn(q)
  751. if err == nil {
  752. cn.inCopy = true
  753. }
  754. return s, err
  755. }
  756. return cn.prepareTo(q, cn.gname()), nil
  757. }
  758. func (cn *conn) Close() (err error) {
  759. // Skip cn.bad return here because we always want to close a connection.
  760. defer cn.errRecover(&err)
  761. // Ensure that cn.c.Close is always run. Since error handling is done with
  762. // panics and cn.errRecover, the Close must be in a defer.
  763. defer func() {
  764. cerr := cn.c.Close()
  765. if err == nil {
  766. err = cerr
  767. }
  768. }()
  769. // Don't go through send(); ListenerConn relies on us not scribbling on the
  770. // scratch buffer of this connection.
  771. return cn.sendSimpleMessage('X')
  772. }
  773. // Implement the "Queryer" interface
  774. func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
  775. return cn.query(query, args)
  776. }
  777. func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
  778. if err := cn.err.get(); err != nil {
  779. return nil, err
  780. }
  781. if cn.inCopy {
  782. return nil, errCopyInProgress
  783. }
  784. defer cn.errRecover(&err)
  785. // Check to see if we can use the "simpleQuery" interface, which is
  786. // *much* faster than going through prepare/exec
  787. if len(args) == 0 {
  788. return cn.simpleQuery(query)
  789. }
  790. if cn.binaryParameters {
  791. cn.sendBinaryModeQuery(query, args)
  792. cn.readParseResponse()
  793. cn.readBindResponse()
  794. rows := &rows{cn: cn}
  795. rows.rowsHeader = cn.readPortalDescribeResponse()
  796. cn.postExecuteWorkaround()
  797. return rows, nil
  798. }
  799. st := cn.prepareTo(query, "")
  800. st.exec(args)
  801. return &rows{
  802. cn: cn,
  803. rowsHeader: st.rowsHeader,
  804. }, nil
  805. }
  806. // Implement the optional "Execer" interface for one-shot queries
  807. func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
  808. if err := cn.err.get(); err != nil {
  809. return nil, err
  810. }
  811. defer cn.errRecover(&err)
  812. // Check to see if we can use the "simpleExec" interface, which is
  813. // *much* faster than going through prepare/exec
  814. if len(args) == 0 {
  815. // ignore commandTag, our caller doesn't care
  816. r, _, err := cn.simpleExec(query)
  817. return r, err
  818. }
  819. if cn.binaryParameters {
  820. cn.sendBinaryModeQuery(query, args)
  821. cn.readParseResponse()
  822. cn.readBindResponse()
  823. cn.readPortalDescribeResponse()
  824. cn.postExecuteWorkaround()
  825. res, _, err = cn.readExecuteResponse("Execute")
  826. return res, err
  827. }
  828. // Use the unnamed statement to defer planning until bind
  829. // time, or else value-based selectivity estimates cannot be
  830. // used.
  831. st := cn.prepareTo(query, "")
  832. r, err := st.Exec(args)
  833. if err != nil {
  834. panic(err)
  835. }
  836. return r, err
  837. }
  838. type safeRetryError struct {
  839. Err error
  840. }
  841. func (se *safeRetryError) Error() string {
  842. return se.Err.Error()
  843. }
  844. func (cn *conn) send(m *writeBuf) {
  845. n, err := cn.c.Write(m.wrap())
  846. if err != nil {
  847. if n == 0 {
  848. err = &safeRetryError{Err: err}
  849. }
  850. panic(err)
  851. }
  852. }
  853. func (cn *conn) sendStartupPacket(m *writeBuf) error {
  854. _, err := cn.c.Write((m.wrap())[1:])
  855. return err
  856. }
  857. // Send a message of type typ to the server on the other end of cn. The
  858. // message should have no payload. This method does not use the scratch
  859. // buffer.
  860. func (cn *conn) sendSimpleMessage(typ byte) (err error) {
  861. _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
  862. return err
  863. }
  864. // saveMessage memorizes a message and its buffer in the conn struct.
  865. // recvMessage will then return these values on the next call to it. This
  866. // method is useful in cases where you have to see what the next message is
  867. // going to be (e.g. to see whether it's an error or not) but you can't handle
  868. // the message yourself.
  869. func (cn *conn) saveMessage(typ byte, buf *readBuf) {
  870. if cn.saveMessageType != 0 {
  871. cn.err.set(driver.ErrBadConn)
  872. errorf("unexpected saveMessageType %d", cn.saveMessageType)
  873. }
  874. cn.saveMessageType = typ
  875. cn.saveMessageBuffer = *buf
  876. }
  877. // recvMessage receives any message from the backend, or returns an error if
  878. // a problem occurred while reading the message.
  879. func (cn *conn) recvMessage(r *readBuf) (byte, error) {
  880. // workaround for a QueryRow bug, see exec
  881. if cn.saveMessageType != 0 {
  882. t := cn.saveMessageType
  883. *r = cn.saveMessageBuffer
  884. cn.saveMessageType = 0
  885. cn.saveMessageBuffer = nil
  886. return t, nil
  887. }
  888. x := cn.scratch[:5]
  889. _, err := io.ReadFull(cn.buf, x)
  890. if err != nil {
  891. return 0, err
  892. }
  893. // read the type and length of the message that follows
  894. t := x[0]
  895. n := int(binary.BigEndian.Uint32(x[1:])) - 4
  896. var y []byte
  897. if n <= len(cn.scratch) {
  898. y = cn.scratch[:n]
  899. } else {
  900. y = make([]byte, n)
  901. }
  902. _, err = io.ReadFull(cn.buf, y)
  903. if err != nil {
  904. return 0, err
  905. }
  906. *r = y
  907. return t, nil
  908. }
  909. // recv receives a message from the backend, but if an error happened while
  910. // reading the message or the received message was an ErrorResponse, it panics.
  911. // NoticeResponses are ignored. This function should generally be used only
  912. // during the startup sequence.
  913. func (cn *conn) recv() (t byte, r *readBuf) {
  914. for {
  915. var err error
  916. r = &readBuf{}
  917. t, err = cn.recvMessage(r)
  918. if err != nil {
  919. panic(err)
  920. }
  921. switch t {
  922. case 'E':
  923. panic(parseError(r))
  924. case 'N':
  925. if n := cn.noticeHandler; n != nil {
  926. n(parseError(r))
  927. }
  928. case 'A':
  929. if n := cn.notificationHandler; n != nil {
  930. n(recvNotification(r))
  931. }
  932. default:
  933. return
  934. }
  935. }
  936. }
  937. // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
  938. // the caller to avoid an allocation.
  939. func (cn *conn) recv1Buf(r *readBuf) byte {
  940. for {
  941. t, err := cn.recvMessage(r)
  942. if err != nil {
  943. panic(err)
  944. }
  945. switch t {
  946. case 'A':
  947. if n := cn.notificationHandler; n != nil {
  948. n(recvNotification(r))
  949. }
  950. case 'N':
  951. if n := cn.noticeHandler; n != nil {
  952. n(parseError(r))
  953. }
  954. case 'S':
  955. cn.processParameterStatus(r)
  956. default:
  957. return t
  958. }
  959. }
  960. }
  961. // recv1 receives a message from the backend, panicking if an error occurs
  962. // while attempting to read it. All asynchronous messages are ignored, with
  963. // the exception of ErrorResponse.
  964. func (cn *conn) recv1() (t byte, r *readBuf) {
  965. r = &readBuf{}
  966. t = cn.recv1Buf(r)
  967. return t, r
  968. }
  969. func (cn *conn) ssl(o values) error {
  970. upgrade, err := ssl(o)
  971. if err != nil {
  972. return err
  973. }
  974. if upgrade == nil {
  975. // Nothing to do
  976. return nil
  977. }
  978. w := cn.writeBuf(0)
  979. w.int32(80877103)
  980. if err = cn.sendStartupPacket(w); err != nil {
  981. return err
  982. }
  983. b := cn.scratch[:1]
  984. _, err = io.ReadFull(cn.c, b)
  985. if err != nil {
  986. return err
  987. }
  988. if b[0] != 'S' {
  989. return ErrSSLNotSupported
  990. }
  991. cn.c, err = upgrade(cn.c)
  992. return err
  993. }
  994. // isDriverSetting returns true iff a setting is purely for configuring the
  995. // driver's options and should not be sent to the server in the connection
  996. // startup packet.
  997. func isDriverSetting(key string) bool {
  998. switch key {
  999. case "host", "port":
  1000. return true
  1001. case "password":
  1002. return true
  1003. case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline", "sslsni":
  1004. return true
  1005. case "fallback_application_name":
  1006. return true
  1007. case "connect_timeout":
  1008. return true
  1009. case "disable_prepared_binary_result":
  1010. return true
  1011. case "binary_parameters":
  1012. return true
  1013. case "krbsrvname":
  1014. return true
  1015. case "krbspn":
  1016. return true
  1017. default:
  1018. return false
  1019. }
  1020. }
  1021. func (cn *conn) startup(o values) {
  1022. w := cn.writeBuf(0)
  1023. w.int32(196608)
  1024. // Send the backend the name of the database we want to connect to, and the
  1025. // user we want to connect as. Additionally, we send over any run-time
  1026. // parameters potentially included in the connection string. If the server
  1027. // doesn't recognize any of them, it will reply with an error.
  1028. for k, v := range o {
  1029. if isDriverSetting(k) {
  1030. // skip options which can't be run-time parameters
  1031. continue
  1032. }
  1033. // The protocol requires us to supply the database name as "database"
  1034. // instead of "dbname".
  1035. if k == "dbname" {
  1036. k = "database"
  1037. }
  1038. w.string(k)
  1039. w.string(v)
  1040. }
  1041. w.string("")
  1042. if err := cn.sendStartupPacket(w); err != nil {
  1043. panic(err)
  1044. }
  1045. for {
  1046. t, r := cn.recv()
  1047. switch t {
  1048. case 'K':
  1049. cn.processBackendKeyData(r)
  1050. case 'S':
  1051. cn.processParameterStatus(r)
  1052. case 'R':
  1053. cn.auth(r, o)
  1054. case 'Z':
  1055. cn.processReadyForQuery(r)
  1056. return
  1057. default:
  1058. errorf("unknown response for startup: %q", t)
  1059. }
  1060. }
  1061. }
  1062. func (cn *conn) auth(r *readBuf, o values) {
  1063. switch code := r.int32(); code {
  1064. case 0:
  1065. // OK
  1066. case 3:
  1067. w := cn.writeBuf('p')
  1068. w.string(o["password"])
  1069. cn.send(w)
  1070. t, r := cn.recv()
  1071. if t != 'R' {
  1072. errorf("unexpected password response: %q", t)
  1073. }
  1074. if r.int32() != 0 {
  1075. errorf("unexpected authentication response: %q", t)
  1076. }
  1077. case 5:
  1078. s := string(r.next(4))
  1079. w := cn.writeBuf('p')
  1080. w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
  1081. cn.send(w)
  1082. t, r := cn.recv()
  1083. if t != 'R' {
  1084. errorf("unexpected password response: %q", t)
  1085. }
  1086. if r.int32() != 0 {
  1087. errorf("unexpected authentication response: %q", t)
  1088. }
  1089. case 7: // GSSAPI, startup
  1090. if newGss == nil {
  1091. errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)")
  1092. }
  1093. cli, err := newGss()
  1094. if err != nil {
  1095. errorf("kerberos error: %s", err.Error())
  1096. }
  1097. var token []byte
  1098. if spn, ok := o["krbspn"]; ok {
  1099. // Use the supplied SPN if provided..
  1100. token, err = cli.GetInitTokenFromSpn(spn)
  1101. } else {
  1102. // Allow the kerberos service name to be overridden
  1103. service := "postgres"
  1104. if val, ok := o["krbsrvname"]; ok {
  1105. service = val
  1106. }
  1107. token, err = cli.GetInitToken(o["host"], service)
  1108. }
  1109. if err != nil {
  1110. errorf("failed to get Kerberos ticket: %q", err)
  1111. }
  1112. w := cn.writeBuf('p')
  1113. w.bytes(token)
  1114. cn.send(w)
  1115. // Store for GSSAPI continue message
  1116. cn.gss = cli
  1117. case 8: // GSSAPI continue
  1118. if cn.gss == nil {
  1119. errorf("GSSAPI protocol error")
  1120. }
  1121. b := []byte(*r)
  1122. done, tokOut, err := cn.gss.Continue(b)
  1123. if err == nil && !done {
  1124. w := cn.writeBuf('p')
  1125. w.bytes(tokOut)
  1126. cn.send(w)
  1127. }
  1128. // Errors fall through and read the more detailed message
  1129. // from the server..
  1130. case 10:
  1131. sc := scram.NewClient(sha256.New, o["user"], o["password"])
  1132. sc.Step(nil)
  1133. if sc.Err() != nil {
  1134. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1135. }
  1136. scOut := sc.Out()
  1137. w := cn.writeBuf('p')
  1138. w.string("SCRAM-SHA-256")
  1139. w.int32(len(scOut))
  1140. w.bytes(scOut)
  1141. cn.send(w)
  1142. t, r := cn.recv()
  1143. if t != 'R' {
  1144. errorf("unexpected password response: %q", t)
  1145. }
  1146. if r.int32() != 11 {
  1147. errorf("unexpected authentication response: %q", t)
  1148. }
  1149. nextStep := r.next(len(*r))
  1150. sc.Step(nextStep)
  1151. if sc.Err() != nil {
  1152. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1153. }
  1154. scOut = sc.Out()
  1155. w = cn.writeBuf('p')
  1156. w.bytes(scOut)
  1157. cn.send(w)
  1158. t, r = cn.recv()
  1159. if t != 'R' {
  1160. errorf("unexpected password response: %q", t)
  1161. }
  1162. if r.int32() != 12 {
  1163. errorf("unexpected authentication response: %q", t)
  1164. }
  1165. nextStep = r.next(len(*r))
  1166. sc.Step(nextStep)
  1167. if sc.Err() != nil {
  1168. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1169. }
  1170. default:
  1171. errorf("unknown authentication response: %d", code)
  1172. }
  1173. }
  1174. type format int
  1175. const formatText format = 0
  1176. const formatBinary format = 1
  1177. // One result-column format code with the value 1 (i.e. all binary).
  1178. var colFmtDataAllBinary = []byte{0, 1, 0, 1}
  1179. // No result-column format codes (i.e. all text).
  1180. var colFmtDataAllText = []byte{0, 0}
  1181. type stmt struct {
  1182. cn *conn
  1183. name string
  1184. rowsHeader
  1185. colFmtData []byte
  1186. paramTyps []oid.Oid
  1187. closed bool
  1188. }
  1189. func (st *stmt) Close() (err error) {
  1190. if st.closed {
  1191. return nil
  1192. }
  1193. if err := st.cn.err.get(); err != nil {
  1194. return err
  1195. }
  1196. defer st.cn.errRecover(&err)
  1197. w := st.cn.writeBuf('C')
  1198. w.byte('S')
  1199. w.string(st.name)
  1200. st.cn.send(w)
  1201. st.cn.send(st.cn.writeBuf('S'))
  1202. t, _ := st.cn.recv1()
  1203. if t != '3' {
  1204. st.cn.err.set(driver.ErrBadConn)
  1205. errorf("unexpected close response: %q", t)
  1206. }
  1207. st.closed = true
  1208. t, r := st.cn.recv1()
  1209. if t != 'Z' {
  1210. st.cn.err.set(driver.ErrBadConn)
  1211. errorf("expected ready for query, but got: %q", t)
  1212. }
  1213. st.cn.processReadyForQuery(r)
  1214. return nil
  1215. }
  1216. func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
  1217. return st.query(v)
  1218. }
  1219. func (st *stmt) query(v []driver.Value) (r *rows, err error) {
  1220. if err := st.cn.err.get(); err != nil {
  1221. return nil, err
  1222. }
  1223. defer st.cn.errRecover(&err)
  1224. st.exec(v)
  1225. return &rows{
  1226. cn: st.cn,
  1227. rowsHeader: st.rowsHeader,
  1228. }, nil
  1229. }
  1230. func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
  1231. if err := st.cn.err.get(); err != nil {
  1232. return nil, err
  1233. }
  1234. defer st.cn.errRecover(&err)
  1235. st.exec(v)
  1236. res, _, err = st.cn.readExecuteResponse("simple query")
  1237. return res, err
  1238. }
  1239. func (st *stmt) exec(v []driver.Value) {
  1240. if len(v) >= 65536 {
  1241. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
  1242. }
  1243. if len(v) != len(st.paramTyps) {
  1244. errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
  1245. }
  1246. cn := st.cn
  1247. w := cn.writeBuf('B')
  1248. w.byte(0) // unnamed portal
  1249. w.string(st.name)
  1250. if cn.binaryParameters {
  1251. cn.sendBinaryParameters(w, v)
  1252. } else {
  1253. w.int16(0)
  1254. w.int16(len(v))
  1255. for i, x := range v {
  1256. if x == nil {
  1257. w.int32(-1)
  1258. } else {
  1259. b := encode(&cn.parameterStatus, x, st.paramTyps[i])
  1260. w.int32(len(b))
  1261. w.bytes(b)
  1262. }
  1263. }
  1264. }
  1265. w.bytes(st.colFmtData)
  1266. w.next('E')
  1267. w.byte(0)
  1268. w.int32(0)
  1269. w.next('S')
  1270. cn.send(w)
  1271. cn.readBindResponse()
  1272. cn.postExecuteWorkaround()
  1273. }
  1274. func (st *stmt) NumInput() int {
  1275. return len(st.paramTyps)
  1276. }
  1277. // parseComplete parses the "command tag" from a CommandComplete message, and
  1278. // returns the number of rows affected (if applicable) and a string
  1279. // identifying only the command that was executed, e.g. "ALTER TABLE". If the
  1280. // command tag could not be parsed, parseComplete panics.
  1281. func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
  1282. commandsWithAffectedRows := []string{
  1283. "SELECT ",
  1284. // INSERT is handled below
  1285. "UPDATE ",
  1286. "DELETE ",
  1287. "FETCH ",
  1288. "MOVE ",
  1289. "COPY ",
  1290. }
  1291. var affectedRows *string
  1292. for _, tag := range commandsWithAffectedRows {
  1293. if strings.HasPrefix(commandTag, tag) {
  1294. t := commandTag[len(tag):]
  1295. affectedRows = &t
  1296. commandTag = tag[:len(tag)-1]
  1297. break
  1298. }
  1299. }
  1300. // INSERT also includes the oid of the inserted row in its command tag.
  1301. // Oids in user tables are deprecated, and the oid is only returned when
  1302. // exactly one row is inserted, so it's unlikely to be of value to any
  1303. // real-world application and we can ignore it.
  1304. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
  1305. parts := strings.Split(commandTag, " ")
  1306. if len(parts) != 3 {
  1307. cn.err.set(driver.ErrBadConn)
  1308. errorf("unexpected INSERT command tag %s", commandTag)
  1309. }
  1310. affectedRows = &parts[len(parts)-1]
  1311. commandTag = "INSERT"
  1312. }
  1313. // There should be no affected rows attached to the tag, just return it
  1314. if affectedRows == nil {
  1315. return driver.RowsAffected(0), commandTag
  1316. }
  1317. n, err := strconv.ParseInt(*affectedRows, 10, 64)
  1318. if err != nil {
  1319. cn.err.set(driver.ErrBadConn)
  1320. errorf("could not parse commandTag: %s", err)
  1321. }
  1322. return driver.RowsAffected(n), commandTag
  1323. }
  1324. type rowsHeader struct {
  1325. colNames []string
  1326. colTyps []fieldDesc
  1327. colFmts []format
  1328. }
  1329. type rows struct {
  1330. cn *conn
  1331. finish func()
  1332. rowsHeader
  1333. done bool
  1334. rb readBuf
  1335. result driver.Result
  1336. tag string
  1337. next *rowsHeader
  1338. }
  1339. func (rs *rows) Close() error {
  1340. if finish := rs.finish; finish != nil {
  1341. defer finish()
  1342. }
  1343. // no need to look at cn.bad as Next() will
  1344. for {
  1345. err := rs.Next(nil)
  1346. switch err {
  1347. case nil:
  1348. case io.EOF:
  1349. // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
  1350. // description, used with HasNextResultSet). We need to fetch messages until
  1351. // we hit a 'Z', which is done by waiting for done to be set.
  1352. if rs.done {
  1353. return nil
  1354. }
  1355. default:
  1356. return err
  1357. }
  1358. }
  1359. }
  1360. func (rs *rows) Columns() []string {
  1361. return rs.colNames
  1362. }
  1363. func (rs *rows) Result() driver.Result {
  1364. if rs.result == nil {
  1365. return emptyRows
  1366. }
  1367. return rs.result
  1368. }
  1369. func (rs *rows) Tag() string {
  1370. return rs.tag
  1371. }
  1372. func (rs *rows) Next(dest []driver.Value) (err error) {
  1373. if rs.done {
  1374. return io.EOF
  1375. }
  1376. conn := rs.cn
  1377. if err := conn.err.getForNext(); err != nil {
  1378. return err
  1379. }
  1380. defer conn.errRecover(&err)
  1381. for {
  1382. t := conn.recv1Buf(&rs.rb)
  1383. switch t {
  1384. case 'E':
  1385. err = parseError(&rs.rb)
  1386. case 'C', 'I':
  1387. if t == 'C' {
  1388. rs.result, rs.tag = conn.parseComplete(rs.rb.string())
  1389. }
  1390. continue
  1391. case 'Z':
  1392. conn.processReadyForQuery(&rs.rb)
  1393. rs.done = true
  1394. if err != nil {
  1395. return err
  1396. }
  1397. return io.EOF
  1398. case 'D':
  1399. n := rs.rb.int16()
  1400. if err != nil {
  1401. conn.err.set(driver.ErrBadConn)
  1402. errorf("unexpected DataRow after error %s", err)
  1403. }
  1404. if n < len(dest) {
  1405. dest = dest[:n]
  1406. }
  1407. for i := range dest {
  1408. l := rs.rb.int32()
  1409. if l == -1 {
  1410. dest[i] = nil
  1411. continue
  1412. }
  1413. dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
  1414. }
  1415. return
  1416. case 'T':
  1417. next := parsePortalRowDescribe(&rs.rb)
  1418. rs.next = &next
  1419. return io.EOF
  1420. default:
  1421. errorf("unexpected message after execute: %q", t)
  1422. }
  1423. }
  1424. }
  1425. func (rs *rows) HasNextResultSet() bool {
  1426. hasNext := rs.next != nil && !rs.done
  1427. return hasNext
  1428. }
  1429. func (rs *rows) NextResultSet() error {
  1430. if rs.next == nil {
  1431. return io.EOF
  1432. }
  1433. rs.rowsHeader = *rs.next
  1434. rs.next = nil
  1435. return nil
  1436. }
  1437. // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
  1438. // used as part of an SQL statement. For example:
  1439. //
  1440. // tblname := "my_table"
  1441. // data := "my_data"
  1442. // quoted := pq.QuoteIdentifier(tblname)
  1443. // err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
  1444. //
  1445. // Any double quotes in name will be escaped. The quoted identifier will be
  1446. // case sensitive when used in a query. If the input string contains a zero
  1447. // byte, the result will be truncated immediately before it.
  1448. func QuoteIdentifier(name string) string {
  1449. end := strings.IndexRune(name, 0)
  1450. if end > -1 {
  1451. name = name[:end]
  1452. }
  1453. return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
  1454. }
  1455. // BufferQuoteIdentifier satisfies the same purpose as QuoteIdentifier, but backed by a
  1456. // byte buffer.
  1457. func BufferQuoteIdentifier(name string, buffer *bytes.Buffer) {
  1458. end := strings.IndexRune(name, 0)
  1459. if end > -1 {
  1460. name = name[:end]
  1461. }
  1462. buffer.WriteRune('"')
  1463. buffer.WriteString(strings.Replace(name, `"`, `""`, -1))
  1464. buffer.WriteRune('"')
  1465. }
  1466. // QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
  1467. // to DDL and other statements that do not accept parameters) to be used as part
  1468. // of an SQL statement. For example:
  1469. //
  1470. // exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
  1471. // err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
  1472. //
  1473. // Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
  1474. // replaced by two backslashes (i.e. "\\") and the C-style escape identifier
  1475. // that PostgreSQL provides ('E') will be prepended to the string.
  1476. func QuoteLiteral(literal string) string {
  1477. // This follows the PostgreSQL internal algorithm for handling quoted literals
  1478. // from libpq, which can be found in the "PQEscapeStringInternal" function,
  1479. // which is found in the libpq/fe-exec.c source file:
  1480. // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
  1481. //
  1482. // substitute any single-quotes (') with two single-quotes ('')
  1483. literal = strings.Replace(literal, `'`, `''`, -1)
  1484. // determine if the string has any backslashes (\) in it.
  1485. // if it does, replace any backslashes (\) with two backslashes (\\)
  1486. // then, we need to wrap the entire string with a PostgreSQL
  1487. // C-style escape. Per how "PQEscapeStringInternal" handles this case, we
  1488. // also add a space before the "E"
  1489. if strings.Contains(literal, `\`) {
  1490. literal = strings.Replace(literal, `\`, `\\`, -1)
  1491. literal = ` E'` + literal + `'`
  1492. } else {
  1493. // otherwise, we can just wrap the literal with a pair of single quotes
  1494. literal = `'` + literal + `'`
  1495. }
  1496. return literal
  1497. }
  1498. func md5s(s string) string {
  1499. h := md5.New()
  1500. h.Write([]byte(s))
  1501. return fmt.Sprintf("%x", h.Sum(nil))
  1502. }
  1503. func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
  1504. // Do one pass over the parameters to see if we're going to send any of
  1505. // them over in binary. If we are, create a paramFormats array at the
  1506. // same time.
  1507. var paramFormats []int
  1508. for i, x := range args {
  1509. _, ok := x.([]byte)
  1510. if ok {
  1511. if paramFormats == nil {
  1512. paramFormats = make([]int, len(args))
  1513. }
  1514. paramFormats[i] = 1
  1515. }
  1516. }
  1517. if paramFormats == nil {
  1518. b.int16(0)
  1519. } else {
  1520. b.int16(len(paramFormats))
  1521. for _, x := range paramFormats {
  1522. b.int16(x)
  1523. }
  1524. }
  1525. b.int16(len(args))
  1526. for _, x := range args {
  1527. if x == nil {
  1528. b.int32(-1)
  1529. } else {
  1530. datum := binaryEncode(&cn.parameterStatus, x)
  1531. b.int32(len(datum))
  1532. b.bytes(datum)
  1533. }
  1534. }
  1535. }
  1536. func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
  1537. if len(args) >= 65536 {
  1538. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
  1539. }
  1540. b := cn.writeBuf('P')
  1541. b.byte(0) // unnamed statement
  1542. b.string(query)
  1543. b.int16(0)
  1544. b.next('B')
  1545. b.int16(0) // unnamed portal and statement
  1546. cn.sendBinaryParameters(b, args)
  1547. b.bytes(colFmtDataAllText)
  1548. b.next('D')
  1549. b.byte('P')
  1550. b.byte(0) // unnamed portal
  1551. b.next('E')
  1552. b.byte(0)
  1553. b.int32(0)
  1554. b.next('S')
  1555. cn.send(b)
  1556. }
  1557. func (cn *conn) processParameterStatus(r *readBuf) {
  1558. var err error
  1559. param := r.string()
  1560. switch param {
  1561. case "server_version":
  1562. var major1 int
  1563. var major2 int
  1564. _, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2)
  1565. if err == nil {
  1566. cn.parameterStatus.serverVersion = major1*10000 + major2*100
  1567. }
  1568. case "TimeZone":
  1569. cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
  1570. if err != nil {
  1571. cn.parameterStatus.currentLocation = nil
  1572. }
  1573. default:
  1574. // ignore
  1575. }
  1576. }
  1577. func (cn *conn) processReadyForQuery(r *readBuf) {
  1578. cn.txnStatus = transactionStatus(r.byte())
  1579. }
  1580. func (cn *conn) readReadyForQuery() {
  1581. t, r := cn.recv1()
  1582. switch t {
  1583. case 'Z':
  1584. cn.processReadyForQuery(r)
  1585. return
  1586. default:
  1587. cn.err.set(driver.ErrBadConn)
  1588. errorf("unexpected message %q; expected ReadyForQuery", t)
  1589. }
  1590. }
  1591. func (cn *conn) processBackendKeyData(r *readBuf) {
  1592. cn.processID = r.int32()
  1593. cn.secretKey = r.int32()
  1594. }
  1595. func (cn *conn) readParseResponse() {
  1596. t, r := cn.recv1()
  1597. switch t {
  1598. case '1':
  1599. return
  1600. case 'E':
  1601. err := parseError(r)
  1602. cn.readReadyForQuery()
  1603. panic(err)
  1604. default:
  1605. cn.err.set(driver.ErrBadConn)
  1606. errorf("unexpected Parse response %q", t)
  1607. }
  1608. }
  1609. func (cn *conn) readStatementDescribeResponse() (
  1610. paramTyps []oid.Oid,
  1611. colNames []string,
  1612. colTyps []fieldDesc,
  1613. ) {
  1614. for {
  1615. t, r := cn.recv1()
  1616. switch t {
  1617. case 't':
  1618. nparams := r.int16()
  1619. paramTyps = make([]oid.Oid, nparams)
  1620. for i := range paramTyps {
  1621. paramTyps[i] = r.oid()
  1622. }
  1623. case 'n':
  1624. return paramTyps, nil, nil
  1625. case 'T':
  1626. colNames, colTyps = parseStatementRowDescribe(r)
  1627. return paramTyps, colNames, colTyps
  1628. case 'E':
  1629. err := parseError(r)
  1630. cn.readReadyForQuery()
  1631. panic(err)
  1632. default:
  1633. cn.err.set(driver.ErrBadConn)
  1634. errorf("unexpected Describe statement response %q", t)
  1635. }
  1636. }
  1637. }
  1638. func (cn *conn) readPortalDescribeResponse() rowsHeader {
  1639. t, r := cn.recv1()
  1640. switch t {
  1641. case 'T':
  1642. return parsePortalRowDescribe(r)
  1643. case 'n':
  1644. return rowsHeader{}
  1645. case 'E':
  1646. err := parseError(r)
  1647. cn.readReadyForQuery()
  1648. panic(err)
  1649. default:
  1650. cn.err.set(driver.ErrBadConn)
  1651. errorf("unexpected Describe response %q", t)
  1652. }
  1653. panic("not reached")
  1654. }
  1655. func (cn *conn) readBindResponse() {
  1656. t, r := cn.recv1()
  1657. switch t {
  1658. case '2':
  1659. return
  1660. case 'E':
  1661. err := parseError(r)
  1662. cn.readReadyForQuery()
  1663. panic(err)
  1664. default:
  1665. cn.err.set(driver.ErrBadConn)
  1666. errorf("unexpected Bind response %q", t)
  1667. }
  1668. }
  1669. func (cn *conn) postExecuteWorkaround() {
  1670. // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
  1671. // any errors from rows.Next, which masks errors that happened during the
  1672. // execution of the query. To avoid the problem in common cases, we wait
  1673. // here for one more message from the database. If it's not an error the
  1674. // query will likely succeed (or perhaps has already, if it's a
  1675. // CommandComplete), so we push the message into the conn struct; recv1
  1676. // will return it as the next message for rows.Next or rows.Close.
  1677. // However, if it's an error, we wait until ReadyForQuery and then return
  1678. // the error to our caller.
  1679. for {
  1680. t, r := cn.recv1()
  1681. switch t {
  1682. case 'E':
  1683. err := parseError(r)
  1684. cn.readReadyForQuery()
  1685. panic(err)
  1686. case 'C', 'D', 'I':
  1687. // the query didn't fail, but we can't process this message
  1688. cn.saveMessage(t, r)
  1689. return
  1690. default:
  1691. cn.err.set(driver.ErrBadConn)
  1692. errorf("unexpected message during extended query execution: %q", t)
  1693. }
  1694. }
  1695. }
  1696. // Only for Exec(), since we ignore the returned data
  1697. func (cn *conn) readExecuteResponse(
  1698. protocolState string,
  1699. ) (res driver.Result, commandTag string, err error) {
  1700. for {
  1701. t, r := cn.recv1()
  1702. switch t {
  1703. case 'C':
  1704. if err != nil {
  1705. cn.err.set(driver.ErrBadConn)
  1706. errorf("unexpected CommandComplete after error %s", err)
  1707. }
  1708. res, commandTag = cn.parseComplete(r.string())
  1709. case 'Z':
  1710. cn.processReadyForQuery(r)
  1711. if res == nil && err == nil {
  1712. err = errUnexpectedReady
  1713. }
  1714. return res, commandTag, err
  1715. case 'E':
  1716. err = parseError(r)
  1717. case 'T', 'D', 'I':
  1718. if err != nil {
  1719. cn.err.set(driver.ErrBadConn)
  1720. errorf("unexpected %q after error %s", t, err)
  1721. }
  1722. if t == 'I' {
  1723. res = emptyRows
  1724. }
  1725. // ignore any results
  1726. default:
  1727. cn.err.set(driver.ErrBadConn)
  1728. errorf("unknown %s response: %q", protocolState, t)
  1729. }
  1730. }
  1731. }
  1732. func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
  1733. n := r.int16()
  1734. colNames = make([]string, n)
  1735. colTyps = make([]fieldDesc, n)
  1736. for i := range colNames {
  1737. colNames[i] = r.string()
  1738. r.next(6)
  1739. colTyps[i].OID = r.oid()
  1740. colTyps[i].Len = r.int16()
  1741. colTyps[i].Mod = r.int32()
  1742. // format code not known when describing a statement; always 0
  1743. r.next(2)
  1744. }
  1745. return
  1746. }
  1747. func parsePortalRowDescribe(r *readBuf) rowsHeader {
  1748. n := r.int16()
  1749. colNames := make([]string, n)
  1750. colFmts := make([]format, n)
  1751. colTyps := make([]fieldDesc, n)
  1752. for i := range colNames {
  1753. colNames[i] = r.string()
  1754. r.next(6)
  1755. colTyps[i].OID = r.oid()
  1756. colTyps[i].Len = r.int16()
  1757. colTyps[i].Mod = r.int32()
  1758. colFmts[i] = format(r.int16())
  1759. }
  1760. return rowsHeader{
  1761. colNames: colNames,
  1762. colFmts: colFmts,
  1763. colTyps: colTyps,
  1764. }
  1765. }
  1766. // parseEnviron tries to mimic some of libpq's environment handling
  1767. //
  1768. // To ease testing, it does not directly reference os.Environ, but is
  1769. // designed to accept its output.
  1770. //
  1771. // Environment-set connection information is intended to have a higher
  1772. // precedence than a library default but lower than any explicitly
  1773. // passed information (such as in the URL or connection string).
  1774. func parseEnviron(env []string) (out map[string]string) {
  1775. out = make(map[string]string)
  1776. for _, v := range env {
  1777. parts := strings.SplitN(v, "=", 2)
  1778. accrue := func(keyname string) {
  1779. out[keyname] = parts[1]
  1780. }
  1781. unsupported := func() {
  1782. panic(fmt.Sprintf("setting %v not supported", parts[0]))
  1783. }
  1784. // The order of these is the same as is seen in the
  1785. // PostgreSQL 9.1 manual. Unsupported but well-defined
  1786. // keys cause a panic; these should be unset prior to
  1787. // execution. Options which pq expects to be set to a
  1788. // certain value are allowed, but must be set to that
  1789. // value if present (they can, of course, be absent).
  1790. switch parts[0] {
  1791. case "PGHOST":
  1792. accrue("host")
  1793. case "PGHOSTADDR":
  1794. unsupported()
  1795. case "PGPORT":
  1796. accrue("port")
  1797. case "PGDATABASE":
  1798. accrue("dbname")
  1799. case "PGUSER":
  1800. accrue("user")
  1801. case "PGPASSWORD":
  1802. accrue("password")
  1803. case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
  1804. unsupported()
  1805. case "PGOPTIONS":
  1806. accrue("options")
  1807. case "PGAPPNAME":
  1808. accrue("application_name")
  1809. case "PGSSLMODE":
  1810. accrue("sslmode")
  1811. case "PGSSLCERT":
  1812. accrue("sslcert")
  1813. case "PGSSLKEY":
  1814. accrue("sslkey")
  1815. case "PGSSLROOTCERT":
  1816. accrue("sslrootcert")
  1817. case "PGSSLSNI":
  1818. accrue("sslsni")
  1819. case "PGREQUIRESSL", "PGSSLCRL":
  1820. unsupported()
  1821. case "PGREQUIREPEER":
  1822. unsupported()
  1823. case "PGKRBSRVNAME", "PGGSSLIB":
  1824. unsupported()
  1825. case "PGCONNECT_TIMEOUT":
  1826. accrue("connect_timeout")
  1827. case "PGCLIENTENCODING":
  1828. accrue("client_encoding")
  1829. case "PGDATESTYLE":
  1830. accrue("datestyle")
  1831. case "PGTZ":
  1832. accrue("timezone")
  1833. case "PGGEQO":
  1834. accrue("geqo")
  1835. case "PGSYSCONFDIR", "PGLOCALEDIR":
  1836. unsupported()
  1837. }
  1838. }
  1839. return out
  1840. }
  1841. // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
  1842. func isUTF8(name string) bool {
  1843. // Recognize all sorts of silly things as "UTF-8", like Postgres does
  1844. s := strings.Map(alnumLowerASCII, name)
  1845. return s == "utf8" || s == "unicode"
  1846. }
  1847. func alnumLowerASCII(ch rune) rune {
  1848. if 'A' <= ch && ch <= 'Z' {
  1849. return ch + ('a' - 'A')
  1850. }
  1851. if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
  1852. return ch
  1853. }
  1854. return -1 // discard
  1855. }
  1856. // The database/sql/driver package says:
  1857. // All Conn implementations should implement the following interfaces: Pinger, SessionResetter, and Validator.
  1858. var _ driver.Pinger = &conn{}
  1859. var _ driver.SessionResetter = &conn{}
  1860. func (cn *conn) ResetSession(ctx context.Context) error {
  1861. // Ensure bad connections are reported: From database/sql/driver:
  1862. // If a connection is never returned to the connection pool but immediately reused, then
  1863. // ResetSession is called prior to reuse but IsValid is not called.
  1864. return cn.err.get()
  1865. }
  1866. func (cn *conn) IsValid() bool {
  1867. return cn.err.get() == nil
  1868. }