copy_images.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. from __future__ import annotations
  2. import os
  3. import time
  4. import uuid
  5. import asyncio
  6. import hashlib
  7. import re
  8. from typing import AsyncIterator
  9. from urllib.parse import quote, unquote
  10. from aiohttp import ClientSession, ClientError
  11. from ..typing import Optional, Cookies
  12. from ..requests.aiohttp import get_connector, StreamResponse
  13. from ..image import MEDIA_TYPE_MAP, EXTENSIONS_MAP
  14. from ..providers.response import ImageResponse, AudioResponse, VideoResponse
  15. from ..Provider.template import BackendApi
  16. from . import is_accepted_format, extract_data_uri
  17. from .. import debug
  18. # Directory for storing generated images
  19. images_dir = "./generated_images"
  20. def get_media_extension(media: str) -> str:
  21. """Extract media file extension from URL or filename"""
  22. match = re.search(r"\.(j?[a-z]{3})(?:\?|$)", media, re.IGNORECASE)
  23. extension = match.group(1).lower() if match else ""
  24. if extension not in EXTENSIONS_MAP:
  25. raise ValueError(f"Unsupported media extension: {extension}")
  26. return f".{extension}"
  27. def ensure_images_dir():
  28. """Create images directory if it doesn't exist"""
  29. os.makedirs(images_dir, exist_ok=True)
  30. def get_source_url(image: str, default: str = None) -> str:
  31. """Extract original URL from image parameter if present"""
  32. if "url=" in image:
  33. decoded_url = unquote(image.split("url=", 1)[1])
  34. if decoded_url.startswith(("http://", "https://")):
  35. return decoded_url
  36. return default
  37. def secure_filename(filename: str) -> str:
  38. if filename is None:
  39. return None
  40. # Keep letters, numbers, basic punctuation and all Unicode chars
  41. filename = re.sub(
  42. r'[^\w.,_-]+',
  43. '_',
  44. unquote(filename).strip(),
  45. flags=re.UNICODE
  46. )
  47. filename = filename[:100].strip(".,_-")
  48. return filename
  49. def is_valid_media_type(content_type: str) -> bool:
  50. return content_type in MEDIA_TYPE_MAP or content_type.startswith("audio/") or content_type.startswith("video/")
  51. async def save_response_media(response: StreamResponse, prompt: str, tags: list[str]) -> AsyncIterator:
  52. """Save media from response to local file and return URL"""
  53. content_type = response.headers["content-type"]
  54. if is_valid_media_type(content_type):
  55. extension = MEDIA_TYPE_MAP[content_type] if content_type in MEDIA_TYPE_MAP else content_type[6:].replace("mpeg", "mp3")
  56. if extension not in EXTENSIONS_MAP:
  57. raise ValueError(f"Unsupported media type: {content_type}")
  58. filename = get_filename(tags, prompt, f".{extension}", prompt)
  59. target_path = os.path.join(images_dir, filename)
  60. with open(target_path, 'wb') as f:
  61. async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any():
  62. f.write(chunk)
  63. media_url = f"/media/{filename}"
  64. if response.method == "GET":
  65. media_url = f"{media_url}?url={str(response.url)}"
  66. if content_type.startswith("audio/"):
  67. yield AudioResponse(media_url)
  68. elif content_type.startswith("video/"):
  69. yield VideoResponse(media_url, prompt)
  70. else:
  71. yield ImageResponse(media_url, prompt)
  72. def get_filename(tags: list[str], alt: str, extension: str, image: str) -> str:
  73. return secure_filename("".join((
  74. f"{int(time.time())}_",
  75. (f"{'_'.join([tag for tag in tags if tag])}_" if tags else ""),
  76. (f"{alt}_" if alt else ""),
  77. f"{hashlib.sha256(image.encode()).hexdigest()[:16]}",
  78. f"{extension}"
  79. )))
  80. async def copy_media(
  81. images: list[str],
  82. cookies: Optional[Cookies] = None,
  83. headers: Optional[dict] = None,
  84. proxy: Optional[str] = None,
  85. alt: str = None,
  86. tags: list[str] = None,
  87. add_url: bool = True,
  88. target: str = None,
  89. ssl: bool = None
  90. ) -> list[str]:
  91. """
  92. Download and store images locally with Unicode-safe filenames
  93. Returns list of relative image URLs
  94. """
  95. if add_url:
  96. add_url = not cookies
  97. ensure_images_dir()
  98. async with ClientSession(
  99. connector=get_connector(proxy=proxy),
  100. cookies=cookies,
  101. headers=headers,
  102. ) as session:
  103. async def copy_image(image: str, target: str = None) -> str:
  104. """Process individual image and return its local URL"""
  105. # Skip if image is already local
  106. if image.startswith("/"):
  107. return image
  108. target_path = target
  109. if target_path is None:
  110. # Build safe filename with full Unicode support
  111. filename = get_filename(tags, alt, get_media_extension(image), image)
  112. target_path = os.path.join(images_dir, filename)
  113. try:
  114. # Handle different image types
  115. if image.startswith("data:"):
  116. with open(target_path, "wb") as f:
  117. f.write(extract_data_uri(image))
  118. else:
  119. # Apply BackendApi settings if needed
  120. if BackendApi.working and image.startswith(BackendApi.url):
  121. request_headers = BackendApi.headers if headers is None else headers
  122. request_ssl = BackendApi.ssl
  123. else:
  124. request_headers = headers
  125. request_ssl = ssl
  126. async with session.get(image, ssl=request_ssl, headers=request_headers) as response:
  127. response.raise_for_status()
  128. media_type = response.headers.get("content-type", "application/octet-stream")
  129. if media_type not in ("application/octet-stream", "binary/octet-stream"):
  130. if not is_valid_media_type(media_type):
  131. raise ValueError(f"Unsupported media type: {media_type}")
  132. with open(target_path, "wb") as f:
  133. async for chunk in response.content.iter_any():
  134. f.write(chunk)
  135. # Verify file format
  136. if target is None and not os.path.splitext(target_path)[1]:
  137. with open(target_path, "rb") as f:
  138. file_header = f.read(12)
  139. try:
  140. detected_type = is_accepted_format(file_header)
  141. if detected_type:
  142. new_ext = f".{detected_type.split('/')[-1]}"
  143. os.rename(target_path, f"{target_path}{new_ext}")
  144. target_path = f"{target_path}{new_ext}"
  145. except ValueError:
  146. pass
  147. # Build URL with safe encoding
  148. url_filename = quote(os.path.basename(target_path))
  149. return f"/media/{url_filename}" + (('?url=' + quote(image)) if add_url and not image.startswith('data:') else '')
  150. except (ClientError, IOError, OSError, ValueError) as e:
  151. debug.error(f"Image copying failed: {type(e).__name__}: {e}")
  152. if target_path and os.path.exists(target_path):
  153. os.unlink(target_path)
  154. return get_source_url(image, image)
  155. return await asyncio.gather(*[copy_image(img, target) for img in images])