copy_images.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. from __future__ import annotations
  2. import os
  3. import time
  4. import uuid
  5. import asyncio
  6. import hashlib
  7. import re
  8. from urllib.parse import quote, unquote
  9. from aiohttp import ClientSession, ClientError
  10. from ..typing import Optional, Cookies
  11. from ..requests.aiohttp import get_connector, StreamResponse
  12. from ..image import EXTENSIONS_MAP, ALLOWED_EXTENSIONS
  13. from ..tools.files import get_bucket_dir
  14. from ..providers.response import ImageResponse, AudioResponse, VideoResponse
  15. from ..Provider.template import BackendApi
  16. from . import is_accepted_format, extract_data_uri
  17. from .. import debug
  18. # Directory for storing generated images
  19. images_dir = "./generated_images"
  20. def get_media_extension(media: str) -> str:
  21. """Extract media file extension from URL or filename"""
  22. match = re.search(r"\.(jpe?g|png|gif|svg|webp|webm|mp4|mp3|wav|flac|opus|ogg|mkv)(?:\?|$)", media, re.IGNORECASE)
  23. return f".{match.group(1).lower()}" if match else ""
  24. def ensure_images_dir():
  25. """Create images directory if it doesn't exist"""
  26. os.makedirs(images_dir, exist_ok=True)
  27. def get_source_url(image: str, default: str = None) -> str:
  28. """Extract original URL from image parameter if present"""
  29. if "url=" in image:
  30. decoded_url = unquote(image.split("url=", 1)[1])
  31. if decoded_url.startswith(("http://", "https://")):
  32. return decoded_url
  33. return default
  34. def secure_filename(filename: str) -> str:
  35. if filename is None:
  36. return None
  37. # Keep letters, numbers, basic punctuation and all Unicode chars
  38. filename = re.sub(
  39. r'[^\w.,_-]+',
  40. '_',
  41. unquote(filename).strip(),
  42. flags=re.UNICODE
  43. )
  44. filename = filename[:100].strip(".,_-")
  45. return filename
  46. async def save_response_media(response: StreamResponse, prompt: str):
  47. content_type = response.headers["content-type"]
  48. if content_type in EXTENSIONS_MAP or content_type.startswith("audio/") or content_type.startswith("video/"):
  49. extension = EXTENSIONS_MAP[content_type] if content_type in EXTENSIONS_MAP else content_type[6:].replace("mpeg", "mp3")
  50. if extension not in ALLOWED_EXTENSIONS:
  51. raise ValueError(f"Unsupported media type: {content_type}")
  52. bucket_id = str(uuid.uuid4())
  53. dirname = str(int(time.time()))
  54. bucket_dir = get_bucket_dir(bucket_id, dirname)
  55. media_dir = os.path.join(bucket_dir, "media")
  56. os.makedirs(media_dir, exist_ok=True)
  57. filename = secure_filename(f"{content_type[0:5] if prompt is None else prompt}.{extension}")
  58. newfile = os.path.join(media_dir, filename)
  59. with open(newfile, 'wb') as f:
  60. async for chunk in response.iter_content() if hasattr(response, "iter_content") else response.content.iter_any():
  61. f.write(chunk)
  62. media_url = f"/files/{dirname}/{bucket_id}/media/{filename}"
  63. if response.method == "GET":
  64. media_url = f"{media_url}?url={str(response.url)}"
  65. if content_type.startswith("audio/"):
  66. yield AudioResponse(media_url)
  67. elif content_type.startswith("video/"):
  68. yield VideoResponse(media_url, prompt)
  69. else:
  70. yield ImageResponse(media_url, prompt)
  71. async def copy_media(
  72. images: list[str],
  73. cookies: Optional[Cookies] = None,
  74. headers: Optional[dict] = None,
  75. proxy: Optional[str] = None,
  76. alt: str = None,
  77. add_url: bool = True,
  78. target: str = None,
  79. ssl: bool = None
  80. ) -> list[str]:
  81. """
  82. Download and store images locally with Unicode-safe filenames
  83. Returns list of relative image URLs
  84. """
  85. if add_url:
  86. add_url = not cookies
  87. ensure_images_dir()
  88. async with ClientSession(
  89. connector=get_connector(proxy=proxy),
  90. cookies=cookies,
  91. headers=headers,
  92. ) as session:
  93. async def copy_image(image: str, target: str = None) -> str:
  94. """Process individual image and return its local URL"""
  95. # Skip if image is already local
  96. if image.startswith("/"):
  97. return image
  98. target_path = target
  99. if target_path is None:
  100. # Build safe filename with full Unicode support
  101. filename = secure_filename("".join((
  102. f"{int(time.time())}_",
  103. (f"{alt}_" if alt else ""),
  104. f"{hashlib.sha256(image.encode()).hexdigest()[:16]}",
  105. f"{get_media_extension(image)}"
  106. )))
  107. target_path = os.path.join(images_dir, filename)
  108. try:
  109. # Handle different image types
  110. if image.startswith("data:"):
  111. with open(target_path, "wb") as f:
  112. f.write(extract_data_uri(image))
  113. else:
  114. # Apply BackendApi settings if needed
  115. if BackendApi.working and image.startswith(BackendApi.url):
  116. request_headers = BackendApi.headers if headers is None else headers
  117. request_ssl = BackendApi.ssl
  118. else:
  119. request_headers = headers
  120. request_ssl = ssl
  121. async with session.get(image, ssl=request_ssl, headers=request_headers) as response:
  122. response.raise_for_status()
  123. with open(target_path, "wb") as f:
  124. async for chunk in response.content.iter_chunked(4096):
  125. f.write(chunk)
  126. # Verify file format
  127. if target is None and not os.path.splitext(target_path)[1]:
  128. with open(target_path, "rb") as f:
  129. file_header = f.read(12)
  130. try:
  131. detected_type = is_accepted_format(file_header)
  132. if detected_type:
  133. new_ext = f".{detected_type.split('/')[-1]}"
  134. os.rename(target_path, f"{target_path}{new_ext}")
  135. target_path = f"{target_path}{new_ext}"
  136. except ValueError:
  137. pass
  138. # Build URL with safe encoding
  139. url_filename = quote(os.path.basename(target_path))
  140. return f"/images/{url_filename}" + (('?url=' + quote(image)) if add_url and not image.startswith('data:') else '')
  141. except (ClientError, IOError, OSError) as e:
  142. debug.error(f"Image copying failed: {type(e).__name__}: {e}")
  143. if target_path and os.path.exists(target_path):
  144. os.unlink(target_path)
  145. return get_source_url(image, image)
  146. return await asyncio.gather(*[copy_image(img, target) for img in images])