Airforce.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import json
  2. import random
  3. import re
  4. import requests
  5. from aiohttp import ClientSession
  6. from typing import List
  7. from ..typing import AsyncResult, Messages
  8. from ..image import ImageResponse
  9. from ..providers.response import FinishReason, Usage
  10. from ..requests.raise_for_status import raise_for_status
  11. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  12. from .. import debug
  13. def split_message(message: str, max_length: int = 1000) -> List[str]:
  14. """Splits the message into parts up to (max_length)."""
  15. chunks = []
  16. while len(message) > max_length:
  17. split_point = message.rfind(' ', 0, max_length)
  18. if split_point == -1:
  19. split_point = max_length
  20. chunks.append(message[:split_point])
  21. message = message[split_point:].strip()
  22. if message:
  23. chunks.append(message)
  24. return chunks
  25. class Airforce(AsyncGeneratorProvider, ProviderModelMixin):
  26. url = "https://api.airforce"
  27. api_endpoint_completions = "https://api.airforce/chat/completions"
  28. api_endpoint_imagine2 = "https://api.airforce/imagine2"
  29. working = True
  30. supports_stream = True
  31. supports_system_message = True
  32. supports_message_history = True
  33. default_model = "llama-3.1-70b-chat"
  34. default_image_model = "flux"
  35. models = []
  36. image_models = []
  37. hidden_models = {"Flux-1.1-Pro"}
  38. additional_models_imagine = ["flux-1.1-pro", "midjourney", "dall-e-3"]
  39. model_aliases = {
  40. # Alias mappings for models
  41. "openchat-3.5": "openchat-3.5-0106",
  42. "deepseek-coder": "deepseek-coder-6.7b-instruct",
  43. "hermes-2-dpo": "Nous-Hermes-2-Mixtral-8x7B-DPO",
  44. "hermes-2-pro": "hermes-2-pro-mistral-7b",
  45. "openhermes-2.5": "openhermes-2.5-mistral-7b",
  46. "lfm-40b": "lfm-40b-moe",
  47. "german-7b": "discolm-german-7b-v1",
  48. "llama-2-7b": "llama-2-7b-chat-int8",
  49. "llama-3.1-70b": "llama-3.1-70b-chat",
  50. "llama-3.1-8b": "llama-3.1-8b-chat",
  51. "llama-3.1-70b": "llama-3.1-70b-turbo",
  52. "llama-3.1-8b": "llama-3.1-8b-turbo",
  53. "neural-7b": "neural-chat-7b-v3-1",
  54. "zephyr-7b": "zephyr-7b-beta",
  55. "evil": "any-uncensored",
  56. "sdxl": "stable-diffusion-xl-lightning",
  57. "sdxl": "stable-diffusion-xl-base",
  58. "flux-pro": "flux-1.1-pro",
  59. "llama-3.1-8b": "llama-3.1-8b-chat"
  60. }
  61. @classmethod
  62. def get_models(cls):
  63. """Get available models with error handling"""
  64. if not cls.image_models:
  65. try:
  66. response = requests.get(
  67. f"{cls.url}/imagine2/models",
  68. headers={
  69. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
  70. }
  71. )
  72. response.raise_for_status()
  73. cls.image_models = response.json()
  74. if isinstance(cls.image_models, list):
  75. cls.image_models.extend(cls.additional_models_imagine)
  76. else:
  77. cls.image_models = cls.additional_models_imagine.copy()
  78. except Exception as e:
  79. debug.log(f"Error fetching image models: {e}")
  80. cls.image_models = cls.additional_models_imagine.copy()
  81. if not cls.models:
  82. try:
  83. response = requests.get(
  84. f"{cls.url}/models",
  85. headers={
  86. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
  87. }
  88. )
  89. response.raise_for_status()
  90. data = response.json()
  91. if isinstance(data, dict) and 'data' in data:
  92. cls.models = [model['id'] for model in data['data']]
  93. cls.models.extend(cls.image_models)
  94. cls.models = [model for model in cls.models if model not in cls.hidden_models]
  95. else:
  96. cls.models = list(cls.model_aliases.keys())
  97. except Exception as e:
  98. debug.log(f"Error fetching text models: {e}")
  99. cls.models = list(cls.model_aliases.keys())
  100. return cls.models or list(cls.model_aliases.keys())
  101. @classmethod
  102. def get_model(cls, model: str) -> str:
  103. """Get the actual model name from alias"""
  104. return cls.model_aliases.get(model, model or cls.default_model)
  105. @classmethod
  106. def _filter_content(cls, part_response: str) -> str:
  107. """
  108. Filters out unwanted content from the partial response.
  109. """
  110. part_response = re.sub(
  111. r"One message exceeds the \d+chars per message limit\..+https:\/\/discord\.com\/invite\/\S+",
  112. '',
  113. part_response
  114. )
  115. part_response = re.sub(
  116. r"Rate limit \(\d+\/minute\) exceeded\. Join our discord for more: .+https:\/\/discord\.com\/invite\/\S+",
  117. '',
  118. part_response
  119. )
  120. return part_response
  121. @classmethod
  122. def _filter_response(cls, response: str) -> str:
  123. """
  124. Filters the full response to remove system errors and other unwanted text.
  125. """
  126. if "Model not found or too long input. Or any other error (xD)" in response:
  127. raise ValueError(response)
  128. filtered_response = re.sub(r"\[ERROR\] '\w{8}-\w{4}-\w{4}-\w{4}-\w{12}'", '', response) # any-uncensored
  129. filtered_response = re.sub(r'<\|im_end\|>', '', filtered_response) # remove <|im_end|> token
  130. filtered_response = re.sub(r'</s>', '', filtered_response) # neural-chat-7b-v3-1
  131. filtered_response = re.sub(r'^(Assistant: |AI: |ANSWER: |Output: )', '', filtered_response) # phi-2
  132. filtered_response = cls._filter_content(filtered_response)
  133. return filtered_response
  134. @classmethod
  135. async def generate_image(
  136. cls,
  137. model: str,
  138. prompt: str,
  139. size: str,
  140. seed: int,
  141. proxy: str = None
  142. ) -> AsyncResult:
  143. headers = {
  144. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:133.0) Gecko/20100101 Firefox/133.0",
  145. "Accept": "image/avif,image/webp,image/png,image/svg+xml,image/*;q=0.8,*/*;q=0.5",
  146. "Accept-Language": "en-US,en;q=0.5",
  147. "Accept-Encoding": "gzip, deflate, br",
  148. "Content-Type": "application/json",
  149. }
  150. params = {"model": model, "prompt": prompt, "size": size, "seed": seed}
  151. async with ClientSession(headers=headers) as session:
  152. async with session.get(cls.api_endpoint_imagine2, params=params, proxy=proxy) as response:
  153. if response.status == 200:
  154. image_url = str(response.url)
  155. yield ImageResponse(images=image_url, alt=prompt)
  156. else:
  157. error_text = await response.text()
  158. raise RuntimeError(f"Image generation failed: {response.status} - {error_text}")
  159. @classmethod
  160. async def generate_text(
  161. cls,
  162. model: str,
  163. messages: Messages,
  164. max_tokens: int,
  165. temperature: float,
  166. top_p: float,
  167. stream: bool,
  168. proxy: str = None
  169. ) -> AsyncResult:
  170. """
  171. Generates text, buffers the response, filters it, and returns the final result.
  172. """
  173. headers = {
  174. "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:133.0) Gecko/20100101 Firefox/133.0",
  175. "Accept": "application/json, text/event-stream",
  176. "Accept-Language": "en-US,en;q=0.5",
  177. "Accept-Encoding": "gzip, deflate, br",
  178. "Content-Type": "application/json",
  179. }
  180. final_messages = []
  181. for message in messages:
  182. message_chunks = split_message(message["content"], max_length=1000)
  183. final_messages.extend([{"role": message["role"], "content": chunk} for chunk in message_chunks])
  184. data = {
  185. "messages": final_messages,
  186. "model": model,
  187. "temperature": temperature,
  188. "top_p": top_p,
  189. "stream": stream,
  190. }
  191. if max_tokens != 512:
  192. data["max_tokens"] = max_tokens
  193. async with ClientSession(headers=headers) as session:
  194. async with session.post(cls.api_endpoint_completions, json=data, proxy=proxy) as response:
  195. await raise_for_status(response)
  196. if stream:
  197. idx = 0
  198. async for line in response.content:
  199. line = line.decode('utf-8').strip()
  200. if line.startswith('data: '):
  201. try:
  202. json_str = line[6:] # Remove 'data: ' prefix
  203. chunk = json.loads(json_str)
  204. if 'choices' in chunk and chunk['choices']:
  205. delta = chunk['choices'][0].get('delta', {})
  206. if 'content' in delta:
  207. chunk = cls._filter_response(delta['content'])
  208. if chunk:
  209. yield chunk
  210. idx += 1
  211. except json.JSONDecodeError:
  212. continue
  213. if idx == 512:
  214. yield FinishReason("length")
  215. else:
  216. # Non-streaming response
  217. result = await response.json()
  218. if "usage" in result:
  219. yield Usage(**result["usage"])
  220. if result["usage"]["completion_tokens"] == 512:
  221. yield FinishReason("length")
  222. if 'choices' in result and result['choices']:
  223. message = result['choices'][0].get('message', {})
  224. content = message.get('content', '')
  225. filtered_response = cls._filter_response(content)
  226. yield filtered_response
  227. @classmethod
  228. async def create_async_generator(
  229. cls,
  230. model: str,
  231. messages: Messages,
  232. prompt: str = None,
  233. proxy: str = None,
  234. max_tokens: int = 512,
  235. temperature: float = 1,
  236. top_p: float = 1,
  237. stream: bool = True,
  238. size: str = "1:1",
  239. seed: int = None,
  240. **kwargs
  241. ) -> AsyncResult:
  242. model = cls.get_model(model)
  243. if model in cls.image_models:
  244. if prompt is None:
  245. prompt = messages[-1]['content']
  246. if seed is None:
  247. seed = random.randint(0, 10000)
  248. async for result in cls.generate_image(model, prompt, size, seed, proxy):
  249. yield result
  250. else:
  251. async for result in cls.generate_text(model, messages, max_tokens, temperature, top_p, stream, proxy):
  252. yield result