GeminiPro.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from __future__ import annotations
  2. import base64
  3. import json
  4. import requests
  5. from typing import Optional
  6. from aiohttp import ClientSession, BaseConnector
  7. from ...typing import AsyncResult, Messages, ImagesType
  8. from ...image import to_bytes, is_accepted_format
  9. from ...errors import MissingAuthError
  10. from ...requests.raise_for_status import raise_for_status
  11. from ...providers.response import Usage, FinishReason
  12. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  13. from ..helper import get_connector
  14. from ... import debug
  15. class GeminiPro(AsyncGeneratorProvider, ProviderModelMixin):
  16. label = "Google Gemini API"
  17. url = "https://ai.google.dev"
  18. login_url = "https://aistudio.google.com/u/0/apikey"
  19. api_base = "https://generativelanguage.googleapis.com/v1beta"
  20. working = True
  21. supports_message_history = True
  22. supports_system_message = True
  23. needs_auth = True
  24. default_model = "gemini-1.5-pro"
  25. default_vision_model = default_model
  26. fallback_models = [default_model, "gemini-2.0-flash-exp", "gemini-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b"]
  27. model_aliases = {
  28. "gemini-1.5-flash": "gemini-1.5-flash",
  29. "gemini-1.5-flash": "gemini-1.5-flash-8b",
  30. "gemini-1.5-pro": "gemini-pro",
  31. "gemini-2.0-flash": "gemini-2.0-flash-exp",
  32. }
  33. @classmethod
  34. def get_models(cls, api_key: str = None, api_base: str = api_base) -> list[str]:
  35. if not cls.models:
  36. try:
  37. url = f"{cls.api_base if not api_base else api_base}/models"
  38. response = requests.get(url, params={"key": api_key})
  39. raise_for_status(response)
  40. data = response.json()
  41. cls.models = [
  42. model.get("name").split("/").pop()
  43. for model in data.get("models")
  44. if "generateContent" in model.get("supportedGenerationMethods")
  45. ]
  46. cls.models.sort()
  47. except Exception as e:
  48. debug.error(e)
  49. if api_key is not None:
  50. raise MissingAuthError("Invalid API key")
  51. return cls.fallback_models
  52. return cls.models
  53. @classmethod
  54. async def create_async_generator(
  55. cls,
  56. model: str,
  57. messages: Messages,
  58. stream: bool = False,
  59. proxy: str = None,
  60. api_key: str = None,
  61. api_base: str = api_base,
  62. use_auth_header: bool = False,
  63. images: ImagesType = None,
  64. tools: Optional[list] = None,
  65. connector: BaseConnector = None,
  66. **kwargs
  67. ) -> AsyncResult:
  68. if not api_key:
  69. raise MissingAuthError('Add a "api_key"')
  70. model = cls.get_model(model, api_key=api_key, api_base=api_base)
  71. headers = params = None
  72. if use_auth_header:
  73. headers = {"Authorization": f"Bearer {api_key}"}
  74. else:
  75. params = {"key": api_key}
  76. method = "streamGenerateContent" if stream else "generateContent"
  77. url = f"{api_base.rstrip('/')}/models/{model}:{method}"
  78. async with ClientSession(headers=headers, connector=get_connector(connector, proxy)) as session:
  79. contents = [
  80. {
  81. "role": "model" if message["role"] == "assistant" else "user",
  82. "parts": [{"text": message["content"]}]
  83. }
  84. for message in messages
  85. if message["role"] != "system"
  86. ]
  87. if images is not None:
  88. for image, _ in images:
  89. image = to_bytes(image)
  90. contents[-1]["parts"].append({
  91. "inline_data": {
  92. "mime_type": is_accepted_format(image),
  93. "data": base64.b64encode(image).decode()
  94. }
  95. })
  96. data = {
  97. "contents": contents,
  98. "generationConfig": {
  99. "stopSequences": kwargs.get("stop"),
  100. "temperature": kwargs.get("temperature"),
  101. "maxOutputTokens": kwargs.get("max_tokens"),
  102. "topP": kwargs.get("top_p"),
  103. "topK": kwargs.get("top_k"),
  104. },
  105. "tools": [{
  106. "function_declarations": [{
  107. "name": tool["function"]["name"],
  108. "description": tool["function"]["description"],
  109. "parameters": {
  110. "type": "object",
  111. "properties": {key: {
  112. "type": value["type"],
  113. "description": value["title"]
  114. } for key, value in tool["function"]["parameters"]["properties"].items()}
  115. },
  116. } for tool in tools]
  117. }] if tools else None
  118. }
  119. system_prompt = "\n".join(
  120. message["content"]
  121. for message in messages
  122. if message["role"] == "system"
  123. )
  124. if system_prompt:
  125. data["system_instruction"] = {"parts": {"text": system_prompt}}
  126. async with session.post(url, params=params, json=data) as response:
  127. if not response.ok:
  128. data = await response.json()
  129. data = data[0] if isinstance(data, list) else data
  130. raise RuntimeError(f"Response {response.status}: {data['error']['message']}")
  131. if stream:
  132. lines = []
  133. async for chunk in response.content:
  134. if chunk == b"[{\n":
  135. lines = [b"{\n"]
  136. elif chunk == b",\r\n" or chunk == b"]":
  137. try:
  138. data = b"".join(lines)
  139. data = json.loads(data)
  140. yield data["candidates"][0]["content"]["parts"][0]["text"]
  141. if "finishReason" in data["candidates"][0]:
  142. yield FinishReason(data["candidates"][0]["finishReason"].lower())
  143. usage = data.get("usageMetadata")
  144. if usage:
  145. yield Usage(
  146. prompt_tokens=usage.get("promptTokenCount"),
  147. completion_tokens=usage.get("candidatesTokenCount"),
  148. total_tokens=usage.get("totalTokenCount")
  149. )
  150. except:
  151. data = data.decode(errors="ignore") if isinstance(data, bytes) else data
  152. raise RuntimeError(f"Read chunk failed: {data}")
  153. lines = []
  154. else:
  155. lines.append(chunk)
  156. else:
  157. data = await response.json()
  158. candidate = data["candidates"][0]
  159. if candidate["finishReason"] == "STOP":
  160. yield candidate["content"]["parts"][0]["text"]
  161. else:
  162. yield candidate["finishReason"] + ' ' + candidate["safetyRatings"]