HuggingFace.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from __future__ import annotations
  2. import json
  3. import base64
  4. import random
  5. import requests
  6. from ...typing import AsyncResult, Messages
  7. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  8. from ...errors import ModelNotFoundError, ModelNotSupportedError
  9. from ...requests import StreamSession, raise_for_status
  10. from ...image import ImageResponse
  11. from .HuggingChat import HuggingChat
  12. class HuggingFace(AsyncGeneratorProvider, ProviderModelMixin):
  13. url = "https://huggingface.co"
  14. working = True
  15. supports_message_history = True
  16. default_model = HuggingChat.default_model
  17. default_image_model = HuggingChat.default_image_model
  18. model_aliases = HuggingChat.model_aliases
  19. @classmethod
  20. def get_models(cls) -> list[str]:
  21. if not cls.models:
  22. url = "https://huggingface.co/api/models?inference=warm&pipeline_tag=text-generation"
  23. cls.models = [model["id"] for model in requests.get(url).json()]
  24. cls.models.append("meta-llama/Llama-3.2-11B-Vision-Instruct")
  25. cls.models.append("nvidia/Llama-3.1-Nemotron-70B-Instruct-HF")
  26. if not cls.image_models:
  27. url = "https://huggingface.co/api/models?pipeline_tag=text-to-image"
  28. cls.image_models = [model["id"] for model in requests.get(url).json() if model["trendingScore"] >= 20]
  29. cls.models.extend(cls.image_models)
  30. return cls.models
  31. @classmethod
  32. async def create_async_generator(
  33. cls,
  34. model: str,
  35. messages: Messages,
  36. stream: bool = True,
  37. proxy: str = None,
  38. api_base: str = "https://api-inference.huggingface.co",
  39. api_key: str = None,
  40. max_new_tokens: int = 1024,
  41. temperature: float = 0.7,
  42. prompt: str = None,
  43. **kwargs
  44. ) -> AsyncResult:
  45. try:
  46. model = cls.get_model(model)
  47. except ModelNotSupportedError:
  48. pass
  49. headers = {
  50. 'accept': '*/*',
  51. 'accept-language': 'en',
  52. 'cache-control': 'no-cache',
  53. 'origin': 'https://huggingface.co',
  54. 'pragma': 'no-cache',
  55. 'priority': 'u=1, i',
  56. 'referer': 'https://huggingface.co/chat/',
  57. 'sec-ch-ua': '"Not)A;Brand";v="99", "Google Chrome";v="127", "Chromium";v="127"',
  58. 'sec-ch-ua-mobile': '?0',
  59. 'sec-ch-ua-platform': '"macOS"',
  60. 'sec-fetch-dest': 'empty',
  61. 'sec-fetch-mode': 'cors',
  62. 'sec-fetch-site': 'same-origin',
  63. 'user-agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/127.0.0.0 Safari/537.36',
  64. }
  65. if api_key is not None:
  66. headers["Authorization"] = f"Bearer {api_key}"
  67. payload = None
  68. if model in cls.image_models:
  69. stream = False
  70. prompt = messages[-1]["content"] if prompt is None else prompt
  71. payload = {"inputs": prompt, "parameters": {"seed": random.randint(0, 2**32)}}
  72. else:
  73. params = {
  74. "return_full_text": False,
  75. "max_new_tokens": max_new_tokens,
  76. "temperature": temperature,
  77. **kwargs
  78. }
  79. async with StreamSession(
  80. headers=headers,
  81. proxy=proxy,
  82. timeout=600
  83. ) as session:
  84. if payload is None:
  85. async with session.get(f"https://huggingface.co/api/models/{model}") as response:
  86. model_data = await response.json()
  87. if "config" in model_data and "tokenizer_config" in model_data["config"] and "eos_token" in model_data["config"]["tokenizer_config"]:
  88. eos_token = model_data["config"]["tokenizer_config"]["eos_token"]
  89. if eos_token == "</s>":
  90. inputs = format_prompt_mistral(messages)
  91. elif eos_token == "<|im_end|>":
  92. inputs = format_prompt_qwen(messages)
  93. elif eos_token == "<|eot_id|>":
  94. inputs = format_prompt_llama(messages)
  95. else:
  96. inputs = format_prompt(messages)
  97. else:
  98. inputs = format_prompt(messages)
  99. payload = {"inputs": inputs, "parameters": params, "stream": stream}
  100. async with session.post(f"{api_base.rstrip('/')}/models/{model}", json=payload) as response:
  101. if response.status == 404:
  102. raise ModelNotFoundError(f"Model is not supported: {model}")
  103. await raise_for_status(response)
  104. if stream:
  105. first = True
  106. async for line in response.iter_lines():
  107. if line.startswith(b"data:"):
  108. data = json.loads(line[5:])
  109. if not data["token"]["special"]:
  110. chunk = data["token"]["text"]
  111. if first:
  112. first = False
  113. chunk = chunk.lstrip()
  114. if chunk:
  115. yield chunk
  116. else:
  117. if response.headers["content-type"].startswith("image/"):
  118. base64_data = base64.b64encode(b"".join([chunk async for chunk in response.iter_content()]))
  119. url = f"data:{response.headers['content-type']};base64,{base64_data.decode()}"
  120. yield ImageResponse(url, prompt)
  121. else:
  122. yield (await response.json())[0]["generated_text"].strip()
  123. def format_prompt(messages: Messages) -> str:
  124. system_messages = [message["content"] for message in messages if message["role"] == "system"]
  125. question = " ".join([messages[-1]["content"], *system_messages])
  126. history = "".join([
  127. f"<s>[INST]{messages[idx-1]['content']} [/INST] {message['content']}</s>"
  128. for idx, message in enumerate(messages)
  129. if message["role"] == "assistant"
  130. ])
  131. return f"{history}<s>[INST] {question} [/INST]"
  132. def format_prompt_qwen(messages: Messages) -> str:
  133. return "".join([
  134. f"<|im_start|>{message['role']}\n{message['content']}\n<|im_end|>\n" for message in messages
  135. ]) + "<|im_start|>assistant\n"
  136. def format_prompt_llama(messages: Messages) -> str:
  137. return "<|begin_of_text|>" + "".join([
  138. f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}\n<|eot_id|>\n" for message in messages
  139. ]) + "<|start_header_id|>assistant<|end_header_id|>\\n\\n"
  140. def format_prompt_mistral(messages: Messages) -> str:
  141. return "".join([
  142. f"<|{message['role']}|>\n{message['content']}'</s>\n" for message in messages
  143. ]) + "<|assistant|>\n"