client.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. from __future__ import annotations
  2. import unittest
  3. from g4f.errors import ModelNotFoundError
  4. from g4f.client import Client, AsyncClient, ChatCompletion, ChatCompletionChunk
  5. from g4f.client.service import get_model_and_provider
  6. from g4f.Provider.Copilot import Copilot
  7. from g4f.models import gpt_4o
  8. from .mocks import AsyncGeneratorProviderMock, ModelProviderMock, YieldProviderMock
  9. DEFAULT_MESSAGES = [{'role': 'user', 'content': 'Hello'}]
  10. class AsyncTestPassModel(unittest.IsolatedAsyncioTestCase):
  11. async def test_response(self):
  12. client = AsyncClient(provider=AsyncGeneratorProviderMock)
  13. response = await client.chat.completions.create(DEFAULT_MESSAGES, "")
  14. self.assertIsInstance(response, ChatCompletion)
  15. self.assertEqual("Mock", response.choices[0].message.content)
  16. async def test_pass_model(self):
  17. client = AsyncClient(provider=ModelProviderMock)
  18. response = await client.chat.completions.create(DEFAULT_MESSAGES, "Hello")
  19. self.assertIsInstance(response, ChatCompletion)
  20. self.assertEqual("Hello", response.choices[0].message.content)
  21. async def test_max_tokens(self):
  22. client = AsyncClient(provider=YieldProviderMock)
  23. messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
  24. response = await client.chat.completions.create(messages, "Hello", max_tokens=1)
  25. self.assertIsInstance(response, ChatCompletion)
  26. self.assertEqual("How ", response.choices[0].message.content)
  27. response = await client.chat.completions.create(messages, "Hello", max_tokens=2)
  28. self.assertIsInstance(response, ChatCompletion)
  29. self.assertEqual("How are ", response.choices[0].message.content)
  30. async def test_max_stream(self):
  31. client = AsyncClient(provider=YieldProviderMock)
  32. messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
  33. response = client.chat.completions.create(messages, "Hello", stream=True)
  34. async for chunk in response:
  35. chunk: ChatCompletionChunk = chunk
  36. self.assertIsInstance(chunk, ChatCompletionChunk)
  37. if chunk.choices[0].delta.content is not None:
  38. self.assertIsInstance(chunk.choices[0].delta.content, str)
  39. messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You ", "Other", "?"]]
  40. response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2)
  41. response_list = []
  42. async for chunk in response:
  43. response_list.append(chunk)
  44. self.assertEqual(len(response_list), 3)
  45. for chunk in response_list:
  46. if chunk.choices[0].delta.content is not None:
  47. self.assertEqual(chunk.choices[0].delta.content, "You ")
  48. async def test_stop(self):
  49. client = AsyncClient(provider=YieldProviderMock)
  50. messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
  51. response = await client.chat.completions.create(messages, "Hello", stop=["and"])
  52. self.assertIsInstance(response, ChatCompletion)
  53. self.assertEqual("How are you?", response.choices[0].message.content)
  54. class TestPassModel(unittest.TestCase):
  55. def test_response(self):
  56. client = Client(provider=AsyncGeneratorProviderMock)
  57. response = client.chat.completions.create(DEFAULT_MESSAGES, "")
  58. self.assertIsInstance(response, ChatCompletion)
  59. self.assertEqual("Mock", response.choices[0].message.content)
  60. def test_pass_model(self):
  61. client = Client(provider=ModelProviderMock)
  62. response = client.chat.completions.create(DEFAULT_MESSAGES, "Hello")
  63. self.assertIsInstance(response, ChatCompletion)
  64. self.assertEqual("Hello", response.choices[0].message.content)
  65. def test_max_tokens(self):
  66. client = Client(provider=YieldProviderMock)
  67. messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
  68. response = client.chat.completions.create(messages, "Hello", max_tokens=1)
  69. self.assertIsInstance(response, ChatCompletion)
  70. self.assertEqual("How ", response.choices[0].message.content)
  71. response = client.chat.completions.create(messages, "Hello", max_tokens=2)
  72. self.assertIsInstance(response, ChatCompletion)
  73. self.assertEqual("How are ", response.choices[0].message.content)
  74. def test_max_stream(self):
  75. client = Client(provider=YieldProviderMock)
  76. messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
  77. response = client.chat.completions.create(messages, "Hello", stream=True)
  78. for chunk in response:
  79. self.assertIsInstance(chunk, ChatCompletionChunk)
  80. if chunk.choices[0].delta.content is not None:
  81. self.assertIsInstance(chunk.choices[0].delta.content, str)
  82. messages = [{'role': 'user', 'content': chunk} for chunk in ["You ", "You ", "Other", "?"]]
  83. response = client.chat.completions.create(messages, "Hello", stream=True, max_tokens=2)
  84. response_list = list(response)
  85. self.assertEqual(len(response_list), 3)
  86. for chunk in response_list:
  87. if chunk.choices[0].delta.content is not None:
  88. self.assertEqual(chunk.choices[0].delta.content, "You ")
  89. def test_stop(self):
  90. client = Client(provider=YieldProviderMock)
  91. messages = [{'role': 'user', 'content': chunk} for chunk in ["How ", "are ", "you", "?"]]
  92. response = client.chat.completions.create(messages, "Hello", stop=["and"])
  93. self.assertIsInstance(response, ChatCompletion)
  94. self.assertEqual("How are you?", response.choices[0].message.content)
  95. def test_model_not_found(self):
  96. def run_exception():
  97. client = Client()
  98. client.chat.completions.create(DEFAULT_MESSAGES, "Hello")
  99. self.assertRaises(ModelNotFoundError, run_exception)
  100. def test_best_provider(self):
  101. not_default_model = "gpt-4o"
  102. model, provider = get_model_and_provider(not_default_model, None, False)
  103. self.assertTrue(hasattr(provider, "create_completion"))
  104. self.assertEqual(model, not_default_model)
  105. def test_default_model(self):
  106. default_model = ""
  107. model, provider = get_model_and_provider(default_model, None, False)
  108. self.assertTrue(hasattr(provider, "create_completion"))
  109. self.assertEqual(model, default_model)
  110. def test_provider_as_model(self):
  111. provider_as_model = Copilot.__name__
  112. model, provider = get_model_and_provider(provider_as_model, None, False)
  113. self.assertTrue(hasattr(provider, "create_completion"))
  114. self.assertIsInstance(model, str)
  115. self.assertEqual(model, Copilot.default_model)
  116. def test_get_model(self):
  117. model, provider = get_model_and_provider(gpt_4o.name, None, False)
  118. self.assertTrue(hasattr(provider, "create_completion"))
  119. self.assertEqual(model, gpt_4o.name)
  120. if __name__ == '__main__':
  121. unittest.main()