helper.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from __future__ import annotations
  2. import random
  3. import string
  4. from pathlib import Path
  5. from ..typing import Messages, Cookies, AsyncIterator, Iterator
  6. from ..tools.files import get_bucket_dir, read_bucket
  7. from .. import debug
  8. def to_string(value) -> str:
  9. if isinstance(value, str):
  10. return value
  11. elif isinstance(value, dict):
  12. if "name" in value:
  13. return ""
  14. elif "bucket_id" in value:
  15. bucket_dir = Path(get_bucket_dir(value.get("bucket_id")))
  16. return "".join(read_bucket(bucket_dir))
  17. elif value.get("type") == "text":
  18. return value.get("text")
  19. return ""
  20. elif isinstance(value, list):
  21. return "".join([to_string(v) for v in value if v.get("type", "text") == "text"])
  22. elif value is None:
  23. return ""
  24. return str(value)
  25. def render_messages(messages: Messages) -> Iterator:
  26. for idx, message in enumerate(messages):
  27. if isinstance(message, dict) and isinstance(message.get("content"), list):
  28. yield {
  29. **message,
  30. "content": to_string(message["content"]),
  31. }
  32. else:
  33. yield message
  34. def format_prompt(messages: Messages, add_special_tokens: bool = False, do_continue: bool = False, include_system: bool = True) -> str:
  35. """
  36. Format a series of messages into a single string, optionally adding special tokens.
  37. Args:
  38. messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
  39. add_special_tokens (bool): Whether to add special formatting tokens.
  40. Returns:
  41. str: A formatted string containing all messages.
  42. """
  43. if not add_special_tokens and len(messages) <= 1:
  44. return to_string(messages[0]["content"])
  45. messages = [
  46. (message["role"], to_string(message["content"]))
  47. for message in messages
  48. if include_system or message.get("role") != "system"
  49. ]
  50. formatted = "\n".join([
  51. f'{role.capitalize()}: {content}'
  52. for role, content in messages
  53. if content.strip()
  54. ])
  55. if do_continue:
  56. return formatted
  57. return f"{formatted}\nAssistant:"
  58. def get_system_prompt(messages: Messages) -> str:
  59. return "\n".join([m["content"] for m in messages if m["role"] == "system"])
  60. def get_last_user_message(messages: Messages) -> str:
  61. user_messages = []
  62. last_message = None if len(messages) == 0 else messages[-1]
  63. messages = messages.copy()
  64. while last_message is not None and messages:
  65. last_message = messages.pop()
  66. if last_message["role"] == "user":
  67. content = to_string(last_message.get("content")).strip()
  68. if content:
  69. user_messages.append(content)
  70. else:
  71. return "\n".join(user_messages[::-1])
  72. return "\n".join(user_messages[::-1])
  73. def get_last_message(messages: Messages, prompt: str = None) -> str:
  74. if prompt is None:
  75. for message in messages[::-1]:
  76. content = to_string(message.get("content")).strip()
  77. if content:
  78. prompt = content
  79. return prompt
  80. def format_image_prompt(messages, prompt: str = None) -> str:
  81. if prompt is None:
  82. return get_last_user_message(messages)
  83. return prompt
  84. def format_prompt_max_length(messages: Messages, max_lenght: int) -> str:
  85. prompt = format_prompt(messages)
  86. start = len(prompt)
  87. if start > max_lenght:
  88. if len(messages) > 6:
  89. prompt = format_prompt(messages[:3] + messages[-3:])
  90. if len(prompt) > max_lenght:
  91. if len(messages) > 2:
  92. prompt = format_prompt([m for m in messages if m["role"] == "system"] + messages[-1:])
  93. if len(prompt) > max_lenght:
  94. prompt = messages[-1]["content"]
  95. debug.log(f"Messages trimmed from: {start} to: {len(prompt)}")
  96. return prompt
  97. def get_random_string(length: int = 10) -> str:
  98. """
  99. Generate a random string of specified length, containing lowercase letters and digits.
  100. Args:
  101. length (int, optional): Length of the random string to generate. Defaults to 10.
  102. Returns:
  103. str: A random string of the specified length.
  104. """
  105. return ''.join(
  106. random.choice(string.ascii_lowercase + string.digits)
  107. for _ in range(length)
  108. )
  109. def get_random_hex(length: int = 32) -> str:
  110. """
  111. Generate a random hexadecimal string with n length.
  112. Returns:
  113. str: A random hexadecimal string of n characters.
  114. """
  115. return ''.join(
  116. random.choice("abcdef" + string.digits)
  117. for _ in range(length)
  118. )
  119. def filter_none(**kwargs) -> dict:
  120. return {
  121. key: value
  122. for key, value in kwargs.items()
  123. if value is not None
  124. }
  125. async def async_concat_chunks(chunks: AsyncIterator) -> str:
  126. return concat_chunks([chunk async for chunk in chunks])
  127. def concat_chunks(chunks: Iterator) -> str:
  128. return "".join([
  129. str(chunk) for chunk in chunks
  130. if chunk and not isinstance(chunk, Exception)
  131. ])
  132. def format_cookies(cookies: Cookies) -> str:
  133. return "; ".join([f"{k}={v}" for k, v in cookies.items()])