__init__.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from __future__ import annotations
  2. from ...typing import AsyncResult, Messages
  3. from ...errors import ResponseError
  4. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  5. from .BlackForestLabsFlux1Dev import BlackForestLabsFlux1Dev
  6. from .BlackForestLabsFlux1Schnell import BlackForestLabsFlux1Schnell
  7. from .VoodoohopFlux1Schnell import VoodoohopFlux1Schnell
  8. class HuggingSpace(AsyncGeneratorProvider, ProviderModelMixin):
  9. url = "https://huggingface.co/spaces"
  10. working = True
  11. default_model = BlackForestLabsFlux1Dev.default_model
  12. providers = [BlackForestLabsFlux1Dev, BlackForestLabsFlux1Schnell, VoodoohopFlux1Schnell]
  13. @classmethod
  14. def get_parameters(cls, **kwargs) -> dict:
  15. parameters = {}
  16. for provider in cls.providers:
  17. parameters = {**parameters, **provider.get_parameters(**kwargs)}
  18. return parameters
  19. @classmethod
  20. def get_models(cls, **kwargs) -> list[str]:
  21. if not cls.models:
  22. for provider in cls.providers:
  23. cls.models.extend(provider.get_models(**kwargs))
  24. cls.models.extend(provider.model_aliases.keys())
  25. cls.models = list(set(cls.models))
  26. cls.models.sort()
  27. return cls.models
  28. @classmethod
  29. async def create_async_generator(
  30. cls, model: str, messages: Messages, **kwargs
  31. ) -> AsyncResult:
  32. is_started = False
  33. for provider in cls.providers:
  34. if model in provider.model_aliases:
  35. async for chunk in provider.create_async_generator(provider.model_aliases[model], messages, **kwargs):
  36. is_started = True
  37. yield chunk
  38. if is_started:
  39. return
  40. error = None
  41. for provider in cls.providers:
  42. if model in provider.get_models():
  43. try:
  44. async for chunk in provider.create_async_generator(model, messages, **kwargs):
  45. is_started = True
  46. yield chunk
  47. if is_started:
  48. break
  49. except ResponseError as e:
  50. if is_started:
  51. raise e
  52. error = e
  53. if not is_started and error is not None:
  54. raise error