HuggingFace.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. from __future__ import annotations
  2. import json
  3. import base64
  4. import random
  5. import requests
  6. from ...typing import AsyncResult, Messages
  7. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, format_prompt
  8. from ...errors import ModelNotFoundError, ModelNotSupportedError, ResponseError
  9. from ...requests import StreamSession, raise_for_status
  10. from ...providers.response import FinishReason
  11. from ...image import ImageResponse
  12. from ... import debug
  13. from .HuggingChat import HuggingChat
  14. class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
  15. url = "https://huggingface.co"
  16. login_url = "https://huggingface.co/settings/tokens"
  17. working = True
  18. supports_message_history = True
  19. default_model = HuggingChat.default_model
  20. default_image_model = HuggingChat.default_image_model
  21. model_aliases = HuggingChat.model_aliases
  22. extra_models = [
  23. "meta-llama/Llama-3.2-11B-Vision-Instruct",
  24. "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF",
  25. "NousResearch/Hermes-3-Llama-3.1-8B",
  26. ]
  27. @classmethod
  28. def get_models(cls) -> list[str]:
  29. if not cls.models:
  30. url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
  31. models = [model["id"] for model in requests.get(url).json()]
  32. models.extend(cls.extra_models)
  33. models.sort()
  34. if not cls.image_models:
  35. url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
  36. cls.image_models = [model["id"] for model in requests.get(url).json() if model["trendingScore"] >= 20]
  37. cls.image_models.sort()
  38. models.extend(cls.image_models)
  39. cls.models = list(set(models))
  40. return cls.models
  41. @classmethod
  42. async def create_async_generator(
  43. cls,
  44. model: str,
  45. messages: Messages,
  46. stream: bool = True,
  47. proxy: str = None,
  48. api_base: str = "https://api-inference.huggingface.co",
  49. api_key: str = None,
  50. max_tokens: int = 1024,
  51. temperature: float = None,
  52. prompt: str = None,
  53. action: str = None,
  54. extra_data: dict = {},
  55. **kwargs
  56. ) -> AsyncResult:
  57. try:
  58. model = cls.get_model(model)
  59. except ModelNotSupportedError:
  60. pass
  61. headers = {
  62. 'accept': '*/*',
  63. 'accept-language': 'en',
  64. 'cache-control': 'no-cache',
  65. 'origin': 'https://huggingface.co',
  66. 'pragma': 'no-cache',
  67. 'priority': 'u=1, i',
  68. 'referer': 'https://huggingface.co/chat/',
  69. 'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
  70. 'sec-ch-ua-mobile': '?0',
  71. 'sec-ch-ua-platform': '"macOS"',
  72. 'sec-fetch-dest': 'empty',
  73. 'sec-fetch-mode': 'cors',
  74. 'sec-fetch-site': 'same-origin',
  75. 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
  76. }
  77. if api_key is not None:
  78. headers["Authorization"] = f"Bearer {api_key}"
  79. payload = None
  80. if cls.get_models() and model in cls.image_models:
  81. stream = False
  82. prompt = messages[-1]["content"] if prompt is None else prompt
  83. payload = {"inputs": prompt, "parameters": {"seed": random.randint(0, 2**32), **extra_data}}
  84. else:
  85. params = {
  86. "return_full_text": False,
  87. "max_new_tokens": max_tokens,
  88. "temperature": temperature,
  89. **extra_data
  90. }
  91. do_continue = action == "continue"
  92. async with StreamSession(
  93. headers=headers,
  94. proxy=proxy,
  95. timeout=600
  96. ) as session:
  97. if payload is None:
  98. async with session.get(f"https://huggingface.co/api/models/{model}") as response:
  99. await raise_for_status(response)
  100. model_data = await response.json()
  101. model_type = None
  102. if "config" in model_data and "model_type" in model_data["config"]:
  103. model_type = model_data["config"]["model_type"]
  104. debug.log(f"Model type: {model_type}")
  105. inputs = get_inputs(messages, model_data, model_type, do_continue)
  106. debug.log(f"Inputs len: {len(inputs)}")
  107. if len(inputs) > 4096:
  108. if len(messages) > 6:
  109. messages = messages[:3] + messages[-3:]
  110. else:
  111. messages = [m for m in messages if m["role"] == "system"] + [messages[-1]]
  112. inputs = get_inputs(messages, model_data, model_type, do_continue)
  113. debug.log(f"New len: {len(inputs)}")
  114. if model_type == "gpt2" and max_tokens >= 1024:
  115. params["max_new_tokens"] = 512
  116. payload = {"inputs": inputs, "parameters": params, "stream": stream}
  117. async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
  118. if response.status == 404:
  119. raise ModelNotFoundError(f"Model is not supported: {model}")
  120. await raise_for_status(response)
  121. if stream:
  122. first = True
  123. is_special = False
  124. async for line in response.iter_lines():
  125. if line.startswith(b"data:"):
  126. data = json.loads(line[5:])
  127. if "error" in data:
  128. raise ResponseError(data["error"])
  129. if not data["token"]["special"]:
  130. chunk = data["token"]["text"]
  131. if first and not do_continue:
  132. first = False
  133. chunk = chunk.lstrip()
  134. if chunk:
  135. yield chunk
  136. else:
  137. is_special = True
  138. debug.log(f"Special token: {is_special}")
  139. yield FinishReason("stop" if is_special else "length", actions=["variant"] if is_special else ["continue", "variant"])
  140. else:
  141. if response.headers["content-type"].startswith("image/"):
  142. base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))
  143. url = f"data:{response.headers['content-type']};base64,{base64_data.decode()}"
  144. yield ImageResponse(url, prompt)
  145. else:
  146. yield (await response.json())[0]["generated_text"].strip()
  147. def format_prompt_mistral(messages: Messages, do_continue: bool = False) -> str:
  148. system_messages = [message["content"] for message in messages if message["role"] == "system"]
  149. question = " ".join([messages[-1]["content"], *system_messages])
  150. history = "\n".join([
  151. f"<s>[INST]{messages[idx-1]['content']} [/INST] {message['content']}</s>"
  152. for idx, message in enumerate(messages)
  153. if message["role"] == "assistant"
  154. ])
  155. if do_continue:
  156. return history[:-len('</s>')]
  157. return f"{history}\n<s>[INST] {question} [/INST]"
  158. def format_prompt_qwen(messages: Messages, do_continue: bool = False) -> str:
  159. prompt = "".join([
  160. f"<|im_start|>{message['role']}\n{message['content']}\n<|im_end|>\n" for message in messages
  161. ]) + ("" if do_continue else "<|im_start|>assistant\n")
  162. if do_continue:
  163. return prompt[:-len("\n<|im_end|>\n")]
  164. return prompt
  165. def format_prompt_llama(messages: Messages, do_continue: bool = False) -> str:
  166. prompt = "<|begin_of_text|>" + "".join([
  167. f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}\n<|eot_id|>\n" for message in messages
  168. ]) + ("" if do_continue else "<|start_header_id|>assistant<|end_header_id|>\n\n")
  169. if do_continue:
  170. return prompt[:-len("\n<|eot_id|>\n")]
  171. return prompt
  172. def format_prompt_custom(messages: Messages, end_token: str = "</s>", do_continue: bool = False) -> str:
  173. prompt = "".join([
  174. f"<|{message['role']}|>\n{message['content']}{end_token}\n" for message in messages
  175. ]) + ("" if do_continue else "<|assistant|>\n")
  176. if do_continue:
  177. return prompt[:-len(end_token + "\n")]
  178. return prompt
  179. def get_inputs(messages: Messages, model_data: dict, model_type: str, do_continue: bool = False) -> str:
  180. if model_type in ("gpt2", "gpt_neo", "gemma", "gemma2"):
  181. inputs = format_prompt(messages, do_continue=do_continue)
  182. elif model_type == "mistral" and model_data.get("author") == "mistralai":
  183. inputs = format_prompt_mistral(messages, do_continue)
  184. elif "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]:
  185. eos_token = model_data["config"]["tokenizer_config"]["eos_token"]
  186. if eos_token in ("<|endoftext|>", "<eos>", "</s>"):
  187. inputs = format_prompt_custom(messages, eos_token, do_continue)
  188. elif eos_token == "<|im_end|>":
  189. inputs = format_prompt_qwen(messages, do_continue)
  190. elif eos_token == "<|eot_id|>":
  191. inputs = format_prompt_llama(messages, do_continue)
  192. else:
  193. inputs = format_prompt(messages, do_continue=do_continue)
  194. else:
  195. inputs = format_prompt(messages, do_continue=do_continue)
  196. return inputs