ratelimit.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
  2. // See LICENSE.txt for license information.
  3. package app
  4. import (
  5. "math"
  6. "net/http"
  7. "strconv"
  8. "strings"
  9. "github.com/mattermost/mattermost-server/v5/mlog"
  10. "github.com/mattermost/mattermost-server/v5/model"
  11. "github.com/mattermost/mattermost-server/v5/utils"
  12. "github.com/pkg/errors"
  13. "github.com/throttled/throttled"
  14. "github.com/throttled/throttled/store/memstore"
  15. )
  16. type RateLimiter struct {
  17. throttledRateLimiter *throttled.GCRARateLimiter
  18. useAuth bool
  19. useIP bool
  20. header string
  21. trustedProxyIPHeader []string
  22. }
  23. func NewRateLimiter(settings *model.RateLimitSettings, trustedProxyIPHeader []string) (*RateLimiter, error) {
  24. store, err := memstore.New(*settings.MemoryStoreSize)
  25. if err != nil {
  26. return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_memory_store"))
  27. }
  28. quota := throttled.RateQuota{
  29. MaxRate: throttled.PerSec(*settings.PerSec),
  30. MaxBurst: *settings.MaxBurst,
  31. }
  32. throttledRateLimiter, err := throttled.NewGCRARateLimiter(store, quota)
  33. if err != nil {
  34. return nil, errors.Wrap(err, utils.T("api.server.start_server.rate_limiting_rate_limiter"))
  35. }
  36. return &RateLimiter{
  37. throttledRateLimiter: throttledRateLimiter,
  38. useAuth: *settings.VaryByUser,
  39. useIP: *settings.VaryByRemoteAddr,
  40. header: settings.VaryByHeader,
  41. trustedProxyIPHeader: trustedProxyIPHeader,
  42. }, nil
  43. }
  44. func (rl *RateLimiter) GenerateKey(r *http.Request) string {
  45. key := ""
  46. if rl.useAuth {
  47. token, tokenLocation := ParseAuthTokenFromRequest(r)
  48. if tokenLocation != TokenLocationNotFound {
  49. key += token
  50. } else if rl.useIP { // If we don't find an authentication token and IP based is enabled, fall back to IP
  51. key += utils.GetIpAddress(r, rl.trustedProxyIPHeader)
  52. }
  53. } else if rl.useIP { // Only if Auth based is not enabed do we use a plain IP based
  54. key += utils.GetIpAddress(r, rl.trustedProxyIPHeader)
  55. }
  56. // Note that most of the time the user won't have to set this because the utils.GetIpAddress above tries the
  57. // most common headers anyway.
  58. if rl.header != "" {
  59. key += strings.ToLower(r.Header.Get(rl.header))
  60. }
  61. return key
  62. }
  63. func (rl *RateLimiter) RateLimitWriter(key string, w http.ResponseWriter) bool {
  64. limited, context, err := rl.throttledRateLimiter.RateLimit(key, 1)
  65. if err != nil {
  66. mlog.Critical("Internal server error when rate limiting. Rate Limiting broken.", mlog.Err(err))
  67. return false
  68. }
  69. setRateLimitHeaders(w, context)
  70. if limited {
  71. mlog.Error("Denied due to throttling settings code=429", mlog.String("key", key))
  72. http.Error(w, "limit exceeded", 429)
  73. }
  74. return limited
  75. }
  76. func (rl *RateLimiter) UserIdRateLimit(userId string, w http.ResponseWriter) bool {
  77. if rl.useAuth {
  78. if rl.RateLimitWriter(userId, w) {
  79. return true
  80. }
  81. }
  82. return false
  83. }
  84. func (rl *RateLimiter) RateLimitHandler(wrappedHandler http.Handler) http.Handler {
  85. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  86. key := rl.GenerateKey(r)
  87. limited := rl.RateLimitWriter(key, w)
  88. if !limited {
  89. wrappedHandler.ServeHTTP(w, r)
  90. }
  91. })
  92. }
  93. // Copied from https://github.com/throttled/throttled http.go
  94. func setRateLimitHeaders(w http.ResponseWriter, context throttled.RateLimitResult) {
  95. if v := context.Limit; v >= 0 {
  96. w.Header().Add("X-RateLimit-Limit", strconv.Itoa(v))
  97. }
  98. if v := context.Remaining; v >= 0 {
  99. w.Header().Add("X-RateLimit-Remaining", strconv.Itoa(v))
  100. }
  101. if v := context.ResetAfter; v >= 0 {
  102. vi := int(math.Ceil(v.Seconds()))
  103. w.Header().Add("X-RateLimit-Reset", strconv.Itoa(vi))
  104. }
  105. if v := context.RetryAfter; v >= 0 {
  106. vi := int(math.Ceil(v.Seconds()))
  107. w.Header().Add("Retry-After", strconv.Itoa(vi))
  108. }
  109. }