PollinationsAI.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. from __future__ import annotations
  2. import json
  3. import random
  4. import requests
  5. from urllib.parse import quote
  6. from typing import Optional
  7. from aiohttp import ClientSession
  8. from ..requests.raise_for_status import raise_for_status
  9. from ..typing import AsyncResult, Messages
  10. from ..image import ImageResponse
  11. from .needs_auth.OpenaiAPI import OpenaiAPI
  12. class PollinationsAI(OpenaiAPI):
  13. label = "Pollinations AI"
  14. url = "https://pollinations.ai"
  15. working = True
  16. needs_auth = True
  17. supports_stream = True
  18. supports_system_message = True
  19. supports_message_history = True
  20. # API endpoints base
  21. api_base = "https://text.pollinations.ai/openai"
  22. # API endpoints
  23. text_api_endpoint = "https://text.pollinations.ai"
  24. image_api_endpoint = "https://image.pollinations.ai"
  25. # Models configuration
  26. default_model = "openai"
  27. default_image_model = "flux"
  28. image_models = []
  29. models = []
  30. additional_models_image = ["midjourney", "dall-e-3"]
  31. additional_models_text = ["sur", "sur-mistral", "claude"]
  32. model_aliases = {
  33. "gpt-4o": "openai",
  34. "mistral-nemo": "mistral",
  35. "llama-3.1-70b": "llama",
  36. "gpt-4": "searchgpt",
  37. "gpt-4": "claude",
  38. "qwen-2.5-coder-32b": "qwen-coder",
  39. "claude-3.5-sonnet": "sur",
  40. }
  41. @classmethod
  42. def get_models(cls, **kwargs):
  43. # Initialize model lists if not exists
  44. if not hasattr(cls, 'image_models'):
  45. cls.image_models = []
  46. if not hasattr(cls, 'text_models'):
  47. cls.text_models = []
  48. # Fetch image models if not cached
  49. if not cls.image_models:
  50. url = "https://image.pollinations.ai/models"
  51. response = requests.get(url)
  52. raise_for_status(response)
  53. cls.image_models = response.json()
  54. cls.image_models.extend(cls.additional_models_image)
  55. # Fetch text models if not cached
  56. if not cls.text_models:
  57. url = "https://text.pollinations.ai/models"
  58. response = requests.get(url)
  59. raise_for_status(response)
  60. cls.text_models = [model.get("name") for model in response.json()]
  61. cls.text_models.extend(cls.additional_models_text)
  62. # Return combined models
  63. return cls.text_models + cls.image_models
  64. @classmethod
  65. async def create_async_generator(
  66. cls,
  67. model: str,
  68. messages: Messages,
  69. proxy: str = None,
  70. # Image specific parameters
  71. prompt: str = None,
  72. width: int = 1024,
  73. height: int = 1024,
  74. seed: Optional[int] = None,
  75. nologo: bool = True,
  76. private: bool = False,
  77. enhance: bool = False,
  78. safe: bool = False,
  79. # Text specific parameters
  80. api_key: str = None,
  81. temperature: float = 0.5,
  82. presence_penalty: float = 0,
  83. top_p: float = 1,
  84. frequency_penalty: float = 0,
  85. stream: bool = True,
  86. **kwargs
  87. ) -> AsyncResult:
  88. model = cls.get_model(model)
  89. # Check if models
  90. # Image generation
  91. if model in cls.image_models:
  92. async for result in cls._generate_image(
  93. model=model,
  94. messages=messages,
  95. prompt=prompt,
  96. proxy=proxy,
  97. width=width,
  98. height=height,
  99. seed=seed,
  100. nologo=nologo,
  101. private=private,
  102. enhance=enhance,
  103. safe=safe
  104. ):
  105. yield result
  106. else:
  107. # Text generation
  108. async for result in cls._generate_text(
  109. model=model,
  110. messages=messages,
  111. proxy=proxy,
  112. api_key=api_key,
  113. temperature=temperature,
  114. presence_penalty=presence_penalty,
  115. top_p=top_p,
  116. frequency_penalty=frequency_penalty,
  117. stream=stream
  118. ):
  119. yield result
  120. @classmethod
  121. async def _generate_image(
  122. cls,
  123. model: str,
  124. messages: Messages,
  125. prompt: str,
  126. proxy: str,
  127. width: int,
  128. height: int,
  129. seed: Optional[int],
  130. nologo: bool,
  131. private: bool,
  132. enhance: bool,
  133. safe: bool
  134. ) -> AsyncResult:
  135. if seed is None:
  136. seed = random.randint(0, 10000)
  137. headers = {
  138. 'Accept': '*/*',
  139. 'Accept-Language': 'en-US,en;q=0.9',
  140. 'User-Agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/130.0.0.0 Safari/537.36',
  141. }
  142. params = {
  143. "seed": seed,
  144. "width": width,
  145. "height": height,
  146. "model": model,
  147. "nologo": nologo,
  148. "private": private,
  149. "enhance": enhance,
  150. "safe": safe
  151. }
  152. params = {k: v for k, v in params.items() if v is not None}
  153. async with ClientSession(headers=headers) as session:
  154. prompt = quote(messages[-1]["content"] if prompt is None else prompt)
  155. param_string = "&".join(f"{k}={v}" for k, v in params.items())
  156. url = f"{cls.image_api_endpoint}/prompt/{prompt}?{param_string}"
  157. async with session.head(url, proxy=proxy) as response:
  158. if response.status == 200:
  159. image_response = ImageResponse(images=url, alt=messages[-1]["content"])
  160. yield image_response
  161. @classmethod
  162. async def _generate_text(
  163. cls,
  164. model: str,
  165. messages: Messages,
  166. proxy: str,
  167. api_key: str,
  168. temperature: float,
  169. presence_penalty: float,
  170. top_p: float,
  171. frequency_penalty: float,
  172. stream: bool
  173. ) -> AsyncResult:
  174. if api_key is None:
  175. api_key = "dummy" # Default value if api_key is not provided
  176. headers = {
  177. "accept": "*/*",
  178. "accept-language": "en-US,en;q=0.9",
  179. "authorization": f"Bearer {api_key}",
  180. "content-type": "application/json",
  181. "user-agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
  182. }
  183. async with ClientSession(headers=headers) as session:
  184. data = {
  185. "messages": messages,
  186. "model": model,
  187. "temperature": temperature,
  188. "presence_penalty": presence_penalty,
  189. "top_p": top_p,
  190. "frequency_penalty": frequency_penalty,
  191. "jsonMode": False,
  192. "stream": stream
  193. }
  194. async with session.post(cls.text_api_endpoint, json=data, proxy=proxy) as response:
  195. response.raise_for_status()
  196. async for chunk in response.content:
  197. if chunk:
  198. decoded_chunk = chunk.decode()
  199. try:
  200. json_response = json.loads(decoded_chunk)
  201. content = json_response['choices'][0]['message']['content']
  202. yield content
  203. except json.JSONDecodeError:
  204. yield decoded_chunk