selector.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. package features
  2. import (
  3. "context"
  4. "encoding/json"
  5. "fmt"
  6. "hash/fnv"
  7. "net"
  8. "slices"
  9. "sync"
  10. "time"
  11. "github.com/rs/zerolog"
  12. )
  13. const (
  14. featureSelectorHostname = "cfd-features.argotunnel.com"
  15. lookupTimeout = time.Second * 10
  16. defaultLookupFreq = time.Hour
  17. )
  18. // If the TXT record adds other fields, the umarshal logic will ignore those keys
  19. // If the TXT record is missing a key, the field will unmarshal to the default Go value
  20. type featuresRecord struct {
  21. DatagramV3Percentage uint32 `json:"dv3_1"`
  22. // DatagramV3Percentage int32 `json:"dv3"` // Removed in TUN-9291
  23. // PostQuantumPercentage int32 `json:"pq"` // Removed in TUN-7970
  24. }
  25. func NewFeatureSelector(ctx context.Context, accountTag string, cliFeatures []string, pq bool, logger *zerolog.Logger) (FeatureSelector, error) {
  26. return newFeatureSelector(ctx, accountTag, logger, newDNSResolver(), cliFeatures, pq, defaultLookupFreq)
  27. }
  28. type FeatureSelector interface {
  29. Snapshot() FeatureSnapshot
  30. }
  31. // FeatureSelector determines if this account will try new features; loaded once during startup.
  32. type featureSelector struct {
  33. accountHash uint32
  34. logger *zerolog.Logger
  35. resolver resolver
  36. staticFeatures staticFeatures
  37. cliFeatures []string
  38. // lock protects concurrent access to dynamic features
  39. lock sync.RWMutex
  40. remoteFeatures featuresRecord
  41. }
  42. func newFeatureSelector(ctx context.Context, accountTag string, logger *zerolog.Logger, resolver resolver, cliFeatures []string, pq bool, refreshFreq time.Duration) (*featureSelector, error) {
  43. // Combine default features and user-provided features
  44. var pqMode *PostQuantumMode
  45. if pq {
  46. mode := PostQuantumStrict
  47. pqMode = &mode
  48. cliFeatures = append(cliFeatures, FeaturePostQuantum)
  49. }
  50. staticFeatures := staticFeatures{
  51. PostQuantumMode: pqMode,
  52. }
  53. selector := &featureSelector{
  54. accountHash: switchThreshold(accountTag),
  55. logger: logger,
  56. resolver: resolver,
  57. staticFeatures: staticFeatures,
  58. cliFeatures: dedupAndRemoveFeatures(cliFeatures),
  59. }
  60. // Load the remote features
  61. if err := selector.refresh(ctx); err != nil {
  62. logger.Err(err).Msg("Failed to fetch features, default to disable")
  63. }
  64. // Spin off reloading routine
  65. go selector.refreshLoop(ctx, refreshFreq)
  66. return selector, nil
  67. }
  68. func (fs *featureSelector) Snapshot() FeatureSnapshot {
  69. fs.lock.RLock()
  70. defer fs.lock.RUnlock()
  71. return FeatureSnapshot{
  72. PostQuantum: fs.postQuantumMode(),
  73. DatagramVersion: fs.datagramVersion(),
  74. FeaturesList: fs.clientFeatures(),
  75. }
  76. }
  77. func (fs *featureSelector) accountEnabled(percentage uint32) bool {
  78. return percentage > fs.accountHash
  79. }
  80. func (fs *featureSelector) postQuantumMode() PostQuantumMode {
  81. if fs.staticFeatures.PostQuantumMode != nil {
  82. return *fs.staticFeatures.PostQuantumMode
  83. }
  84. return PostQuantumPrefer
  85. }
  86. func (fs *featureSelector) datagramVersion() DatagramVersion {
  87. // If user provides the feature via the cli, we take it as priority over remote feature evaluation
  88. if slices.Contains(fs.cliFeatures, FeatureDatagramV3_1) {
  89. return DatagramV3
  90. }
  91. // If the user specifies DatagramV2, we also take that over remote
  92. if slices.Contains(fs.cliFeatures, FeatureDatagramV2) {
  93. return DatagramV2
  94. }
  95. if fs.accountEnabled(fs.remoteFeatures.DatagramV3Percentage) {
  96. return DatagramV3
  97. }
  98. return DatagramV2
  99. }
  100. // clientFeatures will return the list of currently available features that cloudflared should provide to the edge.
  101. func (fs *featureSelector) clientFeatures() []string {
  102. // Evaluate any remote features along with static feature list to construct the list of features
  103. return dedupAndRemoveFeatures(slices.Concat(defaultFeatures, fs.cliFeatures, []string{string(fs.datagramVersion())}))
  104. }
  105. func (fs *featureSelector) refresh(ctx context.Context) error {
  106. record, err := fs.resolver.lookupRecord(ctx)
  107. if err != nil {
  108. return err
  109. }
  110. var features featuresRecord
  111. if err := json.Unmarshal(record, &features); err != nil {
  112. return err
  113. }
  114. fs.lock.Lock()
  115. defer fs.lock.Unlock()
  116. fs.remoteFeatures = features
  117. return nil
  118. }
  119. func (fs *featureSelector) refreshLoop(ctx context.Context, refreshFreq time.Duration) {
  120. ticker := time.NewTicker(refreshFreq)
  121. for {
  122. select {
  123. case <-ctx.Done():
  124. return
  125. case <-ticker.C:
  126. err := fs.refresh(ctx)
  127. if err != nil {
  128. fs.logger.Err(err).Msg("Failed to refresh feature selector")
  129. }
  130. }
  131. }
  132. }
  133. // resolver represents an object that can look up featuresRecord
  134. type resolver interface {
  135. lookupRecord(ctx context.Context) ([]byte, error)
  136. }
  137. type dnsResolver struct {
  138. resolver *net.Resolver
  139. }
  140. func newDNSResolver() *dnsResolver {
  141. return &dnsResolver{
  142. resolver: net.DefaultResolver,
  143. }
  144. }
  145. func (dr *dnsResolver) lookupRecord(ctx context.Context) ([]byte, error) {
  146. ctx, cancel := context.WithTimeout(ctx, lookupTimeout)
  147. defer cancel()
  148. records, err := dr.resolver.LookupTXT(ctx, featureSelectorHostname)
  149. if err != nil {
  150. return nil, err
  151. }
  152. if len(records) == 0 {
  153. return nil, fmt.Errorf("No TXT record found for %s to determine which features to opt-in", featureSelectorHostname)
  154. }
  155. return []byte(records[0]), nil
  156. }
  157. func switchThreshold(accountTag string) uint32 {
  158. h := fnv.New32a()
  159. _, _ = h.Write([]byte(accountTag))
  160. return h.Sum32() % 100
  161. }