BlackForestLabsFlux1Dev.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. from __future__ import annotations
  2. import json
  3. from aiohttp import ClientSession
  4. from ...typing import AsyncResult, Messages
  5. from ...image import ImageResponse, ImagePreview
  6. from ...errors import ResponseError
  7. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  8. from ..helper import format_image_prompt
  9. class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
  10. url = "https://black-forest-labs-flux-1-dev.hf.space"
  11. api_endpoint = "/gradio_api/call/infer"
  12. working = True
  13. default_model = 'black-forest-labs-flux-1-dev'
  14. default_image_model = default_model
  15. model_aliases = {"flux-dev": default_model, "flux": default_model}
  16. image_models = [default_image_model, *model_aliases.keys()]
  17. models = image_models
  18. @classmethod
  19. async def create_async_generator(
  20. cls,
  21. model: str,
  22. messages: Messages,
  23. prompt: str = None,
  24. api_key: str = None,
  25. proxy: str = None,
  26. width: int = 1024,
  27. height: int = 1024,
  28. guidance_scale: float = 3.5,
  29. num_inference_steps: int = 28,
  30. seed: int = 0,
  31. randomize_seed: bool = True,
  32. **kwargs
  33. ) -> AsyncResult:
  34. model = cls.get_model(model)
  35. headers = {
  36. "Content-Type": "application/json",
  37. "Accept": "application/json",
  38. }
  39. if api_key is not None:
  40. headers["Authorization"] = f"Bearer {api_key}"
  41. async with ClientSession(headers=headers) as session:
  42. prompt = format_image_prompt(messages, prompt)
  43. data = {
  44. "data": [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
  45. }
  46. async with session.post(f"{cls.url}{cls.api_endpoint}", json=data, proxy=proxy) as response:
  47. response.raise_for_status()
  48. event_id = (await response.json()).get("event_id")
  49. async with session.get(f"{cls.url}{cls.api_endpoint}/{event_id}") as event_response:
  50. event_response.raise_for_status()
  51. event = None
  52. async for chunk in event_response.content:
  53. if chunk.startswith(b"event: "):
  54. event = chunk[7:].decode(errors="replace").strip()
  55. if chunk.startswith(b"data: "):
  56. if event == "error":
  57. raise ResponseError(f"GPU token limit exceeded: {chunk.decode(errors='replace')}")
  58. if event in ("complete", "generating"):
  59. try:
  60. data = json.loads(chunk[6:])
  61. if data is None:
  62. continue
  63. url = data[0]["url"]
  64. except (json.JSONDecodeError, KeyError, TypeError) as e:
  65. raise RuntimeError(f"Failed to parse image URL: {chunk.decode(errors='replace')}", e)
  66. if event == "generating":
  67. yield ImagePreview(url, prompt)
  68. else:
  69. yield ImageResponse(url, prompt)
  70. break