session.go 13 KB


  1. package gfsmux
  2. import (
  3. "container/heap"
  4. "encoding/binary"
  5. "errors"
  6. "io"
  7. "net"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. )
  12. const (
  13. defaultAcceptBacklog = 2048
  14. )
  15. var (
  16. // ErrInvalidProtocol version or bag negotiation.
  17. ErrInvalidProtocol = errors.New(
  18. "invalid protocol",
  19. )
  20. // ErrConsumed protocol error, indicates desync
  21. ErrConsumed = errors.New(
  22. "peer consumed more than sent",
  23. )
  24. // ErrGoAway overflow condition, restart it all.
  25. ErrGoAway = errors.New(
  26. "stream id overflows, should start a new Connection",
  27. )
  28. // ErrTimeout ...
  29. ErrTimeout = &timeoutError{}
  30. // ErrWouldBlock error for invalid blocking I/O operating
  31. ErrWouldBlock = errors.New(
  32. "operation would block on IO",
  33. )
  34. )
  35. var _ net.Error = &timeoutError{}
  36. type timeoutError struct{}
  37. func (
  38. e *timeoutError,
  39. ) Error() string {
  40. return "timeout"
  41. }
  42. func (
  43. e *timeoutError,
  44. ) Timeout() bool {
  45. return true
  46. }
  47. func (
  48. e *timeoutError,
  49. ) Temporary() bool {
  50. return true
  51. }
  52. // WriteRequest ...
  53. type WriteRequest struct {
  54. Prio uint64
  55. frame Frame
  56. result chan writeResult
  57. }
  58. type writeResult struct {
  59. n int
  60. err error
  61. }
  62. type buffersWriter interface {
  63. WriteBuffers(
  64. v [][]byte,
  65. ) (
  66. n int,
  67. err error,
  68. )
  69. }
  70. // Session defines a multiplexed Connection for streams
  71. type Session struct {
  72. Conn io.ReadWriteCloser
  73. Config *Config
  74. nextStreamID uint32 // next stream identifier
  75. nextStreamIDLock sync.Mutex
  76. bucket int32 // token bucket
  77. bucketNotify chan struct{} // used for waiting for tokens
  78. streams map[uint32]*Stream // all streams in this session
  79. streamLock sync.Mutex // locks streams
  80. die chan struct{} // flag session has died
  81. dieOnce sync.Once
  82. // socket error handling
  83. socketReadError atomic.Value
  84. socketWriteError atomic.Value
  85. chSocketReadError chan struct{}
  86. chSocketWriteError chan struct{}
  87. socketReadErrorOnce sync.Once
  88. socketWriteErrorOnce sync.Once
  89. // smux protocol errors
  90. protoError atomic.Value
  91. chProtoError chan struct{}
  92. protoErrorOnce sync.Once
  93. chAccepts chan *Stream
  94. dataReady int32 // flag data has arrived
  95. goAway int32 // flag id exhausted
  96. deadline atomic.Value
  97. shaper chan WriteRequest // a shaper for writing
  98. writes chan WriteRequest
  99. }
  100. func newSession(
  101. Config *Config,
  102. Conn io.ReadWriteCloser,
  103. client bool,
  104. ) *Session {
  105. s := new(
  106. Session,
  107. )
  108. s.die = make(
  109. chan struct{},
  110. )
  111. s.Conn = Conn
  112. s.Config = Config
  113. s.streams = make(
  114. map[uint32]*Stream,
  115. )
  116. s.chAccepts = make(
  117. chan *Stream,
  118. defaultAcceptBacklog,
  119. )
  120. s.bucket = int32(
  121. Config.MaxReceiveBuffer,
  122. )
  123. s.bucketNotify = make(
  124. chan struct{},
  125. 1,
  126. )
  127. s.shaper = make(
  128. chan WriteRequest,
  129. )
  130. s.writes = make(
  131. chan WriteRequest,
  132. )
  133. s.chSocketReadError = make(
  134. chan struct{},
  135. )
  136. s.chSocketWriteError = make(
  137. chan struct{},
  138. )
  139. s.chProtoError = make(
  140. chan struct{},
  141. )
  142. if client {
  143. s.nextStreamID = 1
  144. } else {
  145. s.nextStreamID = 0
  146. }
  147. go s.shaperLoop()
  148. go s.recvLoop()
  149. go s.sendLoop()
  150. if !Config.KeepAliveDisabled {
  151. go s.keepalive()
  152. }
  153. return s
  154. }
  155. // OpenStream is used to create a new stream
  156. func (
  157. s *Session,
  158. ) OpenStream() (
  159. *Stream,
  160. error,
  161. ) {
  162. if s.IsClosed() {
  163. return nil, io.ErrClosedPipe
  164. }
  165. // generate stream id
  166. s.nextStreamIDLock.Lock()
  167. if s.goAway > 0 {
  168. s.nextStreamIDLock.Unlock()
  169. return nil, ErrGoAway
  170. }
  171. s.nextStreamID += 2
  172. Sid := s.nextStreamID
  173. if Sid == Sid%2 { // stream-id overflows
  174. s.goAway = 1
  175. s.nextStreamIDLock.Unlock()
  176. return nil, ErrGoAway
  177. }
  178. s.nextStreamIDLock.Unlock()
  179. stream := newStream(
  180. Sid,
  181. s.Config.MaxFrameSize,
  182. s,
  183. )
  184. if _, err := s.WriteFrame(
  185. NewFrame(
  186. byte(s.Config.Version),
  187. CmdSyn,
  188. Sid,
  189. ),
  190. ); err != nil {
  191. return nil, err
  192. }
  193. s.streamLock.Lock()
  194. defer s.streamLock.Unlock()
  195. select {
  196. case <-s.chSocketReadError:
  197. return nil, s.socketReadError.Load().(error)
  198. case <-s.chSocketWriteError:
  199. return nil, s.socketWriteError.Load().(error)
  200. case <-s.die:
  201. return nil, io.ErrClosedPipe
  202. default:
  203. s.streams[Sid] = stream
  204. return stream, nil
  205. }
  206. }
  207. // Open returns a generic ReadWriteCloser
  208. func (
  209. s *Session,
  210. ) Open() (
  211. io.ReadWriteCloser,
  212. error,
  213. ) {
  214. return s.OpenStream()
  215. }
  216. // AcceptStream is used to block until the next available stream
  217. // is ready to be accepted.
  218. func (
  219. s *Session,
  220. ) AcceptStream() (
  221. *Stream,
  222. error,
  223. ) {
  224. var deadline <-chan time.Time
  225. if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
  226. timer := time.NewTimer(
  227. time.Until(
  228. d,
  229. ),
  230. )
  231. defer timer.Stop()
  232. deadline = timer.C
  233. }
  234. select {
  235. case stream := <-s.chAccepts:
  236. return stream, nil
  237. case <-deadline:
  238. return nil, ErrTimeout
  239. case <-s.chSocketReadError:
  240. return nil, s.socketReadError.Load().(error)
  241. case <-s.chProtoError:
  242. return nil, s.protoError.Load().(error)
  243. case <-s.die:
  244. return nil, io.ErrClosedPipe
  245. }
  246. }
  247. // Accept Returns a generic ReadWriteCloser instead of smux.Stream
  248. func (
  249. s *Session,
  250. ) Accept() (
  251. io.ReadWriteCloser,
  252. error,
  253. ) {
  254. return s.AcceptStream()
  255. }
  256. // Close is used to close the session and all streams.
  257. func (
  258. s *Session,
  259. ) Close() error {
  260. var once bool
  261. s.dieOnce.Do(func() {
  262. close(
  263. s.die,
  264. )
  265. once = true
  266. })
  267. if once {
  268. s.streamLock.Lock()
  269. for k := range s.streams {
  270. s.streams[k].sessionClose()
  271. }
  272. s.streamLock.Unlock()
  273. return s.Conn.Close()
  274. }
  275. return io.ErrClosedPipe
  276. }
  277. // notifyBucket notifies recvLoop that bucket is available
  278. func (
  279. s *Session,
  280. ) notifyBucket() {
  281. select {
  282. case s.bucketNotify <- struct{}{}:
  283. default:
  284. }
  285. }
  286. func (
  287. s *Session,
  288. ) notifyReadError(
  289. err error,
  290. ) {
  291. s.socketReadErrorOnce.Do(func() {
  292. s.socketReadError.Store(
  293. err,
  294. )
  295. close(
  296. s.chSocketReadError,
  297. )
  298. })
  299. }
  300. func (
  301. s *Session,
  302. ) notifyWriteError(
  303. err error,
  304. ) {
  305. s.socketWriteErrorOnce.Do(func() {
  306. s.socketWriteError.Store(
  307. err,
  308. )
  309. close(
  310. s.chSocketWriteError,
  311. )
  312. })
  313. }
  314. func (
  315. s *Session,
  316. ) notifyProtoError(
  317. err error,
  318. ) {
  319. s.protoErrorOnce.Do(func() {
  320. s.protoError.Store(
  321. err,
  322. )
  323. close(
  324. s.chProtoError,
  325. )
  326. })
  327. }
  328. // IsClosed does a safe check to see if we have shutdown
  329. func (
  330. s *Session,
  331. ) IsClosed() bool {
  332. select {
  333. case <-s.die:
  334. return true
  335. default:
  336. return false
  337. }
  338. }
  339. // NumStreams returns the number of currently open streams
  340. func (
  341. s *Session,
  342. ) NumStreams() int {
  343. if s.IsClosed() {
  344. return 0
  345. }
  346. s.streamLock.Lock()
  347. defer s.streamLock.Unlock()
  348. return len(
  349. s.streams,
  350. )
  351. }
  352. // SetDeadline sets a deadline used by Accept* calls.
  353. // A zero time value disables the deadline.
  354. func (
  355. s *Session,
  356. ) SetDeadline(
  357. t time.Time,
  358. ) error {
  359. s.deadline.Store(
  360. t,
  361. )
  362. return nil
  363. }
  364. // LocalAddr satisfies net.Conn interface
  365. func (
  366. s *Session,
  367. ) LocalAddr() net.Addr {
  368. if ts, ok := s.Conn.(interface {
  369. LocalAddr() net.Addr
  370. }); ok {
  371. return ts.LocalAddr()
  372. }
  373. return nil
  374. }
  375. // RemoteAddr satisfies net.Conn interface
  376. func (
  377. s *Session,
  378. ) RemoteAddr() net.Addr {
  379. if ts, ok := s.Conn.(interface {
  380. RemoteAddr() net.Addr
  381. }); ok {
  382. return ts.RemoteAddr()
  383. }
  384. return nil
  385. }
  386. // notify the session that a stream has closed
  387. func (
  388. s *Session,
  389. ) streamClosed(
  390. Sid uint32,
  391. ) {
  392. s.streamLock.Lock()
  393. // return remaining tokens to the bucket
  394. if n := s.streams[Sid].recycleTokens(); n > 0 {
  395. if atomic.AddInt32(
  396. &s.bucket,
  397. int32(n),
  398. ) > 0 {
  399. s.notifyBucket()
  400. }
  401. }
  402. delete(
  403. s.streams,
  404. Sid,
  405. )
  406. s.streamLock.Unlock()
  407. }
  408. // returnTokens is called by stream to return token after read
  409. func (
  410. s *Session,
  411. ) returnTokens(
  412. n int,
  413. ) {
  414. if atomic.AddInt32(
  415. &s.bucket,
  416. int32(n),
  417. ) > 0 {
  418. s.notifyBucket()
  419. }
  420. }
  421. // recvLoop keeps on reading from underlying Connection if tokens are available
  422. func (
  423. s *Session,
  424. ) recvLoop() {
  425. var hdr rawHeader
  426. var updHdr updHeader
  427. for {
  428. for atomic.LoadInt32(
  429. &s.bucket,
  430. ) <= 0 && !s.IsClosed() {
  431. select {
  432. case <-s.bucketNotify:
  433. case <-s.die:
  434. return
  435. }
  436. }
  437. // read header first
  438. if _, err := io.ReadFull(
  439. s.Conn,
  440. hdr[:],
  441. ); err == nil {
  442. atomic.StoreInt32(
  443. &s.dataReady,
  444. 1,
  445. )
  446. if hdr.Version() != byte(
  447. s.Config.Version,
  448. ) {
  449. s.notifyProtoError(
  450. ErrInvalidProtocol,
  451. )
  452. return
  453. }
  454. Sid := hdr.StreamID()
  455. switch hdr.Cmd() {
  456. case CmdNop:
  457. case CmdSyn:
  458. s.streamLock.Lock()
  459. if _, ok := s.streams[Sid]; !ok {
  460. stream := newStream(
  461. Sid,
  462. s.Config.MaxFrameSize,
  463. s,
  464. )
  465. s.streams[Sid] = stream
  466. select {
  467. case s.chAccepts <- stream:
  468. case <-s.die:
  469. }
  470. }
  471. s.streamLock.Unlock()
  472. case CmdFin:
  473. s.streamLock.Lock()
  474. if stream, ok := s.streams[Sid]; ok {
  475. stream.fin()
  476. stream.notifyReadEvent()
  477. }
  478. s.streamLock.Unlock()
  479. case CmdPsh:
  480. if hdr.Length() > 0 {
  481. newbuf := defaultAllocator.Get(
  482. int(hdr.Length()),
  483. )
  484. if written, err := io.ReadFull(
  485. s.Conn,
  486. newbuf,
  487. ); err == nil {
  488. s.streamLock.Lock()
  489. if stream, ok := s.streams[Sid]; ok {
  490. stream.pushBytes(
  491. newbuf,
  492. )
  493. atomic.AddInt32(
  494. &s.bucket,
  495. -int32(written),
  496. )
  497. stream.notifyReadEvent()
  498. }
  499. s.streamLock.Unlock()
  500. } else {
  501. s.notifyReadError(
  502. err,
  503. )
  504. return
  505. }
  506. }
  507. case CmdUpd:
  508. if _, err := io.ReadFull(
  509. s.Conn,
  510. updHdr[:],
  511. ); err == nil {
  512. s.streamLock.Lock()
  513. if stream, ok := s.streams[Sid]; ok {
  514. stream.update(
  515. updHdr.Consumed(),
  516. updHdr.Window(),
  517. )
  518. }
  519. s.streamLock.Unlock()
  520. } else {
  521. s.notifyReadError(
  522. err,
  523. )
  524. return
  525. }
  526. default:
  527. s.notifyProtoError(
  528. ErrInvalidProtocol,
  529. )
  530. return
  531. }
  532. } else {
  533. s.notifyReadError(
  534. err,
  535. )
  536. return
  537. }
  538. }
  539. }
  540. func (
  541. s *Session,
  542. ) keepalive() {
  543. tickerPing := time.NewTicker(
  544. s.Config.KeepAliveInterval,
  545. )
  546. tickerTimeout := time.NewTicker(
  547. s.Config.KeepAliveTimeout,
  548. )
  549. defer tickerPing.Stop()
  550. defer tickerTimeout.Stop()
  551. for {
  552. select {
  553. case <-tickerPing.C:
  554. s.WriteFrameInternal(
  555. NewFrame(
  556. byte(s.Config.Version),
  557. CmdNop,
  558. 0,
  559. ),
  560. tickerPing.C,
  561. 0,
  562. )
  563. s.notifyBucket() // force a signal to the recvLoop
  564. case <-tickerTimeout.C:
  565. if !atomic.CompareAndSwapInt32(
  566. &s.dataReady,
  567. 1,
  568. 0,
  569. ) {
  570. // recvLoop may block while bucket is 0, in this case,
  571. // session should not be closed.
  572. if atomic.LoadInt32(
  573. &s.bucket,
  574. ) > 0 {
  575. s.Close()
  576. return
  577. }
  578. }
  579. case <-s.die:
  580. return
  581. }
  582. }
  583. }
  584. // shaper shapes the sending sequence among streams
  585. func (
  586. s *Session,
  587. ) shaperLoop() {
  588. var reqs ShaperHeap
  589. var next WriteRequest
  590. var chWrite chan WriteRequest
  591. for {
  592. if len(
  593. reqs,
  594. ) > 0 {
  595. chWrite = s.writes
  596. next = heap.Pop(&reqs).(WriteRequest)
  597. } else {
  598. chWrite = nil
  599. }
  600. select {
  601. case <-s.die:
  602. return
  603. case r := <-s.shaper:
  604. if chWrite != nil { // next is valid, reshape
  605. heap.Push(
  606. &reqs,
  607. next,
  608. )
  609. }
  610. heap.Push(
  611. &reqs,
  612. r,
  613. )
  614. case chWrite <- next:
  615. }
  616. }
  617. }
  618. func (
  619. s *Session,
  620. ) sendLoop() {
  621. var buf []byte
  622. var n int
  623. var err error
  624. var vec [][]byte // vector for writeBuffers
  625. bw, ok := s.Conn.(buffersWriter)
  626. if ok {
  627. buf = make([]byte, HeaderSize)
  628. vec = make([][]byte, 2)
  629. } else {
  630. buf = make([]byte, (1<<16)+HeaderSize)
  631. }
  632. for {
  633. select {
  634. case <-s.die:
  635. return
  636. case request := <-s.writes:
  637. buf[0] = request.frame.Ver
  638. buf[1] = request.frame.Cmd
  639. binary.LittleEndian.PutUint16(
  640. buf[2:],
  641. uint16(
  642. len(
  643. request.frame.Data,
  644. ),
  645. ),
  646. )
  647. binary.LittleEndian.PutUint32(
  648. buf[4:],
  649. request.frame.Sid,
  650. )
  651. if len(
  652. vec,
  653. ) > 0 {
  654. vec[0] = buf[:HeaderSize]
  655. vec[1] = request.frame.Data
  656. n, err = bw.WriteBuffers(
  657. vec,
  658. )
  659. } else {
  660. copy(
  661. buf[HeaderSize:],
  662. request.frame.Data,
  663. )
  664. n, err = s.Conn.Write(
  665. buf[:HeaderSize+len(request.frame.Data)],
  666. )
  667. }
  668. n -= HeaderSize
  669. if n < 0 {
  670. n = 0
  671. }
  672. result := writeResult{
  673. n: n,
  674. err: err,
  675. }
  676. request.result <- result
  677. close(
  678. request.result,
  679. )
  680. // store Conn error
  681. if err != nil {
  682. s.notifyWriteError(
  683. err,
  684. )
  685. return
  686. }
  687. }
  688. }
  689. }
  690. // WriteFrame writes the frame to the underlying Connection
  691. // and returns the number of bytes written if successful
  692. func (
  693. s *Session,
  694. ) WriteFrame(
  695. f Frame,
  696. ) (
  697. n int,
  698. err error,
  699. ) {
  700. return s.WriteFrameInternal(
  701. f,
  702. nil,
  703. 0,
  704. )
  705. }
  706. // WriteFrameInternal is to support deadline used in keepalive
  707. func (
  708. s *Session,
  709. ) WriteFrameInternal(
  710. f Frame,
  711. deadline <-chan time.Time,
  712. Prio uint64,
  713. ) (
  714. int,
  715. error,
  716. ) {
  717. req := WriteRequest{
  718. Prio: Prio,
  719. frame: f,
  720. result: make(
  721. chan writeResult,
  722. 1,
  723. ),
  724. }
  725. select {
  726. case s.shaper <- req:
  727. case <-s.die:
  728. return 0, io.ErrClosedPipe
  729. case <-s.chSocketWriteError:
  730. return 0, s.socketWriteError.Load().(error)
  731. case <-deadline:
  732. return 0, ErrTimeout
  733. }
  734. select {
  735. case result := <-req.result:
  736. return result.n, result.err
  737. case <-s.die:
  738. return 0, io.ErrClosedPipe
  739. case <-s.chSocketWriteError:
  740. return 0, s.socketWriteError.Load().(error)
  741. case <-deadline:
  742. return 0, ErrTimeout
  743. }
  744. }