1234567891011121314151617181920212223242526272829303132333435363738394041 |
- from __future__ import annotations
- from typing import Any, Dict
- from langchain_community.chat_models import openai
- from langchain_community.chat_models.openai import ChatOpenAI, BaseMessage, convert_message_to_dict
- from pydantic import Field
- from g4f.client import AsyncClient, Client
- from g4f.client.stubs import ChatCompletionMessage
- def new_convert_message_to_dict(message: BaseMessage) -> dict:
- message_dict: Dict[str, Any]
- if isinstance(message, ChatCompletionMessage):
- message_dict = {"role": message.role, "content": message.content}
- if message.tool_calls is not None:
- message_dict["tool_calls"] = [{
- "id": tool_call.id,
- "type": tool_call.type,
- "function": tool_call.function
- } for tool_call in message.tool_calls]
- if message_dict["content"] == "":
- message_dict["content"] = None
- else:
- message_dict = convert_message_to_dict(message)
- return message_dict
- openai.convert_message_to_dict = new_convert_message_to_dict
- class ChatAI(ChatOpenAI):
- model_name: str = Field(default="gpt-4o", alias="model")
- @classmethod
- def validate_environment(cls, values: dict) -> dict:
- client_params = {
- "api_key": values["api_key"] if "api_key" in values else None,
- "provider": values["model_kwargs"]["provider"] if "provider" in values["model_kwargs"] else None,
- }
- values["client"] = Client(**client_params).chat.completions
- values["async_client"] = AsyncClient(
- **client_params
- ).chat.completions
- return values
|