Theb.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. from __future__ import annotations
  2. import time
  3. from ...typing import CreateResult, Messages
  4. from ..base_provider import AbstractProvider
  5. from ..helper import format_prompt
  6. models = {
  7. "theb-ai": "TheB.AI",
  8. "theb-ai-free": "TheB.AI Free",
  9. "gpt-3.5-turbo": "GPT-3.5 Turbo (New)",
  10. "gpt-3.5-turbo-16k": "GPT-3.5-16K",
  11. "gpt-4-turbo": "GPT-4 Turbo",
  12. "gpt-4": "GPT-4",
  13. "gpt-4-32k": "GPT-4 32K",
  14. "claude-2": "Claude 2",
  15. "claude-instant-1": "Claude Instant 1.2",
  16. "palm-2": "PaLM 2",
  17. "palm-2-32k": "PaLM 2 32K",
  18. "palm-2-codey": "Codey",
  19. "palm-2-codey-32k": "Codey 32K",
  20. "vicuna-13b-v1.5": "Vicuna v1.5 13B",
  21. "llama-2-7b-chat": "Llama 2 7B",
  22. "llama-2-13b-chat": "Llama 2 13B",
  23. "llama-2-70b-chat": "Llama 2 70B",
  24. "code-llama-7b": "Code Llama 7B",
  25. "code-llama-13b": "Code Llama 13B",
  26. "code-llama-34b": "Code Llama 34B",
  27. "qwen-7b-chat": "Qwen 7B"
  28. }
  29. class Theb(AbstractProvider):
  30. label = "TheB.AI"
  31. url = "https://beta.theb.ai"
  32. working = False
  33. supports_stream = True
  34. models = models.keys()
  35. @classmethod
  36. def create_completion(
  37. cls,
  38. model: str,
  39. messages: Messages,
  40. stream: bool,
  41. proxy: str = None,
  42. webdriver: WebDriver = None,
  43. virtual_display: bool = True,
  44. **kwargs
  45. ) -> CreateResult:
  46. if model in models:
  47. model = models[model]
  48. prompt = format_prompt(messages)
  49. web_session = WebDriverSession(webdriver, virtual_display=virtual_display, proxy=proxy)
  50. with web_session as driver:
  51. from selenium.webdriver.common.by import By
  52. from selenium.webdriver.support.ui import WebDriverWait
  53. from selenium.webdriver.support import expected_conditions as EC
  54. from selenium.webdriver.common.keys import Keys
  55. # Register fetch hook
  56. script = """
  57. window._fetch = window.fetch;
  58. window.fetch = async (url, options) => {
  59. // Call parent fetch method
  60. const response = await window._fetch(url, options);
  61. if (!url.startsWith("/api/conversation")) {
  62. return result;
  63. }
  64. // Copy response
  65. copy = response.clone();
  66. window._reader = response.body.pipeThrough(new TextDecoderStream()).getReader();
  67. return copy;
  68. }
  69. window._last_message = "";
  70. """
  71. driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
  72. "source": script
  73. })
  74. try:
  75. driver.get(f"{cls.url}/home")
  76. wait = WebDriverWait(driver, 5)
  77. wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
  78. except:
  79. driver = web_session.reopen()
  80. driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
  81. "source": script
  82. })
  83. driver.get(f"{cls.url}/home")
  84. wait = WebDriverWait(driver, 240)
  85. wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
  86. try:
  87. driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
  88. driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
  89. except:
  90. pass
  91. if model:
  92. # Load model panel
  93. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "#SelectModel svg")))
  94. time.sleep(0.1)
  95. driver.find_element(By.CSS_SELECTOR, "#SelectModel svg").click()
  96. try:
  97. driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
  98. driver.find_element(By.CSS_SELECTOR, ".driver-overlay").click()
  99. except:
  100. pass
  101. # Select model
  102. selector = f"div.flex-col div.items-center span[title='{model}']"
  103. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, selector)))
  104. span = driver.find_element(By.CSS_SELECTOR, selector)
  105. container = span.find_element(By.XPATH, "//div/../..")
  106. button = container.find_element(By.CSS_SELECTOR, "button.btn-blue.btn-small.border")
  107. button.click()
  108. # Submit prompt
  109. wait.until(EC.visibility_of_element_located((By.ID, "textareaAutosize")))
  110. element_send_text(driver.find_element(By.ID, "textareaAutosize"), prompt)
  111. # Read response with reader
  112. script = """
  113. if(window._reader) {
  114. chunk = await window._reader.read();
  115. if (chunk['done']) {
  116. return null;
  117. }
  118. message = '';
  119. chunk['value'].split('\\r\\n').forEach((line, index) => {
  120. if (line.startsWith('data: ')) {
  121. try {
  122. line = JSON.parse(line.substring('data: '.length));
  123. message = line["args"]["content"];
  124. } catch(e) { }
  125. }
  126. });
  127. if (message) {
  128. try {
  129. return message.substring(window._last_message.length);
  130. } finally {
  131. window._last_message = message;
  132. }
  133. }
  134. }
  135. return '';
  136. """
  137. while True:
  138. chunk = driver.execute_script(script)
  139. if chunk:
  140. yield chunk
  141. elif chunk != "":
  142. break
  143. else:
  144. time.sleep(0.1)