Poe.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from __future__ import annotations
  2. import time
  3. from ...typing import CreateResult, Messages
  4. from ..base_provider import AbstractProvider
  5. from ..helper import format_prompt
  6. from ...webdriver import WebDriver, WebDriverSession, element_send_text
  7. models = {
  8. "meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"},
  9. "meta-llama/Llama-2-13b-chat-hf": {"name": "Llama-2-13b"},
  10. "meta-llama/Llama-2-70b-chat-hf": {"name": "Llama-2-70b"},
  11. "codellama/CodeLlama-7b-Instruct-hf": {"name": "Code-Llama-7b"},
  12. "codellama/CodeLlama-13b-Instruct-hf": {"name": "Code-Llama-13b"},
  13. "codellama/CodeLlama-34b-Instruct-hf": {"name": "Code-Llama-34b"},
  14. "gpt-3.5-turbo": {"name": "GPT-3.5-Turbo"},
  15. "gpt-3.5-turbo-instruct": {"name": "GPT-3.5-Turbo-Instruct"},
  16. "gpt-4": {"name": "GPT-4"},
  17. "palm": {"name": "Google-PaLM"},
  18. }
  19. class Poe(AbstractProvider):
  20. url = "https://poe.com"
  21. working = True
  22. needs_auth = True
  23. supports_stream = True
  24. models = models.keys()
  25. @classmethod
  26. def create_completion(
  27. cls,
  28. model: str,
  29. messages: Messages,
  30. stream: bool,
  31. proxy: str = None,
  32. webdriver: WebDriver = None,
  33. user_data_dir: str = None,
  34. headless: bool = True,
  35. **kwargs
  36. ) -> CreateResult:
  37. if not model:
  38. model = "gpt-3.5-turbo"
  39. elif model not in models:
  40. raise ValueError(f"Model are not supported: {model}")
  41. prompt = format_prompt(messages)
  42. session = WebDriverSession(webdriver, user_data_dir, headless, proxy=proxy)
  43. with session as driver:
  44. from selenium.webdriver.common.by import By
  45. from selenium.webdriver.support.ui import WebDriverWait
  46. from selenium.webdriver.support import expected_conditions as EC
  47. driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
  48. "source": """
  49. window._message = window._last_message = "";
  50. window._message_finished = false;
  51. class ProxiedWebSocket extends WebSocket {
  52. constructor(url, options) {
  53. super(url, options);
  54. this.addEventListener("message", (e) => {
  55. const data = JSON.parse(JSON.parse(e.data)["messages"][0])["payload"]["data"];
  56. if ("messageAdded" in data) {
  57. if (data["messageAdded"]["author"] != "human") {
  58. window._message = data["messageAdded"]["text"];
  59. if (data["messageAdded"]["state"] == "complete") {
  60. window._message_finished = true;
  61. }
  62. }
  63. }
  64. });
  65. }
  66. }
  67. window.WebSocket = ProxiedWebSocket;
  68. """
  69. })
  70. try:
  71. driver.get(f"{cls.url}/{models[model]['name']}")
  72. wait = WebDriverWait(driver, 10 if headless else 240)
  73. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']")))
  74. except:
  75. # Reopen browser for login
  76. if not webdriver:
  77. driver = session.reopen()
  78. driver.get(f"{cls.url}/{models[model]['name']}")
  79. wait = WebDriverWait(driver, 240)
  80. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']")))
  81. else:
  82. raise RuntimeError("Prompt textarea not found. You may not be logged in.")
  83. element_send_text(driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']"), prompt)
  84. driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click()
  85. script = """
  86. if(window._message && window._message != window._last_message) {
  87. try {
  88. return window._message.substring(window._last_message.length);
  89. } finally {
  90. window._last_message = window._message;
  91. }
  92. } else if(window._message_finished) {
  93. return null;
  94. } else {
  95. return '';
  96. }
  97. """
  98. while True:
  99. chunk = driver.execute_script(script)
  100. if chunk:
  101. yield chunk
  102. elif chunk != "":
  103. break
  104. else:
  105. time.sleep(0.1)