api.py 12 KB


  1. from __future__ import annotations
  2. import logging
  3. import os
  4. import asyncio
  5. from typing import Iterator
  6. from flask import send_from_directory, request
  7. from inspect import signature
  8. from ...errors import VersionNotFoundError, MissingAuthError
  9. from ...image.copy_images import copy_media, ensure_media_dir, get_media_dir
  10. from ...tools.run_tools import iter_run_tools
  11. from ... import Provider
  12. from ...providers.base_provider import ProviderModelMixin
  13. from ...providers.retry_provider import BaseRetryProvider
  14. from ...providers.helper import format_image_prompt
  15. from ...providers.response import *
  16. from ... import version, models
  17. from ... import ChatCompletion, get_model_and_provider
  18. from ... import debug
  19. logger = logging.getLogger(__name__)
  20. class Api:
  21. @staticmethod
  22. def get_models():
  23. return [{
  24. "name": model.name,
  25. "image": isinstance(model, models.ImageModel),
  26. "vision": isinstance(model, models.VisionModel),
  27. "audio": isinstance(model, models.AudioModel),
  28. "video": isinstance(model, models.VideoModel),
  29. "providers": [
  30. getattr(provider, "parent", provider.__name__)
  31. for provider in providers
  32. if provider.working
  33. ]
  34. }
  35. for model, providers in models.__models__.values()]
  36. @staticmethod
  37. def get_provider_models(provider: str, api_key: str = None, api_base: str = None, ignored: list = None):
  38. def get_model_data(provider: ProviderModelMixin, model: str):
  39. return {
  40. "model": model,
  41. "label": model.split(":")[-1] if provider.__name__ == "AnyProvider" else model,
  42. "default": model == provider.default_model,
  43. "vision": model in provider.vision_models,
  44. "audio": model in provider.audio_models,
  45. "video": model in provider.video_models,
  46. "image": model in provider.image_models,
  47. "count": provider.models_count.get(model),
  48. }
  49. if provider in Provider.__map__:
  50. provider = Provider.__map__[provider]
  51. if issubclass(provider, ProviderModelMixin):
  52. has_grouped_models = hasattr(provider, "get_grouped_models")
  53. method = provider.get_grouped_models if has_grouped_models else provider.get_models
  54. if "api_key" in signature(provider.get_models).parameters:
  55. models = method(api_key=api_key, api_base=api_base)
  56. elif "ignored" in signature(provider.get_models).parameters:
  57. models = method(ignored=ignored)
  58. else:
  59. models = method()
  60. if has_grouped_models:
  61. return [{
  62. "group": model["group"],
  63. "models": [get_model_data(provider, name) for name in model["models"]]
  64. } for model in models]
  65. return [
  66. get_model_data(provider, model)
  67. for model in models
  68. ]
  69. return []
  70. @staticmethod
  71. def get_providers() -> dict[str, str]:
  72. def safe_get_models(provider: ProviderModelMixin):
  73. if not isinstance(provider, ProviderModelMixin):
  74. return True
  75. try:
  76. return provider.get_models()
  77. except Exception as e:
  78. logger.exception(e)
  79. return True
  80. return [{
  81. "name": provider.__name__,
  82. "label": provider.label if hasattr(provider, "label") else provider.__name__,
  83. "parent": getattr(provider, "parent", None),
  84. "image": len(getattr(provider, "image_models", [])),
  85. "audio": len(getattr(provider, "audio_models", [])),
  86. "video": len(getattr(provider, "video_models", [])),
  87. "vision": getattr(provider, "default_vision_model", None) is not None,
  88. "nodriver": getattr(provider, "use_nodriver", False),
  89. "hf_space": getattr(provider, "hf_space", False),
  90. "auth": provider.needs_auth,
  91. "login_url": getattr(provider, "login_url", None),
  92. } for provider in Provider.__providers__ if provider.working and safe_get_models(provider)]
  93. @staticmethod
  94. def get_version() -> dict:
  95. current_version = None
  96. latest_version = None
  97. try:
  98. current_version = version.utils.current_version
  99. try:
  100. if request.args.get("cache"):
  101. latest_version = version.utils.latest_version_cached
  102. except RuntimeError:
  103. pass
  104. if latest_version is None:
  105. latest_version = version.utils.latest_version
  106. except VersionNotFoundError:
  107. pass
  108. return {
  109. "version": current_version,
  110. "latest_version": latest_version,
  111. }
  112. def serve_images(self, name):
  113. ensure_media_dir()
  114. return send_from_directory(os.path.abspath(get_media_dir()), name)
  115. def _prepare_conversation_kwargs(self, json_data: dict):
  116. kwargs = {**json_data}
  117. model = json_data.get('model')
  118. provider = json_data.get('provider')
  119. messages = json_data.get('messages')
  120. action = json_data.get('action')
  121. if action == "continue":
  122. kwargs["tool_calls"].append({
  123. "function": {
  124. "name": "continue_tool"
  125. },
  126. "type": "function"
  127. })
  128. conversation = json_data.get("conversation")
  129. if isinstance(conversation, dict):
  130. kwargs["conversation"] = JsonConversation(**conversation)
  131. return {
  132. "model": model,
  133. "provider": provider,
  134. "messages": messages,
  135. "stream": True,
  136. "ignore_stream": True,
  137. **kwargs
  138. }
  139. def _create_response_stream(self, kwargs: dict, provider: str, download_media: bool = True) -> Iterator:
  140. def decorated_log(text: str, file = None):
  141. debug.logs.append(text)
  142. if debug.logging:
  143. debug.log_handler(text, file=file)
  144. debug.log = decorated_log
  145. proxy = os.environ.get("G4F_PROXY")
  146. provider = kwargs.get("provider")
  147. try:
  148. model, provider_handler = get_model_and_provider(
  149. kwargs.get("model"), provider,
  150. stream=True,
  151. ignore_stream=True,
  152. logging=False,
  153. has_images="media" in kwargs,
  154. )
  155. except Exception as e:
  156. debug.error(e)
  157. yield self._format_json('error', type(e).__name__, message=get_error_message(e))
  158. return
  159. if not isinstance(provider_handler, BaseRetryProvider):
  160. if not provider:
  161. provider = provider_handler.__name__
  162. yield self.handle_provider(provider_handler, model)
  163. if hasattr(provider_handler, "get_parameters"):
  164. yield self._format_json("parameters", provider_handler.get_parameters(as_json=True))
  165. try:
  166. result = iter_run_tools(ChatCompletion.create, **{**kwargs, "model": model, "provider": provider_handler, "download_media": download_media})
  167. for chunk in result:
  168. if isinstance(chunk, ProviderInfo):
  169. yield self.handle_provider(chunk, model)
  170. elif isinstance(chunk, JsonConversation):
  171. if provider is not None:
  172. yield self._format_json("conversation", chunk.get_dict() if provider == "AnyProvider" else {
  173. provider: chunk.get_dict()
  174. })
  175. elif isinstance(chunk, Exception):
  176. logger.exception(chunk)
  177. yield self._format_json('message', get_error_message(chunk), error=type(chunk).__name__)
  178. elif isinstance(chunk, RequestLogin):
  179. yield self._format_json("preview", chunk.to_string())
  180. elif isinstance(chunk, PreviewResponse):
  181. yield self._format_json("preview", chunk.to_string())
  182. elif isinstance(chunk, ImagePreview):
  183. yield self._format_json("preview", chunk.to_string(), urls=chunk.urls, alt=chunk.alt)
  184. elif isinstance(chunk, MediaResponse):
  185. media = chunk
  186. if download_media or chunk.get("cookies"):
  187. chunk.alt = format_image_prompt(kwargs.get("messages"), chunk.alt)
  188. tags = [model, kwargs.get("aspect_ratio"), kwargs.get("resolution"), kwargs.get("width"), kwargs.get("height")]
  189. media = asyncio.run(copy_media(chunk.get_list(), chunk.get("cookies"), chunk.get("headers"), proxy=proxy, alt=chunk.alt, tags=tags))
  190. media = ImageResponse(media, chunk.alt) if isinstance(chunk, ImageResponse) else VideoResponse(media, chunk.alt)
  191. yield self._format_json("content", str(media), urls=chunk.urls, alt=chunk.alt)
  192. elif isinstance(chunk, SynthesizeData):
  193. yield self._format_json("synthesize", chunk.get_dict())
  194. elif isinstance(chunk, TitleGeneration):
  195. yield self._format_json("title", chunk.title)
  196. elif isinstance(chunk, RequestLogin):
  197. yield self._format_json("login", str(chunk))
  198. elif isinstance(chunk, Parameters):
  199. yield self._format_json("parameters", chunk.get_dict())
  200. elif isinstance(chunk, FinishReason):
  201. yield self._format_json("finish", chunk.get_dict())
  202. elif isinstance(chunk, Usage):
  203. yield self._format_json("usage", chunk.get_dict())
  204. elif isinstance(chunk, Reasoning):
  205. yield self._format_json("reasoning", **chunk.get_dict())
  206. elif isinstance(chunk, YouTube):
  207. yield self._format_json("content", chunk.to_string())
  208. elif isinstance(chunk, AudioResponse):
  209. yield self._format_json("content", str(chunk))
  210. elif isinstance(chunk, SuggestedFollowups):
  211. yield self._format_json("suggestions", chunk.suggestions)
  212. elif isinstance(chunk, DebugResponse):
  213. yield self._format_json("log", chunk.log)
  214. elif isinstance(chunk, RawResponse):
  215. yield self._format_json(chunk.type, **chunk.get_dict())
  216. else:
  217. yield self._format_json("content", str(chunk))
  218. except MissingAuthError as e:
  219. yield self._format_json('auth', type(e).__name__, message=get_error_message(e))
  220. except Exception as e:
  221. logger.exception(e)
  222. yield self._format_json('error', type(e).__name__, message=get_error_message(e))
  223. finally:
  224. yield from self._yield_logs()
  225. def _yield_logs(self):
  226. if debug.logs:
  227. for log in debug.logs:
  228. yield self._format_json("log", log)
  229. debug.logs = []
  230. def _format_json(self, response_type: str, content = None, **kwargs):
  231. if content is not None and isinstance(response_type, str):
  232. return {
  233. 'type': response_type,
  234. response_type: content,
  235. **kwargs
  236. }
  237. return {
  238. 'type': response_type,
  239. **kwargs
  240. }
  241. def handle_provider(self, provider_handler, model):
  242. if isinstance(provider_handler, BaseRetryProvider) and provider_handler.last_provider is not None:
  243. provider_handler = provider_handler.last_provider
  244. if model:
  245. return self._format_json("provider", {**provider_handler.get_dict(), "model": model})
  246. return self._format_json("provider", provider_handler.get_dict())
  247. def get_error_message(exception: Exception) -> str:
  248. return f"{type(exception).__name__}: {exception}"