PerplexityAi.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from __future__ import annotations
  2. import time
  3. try:
  4. from selenium.webdriver.common.by import By
  5. from selenium.webdriver.support.ui import WebDriverWait
  6. from selenium.webdriver.support import expected_conditions as EC
  7. except ImportError:
  8. pass
  9. from ...typing import CreateResult, Messages
  10. from ..base_provider import AbstractProvider
  11. from ..helper import format_prompt
  12. from ...webdriver import WebDriver, WebDriverSession, element_send_text
  13. class PerplexityAi(AbstractProvider):
  14. url = "https://www.perplexity.ai"
  15. working = False
  16. supports_gpt_35_turbo = True
  17. supports_stream = True
  18. @classmethod
  19. def create_completion(
  20. cls,
  21. model: str,
  22. messages: Messages,
  23. stream: bool,
  24. proxy: str = None,
  25. timeout: int = 120,
  26. webdriver: WebDriver = None,
  27. virtual_display: bool = True,
  28. copilot: bool = False,
  29. **kwargs
  30. ) -> CreateResult:
  31. with WebDriverSession(webdriver, "", virtual_display=virtual_display, proxy=proxy) as driver:
  32. prompt = format_prompt(messages)
  33. driver.get(f"{cls.url}/")
  34. wait = WebDriverWait(driver, timeout)
  35. # Is page loaded?
  36. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']")))
  37. # Register WebSocket hook
  38. script = """
  39. window._message = window._last_message = "";
  40. window._message_finished = false;
  41. const _socket_send = WebSocket.prototype.send;
  42. WebSocket.prototype.send = function(...args) {
  43. if (!window.socket_onmessage) {
  44. window._socket_onmessage = this;
  45. this.addEventListener("message", (event) => {
  46. if (event.data.startsWith("42")) {
  47. let data = JSON.parse(event.data.substring(2));
  48. if (data[0] =="query_progress" || data[0] == "query_answered") {
  49. let content = JSON.parse(data[1]["text"]);
  50. if (data[1]["mode"] == "copilot") {
  51. content = content[content.length-1]["content"]["answer"];
  52. content = JSON.parse(content);
  53. }
  54. window._message = content["answer"];
  55. if (!window._message_finished) {
  56. window._message_finished = data[0] == "query_answered";
  57. }
  58. }
  59. }
  60. });
  61. }
  62. return _socket_send.call(this, ...args);
  63. };
  64. """
  65. driver.execute_script(script)
  66. if copilot:
  67. try:
  68. # Check for account
  69. driver.find_element(By.CSS_SELECTOR, "img[alt='User avatar']")
  70. # Enable copilot
  71. driver.find_element(By.CSS_SELECTOR, "button[data-testid='copilot-toggle']").click()
  72. except:
  73. raise RuntimeError("You need a account for copilot")
  74. # Submit prompt
  75. element_send_text(driver.find_element(By.CSS_SELECTOR, "textarea[placeholder='Ask anything...']"), prompt)
  76. # Stream response
  77. script = """
  78. if(window._message && window._message != window._last_message) {
  79. try {
  80. return window._message.substring(window._last_message.length);
  81. } finally {
  82. window._last_message = window._message;
  83. }
  84. } else if(window._message_finished) {
  85. return null;
  86. } else {
  87. return '';
  88. }
  89. """
  90. while True:
  91. chunk = driver.execute_script(script)
  92. if chunk:
  93. yield chunk
  94. elif chunk != "":
  95. break
  96. else:
  97. time.sleep(0.1)