VoodoohopFlux1Schnell.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from __future__ import annotations
  2. from aiohttp import ClientSession
  3. import json
  4. from ...typing import AsyncResult, Messages
  5. from ...image import ImageResponse
  6. from ...errors import ResponseError
  7. from ...requests.raise_for_status import raise_for_status
  8. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  9. class VoodoohopFlux1Schnell(AsyncGeneratorProvider, ProviderModelMixin):
  10. url = "https://voodoohop-flux-1-schnell.hf.space"
  11. api_endpoint = "https://voodoohop-flux-1-schnell.hf.space/call/infer"
  12. working = True
  13. default_model = "voodoohop-flux-1-schnell"
  14. default_image_model = default_model
  15. image_models = [default_image_model]
  16. models = image_models
  17. model_aliases = {"flux-schnell": default_model}
  18. @classmethod
  19. async def create_async_generator(
  20. cls,
  21. model: str,
  22. messages: Messages,
  23. proxy: str = None,
  24. prompt: str = None,
  25. width: int = 768,
  26. height: int = 768,
  27. num_inference_steps: int = 2,
  28. seed: int = 0,
  29. randomize_seed: bool = True,
  30. **kwargs
  31. ) -> AsyncResult:
  32. width = max(32, width - (width % 8))
  33. height = max(32, height - (height % 8))
  34. if prompt is None:
  35. prompt = messages[-1]["content"]
  36. payload = {
  37. "data": [
  38. prompt,
  39. seed,
  40. randomize_seed,
  41. width,
  42. height,
  43. num_inference_steps
  44. ]
  45. }
  46. async with ClientSession() as session:
  47. async with session.post(cls.api_endpoint, json=payload, proxy=proxy) as response:
  48. await raise_for_status(response)
  49. response_data = await response.json()
  50. event_id = response_data['event_id']
  51. while True:
  52. async with session.get(f"{cls.api_endpoint}/{event_id}", proxy=proxy) as status_response:
  53. await raise_for_status(status_response)
  54. while not status_response.content.at_eof():
  55. event = await status_response.content.readuntil(b'\n\n')
  56. if event.startswith(b'event:'):
  57. event_parts = event.split(b'\ndata: ')
  58. if len(event_parts) < 2:
  59. continue
  60. event_type = event_parts[0].split(b': ')[1]
  61. data = event_parts[1]
  62. if event_type == b'error':
  63. raise ResponseError(f"Error generating image: {data}")
  64. elif event_type == b'complete':
  65. json_data = json.loads(data)
  66. image_url = json_data[0]['url']
  67. yield ImageResponse(images=[image_url], alt=prompt)
  68. return