GigaChat.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. from __future__ import annotations
  2. import os
  3. import ssl
  4. import time
  5. import uuid
  6. from pathlib import Path
  7. import json
  8. from aiohttp import ClientSession, TCPConnector, BaseConnector
  9. from ...requests import raise_for_status
  10. from ...typing import AsyncResult, Messages
  11. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  12. from ...errors import MissingAuthError
  13. from ..helper import get_connector
  14. from ...cookies import get_cookies_dir
  15. access_token = ""
  16. token_expires_at = 0
  17. RUSSIAN_CA_CERT = """-----BEGIN CERTIFICATE-----
  18. MIIFwjCCA6qgAwIBAgICEAAwDQYJKoZIhvcNAQELBQAwcDELMAkGA1UEBhMCUlUx
  19. PzA9BgNVBAoMNlRoZSBNaW5pc3RyeSBvZiBEaWdpdGFsIERldmVsb3BtZW50IGFu
  20. ZCBDb21tdW5pY2F0aW9uczEgMB4GA1UEAwwXUnVzc2lhbiBUcnVzdGVkIFJvb3Qg
  21. Q0EwHhcNMjIwMzAxMjEwNDE1WhcNMzIwMjI3MjEwNDE1WjBwMQswCQYDVQQGEwJS
  22. VTE/MD0GA1UECgw2VGhlIE1pbmlzdHJ5IG9mIERpZ2l0YWwgRGV2ZWxvcG1lbnQg
  23. YW5kIENvbW11bmljYXRpb25zMSAwHgYDVQQDDBdSdXNzaWFuIFRydXN0ZWQgUm9v
  24. dCBDQTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAMfFOZ8pUAL3+r2n
  25. qqE0Zp52selXsKGFYoG0GM5bwz1bSFtCt+AZQMhkWQheI3poZAToYJu69pHLKS6Q
  26. XBiwBC1cvzYmUYKMYZC7jE5YhEU2bSL0mX7NaMxMDmH2/NwuOVRj8OImVa5s1F4U
  27. zn4Kv3PFlDBjjSjXKVY9kmjUBsXQrIHeaqmUIsPIlNWUnimXS0I0abExqkbdrXbX
  28. YwCOXhOO2pDUx3ckmJlCMUGacUTnylyQW2VsJIyIGA8V0xzdaeUXg0VZ6ZmNUr5Y
  29. Ber/EAOLPb8NYpsAhJe2mXjMB/J9HNsoFMBFJ0lLOT/+dQvjbdRZoOT8eqJpWnVD
  30. U+QL/qEZnz57N88OWM3rabJkRNdU/Z7x5SFIM9FrqtN8xewsiBWBI0K6XFuOBOTD
  31. 4V08o4TzJ8+Ccq5XlCUW2L48pZNCYuBDfBh7FxkB7qDgGDiaftEkZZfApRg2E+M9
  32. G8wkNKTPLDc4wH0FDTijhgxR3Y4PiS1HL2Zhw7bD3CbslmEGgfnnZojNkJtcLeBH
  33. BLa52/dSwNU4WWLubaYSiAmA9IUMX1/RpfpxOxd4Ykmhz97oFbUaDJFipIggx5sX
  34. ePAlkTdWnv+RWBxlJwMQ25oEHmRguNYf4Zr/Rxr9cS93Y+mdXIZaBEE0KS2iLRqa
  35. OiWBki9IMQU4phqPOBAaG7A+eP8PAgMBAAGjZjBkMB0GA1UdDgQWBBTh0YHlzlpf
  36. BKrS6badZrHF+qwshzAfBgNVHSMEGDAWgBTh0YHlzlpfBKrS6badZrHF+qwshzAS
  37. BgNVHRMBAf8ECDAGAQH/AgEEMA4GA1UdDwEB/wQEAwIBhjANBgkqhkiG9w0BAQsF
  38. AAOCAgEAALIY1wkilt/urfEVM5vKzr6utOeDWCUczmWX/RX4ljpRdgF+5fAIS4vH
  39. tmXkqpSCOVeWUrJV9QvZn6L227ZwuE15cWi8DCDal3Ue90WgAJJZMfTshN4OI8cq
  40. W9E4EG9wglbEtMnObHlms8F3CHmrw3k6KmUkWGoa+/ENmcVl68u/cMRl1JbW2bM+
  41. /3A+SAg2c6iPDlehczKx2oa95QW0SkPPWGuNA/CE8CpyANIhu9XFrj3RQ3EqeRcS
  42. AQQod1RNuHpfETLU/A2gMmvn/w/sx7TB3W5BPs6rprOA37tutPq9u6FTZOcG1Oqj
  43. C/B7yTqgI7rbyvox7DEXoX7rIiEqyNNUguTk/u3SZ4VXE2kmxdmSh3TQvybfbnXV
  44. 4JbCZVaqiZraqc7oZMnRoWrXRG3ztbnbes/9qhRGI7PqXqeKJBztxRTEVj8ONs1d
  45. WN5szTwaPIvhkhO3CO5ErU2rVdUr89wKpNXbBODFKRtgxUT70YpmJ46VVaqdAhOZ
  46. D9EUUn4YaeLaS8AjSF/h7UkjOibNc4qVDiPP+rkehFWM66PVnP1Msh93tc+taIfC
  47. EYVMxjh8zNbFuoc7fzvvrFILLe7ifvEIUqSVIC/AzplM/Jxw7buXFeGP1qVCBEHq
  48. 391d/9RAfaZ12zkwFsl+IKwE/OZxW8AHa9i1p4GO0YSNuczzEm4=
  49. -----END CERTIFICATE-----"""
  50. class GigaChat(AsyncGeneratorProvider, ProviderModelMixin):
  51. url = "https://developers.sber.ru/gigachat"
  52. working = True
  53. supports_message_history = True
  54. supports_system_message = True
  55. supports_stream = True
  56. needs_auth = True
  57. default_model = "GigaChat:latest"
  58. models = [default_model, "GigaChat-Plus", "GigaChat-Pro"]
  59. @classmethod
  60. async def create_async_generator(
  61. cls,
  62. model: str,
  63. messages: Messages,
  64. stream: bool = True,
  65. proxy: str = None,
  66. api_key: str = None,
  67. connector: BaseConnector = None,
  68. scope: str = "GIGACHAT_API_PERS",
  69. update_interval: float = 0,
  70. **kwargs
  71. ) -> AsyncResult:
  72. global access_token, token_expires_at
  73. model = cls.get_model(model)
  74. if not api_key:
  75. raise MissingAuthError('Missing "api_key"')
  76. # Create certificate file in cookies directory
  77. cookies_dir = Path(get_cookies_dir())
  78. cert_file = cookies_dir / 'russian_trusted_root_ca.crt'
  79. # Write certificate if it doesn't exist
  80. if not cert_file.exists():
  81. cert_file.write_text(RUSSIAN_CA_CERT)
  82. ssl_context = ssl.create_default_context(cafile=str(cert_file))
  83. if connector is None:
  84. connector = TCPConnector(ssl_context=ssl_context)
  85. async with ClientSession(connector=get_connector(connector, proxy)) as session:
  86. if token_expires_at - int(time.time() * 1000) < 60000:
  87. async with session.post(url="https://ngw.devices.sberbank.ru:9443/api/v2/oauth",
  88. headers={"Authorization": f"Bearer {api_key}",
  89. "RqUID": str(uuid.uuid4()),
  90. "Content-Type": "application/x-www-form-urlencoded"},
  91. data={"scope": scope}) as response:
  92. await raise_for_status(response)
  93. data = await response.json()
  94. access_token = data['access_token']
  95. token_expires_at = data['expires_at']
  96. async with session.post(url="https://gigachat.devices.sberbank.ru/api/v1/chat/completions",
  97. headers={"Authorization": f"Bearer {access_token}"},
  98. json={
  99. "model": model,
  100. "messages": messages,
  101. "stream": stream,
  102. "update_interval": update_interval,
  103. **kwargs
  104. }) as response:
  105. await raise_for_status(response)
  106. async for line in response.content:
  107. if not stream:
  108. yield json.loads(line.decode("utf-8"))['choices'][0]['message']['content']
  109. return
  110. if line and line.startswith(b"data:"):
  111. line = line[6:-1] # remove "data: " prefix and "\n" suffix
  112. if line.strip() == b"[DONE]":
  113. return
  114. else:
  115. msg = json.loads(line.decode("utf-8"))['choices'][0]
  116. content = msg['delta']['content']
  117. if content:
  118. yield content
  119. if 'finish_reason' in msg:
  120. return