base_provider.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. from __future__ import annotations
  2. import asyncio
  3. from asyncio import AbstractEventLoop
  4. from concurrent.futures import ThreadPoolExecutor
  5. from abc import abstractmethod
  6. import json
  7. from inspect import signature, Parameter
  8. from typing import Optional, _GenericAlias
  9. from pathlib import Path
  10. try:
  11. from types import NoneType
  12. except ImportError:
  13. NoneType = type(None)
  14. from ..typing import CreateResult, AsyncResult, Messages
  15. from .types import BaseProvider
  16. from .asyncio import get_running_loop, to_sync_generator, to_async_iterator
  17. from .response import BaseConversation, AuthResult
  18. from .helper import concat_chunks
  19. from ..cookies import get_cookies_dir
  20. from ..errors import ModelNotFoundError, ResponseError, MissingAuthError, NoValidHarFileError, PaymentRequiredError
  21. SAFE_PARAMETERS = [
  22. "model", "messages", "stream", "timeout",
  23. "proxy", "media", "response_format",
  24. "prompt", "negative_prompt", "tools", "conversation",
  25. "history_disabled",
  26. "temperature", "top_k", "top_p",
  27. "frequency_penalty", "presence_penalty",
  28. "max_tokens", "stop",
  29. "api_key", "api_base", "seed", "width", "height",
  30. "max_retries", "web_search",
  31. "guidance_scale", "num_inference_steps", "randomize_seed",
  32. "safe", "enhance", "private", "aspect_ratio", "n", "transparent"
  33. ]
  34. BASIC_PARAMETERS = {
  35. "provider": None,
  36. "model": "",
  37. "messages": [],
  38. "stream": False,
  39. "timeout": 0,
  40. "response_format": None,
  41. "max_tokens": 4096,
  42. "stop": ["stop1", "stop2"],
  43. }
  44. PARAMETER_EXAMPLES = {
  45. "proxy": "http://user:password@127.0.0.1:3128",
  46. "temperature": 1,
  47. "top_k": 1,
  48. "top_p": 1,
  49. "frequency_penalty": 1,
  50. "presence_penalty": 1,
  51. "messages": [{"role": "system", "content": ""}, {"role": "user", "content": ""}],
  52. "media": [["data:image/jpeg;base64,...", "filename.jpg"]],
  53. "response_format": {"type": "json_object"},
  54. "conversation": {"conversation_id": "550e8400-e29b-11d4-a716-...", "message_id": "550e8400-e29b-11d4-a716-..."},
  55. "seed": 42,
  56. "tools": [],
  57. "width": 1024,
  58. "height": 1024,
  59. }
  60. class AbstractProvider(BaseProvider):
  61. @classmethod
  62. @abstractmethod
  63. def create_completion(
  64. cls,
  65. model: str,
  66. messages: Messages,
  67. stream: bool,
  68. **kwargs
  69. ) -> CreateResult:
  70. """
  71. Create a completion with the given parameters.
  72. Args:
  73. model (str): The model to use.
  74. messages (Messages): The messages to process.
  75. stream (bool): Whether to use streaming.
  76. **kwargs: Additional keyword arguments.
  77. Returns:
  78. CreateResult: The result of the creation process.
  79. """
  80. raise NotImplementedError()
  81. @classmethod
  82. async def create_async(
  83. cls,
  84. model: str,
  85. messages: Messages,
  86. *,
  87. timeout: int = None,
  88. loop: AbstractEventLoop = None,
  89. executor: ThreadPoolExecutor = None,
  90. **kwargs
  91. ) -> str:
  92. """
  93. Asynchronously creates a result based on the given model and messages.
  94. Args:
  95. cls (type): The class on which this method is called.
  96. model (str): The model to use for creation.
  97. messages (Messages): The messages to process.
  98. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  99. executor (ThreadPoolExecutor, optional): The executor for running async tasks. Defaults to None.
  100. **kwargs: Additional keyword arguments.
  101. Returns:
  102. str: The created result as a string.
  103. """
  104. loop = asyncio.get_running_loop() if loop is None else loop
  105. def create_func() -> str:
  106. return concat_chunks(cls.create_completion(model, messages, **kwargs))
  107. return await asyncio.wait_for(
  108. loop.run_in_executor(executor, create_func),
  109. timeout=timeout
  110. )
  111. @classmethod
  112. def create_function(cls, *args, **kwargs) -> CreateResult:
  113. """
  114. Creates a completion using the synchronous method.
  115. Args:
  116. **kwargs: Additional keyword arguments.
  117. Returns:
  118. CreateResult: The result of the completion creation.
  119. """
  120. return cls.create_completion(*args, **kwargs)
  121. @classmethod
  122. def async_create_function(cls, *args, **kwargs) -> AsyncResult:
  123. """
  124. Creates a completion using the synchronous method.
  125. Args:
  126. **kwargs: Additional keyword arguments.
  127. Returns:
  128. CreateResult: The result of the completion creation.
  129. """
  130. return cls.create_async(*args, **kwargs)
  131. @classmethod
  132. def get_parameters(cls, as_json: bool = False) -> dict[str, Parameter]:
  133. params = {name: parameter for name, parameter in signature(
  134. cls.create_async_generator if issubclass(cls, AsyncGeneratorProvider) else
  135. cls.create_async if issubclass(cls, AsyncProvider) else
  136. cls.create_completion
  137. ).parameters.items() if name in SAFE_PARAMETERS
  138. and (name != "stream" or cls.supports_stream)}
  139. if as_json:
  140. def get_type_as_var(annotation: type, key: str, default):
  141. if key in PARAMETER_EXAMPLES:
  142. if key == "messages" and not cls.supports_system_message:
  143. return [PARAMETER_EXAMPLES[key][-1]]
  144. return PARAMETER_EXAMPLES[key]
  145. if isinstance(annotation, type):
  146. if issubclass(annotation, int):
  147. return 0
  148. elif issubclass(annotation, float):
  149. return 0.0
  150. elif issubclass(annotation, bool):
  151. return False
  152. elif issubclass(annotation, str):
  153. return ""
  154. elif issubclass(annotation, dict):
  155. return {}
  156. elif issubclass(annotation, list):
  157. return []
  158. elif issubclass(annotation, BaseConversation):
  159. return {}
  160. elif issubclass(annotation, NoneType):
  161. return {}
  162. elif annotation is None:
  163. return None
  164. elif annotation == "str" or annotation == "list[str]":
  165. return default
  166. elif isinstance(annotation, _GenericAlias):
  167. if annotation.__origin__ is Optional:
  168. return get_type_as_var(annotation.__args__[0])
  169. else:
  170. return str(annotation)
  171. return { name: (
  172. param.default
  173. if isinstance(param, Parameter) and param.default is not Parameter.empty and param.default is not None
  174. else get_type_as_var(param.annotation, name, param.default) if isinstance(param, Parameter) else param
  175. ) for name, param in {
  176. **BASIC_PARAMETERS,
  177. **params,
  178. **{"provider": cls.__name__, "model": getattr(cls, "default_model", ""), "stream": cls.supports_stream},
  179. }.items()}
  180. return params
  181. @classmethod
  182. @property
  183. def params(cls) -> str:
  184. """
  185. Returns the parameters supported by the provider.
  186. Args:
  187. cls (type): The class on which this property is called.
  188. Returns:
  189. str: A string listing the supported parameters.
  190. """
  191. def get_type_name(annotation: type) -> str:
  192. return getattr(annotation, "__name__", str(annotation)) if annotation is not Parameter.empty else ""
  193. args = ""
  194. for name, param in cls.get_parameters().items():
  195. args += f"\n {name}"
  196. args += f": {get_type_name(param.annotation)}"
  197. default_value = getattr(cls, "default_model", "") if name == "model" else param.default
  198. default_value = f'"{default_value}"' if isinstance(default_value, str) else default_value
  199. args += f" = {default_value}" if param.default is not Parameter.empty else ""
  200. args += ","
  201. return f"g4f.Provider.{cls.__name__} supports: ({args}\n)"
  202. class AsyncProvider(AbstractProvider):
  203. """
  204. Provides asynchronous functionality for creating completions.
  205. """
  206. @classmethod
  207. def create_completion(
  208. cls,
  209. model: str,
  210. messages: Messages,
  211. stream: bool = False,
  212. **kwargs
  213. ) -> CreateResult:
  214. """
  215. Creates a completion result synchronously.
  216. Args:
  217. cls (type): The class on which this method is called.
  218. model (str): The model to use for creation.
  219. messages (Messages): The messages to process.
  220. stream (bool): Indicates whether to stream the results. Defaults to False.
  221. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  222. **kwargs: Additional keyword arguments.
  223. Returns:
  224. CreateResult: The result of the completion creation.
  225. """
  226. get_running_loop(check_nested=False)
  227. yield asyncio.run(cls.create_async(model, messages, **kwargs))
  228. @staticmethod
  229. @abstractmethod
  230. async def create_async(
  231. model: str,
  232. messages: Messages,
  233. **kwargs
  234. ) -> str:
  235. """
  236. Abstract method for creating asynchronous results.
  237. Args:
  238. model (str): The model to use for creation.
  239. messages (Messages): The messages to process.
  240. **kwargs: Additional keyword arguments.
  241. Raises:
  242. NotImplementedError: If this method is not overridden in derived classes.
  243. Returns:
  244. str: The created result as a string.
  245. """
  246. raise NotImplementedError()
  247. class AsyncGeneratorProvider(AbstractProvider):
  248. """
  249. Provides asynchronous generator functionality for streaming results.
  250. """
  251. supports_stream = True
  252. @classmethod
  253. def create_completion(
  254. cls,
  255. model: str,
  256. messages: Messages,
  257. stream: bool = True,
  258. timeout: int = None,
  259. **kwargs
  260. ) -> CreateResult:
  261. """
  262. Creates a streaming completion result synchronously.
  263. Args:
  264. cls (type): The class on which this method is called.
  265. model (str): The model to use for creation.
  266. messages (Messages): The messages to process.
  267. stream (bool): Indicates whether to stream the results. Defaults to True.
  268. loop (AbstractEventLoop, optional): The event loop to use. Defaults to None.
  269. **kwargs: Additional keyword arguments.
  270. Returns:
  271. CreateResult: The result of the streaming completion creation.
  272. """
  273. return to_sync_generator(
  274. cls.create_async_generator(model, messages, stream=stream, **kwargs),
  275. stream=stream,
  276. timeout=timeout
  277. )
  278. @staticmethod
  279. @abstractmethod
  280. async def create_async_generator(
  281. model: str,
  282. messages: Messages,
  283. stream: bool = True,
  284. **kwargs
  285. ) -> AsyncResult:
  286. """
  287. Abstract method for creating an asynchronous generator.
  288. Args:
  289. model (str): The model to use for creation.
  290. messages (Messages): The messages to process.
  291. stream (bool): Indicates whether to stream the results. Defaults to True.
  292. **kwargs: Additional keyword arguments.
  293. Raises:
  294. NotImplementedError: If this method is not overridden in derived classes.
  295. Returns:
  296. AsyncResult: An asynchronous generator yielding results.
  297. """
  298. raise NotImplementedError()
  299. @classmethod
  300. def async_create_function(cls, *args, **kwargs) -> AsyncResult:
  301. """
  302. Creates a completion using the synchronous method.
  303. Args:
  304. **kwargs: Additional keyword arguments.
  305. Returns:
  306. CreateResult: The result of the completion creation.
  307. """
  308. return cls.create_async_generator(*args, **kwargs)
  309. class ProviderModelMixin:
  310. default_model: str = None
  311. models: list[str] = []
  312. model_aliases: dict[str, str] = {}
  313. models_count: dict = {}
  314. image_models: list = []
  315. vision_models: list = []
  316. video_models: list = []
  317. audio_models: dict = {}
  318. last_model: str = None
  319. @classmethod
  320. def get_models(cls, **kwargs) -> list[str]:
  321. if not cls.models and cls.default_model is not None:
  322. cls.models = [cls.default_model]
  323. return cls.models
  324. @classmethod
  325. def get_model(cls, model: str, **kwargs) -> str:
  326. if not model and cls.default_model is not None:
  327. model = cls.default_model
  328. if model in cls.model_aliases:
  329. model = cls.model_aliases[model]
  330. if model not in cls.model_aliases.values():
  331. if model not in cls.get_models(**kwargs) and cls.models:
  332. raise ModelNotFoundError(f"Model not found: {model} in: {cls.__name__} Valid models: {cls.models}")
  333. cls.last_model = model
  334. return model
  335. class RaiseErrorMixin():
  336. @staticmethod
  337. def raise_error(data: dict, status: int = None):
  338. if "error_message" in data:
  339. raise ResponseError(data["error_message"])
  340. elif "error" in data:
  341. if isinstance(data["error"], str):
  342. if status is not None:
  343. if status == 401:
  344. raise MissingAuthError(f"Error {status}: {data['error']}")
  345. elif status == 402:
  346. raise PaymentRequiredError(f"Error {status}: {data['error']}")
  347. raise ResponseError(f"Error {status}: {data['error']}")
  348. raise ResponseError(data["error"])
  349. elif isinstance(data["error"], bool):
  350. raise ResponseError(data)
  351. elif "code" in data["error"]:
  352. raise ResponseError("\n".join(
  353. [e for e in [f'Error {data["error"]["code"]}: {data["error"]["message"]}', data["error"].get("failed_generation")] if e is not None]
  354. ))
  355. elif "message" in data["error"]:
  356. raise ResponseError(data["error"]["message"])
  357. else:
  358. raise ResponseError(data["error"])
  359. elif ("choices" not in data or not data["choices"]) and "data" not in data:
  360. raise ResponseError(f"Invalid response: {json.dumps(data)}")
  361. class AuthFileMixin():
  362. @classmethod
  363. def get_cache_file(cls) -> Path:
  364. return Path(get_cookies_dir()) / f"auth_{cls.parent if hasattr(cls, 'parent') else cls.__name__}.json"
  365. class AsyncAuthedProvider(AsyncGeneratorProvider, AuthFileMixin):
  366. @classmethod
  367. async def on_auth_async(cls, **kwargs) -> AuthResult:
  368. if "api_key" not in kwargs:
  369. raise MissingAuthError(f"API key is required for {cls.__name__}")
  370. return AuthResult()
  371. @classmethod
  372. def on_auth(cls, **kwargs) -> AuthResult:
  373. auth_result = cls.on_auth_async(**kwargs)
  374. if hasattr(auth_result, "__aiter__"):
  375. return to_sync_generator(auth_result)
  376. return asyncio.run(auth_result)
  377. @classmethod
  378. def write_cache_file(cls, cache_file: Path, auth_result: AuthResult = None):
  379. if auth_result is not None:
  380. cache_file.parent.mkdir(parents=True, exist_ok=True)
  381. try:
  382. def toJSON(obj):
  383. if hasattr(obj, "get_dict"):
  384. return obj.get_dict()
  385. return str(obj)
  386. with cache_file.open("w") as cache_file:
  387. json.dump(auth_result, cache_file, default=toJSON)
  388. except TypeError as e:
  389. raise RuntimeError(f"Failed to save: {auth_result.get_dict()}\n{type(e).__name__}: {e}")
  390. elif cache_file.exists():
  391. cache_file.unlink()
  392. @classmethod
  393. def get_auth_result(cls) -> AuthResult:
  394. """
  395. Retrieves the authentication result from cache.
  396. """
  397. cache_file = cls.get_cache_file()
  398. if cache_file.exists():
  399. try:
  400. with cache_file.open("r") as f:
  401. return AuthResult(**json.load(f))
  402. except json.JSONDecodeError:
  403. cache_file.unlink()
  404. raise MissingAuthError(f"Invalid auth file: {cache_file}")
  405. else:
  406. raise MissingAuthError
  407. @classmethod
  408. def create_completion(
  409. cls,
  410. model: str,
  411. messages: Messages,
  412. **kwargs
  413. ) -> CreateResult:
  414. auth_result: AuthResult = None
  415. cache_file = cls.get_cache_file()
  416. try:
  417. auth_result = cls.get_auth_result()
  418. yield from to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs))
  419. except (MissingAuthError, NoValidHarFileError):
  420. response = cls.on_auth(**kwargs)
  421. for chunk in response:
  422. if isinstance(chunk, AuthResult):
  423. auth_result = chunk
  424. else:
  425. yield chunk
  426. for chunk in to_sync_generator(cls.create_authed(model, messages, auth_result, **kwargs)):
  427. if cache_file is not None:
  428. cls.write_cache_file(cache_file, auth_result)
  429. cache_file = None
  430. yield chunk
  431. @classmethod
  432. async def create_async_generator(
  433. cls,
  434. model: str,
  435. messages: Messages,
  436. **kwargs
  437. ) -> AsyncResult:
  438. auth_result: AuthResult = None
  439. cache_file = cls.get_cache_file()
  440. try:
  441. auth_result = cls.get_auth_result()
  442. response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
  443. async for chunk in response:
  444. yield chunk
  445. except (MissingAuthError, NoValidHarFileError):
  446. if cache_file.exists():
  447. cache_file.unlink()
  448. response = cls.on_auth_async(**kwargs)
  449. async for chunk in response:
  450. if isinstance(chunk, AuthResult):
  451. auth_result = chunk
  452. else:
  453. yield chunk
  454. response = to_async_iterator(cls.create_authed(model, messages, **kwargs, auth_result=auth_result))
  455. async for chunk in response:
  456. if cache_file is not None:
  457. cls.write_cache_file(cache_file, auth_result)
  458. cache_file = None
  459. yield chunk