base_provider.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from __future__ import annotations
  2. import sys
  3. import asyncio
  4. from asyncio import AbstractEventLoop
  5. from concurrent.futures import ThreadPoolExecutor
  6. from abc import abstractmethod
  7. from inspect import signature, Parameter
  8. from ..typing import CreateResult, AsyncResult, Messages
  9. from .types import BaseProvider
  10. from .asyncio import get_running_loop, to_sync_generator
  11. from .response import FinishReason, BaseConversation, SynthesizeData
  12. from ..errors import ModelNotSupportedError
  13. from .. import debug
  14. # Set Windows event loop policy for better compatibility with asyncio and curl_cffi
  15. if sys.platform == 'win32':
  16. try:
  17. from curl_cffi import aio
  18. if not hasattr(aio, "_get_selector"):
  19. if isinstance(asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy):
  20. asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
  21. except ImportError:
  22. pass
  23. class AbstractProvider(BaseProvider):
  24. """
  25. Abstract class for providing asynchronous functionality to derived classes.
  26. """
  27. @classmethod
  28. async def create_async(
  29. cls,
  30. model: str,
  31. messages: Messages,
  32. *,
  33. loop: AbstractEventLoop = None,
  34. executor: ThreadPoolExecutor = None,
  35. **kwargs
  36. ) -> str:
  37. """
  38. Asynchronously creates a result based on the given model and messages.
  39. Args:
  40. cls (type): The class on which this method is called.
  41. model (str): The model to use for creation.
  42. messages (Messages): The messages to process.
  43. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  44. executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
  45. **kwargs: Additional keyword arguments.
  46. Returns:
  47. str: The created result as a string.
  48. """
  49. loop = loop or asyncio.get_running_loop()
  50. def create_func() -> str:
  51. return "".join(cls.create_completion(model, messages, False, **kwargs))
  52. return await asyncio.wait_for(
  53. loop.run_in_executor(executor, create_func),
  54. timeout=kwargs.get("timeout")
  55. )
  56. @classmethod
  57. def get_parameters(cls) -> dict[str, Parameter]:
  58. return {name: parameter for name, parameter in signature(
  59. cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
  60. cls.create_async if issubclass(cls, AsyncProvider) else
  61. cls.create_completion
  62. ).parameters.items() if name not in ["kwargs", "model", "messages"]
  63. and (name != "stream" or cls.supports_stream)}
  64. @classmethod
  65. @property
  66. def params(cls) -> str:
  67. """
  68. Returns the parameters supported by the provider.
  69. Args:
  70. cls (type): The class on which this property is called.
  71. Returns:
  72. str: A string listing the supported parameters.
  73. """
  74. def get_type_name(annotation: type) -> str:
  75. return annotation.__name__ if hasattr(annotation, "__name__") else str(annotation)
  76. args = ""
  77. for name, param in cls.get_parameters().items():
  78. args += f"\n {name}"
  79. args += f": {get_type_name(param.annotation)}" if param.annotation is not Parameter.empty else ""
  80. default_value = f'"{param.default}"' if isinstance(param.default, str) else param.default
  81. args += f" = {default_value}" if param.default is not Parameter.empty else ""
  82. args += ","
  83. return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
  84. class AsyncProvider(AbstractProvider):
  85. """
  86. Provides asynchronous functionality for creating completions.
  87. """
  88. @classmethod
  89. def create_completion(
  90. cls,
  91. model: str,
  92. messages: Messages,
  93. stream: bool = False,
  94. **kwargs
  95. ) -> CreateResult:
  96. """
  97. Creates a completion result synchronously.
  98. Args:
  99. cls (type): The class on which this method is called.
  100. model (str): The model to use for creation.
  101. messages (Messages): The messages to process.
  102. stream (bool): Indicates whether to stream the results. Defaults to False.
  103. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  104. **kwargs: Additional keyword arguments.
  105. Returns:
  106. CreateResult: The result of the completion creation.
  107. """
  108. get_running_loop(check_nested=False)
  109. yield asyncio.run(cls.create_async(model, messages, **kwargs))
  110. @staticmethod
  111. @abstractmethod
  112. async def create_async(
  113. model: str,
  114. messages: Messages,
  115. **kwargs
  116. ) -> str:
  117. """
  118. Abstract method for creating asynchronous results.
  119. Args:
  120. model (str): The model to use for creation.
  121. messages (Messages): The messages to process.
  122. **kwargs: Additional keyword arguments.
  123. Raises:
  124. NotImplementedError: If this method is not overridden in derived classes.
  125. Returns:
  126. str: The created result as a string.
  127. """
  128. raise NotImplementedError()
  129. class AsyncGeneratorProvider(AsyncProvider):
  130. """
  131. Provides asynchronous generator functionality for streaming results.
  132. """
  133. supports_stream = True
  134. @classmethod
  135. def create_completion(
  136. cls,
  137. model: str,
  138. messages: Messages,
  139. stream: bool = True,
  140. **kwargs
  141. ) -> CreateResult:
  142. """
  143. Creates a streaming completion result synchronously.
  144. Args:
  145. cls (type): The class on which this method is called.
  146. model (str): The model to use for creation.
  147. messages (Messages): The messages to process.
  148. stream (bool): Indicates whether to stream the results. Defaults to True.
  149. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  150. **kwargs: Additional keyword arguments.
  151. Returns:
  152. CreateResult: The result of the streaming completion creation.
  153. """
  154. return to_sync_generator(
  155. cls.create_async_generator(model, messages, stream=stream, **kwargs)
  156. )
  157. @classmethod
  158. async def create_async(
  159. cls,
  160. model: str,
  161. messages: Messages,
  162. **kwargs
  163. ) -> str:
  164. """
  165. Asynchronously creates a result from a generator.
  166. Args:
  167. cls (type): The class on which this method is called.
  168. model (str): The model to use for creation.
  169. messages (Messages): The messages to process.
  170. **kwargs: Additional keyword arguments.
  171. Returns:
  172. str: The created result as a string.
  173. """
  174. return "".join([
  175. str(chunk) async for chunk in cls.create_async_generator(model, messages, stream=False, **kwargs)
  176. if not isinstance(chunk, (Exception, FinishReason, BaseConversation, SynthesizeData))
  177. ])
  178. @staticmethod
  179. @abstractmethod
  180. async def create_async_generator(
  181. model: str,
  182. messages: Messages,
  183. stream: bool = True,
  184. **kwargs
  185. ) -> AsyncResult:
  186. """
  187. Abstract method for creating an asynchronous generator.
  188. Args:
  189. model (str): The model to use for creation.
  190. messages (Messages): The messages to process.
  191. stream (bool): Indicates whether to stream the results. Defaults to True.
  192. **kwargs: Additional keyword arguments.
  193. Raises:
  194. NotImplementedError: If this method is not overridden in derived classes.
  195. Returns:
  196. AsyncResult: An asynchronous generator yielding results.
  197. """
  198. raise NotImplementedError()
  199. class ProviderModelMixin:
  200. default_model: str = None
  201. models: list[str] = []
  202. model_aliases: dict[str, str] = {}
  203. image_models: list = None
  204. @classmethod
  205. def get_models(cls) -> list[str]:
  206. if not cls.models and cls.default_model is not None:
  207. return [cls.default_model]
  208. return cls.models
  209. @classmethod
  210. def get_model(cls, model: str) -> str:
  211. if not model and cls.default_model is not None:
  212. model = cls.default_model
  213. elif model in cls.model_aliases:
  214. model = cls.model_aliases[model]
  215. elif model not in cls.get_models() and cls.models:
  216. raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
  217. debug.last_model = model
  218. return model