ImageLabs.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from __future__ import annotations
  2. from aiohttp import ClientSession
  3. import time
  4. import asyncio
  5. from ..typing import AsyncResult, Messages
  6. from ..image import ImageResponse
  7. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  8. class ImageLabs(AsyncGeneratorProvider, ProviderModelMixin):
  9. url = "https://editor.imagelabs.net"
  10. api_endpoint = "https://editor.imagelabs.net/txt2img"
  11. working = True
  12. supports_stream = False
  13. supports_system_message = False
  14. supports_message_history = False
  15. default_model = 'general'
  16. default_image_model = default_model
  17. image_models = [default_image_model]
  18. models = image_models
  19. model_aliases = {"sdxl-turbo": default_model}
  20. @classmethod
  21. async def create_async_generator(
  22. cls,
  23. model: str,
  24. messages: Messages,
  25. proxy: str = None,
  26. # Image
  27. prompt: str = None,
  28. negative_prompt: str = "",
  29. width: int = 1152,
  30. height: int = 896,
  31. **kwargs
  32. ) -> AsyncResult:
  33. headers = {
  34. 'accept': '*/*',
  35. 'accept-language': 'en-US,en;q=0.9',
  36. 'cache-control': 'no-cache',
  37. 'content-type': 'application/json',
  38. 'origin': cls.url,
  39. 'referer': f'{cls.url}/',
  40. 'x-requested-with': 'XMLHttpRequest',
  41. 'user-agent': 'Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36'
  42. }
  43. async with ClientSession(headers=headers) as session:
  44. prompt = messages[-1]["content"] if prompt is None else prompt
  45. # Generate image
  46. payload = {
  47. "prompt": prompt,
  48. "seed": str(int(time.time())),
  49. "subseed": str(int(time.time() * 1000)),
  50. "attention": 0,
  51. "width": width,
  52. "height": height,
  53. "tiling": False,
  54. "negative_prompt": negative_prompt,
  55. "reference_image": "",
  56. "reference_image_type": None,
  57. "reference_strength": 30
  58. }
  59. async with session.post(f'{cls.url}/txt2img', json=payload, proxy=proxy) as generate_response:
  60. generate_data = await generate_response.json()
  61. task_id = generate_data.get('task_id')
  62. # Poll for progress
  63. while True:
  64. async with session.post(f'{cls.url}/progress', json={"task_id": task_id}, proxy=proxy) as progress_response:
  65. progress_data = await progress_response.json()
  66. # Check for completion or error states
  67. if progress_data.get('status') == 'Done' or progress_data.get('final_image_url'):
  68. # Yield ImageResponse with the final image URL
  69. yield ImageResponse(
  70. images=[progress_data.get('final_image_url')],
  71. alt=prompt
  72. )
  73. break
  74. # Check for queue or error states
  75. if 'error' in progress_data.get('status', '').lower():
  76. raise Exception(f"Image generation error: {progress_data}")
  77. # Wait between polls
  78. await asyncio.sleep(1)
  79. @classmethod
  80. def get_model(cls, model: str) -> str:
  81. return cls.default_model