image_client.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. from __future__ import annotations
  2. import asyncio
  3. import unittest
  4. from g4f.client import AsyncClient, ImagesResponse
  5. from g4f.providers.retry_provider import IterListProvider
  6. from .mocks import (
  7. YieldImageResponseProviderMock,
  8. MissingAuthProviderMock,
  9. AsyncRaiseExceptionProviderMock,
  10. YieldNoneProviderMock
  11. )
  12. DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
  13. class TestIterListProvider(unittest.IsolatedAsyncioTestCase):
  14. async def test_skip_provider(self):
  15. client = AsyncClient(image_provider=IterListProvider([MissingAuthProviderMock, YieldImageResponseProviderMock], False))
  16. response = await client.images.generate("Hello", "", response_format="orginal")
  17. self.assertIsInstance(response, ImagesResponse)
  18. self.assertEqual("Hello", response.data[0].url)
  19. async def test_only_one_result(self):
  20. client = AsyncClient(image_provider=IterListProvider([YieldImageResponseProviderMock, YieldImageResponseProviderMock], False))
  21. response = await client.images.generate("Hello", "", response_format="orginal")
  22. self.assertIsInstance(response, ImagesResponse)
  23. self.assertEqual("Hello", response.data[0].url)
  24. async def test_skip_none(self):
  25. client = AsyncClient(image_provider=IterListProvider([YieldNoneProviderMock, YieldImageResponseProviderMock], False))
  26. response = await client.images.generate("Hello", "", response_format="orginal")
  27. self.assertIsInstance(response, ImagesResponse)
  28. self.assertEqual("Hello", response.data[0].url)
  29. def test_raise_exception(self):
  30. async def run_exception():
  31. client = AsyncClient(image_provider=IterListProvider([YieldNoneProviderMock, AsyncRaiseExceptionProviderMock], False))
  32. await client.images.generate("Hello", "")
  33. self.assertRaises(RuntimeError, asyncio.run, run_exception())
  34. if __name__ == '__main__':
  35. unittest.main()