__init__.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312
  1. from __future__ import annotations
  2. import os
  3. import re
  4. import io
  5. import base64
  6. from io import BytesIO
  7. from pathlib import Path
  8. try:
  9. from PIL.Image import open as open_image, new as new_image
  10. from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
  11. has_requirements = True
  12. except ImportError:
  13. has_requirements = False
  14. from ..providers.helper import filter_none
  15. from ..typing import ImageType, Union, Image
  16. from ..errors import MissingRequirementsError
  17. ALLOWED_EXTENSIONS = {
  18. # Image
  19. 'png', 'jpg', 'jpeg', 'gif', 'webp',
  20. # Audio
  21. 'wav', 'mp3', 'flac', 'opus', 'ogg',
  22. # Video
  23. 'mkv', 'webm', 'mp4'
  24. }
  25. EXTENSIONS_MAP: dict[str, str] = {
  26. "image/png": "png",
  27. "image/jpeg": "jpg",
  28. "image/gif": "gif",
  29. "image/webp": "webp",
  30. }
  31. def to_image(image: ImageType, is_svg: bool = False) -> Image:
  32. """
  33. Converts the input image to a PIL Image object.
  34. Args:
  35. image (Union[str, bytes, Image]): The input image.
  36. Returns:
  37. Image: The converted PIL Image object.
  38. """
  39. if not has_requirements:
  40. raise MissingRequirementsError('Install "pillow" package for images')
  41. if isinstance(image, str) and image.startswith("data:"):
  42. is_data_uri_an_image(image)
  43. image = extract_data_uri(image)
  44. if is_svg:
  45. try:
  46. import cairosvg
  47. except ImportError:
  48. raise MissingRequirementsError('Install "cairosvg" package for svg images')
  49. if not isinstance(image, bytes):
  50. image = image.read()
  51. buffer = BytesIO()
  52. cairosvg.svg2png(image, write_to=buffer)
  53. return open_image(buffer)
  54. if isinstance(image, bytes):
  55. is_accepted_format(image)
  56. return open_image(BytesIO(image))
  57. elif not isinstance(image, Image):
  58. image = open_image(image)
  59. image.load()
  60. return image
  61. return image
  62. def is_allowed_extension(filename: str) -> bool:
  63. """
  64. Checks if the given filename has an allowed extension.
  65. Args:
  66. filename (str): The filename to check.
  67. Returns:
  68. bool: True if the extension is allowed, False otherwise.
  69. """
  70. return '.' in filename and \
  71. filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
  72. def is_data_an_media(data, filename: str = None) -> str:
  73. content_type = is_data_an_audio(data, filename)
  74. if content_type is not None:
  75. return content_type
  76. if isinstance(data, bytes):
  77. return is_accepted_format(data)
  78. return is_data_uri_an_image(data)
  79. def is_data_an_audio(data_uri: str, filename: str = None) -> str:
  80. if filename:
  81. if filename.endswith(".wav"):
  82. return "audio/wav"
  83. elif filename.endswith(".mp3"):
  84. return "audio/mpeg"
  85. if isinstance(data_uri, str):
  86. audio_format = re.match(r'^data:(audio/\w+);base64,', data_uri)
  87. if audio_format:
  88. return audio_format.group(1)
  89. def is_data_uri_an_image(data_uri: str) -> bool:
  90. """
  91. Checks if the given data URI represents an image.
  92. Args:
  93. data_uri (str): The data URI to check.
  94. Raises:
  95. ValueError: If the data URI is invalid or the image format is not allowed.
  96. """
  97. # Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
  98. if not re.match(r'data:image/(\w+);base64,', data_uri):
  99. raise ValueError("Invalid data URI image.")
  100. # Extract the image format from the data URI
  101. image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
  102. # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
  103. if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml":
  104. raise ValueError("Invalid image format (from mime file type).")
  105. def is_accepted_format(binary_data: bytes) -> str:
  106. """
  107. Checks if the given binary data represents an image with an accepted format.
  108. Args:
  109. binary_data (bytes): The binary data to check.
  110. Raises:
  111. ValueError: If the image format is not allowed.
  112. """
  113. if binary_data.startswith(b'\xFF\xD8\xFF'):
  114. return "image/jpeg"
  115. elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
  116. return "image/png"
  117. elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'):
  118. return "image/gif"
  119. elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'):
  120. return "image/jpeg"
  121. elif binary_data.startswith(b'\xFF\xD8'):
  122. return "image/jpeg"
  123. elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP':
  124. return "image/webp"
  125. else:
  126. raise ValueError("Invalid image format (from magic code).")
  127. def extract_data_uri(data_uri: str) -> bytes:
  128. """
  129. Extracts the binary data from the given data URI.
  130. Args:
  131. data_uri (str): The data URI.
  132. Returns:
  133. bytes: The extracted binary data.
  134. """
  135. data = data_uri.split(",")[-1]
  136. data = base64.b64decode(data)
  137. return data
  138. def get_orientation(image: Image) -> int:
  139. """
  140. Gets the orientation of the given image.
  141. Args:
  142. image (Image): The image.
  143. Returns:
  144. int: The orientation value.
  145. """
  146. exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
  147. if exif_data is not None:
  148. orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
  149. if orientation is not None:
  150. return orientation
  151. def process_image(image: Image, new_width: int, new_height: int) -> Image:
  152. """
  153. Processes the given image by adjusting its orientation and resizing it.
  154. Args:
  155. image (Image): The image to process.
  156. new_width (int): The new width of the image.
  157. new_height (int): The new height of the image.
  158. Returns:
  159. Image: The processed image.
  160. """
  161. # Fix orientation
  162. orientation = get_orientation(image)
  163. if orientation:
  164. if orientation > 4:
  165. image = image.transpose(FLIP_LEFT_RIGHT)
  166. if orientation in [3, 4]:
  167. image = image.transpose(ROTATE_180)
  168. if orientation in [5, 6]:
  169. image = image.transpose(ROTATE_270)
  170. if orientation in [7, 8]:
  171. image = image.transpose(ROTATE_90)
  172. # Resize image
  173. image.thumbnail((new_width, new_height))
  174. # Remove transparency
  175. if image.mode == "RGBA":
  176. image.load()
  177. white = new_image('RGB', image.size, (255, 255, 255))
  178. white.paste(image, mask=image.split()[-1])
  179. return white
  180. # Convert to RGB for jpg format
  181. elif image.mode != "RGB":
  182. image = image.convert("RGB")
  183. return image
  184. def to_bytes(image: ImageType) -> bytes:
  185. """
  186. Converts the given image to bytes.
  187. Args:
  188. image (ImageType): The image to convert.
  189. Returns:
  190. bytes: The image as bytes.
  191. """
  192. if isinstance(image, bytes):
  193. return image
  194. elif isinstance(image, str) and image.startswith("data:"):
  195. is_data_an_media(image)
  196. return extract_data_uri(image)
  197. elif isinstance(image, Image):
  198. bytes_io = BytesIO()
  199. image.save(bytes_io, image.format)
  200. image.seek(0)
  201. return bytes_io.getvalue()
  202. elif isinstance(image, (str, os.PathLike)):
  203. return Path(image).read_bytes()
  204. elif isinstance(image, Path):
  205. return image.read_bytes()
  206. else:
  207. try:
  208. image.seek(0)
  209. except (AttributeError, io.UnsupportedOperation):
  210. pass
  211. return image.read()
  212. def to_data_uri(image: ImageType, filename: str = None) -> str:
  213. if not isinstance(image, str):
  214. data = to_bytes(image)
  215. data_base64 = base64.b64encode(data).decode()
  216. return f"data:{is_data_an_media(data, filename)};base64,{data_base64}"
  217. return image
  218. def to_input_audio(audio: ImageType, filename: str = None) -> str:
  219. if not isinstance(audio, str):
  220. if filename is not None and (filename.endswith(".wav") or filename.endswith(".mp3")):
  221. return {
  222. "data": base64.b64encode(to_bytes(audio)).decode(),
  223. "format": "wav" if filename.endswith(".wav") else "mp3"
  224. }
  225. raise ValueError("Invalid input audio")
  226. audio = re.match(r'^data:audio/(\w+);base64,(.+?)', audio)
  227. if audio:
  228. return {
  229. "data": audio.group(2),
  230. "format": audio.group(1).replace("mpeg", "mp3")
  231. }
  232. raise ValueError("Invalid input audio")
  233. def use_aspect_ratio(extra_data: dict, aspect_ratio: str) -> Image:
  234. extra_data = filter_none(**extra_data)
  235. if aspect_ratio == "1:1":
  236. extra_data = {
  237. "width": 1024,
  238. "height": 1024,
  239. **extra_data
  240. }
  241. elif aspect_ratio == "16:9":
  242. extra_data = {
  243. "width": 800,
  244. "height": 512,
  245. **extra_data
  246. }
  247. elif aspect_ratio == "9:16":
  248. extra_data = {
  249. "width": 512,
  250. "height": 800,
  251. **extra_data
  252. }
  253. return extra_data
  254. class ImageDataResponse():
  255. def __init__(
  256. self,
  257. images: Union[str, list],
  258. alt: str,
  259. ):
  260. self.images = images
  261. self.alt = alt
  262. def get_list(self) -> list[str]:
  263. return [self.images] if isinstance(self.images, str) else self.images
  264. class ImageRequest:
  265. def __init__(
  266. self,
  267. options: dict = {}
  268. ):
  269. self.options = options
  270. def get(self, key: str):
  271. return self.options.get(key)