Flux.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. from __future__ import annotations
  2. import json
  3. from aiohttp import ClientSession
  4. from ..typing import AsyncResult, Messages
  5. from ..image import ImageResponse, ImagePreview
  6. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  7. class Flux(AsyncGeneratorProvider, ProviderModelMixin):
  8. label = "Flux (HuggingSpace)"
  9. url = "https://black-forest-labs-flux-1-dev.hf.space"
  10. api_endpoint = "/gradio_api/call/infer"
  11. working = True
  12. default_model = 'flux-dev'
  13. models = [default_model]
  14. image_models = [default_model]
  15. @classmethod
  16. async def create_async_generator(
  17. cls, model: str, messages: Messages, prompt: str = None, api_key: str = None, proxy: str = None, **kwargs
  18. ) -> AsyncResult:
  19. headers = {
  20. "Content-Type": "application/json",
  21. "Accept": "application/json",
  22. }
  23. if api_key is not None:
  24. headers["Authorization"] = f"Bearer {api_key}"
  25. async with ClientSession(headers=headers) as session:
  26. prompt = messages[-1]["content"] if prompt is None else prompt
  27. data = {
  28. "data": [prompt, 0, True, 1024, 1024, 3.5, 28]
  29. }
  30. async with session.post(f"{cls.url}{cls.api_endpoint}", json=data, proxy=proxy) as response:
  31. response.raise_for_status()
  32. event_id = (await response.json()).get("event_id")
  33. async with session.get(f"{cls.url}{cls.api_endpoint}/{event_id}") as event_response:
  34. event_response.raise_for_status()
  35. event = None
  36. async for chunk in event_response.content:
  37. if chunk.startswith(b"event: "):
  38. event = chunk[7:].decode(errors="replace").strip()
  39. if chunk.startswith(b"data: "):
  40. if event == "error":
  41. raise RuntimeError(f"GPU token limit exceeded: {chunk.decode(errors='replace')}")
  42. if event in ("complete", "generating"):
  43. try:
  44. data = json.loads(chunk[6:])
  45. if data is None:
  46. continue
  47. url = data[0]["url"]
  48. except (json.JSONDecodeError, KeyError, TypeError) as e:
  49. raise RuntimeError(f"Failed to parse image URL: {chunk.decode(errors='replace')}", e)
  50. if event == "generating":
  51. yield ImagePreview(url, prompt)
  52. else:
  53. yield ImageResponse(url, prompt)
  54. break