tds.go 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326
  1. package mssql
  2. import (
  3. "crypto/tls"
  4. "crypto/x509"
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "io/ioutil"
  10. "net"
  11. "net/url"
  12. "os"
  13. "sort"
  14. "strconv"
  15. "strings"
  16. "time"
  17. "unicode"
  18. "unicode/utf16"
  19. "unicode/utf8"
  20. "golang.org/x/net/context" // use the "x/net/context" for backwards compatibility.
  21. )
  22. func parseInstances(msg []byte) map[string]map[string]string {
  23. results := map[string]map[string]string{}
  24. if len(msg) > 3 && msg[0] == 5 {
  25. out_s := string(msg[3:])
  26. tokens := strings.Split(out_s, ";")
  27. instdict := map[string]string{}
  28. got_name := false
  29. var name string
  30. for _, token := range tokens {
  31. if got_name {
  32. instdict[name] = token
  33. got_name = false
  34. } else {
  35. name = token
  36. if len(name) == 0 {
  37. if len(instdict) == 0 {
  38. break
  39. }
  40. results[strings.ToUpper(instdict["InstanceName"])] = instdict
  41. instdict = map[string]string{}
  42. continue
  43. }
  44. got_name = true
  45. }
  46. }
  47. }
  48. return results
  49. }
  50. func getInstances(address string) (map[string]map[string]string, error) {
  51. conn, err := net.DialTimeout("udp", address+":1434", 5*time.Second)
  52. if err != nil {
  53. return nil, err
  54. }
  55. defer conn.Close()
  56. conn.SetDeadline(time.Now().Add(5 * time.Second))
  57. _, err = conn.Write([]byte{3})
  58. if err != nil {
  59. return nil, err
  60. }
  61. var resp = make([]byte, 16*1024-1)
  62. read, err := conn.Read(resp)
  63. if err != nil {
  64. return nil, err
  65. }
  66. return parseInstances(resp[:read]), nil
  67. }
  68. // tds versions
  69. const (
  70. verTDS70 = 0x70000000
  71. verTDS71 = 0x71000000
  72. verTDS71rev1 = 0x71000001
  73. verTDS72 = 0x72090002
  74. verTDS73A = 0x730A0003
  75. verTDS73 = verTDS73A
  76. verTDS73B = 0x730B0003
  77. verTDS74 = 0x74000004
  78. )
  79. // packet types
  80. // https://msdn.microsoft.com/en-us/library/dd304214.aspx
  81. const (
  82. packSQLBatch packetType = 1
  83. packRPCRequest = 3
  84. packReply = 4
  85. // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
  86. // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
  87. packAttention = 6
  88. packBulkLoadBCP = 7
  89. packTransMgrReq = 14
  90. packNormal = 15
  91. packLogin7 = 16
  92. packSSPIMessage = 17
  93. packPrelogin = 18
  94. )
  95. // prelogin fields
  96. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  97. const (
  98. preloginVERSION = 0
  99. preloginENCRYPTION = 1
  100. preloginINSTOPT = 2
  101. preloginTHREADID = 3
  102. preloginMARS = 4
  103. preloginTRACEID = 5
  104. preloginTERMINATOR = 0xff
  105. )
  106. const (
  107. encryptOff = 0 // Encryption is available but off.
  108. encryptOn = 1 // Encryption is available and on.
  109. encryptNotSup = 2 // Encryption is not available.
  110. encryptReq = 3 // Encryption is required.
  111. )
  112. type tdsSession struct {
  113. buf *tdsBuffer
  114. loginAck loginAckStruct
  115. database string
  116. partner string
  117. columns []columnStruct
  118. tranid uint64
  119. logFlags uint64
  120. log optionalLogger
  121. routedServer string
  122. routedPort uint16
  123. }
  124. const (
  125. logErrors = 1
  126. logMessages = 2
  127. logRows = 4
  128. logSQL = 8
  129. logParams = 16
  130. logTransaction = 32
  131. logDebug = 64
  132. )
  133. type columnStruct struct {
  134. UserType uint32
  135. Flags uint16
  136. ColName string
  137. ti typeInfo
  138. }
  139. type KeySlice []uint8
  140. func (p KeySlice) Len() int { return len(p) }
  141. func (p KeySlice) Less(i, j int) bool { return p[i] < p[j] }
  142. func (p KeySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
  143. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  144. func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
  145. var err error
  146. w.BeginPacket(packPrelogin)
  147. offset := uint16(5*len(fields) + 1)
  148. keys := make(KeySlice, 0, len(fields))
  149. for k, _ := range fields {
  150. keys = append(keys, k)
  151. }
  152. sort.Sort(keys)
  153. // writing header
  154. for _, k := range keys {
  155. err = w.WriteByte(k)
  156. if err != nil {
  157. return err
  158. }
  159. err = binary.Write(w, binary.BigEndian, offset)
  160. if err != nil {
  161. return err
  162. }
  163. v := fields[k]
  164. size := uint16(len(v))
  165. err = binary.Write(w, binary.BigEndian, size)
  166. if err != nil {
  167. return err
  168. }
  169. offset += size
  170. }
  171. err = w.WriteByte(preloginTERMINATOR)
  172. if err != nil {
  173. return err
  174. }
  175. // writing values
  176. for _, k := range keys {
  177. v := fields[k]
  178. written, err := w.Write(v)
  179. if err != nil {
  180. return err
  181. }
  182. if written != len(v) {
  183. return errors.New("Write method didn't write the whole value")
  184. }
  185. }
  186. return w.FinishPacket()
  187. }
  188. func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
  189. packet_type, err := r.BeginRead()
  190. if err != nil {
  191. return nil, err
  192. }
  193. struct_buf, err := ioutil.ReadAll(r)
  194. if err != nil {
  195. return nil, err
  196. }
  197. if packet_type != 4 {
  198. return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
  199. }
  200. offset := 0
  201. results := map[uint8][]byte{}
  202. for true {
  203. rec_type := struct_buf[offset]
  204. if rec_type == preloginTERMINATOR {
  205. break
  206. }
  207. rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:])
  208. rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:])
  209. value := struct_buf[rec_offset : rec_offset+rec_len]
  210. results[rec_type] = value
  211. offset += 5
  212. }
  213. return results, nil
  214. }
  215. // OptionFlags2
  216. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  217. const (
  218. fLanguageFatal = 1
  219. fODBC = 2
  220. fTransBoundary = 4
  221. fCacheConnect = 8
  222. fIntSecurity = 0x80
  223. )
  224. // TypeFlags
  225. const (
  226. // 4 bits for fSQLType
  227. // 1 bit for fOLEDB
  228. fReadOnlyIntent = 32
  229. )
  230. type login struct {
  231. TDSVersion uint32
  232. PacketSize uint32
  233. ClientProgVer uint32
  234. ClientPID uint32
  235. ConnectionID uint32
  236. OptionFlags1 uint8
  237. OptionFlags2 uint8
  238. TypeFlags uint8
  239. OptionFlags3 uint8
  240. ClientTimeZone int32
  241. ClientLCID uint32
  242. HostName string
  243. UserName string
  244. Password string
  245. AppName string
  246. ServerName string
  247. CtlIntName string
  248. Language string
  249. Database string
  250. ClientID [6]byte
  251. SSPI []byte
  252. AtchDBFile string
  253. ChangePassword string
  254. }
  255. type loginHeader struct {
  256. Length uint32
  257. TDSVersion uint32
  258. PacketSize uint32
  259. ClientProgVer uint32
  260. ClientPID uint32
  261. ConnectionID uint32
  262. OptionFlags1 uint8
  263. OptionFlags2 uint8
  264. TypeFlags uint8
  265. OptionFlags3 uint8
  266. ClientTimeZone int32
  267. ClientLCID uint32
  268. HostNameOffset uint16
  269. HostNameLength uint16
  270. UserNameOffset uint16
  271. UserNameLength uint16
  272. PasswordOffset uint16
  273. PasswordLength uint16
  274. AppNameOffset uint16
  275. AppNameLength uint16
  276. ServerNameOffset uint16
  277. ServerNameLength uint16
  278. ExtensionOffset uint16
  279. ExtensionLenght uint16
  280. CtlIntNameOffset uint16
  281. CtlIntNameLength uint16
  282. LanguageOffset uint16
  283. LanguageLength uint16
  284. DatabaseOffset uint16
  285. DatabaseLength uint16
  286. ClientID [6]byte
  287. SSPIOffset uint16
  288. SSPILength uint16
  289. AtchDBFileOffset uint16
  290. AtchDBFileLength uint16
  291. ChangePasswordOffset uint16
  292. ChangePasswordLength uint16
  293. SSPILongLength uint32
  294. }
  295. // convert Go string to UTF-16 encoded []byte (littleEndian)
  296. // done manually rather than using bytes and binary packages
  297. // for performance reasons
  298. func str2ucs2(s string) []byte {
  299. res := utf16.Encode([]rune(s))
  300. ucs2 := make([]byte, 2*len(res))
  301. for i := 0; i < len(res); i++ {
  302. ucs2[2*i] = byte(res[i])
  303. ucs2[2*i+1] = byte(res[i] >> 8)
  304. }
  305. return ucs2
  306. }
  307. func ucs22str(s []byte) (string, error) {
  308. if len(s)%2 != 0 {
  309. return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
  310. }
  311. buf := make([]uint16, len(s)/2)
  312. for i := 0; i < len(s); i += 2 {
  313. buf[i/2] = binary.LittleEndian.Uint16(s[i:])
  314. }
  315. return string(utf16.Decode(buf)), nil
  316. }
  317. func manglePassword(password string) []byte {
  318. var ucs2password []byte = str2ucs2(password)
  319. for i, ch := range ucs2password {
  320. ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5
  321. }
  322. return ucs2password
  323. }
  324. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  325. func sendLogin(w *tdsBuffer, login login) error {
  326. w.BeginPacket(packLogin7)
  327. hostname := str2ucs2(login.HostName)
  328. username := str2ucs2(login.UserName)
  329. password := manglePassword(login.Password)
  330. appname := str2ucs2(login.AppName)
  331. servername := str2ucs2(login.ServerName)
  332. ctlintname := str2ucs2(login.CtlIntName)
  333. language := str2ucs2(login.Language)
  334. database := str2ucs2(login.Database)
  335. atchdbfile := str2ucs2(login.AtchDBFile)
  336. changepassword := str2ucs2(login.ChangePassword)
  337. hdr := loginHeader{
  338. TDSVersion: login.TDSVersion,
  339. PacketSize: login.PacketSize,
  340. ClientProgVer: login.ClientProgVer,
  341. ClientPID: login.ClientPID,
  342. ConnectionID: login.ConnectionID,
  343. OptionFlags1: login.OptionFlags1,
  344. OptionFlags2: login.OptionFlags2,
  345. TypeFlags: login.TypeFlags,
  346. OptionFlags3: login.OptionFlags3,
  347. ClientTimeZone: login.ClientTimeZone,
  348. ClientLCID: login.ClientLCID,
  349. HostNameLength: uint16(utf8.RuneCountInString(login.HostName)),
  350. UserNameLength: uint16(utf8.RuneCountInString(login.UserName)),
  351. PasswordLength: uint16(utf8.RuneCountInString(login.Password)),
  352. AppNameLength: uint16(utf8.RuneCountInString(login.AppName)),
  353. ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)),
  354. CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)),
  355. LanguageLength: uint16(utf8.RuneCountInString(login.Language)),
  356. DatabaseLength: uint16(utf8.RuneCountInString(login.Database)),
  357. ClientID: login.ClientID,
  358. SSPILength: uint16(len(login.SSPI)),
  359. AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)),
  360. ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),
  361. }
  362. offset := uint16(binary.Size(hdr))
  363. hdr.HostNameOffset = offset
  364. offset += uint16(len(hostname))
  365. hdr.UserNameOffset = offset
  366. offset += uint16(len(username))
  367. hdr.PasswordOffset = offset
  368. offset += uint16(len(password))
  369. hdr.AppNameOffset = offset
  370. offset += uint16(len(appname))
  371. hdr.ServerNameOffset = offset
  372. offset += uint16(len(servername))
  373. hdr.CtlIntNameOffset = offset
  374. offset += uint16(len(ctlintname))
  375. hdr.LanguageOffset = offset
  376. offset += uint16(len(language))
  377. hdr.DatabaseOffset = offset
  378. offset += uint16(len(database))
  379. hdr.SSPIOffset = offset
  380. offset += uint16(len(login.SSPI))
  381. hdr.AtchDBFileOffset = offset
  382. offset += uint16(len(atchdbfile))
  383. hdr.ChangePasswordOffset = offset
  384. offset += uint16(len(changepassword))
  385. hdr.Length = uint32(offset)
  386. var err error
  387. err = binary.Write(w, binary.LittleEndian, &hdr)
  388. if err != nil {
  389. return err
  390. }
  391. _, err = w.Write(hostname)
  392. if err != nil {
  393. return err
  394. }
  395. _, err = w.Write(username)
  396. if err != nil {
  397. return err
  398. }
  399. _, err = w.Write(password)
  400. if err != nil {
  401. return err
  402. }
  403. _, err = w.Write(appname)
  404. if err != nil {
  405. return err
  406. }
  407. _, err = w.Write(servername)
  408. if err != nil {
  409. return err
  410. }
  411. _, err = w.Write(ctlintname)
  412. if err != nil {
  413. return err
  414. }
  415. _, err = w.Write(language)
  416. if err != nil {
  417. return err
  418. }
  419. _, err = w.Write(database)
  420. if err != nil {
  421. return err
  422. }
  423. _, err = w.Write(login.SSPI)
  424. if err != nil {
  425. return err
  426. }
  427. _, err = w.Write(atchdbfile)
  428. if err != nil {
  429. return err
  430. }
  431. _, err = w.Write(changepassword)
  432. if err != nil {
  433. return err
  434. }
  435. return w.FinishPacket()
  436. }
  437. func readUcs2(r io.Reader, numchars int) (res string, err error) {
  438. buf := make([]byte, numchars*2)
  439. _, err = io.ReadFull(r, buf)
  440. if err != nil {
  441. return "", err
  442. }
  443. return ucs22str(buf)
  444. }
  445. func readUsVarChar(r io.Reader) (res string, err error) {
  446. var numchars uint16
  447. err = binary.Read(r, binary.LittleEndian, &numchars)
  448. if err != nil {
  449. return "", err
  450. }
  451. return readUcs2(r, int(numchars))
  452. }
  453. func writeUsVarChar(w io.Writer, s string) (err error) {
  454. buf := str2ucs2(s)
  455. var numchars int = len(buf) / 2
  456. if numchars > 0xffff {
  457. panic("invalid size for US_VARCHAR")
  458. }
  459. err = binary.Write(w, binary.LittleEndian, uint16(numchars))
  460. if err != nil {
  461. return
  462. }
  463. _, err = w.Write(buf)
  464. return
  465. }
  466. func readBVarChar(r io.Reader) (res string, err error) {
  467. var numchars uint8
  468. err = binary.Read(r, binary.LittleEndian, &numchars)
  469. if err != nil {
  470. return "", err
  471. }
  472. return readUcs2(r, int(numchars))
  473. }
  474. func writeBVarChar(w io.Writer, s string) (err error) {
  475. buf := str2ucs2(s)
  476. var numchars int = len(buf) / 2
  477. if numchars > 0xff {
  478. panic("invalid size for B_VARCHAR")
  479. }
  480. err = binary.Write(w, binary.LittleEndian, uint8(numchars))
  481. if err != nil {
  482. return
  483. }
  484. _, err = w.Write(buf)
  485. return
  486. }
  487. func readBVarByte(r io.Reader) (res []byte, err error) {
  488. var length uint8
  489. err = binary.Read(r, binary.LittleEndian, &length)
  490. if err != nil {
  491. return
  492. }
  493. res = make([]byte, length)
  494. _, err = io.ReadFull(r, res)
  495. return
  496. }
  497. func readUshort(r io.Reader) (res uint16, err error) {
  498. err = binary.Read(r, binary.LittleEndian, &res)
  499. return
  500. }
  501. func readByte(r io.Reader) (res byte, err error) {
  502. var b [1]byte
  503. _, err = r.Read(b[:])
  504. res = b[0]
  505. return
  506. }
  507. // Packet Data Stream Headers
  508. // http://msdn.microsoft.com/en-us/library/dd304953.aspx
  509. type headerStruct struct {
  510. hdrtype uint16
  511. data []byte
  512. }
  513. const (
  514. dataStmHdrQueryNotif = 1 // query notifications
  515. dataStmHdrTransDescr = 2 // MARS transaction descriptor (required)
  516. dataStmHdrTraceActivity = 3
  517. )
  518. // Query Notifications Header
  519. // http://msdn.microsoft.com/en-us/library/dd304949.aspx
  520. type queryNotifHdr struct {
  521. notifyId string
  522. ssbDeployment string
  523. notifyTimeout uint32
  524. }
  525. func (hdr queryNotifHdr) pack() (res []byte) {
  526. notifyId := str2ucs2(hdr.notifyId)
  527. ssbDeployment := str2ucs2(hdr.ssbDeployment)
  528. res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)
  529. b := res
  530. binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))
  531. b = b[2:]
  532. copy(b, notifyId)
  533. b = b[len(notifyId):]
  534. binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))
  535. b = b[2:]
  536. copy(b, ssbDeployment)
  537. b = b[len(ssbDeployment):]
  538. binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)
  539. return res
  540. }
  541. // MARS Transaction Descriptor Header
  542. // http://msdn.microsoft.com/en-us/library/dd340515.aspx
  543. type transDescrHdr struct {
  544. transDescr uint64 // transaction descriptor returned from ENVCHANGE
  545. outstandingReqCnt uint32 // outstanding request count
  546. }
  547. func (hdr transDescrHdr) pack() (res []byte) {
  548. res = make([]byte, 8+4)
  549. binary.LittleEndian.PutUint64(res, hdr.transDescr)
  550. binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)
  551. return res
  552. }
  553. func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
  554. // calculatint total length
  555. var totallen uint32 = 4
  556. for _, hdr := range headers {
  557. totallen += 4 + 2 + uint32(len(hdr.data))
  558. }
  559. // writing
  560. err = binary.Write(w, binary.LittleEndian, totallen)
  561. if err != nil {
  562. return err
  563. }
  564. for _, hdr := range headers {
  565. var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))
  566. err = binary.Write(w, binary.LittleEndian, headerlen)
  567. if err != nil {
  568. return err
  569. }
  570. err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)
  571. if err != nil {
  572. return err
  573. }
  574. _, err = w.Write(hdr.data)
  575. if err != nil {
  576. return err
  577. }
  578. }
  579. return nil
  580. }
  581. func sendSqlBatch72(buf *tdsBuffer,
  582. sqltext string,
  583. headers []headerStruct) (err error) {
  584. buf.BeginPacket(packSQLBatch)
  585. if err = writeAllHeaders(buf, headers); err != nil {
  586. return
  587. }
  588. _, err = buf.Write(str2ucs2(sqltext))
  589. if err != nil {
  590. return
  591. }
  592. return buf.FinishPacket()
  593. }
  594. // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
  595. // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
  596. func sendAttention(buf *tdsBuffer) error {
  597. buf.BeginPacket(packAttention)
  598. return buf.FinishPacket()
  599. }
  600. type connectParams struct {
  601. logFlags uint64
  602. port uint64
  603. host string
  604. instance string
  605. database string
  606. user string
  607. password string
  608. dial_timeout time.Duration
  609. conn_timeout time.Duration
  610. keepAlive time.Duration
  611. encrypt bool
  612. disableEncryption bool
  613. trustServerCertificate bool
  614. certificate string
  615. hostInCertificate string
  616. serverSPN string
  617. workstation string
  618. appname string
  619. typeFlags uint8
  620. failOverPartner string
  621. failOverPort uint64
  622. }
  623. func splitConnectionString(dsn string) (res map[string]string) {
  624. res = map[string]string{}
  625. parts := strings.Split(dsn, ";")
  626. for _, part := range parts {
  627. if len(part) == 0 {
  628. continue
  629. }
  630. lst := strings.SplitN(part, "=", 2)
  631. name := strings.TrimSpace(strings.ToLower(lst[0]))
  632. if len(name) == 0 {
  633. continue
  634. }
  635. var value string = ""
  636. if len(lst) > 1 {
  637. value = strings.TrimSpace(lst[1])
  638. }
  639. res[name] = value
  640. }
  641. return res
  642. }
  643. // Splits a URL in the ODBC format
  644. func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
  645. res := map[string]string{}
  646. type parserState int
  647. const (
  648. // Before the start of a key
  649. parserStateBeforeKey parserState = iota
  650. // Inside a key
  651. parserStateKey
  652. // Beginning of a value. May be bare or braced
  653. parserStateBeginValue
  654. // Inside a bare value
  655. parserStateBareValue
  656. // Inside a braced value
  657. parserStateBracedValue
  658. // A closing brace inside a braced value.
  659. // May be the end of the value or an escaped closing brace, depending on the next character
  660. parserStateBracedValueClosingBrace
  661. // After a value. Next character should be a semi-colon or whitespace.
  662. parserStateEndValue
  663. )
  664. var state = parserStateBeforeKey
  665. var key string
  666. var value string
  667. for i, c := range dsn {
  668. switch state {
  669. case parserStateBeforeKey:
  670. switch {
  671. case c == '=':
  672. return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
  673. case !unicode.IsSpace(c) && c != ';':
  674. state = parserStateKey
  675. key += string(c)
  676. }
  677. case parserStateKey:
  678. switch c {
  679. case '=':
  680. key = normalizeOdbcKey(key)
  681. if len(key) == 0 {
  682. return res, fmt.Errorf("Unexpected end of key at index %d.", i)
  683. }
  684. state = parserStateBeginValue
  685. case ';':
  686. // Key without value
  687. key = normalizeOdbcKey(key)
  688. if len(key) == 0 {
  689. return res, fmt.Errorf("Unexpected end of key at index %d.", i)
  690. }
  691. res[key] = value
  692. key = ""
  693. value = ""
  694. state = parserStateBeforeKey
  695. default:
  696. key += string(c)
  697. }
  698. case parserStateBeginValue:
  699. switch {
  700. case c == '{':
  701. state = parserStateBracedValue
  702. case c == ';':
  703. // Empty value
  704. res[key] = value
  705. key = ""
  706. state = parserStateBeforeKey
  707. case unicode.IsSpace(c):
  708. // Ignore whitespace
  709. default:
  710. state = parserStateBareValue
  711. value += string(c)
  712. }
  713. case parserStateBareValue:
  714. if c == ';' {
  715. res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
  716. key = ""
  717. value = ""
  718. state = parserStateBeforeKey
  719. } else {
  720. value += string(c)
  721. }
  722. case parserStateBracedValue:
  723. if c == '}' {
  724. state = parserStateBracedValueClosingBrace
  725. } else {
  726. value += string(c)
  727. }
  728. case parserStateBracedValueClosingBrace:
  729. if c == '}' {
  730. // Escaped closing brace
  731. value += string(c)
  732. state = parserStateBracedValue
  733. continue
  734. }
  735. // End of braced value
  736. res[key] = value
  737. key = ""
  738. value = ""
  739. // This character is the first character past the end,
  740. // so it needs to be parsed like the parserStateEndValue state.
  741. state = parserStateEndValue
  742. switch {
  743. case c == ';':
  744. state = parserStateBeforeKey
  745. case unicode.IsSpace(c):
  746. // Ignore whitespace
  747. default:
  748. return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
  749. }
  750. case parserStateEndValue:
  751. switch {
  752. case c == ';':
  753. state = parserStateBeforeKey
  754. case unicode.IsSpace(c):
  755. // Ignore whitespace
  756. default:
  757. return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
  758. }
  759. }
  760. }
  761. switch state {
  762. case parserStateBeforeKey: // Okay
  763. case parserStateKey: // Unfinished key. Treat as key without value.
  764. key = normalizeOdbcKey(key)
  765. if len(key) == 0 {
  766. return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))
  767. }
  768. res[key] = value
  769. case parserStateBeginValue: // Empty value
  770. res[key] = value
  771. case parserStateBareValue:
  772. res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
  773. case parserStateBracedValue:
  774. return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
  775. case parserStateBracedValueClosingBrace: // End of braced value
  776. res[key] = value
  777. case parserStateEndValue: // Okay
  778. }
  779. return res, nil
  780. }
  781. // Normalizes the given string as an ODBC-format key
  782. func normalizeOdbcKey(s string) string {
  783. return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
  784. }
  785. // Splits a URL of the form sqlserver://username:password@host/instance?param1=value&param2=value
  786. func splitConnectionStringURL(dsn string) (map[string]string, error) {
  787. res := map[string]string{}
  788. u, err := url.Parse(dsn)
  789. if err != nil {
  790. return res, err
  791. }
  792. if u.Scheme != "sqlserver" {
  793. return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
  794. }
  795. if u.User != nil {
  796. res["user id"] = u.User.Username()
  797. p, exists := u.User.Password()
  798. if exists {
  799. res["password"] = p
  800. }
  801. }
  802. host, port, err := net.SplitHostPort(u.Host)
  803. if err != nil {
  804. host = u.Host
  805. }
  806. if len(u.Path) > 0 {
  807. res["server"] = host + "\\" + u.Path[1:]
  808. } else {
  809. res["server"] = host
  810. }
  811. if len(port) > 0 {
  812. res["port"] = port
  813. }
  814. query := u.Query()
  815. for k, v := range query {
  816. if len(v) > 1 {
  817. return res, fmt.Errorf("key %s provided more than once", k)
  818. }
  819. res[k] = v[0]
  820. }
  821. return res, nil
  822. }
  823. func parseConnectParams(dsn string) (connectParams, error) {
  824. var p connectParams
  825. var params map[string]string
  826. if strings.HasPrefix(dsn, "odbc:") {
  827. parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
  828. if err != nil {
  829. return p, err
  830. }
  831. params = parameters
  832. } else if strings.HasPrefix(dsn, "sqlserver://") {
  833. parameters, err := splitConnectionStringURL(dsn)
  834. if err != nil {
  835. return p, err
  836. }
  837. params = parameters
  838. } else {
  839. params = splitConnectionString(dsn)
  840. }
  841. strlog, ok := params["log"]
  842. if ok {
  843. var err error
  844. p.logFlags, err = strconv.ParseUint(strlog, 10, 0)
  845. if err != nil {
  846. return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
  847. }
  848. }
  849. server := params["server"]
  850. parts := strings.SplitN(server, "\\", 2)
  851. p.host = parts[0]
  852. if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
  853. p.host = "localhost"
  854. }
  855. if len(parts) > 1 {
  856. p.instance = parts[1]
  857. }
  858. p.database = params["database"]
  859. p.user = params["user id"]
  860. p.password = params["password"]
  861. p.port = 1433
  862. strport, ok := params["port"]
  863. if ok {
  864. var err error
  865. p.port, err = strconv.ParseUint(strport, 0, 16)
  866. if err != nil {
  867. f := "Invalid tcp port '%v': %v"
  868. return p, fmt.Errorf(f, strport, err.Error())
  869. }
  870. }
  871. // https://msdn.microsoft.com/en-us/library/dd341108.aspx
  872. p.dial_timeout = 15 * time.Second
  873. p.conn_timeout = 30 * time.Second
  874. strconntimeout, ok := params["connection timeout"]
  875. if ok {
  876. timeout, err := strconv.ParseUint(strconntimeout, 0, 16)
  877. if err != nil {
  878. f := "Invalid connection timeout '%v': %v"
  879. return p, fmt.Errorf(f, strconntimeout, err.Error())
  880. }
  881. p.conn_timeout = time.Duration(timeout) * time.Second
  882. }
  883. strdialtimeout, ok := params["dial timeout"]
  884. if ok {
  885. timeout, err := strconv.ParseUint(strdialtimeout, 0, 16)
  886. if err != nil {
  887. f := "Invalid dial timeout '%v': %v"
  888. return p, fmt.Errorf(f, strdialtimeout, err.Error())
  889. }
  890. p.dial_timeout = time.Duration(timeout) * time.Second
  891. }
  892. // default keep alive should be 30 seconds according to spec:
  893. // https://msdn.microsoft.com/en-us/library/dd341108.aspx
  894. p.keepAlive = 30 * time.Second
  895. keepAlive, ok := params["keepalive"]
  896. if ok {
  897. timeout, err := strconv.ParseUint(keepAlive, 0, 16)
  898. if err != nil {
  899. f := "Invalid keepAlive value '%s': %s"
  900. return p, fmt.Errorf(f, keepAlive, err.Error())
  901. }
  902. p.keepAlive = time.Duration(timeout) * time.Second
  903. }
  904. encrypt, ok := params["encrypt"]
  905. if ok {
  906. if strings.ToUpper(encrypt) == "DISABLE" {
  907. p.disableEncryption = true
  908. } else {
  909. var err error
  910. p.encrypt, err = strconv.ParseBool(encrypt)
  911. if err != nil {
  912. f := "Invalid encrypt '%s': %s"
  913. return p, fmt.Errorf(f, encrypt, err.Error())
  914. }
  915. }
  916. } else {
  917. p.trustServerCertificate = true
  918. }
  919. trust, ok := params["trustservercertificate"]
  920. if ok {
  921. var err error
  922. p.trustServerCertificate, err = strconv.ParseBool(trust)
  923. if err != nil {
  924. f := "Invalid trust server certificate '%s': %s"
  925. return p, fmt.Errorf(f, trust, err.Error())
  926. }
  927. }
  928. p.certificate = params["certificate"]
  929. p.hostInCertificate, ok = params["hostnameincertificate"]
  930. if !ok {
  931. p.hostInCertificate = p.host
  932. }
  933. serverSPN, ok := params["serverspn"]
  934. if ok {
  935. p.serverSPN = serverSPN
  936. } else {
  937. p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port)
  938. }
  939. workstation, ok := params["workstation id"]
  940. if ok {
  941. p.workstation = workstation
  942. } else {
  943. workstation, err := os.Hostname()
  944. if err == nil {
  945. p.workstation = workstation
  946. }
  947. }
  948. appname, ok := params["app name"]
  949. if !ok {
  950. appname = "go-mssqldb"
  951. }
  952. p.appname = appname
  953. appintent, ok := params["applicationintent"]
  954. if ok {
  955. if appintent == "ReadOnly" {
  956. p.typeFlags |= fReadOnlyIntent
  957. }
  958. }
  959. failOverPartner, ok := params["failoverpartner"]
  960. if ok {
  961. p.failOverPartner = failOverPartner
  962. }
  963. failOverPort, ok := params["failoverport"]
  964. if ok {
  965. var err error
  966. p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
  967. if err != nil {
  968. f := "Invalid tcp port '%v': %v"
  969. return p, fmt.Errorf(f, failOverPort, err.Error())
  970. }
  971. }
  972. return p, nil
  973. }
  974. type Auth interface {
  975. InitialBytes() ([]byte, error)
  976. NextBytes([]byte) ([]byte, error)
  977. Free()
  978. }
  979. // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
  980. // list of IP addresses. So if there is more than one, try them all and
  981. // use the first one that allows a connection.
  982. func dialConnection(p connectParams) (conn net.Conn, err error) {
  983. var ips []net.IP
  984. ips, err = net.LookupIP(p.host)
  985. if err != nil {
  986. ip := net.ParseIP(p.host)
  987. if ip == nil {
  988. return nil, err
  989. }
  990. ips = []net.IP{ip}
  991. }
  992. if len(ips) == 1 {
  993. d := createDialer(&p)
  994. addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
  995. conn, err = d.Dial(addr)
  996. } else {
  997. //Try Dials in parallel to avoid waiting for timeouts.
  998. connChan := make(chan net.Conn, len(ips))
  999. errChan := make(chan error, len(ips))
  1000. portStr := strconv.Itoa(int(p.port))
  1001. for _, ip := range ips {
  1002. go func(ip net.IP) {
  1003. d := createDialer(&p)
  1004. addr := net.JoinHostPort(ip.String(), portStr)
  1005. conn, err := d.Dial(addr)
  1006. if err == nil {
  1007. connChan <- conn
  1008. } else {
  1009. errChan <- err
  1010. }
  1011. }(ip)
  1012. }
  1013. // Wait for either the *first* successful connection, or all the errors
  1014. wait_loop:
  1015. for i, _ := range ips {
  1016. select {
  1017. case conn = <-connChan:
  1018. // Got a connection to use, close any others
  1019. go func(n int) {
  1020. for i := 0; i < n; i++ {
  1021. select {
  1022. case conn := <-connChan:
  1023. conn.Close()
  1024. case <-errChan:
  1025. }
  1026. }
  1027. }(len(ips) - i - 1)
  1028. // Remove any earlier errors we may have collected
  1029. err = nil
  1030. break wait_loop
  1031. case err = <-errChan:
  1032. }
  1033. }
  1034. }
  1035. // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
  1036. if conn == nil {
  1037. f := "Unable to open tcp connection with host '%v:%v': %v"
  1038. return nil, fmt.Errorf(f, p.host, p.port, err.Error())
  1039. }
  1040. return conn, err
  1041. }
  1042. func connect(log optionalLogger, p connectParams) (res *tdsSession, err error) {
  1043. res = nil
  1044. // if instance is specified use instance resolution service
  1045. if p.instance != "" {
  1046. p.instance = strings.ToUpper(p.instance)
  1047. instances, err := getInstances(p.host)
  1048. if err != nil {
  1049. f := "Unable to get instances from Sql Server Browser on host %v: %v"
  1050. return nil, fmt.Errorf(f, p.host, err.Error())
  1051. }
  1052. strport, ok := instances[p.instance]["tcp"]
  1053. if !ok {
  1054. f := "No instance matching '%v' returned from host '%v'"
  1055. return nil, fmt.Errorf(f, p.instance, p.host)
  1056. }
  1057. p.port, err = strconv.ParseUint(strport, 0, 16)
  1058. if err != nil {
  1059. f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
  1060. return nil, fmt.Errorf(f, strport, err.Error())
  1061. }
  1062. }
  1063. initiate_connection:
  1064. conn, err := dialConnection(p)
  1065. if err != nil {
  1066. return nil, err
  1067. }
  1068. toconn := NewTimeoutConn(conn, p.conn_timeout)
  1069. outbuf := newTdsBuffer(4096, toconn)
  1070. sess := tdsSession{
  1071. buf: outbuf,
  1072. log: log,
  1073. logFlags: p.logFlags,
  1074. }
  1075. instance_buf := []byte(p.instance)
  1076. instance_buf = append(instance_buf, 0) // zero terminate instance name
  1077. var encrypt byte
  1078. if p.disableEncryption {
  1079. encrypt = encryptNotSup
  1080. } else if p.encrypt {
  1081. encrypt = encryptOn
  1082. } else {
  1083. encrypt = encryptOff
  1084. }
  1085. fields := map[uint8][]byte{
  1086. preloginVERSION: {0, 0, 0, 0, 0, 0},
  1087. preloginENCRYPTION: {encrypt},
  1088. preloginINSTOPT: instance_buf,
  1089. preloginTHREADID: {0, 0, 0, 0},
  1090. preloginMARS: {0}, // MARS disabled
  1091. }
  1092. err = writePrelogin(outbuf, fields)
  1093. if err != nil {
  1094. return nil, err
  1095. }
  1096. fields, err = readPrelogin(outbuf)
  1097. if err != nil {
  1098. return nil, err
  1099. }
  1100. encryptBytes, ok := fields[preloginENCRYPTION]
  1101. if !ok {
  1102. return nil, fmt.Errorf("Encrypt negotiation failed")
  1103. }
  1104. encrypt = encryptBytes[0]
  1105. if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
  1106. return nil, fmt.Errorf("Server does not support encryption")
  1107. }
  1108. if encrypt != encryptNotSup {
  1109. var config tls.Config
  1110. if p.certificate != "" {
  1111. pem, err := ioutil.ReadFile(p.certificate)
  1112. if err != nil {
  1113. f := "Cannot read certificate '%s': %s"
  1114. return nil, fmt.Errorf(f, p.certificate, err.Error())
  1115. }
  1116. certs := x509.NewCertPool()
  1117. certs.AppendCertsFromPEM(pem)
  1118. config.RootCAs = certs
  1119. }
  1120. if p.trustServerCertificate {
  1121. config.InsecureSkipVerify = true
  1122. }
  1123. config.ServerName = p.hostInCertificate
  1124. outbuf.transport = conn
  1125. toconn.buf = outbuf
  1126. tlsConn := tls.Client(toconn, &config)
  1127. err = tlsConn.Handshake()
  1128. toconn.buf = nil
  1129. outbuf.transport = tlsConn
  1130. if err != nil {
  1131. f := "TLS Handshake failed: %s"
  1132. return nil, fmt.Errorf(f, err.Error())
  1133. }
  1134. if encrypt == encryptOff {
  1135. outbuf.afterFirst = func() {
  1136. outbuf.transport = toconn
  1137. }
  1138. }
  1139. }
  1140. login := login{
  1141. TDSVersion: verTDS74,
  1142. PacketSize: outbuf.PackageSize(),
  1143. Database: p.database,
  1144. OptionFlags2: fODBC, // to get unlimited TEXTSIZE
  1145. HostName: p.workstation,
  1146. ServerName: p.host,
  1147. AppName: p.appname,
  1148. TypeFlags: p.typeFlags,
  1149. }
  1150. auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation)
  1151. if auth_ok {
  1152. login.SSPI, err = auth.InitialBytes()
  1153. if err != nil {
  1154. return nil, err
  1155. }
  1156. login.OptionFlags2 |= fIntSecurity
  1157. defer auth.Free()
  1158. } else {
  1159. login.UserName = p.user
  1160. login.Password = p.password
  1161. }
  1162. err = sendLogin(outbuf, login)
  1163. if err != nil {
  1164. return nil, err
  1165. }
  1166. // processing login response
  1167. var sspi_msg []byte
  1168. continue_login:
  1169. tokchan := make(chan tokenStruct, 5)
  1170. go processResponse(context.Background(), &sess, tokchan)
  1171. success := false
  1172. for tok := range tokchan {
  1173. switch token := tok.(type) {
  1174. case sspiMsg:
  1175. sspi_msg, err = auth.NextBytes(token)
  1176. if err != nil {
  1177. return nil, err
  1178. }
  1179. case loginAckStruct:
  1180. success = true
  1181. sess.loginAck = token
  1182. case error:
  1183. return nil, fmt.Errorf("Login error: %s", token.Error())
  1184. }
  1185. }
  1186. if sspi_msg != nil {
  1187. outbuf.BeginPacket(packSSPIMessage)
  1188. _, err = outbuf.Write(sspi_msg)
  1189. if err != nil {
  1190. return nil, err
  1191. }
  1192. err = outbuf.FinishPacket()
  1193. if err != nil {
  1194. return nil, err
  1195. }
  1196. sspi_msg = nil
  1197. goto continue_login
  1198. }
  1199. if !success {
  1200. return nil, fmt.Errorf("Login failed")
  1201. }
  1202. if sess.routedServer != "" {
  1203. toconn.Close()
  1204. p.host = sess.routedServer
  1205. p.port = uint64(sess.routedPort)
  1206. goto initiate_connection
  1207. }
  1208. return &sess, nil
  1209. }