ARTA.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. from __future__ import annotations
  2. import os
  3. import time
  4. import json
  5. import random
  6. from pathlib import Path
  7. from aiohttp import ClientSession
  8. import asyncio
  9. from ..typing import AsyncResult, Messages
  10. from ..providers.response import ImageResponse, Reasoning
  11. from ..errors import ResponseError
  12. from ..cookies import get_cookies_dir
  13. from .base_provider import AsyncGeneratorProvider, ProviderModelMixin
  14. from .helper import format_image_prompt
  15. class ARTA(AsyncGeneratorProvider, ProviderModelMixin):
  16. url = "https://ai-arta.com"
  17. auth_url = "https://www.googleapis.com/identitytoolkit/v3/relyingparty/signupNewUser?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
  18. token_refresh_url = "https://securetoken.googleapis.com/v1/token?key=AIzaSyB3-71wG0fIt0shj0ee4fvx1shcjJHGrrQ"
  19. image_generation_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image"
  20. status_check_url = "https://img-gen-prod.ai-arta.com/api/v1/text2image/{record_id}/status"
  21. working = True
  22. default_model = "Flux"
  23. default_image_model = default_model
  24. model_aliases = {
  25. default_image_model: default_image_model,
  26. "flux": default_image_model,
  27. "medieval": "Medieval",
  28. "vincent_van_gogh": "Vincent Van Gogh",
  29. "f_dev": "F Dev",
  30. "low_poly": "Low Poly",
  31. "dreamshaper_xl": "Dreamshaper-xl",
  32. "anima_pencil_xl": "Anima-pencil-xl",
  33. "biomech": "Biomech",
  34. "trash_polka": "Trash Polka",
  35. "no_style": "No Style",
  36. "cheyenne_xl": "Cheyenne-xl",
  37. "chicano": "Chicano",
  38. "embroidery_tattoo": "Embroidery tattoo",
  39. "red_and_black": "Red and Black",
  40. "fantasy_art": "Fantasy Art",
  41. "watercolor": "Watercolor",
  42. "dotwork": "Dotwork",
  43. "old_school_colored": "Old school colored",
  44. "realistic_tattoo": "Realistic tattoo",
  45. "japanese_2": "Japanese_2",
  46. "realistic_stock_xl": "Realistic-stock-xl",
  47. "f_pro": "F Pro",
  48. "revanimated": "RevAnimated",
  49. "katayama_mix_xl": "Katayama-mix-xl",
  50. "sdxl_l": "SDXL L",
  51. "cor_epica_xl": "Cor-epica-xl",
  52. "anime_tattoo": "Anime tattoo",
  53. "new_school": "New School",
  54. "death_metal": "Death metal",
  55. "old_school": "Old School",
  56. "juggernaut_xl": "Juggernaut-xl",
  57. "photographic": "Photographic",
  58. "sdxl_1_0": "SDXL 1.0",
  59. "graffiti": "Graffiti",
  60. "mini_tattoo": "Mini tattoo",
  61. "surrealism": "Surrealism",
  62. "neo_traditional": "Neo-traditional",
  63. "on_limbs_black": "On limbs black",
  64. "yamers_realistic_xl": "Yamers-realistic-xl",
  65. "pony_xl": "Pony-xl",
  66. "playground_xl": "Playground-xl",
  67. "anything_xl": "Anything-xl",
  68. "flame_design": "Flame design",
  69. "kawaii": "Kawaii",
  70. "cinematic_art": "Cinematic Art",
  71. "professional": "Professional",
  72. "black_ink": "Black Ink"
  73. }
  74. image_models = list(model_aliases.keys())
  75. models = image_models
  76. @classmethod
  77. def get_auth_file(cls):
  78. path = Path(get_cookies_dir())
  79. path.mkdir(exist_ok=True)
  80. filename = f"auth_{cls.__name__}.json"
  81. return path / filename
  82. @classmethod
  83. async def create_token(cls, path: Path, proxy: str | None = None):
  84. async with ClientSession() as session:
  85. # Step 1: Generate Authentication Token
  86. auth_payload = {"clientType": "CLIENT_TYPE_ANDROID"}
  87. async with session.post(cls.auth_url, json=auth_payload, proxy=proxy) as auth_response:
  88. if auth_response.status >= 400:
  89. error_text = await auth_response.text()
  90. raise ResponseError(f"Failed to obtain authentication token. Status: {auth_response.status}, Response: {error_text}")
  91. try:
  92. auth_data = await auth_response.json()
  93. except Exception as e:
  94. error_text = await auth_response.text()
  95. content_type = auth_response.headers.get('Content-Type', 'unknown')
  96. raise ResponseError(f"Failed to parse auth response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
  97. auth_token = auth_data.get("idToken")
  98. #refresh_token = auth_data.get("refreshToken")
  99. if not auth_token:
  100. raise ResponseError("Failed to obtain authentication token.")
  101. json.dump(auth_data, path.open("w"))
  102. return auth_data
  103. @classmethod
  104. async def refresh_token(cls, refresh_token: str, proxy: str = None) -> tuple[str, str]:
  105. async with ClientSession() as session:
  106. payload = {
  107. "grant_type": "refresh_token",
  108. "refresh_token": refresh_token,
  109. }
  110. async with session.post(cls.token_refresh_url, data=payload, proxy=proxy) as response:
  111. if response.status >= 400:
  112. error_text = await response.text()
  113. raise ResponseError(f"Failed to refresh token. Status: {response.status}, Response: {error_text}")
  114. try:
  115. response_data = await response.json()
  116. except Exception as e:
  117. error_text = await response.text()
  118. content_type = response.headers.get('Content-Type', 'unknown')
  119. raise ResponseError(f"Failed to parse token refresh response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
  120. return response_data.get("id_token"), response_data.get("refresh_token")
  121. @classmethod
  122. async def read_and_refresh_token(cls, proxy: str | None = None) -> str:
  123. path = cls.get_auth_file()
  124. if path.is_file():
  125. auth_data = json.load(path.open("rb"))
  126. diff = time.time() - os.path.getmtime(path)
  127. expiresIn = int(auth_data.get("expiresIn"))
  128. if diff < expiresIn:
  129. if diff > expiresIn / 2:
  130. auth_data["idToken"], auth_data["refreshToken"] = await cls.refresh_token(auth_data.get("refreshToken"), proxy)
  131. json.dump(auth_data, path.open("w"))
  132. return auth_data
  133. return await cls.create_token(path, proxy)
  134. @classmethod
  135. async def create_async_generator(
  136. cls,
  137. model: str,
  138. messages: Messages,
  139. proxy: str = None,
  140. prompt: str = None,
  141. negative_prompt: str = "blurry, deformed hands, ugly",
  142. n: int = 1,
  143. guidance_scale: int = 7,
  144. num_inference_steps: int = 30,
  145. aspect_ratio: str = "1:1",
  146. seed: int = None,
  147. **kwargs
  148. ) -> AsyncResult:
  149. model = cls.get_model(model)
  150. prompt = format_image_prompt(messages, prompt)
  151. # Generate a random seed if not provided
  152. if seed is None:
  153. seed = random.randint(9999, 99999999) # Common range for random seeds
  154. # Step 1: Get Authentication Token
  155. auth_data = await cls.read_and_refresh_token(proxy)
  156. async with ClientSession() as session:
  157. # Step 2: Generate Images
  158. image_payload = {
  159. "prompt": prompt,
  160. "negative_prompt": negative_prompt,
  161. "style": model,
  162. "images_num": str(n),
  163. "cfg_scale": str(guidance_scale),
  164. "steps": str(num_inference_steps),
  165. "aspect_ratio": aspect_ratio,
  166. "seed": str(seed),
  167. }
  168. headers = {
  169. "Authorization": auth_data.get("idToken"),
  170. }
  171. async with session.post(cls.image_generation_url, data=image_payload, headers=headers, proxy=proxy) as image_response:
  172. if image_response.status >= 400:
  173. error_text = await image_response.text()
  174. raise ResponseError(f"Failed to initiate image generation. Status: {image_response.status}, Response: {error_text}")
  175. try:
  176. image_data = await image_response.json()
  177. except Exception as e:
  178. error_text = await image_response.text()
  179. content_type = image_response.headers.get('Content-Type', 'unknown')
  180. raise ResponseError(f"Failed to parse response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
  181. record_id = image_data.get("record_id")
  182. if not record_id:
  183. raise ResponseError(f"Failed to initiate image generation: {image_data}")
  184. # Step 3: Check Generation Status
  185. status_url = cls.status_check_url.format(record_id=record_id)
  186. counter = 4
  187. start_time = time.time()
  188. last_status = None
  189. while True:
  190. async with session.get(status_url, headers=headers, proxy=proxy) as status_response:
  191. if status_response.status >= 400:
  192. error_text = await status_response.text()
  193. raise ResponseError(f"Failed to check image generation status. Status: {status_response.status}, Response: {error_text}")
  194. try:
  195. status_data = await status_response.json()
  196. except Exception as e:
  197. error_text = await status_response.text()
  198. content_type = status_response.headers.get('Content-Type', 'unknown')
  199. raise ResponseError(f"Failed to parse status response as JSON. Content-Type: {content_type}, Error: {str(e)}, Response: {error_text}")
  200. status = status_data.get("status")
  201. if status == "DONE":
  202. image_urls = [image["url"] for image in status_data.get("response", [])]
  203. duration = time.time() - start_time
  204. yield Reasoning(label="Generated", status=f"{n} image(s) in {duration:.2f}s")
  205. yield ImageResponse(urls=image_urls, alt=prompt)
  206. return
  207. elif status in ("IN_QUEUE", "IN_PROGRESS"):
  208. if last_status != status:
  209. last_status = status
  210. if status == "IN_QUEUE":
  211. yield Reasoning(label="Waiting")
  212. else:
  213. yield Reasoning(label="Generating")
  214. await asyncio.sleep(2) # Poll every 2 seconds
  215. else:
  216. raise ResponseError(f"Image generation failed with status: {status}")