__init__.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. from __future__ import annotations
  2. import os
  3. import re
  4. import io
  5. import base64
  6. from io import BytesIO
  7. from pathlib import Path
  8. try:
  9. from PIL.Image import open as open_image, new as new_image
  10. from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
  11. has_requirements = True
  12. except ImportError:
  13. has_requirements = False
  14. from ..providers.helper import filter_none
  15. from ..typing import ImageType, Union, Image
  16. from ..errors import MissingRequirementsError
  17. ALLOWED_EXTENSIONS = {
  18. # Image
  19. 'png', 'jpg', 'jpeg', 'gif', 'webp',
  20. # Audio
  21. 'wav', 'mp3', 'flac', 'opus', 'ogg',
  22. # Video
  23. 'mkv', 'webm', 'mp4'
  24. }
  25. MEDIA_TYPE_MAP: dict[str, str] = {
  26. "image/png": "png",
  27. "image/jpeg": "jpg",
  28. "image/gif": "gif",
  29. "image/webp": "webp",
  30. }
  31. EXTENSIONS_MAP: dict[str, str] = {
  32. # Image
  33. "png": "image/png",
  34. "jpg": "image/jpeg",
  35. "jpeg": "image/jpeg",
  36. "gif": "image/gif",
  37. "webp": "image/webp",
  38. # Audio
  39. "wav": "audio/wav",
  40. "mp3": "audio/mpeg",
  41. "flac": "audio/flac",
  42. "opus": "audio/opus",
  43. "ogg": "audio/ogg",
  44. # Video
  45. "mkv": "video/x-matroska",
  46. "webm": "video/webm",
  47. "mp4": "video/mp4",
  48. }
  49. def to_image(image: ImageType, is_svg: bool = False) -> Image:
  50. """
  51. Converts the input image to a PIL Image object.
  52. Args:
  53. image (Union[str, bytes, Image]): The input image.
  54. Returns:
  55. Image: The converted PIL Image object.
  56. """
  57. if not has_requirements:
  58. raise MissingRequirementsError('Install "pillow" package for images')
  59. if isinstance(image, str) and image.startswith("data:"):
  60. is_data_uri_an_image(image)
  61. image = extract_data_uri(image)
  62. if is_svg:
  63. try:
  64. import cairosvg
  65. except ImportError:
  66. raise MissingRequirementsError('Install "cairosvg" package for svg images')
  67. if not isinstance(image, bytes):
  68. image = image.read()
  69. buffer = BytesIO()
  70. cairosvg.svg2png(image, write_to=buffer)
  71. return open_image(buffer)
  72. if isinstance(image, bytes):
  73. is_accepted_format(image)
  74. return open_image(BytesIO(image))
  75. elif not isinstance(image, Image):
  76. image = open_image(image)
  77. image.load()
  78. return image
  79. return image
  80. def is_allowed_extension(filename: str) -> bool:
  81. """
  82. Checks if the given filename has an allowed extension.
  83. Args:
  84. filename (str): The filename to check.
  85. Returns:
  86. bool: True if the extension is allowed, False otherwise.
  87. """
  88. return '.' in filename and \
  89. filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
  90. def is_data_an_media(data, filename: str = None) -> str:
  91. content_type = is_data_an_audio(data, filename)
  92. if content_type is not None:
  93. return content_type
  94. if isinstance(data, bytes):
  95. return is_accepted_format(data)
  96. return is_data_uri_an_image(data)
  97. def is_data_an_audio(data_uri: str, filename: str = None) -> str:
  98. if filename:
  99. if filename.endswith(".wav"):
  100. return "audio/wav"
  101. elif filename.endswith(".mp3"):
  102. return "audio/mpeg"
  103. if isinstance(data_uri, str):
  104. audio_format = re.match(r'^data:(audio/\w+);base64,', data_uri)
  105. if audio_format:
  106. return audio_format.group(1)
  107. def is_data_uri_an_image(data_uri: str) -> bool:
  108. """
  109. Checks if the given data URI represents an image.
  110. Args:
  111. data_uri (str): The data URI to check.
  112. Raises:
  113. ValueError: If the data URI is invalid or the image format is not allowed.
  114. """
  115. # Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
  116. if not re.match(r'data:image/(\w+);base64,', data_uri):
  117. raise ValueError("Invalid data URI image.")
  118. # Extract the image format from the data URI
  119. image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
  120. # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
  121. if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml":
  122. raise ValueError("Invalid image format (from mime file type).")
  123. def is_accepted_format(binary_data: bytes) -> str:
  124. """
  125. Checks if the given binary data represents an image with an accepted format.
  126. Args:
  127. binary_data (bytes): The binary data to check.
  128. Raises:
  129. ValueError: If the image format is not allowed.
  130. """
  131. if binary_data.startswith(b'\xFF\xD8\xFF'):
  132. return "image/jpeg"
  133. elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
  134. return "image/png"
  135. elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'):
  136. return "image/gif"
  137. elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'):
  138. return "image/jpeg"
  139. elif binary_data.startswith(b'\xFF\xD8'):
  140. return "image/jpeg"
  141. elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP':
  142. return "image/webp"
  143. else:
  144. raise ValueError("Invalid image format (from magic code).")
  145. def extract_data_uri(data_uri: str) -> bytes:
  146. """
  147. Extracts the binary data from the given data URI.
  148. Args:
  149. data_uri (str): The data URI.
  150. Returns:
  151. bytes: The extracted binary data.
  152. """
  153. data = data_uri.split(",")[-1]
  154. data = base64.b64decode(data)
  155. return data
  156. def get_orientation(image: Image) -> int:
  157. """
  158. Gets the orientation of the given image.
  159. Args:
  160. image (Image): The image.
  161. Returns:
  162. int: The orientation value.
  163. """
  164. exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
  165. if exif_data is not None:
  166. orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
  167. if orientation is not None:
  168. return orientation
  169. def process_image(image: Image, new_width: int, new_height: int) -> Image:
  170. """
  171. Processes the given image by adjusting its orientation and resizing it.
  172. Args:
  173. image (Image): The image to process.
  174. new_width (int): The new width of the image.
  175. new_height (int): The new height of the image.
  176. Returns:
  177. Image: The processed image.
  178. """
  179. # Fix orientation
  180. orientation = get_orientation(image)
  181. if orientation:
  182. if orientation > 4:
  183. image = image.transpose(FLIP_LEFT_RIGHT)
  184. if orientation in [3, 4]:
  185. image = image.transpose(ROTATE_180)
  186. if orientation in [5, 6]:
  187. image = image.transpose(ROTATE_270)
  188. if orientation in [7, 8]:
  189. image = image.transpose(ROTATE_90)
  190. # Resize image
  191. image.thumbnail((new_width, new_height))
  192. # Remove transparency
  193. if image.mode == "RGBA":
  194. image.load()
  195. white = new_image('RGB', image.size, (255, 255, 255))
  196. white.paste(image, mask=image.split()[-1])
  197. return white
  198. # Convert to RGB for jpg format
  199. elif image.mode != "RGB":
  200. image = image.convert("RGB")
  201. return image
  202. def to_bytes(image: ImageType) -> bytes:
  203. """
  204. Converts the given image to bytes.
  205. Args:
  206. image (ImageType): The image to convert.
  207. Returns:
  208. bytes: The image as bytes.
  209. """
  210. if isinstance(image, bytes):
  211. return image
  212. elif isinstance(image, str) and image.startswith("data:"):
  213. is_data_an_media(image)
  214. return extract_data_uri(image)
  215. elif isinstance(image, Image):
  216. bytes_io = BytesIO()
  217. image.save(bytes_io, image.format)
  218. image.seek(0)
  219. return bytes_io.getvalue()
  220. elif isinstance(image, (str, os.PathLike)):
  221. return Path(image).read_bytes()
  222. elif isinstance(image, Path):
  223. return image.read_bytes()
  224. else:
  225. try:
  226. image.seek(0)
  227. except (AttributeError, io.UnsupportedOperation):
  228. pass
  229. return image.read()
  230. def to_data_uri(image: ImageType, filename: str = None) -> str:
  231. if not isinstance(image, str):
  232. data = to_bytes(image)
  233. data_base64 = base64.b64encode(data).decode()
  234. return f"data:{is_data_an_media(data, filename)};base64,{data_base64}"
  235. return image
  236. def to_input_audio(audio: ImageType, filename: str = None) -> str:
  237. if not isinstance(audio, str):
  238. if filename is not None and (filename.endswith(".wav") or filename.endswith(".mp3")):
  239. return {
  240. "data": base64.b64encode(to_bytes(audio)).decode(),
  241. "format": "wav" if filename.endswith(".wav") else "mp3"
  242. }
  243. raise ValueError("Invalid input audio")
  244. audio = re.match(r'^data:audio/(\w+);base64,(.+?)', audio)
  245. if audio:
  246. return {
  247. "data": audio.group(2),
  248. "format": audio.group(1).replace("mpeg", "mp3")
  249. }
  250. raise ValueError("Invalid input audio")
  251. def use_aspect_ratio(extra_data: dict, aspect_ratio: str) -> Image:
  252. extra_data = filter_none(**extra_data)
  253. if aspect_ratio == "1:1":
  254. extra_data = {
  255. "width": 1024,
  256. "height": 1024,
  257. **extra_data
  258. }
  259. elif aspect_ratio == "16:9":
  260. extra_data = {
  261. "width": 800,
  262. "height": 512,
  263. **extra_data
  264. }
  265. elif aspect_ratio == "9:16":
  266. extra_data = {
  267. "width": 512,
  268. "height": 800,
  269. **extra_data
  270. }
  271. return extra_data
  272. class ImageDataResponse():
  273. def __init__(
  274. self,
  275. images: Union[str, list],
  276. alt: str,
  277. ):
  278. self.images = images
  279. self.alt = alt
  280. def get_list(self) -> list[str]:
  281. return [self.images] if isinstance(self.images, str) else self.images
  282. class ImageRequest:
  283. def __init__(
  284. self,
  285. options: dict = {}
  286. ):
  287. self.options = options
  288. def get(self, key: str):
  289. return self.options.get(key)