BlackForestLabsFlux1Dev.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from __future__ import annotations
  2. import json
  3. from aiohttp import ClientSession
  4. from ...typing import AsyncResult, Messages
  5. from ...image import ImageResponse, ImagePreview
  6. from ...errors import ResponseError
  7. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  8. class BlackForestLabsFlux1Dev(AsyncGeneratorProvider, ProviderModelMixin):
  9. url = "https://black-forest-labs-flux-1-dev.hf.space"
  10. api_endpoint = "/gradio_api/call/infer"
  11. working = True
  12. default_model = 'black-forest-labs-flux-1-dev'
  13. default_image_model = default_model
  14. model_aliases = {"flux-dev": default_model, "flux": default_model}
  15. image_models = [default_image_model, *model_aliases.keys()]
  16. models = image_models
  17. @classmethod
  18. async def create_async_generator(
  19. cls,
  20. model: str,
  21. messages: Messages,
  22. prompt: str = None,
  23. api_key: str = None,
  24. proxy: str = None,
  25. width: int = 1024,
  26. height: int = 1024,
  27. guidance_scale: float = 3.5,
  28. num_inference_steps: int = 28,
  29. seed: int = 0,
  30. randomize_seed: bool = True,
  31. **kwargs
  32. ) -> AsyncResult:
  33. model = cls.get_model(model)
  34. headers = {
  35. "Content-Type": "application/json",
  36. "Accept": "application/json",
  37. }
  38. if api_key is not None:
  39. headers["Authorization"] = f"Bearer {api_key}"
  40. async with ClientSession(headers=headers) as session:
  41. prompt = messages[-1]["content"] if prompt is None else prompt
  42. data = {
  43. "data": [prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
  44. }
  45. async with session.post(f"{cls.url}{cls.api_endpoint}", json=data, proxy=proxy) as response:
  46. response.raise_for_status()
  47. event_id = (await response.json()).get("event_id")
  48. async with session.get(f"{cls.url}{cls.api_endpoint}/{event_id}") as event_response:
  49. event_response.raise_for_status()
  50. event = None
  51. async for chunk in event_response.content:
  52. if chunk.startswith(b"event: "):
  53. event = chunk[7:].decode(errors="replace").strip()
  54. if chunk.startswith(b"data: "):
  55. if event == "error":
  56. raise ResponseError(f"GPU token limit exceeded: {chunk.decode(errors='replace')}")
  57. if event in ("complete", "generating"):
  58. try:
  59. data = json.loads(chunk[6:])
  60. if data is None:
  61. continue
  62. url = data[0]["url"]
  63. except (json.JSONDecodeError, KeyError, TypeError) as e:
  64. raise RuntimeError(f"Failed to parse image URL: {chunk.decode(errors='replace')}", e)
  65. if event == "generating":
  66. yield ImagePreview(url, prompt)
  67. else:
  68. yield ImageResponse(url, prompt)
  69. break