buf.go 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. package mssql
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "io"
  6. )
  7. type packetType uint8
  8. type header struct {
  9. PacketType packetType
  10. Status uint8
  11. Size uint16
  12. Spid uint16
  13. PacketNo uint8
  14. Pad uint8
  15. }
  16. // tdsBuffer reads and writes TDS packets of data to the transport.
  17. // The write and read buffers are spearate to make sending attn signals
  18. // possible without locks. Currently attn signals are only sent during
  19. // reads, not writes.
  20. type tdsBuffer struct {
  21. transport io.ReadWriteCloser
  22. // Write fields.
  23. wbuf []byte
  24. wpos uint16
  25. // Read fields.
  26. rbuf []byte
  27. rpos uint16
  28. rsize uint16
  29. final bool
  30. packet_type packetType
  31. // afterFirst is assigned to right after tdsBuffer is created and
  32. // before the first use. It is executed after the first packet is
  33. // writen and then removed.
  34. afterFirst func()
  35. }
  36. func newTdsBuffer(bufsize int, transport io.ReadWriteCloser) *tdsBuffer {
  37. w := new(tdsBuffer)
  38. w.wbuf = make([]byte, bufsize)
  39. w.rbuf = make([]byte, bufsize)
  40. w.wpos = 0
  41. w.rpos = 8
  42. w.transport = transport
  43. return w
  44. }
  45. func (rw *tdsBuffer) ResizeBuffer(packetsizei int) {
  46. if len(rw.rbuf) != packetsizei {
  47. newbuf := make([]byte, packetsizei)
  48. copy(newbuf, rw.rbuf)
  49. rw.rbuf = newbuf
  50. }
  51. if len(rw.wbuf) != packetsizei {
  52. newbuf := make([]byte, packetsizei)
  53. copy(newbuf, rw.wbuf)
  54. rw.wbuf = newbuf
  55. }
  56. }
  57. func (w *tdsBuffer) PackageSize() uint32 {
  58. return uint32(len(w.wbuf))
  59. }
  60. func (w *tdsBuffer) flush() (err error) {
  61. // writing packet size
  62. binary.BigEndian.PutUint16(w.wbuf[2:], w.wpos)
  63. // writing packet into underlying transport
  64. if _, err = w.transport.Write(w.wbuf[:w.wpos]); err != nil {
  65. return err
  66. }
  67. // execute afterFirst hook if it is set
  68. if w.afterFirst != nil {
  69. w.afterFirst()
  70. w.afterFirst = nil
  71. }
  72. w.wpos = 8
  73. // packet number
  74. w.wbuf[6] += 1
  75. return nil
  76. }
  77. func (w *tdsBuffer) Write(p []byte) (total int, err error) {
  78. total = 0
  79. for {
  80. copied := copy(w.wbuf[w.wpos:], p)
  81. w.wpos += uint16(copied)
  82. total += copied
  83. if copied == len(p) {
  84. break
  85. }
  86. if err = w.flush(); err != nil {
  87. return
  88. }
  89. p = p[copied:]
  90. }
  91. return
  92. }
  93. func (w *tdsBuffer) WriteByte(b byte) error {
  94. if int(w.wpos) == len(w.wbuf) {
  95. if err := w.flush(); err != nil {
  96. return err
  97. }
  98. }
  99. w.wbuf[w.wpos] = b
  100. w.wpos += 1
  101. return nil
  102. }
  103. func (w *tdsBuffer) BeginPacket(packet_type packetType) {
  104. w.wbuf[0] = byte(packet_type)
  105. w.wbuf[1] = 0 // packet is incomplete
  106. w.wbuf[4] = 0 // spid
  107. w.wbuf[5] = 0
  108. w.wbuf[6] = 1 // packet id
  109. w.wbuf[7] = 0 // window
  110. w.wpos = 8
  111. }
  112. func (w *tdsBuffer) FinishPacket() error {
  113. w.wbuf[1] = 1 // this is last packet
  114. return w.flush()
  115. }
  116. func (r *tdsBuffer) readNextPacket() error {
  117. header := header{}
  118. var err error
  119. err = binary.Read(r.transport, binary.BigEndian, &header)
  120. if err != nil {
  121. return err
  122. }
  123. offset := uint16(binary.Size(header))
  124. if int(header.Size) > len(r.rbuf) {
  125. return errors.New("Invalid packet size, it is longer than buffer size")
  126. }
  127. if int(offset) > int(header.Size) {
  128. return errors.New("Invalid packet size, it is shorter than header size")
  129. }
  130. _, err = io.ReadFull(r.transport, r.rbuf[offset:header.Size])
  131. if err != nil {
  132. return err
  133. }
  134. r.rpos = offset
  135. r.rsize = header.Size
  136. r.final = header.Status != 0
  137. r.packet_type = header.PacketType
  138. return nil
  139. }
  140. func (r *tdsBuffer) BeginRead() (packetType, error) {
  141. err := r.readNextPacket()
  142. if err != nil {
  143. return 0, err
  144. }
  145. return r.packet_type, nil
  146. }
  147. func (r *tdsBuffer) ReadByte() (res byte, err error) {
  148. if r.rpos == r.rsize {
  149. if r.final {
  150. return 0, io.EOF
  151. }
  152. err = r.readNextPacket()
  153. if err != nil {
  154. return 0, err
  155. }
  156. }
  157. res = r.rbuf[r.rpos]
  158. r.rpos++
  159. return res, nil
  160. }
  161. func (r *tdsBuffer) byte() byte {
  162. b, err := r.ReadByte()
  163. if err != nil {
  164. badStreamPanic(err)
  165. }
  166. return b
  167. }
  168. func (r *tdsBuffer) ReadFull(buf []byte) {
  169. _, err := io.ReadFull(r, buf[:])
  170. if err != nil {
  171. badStreamPanic(err)
  172. }
  173. }
  174. func (r *tdsBuffer) uint64() uint64 {
  175. var buf [8]byte
  176. r.ReadFull(buf[:])
  177. return binary.LittleEndian.Uint64(buf[:])
  178. }
  179. func (r *tdsBuffer) int32() int32 {
  180. return int32(r.uint32())
  181. }
  182. func (r *tdsBuffer) uint32() uint32 {
  183. var buf [4]byte
  184. r.ReadFull(buf[:])
  185. return binary.LittleEndian.Uint32(buf[:])
  186. }
  187. func (r *tdsBuffer) uint16() uint16 {
  188. var buf [2]byte
  189. r.ReadFull(buf[:])
  190. return binary.LittleEndian.Uint16(buf[:])
  191. }
  192. func (r *tdsBuffer) BVarChar() string {
  193. l := int(r.byte())
  194. return r.readUcs2(l)
  195. }
  196. func (r *tdsBuffer) UsVarChar() string {
  197. l := int(r.uint16())
  198. return r.readUcs2(l)
  199. }
  200. func (r *tdsBuffer) readUcs2(numchars int) string {
  201. b := make([]byte, numchars*2)
  202. r.ReadFull(b)
  203. res, err := ucs22str(b)
  204. if err != nil {
  205. badStreamPanic(err)
  206. }
  207. return res
  208. }
  209. func (r *tdsBuffer) Read(buf []byte) (copied int, err error) {
  210. copied = 0
  211. err = nil
  212. if r.rpos == r.rsize {
  213. if r.final {
  214. return 0, io.EOF
  215. }
  216. err = r.readNextPacket()
  217. if err != nil {
  218. return
  219. }
  220. }
  221. copied = copy(buf, r.rbuf[r.rpos:r.rsize])
  222. r.rpos += uint16(copied)
  223. return
  224. }