Blackbox2.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. from __future__ import annotations
  2. import random
  3. import asyncio
  4. from aiohttp import ClientSession
  5. from typing import AsyncGenerator
  6. from ..typing import AsyncResult, Messages
  7. from ..image import ImageResponse
  8. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  9. from .. import debug
  10. class Blackbox2(AsyncGeneratorProvider, ProviderModelMixin):
  11. url = "https://www.blackbox.ai"
  12. api_endpoints = {
  13. "llama-3.1-70b": "https://www.blackbox.ai/api/improve-prompt",
  14. "flux": "https://www.blackbox.ai/api/image-generator"
  15. }
  16. working = True
  17. supports_system_message = True
  18. supports_message_history = True
  19. supports_stream = False
  20. default_model = 'llama-3.1-70b'
  21. chat_models = ['llama-3.1-70b']
  22. image_models = ['flux']
  23. models = [*chat_models, *image_models]
  24. @classmethod
  25. async def create_async_generator(
  26. cls,
  27. model: str,
  28. messages: Messages,
  29. proxy: str = None,
  30. max_retries: int = 3,
  31. delay: int = 1,
  32. **kwargs
  33. ) -> AsyncResult:
  34. if not model:
  35. model = cls.default_model
  36. if model in cls.chat_models:
  37. async for result in cls._generate_text(model, messages, proxy, max_retries, delay):
  38. yield result
  39. elif model in cls.image_models:
  40. prompt = messages[-1]["content"] if prompt is None else prompt
  41. async for result in cls._generate_image(model, prompt, proxy):
  42. yield result
  43. else:
  44. raise ValueError(f"Unsupported model: {model}")
  45. @classmethod
  46. async def _generate_text(
  47. cls,
  48. model: str,
  49. messages: Messages,
  50. proxy: str = None,
  51. max_retries: int = 3,
  52. delay: int = 1
  53. ) -> AsyncGenerator:
  54. headers = cls._get_headers()
  55. api_endpoint = cls.api_endpoints[model]
  56. data = {
  57. "messages": messages,
  58. "max_tokens": None
  59. }
  60. async with ClientSession(headers=headers) as session:
  61. for attempt in range(max_retries):
  62. try:
  63. async with session.post(api_endpoint, json=data, proxy=proxy) as response:
  64. response.raise_for_status()
  65. response_data = await response.json()
  66. if 'prompt' in response_data:
  67. yield response_data['prompt']
  68. return
  69. else:
  70. raise KeyError("'prompt' key not found in the response")
  71. except Exception as e:
  72. if attempt == max_retries - 1:
  73. raise RuntimeError(f"Error after {max_retries} attempts: {str(e)}")
  74. else:
  75. wait_time = delay * (2 ** attempt) + random.uniform(0, 1)
  76. debug.log(f"Attempt {attempt + 1} failed. Retrying in {wait_time:.2f} seconds...")
  77. await asyncio.sleep(wait_time)
  78. @classmethod
  79. async def _generate_image(
  80. cls,
  81. model: str,
  82. prompt: str,
  83. proxy: str = None
  84. ) -> AsyncGenerator:
  85. headers = cls._get_headers()
  86. api_endpoint = cls.api_endpoints[model]
  87. async with ClientSession(headers=headers) as session:
  88. data = {
  89. "query": prompt
  90. }
  91. async with session.post(api_endpoint, headers=headers, json=data, proxy=proxy) as response:
  92. response.raise_for_status()
  93. response_data = await response.json()
  94. if 'markdown' in response_data:
  95. image_url = response_data['markdown'].split('(')[1].split(')')[0]
  96. yield ImageResponse(images=image_url, alt=prompt)
  97. @staticmethod
  98. def _get_headers() -> dict:
  99. return {
  100. 'accept': '*/*',
  101. 'accept-language': 'en-US,en;q=0.9',
  102. 'content-type': 'text/plain;charset=UTF-8',
  103. 'origin': 'https://www.blackbox.ai',
  104. 'priority': 'u=1, i',
  105. 'referer': 'https://www.blackbox.ai',
  106. 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36'
  107. }