pydantic_ai.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. from __future__ import annotations
  2. from typing import Optional
  3. from functools import partial
  4. from dataclasses import dataclass, field
  5. from pydantic_ai.models import Model, KnownModelName, infer_model
  6. from pydantic_ai.models.openai import OpenAIModel, OpenAISystemPromptRole
  7. import pydantic_ai.models.openai
  8. pydantic_ai.models.openai.NOT_GIVEN = None
  9. from ..client import AsyncClient
  10. @dataclass(init=False)
  11. class AIModel(OpenAIModel):
  12. """A model that uses the G4F API."""
  13. client: AsyncClient = field(repr=False)
  14. system_prompt_role: OpenAISystemPromptRole | None = field(default=None)
  15. _model_name: str = field(repr=False)
  16. _provider: str = field(repr=False)
  17. _system: Optional[str] = field(repr=False)
  18. def __init__(
  19. self,
  20. model_name: str,
  21. provider: str | None = None,
  22. *,
  23. system_prompt_role: OpenAISystemPromptRole | None = None,
  24. system: str | None = 'openai',
  25. **kwargs
  26. ):
  27. """Initialize an AI model.
  28. Args:
  29. model_name: The name of the AI model to use. List of model names available
  30. [here](https://github.com/openai/openai-python/blob/v1.54.3/src/openai/types/chat_model.py#L7)
  31. (Unfortunately, despite being ask to do so, OpenAI do not provide `.inv` files for their API).
  32. system_prompt_role: The role to use for the system prompt message. If not provided, defaults to `'system'`.
  33. In the future, this may be inferred from the model name.
  34. system: The model provider used, defaults to `openai`. This is for observability purposes, you must
  35. customize the `base_url` and `api_key` to use a different provider.
  36. """
  37. self._model_name = model_name
  38. self._provider = provider
  39. self.client = AsyncClient(provider=provider, **kwargs)
  40. self.system_prompt_role = system_prompt_role
  41. self._system = system
  42. def name(self) -> str:
  43. if self._provider:
  44. return f'g4f:{self._provider}:{self._model_name}'
  45. return f'g4f:{self._model_name}'
  46. def new_infer_model(model: Model | KnownModelName, api_key: str = None) -> Model:
  47. if isinstance(model, Model):
  48. return model
  49. if model.startswith("g4f:"):
  50. model = model[4:]
  51. if ":" in model:
  52. provider, model = model.split(":", 1)
  53. return AIModel(model, provider=provider, api_key=api_key)
  54. return AIModel(model)
  55. return infer_model(model)
  56. def patch_infer_model(api_key: str | None = None):
  57. import pydantic_ai.models
  58. pydantic_ai.models.infer_model = partial(new_infer_model, api_key=api_key)
  59. pydantic_ai.models.AIModel = AIModel