api.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. import asyncio
  5. from typing import Iterator
  6. from flask import send_from_directory
  7. from inspect import signature
  8. from ...errors import VersionNotFoundError
  9. from ...image.copy_images import copy_images, ensure_images_dir, images_dir
  10. from ...tools.run_tools import iter_run_tools
  11. from ...Provider import ProviderUtils, __providers__
  12. from ...providers.base_provider import ProviderModelMixin
  13. from ...providers.retry_provider import BaseRetryProvider
  14. from ...providers.helper import format_image_prompt
  15. from ...providers.response import *
  16. from ... import version, models
  17. from ... import ChatCompletion, get_model_and_provider
  18. from ... import debug
  19. logger = logging.getLogger(__name__)
  20. conversations: dict[dict[str, BaseConversation]] = {}
  21. class Api:
  22. @staticmethod
  23. def get_models():
  24. return [{
  25. "name": model.name,
  26. "image": isinstance(model, models.ImageModel),
  27. "vision": isinstance(model, models.VisionModel),
  28. "providers": [
  29. getattr(provider, "parent", provider.__name__)
  30. for provider in providers
  31. if provider.working
  32. ]
  33. }
  34. for model, providers in models.__models__.values()]
  35. @staticmethod
  36. def get_provider_models(provider: str, api_key: str = None, api_base: str = None):
  37. if provider in ProviderUtils.convert:
  38. provider = ProviderUtils.convert[provider]
  39. if issubclass(provider, ProviderModelMixin):
  40. if "api_key" in signature(provider.get_models).parameters:
  41. models = provider.get_models(api_key=api_key, api_base=api_base)
  42. else:
  43. models = provider.get_models()
  44. return [
  45. {
  46. "model": model,
  47. "default": model == provider.default_model,
  48. "vision": getattr(provider, "default_vision_model", None) == model or model in getattr(provider, "vision_models", []),
  49. "image": False if provider.image_models is None else model in provider.image_models,
  50. }
  51. for model in models
  52. ]
  53. return []
  54. @staticmethod
  55. def get_providers() -> dict[str, str]:
  56. return [{
  57. "name": provider.__name__,
  58. "label": provider.label if hasattr(provider, "label") else provider.__name__,
  59. "parent": getattr(provider, "parent", None),
  60. "image": bool(getattr(provider, "image_models", False)),
  61. "vision": getattr(provider, "default_vision_model", None) is not None,
  62. "nodriver": getattr(provider, "use_nodriver", False),
  63. "auth": provider.needs_auth,
  64. "login_url": getattr(provider, "login_url", None),
  65. } for provider in __providers__ if provider.working]
  66. @staticmethod
  67. def get_version() -> dict:
  68. current_version = None
  69. latest_version = None
  70. try:
  71. current_version = version.utils.current_version
  72. latest_version = version.utils.latest_version
  73. except VersionNotFoundError:
  74. pass
  75. return {
  76. "version": current_version,
  77. "latest_version": latest_version,
  78. }
  79. def serve_images(self, name):
  80. ensure_images_dir()
  81. return send_from_directory(os.path.abspath(images_dir), name)
  82. def _prepare_conversation_kwargs(self, json_data: dict, kwargs: dict):
  83. model = json_data.get('model')
  84. provider = json_data.get('provider')
  85. messages = json_data.get('messages')
  86. api_key = json_data.get("api_key")
  87. if api_key:
  88. kwargs["api_key"] = api_key
  89. api_base = json_data.get("api_base")
  90. if api_base:
  91. kwargs["api_base"] = api_base
  92. kwargs["tool_calls"] = [{
  93. "function": {
  94. "name": "bucket_tool"
  95. },
  96. "type": "function"
  97. }]
  98. web_search = json_data.get('web_search')
  99. if web_search:
  100. kwargs["web_search"] = web_search
  101. action = json_data.get('action')
  102. if action == "continue":
  103. kwargs["tool_calls"].append({
  104. "function": {
  105. "name": "continue_tool"
  106. },
  107. "type": "function"
  108. })
  109. conversation = json_data.get("conversation")
  110. if conversation is not None:
  111. kwargs["conversation"] = JsonConversation(**conversation)
  112. else:
  113. conversation_id = json_data.get("conversation_id")
  114. if conversation_id and provider:
  115. if provider in conversations and conversation_id in conversations[provider]:
  116. kwargs["conversation"] = conversations[provider][conversation_id]
  117. if json_data.get("ignored"):
  118. kwargs["ignored"] = json_data["ignored"]
  119. if json_data.get("action"):
  120. kwargs["action"] = json_data["action"]
  121. return {
  122. "model": model,
  123. "provider": provider,
  124. "messages": messages,
  125. "stream": True,
  126. "ignore_stream": True,
  127. "return_conversation": True,
  128. **kwargs
  129. }
  130. def _create_response_stream(self, kwargs: dict, conversation_id: str, provider: str, download_images: bool = True) -> Iterator:
  131. def decorated_log(text: str, file = None):
  132. debug.logs.append(text)
  133. if debug.logging:
  134. debug.log_handler(text, file=file)
  135. debug.log = decorated_log
  136. proxy = os.environ.get("G4F_PROXY")
  137. provider = kwargs.get("provider")
  138. try:
  139. model, provider_handler = get_model_and_provider(
  140. kwargs.get("model"), provider,
  141. stream=True,
  142. ignore_stream=True,
  143. logging=False,
  144. has_images="images" in kwargs,
  145. )
  146. except Exception as e:
  147. debug.error(e)
  148. yield self._format_json('error', type(e).__name__, message=get_error_message(e))
  149. return
  150. if not isinstance(provider_handler, BaseRetryProvider):
  151. if not provider:
  152. provider = provider_handler.__name__
  153. yield self.handle_provider(provider_handler, model)
  154. if hasattr(provider_handler, "get_parameters"):
  155. yield self._format_json("parameters", provider_handler.get_parameters(as_json=True))
  156. try:
  157. result = iter_run_tools(ChatCompletion.create, **{**kwargs, "model": model, "provider": provider_handler})
  158. for chunk in result:
  159. if isinstance(chunk, ProviderInfo):
  160. yield self.handle_provider(chunk, model)
  161. provider = chunk.name
  162. elif isinstance(chunk, BaseConversation):
  163. if provider is not None:
  164. if hasattr(provider, "__name__"):
  165. provider = provider.__name__
  166. if provider not in conversations:
  167. conversations[provider] = {}
  168. conversations[provider][conversation_id] = chunk
  169. if isinstance(chunk, JsonConversation):
  170. yield self._format_json("conversation", {
  171. provider: chunk.get_dict()
  172. })
  173. else:
  174. yield self._format_json("conversation_id", conversation_id)
  175. elif isinstance(chunk, Exception):
  176. logger.exception(chunk)
  177. debug.error(chunk)
  178. yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
  179. elif isinstance(chunk, RequestLogin):
  180. yield self._format_json("preview", chunk.to_string())
  181. elif isinstance(chunk, PreviewResponse):
  182. yield self._format_json("preview", chunk.to_string())
  183. elif isinstance(chunk, ImagePreview):
  184. yield self._format_json("preview", chunk.to_string(), images=chunk.images, alt=chunk.alt)
  185. elif isinstance(chunk, ImageResponse):
  186. images = chunk
  187. if download_images or chunk.get("cookies"):
  188. chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt)
  189. images = asyncio.run(copy_images(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt))
  190. images = ImageResponse(images, chunk.alt)
  191. yield self._format_json("content", str(images), images=chunk.get_list(), alt=chunk.alt)
  192. elif isinstance(chunk, SynthesizeData):
  193. yield self._format_json("synthesize", chunk.get_dict())
  194. elif isinstance(chunk, TitleGeneration):
  195. yield self._format_json("title", chunk.title)
  196. elif isinstance(chunk, RequestLogin):
  197. yield self._format_json("login", str(chunk))
  198. elif isinstance(chunk, Parameters):
  199. yield self._format_json("parameters", chunk.get_dict())
  200. elif isinstance(chunk, FinishReason):
  201. yield self._format_json("finish", chunk.get_dict())
  202. elif isinstance(chunk, Usage):
  203. yield self._format_json("usage", chunk.get_dict())
  204. elif isinstance(chunk, Reasoning):
  205. yield self._format_json("reasoning", **chunk.get_dict())
  206. elif isinstance(chunk, YouTube):
  207. yield self._format_json("content", chunk.to_string())
  208. elif isinstance(chunk, Audio):
  209. yield self._format_json("audio", str(chunk))
  210. elif isinstance(chunk, DebugResponse):
  211. yield self._format_json("log", chunk.log)
  212. elif isinstance(chunk, RawResponse):
  213. yield self._format_json(chunk.type, **chunk.get_dict())
  214. else:
  215. yield self._format_json("content", str(chunk))
  216. yield from self._yield_logs()
  217. except Exception as e:
  218. logger.exception(e)
  219. debug.error(e)
  220. yield from self._yield_logs()
  221. yield self._format_json('error', type(e).__name__, message=get_error_message(e))
  222. def _yield_logs(self):
  223. if debug.logs:
  224. for log in debug.logs:
  225. yield self._format_json("log", log)
  226. debug.logs = []
  227. def _format_json(self, response_type: str, content = None, **kwargs):
  228. if content is not None and isinstance(response_type, str):
  229. return {
  230. 'type': response_type,
  231. response_type: content,
  232. **kwargs
  233. }
  234. return {
  235. 'type': response_type,
  236. **kwargs
  237. }
  238. def handle_provider(self, provider_handler, model):
  239. if isinstance(provider_handler, BaseRetryProvider) and provider_handler.last_provider is not None:
  240. provider_handler = provider_handler.last_provider
  241. if model:
  242. return self._format_json("provider", {**provider_handler.get_dict(), "model": model})
  243. return self._format_json("provider", provider_handler.get_dict())
  244. def get_error_message(exception: Exception) -> str:
  245. return f"{type(exception).__name__}: {exception}"