BlackboxCreateAgent.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. from __future__ import annotations
  2. import random
  3. import asyncio
  4. import re
  5. import json
  6. from pathlib import Path
  7. from aiohttp import ClientSession
  8. from typing import AsyncIterator, Optional
  9. from ..typing import AsyncResult, Messages
  10. from ..image import ImageResponse
  11. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  12. from ..cookies import get_cookies_dir
  13. from .. import debug
  14. class BlackboxCreateAgent(AsyncGeneratorProvider, ProviderModelMixin):
  15. url = "https://www.blackbox.ai"
  16. api_endpoints = {
  17. "llama-3.1-70b": "https://www.blackbox.ai/api/improve-prompt",
  18. "flux": "https://www.blackbox.ai/api/image-generator"
  19. }
  20. working = True
  21. supports_system_message = True
  22. supports_message_history = True
  23. default_model = 'llama-3.1-70b'
  24. chat_models = [default_model]
  25. image_models = ['flux']
  26. models = [*chat_models, *image_models]
  27. @classmethod
  28. def _get_cache_file(cls) -> Path:
  29. """Returns the path to the cache file."""
  30. dir = Path(get_cookies_dir())
  31. dir.mkdir(exist_ok=True)
  32. return dir / 'blackbox_create_agent.json'
  33. @classmethod
  34. def _load_cached_value(cls) -> str | None:
  35. cache_file = cls._get_cache_file()
  36. if cache_file.exists():
  37. try:
  38. with open(cache_file, 'r') as f:
  39. data = json.load(f)
  40. return data.get('validated_value')
  41. except Exception as e:
  42. debug.log(f"Error reading cache file: {e}")
  43. return None
  44. @classmethod
  45. def _save_cached_value(cls, value: str):
  46. cache_file = cls._get_cache_file()
  47. try:
  48. with open(cache_file, 'w') as f:
  49. json.dump({'validated_value': value}, f)
  50. except Exception as e:
  51. debug.log(f"Error writing to cache file: {e}")
  52. @classmethod
  53. async def fetch_validated(cls) -> Optional[str]:
  54. """
  55. Asynchronously retrieves the validated value from cache or website.
  56. :return: The validated value or None if retrieval fails.
  57. """
  58. cached_value = cls._load_cached_value()
  59. if cached_value:
  60. return cached_value
  61. js_file_pattern = r'static/chunks/\d{4}-[a-fA-F0-9]+\.js'
  62. v_pattern = r'L\s*=\s*[\'"]([0-9a-fA-F-]{36})[\'"]'
  63. def is_valid_context(text: str) -> bool:
  64. """Checks if the context is valid."""
  65. return any(char + '=' in text for char in 'abcdefghijklmnopqrstuvwxyz')
  66. async with ClientSession() as session:
  67. try:
  68. async with session.get(cls.url) as response:
  69. if response.status != 200:
  70. debug.log("Failed to download the page.")
  71. return cached_value
  72. page_content = await response.text()
  73. js_files = re.findall(js_file_pattern, page_content)
  74. for js_file in js_files:
  75. js_url = f"{cls.url}/_next/{js_file}"
  76. async with session.get(js_url) as js_response:
  77. if js_response.status == 200:
  78. js_content = await js_response.text()
  79. for match in re.finditer(v_pattern, js_content):
  80. start = max(0, match.start() - 50)
  81. end = min(len(js_content), match.end() + 50)
  82. context = js_content[start:end]
  83. if is_valid_context(context):
  84. validated_value = match.group(1)
  85. cls._save_cached_value(validated_value)
  86. return validated_value
  87. except Exception as e:
  88. debug.log(f"Error while retrieving validated_value: {e}")
  89. return cached_value
  90. @classmethod
  91. async def create_async_generator(
  92. cls,
  93. model: str,
  94. messages: Messages,
  95. proxy: str = None,
  96. prompt: str = None,
  97. **kwargs
  98. ) -> AsyncIterator[str | ImageResponse]:
  99. """
  100. Creates an async generator for text or image generation.
  101. """
  102. if model in cls.chat_models:
  103. async for text in cls._generate_text(model, messages, proxy=proxy, **kwargs):
  104. yield text
  105. elif model in cls.image_models:
  106. prompt = messages[-1]['content']
  107. async for image in cls._generate_image(model, prompt, proxy=proxy, **kwargs):
  108. yield image
  109. else:
  110. raise ValueError(f"Model {model} not supported")
  111. @classmethod
  112. async def _generate_text(
  113. cls,
  114. model: str,
  115. messages: Messages,
  116. proxy: str = None,
  117. max_retries: int = 3,
  118. delay: int = 1,
  119. max_tokens: int = None,
  120. **kwargs
  121. ) -> AsyncIterator[str]:
  122. headers = cls._get_headers()
  123. for outer_attempt in range(2): # Add outer loop for retrying with a new key
  124. validated_value = await cls.fetch_validated()
  125. if not validated_value:
  126. raise RuntimeError("Failed to get validated value")
  127. async with ClientSession(headers=headers) as session:
  128. api_endpoint = cls.api_endpoints[model]
  129. data = {
  130. "messages": messages,
  131. "max_tokens": max_tokens,
  132. "validated": validated_value
  133. }
  134. for attempt in range(max_retries):
  135. try:
  136. async with session.post(api_endpoint, json=data, proxy=proxy) as response:
  137. response.raise_for_status()
  138. response_data = await response.json()
  139. if response_data.get('status') == 200 and 'prompt' in response_data:
  140. yield response_data['prompt']
  141. return # Successful execution
  142. else:
  143. raise KeyError("Invalid response format or missing 'prompt' key")
  144. except Exception as e:
  145. if attempt == max_retries - 1:
  146. if outer_attempt == 0: # If this is the first attempt with this key
  147. # Remove the cached key and try to get a new one
  148. cls._save_cached_value("")
  149. debug.log("Invalid key, trying to get a new one...")
  150. break # Exit the inner loop to get a new key
  151. else:
  152. raise RuntimeError(f"Error after all attempts: {str(e)}")
  153. else:
  154. wait_time = delay * (2 ** attempt) + random.uniform(0, 1)
  155. debug.log(f"Attempt {attempt + 1} failed. Retrying in {wait_time:.2f} seconds...")
  156. await asyncio.sleep(wait_time)
  157. @classmethod
  158. async def _generate_image(
  159. cls,
  160. model: str,
  161. prompt: str,
  162. proxy: str = None,
  163. **kwargs
  164. ) -> AsyncIterator[ImageResponse]:
  165. headers = {
  166. **cls._get_headers()
  167. }
  168. api_endpoint = cls.api_endpoints[model]
  169. async with ClientSession(headers=headers) as session:
  170. data = {
  171. "query": prompt
  172. }
  173. async with session.post(api_endpoint, json=data, proxy=proxy) as response:
  174. response.raise_for_status()
  175. response_data = await response.json()
  176. if 'markdown' in response_data:
  177. # Extract URL from markdown format: ![](url)
  178. image_url = re.search(r'\!\[\]\((.*?)\)', response_data['markdown'])
  179. if image_url:
  180. yield ImageResponse(images=[image_url.group(1)], alt=prompt)
  181. else:
  182. raise ValueError("Could not extract image URL from markdown")
  183. else:
  184. raise KeyError("'markdown' key not found in response")
  185. @staticmethod
  186. def _get_headers() -> dict:
  187. return {
  188. 'accept': '*/*',
  189. 'accept-language': 'en-US,en;q=0.9',
  190. 'authorization': f'Bearer 56c8eeff9971269d7a7e625ff88e8a83a34a556003a5c87c289ebe9a3d8a3d2c',
  191. 'content-type': 'application/json',
  192. 'origin': 'https://www.blackbox.ai',
  193. 'referer': 'https://www.blackbox.ai',
  194. 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36'
  195. }
  196. @classmethod
  197. async def create_async(
  198. cls,
  199. model: str,
  200. messages: Messages,
  201. proxy: str = None,
  202. **kwargs
  203. ) -> AsyncResult:
  204. """
  205. Creates an async response for the provider.
  206. Args:
  207. model: The model to use
  208. messages: The messages to process
  209. proxy: Optional proxy to use
  210. **kwargs: Additional arguments
  211. Returns:
  212. AsyncResult: The response from the provider
  213. """
  214. if not model:
  215. model = cls.default_model
  216. if model in cls.chat_models:
  217. async for text in cls._generate_text(model, messages, proxy=proxy, **kwargs):
  218. return text
  219. elif model in cls.image_models:
  220. prompt = messages[-1]['content']
  221. async for image in cls._generate_image(model, prompt, proxy=proxy, **kwargs):
  222. return image
  223. else:
  224. raise ValueError(f"Model {model} not supported")