models.py 1.2 KB

1234567891011121314151617181920212223242526272829
  1. import unittest
  2. from typing import Type
  3. import asyncio
  4. from g4f.models import __models__
  5. from g4f.providers.base_provider import BaseProvider, ProviderModelMixin
  6. from g4f.errors import MissingRequirementsError, MissingAuthError
  7. class TestProviderHasModel(unittest.IsolatedAsyncioTestCase):
  8. cache: dict = {}
  9. async def test_provider_has_model(self):
  10. for model, providers in __models__.values():
  11. for provider in providers:
  12. if issubclass(provider, ProviderModelMixin):
  13. if model.name in provider.model_aliases:
  14. model_name = provider.model_aliases[model.name]
  15. else:
  16. model_name = model.name
  17. await asyncio.wait_for(self.provider_has_model(provider, model_name), 10)
  18. async def provider_has_model(self, provider: Type[BaseProvider], model: str):
  19. if provider.__name__ not in self.cache:
  20. try:
  21. self.cache[provider.__name__] = provider.get_models()
  22. except (MissingRequirementsError, MissingAuthError):
  23. return
  24. if self.cache[provider.__name__]:
  25. self.assertIn(model, self.cache[provider.__name__], provider.__name__)