123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812 |
- package gfsmux
- import (
- "container/heap"
- "encoding/binary"
- "errors"
- "io"
- "net"
- "sync"
- "sync/atomic"
- "time"
- )
- const (
- defaultAcceptBacklog = 2048
- )
- var (
- // ErrInvalidProtocol version or bag negotiation.
- ErrInvalidProtocol = errors.New(
- "invalid protocol",
- )
- // ErrConsumed protocol error, indicates desync
- ErrConsumed = errors.New(
- "peer consumed more than sent",
- )
- // ErrGoAway overflow condition, restart it all.
- ErrGoAway = errors.New(
- "stream id overflows, should start a new Connection",
- )
- // ErrTimeout ...
- ErrTimeout = &timeoutError{}
- // ErrWouldBlock error for invalid blocking I/O operating
- ErrWouldBlock = errors.New(
- "operation would block on IO",
- )
- )
- var _ net.Error = &timeoutError{}
- type timeoutError struct{}
- func (
- e *timeoutError,
- ) Error() string {
- return "timeout"
- }
- func (
- e *timeoutError,
- ) Timeout() bool {
- return true
- }
- func (
- e *timeoutError,
- ) Temporary() bool {
- return true
- }
- // WriteRequest ...
- type WriteRequest struct {
- Prio uint64
- frame Frame
- result chan writeResult
- }
- type writeResult struct {
- n int
- err error
- }
- type buffersWriter interface {
- WriteBuffers(
- v [][]byte,
- ) (
- n int,
- err error,
- )
- }
- // Session defines a multiplexed Connection for streams
- type Session struct {
- Conn io.ReadWriteCloser
- Config *Config
- nextStreamID uint32 // next stream identifier
- nextStreamIDLock sync.Mutex
- bucket int32 // token bucket
- bucketNotify chan struct{} // used for waiting for tokens
- streams map[uint32]*Stream // all streams in this session
- streamLock sync.Mutex // locks streams
- die chan struct{} // flag session has died
- dieOnce sync.Once
- // socket error handling
- socketReadError atomic.Value
- socketWriteError atomic.Value
- chSocketReadError chan struct{}
- chSocketWriteError chan struct{}
- socketReadErrorOnce sync.Once
- socketWriteErrorOnce sync.Once
- // smux protocol errors
- protoError atomic.Value
- chProtoError chan struct{}
- protoErrorOnce sync.Once
- chAccepts chan *Stream
- dataReady int32 // flag data has arrived
- goAway int32 // flag id exhausted
- deadline atomic.Value
- shaper chan WriteRequest // a shaper for writing
- writes chan WriteRequest
- }
- func newSession(
- Config *Config,
- Conn io.ReadWriteCloser,
- client bool,
- ) *Session {
- s := new(
- Session,
- )
- s.die = make(
- chan struct{},
- )
- s.Conn = Conn
- s.Config = Config
- s.streams = make(
- map[uint32]*Stream,
- )
- s.chAccepts = make(
- chan *Stream,
- defaultAcceptBacklog,
- )
- s.bucket = int32(
- Config.MaxReceiveBuffer,
- )
- s.bucketNotify = make(
- chan struct{},
- 1,
- )
- s.shaper = make(
- chan WriteRequest,
- )
- s.writes = make(
- chan WriteRequest,
- )
- s.chSocketReadError = make(
- chan struct{},
- )
- s.chSocketWriteError = make(
- chan struct{},
- )
- s.chProtoError = make(
- chan struct{},
- )
- if client {
- s.nextStreamID = 1
- } else {
- s.nextStreamID = 0
- }
- go s.shaperLoop()
- go s.recvLoop()
- go s.sendLoop()
- if !Config.KeepAliveDisabled {
- go s.keepalive()
- }
- return s
- }
- // OpenStream is used to create a new stream
- func (
- s *Session,
- ) OpenStream() (
- *Stream,
- error,
- ) {
- if s.IsClosed() {
- return nil, io.ErrClosedPipe
- }
- // generate stream id
- s.nextStreamIDLock.Lock()
- if s.goAway > 0 {
- s.nextStreamIDLock.Unlock()
- return nil, ErrGoAway
- }
- s.nextStreamID += 2
- Sid := s.nextStreamID
- if Sid == Sid%2 { // stream-id overflows
- s.goAway = 1
- s.nextStreamIDLock.Unlock()
- return nil, ErrGoAway
- }
- s.nextStreamIDLock.Unlock()
- stream := newStream(
- Sid,
- s.Config.MaxFrameSize,
- s,
- )
- if _, err := s.WriteFrame(
- NewFrame(
- byte(s.Config.Version),
- CmdSyn,
- Sid,
- ),
- ); err != nil {
- return nil, err
- }
- s.streamLock.Lock()
- defer s.streamLock.Unlock()
- select {
- case <-s.chSocketReadError:
- return nil, s.socketReadError.Load().(error)
- case <-s.chSocketWriteError:
- return nil, s.socketWriteError.Load().(error)
- case <-s.die:
- return nil, io.ErrClosedPipe
- default:
- s.streams[Sid] = stream
- return stream, nil
- }
- }
- // Open returns a generic ReadWriteCloser
- func (
- s *Session,
- ) Open() (
- io.ReadWriteCloser,
- error,
- ) {
- return s.OpenStream()
- }
- // AcceptStream is used to block until the next available stream
- // is ready to be accepted.
- func (
- s *Session,
- ) AcceptStream() (
- *Stream,
- error,
- ) {
- var deadline <-chan time.Time
- if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
- timer := time.NewTimer(
- time.Until(
- d,
- ),
- )
- defer timer.Stop()
- deadline = timer.C
- }
- select {
- case stream := <-s.chAccepts:
- return stream, nil
- case <-deadline:
- return nil, ErrTimeout
- case <-s.chSocketReadError:
- return nil, s.socketReadError.Load().(error)
- case <-s.chProtoError:
- return nil, s.protoError.Load().(error)
- case <-s.die:
- return nil, io.ErrClosedPipe
- }
- }
- // Accept Returns a generic ReadWriteCloser instead of smux.Stream
- func (
- s *Session,
- ) Accept() (
- io.ReadWriteCloser,
- error,
- ) {
- return s.AcceptStream()
- }
- // Close is used to close the session and all streams.
- func (
- s *Session,
- ) Close() error {
- var once bool
- s.dieOnce.Do(func() {
- close(
- s.die,
- )
- once = true
- })
- if once {
- s.streamLock.Lock()
- for k := range s.streams {
- s.streams[k].sessionClose()
- }
- s.streamLock.Unlock()
- return s.Conn.Close()
- }
- return io.ErrClosedPipe
- }
- // notifyBucket notifies recvLoop that bucket is available
- func (
- s *Session,
- ) notifyBucket() {
- select {
- case s.bucketNotify <- struct{}{}:
- default:
- }
- }
- func (
- s *Session,
- ) notifyReadError(
- err error,
- ) {
- s.socketReadErrorOnce.Do(func() {
- s.socketReadError.Store(
- err,
- )
- close(
- s.chSocketReadError,
- )
- })
- }
- func (
- s *Session,
- ) notifyWriteError(
- err error,
- ) {
- s.socketWriteErrorOnce.Do(func() {
- s.socketWriteError.Store(
- err,
- )
- close(
- s.chSocketWriteError,
- )
- })
- }
- func (
- s *Session,
- ) notifyProtoError(
- err error,
- ) {
- s.protoErrorOnce.Do(func() {
- s.protoError.Store(
- err,
- )
- close(
- s.chProtoError,
- )
- })
- }
- // IsClosed does a safe check to see if we have shutdown
- func (
- s *Session,
- ) IsClosed() bool {
- select {
- case <-s.die:
- return true
- default:
- return false
- }
- }
- // NumStreams returns the number of currently open streams
- func (
- s *Session,
- ) NumStreams() int {
- if s.IsClosed() {
- return 0
- }
- s.streamLock.Lock()
- defer s.streamLock.Unlock()
- return len(
- s.streams,
- )
- }
- // SetDeadline sets a deadline used by Accept* calls.
- // A zero time value disables the deadline.
- func (
- s *Session,
- ) SetDeadline(
- t time.Time,
- ) error {
- s.deadline.Store(
- t,
- )
- return nil
- }
- // LocalAddr satisfies net.Conn interface
- func (
- s *Session,
- ) LocalAddr() net.Addr {
- if ts, ok := s.Conn.(interface {
- LocalAddr() net.Addr
- }); ok {
- return ts.LocalAddr()
- }
- return nil
- }
- // RemoteAddr satisfies net.Conn interface
- func (
- s *Session,
- ) RemoteAddr() net.Addr {
- if ts, ok := s.Conn.(interface {
- RemoteAddr() net.Addr
- }); ok {
- return ts.RemoteAddr()
- }
- return nil
- }
- // notify the session that a stream has closed
- func (
- s *Session,
- ) streamClosed(
- Sid uint32,
- ) {
- s.streamLock.Lock()
- // return remaining tokens to the bucket
- if n := s.streams[Sid].recycleTokens(); n > 0 {
- if atomic.AddInt32(
- &s.bucket,
- int32(n),
- ) > 0 {
- s.notifyBucket()
- }
- }
- delete(
- s.streams,
- Sid,
- )
- s.streamLock.Unlock()
- }
- // returnTokens is called by stream to return token after read
- func (
- s *Session,
- ) returnTokens(
- n int,
- ) {
- if atomic.AddInt32(
- &s.bucket,
- int32(n),
- ) > 0 {
- s.notifyBucket()
- }
- }
- // recvLoop keeps on reading from underlying Connection if tokens are available
- func (
- s *Session,
- ) recvLoop() {
- var hdr rawHeader
- var updHdr updHeader
- for {
- for atomic.LoadInt32(
- &s.bucket,
- ) <= 0 && !s.IsClosed() {
- select {
- case <-s.bucketNotify:
- case <-s.die:
- return
- }
- }
- // read header first
- if _, err := io.ReadFull(
- s.Conn,
- hdr[:],
- ); err == nil {
- atomic.StoreInt32(
- &s.dataReady,
- 1,
- )
- if hdr.Version() != byte(
- s.Config.Version,
- ) {
- s.notifyProtoError(
- ErrInvalidProtocol,
- )
- return
- }
- Sid := hdr.StreamID()
- switch hdr.Cmd() {
- case CmdNop:
- case CmdSyn:
- s.streamLock.Lock()
- if _, ok := s.streams[Sid]; !ok {
- stream := newStream(
- Sid,
- s.Config.MaxFrameSize,
- s,
- )
- s.streams[Sid] = stream
- select {
- case s.chAccepts <- stream:
- case <-s.die:
- }
- }
- s.streamLock.Unlock()
- case CmdFin:
- s.streamLock.Lock()
- if stream, ok := s.streams[Sid]; ok {
- stream.fin()
- stream.notifyReadEvent()
- }
- s.streamLock.Unlock()
- case CmdPsh:
- if hdr.Length() > 0 {
- newbuf := defaultAllocator.Get(
- int(hdr.Length()),
- )
- if written, err := io.ReadFull(
- s.Conn,
- newbuf,
- ); err == nil {
- s.streamLock.Lock()
- if stream, ok := s.streams[Sid]; ok {
- stream.pushBytes(
- newbuf,
- )
- atomic.AddInt32(
- &s.bucket,
- -int32(written),
- )
- stream.notifyReadEvent()
- }
- s.streamLock.Unlock()
- } else {
- s.notifyReadError(
- err,
- )
- return
- }
- }
- case CmdUpd:
- if _, err := io.ReadFull(
- s.Conn,
- updHdr[:],
- ); err == nil {
- s.streamLock.Lock()
- if stream, ok := s.streams[Sid]; ok {
- stream.update(
- updHdr.Consumed(),
- updHdr.Window(),
- )
- }
- s.streamLock.Unlock()
- } else {
- s.notifyReadError(
- err,
- )
- return
- }
- default:
- s.notifyProtoError(
- ErrInvalidProtocol,
- )
- return
- }
- } else {
- s.notifyReadError(
- err,
- )
- return
- }
- }
- }
- func (
- s *Session,
- ) keepalive() {
- tickerPing := time.NewTicker(
- s.Config.KeepAliveInterval,
- )
- tickerTimeout := time.NewTicker(
- s.Config.KeepAliveTimeout,
- )
- defer tickerPing.Stop()
- defer tickerTimeout.Stop()
- for {
- select {
- case <-tickerPing.C:
- s.WriteFrameInternal(
- NewFrame(
- byte(s.Config.Version),
- CmdNop,
- 0,
- ),
- tickerPing.C,
- 0,
- )
- s.notifyBucket() // force a signal to the recvLoop
- case <-tickerTimeout.C:
- if !atomic.CompareAndSwapInt32(
- &s.dataReady,
- 1,
- 0,
- ) {
- // recvLoop may block while bucket is 0, in this case,
- // session should not be closed.
- if atomic.LoadInt32(
- &s.bucket,
- ) > 0 {
- s.Close()
- return
- }
- }
- case <-s.die:
- return
- }
- }
- }
- // shaper shapes the sending sequence among streams
- func (
- s *Session,
- ) shaperLoop() {
- var reqs ShaperHeap
- var next WriteRequest
- var chWrite chan WriteRequest
- for {
- if len(
- reqs,
- ) > 0 {
- chWrite = s.writes
- next = heap.Pop(&reqs).(WriteRequest)
- } else {
- chWrite = nil
- }
- select {
- case <-s.die:
- return
- case r := <-s.shaper:
- if chWrite != nil { // next is valid, reshape
- heap.Push(
- &reqs,
- next,
- )
- }
- heap.Push(
- &reqs,
- r,
- )
- case chWrite <- next:
- }
- }
- }
- func (
- s *Session,
- ) sendLoop() {
- var buf []byte
- var n int
- var err error
- var vec [][]byte // vector for writeBuffers
- bw, ok := s.Conn.(buffersWriter)
- if ok {
- buf = make([]byte, HeaderSize)
- vec = make([][]byte, 2)
- } else {
- buf = make([]byte, (1<<16)+HeaderSize)
- }
- for {
- select {
- case <-s.die:
- return
- case request := <-s.writes:
- buf[0] = request.frame.Ver
- buf[1] = request.frame.Cmd
- binary.LittleEndian.PutUint16(
- buf[2:],
- uint16(
- len(
- request.frame.Data,
- ),
- ),
- )
- binary.LittleEndian.PutUint32(
- buf[4:],
- request.frame.Sid,
- )
- if len(
- vec,
- ) > 0 {
- vec[0] = buf[:HeaderSize]
- vec[1] = request.frame.Data
- n, err = bw.WriteBuffers(
- vec,
- )
- } else {
- copy(
- buf[HeaderSize:],
- request.frame.Data,
- )
- n, err = s.Conn.Write(
- buf[:HeaderSize+len(request.frame.Data)],
- )
- }
- n -= HeaderSize
- if n < 0 {
- n = 0
- }
- result := writeResult{
- n: n,
- err: err,
- }
- request.result <- result
- close(
- request.result,
- )
- // store Conn error
- if err != nil {
- s.notifyWriteError(
- err,
- )
- return
- }
- }
- }
- }
- // WriteFrame writes the frame to the underlying Connection
- // and returns the number of bytes written if successful
- func (
- s *Session,
- ) WriteFrame(
- f Frame,
- ) (
- n int,
- err error,
- ) {
- return s.WriteFrameInternal(
- f,
- nil,
- 0,
- )
- }
- // WriteFrameInternal is to support deadline used in keepalive
- func (
- s *Session,
- ) WriteFrameInternal(
- f Frame,
- deadline <-chan time.Time,
- Prio uint64,
- ) (
- int,
- error,
- ) {
- req := WriteRequest{
- Prio: Prio,
- frame: f,
- result: make(
- chan writeResult,
- 1,
- ),
- }
- select {
- case s.shaper <- req:
- case <-s.die:
- return 0, io.ErrClosedPipe
- case <-s.chSocketWriteError:
- return 0, s.socketWriteError.Load().(error)
- case <-deadline:
- return 0, ErrTimeout
- }
- select {
- case result := <-req.result:
- return result.n, result.err
- case <-s.die:
- return 0, io.ErrClosedPipe
- case <-s.chSocketWriteError:
- return 0, s.socketWriteError.Load().(error)
- case <-deadline:
- return 0, ErrTimeout
- }
- }
|