client.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. from __future__ import annotations
  2. import unittest
  3. from g4f.errors import ModelNotFoundError
  4. from g4f.client import Client, AsyncClient, ChatCompletion, ChatCompletionChunk, get_model_and_provider
  5. from g4f.Provider.Copilot import Copilot
  6. from g4f.models import gpt_4o
  7. from .mocks import AsyncGeneratorProviderMock, ModelProviderMock, YieldProviderMock
  8. DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
  9. class AsyncTestPassModel(unittest.IsolatedAsyncioTestCase):
  10. async def test_response(self):
  11. client = AsyncClient(provider=AsyncGeneratorProviderMock)
  12. response = await client.chat.completions.create(DEFAULT_MESSAGES, "")
  13. self.assertIsInstance(response, ChatCompletion)
  14. self.assertEqual("Mock", response.choices[0].message.content)
  15. async def test_pass_model(self):
  16. client = AsyncClient(provider=ModelProviderMock)
  17. response = await client.chat.completions.create(DEFAULT_MESSAGES, "Hello")
  18. self.assertIsInstance(response, ChatCompletion)
  19. self.assertEqual("Hello", response.choices[0].message.content)
  20. async def test_max_tokens(self):
  21. client = AsyncClient(provider=YieldProviderMock)
  22. messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
  23. response = await client.chat.completions.create(messages, "Hello", max_tokens=1)
  24. self.assertIsInstance(response, ChatCompletion)
  25. self.assertEqual("How ", response.choices[0].message.content)
  26. response = await client.chat.completions.create(messages, "Hello", max_tokens=2)
  27. self.assertIsInstance(response, ChatCompletion)
  28. self.assertEqual("How are ", response.choices[0].message.content)
  29. async def test_max_stream(self):
  30. client = AsyncClient(provider=YieldProviderMock)
  31. messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
  32. response = client.chat.completions.create(messages, "Hello", stream=True)
  33. async for chunk in response:
  34. chunk: ChatCompletionChunk = chunk
  35. self.assertIsInstance(chunk, ChatCompletionChunk)
  36. if chunk.choices[0].delta.content is not None:
  37. self.assertIsInstance(chunk.choices[0].delta.content, str)
  38. messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You ", "Other", "?"]]
  39. response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2)
  40. response_list = []
  41. async for chunk in response:
  42. response_list.append(chunk)
  43. self.assertEqual(len(response_list), 3)
  44. for chunk in response_list:
  45. if chunk.choices[0].delta.content is not None:
  46. self.assertEqual(chunk.choices[0].delta.content, "You ")
  47. async def test_stop(self):
  48. client = AsyncClient(provider=YieldProviderMock)
  49. messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
  50. response = await client.chat.completions.create(messages, "Hello", stop=["and"])
  51. self.assertIsInstance(response, ChatCompletion)
  52. self.assertEqual("How are you?", response.choices[0].message.content)
  53. class TestPassModel(unittest.TestCase):
  54. def test_response(self):
  55. client = Client(provider=AsyncGeneratorProviderMock)
  56. response = client.chat.completions.create(DEFAULT_MESSAGES, "")
  57. self.assertIsInstance(response, ChatCompletion)
  58. self.assertEqual("Mock", response.choices[0].message.content)
  59. def test_pass_model(self):
  60. client = Client(provider=ModelProviderMock)
  61. response = client.chat.completions.create(DEFAULT_MESSAGES, "Hello")
  62. self.assertIsInstance(response, ChatCompletion)
  63. self.assertEqual("Hello", response.choices[0].message.content)
  64. def test_max_tokens(self):
  65. client = Client(provider=YieldProviderMock)
  66. messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
  67. response = client.chat.completions.create(messages, "Hello", max_tokens=1)
  68. self.assertIsInstance(response, ChatCompletion)
  69. self.assertEqual("How ", response.choices[0].message.content)
  70. response = client.chat.completions.create(messages, "Hello", max_tokens=2)
  71. self.assertIsInstance(response, ChatCompletion)
  72. self.assertEqual("How are ", response.choices[0].message.content)
  73. def test_max_stream(self):
  74. client = Client(provider=YieldProviderMock)
  75. messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
  76. response = client.chat.completions.create(messages, "Hello", stream=True)
  77. for chunk in response:
  78. self.assertIsInstance(chunk, ChatCompletionChunk)
  79. if chunk.choices[0].delta.content is not None:
  80. self.assertIsInstance(chunk.choices[0].delta.content, str)
  81. messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You ", "Other", "?"]]
  82. response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2)
  83. response_list = list(response)
  84. self.assertEqual(len(response_list), 3)
  85. for chunk in response_list:
  86. if chunk.choices[0].delta.content is not None:
  87. self.assertEqual(chunk.choices[0].delta.content, "You ")
  88. def test_stop(self):
  89. client = Client(provider=YieldProviderMock)
  90. messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
  91. response = client.chat.completions.create(messages, "Hello", stop=["and"])
  92. self.assertIsInstance(response, ChatCompletion)
  93. self.assertEqual("How are you?", response.choices[0].message.content)
  94. def test_model_not_found(self):
  95. def run_exception():
  96. client = Client()
  97. client.chat.completions.create(DEFAULT_MESSAGES, "Hello")
  98. self.assertRaises(ModelNotFoundError, run_exception)
  99. def test_best_provider(self):
  100. not_default_model = "gpt-4o"
  101. model, provider = get_model_and_provider(not_default_model, None, False)
  102. self.assertTrue(hasattr(provider, "create_completion"))
  103. self.assertEqual(model, not_default_model)
  104. def test_default_model(self):
  105. default_model = ""
  106. model, provider = get_model_and_provider(default_model, None, False)
  107. self.assertTrue(hasattr(provider, "create_completion"))
  108. self.assertEqual(model, default_model)
  109. def test_provider_as_model(self):
  110. provider_as_model = Copilot.__name__
  111. model, provider = get_model_and_provider(provider_as_model, None, False)
  112. self.assertTrue(hasattr(provider, "create_completion"))
  113. self.assertIsInstance(model, str)
  114. self.assertEqual(model, Copilot.default_model)
  115. def test_get_model(self):
  116. model, provider = get_model_and_provider(gpt_4o.name, None, False)
  117. self.assertTrue(hasattr(provider, "create_completion"))
  118. self.assertEqual(model, gpt_4o.name)
  119. if __name__ == '__main__':
  120. unittest.main()