StableDiffusion35Large.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. from __future__ import annotations
  2. import json
  3. from aiohttp import ClientSession
  4. from ...typing import AsyncResult, Messages
  5. from ...image import ImageResponse, ImagePreview
  6. from ...errors import ResponseError
  7. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  8. from ..helper import format_image_prompt
  9. class StableDiffusion35Large(AsyncGeneratorProvider, ProviderModelMixin):
  10. url = "https://stabilityai-stable-diffusion-3-5-large.hf.space"
  11. api_endpoint = "/gradio_api/call/infer"
  12. working = True
  13. default_model = 'stabilityai-stable-diffusion-3-5-large'
  14. default_image_model = default_model
  15. image_models = [default_model]
  16. models = image_models
  17. model_aliases = {"sd-3.5": default_model}
  18. @classmethod
  19. async def create_async_generator(
  20. cls, model: str, messages: Messages,
  21. prompt: str = None,
  22. negative_prompt: str = None,
  23. api_key: str = None,
  24. proxy: str = None,
  25. width: int = 1024,
  26. height: int = 1024,
  27. guidance_scale: float = 4.5,
  28. num_inference_steps: int = 50,
  29. seed: int = 0,
  30. randomize_seed: bool = True,
  31. **kwargs
  32. ) -> AsyncResult:
  33. headers = {
  34. "Content-Type": "application/json",
  35. "Accept": "application/json",
  36. }
  37. if api_key is not None:
  38. headers["Authorization"] = f"Bearer {api_key}"
  39. async with ClientSession(headers=headers) as session:
  40. prompt = format_image_prompt(messages, prompt)
  41. data = {
  42. "data": [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
  43. }
  44. async with session.post(f"{cls.url}{cls.api_endpoint}", json=data, proxy=proxy) as response:
  45. response.raise_for_status()
  46. event_id = (await response.json()).get("event_id")
  47. async with session.get(f"{cls.url}{cls.api_endpoint}/{event_id}") as event_response:
  48. event_response.raise_for_status()
  49. event = None
  50. async for chunk in event_response.content:
  51. if chunk.startswith(b"event: "):
  52. event = chunk[7:].decode(errors="replace").strip()
  53. if chunk.startswith(b"data: "):
  54. if event == "error":
  55. raise ResponseError(f"GPU token limit exceeded: {chunk.decode(errors='replace')}")
  56. if event in ("complete", "generating"):
  57. try:
  58. data = json.loads(chunk[6:])
  59. if data is None:
  60. continue
  61. url = data[0]["url"]
  62. except (json.JSONDecodeError, KeyError, TypeError) as e:
  63. raise RuntimeError(f"Failed to parse image URL: {chunk.decode(errors='replace')}", e)
  64. if event == "generating":
  65. yield ImagePreview(url, prompt)
  66. else:
  67. yield ImageResponse(url, prompt)
  68. break