VoodoohopFlux1Schnell.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from __future__ import annotations
  2. from aiohttp import ClientSession
  3. import json
  4. from ...typing import AsyncResult, Messages
  5. from ...image import ImageResponse
  6. from ...errors import ResponseError
  7. from ...requests.raise_for_status import raise_for_status
  8. from ..helper import format_image_prompt
  9. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  10. class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
  11. url = "https://voodoohop-flux-1-schnell.hf.space"
  12. api_endpoint = "https://voodoohop-flux-1-schnell.hf.space/call/infer"
  13. working = True
  14. default_model = "voodoohop-flux-1-schnell"
  15. default_image_model = default_model
  16. model_aliases = {"flux-schnell": default_model, "flux": default_model}
  17. image_models = [default_image_model, *model_aliases.keys()]
  18. models = image_models
  19. @classmethod
  20. async def create_async_generator(
  21. cls,
  22. model: str,
  23. messages: Messages,
  24. proxy: str = None,
  25. prompt: str = None,
  26. width: int = 768,
  27. height: int = 768,
  28. num_inference_steps: int = 2,
  29. seed: int = 0,
  30. randomize_seed: bool = True,
  31. **kwargs
  32. ) -> AsyncResult:
  33. width = max(32, width - (width % 8))
  34. height = max(32, height - (height % 8))
  35. prompt = format_image_prompt(messages, prompt)
  36. payload = {
  37. "data": [
  38. prompt,
  39. seed,
  40. randomize_seed,
  41. width,
  42. height,
  43. num_inference_steps
  44. ]
  45. }
  46. async with ClientSession() as session:
  47. async with session.post(cls.api_endpoint, json=payload, proxy=proxy) as response:
  48. await raise_for_status(response)
  49. response_data = await response.json()
  50. event_id = response_data['event_id']
  51. while True:
  52. async with session.get(f"{cls.api_endpoint}/{event_id}", proxy=proxy) as status_response:
  53. await raise_for_status(status_response)
  54. while not status_response.content.at_eof():
  55. event = await status_response.content.readuntil(b'\n\n')
  56. if event.startswith(b'event:'):
  57. event_parts = event.split(b'\ndata: ')
  58. if len(event_parts) < 2:
  59. continue
  60. event_type = event_parts[0].split(b': ')[1]
  61. data = event_parts[1]
  62. if event_type == b'error':
  63. raise ResponseError(f"Error generating image: {data}")
  64. elif event_type == b'complete':
  65. json_data = json.loads(data)
  66. image_url = json_data[0]['url']
  67. yield ImageResponse(images=[image_url], alt=prompt)
  68. return