throttling.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. import asyncio
  2. from aiogram import types, Dispatcher
  3. from aiogram.dispatcher import DEFAULT_RATE_LIMIT
  4. from aiogram.dispatcher.handler import CancelHandler, current_handler
  5. from aiogram.dispatcher.middlewares import BaseMiddleware
  6. from aiogram.utils.exceptions import Throttled
  7. class ThrottlingMiddleware(BaseMiddleware):
  8. """
  9. Simple middleware
  10. """
  11. def __init__(self, limit=DEFAULT_RATE_LIMIT, key_prefix='antiflood_'):
  12. self.rate_limit = limit
  13. self.prefix = key_prefix
  14. super(ThrottlingMiddleware, self).__init__()
  15. async def on_process_message(self, message: types.Message, data: dict):
  16. handler = current_handler.get()
  17. dispatcher = Dispatcher.get_current()
  18. if handler:
  19. limit = getattr(handler, "throttling_rate_limit", self.rate_limit)
  20. key = getattr(handler, "throttling_key", f"{self.prefix}_{handler.__name__}")
  21. else:
  22. limit = self.rate_limit
  23. key = f"{self.prefix}_message"
  24. try:
  25. await dispatcher.throttle(key, rate=limit)
  26. except Throttled as t:
  27. await self.message_throttled(message, t)
  28. raise CancelHandler()
  29. async def message_throttled(self, message: types.Message, throttled: Throttled):
  30. if throttled.exceeded_count <= 2:
  31. await message.reply("Too many requests!")