ARTA.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. from __future__ import annotations
  2. import os
  3. import time
  4. import json
  5. import random
  6. from pathlib import Path
  7. from aiohttp import ClientSession
  8. import asyncio
  9. from ..typing import AsyncResult, Messages
  10. from ..providers.response import ImageResponse, Reasoning
  11. from ..errors import ResponseError
  12. from ..cookies import get_cookies_dir
  13. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  14. from .helper import format_image_prompt
  15. class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
  16. url = "https://ai-arta.com"
  17. auth_url = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/signupNewUser?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
  18. token_refresh_url = "https://securetoken.googleapis.com/v1/token?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
  19. image_generation_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image"
  20. status_check_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image/{record_id}/status"
  21. working = True
  22. default_model = "Flux"
  23. default_image_model = default_model
  24. model_aliases = {
  25. "flux": "Flux",
  26. "medieval": "Medieval",
  27. "vincent_van_gogh": "Vincent Van Gogh",
  28. "f_dev": "F Dev",
  29. "low_poly": "Low Poly",
  30. "dreamshaper_xl": "Dreamshaper-xl",
  31. "anima_pencil_xl": "Anima-pencil-xl",
  32. "biomech": "Biomech",
  33. "trash_polka": "Trash Polka",
  34. "no_style": "No Style",
  35. "cheyenne_xl": "Cheyenne-xl",
  36. "chicano": "Chicano",
  37. "embroidery_tattoo": "Embroidery tattoo",
  38. "red_and_black": "Red and Black",
  39. "fantasy_art": "Fantasy Art",
  40. "watercolor": "Watercolor",
  41. "dotwork": "Dotwork",
  42. "old_school_colored": "Old school colored",
  43. "realistic_tattoo": "Realistic tattoo",
  44. "japanese_2": "Japanese_2",
  45. "realistic_stock_xl": "Realistic-stock-xl",
  46. "f_pro": "F Pro",
  47. "revanimated": "RevAnimated",
  48. "katayama_mix_xl": "Katayama-mix-xl",
  49. "sdxl_l": "SDXL L",
  50. "cor_epica_xl": "Cor-epica-xl",
  51. "anime_tattoo": "Anime tattoo",
  52. "new_school": "New School",
  53. "death_metal": "Death metal",
  54. "old_school": "Old School",
  55. "juggernaut_xl": "Juggernaut-xl",
  56. "photographic": "Photographic",
  57. "sdxl_1_0": "SDXL 1.0",
  58. "graffiti": "Graffiti",
  59. "mini_tattoo": "Mini tattoo",
  60. "surrealism": "Surrealism",
  61. "neo_traditional": "Neo-traditional",
  62. "on_limbs_black": "On limbs black",
  63. "yamers_realistic_xl": "Yamers-realistic-xl",
  64. "pony_xl": "Pony-xl",
  65. "playground_xl": "Playground-xl",
  66. "anything_xl": "Anything-xl",
  67. "flame_design": "Flame design",
  68. "kawaii": "Kawaii",
  69. "cinematic_art": "Cinematic Art",
  70. "professional": "Professional",
  71. "black_ink": "Black Ink"
  72. }
  73. image_models = list(model_aliases.keys())
  74. models = image_models
  75. @classmethod
  76. def get_auth_file(cls):
  77. path = Path(get_cookies_dir())
  78. path.mkdir(exist_ok=True)
  79. filename = f"auth_{cls.__name__}.json"
  80. return path / filename
  81. @classmethod
  82. async def create_token(cls, path: Path, proxy: str | None = None):
  83. async with ClientSession() as session:
  84. # Step 1: Generate Authentication Token
  85. auth_payload = {"clientType": "CLIENT_TYPE_ANDROID"}
  86. async with session.post(cls.auth_url, json=auth_payload, proxy=proxy) as auth_response:
  87. auth_data = await auth_response.json()
  88. auth_token = auth_data.get("idToken")
  89. #refresh_token = auth_data.get("refreshToken")
  90. if not auth_token:
  91. raise ResponseError("Failed to obtain authentication token.")
  92. json.dump(auth_data, path.open("w"))
  93. return auth_data
  94. @classmethod
  95. async def refresh_token(cls, refresh_token: str, proxy: str = None) -> tuple[str, str]:
  96. async with ClientSession() as session:
  97. payload = {
  98. "grant_type": "refresh_token",
  99. "refresh_token": refresh_token,
  100. }
  101. async with session.post(cls.token_refresh_url, data=payload, proxy=proxy) as response:
  102. response_data = await response.json()
  103. return response_data.get("id_token"), response_data.get("refresh_token")
  104. @classmethod
  105. async def read_and_refresh_token(cls, proxy: str | None = None) -> str:
  106. path = cls.get_auth_file()
  107. if path.is_file():
  108. auth_data = json.load(path.open("rb"))
  109. diff = time.time() - os.path.getmtime(path)
  110. expiresIn = int(auth_data.get("expiresIn"))
  111. if diff < expiresIn:
  112. if diff > expiresIn / 2:
  113. auth_data["idToken"], auth_data["refreshToken"] = await cls.refresh_token(auth_data.get("refreshToken"), proxy)
  114. json.dump(auth_data, path.open("w"))
  115. return auth_data
  116. return await cls.create_token(path, proxy)
  117. @classmethod
  118. async def create_async_generator(
  119. cls,
  120. model: str,
  121. messages: Messages,
  122. proxy: str = None,
  123. prompt: str = None,
  124. negative_prompt: str = "blurry, deformed hands, ugly",
  125. n: int = 1,
  126. guidance_scale: int = 7,
  127. num_inference_steps: int = 30,
  128. aspect_ratio: str = "1:1",
  129. seed: int = None,
  130. **kwargs
  131. ) -> AsyncResult:
  132. model = cls.get_model(model)
  133. prompt = format_image_prompt(messages, prompt)
  134. # Generate a random seed if not provided
  135. if seed is None:
  136. seed = random.randint(9999, 99999999) # Common range for random seeds
  137. # Step 1: Get Authentication Token
  138. auth_data = await cls.read_and_refresh_token(proxy)
  139. async with ClientSession() as session:
  140. # Step 2: Generate Images
  141. image_payload = {
  142. "prompt": prompt,
  143. "negative_prompt": negative_prompt,
  144. "style": model,
  145. "images_num": str(n),
  146. "cfg_scale": str(guidance_scale),
  147. "steps": str(num_inference_steps),
  148. "aspect_ratio": aspect_ratio,
  149. "seed": str(seed),
  150. }
  151. headers = {
  152. "Authorization": auth_data.get("idToken"),
  153. }
  154. async with session.post(cls.image_generation_url, data=image_payload, headers=headers, proxy=proxy) as image_response:
  155. image_data = await image_response.json()
  156. record_id = image_data.get("record_id")
  157. if not record_id:
  158. raise ResponseError(f"Failed to initiate image generation: {image_data}")
  159. # Step 3: Check Generation Status
  160. status_url = cls.status_check_url.format(record_id=record_id)
  161. counter = 4
  162. start_time = time.time()
  163. last_status = None
  164. while True:
  165. async with session.get(status_url, headers=headers, proxy=proxy) as status_response:
  166. status_data = await status_response.json()
  167. status = status_data.get("status")
  168. if status == "DONE":
  169. image_urls = [image["url"] for image in status_data.get("response", [])]
  170. duration = time.time() - start_time
  171. yield Reasoning(label="Generated", status=f"{n} image(s) in {duration:.2f}s")
  172. yield ImageResponse(images=image_urls, alt=prompt)
  173. return
  174. elif status in ("IN_QUEUE", "IN_PROGRESS"):
  175. if last_status != status:
  176. last_status = status
  177. if status == "IN_QUEUE":
  178. yield Reasoning(label="Waiting")
  179. else:
  180. yield Reasoning(label="Generating")
  181. await asyncio.sleep(2) # Poll every 2 seconds
  182. else:
  183. raise ResponseError(f"Image generation failed with status: {status}")