session.go 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. // Copyright (C) 2015 Audrius Butkevicius and Contributors.
  2. package main
  3. import (
  4. "crypto/rand"
  5. "encoding/hex"
  6. "fmt"
  7. "log"
  8. "math"
  9. "net"
  10. "sync"
  11. "sync/atomic"
  12. "time"
  13. "golang.org/x/time/rate"
  14. syncthingprotocol "github.com/syncthing/syncthing/lib/protocol"
  15. "github.com/syncthing/syncthing/lib/relay/protocol"
  16. )
  17. var (
  18. sessionMut = sync.RWMutex{}
  19. activeSessions = make([]*session, 0)
  20. pendingSessions = make(map[string]*session)
  21. numProxies int64
  22. bytesProxied int64
  23. )
  24. func newSession(serverid, clientid syncthingprotocol.DeviceID, sessionRateLimit, globalRateLimit *rate.Limiter) *session {
  25. serverkey := make([]byte, 32)
  26. _, err := rand.Read(serverkey)
  27. if err != nil {
  28. return nil
  29. }
  30. clientkey := make([]byte, 32)
  31. _, err = rand.Read(clientkey)
  32. if err != nil {
  33. return nil
  34. }
  35. ses := &session{
  36. serverkey: serverkey,
  37. serverid: serverid,
  38. clientkey: clientkey,
  39. clientid: clientid,
  40. rateLimit: makeRateLimitFunc(sessionRateLimit, globalRateLimit),
  41. connsChan: make(chan net.Conn),
  42. conns: make([]net.Conn, 0, 2),
  43. }
  44. if debug {
  45. log.Println("New session", ses)
  46. }
  47. sessionMut.Lock()
  48. pendingSessions[string(ses.serverkey)] = ses
  49. pendingSessions[string(ses.clientkey)] = ses
  50. sessionMut.Unlock()
  51. return ses
  52. }
  53. func findSession(key string) *session {
  54. sessionMut.Lock()
  55. defer sessionMut.Unlock()
  56. ses, ok := pendingSessions[key]
  57. if !ok {
  58. return nil
  59. }
  60. delete(pendingSessions, key)
  61. return ses
  62. }
  63. func dropSessions(id syncthingprotocol.DeviceID) {
  64. sessionMut.RLock()
  65. for _, session := range activeSessions {
  66. if session.HasParticipant(id) {
  67. if debug {
  68. log.Println("Dropping session", session, "involving", id)
  69. }
  70. session.CloseConns()
  71. }
  72. }
  73. sessionMut.RUnlock()
  74. }
  75. func hasSessions(id syncthingprotocol.DeviceID) bool {
  76. sessionMut.RLock()
  77. has := false
  78. for _, session := range activeSessions {
  79. if session.HasParticipant(id) {
  80. has = true
  81. break
  82. }
  83. }
  84. sessionMut.RUnlock()
  85. return has
  86. }
  87. type session struct {
  88. mut sync.Mutex
  89. serverkey []byte
  90. serverid syncthingprotocol.DeviceID
  91. clientkey []byte
  92. clientid syncthingprotocol.DeviceID
  93. rateLimit func(bytes int)
  94. connsChan chan net.Conn
  95. conns []net.Conn
  96. }
  97. func (s *session) AddConnection(conn net.Conn) bool {
  98. if debug {
  99. log.Println("New connection for", s, "from", conn.RemoteAddr())
  100. }
  101. select {
  102. case s.connsChan <- conn:
  103. return true
  104. default:
  105. }
  106. return false
  107. }
  108. func (s *session) Serve() {
  109. timedout := time.After(messageTimeout)
  110. if debug {
  111. log.Println("Session", s, "serving")
  112. }
  113. for {
  114. select {
  115. case conn := <-s.connsChan:
  116. s.mut.Lock()
  117. s.conns = append(s.conns, conn)
  118. s.mut.Unlock()
  119. // We're the only ones mutating s.conns, hence we are free to read it.
  120. if len(s.conns) < 2 {
  121. continue
  122. }
  123. close(s.connsChan)
  124. if debug {
  125. log.Println("Session", s, "starting between", s.conns[0].RemoteAddr(), "and", s.conns[1].RemoteAddr())
  126. }
  127. wg := sync.WaitGroup{}
  128. wg.Add(2)
  129. var err0 error
  130. go func() {
  131. err0 = s.proxy(s.conns[0], s.conns[1])
  132. wg.Done()
  133. }()
  134. var err1 error
  135. go func() {
  136. err1 = s.proxy(s.conns[1], s.conns[0])
  137. wg.Done()
  138. }()
  139. sessionMut.Lock()
  140. activeSessions = append(activeSessions, s)
  141. sessionMut.Unlock()
  142. wg.Wait()
  143. if debug {
  144. log.Println("Session", s, "ended, outcomes:", err0, "and", err1)
  145. }
  146. goto done
  147. case <-timedout:
  148. if debug {
  149. log.Println("Session", s, "timed out")
  150. }
  151. goto done
  152. }
  153. }
  154. done:
  155. // We can end up here in 3 cases:
  156. // 1. Timeout joining, in which case there are potentially entries in pendingSessions
  157. // 2. General session end/timeout, in which case there are entries in activeSessions
  158. // 3. Protocol handler calls dropSession as one of its clients disconnects.
  159. sessionMut.Lock()
  160. delete(pendingSessions, string(s.serverkey))
  161. delete(pendingSessions, string(s.clientkey))
  162. for i, session := range activeSessions {
  163. if session == s {
  164. l := len(activeSessions) - 1
  165. activeSessions[i] = activeSessions[l]
  166. activeSessions[l] = nil
  167. activeSessions = activeSessions[:l]
  168. }
  169. }
  170. sessionMut.Unlock()
  171. // If we are here because of case 2 or 3, we are potentially closing some or
  172. // all connections a second time.
  173. s.CloseConns()
  174. if debug {
  175. log.Println("Session", s, "stopping")
  176. }
  177. }
  178. func (s *session) GetClientInvitationMessage() protocol.SessionInvitation {
  179. return protocol.SessionInvitation{
  180. From: s.serverid[:],
  181. Key: s.clientkey,
  182. Address: sessionAddress,
  183. Port: sessionPort,
  184. ServerSocket: false,
  185. }
  186. }
  187. func (s *session) GetServerInvitationMessage() protocol.SessionInvitation {
  188. return protocol.SessionInvitation{
  189. From: s.clientid[:],
  190. Key: s.serverkey,
  191. Address: sessionAddress,
  192. Port: sessionPort,
  193. ServerSocket: true,
  194. }
  195. }
  196. func (s *session) HasParticipant(id syncthingprotocol.DeviceID) bool {
  197. return s.clientid == id || s.serverid == id
  198. }
  199. func (s *session) CloseConns() {
  200. s.mut.Lock()
  201. for _, conn := range s.conns {
  202. conn.Close()
  203. }
  204. s.mut.Unlock()
  205. }
  206. func (s *session) proxy(c1, c2 net.Conn) error {
  207. if debug {
  208. log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr())
  209. }
  210. atomic.AddInt64(&numProxies, 1)
  211. defer atomic.AddInt64(&numProxies, -1)
  212. buf := make([]byte, networkBufferSize)
  213. for {
  214. c1.SetReadDeadline(time.Now().Add(networkTimeout))
  215. n, err := c1.Read(buf)
  216. if err != nil {
  217. return err
  218. }
  219. atomic.AddInt64(&bytesProxied, int64(n))
  220. if debug {
  221. log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr())
  222. }
  223. if s.rateLimit != nil {
  224. s.rateLimit(n)
  225. }
  226. c2.SetWriteDeadline(time.Now().Add(networkTimeout))
  227. _, err = c2.Write(buf[:n])
  228. if err != nil {
  229. return err
  230. }
  231. }
  232. }
  233. func (s *session) String() string {
  234. return fmt.Sprintf("<%s/%s>", hex.EncodeToString(s.clientkey)[:5], hex.EncodeToString(s.serverkey)[:5])
  235. }
  236. func makeRateLimitFunc(sessionRateLimit, globalRateLimit *rate.Limiter) func(int) {
  237. // This may be a case of super duper premature optimization... We build an
  238. // optimized function to do the rate limiting here based on what we need
  239. // to do and then use it in the loop.
  240. if sessionRateLimit == nil && globalRateLimit == nil {
  241. // No limiting needed. We could equally well return a func(int64){} and
  242. // not do a nil check were we use it, but I think the nil check there
  243. // makes it clear that there will be no limiting if none is
  244. // configured...
  245. return nil
  246. }
  247. if sessionRateLimit == nil {
  248. // We only have a global limiter
  249. return func(bytes int) {
  250. take(bytes, globalRateLimit)
  251. }
  252. }
  253. if globalRateLimit == nil {
  254. // We only have a session limiter
  255. return func(bytes int) {
  256. take(bytes, sessionRateLimit)
  257. }
  258. }
  259. // We have both. Queue the bytes on both the global and session specific
  260. // rate limiters.
  261. return func(bytes int) {
  262. take(bytes, sessionRateLimit, globalRateLimit)
  263. }
  264. }
  265. // take is a utility function to consume tokens from a set of rate.Limiters.
  266. // Tokens are consumed in parallel on all limiters, respecting their
  267. // individual burst sizes.
  268. func take(tokens int, ls ...*rate.Limiter) {
  269. // minBurst is the smallest burst size supported by all limiters.
  270. minBurst := int(math.MaxInt32)
  271. for _, l := range ls {
  272. if burst := l.Burst(); burst < minBurst {
  273. minBurst = burst
  274. }
  275. }
  276. for tokens > 0 {
  277. // chunk is how many tokens we can consume at a time
  278. chunk := tokens
  279. if chunk > minBurst {
  280. chunk = minBurst
  281. }
  282. // maxDelay is the longest delay mandated by any of the limiters for
  283. // the chosen chunk size.
  284. var maxDelay time.Duration
  285. for _, l := range ls {
  286. res := l.ReserveN(time.Now(), chunk)
  287. if del := res.Delay(); del > maxDelay {
  288. maxDelay = del
  289. }
  290. }
  291. time.Sleep(maxDelay)
  292. tokens -= chunk
  293. }
  294. }