base_provider.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. from __future__ import annotations
  2. import asyncio
  3. from asyncio import AbstractEventLoop
  4. from concurrent.futures import ThreadPoolExecutor
  5. from abc import abstractmethod
  6. import json
  7. from inspect import signature, Parameter
  8. from typing import Optional, _GenericAlias
  9. from pathlib import Path
  10. try:
  11. from types import NoneType
  12. except ImportError:
  13. NoneType = type(None)
  14. from ..typing import CreateResult, AsyncResult, Messages
  15. from .types import BaseProvider
  16. from .asyncio import get_running_loop, to_sync_generator, to_async_iterator
  17. from .response import BaseConversation, AuthResult
  18. from .helper import concat_chunks
  19. from ..cookies import get_cookies_dir
  20. from ..errors import ModelNotSupportedError, ResponseError, MissingAuthError, NoValidHarFileError
  21. from .. import debug
  22. SAFE_PARAMETERS = [
  23. "model", "messages", "stream", "timeout",
  24. "proxy", "media", "response_format",
  25. "prompt", "negative_prompt", "tools", "conversation",
  26. "history_disabled",
  27. "temperature", "top_k", "top_p",
  28. "frequency_penalty", "presence_penalty",
  29. "max_tokens", "stop",
  30. "api_key", "api_base", "seed", "width", "height",
  31. "max_retries", "web_search",
  32. "guidance_scale", "num_inference_steps", "randomize_seed",
  33. "safe", "enhance", "private", "aspect_ratio", "n",
  34. ]
  35. BASIC_PARAMETERS = {
  36. "provider": None,
  37. "model": "",
  38. "messages": [],
  39. "stream": False,
  40. "timeout": 0,
  41. "response_format": None,
  42. "max_tokens": 4096,
  43. "stop": ["stop1", "stop2"],
  44. }
  45. PARAMETER_EXAMPLES = {
  46. "proxy": "http://user:password@127.0.0.1:3128",
  47. "temperature": 1,
  48. "top_k": 1,
  49. "top_p": 1,
  50. "frequency_penalty": 1,
  51. "presence_penalty": 1,
  52. "messages": [{"role": "system", "content": ""}, {"role": "user", "content": ""}],
  53. "media": [["data:image/jpeg;base64,...", "filename.jpg"]],
  54. "response_format": {"type": "json_object"},
  55. "conversation": {"conversation_id": "550e8400-e29b-11d4-a716-...", "message_id": "550e8400-e29b-11d4-a716-..."},
  56. "seed": 42,
  57. "tools": [],
  58. }
  59. class AbstractProvider(BaseProvider):
  60. @classmethod
  61. @abstractmethod
  62. def create_completion(
  63. cls,
  64. model: str,
  65. messages: Messages,
  66. stream: bool,
  67. **kwargs
  68. ) -> CreateResult:
  69. """
  70. Create a completion with the given parameters.
  71. Args:
  72. model (str): The model to use.
  73. messages (Messages): The messages to process.
  74. stream (bool): Whether to use streaming.
  75. **kwargs: Additional keyword arguments.
  76. Returns:
  77. CreateResult: The result of the creation process.
  78. """
  79. raise NotImplementedError()
  80. @classmethod
  81. async def create_async(
  82. cls,
  83. model: str,
  84. messages: Messages,
  85. *,
  86. timeout: int = None,
  87. loop: AbstractEventLoop = None,
  88. executor: ThreadPoolExecutor = None,
  89. **kwargs
  90. ) -> str:
  91. """
  92. Asynchronously creates a result based on the given model and messages.
  93. Args:
  94. cls (type): The class on which this method is called.
  95. model (str): The model to use for creation.
  96. messages (Messages): The messages to process.
  97. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  98. executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
  99. **kwargs: Additional keyword arguments.
  100. Returns:
  101. str: The created result as a string.
  102. """
  103. loop = asyncio.get_running_loop() if loop is None else loop
  104. def create_func() -> str:
  105. return concat_chunks(cls.create_completion(model, messages, **kwargs))
  106. return await asyncio.wait_for(
  107. loop.run_in_executor(executor, create_func),
  108. timeout=timeout
  109. )
  110. @classmethod
  111. def get_create_function(cls) -> callable:
  112. return cls.create_completion
  113. @classmethod
  114. def get_async_create_function(cls) -> callable:
  115. return cls.create_async
  116. @classmethod
  117. def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
  118. params = {name: parameter for name, parameter in signature(
  119. cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
  120. cls.create_async if issubclass(cls, AsyncProvider) else
  121. cls.create_completion
  122. ).parameters.items() if name in SAFE_PARAMETERS
  123. and (name != "stream" or cls.supports_stream)}
  124. if as_json:
  125. def get_type_as_var(annotation: type, key: str, default):
  126. if key in PARAMETER_EXAMPLES:
  127. if key == "messages" and not cls.supports_system_message:
  128. return [PARAMETER_EXAMPLES[key][-1]]
  129. return PARAMETER_EXAMPLES[key]
  130. if isinstance(annotation, type):
  131. if issubclass(annotation, int):
  132. return 0
  133. elif issubclass(annotation, float):
  134. return 0.0
  135. elif issubclass(annotation, bool):
  136. return False
  137. elif issubclass(annotation, str):
  138. return ""
  139. elif issubclass(annotation, dict):
  140. return {}
  141. elif issubclass(annotation, list):
  142. return []
  143. elif issubclass(annotation, BaseConversation):
  144. return {}
  145. elif issubclass(annotation, NoneType):
  146. return {}
  147. elif annotation is None:
  148. return None
  149. elif annotation == "str" or annotation == "list[str]":
  150. return default
  151. elif isinstance(annotation, _GenericAlias):
  152. if annotation.__origin__ is Optional:
  153. return get_type_as_var(annotation.__args__[0])
  154. else:
  155. return str(annotation)
  156. return { name: (
  157. param.default
  158. if isinstance(param, Parameter) and param.default is not Parameter.empty and param.default is not None
  159. else get_type_as_var(param.annotation, name, param.default) if isinstance(param, Parameter) else param
  160. ) for name, param in {
  161. **BASIC_PARAMETERS,
  162. **params,
  163. **{"provider": cls.__name__, "model": getattr(cls, "default_model", ""), "stream": cls.supports_stream},
  164. }.items()}
  165. return params
  166. @classmethod
  167. @property
  168. def params(cls) -> str:
  169. """
  170. Returns the parameters supported by the provider.
  171. Args:
  172. cls (type): The class on which this property is called.
  173. Returns:
  174. str: A string listing the supported parameters.
  175. """
  176. def get_type_name(annotation: type) -> str:
  177. return getattr(annotation, "__name__", str(annotation)) if annotation is not Parameter.empty else ""
  178. args = ""
  179. for name, param in cls.get_parameters().items():
  180. args += f"\n {name}"
  181. args += f": {get_type_name(param.annotation)}"
  182. default_value = getattr(cls, "default_model", "") if name == "model" else param.default
  183. default_value = f'"{default_value}"' if isinstance(default_value, str) else default_value
  184. args += f" = {default_value}" if param.default is not Parameter.empty else ""
  185. args += ","
  186. return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
  187. class AsyncProvider(AbstractProvider):
  188. """
  189. Provides asynchronous functionality for creating completions.
  190. """
  191. @classmethod
  192. def create_completion(
  193. cls,
  194. model: str,
  195. messages: Messages,
  196. stream: bool = False,
  197. **kwargs
  198. ) -> CreateResult:
  199. """
  200. Creates a completion result synchronously.
  201. Args:
  202. cls (type): The class on which this method is called.
  203. model (str): The model to use for creation.
  204. messages (Messages): The messages to process.
  205. stream (bool): Indicates whether to stream the results. Defaults to False.
  206. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  207. **kwargs: Additional keyword arguments.
  208. Returns:
  209. CreateResult: The result of the completion creation.
  210. """
  211. get_running_loop(check_nested=False)
  212. yield asyncio.run(cls.create_async(model, messages, **kwargs))
  213. @staticmethod
  214. @abstractmethod
  215. async def create_async(
  216. model: str,
  217. messages: Messages,
  218. **kwargs
  219. ) -> str:
  220. """
  221. Abstract method for creating asynchronous results.
  222. Args:
  223. model (str): The model to use for creation.
  224. messages (Messages): The messages to process.
  225. **kwargs: Additional keyword arguments.
  226. Raises:
  227. NotImplementedError: If this method is not overridden in derived classes.
  228. Returns:
  229. str: The created result as a string.
  230. """
  231. raise NotImplementedError()
  232. @classmethod
  233. def get_create_function(cls) -> callable:
  234. return cls.create_completion
  235. @classmethod
  236. def get_async_create_function(cls) -> callable:
  237. return cls.create_async
  238. class AsyncGeneratorProvider(AbstractProvider):
  239. """
  240. Provides asynchronous generator functionality for streaming results.
  241. """
  242. supports_stream = True
  243. @classmethod
  244. def create_completion(
  245. cls,
  246. model: str,
  247. messages: Messages,
  248. stream: bool = True,
  249. **kwargs
  250. ) -> CreateResult:
  251. """
  252. Creates a streaming completion result synchronously.
  253. Args:
  254. cls (type): The class on which this method is called.
  255. model (str): The model to use for creation.
  256. messages (Messages): The messages to process.
  257. stream (bool): Indicates whether to stream the results. Defaults to True.
  258. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  259. **kwargs: Additional keyword arguments.
  260. Returns:
  261. CreateResult: The result of the streaming completion creation.
  262. """
  263. return to_sync_generator(
  264. cls.create_async_generator(model, messages, stream=stream, **kwargs),
  265. stream=stream
  266. )
  267. @staticmethod
  268. @abstractmethod
  269. async def create_async_generator(
  270. model: str,
  271. messages: Messages,
  272. stream: bool = True,
  273. **kwargs
  274. ) -> AsyncResult:
  275. """
  276. Abstract method for creating an asynchronous generator.
  277. Args:
  278. model (str): The model to use for creation.
  279. messages (Messages): The messages to process.
  280. stream (bool): Indicates whether to stream the results. Defaults to True.
  281. **kwargs: Additional keyword arguments.
  282. Raises:
  283. NotImplementedError: If this method is not overridden in derived classes.
  284. Returns:
  285. AsyncResult: An asynchronous generator yielding results.
  286. """
  287. raise NotImplementedError()
  288. @classmethod
  289. def get_create_function(cls) -> callable:
  290. return cls.create_completion
  291. @classmethod
  292. def get_async_create_function(cls) -> callable:
  293. return cls.create_async_generator
  294. class ProviderModelMixin:
  295. default_model: str = None
  296. models: list[str] = []
  297. model_aliases: dict[str, str] = {}
  298. image_models: list = []
  299. vision_models: list = []
  300. last_model: str = None
  301. @classmethod
  302. def get_models(cls, **kwargs) -> list[str]:
  303. if not cls.models and cls.default_model is not None:
  304. return [cls.default_model]
  305. return cls.models
  306. @classmethod
  307. def get_model(cls, model: str, **kwargs) -> str:
  308. if not model and cls.default_model is not None:
  309. model = cls.default_model
  310. elif model in cls.model_aliases:
  311. model = cls.model_aliases[model]
  312. else:
  313. if model not in cls.get_models(**kwargs) and cls.models:
  314. raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} Valid models: {cls.models}")
  315. cls.last_model = model
  316. debug.last_model = model
  317. return model
  318. class RaiseErrorMixin():
  319. @staticmethod
  320. def raise_error(data: dict):
  321. if "error_message" in data:
  322. raise ResponseError(data["error_message"])
  323. elif "error" in data:
  324. if isinstance(data["error"], str):
  325. raise ResponseError(data["error"])
  326. elif "code" in data["error"]:
  327. raise ResponseError("\n".join(
  328. [e for e in [f'Error {data["error"]["code"]}: {data["error"]["message"]}', data["error"].get("failed_generation")] if e is not None]
  329. ))
  330. elif "message" in data["error"]:
  331. raise ResponseError(data["error"]["message"])
  332. else:
  333. raise ResponseError(data["error"])
  334. elif ("choices" not in data or not data["choices"]) and "data" not in data:
  335. raise ResponseError(f"Invalid response: {json.dumps(data)}")
  336. class AuthFileMixin():
  337. @classmethod
  338. def get_cache_file(cls) -> Path:
  339. return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
  340. class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
  341. @classmethod
  342. async def on_auth_async(cls, **kwargs) -> AuthResult:
  343. if "api_key" not in kwargs:
  344. raise MissingAuthError(f"API key is required for {cls.__name__}")
  345. return AuthResult()
  346. @classmethod
  347. def on_auth(cls, **kwargs) -> AuthResult:
  348. auth_result = cls.on_auth_async(**kwargs)
  349. if hasattr(auth_result, "__aiter__"):
  350. return to_sync_generator(auth_result)
  351. return asyncio.run(auth_result)
  352. @classmethod
  353. def get_create_function(cls) -> callable:
  354. return cls.create_completion
  355. @classmethod
  356. def get_async_create_function(cls) -> callable:
  357. return cls.create_async_generator
  358. @classmethod
  359. def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
  360. if auth_result is not None:
  361. cache_file.parent.mkdir(parents=True, exist_ok=True)
  362. cache_file.write_text(json.dumps(auth_result.get_dict()))
  363. elif cache_file.exists():
  364. cache_file.unlink()
  365. @classmethod
  366. def create_completion(
  367. cls,
  368. model: str,
  369. messages: Messages,
  370. **kwargs
  371. ) -> CreateResult:
  372. auth_result: AuthResult = None
  373. cache_file = cls.get_cache_file()
  374. try:
  375. if cache_file.exists():
  376. with cache_file.open("r") as f:
  377. auth_result = AuthResult(**json.load(f))
  378. else:
  379. raise MissingAuthError
  380. yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
  381. except (MissingAuthError, NoValidHarFileError):
  382. response = cls.on_auth(**kwargs)
  383. for chunk in response:
  384. if isinstance(chunk, AuthResult):
  385. auth_result = chunk
  386. else:
  387. yield chunk
  388. yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
  389. finally:
  390. cls.write_cache_file(cache_file, auth_result)
  391. @classmethod
  392. async def create_async_generator(
  393. cls,
  394. model: str,
  395. messages: Messages,
  396. **kwargs
  397. ) -> AsyncResult:
  398. auth_result: AuthResult = None
  399. cache_file = cls.get_cache_file()
  400. try:
  401. if cache_file.exists():
  402. with cache_file.open("r") as f:
  403. auth_result = AuthResult(**json.load(f))
  404. else:
  405. raise MissingAuthError
  406. response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
  407. async for chunk in response:
  408. yield chunk
  409. except (MissingAuthError, NoValidHarFileError):
  410. if cache_file.exists():
  411. cache_file.unlink()
  412. response = cls.on_auth_async(**kwargs)
  413. async for chunk in response:
  414. if isinstance(chunk, AuthResult):
  415. auth_result = chunk
  416. else:
  417. yield chunk
  418. response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
  419. async for chunk in response:
  420. if cache_file is not None:
  421. cls.write_cache_file(cache_file, auth_result)
  422. cache_file = None
  423. yield chunk
  424. finally:
  425. if cache_file is not None:
  426. cls.write_cache_file(cache_file, auth_result)