limiter.go 2.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. package handlers
  2. import (
  3. "errors"
  4. "math"
  5. "net"
  6. "net/http"
  7. "strings"
  8. "sync"
  9. "golang.org/x/time/rate"
  10. "codeberg.org/vnpower/pixivfe/v2/config"
  11. "codeberg.org/vnpower/pixivfe/v2/routes"
  12. )
  13. func CanRequestSkipLimiter(r *http.Request) bool {
  14. path := r.URL.Path
  15. return strings.HasPrefix(path, "/img/") ||
  16. strings.HasPrefix(path, "/css/") ||
  17. strings.HasPrefix(path, "/js/") ||
  18. strings.HasPrefix(path, "/proxy/s.pximg.net/") ||
  19. strings.HasPrefix(path, "/favicon.ico")
  20. }
  21. // Todo: Should we put middlewares in a separate file?
  22. // IPRateLimiter represents an IP rate limiter.
  23. type IPRateLimiter struct {
  24. ips map[string]*rate.Limiter
  25. mu *sync.RWMutex
  26. limiter *rate.Limiter
  27. }
  28. // NewIPRateLimiter creates a new instance of IPRateLimiter with the given rate limit.
  29. func NewIPRateLimiter(r rate.Limit, burst int) *IPRateLimiter {
  30. return &IPRateLimiter{
  31. ips: make(map[string]*rate.Limiter),
  32. mu: &sync.RWMutex{},
  33. limiter: rate.NewLimiter(r, burst),
  34. }
  35. }
  36. // Allow checks if the request from the given IP is allowed.
  37. func (lim *IPRateLimiter) Allow(ip string) bool {
  38. lim.mu.RLock()
  39. rl, exists := lim.ips[ip]
  40. lim.mu.RUnlock()
  41. if !exists {
  42. lim.mu.Lock()
  43. rl, exists = lim.ips[ip]
  44. if !exists {
  45. rl = rate.NewLimiter(lim.limiter.Limit(), lim.limiter.Burst())
  46. lim.ips[ip] = rl
  47. }
  48. lim.mu.Unlock()
  49. }
  50. return rl.Allow()
  51. }
  52. var limiter *IPRateLimiter
  53. func InitializeRateLimiter() {
  54. r := float64(config.GlobalServerConfig.RequestLimit) / 30.0
  55. if config.GlobalServerConfig.RequestLimit < 1 {
  56. r = math.Inf(1)
  57. }
  58. limiter = NewIPRateLimiter(rate.Limit(r), 3)
  59. }
  60. func RateLimitRequest(handler http.HandlerFunc) http.HandlerFunc {
  61. return func(w http.ResponseWriter, r *http.Request) {
  62. ip, _, _ := net.SplitHostPort(r.RemoteAddr)
  63. if CanRequestSkipLimiter(r) {
  64. handler(w, r)
  65. return
  66. }
  67. if !limiter.Allow(ip) {
  68. CatchError(func(w http.ResponseWriter, r *http.Request) error {
  69. err := errors.New("Too many requests")
  70. GetUserContext(r).Err = err
  71. GetUserContext(r).ErrorStatusCodeOverride = http.StatusTooManyRequests
  72. err = routes.ErrorPage(w, r, err)
  73. if err != nil {
  74. println("Error rendering error route: %s", err)
  75. }
  76. return err
  77. })(w, r)
  78. } else {
  79. handler(w, r)
  80. }
  81. }
  82. }