__init__.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from __future__ import annotations
  2. import random
  3. from ...typing import AsyncResult, Messages, ImagesType
  4. from ...errors import ResponseError
  5. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  6. from .BlackForestLabsFlux1Dev import BlackForestLabsFlux1Dev
  7. from .BlackForestLabsFlux1Schnell import BlackForestLabsFlux1Schnell
  8. from .VoodoohopFlux1Schnell import VoodoohopFlux1Schnell
  9. from .CohereForAI import CohereForAI
  10. from .Janus_Pro_7B import Janus_Pro_7B
  11. from .Qwen_QVQ_72B import Qwen_QVQ_72B
  12. from .Qwen_Qwen_2_5M_Demo import Qwen_Qwen_2_5M_Demo
  13. from .Qwen_Qwen_2_72B_Instruct import Qwen_Qwen_2_72B_Instruct
  14. from .StableDiffusion35Large import StableDiffusion35Large
  15. class HuggingSpace(AsyncGeneratorProvider, ProviderModelMixin):
  16. url = "https://huggingface.co/spaces"
  17. parent = "HuggingFace"
  18. working = True
  19. default_model = Qwen_Qwen_2_72B_Instruct.default_model
  20. default_image_model = BlackForestLabsFlux1Dev.default_model
  21. default_vision_model = Qwen_QVQ_72B.default_model
  22. providers = [
  23. BlackForestLabsFlux1Dev, BlackForestLabsFlux1Schnell,
  24. VoodoohopFlux1Schnell,
  25. CohereForAI, Janus_Pro_7B,
  26. Qwen_QVQ_72B, Qwen_Qwen_2_5M_Demo, Qwen_Qwen_2_72B_Instruct,
  27. StableDiffusion35Large
  28. ]
  29. @classmethod
  30. def get_parameters(cls, **kwargs) -> dict:
  31. parameters = {}
  32. for provider in cls.providers:
  33. parameters = {**parameters, **provider.get_parameters(**kwargs)}
  34. return parameters
  35. @classmethod
  36. def get_models(cls, **kwargs) -> list[str]:
  37. if not cls.models:
  38. models = []
  39. image_models = []
  40. vision_models = []
  41. for provider in cls.providers:
  42. models.extend(provider.get_models(**kwargs))
  43. models.extend(provider.model_aliases.keys())
  44. image_models.extend(provider.image_models)
  45. vision_models.extend(provider.vision_models)
  46. models = list(set(models))
  47. models.sort()
  48. cls.models = models
  49. cls.image_models = list(set(image_models))
  50. cls.vision_models = list(set(vision_models))
  51. return cls.models
  52. @classmethod
  53. async def create_async_generator(
  54. cls, model: str, messages: Messages, images: ImagesType = None, **kwargs
  55. ) -> AsyncResult:
  56. if not model and images is not None:
  57. model = cls.default_vision_model
  58. is_started = False
  59. random.shuffle(cls.providers)
  60. for provider in cls.providers:
  61. if model in provider.model_aliases:
  62. async for chunk in provider.create_async_generator(provider.model_aliases[model], messages, **kwargs):
  63. is_started = True
  64. yield chunk
  65. if is_started:
  66. return
  67. error = None
  68. for provider in cls.providers:
  69. if model in provider.get_models():
  70. try:
  71. async for chunk in provider.create_async_generator(model, messages, **kwargs):
  72. is_started = True
  73. yield chunk
  74. if is_started:
  75. break
  76. except ResponseError as e:
  77. if is_started:
  78. raise e
  79. error = e
  80. if not is_started and error is not None:
  81. raise error