Pi.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from __future__ import annotations
  2. import json
  3. from ..typing import CreateResult, Messages
  4. from .base_provider import AbstractProvider, format_prompt
  5. from ..requests import Session, get_session_from_browser, raise_for_status
  6. class Pi(AbstractProvider):
  7. url = "https://pi.ai/talk"
  8. working = True
  9. supports_stream = True
  10. _session = None
  11. default_model = "pi"
  12. @classmethod
  13. def create_completion(
  14. cls,
  15. model: str,
  16. messages: Messages,
  17. stream: bool,
  18. proxy: str = None,
  19. timeout: int = 180,
  20. conversation_id: str = None,
  21. webdriver: WebDriver = None,
  22. **kwargs
  23. ) -> CreateResult:
  24. if cls._session is None:
  25. cls._session = get_session_from_browser(url=cls.url, proxy=proxy, timeout=timeout)
  26. if not conversation_id:
  27. conversation_id = cls.start_conversation(cls._session)
  28. prompt = format_prompt(messages)
  29. else:
  30. prompt = messages[-1]["content"]
  31. answer = cls.ask(cls._session, prompt, conversation_id)
  32. for line in answer:
  33. if "text" in line:
  34. yield line["text"]
  35. @classmethod
  36. def start_conversation(cls, session: Session) -> str:
  37. response = session.post('https://pi.ai/api/chat/start', data="{}", headers={
  38. 'accept': 'application/json',
  39. 'x-api-version': '3'
  40. })
  41. raise_for_status(response)
  42. return response.json()['conversations'][0]['sid']
  43. def get_chat_history(session: Session, conversation_id: str):
  44. params = {
  45. 'conversation': conversation_id,
  46. }
  47. response = session.get('https://pi.ai/api/chat/history', params=params)
  48. raise_for_status(response)
  49. return response.json()
  50. def ask(session: Session, prompt: str, conversation_id: str):
  51. json_data = {
  52. 'text': prompt,
  53. 'conversation': conversation_id,
  54. 'mode': 'BASE',
  55. }
  56. response = session.post('https://pi.ai/api/chat', json=json_data, stream=True)
  57. raise_for_status(response)
  58. for line in response.iter_lines():
  59. if line.startswith(b'data: {"text":'):
  60. yield json.loads(line.split(b'data: ')[1])
  61. elif line.startswith(b'data: {"title":'):
  62. yield json.loads(line.split(b'data: ')[1])