__init__.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import aio_pika
  2. import aiopg
  3. import asyncio
  4. import json
  5. import os
  6. from aio_pika.pool import Pool
  7. from distutils.util import strtobool
  8. async def consume(loop, sql_template=None, logger=None, config=None, consumer_pool_size=10):
  9. if config is None:
  10. config = {
  11. "mq_host": os.environ.get('MQ_HOST'),
  12. "mq_port": int(os.environ.get('MQ_PORT', '5672')),
  13. "mq_vhost": os.environ.get('MQ_VHOST'),
  14. "mq_user": os.environ.get('MQ_USER'),
  15. "mq_pass": os.environ.get('MQ_PASS'),
  16. "mq_queue": os.environ.get('MQ_QUEUE'),
  17. "mq_queue_durable": bool(strtobool(os.environ.get('MQ_QUEUE_DURABLE', 'True'))),
  18. "mq_exchange": os.environ.get("MQ_EXCHANGE"),
  19. "mq_routing_key": os.environ.get("MQ_ROUTING_KEY"),
  20. "db_host": os.environ.get('DB_HOST'),
  21. "db_port": int(os.environ.get('DB_PORT', '5432')),
  22. "db_user": os.environ.get('DB_USER'),
  23. "db_pass": os.environ.get('DB_PASS'),
  24. "db_database": os.environ.get('DB_DATABASE'),
  25. "consumer_pool_size": os.environ.get("CONSUMER_POOL_SIZE"),
  26. "sql_template": os.environ.get('SQL_TEMPLATE')
  27. }
  28. if sql_template is None:
  29. sql_template = config.get("sql_template")
  30. if "consumer_pool_size" in config:
  31. if config.get("consumer_pool_size"):
  32. try:
  33. consumer_pool_size = int(config.get("consumer_pool_size"))
  34. except TypeError as e:
  35. if logger:
  36. logger.error(f"Invalid pool size: {consumer_pool_size}")
  37. raise e
  38. db_pool = await aiopg.create_pool(
  39. host=config.get("db_host"),
  40. user=config.get("db_user"),
  41. password=config.get("db_pass"),
  42. database=config.get("db_database"),
  43. port=config.get("db_port"),
  44. minsize=consumer_pool_size,
  45. maxsize=consumer_pool_size * 2
  46. )
  47. async def get_connection():
  48. return await aio_pika.connect(
  49. host=config.get("mq_host"),
  50. port=config.get("mq_port"),
  51. login=config.get("mq_user"),
  52. password=config.get("mq_pass"),
  53. virtualhost=config.get("mq_vhost"),
  54. loop=loop
  55. )
  56. connection_pool = Pool(get_connection, max_size=consumer_pool_size, loop=loop)
  57. async def get_channel():
  58. async with connection_pool.acquire() as connection:
  59. return await connection.channel()
  60. channel_pool = Pool(get_channel, max_size=consumer_pool_size, loop=loop)
  61. async def _push_to_dead_letter_queue(message, channel):
  62. exchange = await channel.get_exchange(config.get("mq_exchange"))
  63. await exchange.publish(
  64. message=aio_pika.Message(message.encode("utf-8")),
  65. routing_key=config.get("mq_routing_key")
  66. )
  67. async def _consume():
  68. async with channel_pool.acquire() as channel:
  69. queue = await channel.declare_queue(
  70. config.get("mq_queue"), durable=config.get("mq_queue_durable"), auto_delete=False
  71. )
  72. db_conn = await db_pool.acquire()
  73. cursor = await db_conn.cursor()
  74. while True:
  75. try:
  76. m = await queue.get(timeout=5 * consumer_pool_size)
  77. message = m.body.decode('utf-8')
  78. if logger:
  79. logger.debug(f"Message {message} inserting to db")
  80. try:
  81. await cursor.execute(sql_template, (message,))
  82. except Exception as e:
  83. if logger:
  84. logger.error(f"DB Error: {e}, pushing message to dead letter queue!")
  85. _push_to_dead_letter_queue(message, channel)
  86. finally:
  87. await m.ack()
  88. except aio_pika.exceptions.QueueEmpty:
  89. db_conn.close()
  90. if logger:
  91. logger.info("Queue empty. Stopping.")
  92. break
  93. async with connection_pool, channel_pool:
  94. consumer_pool = []
  95. if logger:
  96. logger.info("Consumers started")
  97. for _ in range(consumer_pool_size):
  98. consumer_pool.append(_consume())
  99. await asyncio.gather(*consumer_pool)