BlackForestLabsFlux1Schnell.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from __future__ import annotations
  2. from aiohttp import ClientSession
  3. import json
  4. from ...typing import AsyncResult, Messages
  5. from ...image import ImageResponse
  6. from ...errors import ResponseError
  7. from ...requests.raise_for_status import raise_for_status
  8. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  9. from ..helper import format_image_prompt
  10. class BlackForestLabsFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
  11. url = "https://black-forest-labs-flux-1-schnell.hf.space"
  12. api_endpoint = "https://black-forest-labs-flux-1-schnell.hf.space/call/infer"
  13. working = True
  14. default_model = "black-forest-labs-flux-1-schnell"
  15. default_image_model = default_model
  16. model_aliases = {"flux-schnell": default_model, "flux": default_model}
  17. image_models = [default_image_model, *model_aliases.keys()]
  18. models = image_models
  19. @classmethod
  20. async def create_async_generator(
  21. cls,
  22. model: str,
  23. messages: Messages,
  24. proxy: str = None,
  25. prompt: str = None,
  26. width: int = 768,
  27. height: int = 768,
  28. num_inference_steps: int = 2,
  29. seed: int = 0,
  30. randomize_seed: bool = True,
  31. **kwargs
  32. ) -> AsyncResult:
  33. model = cls.get_model(model)
  34. width = max(32, width - (width % 8))
  35. height = max(32, height - (height % 8))
  36. if prompt is None:
  37. prompt = format_image_prompt(messages)
  38. payload = {
  39. "data": [
  40. prompt,
  41. seed,
  42. randomize_seed,
  43. width,
  44. height,
  45. num_inference_steps
  46. ]
  47. }
  48. async with ClientSession() as session:
  49. async with session.post(cls.api_endpoint, json=payload, proxy=proxy) as response:
  50. await raise_for_status(response)
  51. response_data = await response.json()
  52. event_id = response_data['event_id']
  53. while True:
  54. async with session.get(f"{cls.api_endpoint}/{event_id}", proxy=proxy) as status_response:
  55. await raise_for_status(status_response)
  56. while not status_response.content.at_eof():
  57. event = await status_response.content.readuntil(b'\n\n')
  58. if event.startswith(b'event:'):
  59. event_parts = event.split(b'\ndata: ')
  60. if len(event_parts) < 2:
  61. continue
  62. event_type = event_parts[0].split(b': ')[1]
  63. data = event_parts[1]
  64. if event_type == b'error':
  65. raise ResponseError(f"Error generating image: {data}")
  66. elif event_type == b'complete':
  67. json_data = json.loads(data)
  68. image_url = json_data[0]['url']
  69. yield ImageResponse(images=[image_url], alt=prompt)
  70. return