image.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. from __future__ import annotations
  2. import re
  3. from io import BytesIO
  4. import base64
  5. from .typing import ImageType, Union, Image
  6. try:
  7. from PIL.Image import open as open_image, new as new_image
  8. from PIL.Image import FLIP_LEFT_RIGHT, ROTATE_180, ROTATE_270, ROTATE_90
  9. has_requirements = True
  10. except ImportError:
  11. has_requirements = False
  12. from .errors import MissingRequirementsError
  13. ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif', 'webp', 'svg'}
  14. def to_image(image: ImageType, is_svg: bool = False) -> Image:
  15. """
  16. Converts the input image to a PIL Image object.
  17. Args:
  18. image (Union[str, bytes, Image]): The input image.
  19. Returns:
  20. Image: The converted PIL Image object.
  21. """
  22. if not has_requirements:
  23. raise MissingRequirementsError('Install "pillow" package for images')
  24. if isinstance(image, str):
  25. is_data_uri_an_image(image)
  26. image = extract_data_uri(image)
  27. if is_svg:
  28. try:
  29. import cairosvg
  30. except ImportError:
  31. raise MissingRequirementsError('Install "cairosvg" package for svg images')
  32. if not isinstance(image, bytes):
  33. image = image.read()
  34. buffer = BytesIO()
  35. cairosvg.svg2png(image, write_to=buffer)
  36. return open_image(buffer)
  37. if isinstance(image, bytes):
  38. is_accepted_format(image)
  39. return open_image(BytesIO(image))
  40. elif not isinstance(image, Image):
  41. image = open_image(image)
  42. image.load()
  43. return image
  44. return image
  45. def is_allowed_extension(filename: str) -> bool:
  46. """
  47. Checks if the given filename has an allowed extension.
  48. Args:
  49. filename (str): The filename to check.
  50. Returns:
  51. bool: True if the extension is allowed, False otherwise.
  52. """
  53. return '.' in filename and \
  54. filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
  55. def is_data_uri_an_image(data_uri: str) -> bool:
  56. """
  57. Checks if the given data URI represents an image.
  58. Args:
  59. data_uri (str): The data URI to check.
  60. Raises:
  61. ValueError: If the data URI is invalid or the image format is not allowed.
  62. """
  63. # Check if the data URI starts with 'data:image' and contains an image format (e.g., jpeg, png, gif)
  64. if not re.match(r'data:image/(\w+);base64,', data_uri):
  65. raise ValueError("Invalid data URI image.")
  66. # Extract the image format from the data URI
  67. image_format = re.match(r'data:image/(\w+);base64,', data_uri).group(1).lower()
  68. # Check if the image format is one of the allowed formats (jpg, jpeg, png, gif)
  69. if image_format not in ALLOWED_EXTENSIONS and image_format != "svg+xml":
  70. raise ValueError("Invalid image format (from mime file type).")
  71. def is_accepted_format(binary_data: bytes) -> str:
  72. """
  73. Checks if the given binary data represents an image with an accepted format.
  74. Args:
  75. binary_data (bytes): The binary data to check.
  76. Raises:
  77. ValueError: If the image format is not allowed.
  78. """
  79. if binary_data.startswith(b'\xFF\xD8\xFF'):
  80. return "image/jpeg"
  81. elif binary_data.startswith(b'\x89PNG\r\n\x1a\n'):
  82. return "image/png"
  83. elif binary_data.startswith(b'GIF87a') or binary_data.startswith(b'GIF89a'):
  84. return "image/gif"
  85. elif binary_data.startswith(b'\x89JFIF') or binary_data.startswith(b'JFIF\x00'):
  86. return "image/jpeg"
  87. elif binary_data.startswith(b'\xFF\xD8'):
  88. return "image/jpeg"
  89. elif binary_data.startswith(b'RIFF') and binary_data[8:12] == b'WEBP':
  90. return "image/webp"
  91. else:
  92. raise ValueError("Invalid image format (from magic code).")
  93. def extract_data_uri(data_uri: str) -> bytes:
  94. """
  95. Extracts the binary data from the given data URI.
  96. Args:
  97. data_uri (str): The data URI.
  98. Returns:
  99. bytes: The extracted binary data.
  100. """
  101. data = data_uri.split(",")[1]
  102. data = base64.b64decode(data)
  103. return data
  104. def get_orientation(image: Image) -> int:
  105. """
  106. Gets the orientation of the given image.
  107. Args:
  108. image (Image): The image.
  109. Returns:
  110. int: The orientation value.
  111. """
  112. exif_data = image.getexif() if hasattr(image, 'getexif') else image._getexif()
  113. if exif_data is not None:
  114. orientation = exif_data.get(274) # 274 corresponds to the orientation tag in EXIF
  115. if orientation is not None:
  116. return orientation
  117. def process_image(image: Image, new_width: int, new_height: int) -> Image:
  118. """
  119. Processes the given image by adjusting its orientation and resizing it.
  120. Args:
  121. image (Image): The image to process.
  122. new_width (int): The new width of the image.
  123. new_height (int): The new height of the image.
  124. Returns:
  125. Image: The processed image.
  126. """
  127. # Fix orientation
  128. orientation = get_orientation(image)
  129. if orientation:
  130. if orientation > 4:
  131. image = image.transpose(FLIP_LEFT_RIGHT)
  132. if orientation in [3, 4]:
  133. image = image.transpose(ROTATE_180)
  134. if orientation in [5, 6]:
  135. image = image.transpose(ROTATE_270)
  136. if orientation in [7, 8]:
  137. image = image.transpose(ROTATE_90)
  138. # Resize image
  139. image.thumbnail((new_width, new_height))
  140. # Remove transparency
  141. if image.mode == "RGBA":
  142. image.load()
  143. white = new_image('RGB', image.size, (255, 255, 255))
  144. white.paste(image, mask=image.split()[-1])
  145. return white
  146. # Convert to RGB for jpg format
  147. elif image.mode != "RGB":
  148. image = image.convert("RGB")
  149. return image
  150. def to_base64_jpg(image: Image, compression_rate: float) -> str:
  151. """
  152. Converts the given image to a base64-encoded string.
  153. Args:
  154. image (Image.Image): The image to convert.
  155. compression_rate (float): The compression rate (0.0 to 1.0).
  156. Returns:
  157. str: The base64-encoded image.
  158. """
  159. output_buffer = BytesIO()
  160. image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
  161. return base64.b64encode(output_buffer.getvalue()).decode()
  162. def format_images_markdown(images: Union[str, list], alt: str, preview: Union[str, list] = None) -> str:
  163. """
  164. Formats the given images as a markdown string.
  165. Args:
  166. images: The images to format.
  167. alt (str): The alt for the images.
  168. preview (str, optional): The preview URL format. Defaults to "{image}?w=200&h=200".
  169. Returns:
  170. str: The formatted markdown string.
  171. """
  172. if isinstance(images, str):
  173. result = f"[![{alt}]({preview.replace('{image}', images) if preview else images})]({images})"
  174. else:
  175. if not isinstance(preview, list):
  176. preview = [preview.replace('{image}', image) if preview else image for image in images]
  177. result = "\n".join(
  178. f"[![#{idx+1} {alt}]({preview[idx]})]({image})"
  179. #f'[<img src="{preview[idx]}" width="200" alt="#{idx+1} {alt}">]({image})'
  180. for idx, image in enumerate(images)
  181. )
  182. start_flag = "<!-- generated images start -->\n"
  183. end_flag = "<!-- generated images end -->\n"
  184. return f"\n{start_flag}{result}\n{end_flag}\n"
  185. def to_bytes(image: ImageType) -> bytes:
  186. """
  187. Converts the given image to bytes.
  188. Args:
  189. image (ImageType): The image to convert.
  190. Returns:
  191. bytes: The image as bytes.
  192. """
  193. if isinstance(image, bytes):
  194. return image
  195. elif isinstance(image, str):
  196. is_data_uri_an_image(image)
  197. return extract_data_uri(image)
  198. elif isinstance(image, Image):
  199. bytes_io = BytesIO()
  200. image.save(bytes_io, image.format)
  201. image.seek(0)
  202. return bytes_io.getvalue()
  203. else:
  204. return image.read()
  205. def to_data_uri(image: ImageType) -> str:
  206. if not isinstance(image, str):
  207. data = to_bytes(image)
  208. data_base64 = base64.b64encode(data).decode()
  209. return f"data:{is_accepted_format(data)};base64,{data_base64}"
  210. return image
  211. class ImageResponse:
  212. def __init__(
  213. self,
  214. images: Union[str, list],
  215. alt: str,
  216. options: dict = {}
  217. ):
  218. self.images = images
  219. self.alt = alt
  220. self.options = options
  221. def __str__(self) -> str:
  222. return format_images_markdown(self.images, self.alt, self.get("preview"))
  223. def get(self, key: str):
  224. return self.options.get(key)
  225. def get_list(self) -> list[str]:
  226. return [self.images] if isinstance(self.images, str) else self.images
  227. class ImagePreview(ImageResponse):
  228. def __str__(self):
  229. return ""
  230. def to_string(self):
  231. return super().__str__()
  232. class ImageDataResponse():
  233. def __init__(
  234. self,
  235. images: Union[str, list],
  236. alt: str,
  237. ):
  238. self.images = images
  239. self.alt = alt
  240. def get_list(self) -> list[str]:
  241. return [self.images] if isinstance(self.images, str) else self.images
  242. class ImageRequest:
  243. def __init__(
  244. self,
  245. options: dict = {}
  246. ):
  247. self.options = options
  248. def get(self, key: str):
  249. return self.options.get(key)