limiter.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package handlers
  2. import (
  3. "errors"
  4. "math"
  5. "net"
  6. "net/http"
  7. "strings"
  8. "sync"
  9. "golang.org/x/time/rate"
  10. "codeberg.org/vnpower/pixivfe/v2/config"
  11. "codeberg.org/vnpower/pixivfe/v2/routes"
  12. )
  13. func CanRequestSkipLimiter(r *http.Request) bool {
  14. path := r.URL.Path
  15. return strings.HasPrefix(path, "/img/") ||
  16. strings.HasPrefix(path, "/css/") ||
  17. strings.HasPrefix(path, "/js/") ||
  18. strings.HasPrefix(path, "/proxy/s.pximg.net/")
  19. }
  20. // Todo: Should we put middlewares in a separate file?
  21. // IPRateLimiter represents an IP rate limiter.
  22. type IPRateLimiter struct {
  23. ips map[string]*rate.Limiter
  24. mu *sync.RWMutex
  25. limiter *rate.Limiter
  26. }
  27. // NewIPRateLimiter creates a new instance of IPRateLimiter with the given rate limit.
  28. func NewIPRateLimiter(r rate.Limit, burst int) *IPRateLimiter {
  29. return &IPRateLimiter{
  30. ips: make(map[string]*rate.Limiter),
  31. mu: &sync.RWMutex{},
  32. limiter: rate.NewLimiter(r, burst),
  33. }
  34. }
  35. // Allow checks if the request from the given IP is allowed.
  36. func (lim *IPRateLimiter) Allow(ip string) bool {
  37. lim.mu.RLock()
  38. rl, exists := lim.ips[ip]
  39. lim.mu.RUnlock()
  40. if !exists {
  41. lim.mu.Lock()
  42. rl, exists = lim.ips[ip]
  43. if !exists {
  44. rl = rate.NewLimiter(lim.limiter.Limit(), lim.limiter.Burst())
  45. lim.ips[ip] = rl
  46. }
  47. lim.mu.Unlock()
  48. }
  49. return rl.Allow()
  50. }
  51. var limiter *IPRateLimiter
  52. func InitializeRateLimiter() {
  53. r := float64(config.GlobalConfig.RequestLimit) / 30.0
  54. if config.GlobalConfig.RequestLimit < 1 {
  55. r = math.Inf(1)
  56. }
  57. limiter = NewIPRateLimiter(rate.Limit(r), 3)
  58. }
  59. func RateLimitRequest(h http.Handler) http.Handler {
  60. return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  61. ip, _, _ := net.SplitHostPort(r.RemoteAddr)
  62. if CanRequestSkipLimiter(r) {
  63. h.ServeHTTP(w, r)
  64. return
  65. }
  66. if !limiter.Allow(ip) {
  67. routes.ErrorPage(w, r, errors.New("Too many requests"), http.StatusTooManyRequests)
  68. } else {
  69. h.ServeHTTP(w, r)
  70. }
  71. })
  72. }