GeminiPro.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. from __future__ import annotations
  2. import base64
  3. import json
  4. from aiohttp import ClientSession, BaseConnector
  5. from ...typing import AsyncResult, Messages, ImagesType
  6. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  7. from ...image import to_bytes, is_accepted_format
  8. from ...errors import MissingAuthError
  9. from ..helper import get_connector
  10. class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
  11. label = "Google Gemini API"
  12. url = "https://ai.google.dev"
  13. working = True
  14. supports_message_history = True
  15. needs_auth = True
  16. default_model = "gemini-1.5-pro"
  17. default_vision_model = default_model
  18. models = [default_model, "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
  19. model_aliases = {
  20. "gemini-flash": "gemini-1.5-flash",
  21. "gemini-flash": "gemini-1.5-flash-8b",
  22. }
  23. @classmethod
  24. async def create_async_generator(
  25. cls,
  26. model: str,
  27. messages: Messages,
  28. stream: bool = False,
  29. proxy: str = None,
  30. api_key: str = None,
  31. api_base: str = "https://generativelanguage.googleapis.com/v1beta",
  32. use_auth_header: bool = False,
  33. images: ImagesType = None,
  34. connector: BaseConnector = None,
  35. **kwargs
  36. ) -> AsyncResult:
  37. model = cls.get_model(model)
  38. if not api_key:
  39. raise MissingAuthError('Add a "api_key"')
  40. headers = params = None
  41. if use_auth_header:
  42. headers = {"Authorization": f"Bearer {api_key}"}
  43. else:
  44. params = {"key": api_key}
  45. method = "streamGenerateContent" if stream else "generateContent"
  46. url = f"{api_base.rstrip('/')}/models/{model}:{method}"
  47. async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session:
  48. contents = [
  49. {
  50. "role": "model" if message["role"] == "assistant" else "user",
  51. "parts": [{"text": message["content"]}]
  52. }
  53. for message in messages
  54. if message["role"] != "system"
  55. ]
  56. if images is not None:
  57. for image, _ in images:
  58. image = to_bytes(image)
  59. contents[-1]["parts"].append({
  60. "inline_data": {
  61. "mime_type": is_accepted_format(image),
  62. "data": base64.b64encode(image).decode()
  63. }
  64. })
  65. data = {
  66. "contents": contents,
  67. "generationConfig": {
  68. "stopSequences": kwargs.get("stop"),
  69. "temperature": kwargs.get("temperature"),
  70. "maxOutputTokens": kwargs.get("max_tokens"),
  71. "topP": kwargs.get("top_p"),
  72. "topK": kwargs.get("top_k"),
  73. }
  74. }
  75. system_prompt = "\n".join(
  76. message["content"]
  77. for message in messages
  78. if message["role"] == "system"
  79. )
  80. if system_prompt:
  81. data["system_instruction"] = {"parts": {"text": system_prompt}}
  82. async with session.post(url, params=params, json=data) as response:
  83. if not response.ok:
  84. data = await response.json()
  85. data = data[0] if isinstance(data, list) else data
  86. raise RuntimeError(f"Response {response.status}: {data['error']['message']}")
  87. if stream:
  88. lines = []
  89. async for chunk in response.content:
  90. if chunk == b"[{\n":
  91. lines = [b"{\n"]
  92. elif chunk == b",\r\n" or chunk == b"]":
  93. try:
  94. data = b"".join(lines)
  95. data = json.loads(data)
  96. yield data["candidates"][0]["content"]["parts"][0]["text"]
  97. except:
  98. data = data.decode(errors="ignore") if isinstance(data, bytes) else data
  99. raise RuntimeError(f"Read chunk failed: {data}")
  100. lines = []
  101. else:
  102. lines.append(chunk)
  103. else:
  104. data = await response.json()
  105. candidate = data["candidates"][0]
  106. if candidate["finishReason"] == "STOP":
  107. yield candidate["content"]["parts"][0]["text"]
  108. else:
  109. yield candidate["finishReason"] + ' ' + candidate["safetyRatings"]