123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547 |
- package gfsmux // import "go.gridfinity.dev/gfsmux"
- import (
- "encoding/binary"
- "io"
- "net"
- "sync"
- "sync/atomic"
- "time"
- )
- // Stream implements net.Conn
- type Stream struct {
- id uint32
- sess *Session
- buffers [][]byte
- heads [][]byte // slice heads kept for recycle
- bufferLock sync.Mutex
- frameSize int
- // notify a read event
- chReadEvent chan struct{}
- // flag the stream has closed
- die chan struct{}
- dieOnce sync.Once
- // FIN command
- chFinEvent chan struct{}
- finEventOnce sync.Once
- // deadlines
- readDeadline atomic.Value
- writeDeadline atomic.Value
- // per stream sliding window control
- numRead uint32 // number of consumed bytes
- numWritten uint32 // count num of bytes written
- incr uint32 // counting for sending
- // UPD command
- peerConsumed uint32 // num of bytes the peer has consumed
- peerWindow uint32 // peer window, initialized to 256KB, updated by peer
- chUpdate chan struct{} // notify of remote data consuming and window update
- }
- // newStream initiates a Stream struct
- func newStream(id uint32, frameSize int, sess *Session) *Stream {
- s := new(Stream)
- s.id = id
- s.chReadEvent = make(chan struct{}, 1)
- s.chUpdate = make(chan struct{}, 1)
- s.frameSize = frameSize
- s.sess = sess
- s.die = make(chan struct{})
- s.chFinEvent = make(chan struct{})
- s.peerWindow = initialPeerWindow // set to initial window size
- return s
- }
- // ID returns the unique stream ID.
- func (s *Stream) ID() uint32 {
- return s.id
- }
- // Read implements net.Conn
- func (s *Stream) Read(b []byte) (n int, err error) {
- for {
- n, err = s.tryRead(b)
- if err == ErrWouldBlock {
- if ew := s.waitRead(); ew != nil {
- return 0, ew
- }
- } else {
- return n, err
- }
- }
- }
- // tryRead is the nonblocking version of Read
- func (s *Stream) tryRead(b []byte) (n int, err error) {
- if s.sess.Config.Version == 2 {
- return s.tryReadv2(b)
- }
- if len(b) == 0 {
- return 0, nil
- }
- s.bufferLock.Lock()
- if len(s.buffers) > 0 {
- n = copy(b, s.buffers[0])
- s.buffers[0] = s.buffers[0][n:]
- if len(s.buffers[0]) == 0 {
- s.buffers[0] = nil
- s.buffers = s.buffers[1:]
- // full recycle
- defaultAllocator.Put(s.heads[0])
- s.heads = s.heads[1:]
- }
- }
- s.bufferLock.Unlock()
- if n > 0 {
- s.sess.returnTokens(n)
- return n, nil
- }
- select {
- case <-s.die:
- return 0, io.EOF
- default:
- return 0, ErrWouldBlock
- }
- }
- func (s *Stream) tryReadv2(b []byte) (n int, err error) {
- if len(b) == 0 {
- return 0, nil
- }
- var notifyConsumed uint32
- s.bufferLock.Lock()
- if len(s.buffers) > 0 {
- n = copy(b, s.buffers[0])
- s.buffers[0] = s.buffers[0][n:]
- if len(s.buffers[0]) == 0 {
- s.buffers[0] = nil
- s.buffers = s.buffers[1:]
- // full recycle
- defaultAllocator.Put(s.heads[0])
- s.heads = s.heads[1:]
- }
- }
- // in an ideal environment:
- // if more than half of buffer has consumed, send read ack to peer
- // based on round-trip time of ACK, continuous flowing data
- // won't slow down because of waiting for ACK, as long as the
- // consumer keeps on reading data
- // s.numRead == n also notify window at the first read
- s.numRead += uint32(n)
- s.incr += uint32(n)
- if s.incr >= uint32(s.sess.Config.MaxStreamBuffer/2) || s.numRead == uint32(n) {
- notifyConsumed = s.numRead
- s.incr = 0
- }
- s.bufferLock.Unlock()
- if n > 0 {
- s.sess.returnTokens(n)
- if notifyConsumed > 0 {
- err := s.sendWindowUpdate(notifyConsumed)
- return n, err
- }
- return n, nil
- }
- select {
- case <-s.die:
- return 0, io.EOF
- default:
- return 0, ErrWouldBlock
- }
- }
- // WriteTo implements io.WriteTo
- func (s *Stream) WriteTo(w io.Writer) (n int64, err error) {
- if s.sess.Config.Version == 2 {
- return s.writeTov2(w)
- }
- for {
- var buf []byte
- s.bufferLock.Lock()
- if len(s.buffers) > 0 {
- buf = s.buffers[0]
- s.buffers = s.buffers[1:]
- s.heads = s.heads[1:]
- }
- s.bufferLock.Unlock()
- if buf != nil {
- nw, ew := w.Write(buf)
- s.sess.returnTokens(len(buf))
- defaultAllocator.Put(buf)
- if nw > 0 {
- n += int64(nw)
- }
- if ew != nil {
- return n, ew
- }
- } else if ew := s.waitRead(); ew != nil {
- return n, ew
- }
- }
- }
- func (s *Stream) writeTov2(w io.Writer) (n int64, err error) {
- for {
- var notifyConsumed uint32
- var buf []byte
- s.bufferLock.Lock()
- if len(s.buffers) > 0 {
- buf = s.buffers[0]
- s.buffers = s.buffers[1:]
- s.heads = s.heads[1:]
- }
- s.numRead += uint32(len(buf))
- s.incr += uint32(len(buf))
- if s.incr >= uint32(s.sess.Config.MaxStreamBuffer/2) || s.numRead == uint32(len(buf)) {
- notifyConsumed = s.numRead
- s.incr = 0
- }
- s.bufferLock.Unlock()
- if buf != nil {
- nw, ew := w.Write(buf)
- s.sess.returnTokens(len(buf))
- defaultAllocator.Put(buf)
- if nw > 0 {
- n += int64(nw)
- }
- if ew != nil {
- return n, ew
- }
- if notifyConsumed > 0 {
- if err := s.sendWindowUpdate(notifyConsumed); err != nil {
- return n, err
- }
- }
- } else if ew := s.waitRead(); ew != nil {
- return n, ew
- }
- }
- }
- func (s *Stream) sendWindowUpdate(consumed uint32) error {
- var timer *time.Timer
- var deadline <-chan time.Time
- if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
- timer = time.NewTimer(time.Until(d))
- defer timer.Stop()
- deadline = timer.C
- }
- frame := NewFrame(byte(s.sess.Config.Version), CmdUpd, s.id)
- var hdr updHeader
- binary.LittleEndian.PutUint32(hdr[:], consumed)
- binary.LittleEndian.PutUint32(hdr[4:], uint32(s.sess.Config.MaxStreamBuffer))
- frame.Data = hdr[:]
- _, err := s.sess.WriteFrameInternal(frame, deadline, 0)
- return err
- }
- func (s *Stream) waitRead() error {
- var timer *time.Timer
- var deadline <-chan time.Time
- if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
- timer = time.NewTimer(time.Until(d))
- defer timer.Stop()
- deadline = timer.C
- }
- select {
- case <-s.chReadEvent:
- return nil
- case <-s.chFinEvent:
- // BUG(xtaci): Fix for https://github.com/xtaci/smux/issues/82
- s.bufferLock.Lock()
- defer s.bufferLock.Unlock()
- if len(s.buffers) > 0 {
- return nil
- }
- return io.EOF
- case <-s.sess.chSocketReadError:
- return s.sess.socketReadError.Load().(error)
- case <-s.sess.chProtoError:
- return s.sess.protoError.Load().(error)
- case <-deadline:
- return ErrTimeout
- case <-s.die:
- return io.ErrClosedPipe
- }
- }
- // Write implements net.Conn
- //
- // Note that the behavior when multiple goroutines write concurrently is not deterministic,
- // frames may interleave in random way.
- func (s *Stream) Write(b []byte) (n int, err error) {
- if s.sess.Config.Version == 2 {
- return s.writeV2(b)
- }
- var deadline <-chan time.Time
- if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
- timer := time.NewTimer(time.Until(d))
- defer timer.Stop()
- deadline = timer.C
- }
- // check if stream has closed
- select {
- case <-s.die:
- return 0, io.ErrClosedPipe
- default:
- }
- // frame split and transmit
- sent := 0
- frame := NewFrame(byte(s.sess.Config.Version), CmdPsh, s.id)
- bts := b
- for len(bts) > 0 {
- sz := len(bts)
- if sz > s.frameSize {
- sz = s.frameSize
- }
- frame.Data = bts[:sz]
- bts = bts[sz:]
- n, err := s.sess.WriteFrameInternal(frame, deadline, uint64(s.numWritten))
- s.numWritten++
- sent += n
- if err != nil {
- return sent, err
- }
- }
- return sent, nil
- }
- func (s *Stream) writeV2(b []byte) (n int, err error) {
- // check empty input
- if len(b) == 0 {
- return 0, nil
- }
- // check if stream has closed
- select {
- case <-s.die:
- return 0, io.ErrClosedPipe
- default:
- }
- // create write deadline timer
- var deadline <-chan time.Time
- if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
- timer := time.NewTimer(time.Until(d))
- defer timer.Stop()
- deadline = timer.C
- }
- // frame split and transmit process
- sent := 0
- frame := NewFrame(byte(s.sess.Config.Version), CmdPsh, s.id)
- for {
- // per stream sliding window control
- // [.... [consumed... numWritten] ... win... ]
- // [.... [consumed...................+rmtwnd]]
- var bts []byte
- // note:
- // even if uint32 overflow, this math still works:
- // eg1: uint32(0) - uint32(math.MaxUint32) = 1
- // eg2: int32(uint32(0) - uint32(1)) = -1
- // security check for misbehavior
- inflight := int32(atomic.LoadUint32(&s.numWritten) - atomic.LoadUint32(&s.peerConsumed))
- if inflight < 0 {
- return 0, ErrConsumed
- }
- win := int32(atomic.LoadUint32(&s.peerWindow)) - inflight
- if win > 0 {
- if win > int32(len(b)) {
- bts = b
- b = nil
- } else {
- bts = b[:win]
- b = b[win:]
- }
- for len(bts) > 0 {
- sz := len(bts)
- if sz > s.frameSize {
- sz = s.frameSize
- }
- frame.Data = bts[:sz]
- bts = bts[sz:]
- n, err := s.sess.WriteFrameInternal(frame, deadline, uint64(atomic.LoadUint32(&s.numWritten)))
- atomic.AddUint32(&s.numWritten, uint32(sz))
- sent += n
- if err != nil {
- return sent, err
- }
- }
- }
- // if there is any data remaining to be sent
- // wait until stream closes, window changes or deadline reached
- // this blocking behavior will inform upper layer to do flow control
- if len(b) > 0 {
- select {
- case <-s.chFinEvent: // if fin arrived, future window update is impossible
- return 0, io.EOF
- case <-s.die:
- return sent, io.ErrClosedPipe
- case <-deadline:
- return sent, ErrTimeout
- case <-s.sess.chSocketWriteError:
- return sent, s.sess.socketWriteError.Load().(error)
- case <-s.chUpdate:
- continue
- }
- } else {
- return sent, nil
- }
- }
- }
- // Close implements net.Conn
- func (s *Stream) Close() error {
- var once bool
- var err error
- s.dieOnce.Do(func() {
- close(s.die)
- once = true
- })
- if once {
- _, err = s.sess.WriteFrame(NewFrame(byte(s.sess.Config.Version), CmdFin, s.id))
- s.sess.streamClosed(s.id)
- return err
- }
- return io.ErrClosedPipe
- }
- // GetDieCh returns a readonly chan which can be readable
- // when the stream is to be closed.
- func (s *Stream) GetDieCh() <-chan struct{} {
- return s.die
- }
- // SetReadDeadline sets the read deadline as defined by
- // net.Conn.SetReadDeadline.
- // A zero time value disables the deadline.
- func (s *Stream) SetReadDeadline(t time.Time) error {
- s.readDeadline.Store(t)
- s.notifyReadEvent()
- return nil
- }
- // SetWriteDeadline sets the write deadline as defined by
- // net.Conn.SetWriteDeadline.
- // A zero time value disables the deadline.
- func (s *Stream) SetWriteDeadline(t time.Time) error {
- s.writeDeadline.Store(t)
- return nil
- }
- // SetDeadline sets both read and write deadlines as defined by
- // net.Conn.SetDeadline.
- // A zero time value disables the deadlines.
- func (s *Stream) SetDeadline(t time.Time) error {
- if err := s.SetReadDeadline(t); err != nil {
- return err
- }
- if err := s.SetWriteDeadline(t); err != nil {
- return err
- }
- return nil
- }
- // session closes
- func (s *Stream) sessionClose() { s.dieOnce.Do(func() { close(s.die) }) }
- // LocalAddr satisfies net.Conn interface
- func (s *Stream) LocalAddr() net.Addr {
- if ts, ok := s.sess.Conn.(interface {
- LocalAddr() net.Addr
- }); ok {
- return ts.LocalAddr()
- }
- return nil
- }
- // RemoteAddr satisfies net.Conn interface
- func (s *Stream) RemoteAddr() net.Addr {
- if ts, ok := s.sess.Conn.(interface {
- RemoteAddr() net.Addr
- }); ok {
- return ts.RemoteAddr()
- }
- return nil
- }
- // pushBytes append buf to buffers
- func (s *Stream) pushBytes(buf []byte) (written int, err error) {
- s.bufferLock.Lock()
- s.buffers = append(s.buffers, buf)
- s.heads = append(s.heads, buf)
- s.bufferLock.Unlock()
- return
- }
- // recycleTokens transform remaining bytes to tokens(will truncate buffer)
- func (s *Stream) recycleTokens() (n int) {
- s.bufferLock.Lock()
- for k := range s.buffers {
- n += len(s.buffers[k])
- defaultAllocator.Put(s.heads[k])
- }
- s.buffers = nil
- s.heads = nil
- s.bufferLock.Unlock()
- return
- }
- // notify read event
- func (s *Stream) notifyReadEvent() {
- select {
- case s.chReadEvent <- struct{}{}:
- default:
- }
- }
- // update command
- func (s *Stream) update(consumed, window uint32) {
- atomic.StoreUint32(&s.peerConsumed, consumed)
- atomic.StoreUint32(&s.peerWindow, window)
- select {
- case s.chUpdate <- struct{}{}:
- default:
- }
- }
- // mark this stream has been closed in protocol
- func (s *Stream) fin() {
- s.finEventOnce.Do(func() {
- close(s.chFinEvent)
- })
- }
|