HuggingFaceInference.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. from __future__ import annotations
  2. import json
  3. import base64
  4. import random
  5. import requests
  6. from ...typing import AsyncResult, Messages
  7. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt
  8. from ...errors import ModelNotSupportedError, ResponseError
  9. from ...requests import StreamSession, raise_for_status
  10. from ...providers.response import FinishReason, ImageResponse
  11. from ..helper import format_image_prompt, get_last_user_message
  12. from .models import default_model, default_image_model, model_aliases, fallback_models, image_models
  13. from ... import debug
  14. class HuggingFaceInference(AsyncGeneratorProvider, ProviderModelMixin):
  15. url = "https://huggingface.co"
  16. parent = "HuggingFace"
  17. working = True
  18. default_model = default_model
  19. default_image_model = default_image_model
  20. model_aliases = model_aliases
  21. image_models = image_models
  22. model_data: dict[str, dict] = {}
  23. @classmethod
  24. def get_models(cls) -> list[str]:
  25. if not cls.models:
  26. models = fallback_models.copy()
  27. url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
  28. response = requests.get(url)
  29. if response.ok:
  30. extra_models = [model["id"] for model in response.json()]
  31. extra_models.sort()
  32. models.extend([model for model in extra_models if model not in models])
  33. url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
  34. response = requests.get(url)
  35. if response.ok:
  36. cls.image_models = [model["id"] for model in response.json() if model.get("trendingScore", 0) >= 20]
  37. cls.image_models.sort()
  38. models.extend([model for model in cls.image_models if model not in models])
  39. cls.models = models
  40. return cls.models
  41. @classmethod
  42. async def get_model_data(cls, session: StreamSession, model: str) -> str:
  43. if model in cls.model_data:
  44. return cls.model_data[model]
  45. async with session.get(f"https://huggingface.co/api/models/{model}") as response:
  46. if response.status == 404:
  47. raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__}")
  48. await raise_for_status(response)
  49. cls.model_data[model] = await response.json()
  50. return cls.model_data[model]
  51. @classmethod
  52. async def create_async_generator(
  53. cls,
  54. model: str,
  55. messages: Messages,
  56. stream: bool = True,
  57. proxy: str = None,
  58. api_base: str = "https://api-inference.huggingface.co",
  59. api_key: str = None,
  60. max_tokens: int = 1024,
  61. temperature: float = None,
  62. prompt: str = None,
  63. action: str = None,
  64. extra_data: dict = {},
  65. seed: int = None,
  66. **kwargs
  67. ) -> AsyncResult:
  68. try:
  69. model = cls.get_model(model)
  70. except ModelNotSupportedError:
  71. pass
  72. headers = {
  73. 'accept': '*/*',
  74. 'accept-language': 'en',
  75. 'cache-control': 'no-cache',
  76. 'origin': 'https://huggingface.co',
  77. 'pragma': 'no-cache',
  78. 'priority': 'u=1, i',
  79. 'referer': 'https://huggingface.co/chat/',
  80. 'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
  81. 'sec-ch-ua-mobile': '?0',
  82. 'sec-ch-ua-platform': '"macOS"',
  83. 'sec-fetch-dest': 'empty',
  84. 'sec-fetch-mode': 'cors',
  85. 'sec-fetch-site': 'same-origin',
  86. 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
  87. }
  88. if api_key is not None:
  89. headers["Authorization"] = f"Bearer {api_key}"
  90. payload = None
  91. params = {
  92. "return_full_text": False,
  93. "max_new_tokens": max_tokens,
  94. "temperature": temperature,
  95. **extra_data
  96. }
  97. do_continue = action == "continue"
  98. async with StreamSession(
  99. headers=headers,
  100. proxy=proxy,
  101. timeout=600
  102. ) as session:
  103. if payload is None:
  104. model_data = await cls.get_model_data(session, model)
  105. pipeline_tag = model_data.get("pipeline_tag")
  106. if pipeline_tag == "text-to-image":
  107. stream = False
  108. inputs = format_image_prompt(messages, prompt)
  109. payload = {"inputs": inputs, "parameters": {"seed": random.randint(0, 2**32) if seed is None else seed, **extra_data}}
  110. elif pipeline_tag in ("text-generation", "image-text-to-text"):
  111. model_type = None
  112. if "config" in model_data and "model_type" in model_data["config"]:
  113. model_type = model_data["config"]["model_type"]
  114. debug.log(f"Model type: {model_type}")
  115. inputs = get_inputs(messages, model_data, model_type, do_continue)
  116. debug.log(f"Inputs len: {len(inputs)}")
  117. if len(inputs) > 4096:
  118. if len(messages) > 6:
  119. messages = messages[:3] + messages[-3:]
  120. else:
  121. messages = [m for m in messages if m["role"] == "system"] + [{"role": "user", "content": get_last_user_message(messages)}]
  122. inputs = get_inputs(messages, model_data, model_type, do_continue)
  123. debug.log(f"New len: {len(inputs)}")
  124. if model_type == "gpt2" and max_tokens >= 1024:
  125. params["max_new_tokens"] = 512
  126. if seed is not None:
  127. params["seed"] = seed
  128. payload = {"inputs": inputs, "parameters": params, "stream": stream}
  129. else:
  130. raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
  131. async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
  132. if response.status == 404:
  133. raise ModelNotSupportedError(f"Model is not supported: {model}")
  134. await raise_for_status(response)
  135. if stream:
  136. first = True
  137. is_special = False
  138. async for line in response.iter_lines():
  139. if line.startswith(b"data:"):
  140. data = json.loads(line[5:])
  141. if "error" in data:
  142. raise ResponseError(data["error"])
  143. if not data["token"]["special"]:
  144. chunk = data["token"]["text"]
  145. if first and not do_continue:
  146. first = False
  147. chunk = chunk.lstrip()
  148. if chunk:
  149. yield chunk
  150. else:
  151. is_special = True
  152. debug.log(f"Special token: {is_special}")
  153. yield FinishReason("stop" if is_special else "length")
  154. else:
  155. if response.headers["content-type"].startswith("image/"):
  156. base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))
  157. url = f"data:{response.headers['content-type']};base64,{base64_data.decode()}"
  158. yield ImageResponse(url, inputs)
  159. else:
  160. yield (await response.json())[0]["generated_text"].strip()
  161. def format_prompt_mistral(messages: Messages, do_continue: bool = False) -> str:
  162. system_messages = [message["content"] for message in messages if message["role"] == "system"]
  163. question = " ".join([messages[-1]["content"], *system_messages])
  164. history = "\n".join([
  165. f"<s>[INST]{messages[idx-1]['content']} [/INST] {message['content']}</s>"
  166. for idx, message in enumerate(messages)
  167. if message["role"] == "assistant"
  168. ])
  169. if do_continue:
  170. return history[:-len('</s>')]
  171. return f"{history}\n<s>[INST] {question} [/INST]"
  172. def format_prompt_qwen(messages: Messages, do_continue: bool = False) -> str:
  173. prompt = "".join([
  174. f"<|im_start|>{message['role']}\n{message['content']}\n<|im_end|>\n" for message in messages
  175. ]) + ("" if do_continue else "<|im_start|>assistant\n")
  176. if do_continue:
  177. return prompt[:-len("\n<|im_end|>\n")]
  178. return prompt
  179. def format_prompt_qwen2(messages: Messages, do_continue: bool = False) -> str:
  180. prompt = "".join([
  181. f"\u003C|{message['role'].capitalize()}|\u003E{message['content']}\u003C|end▁of▁sentence|\u003E" for message in messages
  182. ]) + ("" if do_continue else "\u003C|Assistant|\u003E")
  183. if do_continue:
  184. return prompt[:-len("\u003C|Assistant|\u003E")]
  185. return prompt
  186. def format_prompt_llama(messages: Messages, do_continue: bool = False) -> str:
  187. prompt = "<|begin_of_text|>" + "".join([
  188. f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}\n<|eot_id|>\n" for message in messages
  189. ]) + ("" if do_continue else "<|start_header_id|>assistant<|end_header_id|>\n\n")
  190. if do_continue:
  191. return prompt[:-len("\n<|eot_id|>\n")]
  192. return prompt
  193. def format_prompt_custom(messages: Messages, end_token: str = "</s>", do_continue: bool = False) -> str:
  194. prompt = "".join([
  195. f"<|{message['role']}|>\n{message['content']}{end_token}\n" for message in messages
  196. ]) + ("" if do_continue else "<|assistant|>\n")
  197. if do_continue:
  198. return prompt[:-len(end_token + "\n")]
  199. return prompt
  200. def get_inputs(messages: Messages, model_data: dict, model_type: str, do_continue: bool = False) -> str:
  201. if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"):
  202. inputs = format_prompt(messages, do_continue=do_continue)
  203. elif model_type == "mistral" and model_data.get("author") == "mistralai":
  204. inputs = format_prompt_mistral(messages, do_continue)
  205. elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]:
  206. eos_token = model_data["config"]["tokenizer_config"]["eos_token"]
  207. if eos_token in ("<|endoftext|>", "<eos>", "</s>"):
  208. inputs = format_prompt_custom(messages, eos_token, do_continue)
  209. elif eos_token == "<|im_end|>":
  210. inputs = format_prompt_qwen(messages, do_continue)
  211. elif "content" in eos_token and eos_token["content"] == "\u003C|end▁of▁sentence|\u003E":
  212. inputs = format_prompt_qwen2(messages, do_continue)
  213. elif eos_token == "<|eot_id|>":
  214. inputs = format_prompt_llama(messages, do_continue)
  215. else:
  216. inputs = format_prompt(messages, do_continue=do_continue)
  217. else:
  218. inputs = format_prompt(messages, do_continue=do_continue)
  219. return inputs