RubiksAI.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. from __future__ import annotations
  2. import asyncio
  3. import aiohttp
  4. import random
  5. import string
  6. import json
  7. from urllib.parse import urlencode
  8. from aiohttp import ClientSession
  9. from ..typing import AsyncResult, Messages
  10. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  11. from .helper import format_prompt
  12. class RubiksAI(AsyncGeneratorProvider, ProviderModelMixin):
  13. label = "Rubiks AI"
  14. url = "https://rubiks.ai"
  15. api_endpoint = "https://rubiks.ai/search/api.php"
  16. working = True
  17. supports_stream = True
  18. supports_system_message = True
  19. supports_message_history = True
  20. default_model = 'llama-3.1-70b-versatile'
  21. models = [default_model, 'gpt-4o-mini']
  22. model_aliases = {
  23. "llama-3.1-70b": "llama-3.1-70b-versatile",
  24. }
  25. @classmethod
  26. def get_model(cls, model: str) -> str:
  27. if model in cls.models:
  28. return model
  29. elif model in cls.model_aliases:
  30. return cls.model_aliases[model]
  31. else:
  32. return cls.default_model
  33. @staticmethod
  34. def generate_mid() -> str:
  35. """
  36. Generates a 'mid' string following the pattern:
  37. 6 characters - 4 characters - 4 characters - 4 characters - 12 characters
  38. Example: 0r7v7b-quw4-kdy3-rvdu-ekief6xbuuq4
  39. """
  40. parts = [
  41. ''.join(random.choices(string.ascii_lowercase + string.digits, k=6)),
  42. ''.join(random.choices(string.ascii_lowercase + string.digits, k=4)),
  43. ''.join(random.choices(string.ascii_lowercase + string.digits, k=4)),
  44. ''.join(random.choices(string.ascii_lowercase + string.digits, k=4)),
  45. ''.join(random.choices(string.ascii_lowercase + string.digits, k=12))
  46. ]
  47. return '-'.join(parts)
  48. @staticmethod
  49. def create_referer(q: str, mid: str, model: str = '') -> str:
  50. """
  51. Creates a Referer URL with dynamic q and mid values, using urlencode for safe parameter encoding.
  52. """
  53. params = {'q': q, 'model': model, 'mid': mid}
  54. encoded_params = urlencode(params)
  55. return f'https://rubiks.ai/search/?{encoded_params}'
  56. @classmethod
  57. async def create_async_generator(
  58. cls,
  59. model: str,
  60. messages: Messages,
  61. proxy: str = None,
  62. websearch: bool = False,
  63. **kwargs
  64. ) -> AsyncResult:
  65. """
  66. Creates an asynchronous generator that sends requests to the Rubiks AI API and yields the response.
  67. Parameters:
  68. - model (str): The model to use in the request.
  69. - messages (Messages): The messages to send as a prompt.
  70. - proxy (str, optional): Proxy URL, if needed.
  71. - websearch (bool, optional): Indicates whether to include search sources in the response. Defaults to False.
  72. """
  73. model = cls.get_model(model)
  74. prompt = format_prompt(messages)
  75. q_value = prompt
  76. mid_value = cls.generate_mid()
  77. referer = cls.create_referer(q=q_value, mid=mid_value, model=model)
  78. url = cls.api_endpoint
  79. params = {
  80. 'q': q_value,
  81. 'model': model,
  82. 'id': '',
  83. 'mid': mid_value
  84. }
  85. headers = {
  86. 'Accept': 'text/event-stream',
  87. 'Accept-Language': 'en-US,en;q=0.9',
  88. 'Cache-Control': 'no-cache',
  89. 'Connection': 'keep-alive',
  90. 'Pragma': 'no-cache',
  91. 'Referer': referer,
  92. 'Sec-Fetch-Dest': 'empty',
  93. 'Sec-Fetch-Mode': 'cors',
  94. 'Sec-Fetch-Site': 'same-origin',
  95. 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/129.0.0.0 Safari/537.36',
  96. 'sec-ch-ua': '"Chromium";v="129", "Not=A?Brand";v="8"',
  97. 'sec-ch-ua-mobile': '?0',
  98. 'sec-ch-ua-platform': '"Linux"'
  99. }
  100. try:
  101. timeout = aiohttp.ClientTimeout(total=None)
  102. async with ClientSession(timeout=timeout) as session:
  103. async with session.get(url, headers=headers, params=params, proxy=proxy) as response:
  104. if response.status != 200:
  105. yield f"Request ended with status code {response.status}"
  106. return
  107. assistant_text = ''
  108. sources = []
  109. async for line in response.content:
  110. decoded_line = line.decode('utf-8').strip()
  111. if not decoded_line.startswith('data: '):
  112. continue
  113. data = decoded_line[6:]
  114. if data in ('[DONE]', '{"done": ""}'):
  115. break
  116. try:
  117. json_data = json.loads(data)
  118. except json.JSONDecodeError:
  119. continue
  120. if 'url' in json_data and 'title' in json_data:
  121. if websearch:
  122. sources.append({'title': json_data['title'], 'url': json_data['url']})
  123. elif 'choices' in json_data:
  124. for choice in json_data['choices']:
  125. delta = choice.get('delta', {})
  126. content = delta.get('content', '')
  127. role = delta.get('role', '')
  128. if role == 'assistant':
  129. continue
  130. assistant_text += content
  131. if websearch and sources:
  132. sources_text = '\n'.join([f"{i+1}. [{s['title']}]: {s['url']}" for i, s in enumerate(sources)])
  133. assistant_text += f"\n\n**Source:**\n{sources_text}"
  134. yield assistant_text
  135. except asyncio.CancelledError:
  136. yield "The request was cancelled."
  137. except aiohttp.ClientError as e:
  138. yield f"An error occurred during the request: {e}"
  139. except Exception as e:
  140. yield f"An unexpected error occurred: {e}"