stream.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547
  1. package gfsmux // import "go.gridfinity.dev/gfsmux"
  2. import (
  3. "encoding/binary"
  4. "io"
  5. "net"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. )
  10. // Stream implements net.Conn
  11. type Stream struct {
  12. id uint32
  13. sess *Session
  14. buffers [][]byte
  15. heads [][]byte // slice heads kept for recycle
  16. bufferLock sync.Mutex
  17. frameSize int
  18. // notify a read event
  19. chReadEvent chan struct{}
  20. // flag the stream has closed
  21. die chan struct{}
  22. dieOnce sync.Once
  23. // FIN command
  24. chFinEvent chan struct{}
  25. finEventOnce sync.Once
  26. // deadlines
  27. readDeadline atomic.Value
  28. writeDeadline atomic.Value
  29. // per stream sliding window control
  30. numRead uint32 // number of consumed bytes
  31. numWritten uint32 // count num of bytes written
  32. incr uint32 // counting for sending
  33. // UPD command
  34. peerConsumed uint32 // num of bytes the peer has consumed
  35. peerWindow uint32 // peer window, initialized to 256KB, updated by peer
  36. chUpdate chan struct{} // notify of remote data consuming and window update
  37. }
  38. // newStream initiates a Stream struct
  39. func newStream(id uint32, frameSize int, sess *Session) *Stream {
  40. s := new(Stream)
  41. s.id = id
  42. s.chReadEvent = make(chan struct{}, 1)
  43. s.chUpdate = make(chan struct{}, 1)
  44. s.frameSize = frameSize
  45. s.sess = sess
  46. s.die = make(chan struct{})
  47. s.chFinEvent = make(chan struct{})
  48. s.peerWindow = initialPeerWindow // set to initial window size
  49. return s
  50. }
  51. // ID returns the unique stream ID.
  52. func (s *Stream) ID() uint32 {
  53. return s.id
  54. }
  55. // Read implements net.Conn
  56. func (s *Stream) Read(b []byte) (n int, err error) {
  57. for {
  58. n, err = s.tryRead(b)
  59. if err == ErrWouldBlock {
  60. if ew := s.waitRead(); ew != nil {
  61. return 0, ew
  62. }
  63. } else {
  64. return n, err
  65. }
  66. }
  67. }
  68. // tryRead is the nonblocking version of Read
  69. func (s *Stream) tryRead(b []byte) (n int, err error) {
  70. if s.sess.Config.Version == 2 {
  71. return s.tryReadv2(b)
  72. }
  73. if len(b) == 0 {
  74. return 0, nil
  75. }
  76. s.bufferLock.Lock()
  77. if len(s.buffers) > 0 {
  78. n = copy(b, s.buffers[0])
  79. s.buffers[0] = s.buffers[0][n:]
  80. if len(s.buffers[0]) == 0 {
  81. s.buffers[0] = nil
  82. s.buffers = s.buffers[1:]
  83. // full recycle
  84. defaultAllocator.Put(s.heads[0])
  85. s.heads = s.heads[1:]
  86. }
  87. }
  88. s.bufferLock.Unlock()
  89. if n > 0 {
  90. s.sess.returnTokens(n)
  91. return n, nil
  92. }
  93. select {
  94. case <-s.die:
  95. return 0, io.EOF
  96. default:
  97. return 0, ErrWouldBlock
  98. }
  99. }
  100. func (s *Stream) tryReadv2(b []byte) (n int, err error) {
  101. if len(b) == 0 {
  102. return 0, nil
  103. }
  104. var notifyConsumed uint32
  105. s.bufferLock.Lock()
  106. if len(s.buffers) > 0 {
  107. n = copy(b, s.buffers[0])
  108. s.buffers[0] = s.buffers[0][n:]
  109. if len(s.buffers[0]) == 0 {
  110. s.buffers[0] = nil
  111. s.buffers = s.buffers[1:]
  112. // full recycle
  113. defaultAllocator.Put(s.heads[0])
  114. s.heads = s.heads[1:]
  115. }
  116. }
  117. // in an ideal environment:
  118. // if more than half of buffer has consumed, send read ack to peer
  119. // based on round-trip time of ACK, continuous flowing data
  120. // won't slow down because of waiting for ACK, as long as the
  121. // consumer keeps on reading data
  122. // s.numRead == n also notify window at the first read
  123. s.numRead += uint32(n)
  124. s.incr += uint32(n)
  125. if s.incr >= uint32(s.sess.Config.MaxStreamBuffer/2) || s.numRead == uint32(n) {
  126. notifyConsumed = s.numRead
  127. s.incr = 0
  128. }
  129. s.bufferLock.Unlock()
  130. if n > 0 {
  131. s.sess.returnTokens(n)
  132. if notifyConsumed > 0 {
  133. err := s.sendWindowUpdate(notifyConsumed)
  134. return n, err
  135. }
  136. return n, nil
  137. }
  138. select {
  139. case <-s.die:
  140. return 0, io.EOF
  141. default:
  142. return 0, ErrWouldBlock
  143. }
  144. }
  145. // WriteTo implements io.WriteTo
  146. func (s *Stream) WriteTo(w io.Writer) (n int64, err error) {
  147. if s.sess.Config.Version == 2 {
  148. return s.writeTov2(w)
  149. }
  150. for {
  151. var buf []byte
  152. s.bufferLock.Lock()
  153. if len(s.buffers) > 0 {
  154. buf = s.buffers[0]
  155. s.buffers = s.buffers[1:]
  156. s.heads = s.heads[1:]
  157. }
  158. s.bufferLock.Unlock()
  159. if buf != nil {
  160. nw, ew := w.Write(buf)
  161. s.sess.returnTokens(len(buf))
  162. defaultAllocator.Put(buf)
  163. if nw > 0 {
  164. n += int64(nw)
  165. }
  166. if ew != nil {
  167. return n, ew
  168. }
  169. } else if ew := s.waitRead(); ew != nil {
  170. return n, ew
  171. }
  172. }
  173. }
  174. func (s *Stream) writeTov2(w io.Writer) (n int64, err error) {
  175. for {
  176. var notifyConsumed uint32
  177. var buf []byte
  178. s.bufferLock.Lock()
  179. if len(s.buffers) > 0 {
  180. buf = s.buffers[0]
  181. s.buffers = s.buffers[1:]
  182. s.heads = s.heads[1:]
  183. }
  184. s.numRead += uint32(len(buf))
  185. s.incr += uint32(len(buf))
  186. if s.incr >= uint32(s.sess.Config.MaxStreamBuffer/2) || s.numRead == uint32(len(buf)) {
  187. notifyConsumed = s.numRead
  188. s.incr = 0
  189. }
  190. s.bufferLock.Unlock()
  191. if buf != nil {
  192. nw, ew := w.Write(buf)
  193. s.sess.returnTokens(len(buf))
  194. defaultAllocator.Put(buf)
  195. if nw > 0 {
  196. n += int64(nw)
  197. }
  198. if ew != nil {
  199. return n, ew
  200. }
  201. if notifyConsumed > 0 {
  202. if err := s.sendWindowUpdate(notifyConsumed); err != nil {
  203. return n, err
  204. }
  205. }
  206. } else if ew := s.waitRead(); ew != nil {
  207. return n, ew
  208. }
  209. }
  210. }
  211. func (s *Stream) sendWindowUpdate(consumed uint32) error {
  212. var timer *time.Timer
  213. var deadline <-chan time.Time
  214. if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
  215. timer = time.NewTimer(time.Until(d))
  216. defer timer.Stop()
  217. deadline = timer.C
  218. }
  219. frame := NewFrame(byte(s.sess.Config.Version), CmdUpd, s.id)
  220. var hdr updHeader
  221. binary.LittleEndian.PutUint32(hdr[:], consumed)
  222. binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.Config.MaxStreamBuffer))
  223. frame.Data = hdr[:]
  224. _, err := s.sess.WriteFrameInternal(frame, deadline, 0)
  225. return err
  226. }
  227. func (s *Stream) waitRead() error {
  228. var timer *time.Timer
  229. var deadline <-chan time.Time
  230. if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
  231. timer = time.NewTimer(time.Until(d))
  232. defer timer.Stop()
  233. deadline = timer.C
  234. }
  235. select {
  236. case <-s.chReadEvent:
  237. return nil
  238. case <-s.chFinEvent:
  239. // BUG(xtaci): Fix for https://github.com/xtaci/smux/issues/82
  240. s.bufferLock.Lock()
  241. defer s.bufferLock.Unlock()
  242. if len(s.buffers) > 0 {
  243. return nil
  244. }
  245. return io.EOF
  246. case <-s.sess.chSocketReadError:
  247. return s.sess.socketReadError.Load().(error)
  248. case <-s.sess.chProtoError:
  249. return s.sess.protoError.Load().(error)
  250. case <-deadline:
  251. return ErrTimeout
  252. case <-s.die:
  253. return io.ErrClosedPipe
  254. }
  255. }
  256. // Write implements net.Conn
  257. //
  258. // Note that the behavior when multiple goroutines write concurrently is not deterministic,
  259. // frames may interleave in random way.
  260. func (s *Stream) Write(b []byte) (n int, err error) {
  261. if s.sess.Config.Version == 2 {
  262. return s.writeV2(b)
  263. }
  264. var deadline <-chan time.Time
  265. if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
  266. timer := time.NewTimer(time.Until(d))
  267. defer timer.Stop()
  268. deadline = timer.C
  269. }
  270. // check if stream has closed
  271. select {
  272. case <-s.die:
  273. return 0, io.ErrClosedPipe
  274. default:
  275. }
  276. // frame split and transmit
  277. sent := 0
  278. frame := NewFrame(byte(s.sess.Config.Version), CmdPsh, s.id)
  279. bts := b
  280. for len(bts) > 0 {
  281. sz := len(bts)
  282. if sz > s.frameSize {
  283. sz = s.frameSize
  284. }
  285. frame.Data = bts[:sz]
  286. bts = bts[sz:]
  287. n, err := s.sess.WriteFrameInternal(frame, deadline, uint64(s.numWritten))
  288. s.numWritten++
  289. sent += n
  290. if err != nil {
  291. return sent, err
  292. }
  293. }
  294. return sent, nil
  295. }
  296. func (s *Stream) writeV2(b []byte) (n int, err error) {
  297. // check empty input
  298. if len(b) == 0 {
  299. return 0, nil
  300. }
  301. // check if stream has closed
  302. select {
  303. case <-s.die:
  304. return 0, io.ErrClosedPipe
  305. default:
  306. }
  307. // create write deadline timer
  308. var deadline <-chan time.Time
  309. if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
  310. timer := time.NewTimer(time.Until(d))
  311. defer timer.Stop()
  312. deadline = timer.C
  313. }
  314. // frame split and transmit process
  315. sent := 0
  316. frame := NewFrame(byte(s.sess.Config.Version), CmdPsh, s.id)
  317. for {
  318. // per stream sliding window control
  319. // [.... [consumed... numWritten] ... win... ]
  320. // [.... [consumed...................+rmtwnd]]
  321. var bts []byte
  322. // note:
  323. // even if uint32 overflow, this math still works:
  324. // eg1: uint32(0) - uint32(math.MaxUint32) = 1
  325. // eg2: int32(uint32(0) - uint32(1)) = -1
  326. // security check for misbehavior
  327. inflight := int32(atomic.LoadUint32(&s.numWritten) - atomic.LoadUint32(&s.peerConsumed))
  328. if inflight < 0 {
  329. return 0, ErrConsumed
  330. }
  331. win := int32(atomic.LoadUint32(&s.peerWindow)) - inflight
  332. if win > 0 {
  333. if win > int32(len(b)) {
  334. bts = b
  335. b = nil
  336. } else {
  337. bts = b[:win]
  338. b = b[win:]
  339. }
  340. for len(bts) > 0 {
  341. sz := len(bts)
  342. if sz > s.frameSize {
  343. sz = s.frameSize
  344. }
  345. frame.Data = bts[:sz]
  346. bts = bts[sz:]
  347. n, err := s.sess.WriteFrameInternal(frame, deadline, uint64(atomic.LoadUint32(&s.numWritten)))
  348. atomic.AddUint32(&s.numWritten, uint32(sz))
  349. sent += n
  350. if err != nil {
  351. return sent, err
  352. }
  353. }
  354. }
  355. // if there is any data remaining to be sent
  356. // wait until stream closes, window changes or deadline reached
  357. // this blocking behavior will inform upper layer to do flow control
  358. if len(b) > 0 {
  359. select {
  360. case <-s.chFinEvent: // if fin arrived, future window update is impossible
  361. return 0, io.EOF
  362. case <-s.die:
  363. return sent, io.ErrClosedPipe
  364. case <-deadline:
  365. return sent, ErrTimeout
  366. case <-s.sess.chSocketWriteError:
  367. return sent, s.sess.socketWriteError.Load().(error)
  368. case <-s.chUpdate:
  369. continue
  370. }
  371. } else {
  372. return sent, nil
  373. }
  374. }
  375. }
  376. // Close implements net.Conn
  377. func (s *Stream) Close() error {
  378. var once bool
  379. var err error
  380. s.dieOnce.Do(func() {
  381. close(s.die)
  382. once = true
  383. })
  384. if once {
  385. _, err = s.sess.WriteFrame(NewFrame(byte(s.sess.Config.Version), CmdFin, s.id))
  386. s.sess.streamClosed(s.id)
  387. return err
  388. }
  389. return io.ErrClosedPipe
  390. }
  391. // GetDieCh returns a readonly chan which can be readable
  392. // when the stream is to be closed.
  393. func (s *Stream) GetDieCh() <-chan struct{} {
  394. return s.die
  395. }
  396. // SetReadDeadline sets the read deadline as defined by
  397. // net.Conn.SetReadDeadline.
  398. // A zero time value disables the deadline.
  399. func (s *Stream) SetReadDeadline(t time.Time) error {
  400. s.readDeadline.Store(t)
  401. s.notifyReadEvent()
  402. return nil
  403. }
  404. // SetWriteDeadline sets the write deadline as defined by
  405. // net.Conn.SetWriteDeadline.
  406. // A zero time value disables the deadline.
  407. func (s *Stream) SetWriteDeadline(t time.Time) error {
  408. s.writeDeadline.Store(t)
  409. return nil
  410. }
  411. // SetDeadline sets both read and write deadlines as defined by
  412. // net.Conn.SetDeadline.
  413. // A zero time value disables the deadlines.
  414. func (s *Stream) SetDeadline(t time.Time) error {
  415. if err := s.SetReadDeadline(t); err != nil {
  416. return err
  417. }
  418. if err := s.SetWriteDeadline(t); err != nil {
  419. return err
  420. }
  421. return nil
  422. }
  423. // session closes
  424. func (s *Stream) sessionClose() { s.dieOnce.Do(func() { close(s.die) }) }
  425. // LocalAddr satisfies net.Conn interface
  426. func (s *Stream) LocalAddr() net.Addr {
  427. if ts, ok := s.sess.Conn.(interface {
  428. LocalAddr() net.Addr
  429. }); ok {
  430. return ts.LocalAddr()
  431. }
  432. return nil
  433. }
  434. // RemoteAddr satisfies net.Conn interface
  435. func (s *Stream) RemoteAddr() net.Addr {
  436. if ts, ok := s.sess.Conn.(interface {
  437. RemoteAddr() net.Addr
  438. }); ok {
  439. return ts.RemoteAddr()
  440. }
  441. return nil
  442. }
  443. // pushBytes append buf to buffers
  444. func (s *Stream) pushBytes(buf []byte) (written int, err error) {
  445. s.bufferLock.Lock()
  446. s.buffers = append(s.buffers, buf)
  447. s.heads = append(s.heads, buf)
  448. s.bufferLock.Unlock()
  449. return
  450. }
  451. // recycleTokens transform remaining bytes to tokens(will truncate buffer)
  452. func (s *Stream) recycleTokens() (n int) {
  453. s.bufferLock.Lock()
  454. for k := range s.buffers {
  455. n += len(s.buffers[k])
  456. defaultAllocator.Put(s.heads[k])
  457. }
  458. s.buffers = nil
  459. s.heads = nil
  460. s.bufferLock.Unlock()
  461. return
  462. }
  463. // notify read event
  464. func (s *Stream) notifyReadEvent() {
  465. select {
  466. case s.chReadEvent <- struct{}{}:
  467. default:
  468. }
  469. }
  470. // update command
  471. func (s *Stream) update(consumed, window uint32) {
  472. atomic.StoreUint32(&s.peerConsumed, consumed)
  473. atomic.StoreUint32(&s.peerWindow, window)
  474. select {
  475. case s.chUpdate <- struct{}{}:
  476. default:
  477. }
  478. }
  479. // mark this stream has been closed in protocol
  480. func (s *Stream) fin() {
  481. s.finEventOnce.Do(func() {
  482. close(s.chFinEvent)
  483. })
  484. }