HuggingFaceAPI.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from __future__ import annotations
  2. from ...providers.types import Messages
  3. from ...typing import ImagesType
  4. from ...requests import StreamSession, raise_for_status
  5. from ...errors import ModelNotSupportedError
  6. from ...providers.helper import get_last_user_message
  7. from ..template.OpenaiTemplate import OpenaiTemplate
  8. from .models import model_aliases, vision_models, default_vision_model
  9. from .HuggingChat import HuggingChat
  10. from ... import debug
  11. class HuggingFaceAPI(OpenaiTemplate):
  12. label = "HuggingFace (Inference API)"
  13. parent = "HuggingFace"
  14. url = "https://api-inference.huggingface.com"
  15. api_base = "https://api-inference.huggingface.co/v1"
  16. working = True
  17. needs_auth = True
  18. default_model = default_vision_model
  19. default_vision_model = default_vision_model
  20. vision_models = vision_models
  21. model_aliases = model_aliases
  22. pipeline_tags: dict[str, str] = {}
  23. @classmethod
  24. def get_models(cls, **kwargs):
  25. if not cls.models:
  26. HuggingChat.get_models()
  27. cls.models = HuggingChat.text_models.copy()
  28. for model in cls.vision_models:
  29. if model not in cls.models:
  30. cls.models.append(model)
  31. return cls.models
  32. @classmethod
  33. async def get_pipline_tag(cls, model: str, api_key: str = None):
  34. if model in cls.pipeline_tags:
  35. return cls.pipeline_tags[model]
  36. async with StreamSession(
  37. timeout=30,
  38. headers=cls.get_headers(False, api_key),
  39. ) as session:
  40. async with session.get(f"https://huggingface.co/api/models/{model}") as response:
  41. await raise_for_status(response)
  42. model_data = await response.json()
  43. cls.pipeline_tags[model] = model_data.get("pipeline_tag")
  44. return cls.pipeline_tags[model]
  45. @classmethod
  46. async def create_async_generator(
  47. cls,
  48. model: str,
  49. messages: Messages,
  50. api_base: str = None,
  51. api_key: str = None,
  52. max_tokens: int = 2048,
  53. max_inputs_lenght: int = 10000,
  54. images: ImagesType = None,
  55. **kwargs
  56. ):
  57. if model in cls.model_aliases:
  58. model = cls.model_aliases[model]
  59. api_base = f"https://api-inference.huggingface.co/models/{model}/v1"
  60. pipeline_tag = await cls.get_pipline_tag(model, api_key)
  61. if pipeline_tag not in ("text-generation", "image-text-to-text"):
  62. raise ModelNotSupportedError(f"Model is not supported: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
  63. elif images and pipeline_tag != "image-text-to-text":
  64. raise ModelNotSupportedError(f"Model does not support images: {model} in: {cls.__name__} pipeline_tag: {pipeline_tag}")
  65. start = calculate_lenght(messages)
  66. if start > max_inputs_lenght:
  67. if len(messages) > 6:
  68. messages = messages[:3] + messages[-3:]
  69. if calculate_lenght(messages) > max_inputs_lenght:
  70. last_user_message = [{"role": "user", "content": get_last_user_message(messages)}]
  71. if len(messages) > 2:
  72. messages = [m for m in messages if m["role"] == "system"] + last_user_message
  73. if len(messages) > 1 and calculate_lenght(messages) > max_inputs_lenght:
  74. messages = last_user_message
  75. debug.log(f"Messages trimmed from: {start} to: {calculate_lenght(messages)}")
  76. async for chunk in super().create_async_generator(model, messages, api_base=api_base, api_key=api_key, max_tokens=max_tokens, images=images, **kwargs):
  77. yield chunk
  78. def calculate_lenght(messages: Messages) -> int:
  79. return sum([len(message["content"]) + 16 for message in messages])