retry_provider.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from __future__ import annotations
  2. import asyncio
  3. import random
  4. from ..typing import CreateResult, Messages
  5. from ..base_provider import BaseRetryProvider
  6. from .. import debug
  7. from ..errors import RetryProviderError, RetryNoProviderError
  8. class RetryProvider(BaseRetryProvider):
  9. """
  10. A provider class to handle retries for creating completions with different providers.
  11. Attributes:
  12. providers (list): A list of provider instances.
  13. shuffle (bool): A flag indicating whether to shuffle providers before use.
  14. exceptions (dict): A dictionary to store exceptions encountered during retries.
  15. last_provider (BaseProvider): The last provider that was used.
  16. """
  17. def create_completion(
  18. self,
  19. model: str,
  20. messages: Messages,
  21. stream: bool = False,
  22. **kwargs
  23. ) -> CreateResult:
  24. """
  25. Create a completion using available providers, with an option to stream the response.
  26. Args:
  27. model (str): The model to be used for completion.
  28. messages (Messages): The messages to be used for generating completion.
  29. stream (bool, optional): Flag to indicate if the response should be streamed. Defaults to False.
  30. Yields:
  31. CreateResult: Tokens or results from the completion.
  32. Raises:
  33. Exception: Any exception encountered during the completion process.
  34. """
  35. providers = [p for p in self.providers if stream and p.supports_stream] if stream else self.providers
  36. if self.shuffle:
  37. random.shuffle(providers)
  38. self.exceptions = {}
  39. started: bool = False
  40. for provider in providers:
  41. self.last_provider = provider
  42. try:
  43. if debug.logging:
  44. print(f"Using {provider.__name__} provider")
  45. for token in provider.create_completion(model, messages, stream, **kwargs):
  46. yield token
  47. started = True
  48. if started:
  49. return
  50. except Exception as e:
  51. self.exceptions[provider.__name__] = e
  52. if debug.logging:
  53. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  54. if started:
  55. raise e
  56. self.raise_exceptions()
  57. async def create_async(
  58. self,
  59. model: str,
  60. messages: Messages,
  61. **kwargs
  62. ) -> str:
  63. """
  64. Asynchronously create a completion using available providers.
  65. Args:
  66. model (str): The model to be used for completion.
  67. messages (Messages): The messages to be used for generating completion.
  68. Returns:
  69. str: The result of the asynchronous completion.
  70. Raises:
  71. Exception: Any exception encountered during the asynchronous completion process.
  72. """
  73. providers = self.providers
  74. if self.shuffle:
  75. random.shuffle(providers)
  76. self.exceptions = {}
  77. for provider in providers:
  78. self.last_provider = provider
  79. try:
  80. return await asyncio.wait_for(
  81. provider.create_async(model, messages, **kwargs),
  82. timeout=kwargs.get("timeout", 60)
  83. )
  84. except Exception as e:
  85. self.exceptions[provider.__name__] = e
  86. if debug.logging:
  87. print(f"{provider.__name__}: {e.__class__.__name__}: {e}")
  88. self.raise_exceptions()
  89. def raise_exceptions(self) -> None:
  90. """
  91. Raise a combined exception if any occurred during retries.
  92. Raises:
  93. RetryProviderError: If any provider encountered an exception.
  94. RetryNoProviderError: If no provider is found.
  95. """
  96. if self.exceptions:
  97. raise RetryProviderError("RetryProvider failed:\n" + "\n".join([
  98. f"{p}: {exception.__class__.__name__}: {exception}" for p, exception in self.exceptions.items()
  99. ]))
  100. raise RetryNoProviderError("No provider found")