langchain.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. from __future__ import annotations
  2. from typing import Any, Dict
  3. from langchain_community.chat_models import openai
  4. from langchain_community.chat_models.openai import ChatOpenAI, BaseMessage, convert_message_to_dict
  5. from pydantic import Field
  6. from g4f.client import AsyncClient, Client
  7. from g4f.client.stubs import ChatCompletionMessage
  8. def new_convert_message_to_dict(message: BaseMessage) -> dict:
  9. message_dict: Dict[str, Any]
  10. if isinstance(message, ChatCompletionMessage):
  11. message_dict = {"role": message.role, "content": message.content}
  12. if message.tool_calls is not None:
  13. message_dict["tool_calls"] = [{
  14. "id": tool_call.id,
  15. "type": tool_call.type,
  16. "function": tool_call.function
  17. } for tool_call in message.tool_calls]
  18. if message_dict["content"] == "":
  19. message_dict["content"] = None
  20. else:
  21. message_dict = convert_message_to_dict(message)
  22. return message_dict
  23. openai.convert_message_to_dict = new_convert_message_to_dict
  24. class ChatAI(ChatOpenAI):
  25. model_name: str = Field(default="gpt-4o", alias="model")
  26. @classmethod
  27. def validate_environment(cls, values: dict) -> dict:
  28. client_params = {
  29. "api_key": values["api_key"] if "api_key" in values else None,
  30. "provider": values["model_kwargs"]["provider"] if "provider" in values["model_kwargs"] else None,
  31. }
  32. values["client"] = Client(**client_params).chat.completions
  33. values["async_client"] = AsyncClient(
  34. **client_params
  35. ).chat.completions
  36. return values