Bard.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from __future__ import annotations
  2. import time
  3. import os
  4. from selenium.webdriver.common.by import By
  5. from selenium.webdriver.support.ui import WebDriverWait
  6. from selenium.webdriver.support import expected_conditions as EC
  7. from selenium.webdriver.common.keys import Keys
  8. from ...typing import CreateResult, Messages
  9. from ..base_provider import AbstractProvider
  10. from ..helper import format_prompt
  11. from ...webdriver import WebDriver, WebDriverSession
  12. class Bard(AbstractProvider):
  13. url = "https://bard.google.com"
  14. working = True
  15. needs_auth = True
  16. @classmethod
  17. def create_completion(
  18. cls,
  19. model: str,
  20. messages: Messages,
  21. stream: bool,
  22. proxy: str = None,
  23. webdriver: WebDriver = None,
  24. user_data_dir: str = None,
  25. headless: bool = True,
  26. **kwargs
  27. ) -> CreateResult:
  28. prompt = format_prompt(messages)
  29. session = WebDriverSession(webdriver, user_data_dir, headless, proxy=proxy)
  30. with session as driver:
  31. try:
  32. driver.get(f"{cls.url}/chat")
  33. wait = WebDriverWait(driver, 10 if headless else 240)
  34. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea")))
  35. except:
  36. # Reopen browser for login
  37. if not webdriver:
  38. driver = session.reopen()
  39. driver.get(f"{cls.url}/chat")
  40. login_url = os.environ.get("G4F_LOGIN_URL")
  41. if login_url:
  42. yield f"Please login: [Google Bard]({login_url})\n\n"
  43. wait = WebDriverWait(driver, 240)
  44. wait.until(EC.visibility_of_element_located((By.CSS_SELECTOR, "div.ql-editor.textarea")))
  45. else:
  46. raise RuntimeError("Prompt textarea not found. You may not be logged in.")
  47. # Add hook in XMLHttpRequest
  48. script = """
  49. const _http_request_open = XMLHttpRequest.prototype.open;
  50. window._message = "";
  51. XMLHttpRequest.prototype.open = function(method, url) {
  52. if (url.includes("/assistant.lamda.BardFrontendService/StreamGenerate")) {
  53. this.addEventListener("load", (event) => {
  54. window._message = JSON.parse(JSON.parse(this.responseText.split("\\n")[3])[0][2])[4][0][1][0];
  55. });
  56. }
  57. return _http_request_open.call(this, method, url);
  58. }
  59. """
  60. driver.execute_script(script)
  61. textarea = driver.find_element(By.CSS_SELECTOR, "div.ql-editor.textarea")
  62. lines = prompt.splitlines()
  63. for idx, line in enumerate(lines):
  64. textarea.send_keys(line)
  65. if (len(lines) - 1 != idx):
  66. textarea.send_keys(Keys.SHIFT + "\n")
  67. textarea.send_keys(Keys.ENTER)
  68. while True:
  69. chunk = driver.execute_script("return window._message;")
  70. if chunk:
  71. yield chunk
  72. return
  73. else:
  74. time.sleep(0.1)