media.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. from __future__ import annotations
  2. import os
  3. import base64
  4. from typing import Iterator, Union
  5. from pathlib import Path
  6. from ..typing import Messages
  7. from ..image import is_data_an_media, is_data_an_audio, to_input_audio, to_data_uri
  8. from .files import get_bucket_dir, read_bucket
  9. def render_media(bucket_id: str, name: str, url: str, as_path: bool = False, as_base64: bool = False) -> Union[str, Path]:
  10. if (as_base64 or as_path or url.startswith("/")):
  11. file = Path(get_bucket_dir(bucket_id, "media", name))
  12. if as_path:
  13. return file
  14. data = file.read_bytes()
  15. data_base64 = base64.b64encode(data).decode()
  16. if as_base64:
  17. return data_base64
  18. return f"data:{is_data_an_media(data, name)};base64,{data_base64}"
  19. return url
  20. def render_part(part: dict) -> dict:
  21. if "type" in part:
  22. return part
  23. filename = part.get("name")
  24. if (filename is None):
  25. bucket_dir = Path(get_bucket_dir(part.get("bucket_id")))
  26. return {
  27. "type": "text",
  28. "text": "".join(read_bucket(bucket_dir))
  29. }
  30. if is_data_an_audio(filename=filename):
  31. return {
  32. "type": "input_audio",
  33. "input_audio": {
  34. "data": render_media(**part, as_base64=True),
  35. "format": os.path.splitext(filename)[1][1:]
  36. }
  37. }
  38. return {
  39. "type": "image_url",
  40. "image_url": {"url": render_media(**part)}
  41. }
  42. def merge_media(media: list, messages: list) -> Iterator:
  43. buffer = []
  44. for message in messages:
  45. if message.get("role") == "user":
  46. content = message.get("content")
  47. if isinstance(content, list):
  48. for part in content:
  49. if "type" not in part and "name" in part:
  50. path = render_media(**part, as_path=True)
  51. buffer.append((path, os.path.basename(path)))
  52. elif part.get("type") == "image_url":
  53. buffer.append((part.get("image_url"), None))
  54. else:
  55. buffer = []
  56. yield from buffer
  57. if media is not None:
  58. yield from media
  59. def render_messages(messages: Messages, media: list = None) -> Iterator:
  60. for idx, message in enumerate(messages):
  61. if isinstance(message["content"], list):
  62. yield {
  63. **message,
  64. "content": [render_part(part) for part in message["content"] if part]
  65. }
  66. else:
  67. if media is not None and idx == len(messages) - 1:
  68. yield {
  69. **message,
  70. "content": [
  71. {
  72. "type": "input_audio",
  73. "input_audio": to_input_audio(media_data, filename)
  74. }
  75. if is_data_an_audio(media_data, filename) else {
  76. "type": "image_url",
  77. "image_url": {"url": to_data_uri(media_data)}
  78. }
  79. for media_data, filename in media
  80. ] + ([{"type": "text", "text": message["content"]}] if isinstance(message["content"], str) else message["content"])
  81. }
  82. else:
  83. yield message