123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354 |
- // Copyright (C) 2015 Audrius Butkevicius and Contributors.
- package main
- import (
- "crypto/rand"
- "encoding/hex"
- "fmt"
- "log"
- "math"
- "net"
- "sync"
- "sync/atomic"
- "time"
- "golang.org/x/time/rate"
- syncthingprotocol "github.com/syncthing/syncthing/lib/protocol"
- "github.com/syncthing/syncthing/lib/relay/protocol"
- )
- var (
- sessionMut = sync.RWMutex{}
- activeSessions = make([]*session, 0)
- pendingSessions = make(map[string]*session)
- numProxies int64
- bytesProxied int64
- )
- func newSession(serverid, clientid syncthingprotocol.DeviceID, sessionRateLimit, globalRateLimit *rate.Limiter) *session {
- serverkey := make([]byte, 32)
- _, err := rand.Read(serverkey)
- if err != nil {
- return nil
- }
- clientkey := make([]byte, 32)
- _, err = rand.Read(clientkey)
- if err != nil {
- return nil
- }
- ses := &session{
- serverkey: serverkey,
- serverid: serverid,
- clientkey: clientkey,
- clientid: clientid,
- rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit),
- connsChan: make(chan net.Conn),
- conns: make([]net.Conn, 0, 2),
- }
- if debug {
- log.Println("New session", ses)
- }
- sessionMut.Lock()
- pendingSessions[string(ses.serverkey)] = ses
- pendingSessions[string(ses.clientkey)] = ses
- sessionMut.Unlock()
- return ses
- }
- func findSession(key string) *session {
- sessionMut.Lock()
- defer sessionMut.Unlock()
- ses, ok := pendingSessions[key]
- if !ok {
- return nil
- }
- delete(pendingSessions, key)
- return ses
- }
- func dropSessions(id syncthingprotocol.DeviceID) {
- sessionMut.RLock()
- for _, session := range activeSessions {
- if session.HasParticipant(id) {
- if debug {
- log.Println("Dropping session", session, "involving", id)
- }
- session.CloseConns()
- }
- }
- sessionMut.RUnlock()
- }
- func hasSessions(id syncthingprotocol.DeviceID) bool {
- sessionMut.RLock()
- has := false
- for _, session := range activeSessions {
- if session.HasParticipant(id) {
- has = true
- break
- }
- }
- sessionMut.RUnlock()
- return has
- }
- type session struct {
- mut sync.Mutex
- serverkey []byte
- serverid syncthingprotocol.DeviceID
- clientkey []byte
- clientid syncthingprotocol.DeviceID
- rateLimit func(bytes int)
- connsChan chan net.Conn
- conns []net.Conn
- }
- func (s *session) AddConnection(conn net.Conn) bool {
- if debug {
- log.Println("New connection for", s, "from", conn.RemoteAddr())
- }
- select {
- case s.connsChan <- conn:
- return true
- default:
- }
- return false
- }
- func (s *session) Serve() {
- timedout := time.After(messageTimeout)
- if debug {
- log.Println("Session", s, "serving")
- }
- for {
- select {
- case conn := <-s.connsChan:
- s.mut.Lock()
- s.conns = append(s.conns, conn)
- s.mut.Unlock()
- // We're the only ones mutating s.conns, hence we are free to read it.
- if len(s.conns) < 2 {
- continue
- }
- close(s.connsChan)
- if debug {
- log.Println("Session", s, "starting between", s.conns[0].RemoteAddr(), "and", s.conns[1].RemoteAddr())
- }
- wg := sync.WaitGroup{}
- wg.Add(2)
- var err0 error
- go func() {
- err0 = s.proxy(s.conns[0], s.conns[1])
- wg.Done()
- }()
- var err1 error
- go func() {
- err1 = s.proxy(s.conns[1], s.conns[0])
- wg.Done()
- }()
- sessionMut.Lock()
- activeSessions = append(activeSessions, s)
- sessionMut.Unlock()
- wg.Wait()
- if debug {
- log.Println("Session", s, "ended, outcomes:", err0, "and", err1)
- }
- goto done
- case <-timedout:
- if debug {
- log.Println("Session", s, "timed out")
- }
- goto done
- }
- }
- done:
- // We can end up here in 3 cases:
- // 1. Timeout joining, in which case there are potentially entries in pendingSessions
- // 2. General session end/timeout, in which case there are entries in activeSessions
- // 3. Protocol handler calls dropSession as one of its clients disconnects.
- sessionMut.Lock()
- delete(pendingSessions, string(s.serverkey))
- delete(pendingSessions, string(s.clientkey))
- for i, session := range activeSessions {
- if session == s {
- l := len(activeSessions) - 1
- activeSessions[i] = activeSessions[l]
- activeSessions[l] = nil
- activeSessions = activeSessions[:l]
- }
- }
- sessionMut.Unlock()
- // If we are here because of case 2 or 3, we are potentially closing some or
- // all connections a second time.
- s.CloseConns()
- if debug {
- log.Println("Session", s, "stopping")
- }
- }
- func (s *session) GetClientInvitationMessage() protocol.SessionInvitation {
- return protocol.SessionInvitation{
- From: s.serverid[:],
- Key: s.clientkey,
- Address: sessionAddress,
- Port: sessionPort,
- ServerSocket: false,
- }
- }
- func (s *session) GetServerInvitationMessage() protocol.SessionInvitation {
- return protocol.SessionInvitation{
- From: s.clientid[:],
- Key: s.serverkey,
- Address: sessionAddress,
- Port: sessionPort,
- ServerSocket: true,
- }
- }
- func (s *session) HasParticipant(id syncthingprotocol.DeviceID) bool {
- return s.clientid == id || s.serverid == id
- }
- func (s *session) CloseConns() {
- s.mut.Lock()
- for _, conn := range s.conns {
- conn.Close()
- }
- s.mut.Unlock()
- }
- func (s *session) proxy(c1, c2 net.Conn) error {
- if debug {
- log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr())
- }
- atomic.AddInt64(&numProxies, 1)
- defer atomic.AddInt64(&numProxies, -1)
- buf := make([]byte, networkBufferSize)
- for {
- c1.SetReadDeadline(time.Now().Add(networkTimeout))
- n, err := c1.Read(buf)
- if err != nil {
- return err
- }
- atomic.AddInt64(&bytesProxied, int64(n))
- if debug {
- log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr())
- }
- if s.rateLimit != nil {
- s.rateLimit(n)
- }
- c2.SetWriteDeadline(time.Now().Add(networkTimeout))
- _, err = c2.Write(buf[:n])
- if err != nil {
- return err
- }
- }
- }
- func (s *session) String() string {
- return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5])
- }
- func makeRateLimitFunc(sessionRateLimit, globalRateLimit *rate.Limiter) func(int) {
- // This may be a case of super duper premature optimization... We build an
- // optimized function to do the rate limiting here based on what we need
- // to do and then use it in the loop.
- if sessionRateLimit == nil && globalRateLimit == nil {
- // No limiting needed. We could equally well return a func(int64){} and
- // not do a nil check were we use it, but I think the nil check there
- // makes it clear that there will be no limiting if none is
- // configured...
- return nil
- }
- if sessionRateLimit == nil {
- // We only have a global limiter
- return func(bytes int) {
- take(bytes, globalRateLimit)
- }
- }
- if globalRateLimit == nil {
- // We only have a session limiter
- return func(bytes int) {
- take(bytes, sessionRateLimit)
- }
- }
- // We have both. Queue the bytes on both the global and session specific
- // rate limiters.
- return func(bytes int) {
- take(bytes, sessionRateLimit, globalRateLimit)
- }
- }
- // take is a utility function to consume tokens from a set of rate.Limiters.
- // Tokens are consumed in parallel on all limiters, respecting their
- // individual burst sizes.
- func take(tokens int, ls ...*rate.Limiter) {
- // minBurst is the smallest burst size supported by all limiters.
- minBurst := int(math.MaxInt32)
- for _, l := range ls {
- if burst := l.Burst(); burst < minBurst {
- minBurst = burst
- }
- }
- for tokens > 0 {
- // chunk is how many tokens we can consume at a time
- chunk := tokens
- if chunk > minBurst {
- chunk = minBurst
- }
- // maxDelay is the longest delay mandated by any of the limiters for
- // the chosen chunk size.
- var maxDelay time.Duration
- for _, l := range ls {
- res := l.ReserveN(time.Now(), chunk)
- if del := res.Delay(); del > maxDelay {
- maxDelay = del
- }
- }
- time.Sleep(maxDelay)
- tokens -= chunk
- }
- }
|