bucket_manager.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. // bucket_manager.go
  2. package golimiter
  3. import (
  4. "sync"
  5. "time"
  6. "fmt"
  7. )
  8. // TokenBucketManager manages multiple token buckets.
  9. type TokenBucketManager struct {
  10. tokenBuckets map[string]*TokenBucket // Map of IP addresses to their corresponding token buckets
  11. mu sync.Mutex // Mutex for synchronization
  12. }
  13. // NewTokenBucketManager creates a new TokenBucketManager instance.
  14. func NewTokenBucketManager() *TokenBucketManager {
  15. return &TokenBucketManager{
  16. tokenBuckets: make(map[string]*TokenBucket),
  17. }
  18. }
  19. // GetTokenBucket checks whether a TokenBucket exists for the specified IP address.
  20. // If a bucket exists and is still valid, it returns the corresponding TokenBucket; otherwise, it returns nil.
  21. func (tm *TokenBucketManager) GetTokenBucket(ip string) *TokenBucket {
  22. tm.mu.Lock()
  23. defer tm.mu.Unlock()
  24. tb, exists := tm.tokenBuckets[ip]
  25. if exists {
  26. if tb.expirationTime.Before(time.Now()) {
  27. delete(tm.tokenBuckets, ip) // Remove expired bucket from the map
  28. return nil
  29. }
  30. return tb
  31. }
  32. return nil
  33. }
  34. // CreateTokenBucket creates a new TokenBucket for the given IP address and stores it in the manager.
  35. // If a bucket already exists, it returns an error.
  36. func (tm *TokenBucketManager) CreateTokenBucket(ip string, capacity, fillRate int, bucketLifetime time.Duration) (*TokenBucket, error) {
  37. tm.mu.Lock()
  38. defer tm.mu.Unlock()
  39. if _, exists := tm.tokenBuckets[ip]; exists {
  40. return nil, fmt.Errorf("token bucket for IP %s already exists", ip)
  41. }
  42. tb := NewTokenBucket(capacity, fillRate, bucketLifetime)
  43. tm.tokenBuckets[ip] = tb
  44. return tb, nil
  45. }
  46. // GetRemainingTokens returns the remaining token count for the specified IP address.
  47. // Returns -1 if the bucket does not exist.
  48. func (tm *TokenBucketManager) GetRemainingTokens(ip string) int {
  49. tb := tm.GetTokenBucket(ip)
  50. if tb == nil {
  51. return -1
  52. }
  53. tb.mu.Lock()
  54. defer tb.mu.Unlock()
  55. tb.AddTokensToBucket()
  56. return tb.tokens
  57. }
  58. // cleanupExpiredTokenBuckets removes expired token buckets from the manager.
  59. func (tm *TokenBucketManager) cleanupExpiredTokenBuckets() {
  60. tm.mu.Lock()
  61. defer tm.mu.Unlock()
  62. currentTime := time.Now()
  63. for ip, tb := range tm.tokenBuckets {
  64. if tb.expirationTime.Before(currentTime) {
  65. delete(tm.tokenBuckets, ip)
  66. }
  67. }
  68. }
  69. // StartCleanupTask starts a periodic task to clean up expired token buckets at a specified interval.
  70. func (tm *TokenBucketManager) StartCleanupTask(cleanupInterval time.Duration) {
  71. ticker := time.NewTicker(cleanupInterval)
  72. defer ticker.Stop()
  73. for {
  74. select {
  75. case <-ticker.C:
  76. tm.cleanupExpiredTokenBuckets()
  77. }
  78. }
  79. }