StableDiffusion35Large.py 3.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. from __future__ import annotations
  2. import json
  3. from aiohttp import ClientSession
  4. from ...typing import AsyncResult, Messages
  5. from ...image import ImageResponse, ImagePreview
  6. from ...errors import ResponseError
  7. from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin
  8. class StableDiffusion35Large(AsyncGeneratorProvider, ProviderModelMixin):
  9. url = "https://stabilityai-stable-diffusion-3-5-large.hf.space"
  10. api_endpoint = "/gradio_api/call/infer"
  11. working = True
  12. default_model = 'stabilityai-stable-diffusion-3-5-large'
  13. default_image_model = default_model
  14. image_models = [default_model]
  15. models = image_models
  16. model_aliases = {"sd-3.5": default_model}
  17. @classmethod
  18. async def create_async_generator(
  19. cls, model: str, messages: Messages,
  20. prompt: str = None,
  21. negative_prompt: str = None,
  22. api_key: str = None,
  23. proxy: str = None,
  24. width: int = 1024,
  25. height: int = 1024,
  26. guidance_scale: float = 4.5,
  27. num_inference_steps: int = 50,
  28. seed: int = 0,
  29. randomize_seed: bool = True,
  30. **kwargs
  31. ) -> AsyncResult:
  32. headers = {
  33. "Content-Type": "application/json",
  34. "Accept": "application/json",
  35. }
  36. if api_key is not None:
  37. headers["Authorization"] = f"Bearer {api_key}"
  38. async with ClientSession(headers=headers) as session:
  39. prompt = messages[-1]["content"] if prompt is None else prompt
  40. data = {
  41. "data": [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps]
  42. }
  43. async with session.post(f"{cls.url}{cls.api_endpoint}", json=data, proxy=proxy) as response:
  44. response.raise_for_status()
  45. event_id = (await response.json()).get("event_id")
  46. async with session.get(f"{cls.url}{cls.api_endpoint}/{event_id}") as event_response:
  47. event_response.raise_for_status()
  48. event = None
  49. async for chunk in event_response.content:
  50. if chunk.startswith(b"event: "):
  51. event = chunk[7:].decode(errors="replace").strip()
  52. if chunk.startswith(b"data: "):
  53. if event == "error":
  54. raise ResponseError(f"GPU token limit exceeded: {chunk.decode(errors='replace')}")
  55. if event in ("complete", "generating"):
  56. try:
  57. data = json.loads(chunk[6:])
  58. if data is None:
  59. continue
  60. url = data[0]["url"]
  61. except (json.JSONDecodeError, KeyError, TypeError) as e:
  62. raise RuntimeError(f"Failed to parse image URL: {chunk.decode(errors='replace')}", e)
  63. if event == "generating":
  64. yield ImagePreview(url, prompt)
  65. else:
  66. yield ImageResponse(url, prompt)
  67. break