Poe.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. from __future__ import annotations
  2. import time
  3. from ...typing import CreateResult, Messages
  4. from ..base_provider import AbstractProvider
  5. from ..helper import format_prompt
  6. models = {
  7. "meta-llama/Llama-2-7b-chat-hf": {"name": "Llama-2-7b"},
  8. "meta-llama/Llama-2-13b-chat-hf": {"name": "Llama-2-13b"},
  9. "meta-llama/Llama-2-70b-chat-hf": {"name": "Llama-2-70b"},
  10. "codellama/CodeLlama-7b-Instruct-hf": {"name": "Code-Llama-7b"},
  11. "codellama/CodeLlama-13b-Instruct-hf": {"name": "Code-Llama-13b"},
  12. "codellama/CodeLlama-34b-Instruct-hf": {"name": "Code-Llama-34b"},
  13. "gpt-3.5-turbo": {"name": "GPT-3.5-Turbo"},
  14. "gpt-3.5-turbo-instruct": {"name": "GPT-3.5-Turbo-Instruct"},
  15. "gpt-4": {"name": "GPT-4"},
  16. "palm": {"name": "Google-PaLM"},
  17. }
  18. class Poe(AbstractProvider):
  19. url = "https://poe.com"
  20. working = False
  21. needs_auth = True
  22. supports_stream = True
  23. models = models.keys()
  24. @classmethod
  25. def create_completion(
  26. cls,
  27. model: str,
  28. messages: Messages,
  29. stream: bool,
  30. proxy: str = None,
  31. webdriver: WebDriver = None,
  32. user_data_dir: str = None,
  33. headless: bool = True,
  34. **kwargs
  35. ) -> CreateResult:
  36. if not model:
  37. model = "gpt-3.5-turbo"
  38. elif model not in models:
  39. raise ValueError(f"Model are not supported: {model}")
  40. prompt = format_prompt(messages)
  41. session = WebDriverSession(webdriver, user_data_dir, headless, proxy=proxy)
  42. with session as driver:
  43. from selenium.webdriver.common.by import By
  44. from selenium.webdriver.support.ui import WebDriverWait
  45. from selenium.webdriver.support import expected_conditions as EC
  46. driver.execute_cdp_cmd("Page.addScriptToEvaluateOnNewDocument", {
  47. "source": """
  48. window._message = window._last_message = "";
  49. window._message_finished = false;
  50. class ProxiedWebSocket extends WebSocket {
  51. constructor(url, options) {
  52. super(url, options);
  53. this.addEventListener("message", (e) => {
  54. const data = JSON.parse(JSON.parse(e.data)["messages"][0])["payload"]["data"];
  55. if ("messageAdded" in data) {
  56. if (data["messageAdded"]["author"] != "human") {
  57. window._message = data["messageAdded"]["text"];
  58. if (data["messageAdded"]["state"] == "complete") {
  59. window._message_finished = true;
  60. }
  61. }
  62. }
  63. });
  64. }
  65. }
  66. window.WebSocket = ProxiedWebSocket;
  67. """
  68. })
  69. try:
  70. driver.get(f"{cls.url}/{models[model]['name']}")
  71. wait = WebDriverWait(driver, 10 if headless else 240)
  72. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']")))
  73. except:
  74. # Reopen browser for login
  75. if not webdriver:
  76. driver = session.reopen()
  77. driver.get(f"{cls.url}/{models[model]['name']}")
  78. wait = WebDriverWait(driver, 240)
  79. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[class^='GrowingTextArea']")))
  80. else:
  81. raise RuntimeError("Prompt textarea not found. You may not be logged in.")
  82. element_send_text(driver.find_element(By.CSS_SELECTOR, "footer textarea[class^='GrowingTextArea']"), prompt)
  83. driver.find_element(By.CSS_SELECTOR, "footer button[class*='ChatMessageSendButton']").click()
  84. script = """
  85. if(window._message && window._message != window._last_message) {
  86. try {
  87. return window._message.substring(window._last_message.length);
  88. } finally {
  89. window._last_message = window._message;
  90. }
  91. } else if(window._message_finished) {
  92. return null;
  93. } else {
  94. return '';
  95. }
  96. """
  97. while True:
  98. chunk = driver.execute_script(script)
  99. if chunk:
  100. yield chunk
  101. elif chunk != "":
  102. break
  103. else:
  104. time.sleep(0.1)