helper.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from __future__ import annotations
  2. import random
  3. import string
  4. from ..typing import Messages, Cookies, AsyncIterator, Iterator
  5. from .. import debug
  6. def format_prompt(messages: Messages, add_special_tokens: bool = False, do_continue: bool = False, include_system: bool = True) -> str:
  7. """
  8. Format a series of messages into a single string, optionally adding special tokens.
  9. Args:
  10. messages (Messages): A list of message dictionaries, each containing 'role' and 'content'.
  11. add_special_tokens (bool): Whether to add special formatting tokens.
  12. Returns:
  13. str: A formatted string containing all messages.
  14. """
  15. if not add_special_tokens and len(messages) <= 1:
  16. return messages[0]["content"]
  17. formatted = "\n".join([
  18. f'{message["role"].capitalize()}: {message["content"]}'
  19. for message in messages
  20. if include_system or message["role"] != "system"
  21. ])
  22. if do_continue:
  23. return formatted
  24. return f"{formatted}\nAssistant:"
  25. def get_system_prompt(messages: Messages) -> str:
  26. return "\n".join([m["content"] for m in messages if m["role"] == "system"])
  27. def get_last_user_message(messages: Messages) -> str:
  28. user_messages = []
  29. last_message = None if len(messages) == 0 else messages[-1]
  30. while last_message is not None and messages:
  31. last_message = messages.pop()
  32. if last_message["role"] == "user":
  33. if isinstance(last_message["content"], str):
  34. user_messages.append(last_message["content"].strip())
  35. else:
  36. return "\n".join(user_messages[::-1])
  37. return "\n".join(user_messages[::-1])
  38. def format_image_prompt(messages, prompt: str = None) -> str:
  39. if prompt is None:
  40. return get_last_user_message(messages)
  41. return prompt
  42. def format_prompt_max_length(messages: Messages, max_lenght: int) -> str:
  43. prompt = format_prompt(messages)
  44. start = len(prompt)
  45. if start > max_lenght:
  46. if len(messages) > 6:
  47. prompt = format_prompt(messages[:3] + messages[-3:])
  48. if len(prompt) > max_lenght:
  49. if len(messages) > 2:
  50. prompt = format_prompt([m for m in messages if m["role"] == "system"] + messages[-1:])
  51. if len(prompt) > max_lenght:
  52. prompt = messages[-1]["content"]
  53. debug.log(f"Messages trimmed from: {start} to: {len(prompt)}")
  54. return prompt
  55. def get_random_string(length: int = 10) -> str:
  56. """
  57. Generate a random string of specified length, containing lowercase letters and digits.
  58. Args:
  59. length (int, optional): Length of the random string to generate. Defaults to 10.
  60. Returns:
  61. str: A random string of the specified length.
  62. """
  63. return ''.join(
  64. random.choice(string.ascii_lowercase + string.digits)
  65. for _ in range(length)
  66. )
  67. def get_random_hex(length: int = 32) -> str:
  68. """
  69. Generate a random hexadecimal string with n length.
  70. Returns:
  71. str: A random hexadecimal string of n characters.
  72. """
  73. return ''.join(
  74. random.choice("abcdef" + string.digits)
  75. for _ in range(length)
  76. )
  77. def filter_none(**kwargs) -> dict:
  78. return {
  79. key: value
  80. for key, value in kwargs.items()
  81. if value is not None
  82. }
  83. async def async_concat_chunks(chunks: AsyncIterator) -> str:
  84. return concat_chunks([chunk async for chunk in chunks])
  85. def concat_chunks(chunks: Iterator) -> str:
  86. return "".join([
  87. str(chunk) for chunk in chunks
  88. if chunk and not isinstance(chunk, Exception)
  89. ])
  90. def format_cookies(cookies: Cookies) -> str:
  91. return "; ".join([f"{k}={v}" for k, v in cookies.items()])