create_images.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. from __future__ import annotations
  2. import re
  3. import asyncio
  4. from .. import debug
  5. from ..typing import CreateResult, Messages
  6. from ..base_provider import BaseProvider, ProviderType
  7. system_message = """
  8. You can generate images, pictures, photos or img with the DALL-E 3 image generator.
  9. To generate an image with a prompt, do this:
  10. <img data-prompt=\"keywords for the image\">
  11. Never use own image links. Don't wrap it in backticks.
  12. It is important to use a only a img tag with a prompt.
  13. <img data-prompt=\"image caption\">
  14. """
  15. class CreateImagesProvider(BaseProvider):
  16. """
  17. Provider class for creating images based on text prompts.
  18. This provider handles image creation requests embedded within message content,
  19. using provided image creation functions.
  20. Attributes:
  21. provider (ProviderType): The underlying provider to handle non-image related tasks.
  22. create_images (callable): A function to create images synchronously.
  23. create_images_async (callable): A function to create images asynchronously.
  24. system_message (str): A message that explains the image creation capability.
  25. include_placeholder (bool): Flag to determine whether to include the image placeholder in the output.
  26. __name__ (str): Name of the provider.
  27. url (str): URL of the provider.
  28. working (bool): Indicates if the provider is operational.
  29. supports_stream (bool): Indicates if the provider supports streaming.
  30. """
  31. def __init__(
  32. self,
  33. provider: ProviderType,
  34. create_images: callable,
  35. create_async: callable,
  36. system_message: str = system_message,
  37. include_placeholder: bool = True
  38. ) -> None:
  39. """
  40. Initializes the CreateImagesProvider.
  41. Args:
  42. provider (ProviderType): The underlying provider.
  43. create_images (callable): Function to create images synchronously.
  44. create_async (callable): Function to create images asynchronously.
  45. system_message (str, optional): System message to be prefixed to messages. Defaults to a predefined message.
  46. include_placeholder (bool, optional): Whether to include image placeholders in the output. Defaults to True.
  47. """
  48. self.provider = provider
  49. self.create_images = create_images
  50. self.create_images_async = create_async
  51. self.system_message = system_message
  52. self.include_placeholder = include_placeholder
  53. self.__name__ = provider.__name__
  54. self.url = provider.url
  55. self.working = provider.working
  56. self.supports_stream = provider.supports_stream
  57. def create_completion(
  58. self,
  59. model: str,
  60. messages: Messages,
  61. stream: bool = False,
  62. **kwargs
  63. ) -> CreateResult:
  64. """
  65. Creates a completion result, processing any image creation prompts found within the messages.
  66. Args:
  67. model (str): The model to use for creation.
  68. messages (Messages): The messages to process, which may contain image prompts.
  69. stream (bool, optional): Indicates whether to stream the results. Defaults to False.
  70. **kwargs: Additional keywordarguments for the provider.
  71. Yields:
  72. CreateResult: Yields chunks of the processed messages, including image data if applicable.
  73. Note:
  74. This method processes messages to detect image creation prompts. When such a prompt is found,
  75. it calls the synchronous image creation function and includes the resulting image in the output.
  76. """
  77. messages.insert(0, {"role": "system", "content": self.system_message})
  78. buffer = ""
  79. for chunk in self.provider.create_completion(model, messages, stream, **kwargs):
  80. if isinstance(chunk, str) and buffer or "<" in chunk:
  81. buffer += chunk
  82. if ">" in buffer:
  83. match = re.search(r'<img data-prompt="(.*?)">', buffer)
  84. if match:
  85. placeholder, prompt = match.group(0), match.group(1)
  86. start, append = buffer.split(placeholder, 1)
  87. if start:
  88. yield start
  89. if self.include_placeholder:
  90. yield placeholder
  91. if debug.logging:
  92. print(f"Create images with prompt: {prompt}")
  93. yield from self.create_images(prompt)
  94. if append:
  95. yield append
  96. else:
  97. yield buffer
  98. buffer = ""
  99. else:
  100. yield chunk
  101. async def create_async(
  102. self,
  103. model: str,
  104. messages: Messages,
  105. **kwargs
  106. ) -> str:
  107. """
  108. Asynchronously creates a response, processing any image creation prompts found within the messages.
  109. Args:
  110. model (str): The model to use for creation.
  111. messages (Messages): The messages to process, which may contain image prompts.
  112. **kwargs: Additional keyword arguments for the provider.
  113. Returns:
  114. str: The processed response string, including asynchronously generated image data if applicable.
  115. Note:
  116. This method processes messages to detect image creation prompts. When such a prompt is found,
  117. it calls the asynchronous image creation function and includes the resulting image in the output.
  118. """
  119. messages.insert(0, {"role": "system", "content": self.system_message})
  120. response = await self.provider.create_async(model, messages, **kwargs)
  121. matches = re.findall(r'(<img data-prompt="(.*?)">)', response)
  122. results = []
  123. placeholders = []
  124. for placeholder, prompt in matches:
  125. if placeholder not in placeholders:
  126. if debug.logging:
  127. print(f"Create images with prompt: {prompt}")
  128. results.append(self.create_images_async(prompt))
  129. placeholders.append(placeholder)
  130. results = await asyncio.gather(*results)
  131. for idx, result in enumerate(results):
  132. placeholder = placeholder[idx]
  133. if self.include_placeholder:
  134. result = placeholder + result
  135. response = response.replace(placeholder, result)
  136. return response