PollinationsAI.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. from __future__ import annotations
  2. import random
  3. import requests
  4. from urllib.parse import quote_plus
  5. from typing import Optional
  6. from aiohttp import ClientSession
  7. from .helper import filter_none, format_image_prompt
  8. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  9. from ..typing import AsyncResult, Messages, ImagesType
  10. from ..image import to_data_uri
  11. from ..errors import ModelNotFoundError
  12. from ..requests.raise_for_status import raise_for_status
  13. from ..requests.aiohttp import get_connector
  14. from ..providers.response import ImageResponse, ImagePreview, FinishReason, Usage, Audio
  15. from .. import debug
  16. DEFAULT_HEADERS = {
  17. "accept": "*/*",
  18. 'accept-language': 'en-US,en;q=0.9',
  19. "user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/133.0.0.0 Safari/537.36",
  20. "priority": "u=1, i",
  21. "sec-ch-ua": "\"Not(A:Brand\";v=\"99\", \"Google Chrome\";v=\"133\", \"Chromium\";v=\"133\"",
  22. "sec-ch-ua-mobile": "?0",
  23. "sec-ch-ua-platform": "\"Linux\"",
  24. "sec-fetch-dest": "empty",
  25. "sec-fetch-mode": "cors",
  26. "sec-fetch-site": "same-site",
  27. "referer": "https://pollinations.ai/",
  28. "origin": "https://pollinations.ai",
  29. }
  30. class PollinationsAI(AsyncGeneratorProvider, ProviderModelMixin):
  31. label = "Pollinations AI"
  32. url = "https://pollinations.ai"
  33. working = True
  34. supports_stream = False
  35. supports_system_message = True
  36. supports_message_history = True
  37. # API endpoints
  38. text_api_endpoint = "https://text.pollinations.ai"
  39. openai_endpoint = "https://text.pollinations.ai/openai"
  40. image_api_endpoint = "https://image.pollinations.ai/"
  41. # Models configuration
  42. default_model = "openai"
  43. default_image_model = "flux"
  44. default_vision_model = "gpt-4o"
  45. text_models = [default_model]
  46. image_models = [default_image_model]
  47. extra_image_models = ["flux-pro", "flux-dev", "flux-schnell", "midjourney", "dall-e-3"]
  48. vision_models = [default_vision_model, "gpt-4o-mini", "o1-mini"]
  49. extra_text_models = vision_models
  50. _models_loaded = False
  51. model_aliases = {
  52. ### Text Models ###
  53. "gpt-4o-mini": "openai",
  54. "gpt-4": "openai-large",
  55. "gpt-4o": "openai-large",
  56. "o1-mini": "openai-reasoning",
  57. "qwen-2.5-coder-32b": "qwen-coder",
  58. "llama-3.3-70b": "llama",
  59. "mistral-nemo": "mistral",
  60. "gpt-4o-mini": "searchgpt",
  61. "llama-3.1-8b": "llamalight",
  62. "llama-3.3-70b": "llama-scaleway",
  63. "phi-4": "phi",
  64. "gemini-2.0": "gemini",
  65. "gemini-2.0-flash": "gemini",
  66. "gemini-2.0-flash-thinking": "gemini-thinking",
  67. ### Image Models ###
  68. "sdxl-turbo": "turbo",
  69. }
  70. @classmethod
  71. def get_models(cls, **kwargs):
  72. if not cls._models_loaded:
  73. try:
  74. # Update of image models
  75. image_response = requests.get("https://image.pollinations.ai/models")
  76. if image_response.ok:
  77. new_image_models = image_response.json()
  78. else:
  79. new_image_models = []
  80. # Combine models without duplicates
  81. all_image_models = (
  82. cls.image_models + # Already contains the default
  83. cls.extra_image_models +
  84. new_image_models
  85. )
  86. cls.image_models = list(dict.fromkeys(all_image_models))
  87. # Update of text models
  88. text_response = requests.get("https://text.pollinations.ai/models")
  89. text_response.raise_for_status()
  90. models = text_response.json()
  91. original_text_models = [
  92. model.get("name")
  93. for model in models
  94. if model.get("type") == "chat"
  95. ]
  96. cls.audio_models = {
  97. model.get("name"): model.get("voices")
  98. for model in models
  99. if model.get("audio")
  100. }
  101. # Combining text models
  102. combined_text = (
  103. cls.text_models + # Already contains the default
  104. cls.extra_text_models +
  105. [
  106. model for model in original_text_models
  107. if model not in cls.extra_text_models
  108. ]
  109. )
  110. cls.text_models = list(dict.fromkeys(combined_text))
  111. cls._models_loaded = True
  112. except Exception as e:
  113. # Save default models in case of an error
  114. if not cls.text_models:
  115. cls.text_models = [cls.default_model]
  116. if not cls.image_models:
  117. cls.image_models = [cls.default_image_model]
  118. debug.error(f"Failed to fetch models: {e}")
  119. return cls.text_models + cls.image_models
  120. @classmethod
  121. async def create_async_generator(
  122. cls,
  123. model: str,
  124. messages: Messages,
  125. proxy: str = None,
  126. prompt: str = None,
  127. width: int = 1024,
  128. height: int = 1024,
  129. seed: Optional[int] = None,
  130. nologo: bool = True,
  131. private: bool = False,
  132. enhance: bool = False,
  133. safe: bool = False,
  134. images: ImagesType = None,
  135. temperature: float = None,
  136. presence_penalty: float = None,
  137. top_p: float = 1,
  138. frequency_penalty: float = None,
  139. response_format: Optional[dict] = None,
  140. cache: bool = False,
  141. **kwargs
  142. ) -> AsyncResult:
  143. cls.get_models()
  144. if images is not None and not model:
  145. model = cls.default_vision_model
  146. try:
  147. model = cls.get_model(model)
  148. except ModelNotFoundError:
  149. if model not in cls.image_models:
  150. raise
  151. if model in cls.image_models:
  152. async for chunk in cls._generate_image(
  153. model=model,
  154. prompt=format_image_prompt(messages, prompt),
  155. proxy=proxy,
  156. width=width,
  157. height=height,
  158. seed=seed,
  159. cache=cache,
  160. nologo=nologo,
  161. private=private,
  162. enhance=enhance,
  163. safe=safe
  164. ):
  165. yield chunk
  166. else:
  167. async for result in cls._generate_text(
  168. model=model,
  169. messages=messages,
  170. images=images,
  171. proxy=proxy,
  172. temperature=temperature,
  173. presence_penalty=presence_penalty,
  174. top_p=top_p,
  175. frequency_penalty=frequency_penalty,
  176. response_format=response_format,
  177. seed=seed,
  178. cache=cache,
  179. ):
  180. yield result
  181. @classmethod
  182. async def _generate_image(
  183. cls,
  184. model: str,
  185. prompt: str,
  186. proxy: str,
  187. width: int,
  188. height: int,
  189. seed: Optional[int],
  190. cache: bool,
  191. nologo: bool,
  192. private: bool,
  193. enhance: bool,
  194. safe: bool
  195. ) -> AsyncResult:
  196. if not cache and seed is None:
  197. seed = random.randint(9999, 99999999)
  198. params = {
  199. "seed": str(seed) if seed is not None else None,
  200. "width": str(width),
  201. "height": str(height),
  202. "model": model,
  203. "nologo": str(nologo).lower(),
  204. "private": str(private).lower(),
  205. "enhance": str(enhance).lower(),
  206. "safe": str(safe).lower()
  207. }
  208. query = "&".join(f"{k}={quote_plus(v)}" for k, v in params.items() if v is not None)
  209. url = f"{cls.image_api_endpoint}prompt/{quote_plus(prompt)}?{query}"
  210. yield ImagePreview(url, prompt)
  211. async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
  212. async with session.get(url, allow_redirects=True) as response:
  213. await raise_for_status(response)
  214. image_url = str(response.url)
  215. yield ImageResponse(image_url, prompt)
  216. @classmethod
  217. async def _generate_text(
  218. cls,
  219. model: str,
  220. messages: Messages,
  221. images: Optional[ImagesType],
  222. proxy: str,
  223. temperature: float,
  224. presence_penalty: float,
  225. top_p: float,
  226. frequency_penalty: float,
  227. response_format: Optional[dict],
  228. seed: Optional[int],
  229. cache: bool
  230. ) -> AsyncResult:
  231. if not cache and seed is None:
  232. seed = random.randint(9999, 99999999)
  233. json_mode = False
  234. if response_format and response_format.get("type") == "json_object":
  235. json_mode = True
  236. if images and messages:
  237. last_message = messages[-1].copy()
  238. image_content = [
  239. {
  240. "type": "image_url",
  241. "image_url": {"url": to_data_uri(image)}
  242. }
  243. for image, _ in images
  244. ]
  245. last_message["content"] = image_content + [{"type": "text", "text": last_message["content"]}]
  246. messages[-1] = last_message
  247. async with ClientSession(headers=DEFAULT_HEADERS, connector=get_connector(proxy=proxy)) as session:
  248. data = filter_none(**{
  249. "messages": messages,
  250. "model": model,
  251. "temperature": temperature,
  252. "presence_penalty": presence_penalty,
  253. "top_p": top_p,
  254. "frequency_penalty": frequency_penalty,
  255. "jsonMode": json_mode,
  256. "stream": False,
  257. "seed": seed,
  258. "cache": cache
  259. })
  260. if "gemini" in model:
  261. data.pop("seed")
  262. if model in cls.audio_models:
  263. #data["voice"] = random.choice(cls.audio_models[model])
  264. url = cls.text_api_endpoint
  265. else:
  266. url = cls.openai_endpoint
  267. async with session.post(url, json=data) as response:
  268. await raise_for_status(response)
  269. if response.headers["content-type"] == "audio/mpeg":
  270. yield Audio(await response.read())
  271. return
  272. elif response.headers["content-type"].startswith("text/plain"):
  273. yield await response.text()
  274. return
  275. result = await response.json()
  276. choice = result["choices"][0]
  277. message = choice.get("message", {})
  278. content = message.get("content", "")
  279. if "</think>" in content and "<think>" not in content:
  280. yield "<think>"
  281. if content:
  282. yield content.replace("\\(", "(").replace("\\)", ")")
  283. if "usage" in result:
  284. yield Usage(**result["usage"])
  285. finish_reason = choice.get("finish_reason")
  286. if finish_reason:
  287. yield FinishReason(finish_reason)