GeminiPro.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from __future__ import annotations
  2. import base64
  3. import json
  4. import requests
  5. from typing import Optional
  6. from aiohttp import ClientSession, BaseConnector
  7. from ...typing import AsyncResult, Messages, ImagesType
  8. from ...image import to_bytes, is_accepted_format
  9. from ...errors import MissingAuthError
  10. from ...requests.raise_for_status import raise_for_status
  11. from ...providers.response import Usage, FinishReason
  12. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  13. from ..helper import get_connector
  14. from ... import debug
  15. class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
  16. label = "Google Gemini API"
  17. url = "https://ai.google.dev"
  18. login_url = "https://aistudio.google.com/u/0/apikey"
  19. api_base = "https://generativelanguage.googleapis.com/v1beta"
  20. working = True
  21. supports_message_history = True
  22. needs_auth = True
  23. default_model = "gemini-1.5-pro"
  24. default_vision_model = default_model
  25. fallback_models = [default_model, "gemini-2.0-flash-exp", "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
  26. model_aliases = {
  27. "gemini-1.5-flash": "gemini-1.5-flash",
  28. "gemini-1.5-flash": "gemini-1.5-flash-8b",
  29. "gemini-1.5-pro": "gemini-pro",
  30. "gemini-2.0-flash": "gemini-2.0-flash-exp",
  31. }
  32. @classmethod
  33. def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
  34. if not cls.models:
  35. try:
  36. response = requests.get(f"{api_base}/models?key={api_key}")
  37. raise_for_status(response)
  38. data = response.json()
  39. cls.models = [
  40. model.get("name").split("/").pop()
  41. for model in data.get("models")
  42. if "generateContent" in model.get("supportedGenerationMethods")
  43. ]
  44. cls.models.sort()
  45. except Exception as e:
  46. debug.log(e)
  47. cls.models = cls.fallback_models
  48. return cls.models
  49. @classmethod
  50. async def create_async_generator(
  51. cls,
  52. model: str,
  53. messages: Messages,
  54. stream: bool = False,
  55. proxy: str = None,
  56. api_key: str = None,
  57. api_base: str = api_base,
  58. use_auth_header: bool = False,
  59. images: ImagesType = None,
  60. tools: Optional[list] = None,
  61. connector: BaseConnector = None,
  62. **kwargs
  63. ) -> AsyncResult:
  64. if not api_key:
  65. raise MissingAuthError('Add a "api_key"')
  66. model = cls.get_model(model, api_key=api_key, api_base=api_base)
  67. headers = params = None
  68. if use_auth_header:
  69. headers = {"Authorization": f"Bearer {api_key}"}
  70. else:
  71. params = {"key": api_key}
  72. method = "streamGenerateContent" if stream else "generateContent"
  73. url = f"{api_base.rstrip('/')}/models/{model}:{method}"
  74. async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session:
  75. contents = [
  76. {
  77. "role": "model" if message["role"] == "assistant" else "user",
  78. "parts": [{"text": message["content"]}]
  79. }
  80. for message in messages
  81. if message["role"] != "system"
  82. ]
  83. if images is not None:
  84. for image, _ in images:
  85. image = to_bytes(image)
  86. contents[-1]["parts"].append({
  87. "inline_data": {
  88. "mime_type": is_accepted_format(image),
  89. "data": base64.b64encode(image).decode()
  90. }
  91. })
  92. data = {
  93. "contents": contents,
  94. "generationConfig": {
  95. "stopSequences": kwargs.get("stop"),
  96. "temperature": kwargs.get("temperature"),
  97. "maxOutputTokens": kwargs.get("max_tokens"),
  98. "topP": kwargs.get("top_p"),
  99. "topK": kwargs.get("top_k"),
  100. },
  101. "tools": [{
  102. "functionDeclarations": tools
  103. }] if tools else None
  104. }
  105. system_prompt = "\n".join(
  106. message["content"]
  107. for message in messages
  108. if message["role"] == "system"
  109. )
  110. if system_prompt:
  111. data["system_instruction"] = {"parts": {"text": system_prompt}}
  112. async with session.post(url, params=params, json=data) as response:
  113. if not response.ok:
  114. data = await response.json()
  115. data = data[0] if isinstance(data, list) else data
  116. raise RuntimeError(f"Response {response.status}: {data['error']['message']}")
  117. if stream:
  118. lines = []
  119. async for chunk in response.content:
  120. if chunk == b"[{\n":
  121. lines = [b"{\n"]
  122. elif chunk == b",\r\n" or chunk == b"]":
  123. try:
  124. data = b"".join(lines)
  125. data = json.loads(data)
  126. yield data["candidates"][0]["content"]["parts"][0]["text"]
  127. if "finishReason" in data["candidates"][0]:
  128. yield FinishReason(data["candidates"][0]["finishReason"].lower())
  129. usage = data.get("usageMetadata")
  130. if usage:
  131. yield Usage(
  132. prompt_tokens=usage.get("promptTokenCount"),
  133. completion_tokens=usage.get("candidatesTokenCount"),
  134. total_tokens=usage.get("totalTokenCount")
  135. )
  136. except:
  137. data = data.decode(errors="ignore") if isinstance(data, bytes) else data
  138. raise RuntimeError(f"Read chunk failed: {data}")
  139. lines = []
  140. else:
  141. lines.append(chunk)
  142. else:
  143. data = await response.json()
  144. candidate = data["candidates"][0]
  145. if candidate["finishReason"] == "STOP":
  146. yield candidate["content"]["parts"][0]["text"]
  147. else:
  148. yield candidate["finishReason"] + ' ' + candidate["safetyRatings"]