reconnect.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. package origin
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "sync"
  7. "time"
  8. "github.com/cloudflare/cloudflared/h2mux"
  9. "github.com/cloudflare/cloudflared/logger"
  10. "github.com/cloudflare/cloudflared/tunnelrpc"
  11. tunnelpogs "github.com/cloudflare/cloudflared/tunnelrpc/pogs"
  12. "github.com/google/uuid"
  13. "github.com/prometheus/client_golang/prometheus"
  14. )
  15. var (
  16. errJWTUnset = errors.New("JWT unset")
  17. )
  18. // reconnectTunnelCredentialManager is invoked by functions in tunnel.go to
  19. // get/set parameters for ReconnectTunnel RPC calls.
  20. type reconnectCredentialManager struct {
  21. mu sync.RWMutex
  22. jwt []byte
  23. eventDigest map[uint8][]byte
  24. connDigest map[uint8][]byte
  25. authSuccess prometheus.Counter
  26. authFail *prometheus.CounterVec
  27. }
  28. func newReconnectCredentialManager(namespace, subsystem string, haConnections int) *reconnectCredentialManager {
  29. authSuccess := prometheus.NewCounter(
  30. prometheus.CounterOpts{
  31. Namespace: namespace,
  32. Subsystem: subsystem,
  33. Name: "tunnel_authenticate_success",
  34. Help: "Count of successful tunnel authenticate",
  35. },
  36. )
  37. authFail := prometheus.NewCounterVec(
  38. prometheus.CounterOpts{
  39. Namespace: namespace,
  40. Subsystem: subsystem,
  41. Name: "tunnel_authenticate_fail",
  42. Help: "Count of tunnel authenticate errors by type",
  43. },
  44. []string{"error"},
  45. )
  46. prometheus.MustRegister(authSuccess, authFail)
  47. return &reconnectCredentialManager{
  48. eventDigest: make(map[uint8][]byte, haConnections),
  49. connDigest: make(map[uint8][]byte, haConnections),
  50. authSuccess: authSuccess,
  51. authFail: authFail,
  52. }
  53. }
  54. func (cm *reconnectCredentialManager) ReconnectToken() ([]byte, error) {
  55. cm.mu.RLock()
  56. defer cm.mu.RUnlock()
  57. if cm.jwt == nil {
  58. return nil, errJWTUnset
  59. }
  60. return cm.jwt, nil
  61. }
  62. func (cm *reconnectCredentialManager) SetReconnectToken(jwt []byte) {
  63. cm.mu.Lock()
  64. defer cm.mu.Unlock()
  65. cm.jwt = jwt
  66. }
  67. func (cm *reconnectCredentialManager) EventDigest(connID uint8) ([]byte, error) {
  68. cm.mu.RLock()
  69. defer cm.mu.RUnlock()
  70. digest, ok := cm.eventDigest[connID]
  71. if !ok {
  72. return nil, fmt.Errorf("no event digest for connection %v", connID)
  73. }
  74. return digest, nil
  75. }
  76. func (cm *reconnectCredentialManager) SetEventDigest(connID uint8, digest []byte) {
  77. cm.mu.Lock()
  78. defer cm.mu.Unlock()
  79. cm.eventDigest[connID] = digest
  80. }
  81. func (cm *reconnectCredentialManager) ConnDigest(connID uint8) ([]byte, error) {
  82. cm.mu.RLock()
  83. defer cm.mu.RUnlock()
  84. digest, ok := cm.connDigest[connID]
  85. if !ok {
  86. return nil, fmt.Errorf("no conneciton digest for connection %v", connID)
  87. }
  88. return digest, nil
  89. }
  90. func (cm *reconnectCredentialManager) SetConnDigest(connID uint8, digest []byte) {
  91. cm.mu.Lock()
  92. defer cm.mu.Unlock()
  93. cm.connDigest[connID] = digest
  94. }
  95. func (cm *reconnectCredentialManager) RefreshAuth(
  96. ctx context.Context,
  97. backoff *BackoffHandler,
  98. authenticate func(ctx context.Context, numPreviousAttempts int) (tunnelpogs.AuthOutcome, error),
  99. ) (retryTimer <-chan time.Time, err error) {
  100. authOutcome, err := authenticate(ctx, backoff.Retries())
  101. if err != nil {
  102. cm.authFail.WithLabelValues(err.Error()).Inc()
  103. if _, ok := backoff.GetBackoffDuration(ctx); ok {
  104. return backoff.BackoffTimer(), nil
  105. }
  106. return nil, err
  107. }
  108. // clear backoff timer
  109. backoff.SetGracePeriod()
  110. switch outcome := authOutcome.(type) {
  111. case tunnelpogs.AuthSuccess:
  112. cm.SetReconnectToken(outcome.JWT())
  113. cm.authSuccess.Inc()
  114. return timeAfter(outcome.RefreshAfter()), nil
  115. case tunnelpogs.AuthUnknown:
  116. duration := outcome.RefreshAfter()
  117. cm.authFail.WithLabelValues(outcome.Error()).Inc()
  118. return timeAfter(duration), nil
  119. case tunnelpogs.AuthFail:
  120. cm.authFail.WithLabelValues(outcome.Error()).Inc()
  121. return nil, outcome
  122. default:
  123. err := fmt.Errorf("refresh_auth: Unexpected outcome type %T", authOutcome)
  124. cm.authFail.WithLabelValues(err.Error()).Inc()
  125. return nil, err
  126. }
  127. }
  128. func ReconnectTunnel(
  129. ctx context.Context,
  130. muxer *h2mux.Muxer,
  131. config *TunnelConfig,
  132. logger logger.Service,
  133. connectionID uint8,
  134. originLocalAddr string,
  135. uuid uuid.UUID,
  136. credentialManager *reconnectCredentialManager,
  137. ) error {
  138. token, err := credentialManager.ReconnectToken()
  139. if err != nil {
  140. return err
  141. }
  142. eventDigest, err := credentialManager.EventDigest(connectionID)
  143. if err != nil {
  144. return err
  145. }
  146. connDigest, err := credentialManager.ConnDigest(connectionID)
  147. if err != nil {
  148. return err
  149. }
  150. config.TransportLogger.Debug("initiating RPC stream to reconnect")
  151. rpcClient, err := newTunnelRPCClient(ctx, muxer, config, reconnect)
  152. if err != nil {
  153. return err
  154. }
  155. defer rpcClient.Close()
  156. // Request server info without blocking tunnel registration; must use capnp library directly.
  157. serverInfoPromise := tunnelrpc.TunnelServer{Client: rpcClient.Client}.GetServerInfo(ctx, func(tunnelrpc.TunnelServer_getServerInfo_Params) error {
  158. return nil
  159. })
  160. LogServerInfo(serverInfoPromise.Result(), connectionID, config.Metrics, logger, config.TunnelEventChan)
  161. registration := rpcClient.ReconnectTunnel(
  162. ctx,
  163. token,
  164. eventDigest,
  165. connDigest,
  166. config.Hostname,
  167. config.RegistrationOptions(connectionID, originLocalAddr, uuid),
  168. )
  169. if registrationErr := registration.DeserializeError(); registrationErr != nil {
  170. // ReconnectTunnel RPC failure
  171. return processRegisterTunnelError(registrationErr, config.Metrics, reconnect)
  172. }
  173. return processRegistrationSuccess(config, logger, connectionID, registration, reconnect, credentialManager)
  174. }