session.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. package gfsmux
  2. import (
  3. "container/heap"
  4. "encoding/binary"
  5. "errors"
  6. "io"
  7. "net"
  8. "sync"
  9. "sync/atomic"
  10. "time"
  11. )
  12. const (
  13. defaultAcceptBacklog = 1024
  14. )
  15. var (
  16. ErrInvalidProtocol = errors.New("invalid protocol")
  17. ErrConsumed = errors.New("peer consumed more than sent")
  18. ErrGoAway = errors.New("stream id overflows, should start a new Connection")
  19. ErrTimeout = &timeoutError{}
  20. ErrWouldBlock = errors.New("operation would block on IO")
  21. )
  22. var _ net.Error = &timeoutError{}
  23. type timeoutError struct{}
  24. func (e *timeoutError) Error() string { return "timeout" }
  25. func (e *timeoutError) Timeout() bool { return true }
  26. func (e *timeoutError) Temporary() bool { return true }
  27. type WriteRequest struct {
  28. Prio uint64
  29. frame Frame
  30. result chan writeResult
  31. }
  32. type writeResult struct {
  33. n int
  34. err error
  35. }
  36. type buffersWriter interface {
  37. WriteBuffers(v [][]byte) (n int, err error)
  38. }
  39. // Session defines a multiplexed Connection for streams
  40. type Session struct {
  41. Conn io.ReadWriteCloser
  42. Config *Config
  43. nextStreamID uint32 // next stream identifier
  44. nextStreamIDLock sync.Mutex
  45. bucket int32 // token bucket
  46. bucketNotify chan struct{} // used for waiting for tokens
  47. streams map[uint32]*Stream // all streams in this session
  48. streamLock sync.Mutex // locks streams
  49. die chan struct{} // flag session has died
  50. dieOnce sync.Once
  51. // socket error handling
  52. socketReadError atomic.Value
  53. socketWriteError atomic.Value
  54. chSocketReadError chan struct{}
  55. chSocketWriteError chan struct{}
  56. socketReadErrorOnce sync.Once
  57. socketWriteErrorOnce sync.Once
  58. // smux protocol errors
  59. protoError atomic.Value
  60. chProtoError chan struct{}
  61. protoErrorOnce sync.Once
  62. chAccepts chan *Stream
  63. dataReady int32 // flag data has arrived
  64. goAway int32 // flag id exhausted
  65. deadline atomic.Value
  66. shaper chan WriteRequest // a shaper for writing
  67. writes chan WriteRequest
  68. }
  69. func newSession(Config *Config, Conn io.ReadWriteCloser, client bool) *Session {
  70. s := new(Session)
  71. s.die = make(chan struct{})
  72. s.Conn = Conn
  73. s.Config = Config
  74. s.streams = make(map[uint32]*Stream)
  75. s.chAccepts = make(chan *Stream, defaultAcceptBacklog)
  76. s.bucket = int32(Config.MaxReceiveBuffer)
  77. s.bucketNotify = make(chan struct{}, 1)
  78. s.shaper = make(chan WriteRequest)
  79. s.writes = make(chan WriteRequest)
  80. s.chSocketReadError = make(chan struct{})
  81. s.chSocketWriteError = make(chan struct{})
  82. s.chProtoError = make(chan struct{})
  83. if client {
  84. s.nextStreamID = 1
  85. } else {
  86. s.nextStreamID = 0
  87. }
  88. go s.shaperLoop()
  89. go s.recvLoop()
  90. go s.sendLoop()
  91. if !Config.KeepAliveDisabled {
  92. go s.keepalive()
  93. }
  94. return s
  95. }
  96. // OpenStream is used to create a new stream
  97. func (s *Session) OpenStream() (*Stream, error) {
  98. if s.IsClosed() {
  99. return nil, io.ErrClosedPipe
  100. }
  101. // generate stream id
  102. s.nextStreamIDLock.Lock()
  103. if s.goAway > 0 {
  104. s.nextStreamIDLock.Unlock()
  105. return nil, ErrGoAway
  106. }
  107. s.nextStreamID += 2
  108. Sid := s.nextStreamID
  109. if Sid == Sid%2 { // stream-id overflows
  110. s.goAway = 1
  111. s.nextStreamIDLock.Unlock()
  112. return nil, ErrGoAway
  113. }
  114. s.nextStreamIDLock.Unlock()
  115. stream := newStream(Sid, s.Config.MaxFrameSize, s)
  116. if _, err := s.WriteFrame(NewFrame(byte(s.Config.Version), CmdSyn, Sid)); err != nil {
  117. return nil, err
  118. }
  119. s.streamLock.Lock()
  120. defer s.streamLock.Unlock()
  121. select {
  122. case <-s.chSocketReadError:
  123. return nil, s.socketReadError.Load().(error)
  124. case <-s.chSocketWriteError:
  125. return nil, s.socketWriteError.Load().(error)
  126. case <-s.die:
  127. return nil, io.ErrClosedPipe
  128. default:
  129. s.streams[Sid] = stream
  130. return stream, nil
  131. }
  132. }
  133. // Open returns a generic ReadWriteCloser
  134. func (s *Session) Open() (io.ReadWriteCloser, error) {
  135. return s.OpenStream()
  136. }
  137. // AcceptStream is used to block until the next available stream
  138. // is ready to be accepted.
  139. func (s *Session) AcceptStream() (*Stream, error) {
  140. var deadline <-chan time.Time
  141. if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
  142. timer := time.NewTimer(time.Until(d))
  143. defer timer.Stop()
  144. deadline = timer.C
  145. }
  146. select {
  147. case stream := <-s.chAccepts:
  148. return stream, nil
  149. case <-deadline:
  150. return nil, ErrTimeout
  151. case <-s.chSocketReadError:
  152. return nil, s.socketReadError.Load().(error)
  153. case <-s.chProtoError:
  154. return nil, s.protoError.Load().(error)
  155. case <-s.die:
  156. return nil, io.ErrClosedPipe
  157. }
  158. }
  159. // Accept Returns a generic ReadWriteCloser instead of smux.Stream
  160. func (s *Session) Accept() (io.ReadWriteCloser, error) {
  161. return s.AcceptStream()
  162. }
  163. // Close is used to close the session and all streams.
  164. func (s *Session) Close() error {
  165. var once bool
  166. s.dieOnce.Do(func() {
  167. close(s.die)
  168. once = true
  169. })
  170. if once {
  171. s.streamLock.Lock()
  172. for k := range s.streams {
  173. s.streams[k].sessionClose()
  174. }
  175. s.streamLock.Unlock()
  176. return s.Conn.Close()
  177. }
  178. return io.ErrClosedPipe
  179. }
  180. // notifyBucket notifies recvLoop that bucket is available
  181. func (s *Session) notifyBucket() {
  182. select {
  183. case s.bucketNotify <- struct{}{}:
  184. default:
  185. }
  186. }
  187. func (s *Session) notifyReadError(err error) {
  188. s.socketReadErrorOnce.Do(func() {
  189. s.socketReadError.Store(err)
  190. close(s.chSocketReadError)
  191. })
  192. }
  193. func (s *Session) notifyWriteError(err error) {
  194. s.socketWriteErrorOnce.Do(func() {
  195. s.socketWriteError.Store(err)
  196. close(s.chSocketWriteError)
  197. })
  198. }
  199. func (s *Session) notifyProtoError(err error) {
  200. s.protoErrorOnce.Do(func() {
  201. s.protoError.Store(err)
  202. close(s.chProtoError)
  203. })
  204. }
  205. // IsClosed does a safe check to see if we have shutdown
  206. func (s *Session) IsClosed() bool {
  207. select {
  208. case <-s.die:
  209. return true
  210. default:
  211. return false
  212. }
  213. }
  214. // NumStreams returns the number of currently open streams
  215. func (s *Session) NumStreams() int {
  216. if s.IsClosed() {
  217. return 0
  218. }
  219. s.streamLock.Lock()
  220. defer s.streamLock.Unlock()
  221. return len(s.streams)
  222. }
  223. // SetDeadline sets a deadline used by Accept* calls.
  224. // A zero time value disables the deadline.
  225. func (s *Session) SetDeadline(t time.Time) error {
  226. s.deadline.Store(t)
  227. return nil
  228. }
  229. // LocalAddr satisfies net.Conn interface
  230. func (s *Session) LocalAddr() net.Addr {
  231. if ts, ok := s.Conn.(interface {
  232. LocalAddr() net.Addr
  233. }); ok {
  234. return ts.LocalAddr()
  235. }
  236. return nil
  237. }
  238. // RemoteAddr satisfies net.Conn interface
  239. func (s *Session) RemoteAddr() net.Addr {
  240. if ts, ok := s.Conn.(interface {
  241. RemoteAddr() net.Addr
  242. }); ok {
  243. return ts.RemoteAddr()
  244. }
  245. return nil
  246. }
  247. // notify the session that a stream has closed
  248. func (s *Session) streamClosed(Sid uint32) {
  249. s.streamLock.Lock()
  250. if n := s.streams[Sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket
  251. if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
  252. s.notifyBucket()
  253. }
  254. }
  255. delete(s.streams, Sid)
  256. s.streamLock.Unlock()
  257. }
  258. // returnTokens is called by stream to return token after read
  259. func (s *Session) returnTokens(n int) {
  260. if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
  261. s.notifyBucket()
  262. }
  263. }
  264. // recvLoop keeps on reading from underlying Connection if tokens are available
  265. func (s *Session) recvLoop() {
  266. var hdr rawHeader
  267. var updHdr updHeader
  268. for {
  269. for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() {
  270. select {
  271. case <-s.bucketNotify:
  272. case <-s.die:
  273. return
  274. }
  275. }
  276. // read header first
  277. if _, err := io.ReadFull(s.Conn, hdr[:]); err == nil {
  278. atomic.StoreInt32(&s.dataReady, 1)
  279. if hdr.Version() != byte(s.Config.Version) {
  280. s.notifyProtoError(ErrInvalidProtocol)
  281. return
  282. }
  283. Sid := hdr.StreamID()
  284. switch hdr.Cmd() {
  285. case CmdNop:
  286. case CmdSyn:
  287. s.streamLock.Lock()
  288. if _, ok := s.streams[Sid]; !ok {
  289. stream := newStream(Sid, s.Config.MaxFrameSize, s)
  290. s.streams[Sid] = stream
  291. select {
  292. case s.chAccepts <- stream:
  293. case <-s.die:
  294. }
  295. }
  296. s.streamLock.Unlock()
  297. case CmdFin:
  298. s.streamLock.Lock()
  299. if stream, ok := s.streams[Sid]; ok {
  300. stream.fin()
  301. stream.notifyReadEvent()
  302. }
  303. s.streamLock.Unlock()
  304. case CmdPsh:
  305. if hdr.Length() > 0 {
  306. newbuf := defaultAllocator.Get(int(hdr.Length()))
  307. if written, err := io.ReadFull(s.Conn, newbuf); err == nil {
  308. s.streamLock.Lock()
  309. if stream, ok := s.streams[Sid]; ok {
  310. stream.pushBytes(newbuf)
  311. atomic.AddInt32(&s.bucket, -int32(written))
  312. stream.notifyReadEvent()
  313. }
  314. s.streamLock.Unlock()
  315. } else {
  316. s.notifyReadError(err)
  317. return
  318. }
  319. }
  320. case CmdUpd:
  321. if _, err := io.ReadFull(s.Conn, updHdr[:]); err == nil {
  322. s.streamLock.Lock()
  323. if stream, ok := s.streams[Sid]; ok {
  324. stream.update(updHdr.Consumed(), updHdr.Window())
  325. }
  326. s.streamLock.Unlock()
  327. } else {
  328. s.notifyReadError(err)
  329. return
  330. }
  331. default:
  332. s.notifyProtoError(ErrInvalidProtocol)
  333. return
  334. }
  335. } else {
  336. s.notifyReadError(err)
  337. return
  338. }
  339. }
  340. }
  341. func (s *Session) keepalive() {
  342. tickerPing := time.NewTicker(s.Config.KeepAliveInterval)
  343. tickerTimeout := time.NewTicker(s.Config.KeepAliveTimeout)
  344. defer tickerPing.Stop()
  345. defer tickerTimeout.Stop()
  346. for {
  347. select {
  348. case <-tickerPing.C:
  349. s.WriteFrameInternal(NewFrame(byte(s.Config.Version), CmdNop, 0), tickerPing.C, 0)
  350. s.notifyBucket() // force a signal to the recvLoop
  351. case <-tickerTimeout.C:
  352. if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
  353. // recvLoop may block while bucket is 0, in this case,
  354. // session should not be closed.
  355. if atomic.LoadInt32(&s.bucket) > 0 {
  356. s.Close()
  357. return
  358. }
  359. }
  360. case <-s.die:
  361. return
  362. }
  363. }
  364. }
  365. // shaper shapes the sending sequence among streams
  366. func (s *Session) shaperLoop() {
  367. var reqs ShaperHeap
  368. var next WriteRequest
  369. var chWrite chan WriteRequest
  370. for {
  371. if len(reqs) > 0 {
  372. chWrite = s.writes
  373. next = heap.Pop(&reqs).(WriteRequest)
  374. } else {
  375. chWrite = nil
  376. }
  377. select {
  378. case <-s.die:
  379. return
  380. case r := <-s.shaper:
  381. if chWrite != nil { // next is valid, reshape
  382. heap.Push(&reqs, next)
  383. }
  384. heap.Push(&reqs, r)
  385. case chWrite <- next:
  386. }
  387. }
  388. }
  389. func (s *Session) sendLoop() {
  390. var buf []byte
  391. var n int
  392. var err error
  393. var vec [][]byte // vector for writeBuffers
  394. bw, ok := s.Conn.(buffersWriter)
  395. if ok {
  396. buf = make([]byte, HeaderSize)
  397. vec = make([][]byte, 2)
  398. } else {
  399. buf = make([]byte, (1<<16)+HeaderSize)
  400. }
  401. for {
  402. select {
  403. case <-s.die:
  404. return
  405. case request := <-s.writes:
  406. buf[0] = request.frame.Ver
  407. buf[1] = request.frame.Cmd
  408. binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.Data)))
  409. binary.LittleEndian.PutUint32(buf[4:], request.frame.Sid)
  410. if len(vec) > 0 {
  411. vec[0] = buf[:HeaderSize]
  412. vec[1] = request.frame.Data
  413. n, err = bw.WriteBuffers(vec)
  414. } else {
  415. copy(buf[HeaderSize:], request.frame.Data)
  416. n, err = s.Conn.Write(buf[:HeaderSize+len(request.frame.Data)])
  417. }
  418. n -= HeaderSize
  419. if n < 0 {
  420. n = 0
  421. }
  422. result := writeResult{
  423. n: n,
  424. err: err,
  425. }
  426. request.result <- result
  427. close(request.result)
  428. // store Conn error
  429. if err != nil {
  430. s.notifyWriteError(err)
  431. return
  432. }
  433. }
  434. }
  435. }
  436. // WriteFrame writes the frame to the underlying Connection
  437. // and returns the number of bytes written if successful
  438. func (s *Session) WriteFrame(f Frame) (n int, err error) {
  439. return s.WriteFrameInternal(f, nil, 0)
  440. }
  441. // WriteFrameInternal is to support deadline used in keepalive
  442. func (s *Session) WriteFrameInternal(f Frame, deadline <-chan time.Time, Prio uint64) (int, error) {
  443. req := WriteRequest{
  444. Prio: Prio,
  445. frame: f,
  446. result: make(chan writeResult, 1),
  447. }
  448. select {
  449. case s.shaper <- req:
  450. case <-s.die:
  451. return 0, io.ErrClosedPipe
  452. case <-s.chSocketWriteError:
  453. return 0, s.socketWriteError.Load().(error)
  454. case <-deadline:
  455. return 0, ErrTimeout
  456. }
  457. select {
  458. case result := <-req.result:
  459. return result.n, result.err
  460. case <-s.die:
  461. return 0, io.ErrClosedPipe
  462. case <-s.chSocketWriteError:
  463. return 0, s.socketWriteError.Load().(error)
  464. case <-deadline:
  465. return 0, ErrTimeout
  466. }
  467. }