Theb.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. from __future__ import annotations
  2. import time
  3. from ...typing import CreateResult, Messages
  4. from ..base_provider import AbstractProvider
  5. from ..helper import format_prompt
  6. from ...webdriver import WebDriver, WebDriverSession, element_send_text
  7. models = {
  8. "theb-ai": "TheB.AI",
  9. "theb-ai-free": "TheB.AI Free",
  10. "gpt-3.5-turbo": "GPT-3.5 Turbo (New)",
  11. "gpt-3.5-turbo-16k": "GPT-3.5-16K",
  12. "gpt-4-turbo": "GPT-4 Turbo",
  13. "gpt-4": "GPT-4",
  14. "gpt-4-32k": "GPT-4 32K",
  15. "claude-2": "Claude 2",
  16. "claude-instant-1": "Claude Instant 1.2",
  17. "palm-2": "PaLM 2",
  18. "palm-2-32k": "PaLM 2 32K",
  19. "palm-2-codey": "Codey",
  20. "palm-2-codey-32k": "Codey 32K",
  21. "vicuna-13b-v1.5": "Vicuna v1.5 13B",
  22. "llama-2-7b-chat": "Llama 2 7B",
  23. "llama-2-13b-chat": "Llama 2 13B",
  24. "llama-2-70b-chat": "Llama 2 70B",
  25. "code-llama-7b": "Code Llama 7B",
  26. "code-llama-13b": "Code Llama 13B",
  27. "code-llama-34b": "Code Llama 34B",
  28. "qwen-7b-chat": "Qwen 7B"
  29. }
  30. class Theb(AbstractProvider):
  31. label = "TheB.AI"
  32. url = "https://beta.theb.ai"
  33. working = True
  34. supports_stream = True
  35. models = models.keys()
  36. @classmethod
  37. def create_completion(
  38. cls,
  39. model: str,
  40. messages: Messages,
  41. stream: bool,
  42. proxy: str = None,
  43. webdriver: WebDriver = None,
  44. virtual_display: bool = True,
  45. **kwargs
  46. ) -> CreateResult:
  47. if model in models:
  48. model = models[model]
  49. prompt = format_prompt(messages)
  50. web_session = WebDriverSession(webdriver, virtual_display=virtual_display, proxy=proxy)
  51. with web_session as driver:
  52. from selenium.webdriver.common.by import By
  53. from selenium.webdriver.support.ui import WebDriverWait
  54. from selenium.webdriver.support import expected_conditions as EC
  55. from selenium.webdriver.common.keys import Keys
  56. # Register fetch hook
  57. script = """
  58. window._fetch = window.fetch;
  59. window.fetch = async (url, options) => {
  60. // Call parent fetch method
  61. const response = await window._fetch(url, options);
  62. if (!url.startsWith("/api/conversation")) {
  63. return result;
  64. }
  65. // Copy response
  66. copy = response.clone();
  67. window._reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
  68. return copy;
  69. }
  70. window._last_message = "";
  71. """
  72. driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
  73. "source": script
  74. })
  75. try:
  76. driver.get(f"{cls.url}/home")
  77. wait = WebDriverWait(driver, 5)
  78. wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
  79. except:
  80. driver = web_session.reopen()
  81. driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
  82. "source": script
  83. })
  84. driver.get(f"{cls.url}/home")
  85. wait = WebDriverWait(driver, 240)
  86. wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
  87. try:
  88. driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
  89. driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
  90. except:
  91. pass
  92. if model:
  93. # Load model panel
  94. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "#SelectModel svg")))
  95. time.sleep(0.1)
  96. driver.find_element(By.CSS_SELECTOR, "#SelectModel svg").click()
  97. try:
  98. driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
  99. driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
  100. except:
  101. pass
  102. # Select model
  103. selector = f"div.flex-col div.items-center span[title='{model}']"
  104. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, selector)))
  105. span = driver.find_element(By.CSS_SELECTOR, selector)
  106. container = span.find_element(By.XPATH, "//div/../..")
  107. button = container.find_element(By.CSS_SELECTOR, "button.btn-blue.btn-small.border")
  108. button.click()
  109. # Submit prompt
  110. wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
  111. element_send_text(driver.find_element(By.ID, "textareaAutosize"), prompt)
  112. # Read response with reader
  113. script = """
  114. if(window._reader) {
  115. chunk = await window._reader.read();
  116. if (chunk['done']) {
  117. return null;
  118. }
  119. message = '';
  120. chunk['value'].split('\\r\\n').forEach((line, index) => {
  121. if (line.startsWith('data: ')) {
  122. try {
  123. line = JSON.parse(line.substring('data: '.length));
  124. message = line["args"]["content"];
  125. } catch(e) { }
  126. }
  127. });
  128. if (message) {
  129. try {
  130. return message.substring(window._last_message.length);
  131. } finally {
  132. window._last_message = message;
  133. }
  134. }
  135. }
  136. return '';
  137. """
  138. while True:
  139. chunk = driver.execute_script(script)
  140. if chunk:
  141. yield chunk
  142. elif chunk != "":
  143. break
  144. else:
  145. time.sleep(0.1)