__init__.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. from __future__ import annotations
  2. import logging
  3. import json
  4. import uvicorn
  5. import secrets
  6. import os
  7. import shutil
  8. import os.path
  9. from fastapi import FastAPI, Response, Request, UploadFile
  10. from fastapi.responses import StreamingResponse, RedirectResponse, HTMLResponse, JSONResponse
  11. from fastapi.exceptions import RequestValidationError
  12. from fastapi.security import APIKeyHeader
  13. from starlette.exceptions import HTTPException
  14. from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY, HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
  15. from fastapi.encoders import jsonable_encoder
  16. from fastapi.middleware.cors import CORSMiddleware
  17. from starlette.responses import FileResponse
  18. from pydantic import BaseModel
  19. from typing import Union, Optional, List
  20. import g4f
  21. import g4f.debug
  22. from g4f.client import AsyncClient, ChatCompletion, convert_to_provider
  23. from g4f.providers.response import BaseConversation
  24. from g4f.client.helper import filter_none
  25. from g4f.image import is_accepted_format, images_dir
  26. from g4f.typing import Messages
  27. from g4f.errors import ProviderNotFoundError
  28. from g4f.cookies import read_cookie_files, get_cookies_dir
  29. from g4f.Provider import ProviderType, ProviderUtils, __providers__
  30. logger = logging.getLogger(__name__)
  31. def create_app(g4f_api_key: str = None):
  32. app = FastAPI()
  33. # Add CORS middleware
  34. app.add_middleware(
  35. CORSMiddleware,
  36. allow_origin_regex=".*",
  37. allow_credentials=True,
  38. allow_methods=["*"],
  39. allow_headers=["*"],
  40. )
  41. api = Api(app, g4f_api_key=g4f_api_key)
  42. api.register_routes()
  43. api.register_authorization()
  44. api.register_validation_exception_handler()
  45. # Read cookie files if not ignored
  46. if not AppConfig.ignore_cookie_files:
  47. read_cookie_files()
  48. return app
  49. def create_app_debug(g4f_api_key: str = None):
  50. g4f.debug.logging = True
  51. return create_app(g4f_api_key)
  52. class ChatCompletionsConfig(BaseModel):
  53. messages: Messages
  54. model: str
  55. provider: Optional[str] = None
  56. stream: bool = False
  57. temperature: Optional[float] = None
  58. max_tokens: Optional[int] = None
  59. stop: Union[list[str], str, None] = None
  60. api_key: Optional[str] = None
  61. web_search: Optional[bool] = None
  62. proxy: Optional[str] = None
  63. conversation_id: str = None
  64. class ImageGenerationConfig(BaseModel):
  65. prompt: str
  66. model: Optional[str] = None
  67. provider: Optional[str] = None
  68. response_format: str = "url"
  69. api_key: Optional[str] = None
  70. proxy: Optional[str] = None
  71. class ProviderResponseModel(BaseModel):
  72. id: str
  73. object: str = "provider"
  74. created: int
  75. owned_by: Optional[str]
  76. url: Optional[str]
  77. label: Optional[str]
  78. class ProviderResponseModelDetail(ProviderResponseModel):
  79. models: list[str]
  80. image_models: list[str]
  81. vision_models: list[str]
  82. params: list[str]
  83. class ModelResponseModel(BaseModel):
  84. id: str
  85. object: str = "model"
  86. created: int
  87. owned_by: Optional[str]
  88. class AppConfig:
  89. ignored_providers: Optional[list[str]] = None
  90. g4f_api_key: Optional[str] = None
  91. ignore_cookie_files: bool = False
  92. model: str = None,
  93. provider: str = None
  94. image_provider: str = None
  95. proxy: str = None
  96. @classmethod
  97. def set_config(cls, **data):
  98. for key, value in data.items():
  99. setattr(cls, key, value)
  100. list_ignored_providers: list[str] = None
  101. def set_list_ignored_providers(ignored: list[str]):
  102. global list_ignored_providers
  103. list_ignored_providers = ignored
  104. class Api:
  105. def __init__(self, app: FastAPI, g4f_api_key=None) -> None:
  106. self.app = app
  107. self.client = AsyncClient()
  108. self.g4f_api_key = g4f_api_key
  109. self.get_g4f_api_key = APIKeyHeader(name="g4f-api-key")
  110. self.conversations: dict[str, dict[str, BaseConversation]] = {}
  111. def register_authorization(self):
  112. @self.app.middleware("http")
  113. async def authorization(request: Request, call_next):
  114. if self.g4f_api_key and request.url.path not in ("/", "/v1"):
  115. try:
  116. user_g4f_api_key = await self.get_g4f_api_key(request)
  117. except HTTPException as e:
  118. if e.status_code == 403:
  119. return JSONResponse(
  120. status_code=HTTP_401_UNAUTHORIZED,
  121. content=jsonable_encoder({"detail": "G4F API key required"}),
  122. )
  123. if not secrets.compare_digest(self.g4f_api_key, user_g4f_api_key):
  124. return JSONResponse(
  125. status_code=HTTP_403_FORBIDDEN,
  126. content=jsonable_encoder({"detail": "Invalid G4F API key"}),
  127. )
  128. return await call_next(request)
  129. def register_validation_exception_handler(self):
  130. @self.app.exception_handler(RequestValidationError)
  131. async def validation_exception_handler(request: Request, exc: RequestValidationError):
  132. details = exc.errors()
  133. modified_details = []
  134. for error in details:
  135. modified_details.append({
  136. "loc": error["loc"],
  137. "message": error["msg"],
  138. "type": error["type"],
  139. })
  140. return JSONResponse(
  141. status_code=HTTP_422_UNPROCESSABLE_ENTITY,
  142. content=jsonable_encoder({"detail": modified_details}),
  143. )
  144. def register_routes(self):
  145. @self.app.get("/")
  146. async def read_root():
  147. return RedirectResponse("/v1", 302)
  148. @self.app.get("/v1")
  149. async def read_root_v1():
  150. return HTMLResponse('g4f API: Go to '
  151. '<a href="/v1/models">models</a>, '
  152. '<a href="/v1/chat/completions">chat/completions</a>, or '
  153. '<a href="/v1/images/generate">images/generate</a> <br><br>'
  154. 'Open Swagger UI at: '
  155. '<a href="/docs">/docs</a>')
  156. @self.app.get("/v1/models")
  157. async def models() -> list[ModelResponseModel]:
  158. model_list = dict(
  159. (model, g4f.models.ModelUtils.convert[model])
  160. for model in g4f.Model.__all__()
  161. )
  162. return [{
  163. 'id': model_id,
  164. 'object': 'model',
  165. 'created': 0,
  166. 'owned_by': model.base_provider
  167. } for model_id, model in model_list.items()]
  168. @self.app.get("/v1/models/{model_name}")
  169. async def model_info(model_name: str):
  170. if model_name in g4f.models.ModelUtils.convert:
  171. model_info = g4f.models.ModelUtils.convert[model_name]
  172. return JSONResponse({
  173. 'id': model_name,
  174. 'object': 'model',
  175. 'created': 0,
  176. 'owned_by': model_info.base_provider
  177. })
  178. return JSONResponse({"error": "The model does not exist."}, 404)
  179. @self.app.post("/v1/chat/completions")
  180. async def chat_completions(config: ChatCompletionsConfig, request: Request = None, provider: str = None):
  181. try:
  182. config.provider = provider if config.provider is None else config.provider
  183. if config.provider is None:
  184. config.provider = AppConfig.provider
  185. if config.api_key is None and request is not None:
  186. auth_header = request.headers.get("Authorization")
  187. if auth_header is not None:
  188. api_key = auth_header.split(None, 1)[-1]
  189. if api_key and api_key != "Bearer":
  190. config.api_key = api_key
  191. conversation = return_conversation = None
  192. if config.conversation_id is not None and config.provider is not None:
  193. return_conversation = True
  194. if config.conversation_id in self.conversations:
  195. if config.provider in self.conversations[config.conversation_id]:
  196. conversation = self.conversations[config.conversation_id][config.provider]
  197. # Create the completion response
  198. response = self.client.chat.completions.create(
  199. **filter_none(
  200. **{
  201. "model": AppConfig.model,
  202. "provider": AppConfig.provider,
  203. "proxy": AppConfig.proxy,
  204. **config.dict(exclude_none=True),
  205. **{
  206. "conversation_id": None,
  207. "return_conversation": return_conversation,
  208. "conversation": conversation
  209. }
  210. },
  211. ignored=AppConfig.ignored_providers
  212. ),
  213. )
  214. if not config.stream:
  215. response: ChatCompletion = await response
  216. return JSONResponse(response.to_json())
  217. async def streaming():
  218. try:
  219. async for chunk in response:
  220. if isinstance(chunk, BaseConversation):
  221. if config.conversation_id is not None and config.provider is not None:
  222. if config.conversation_id not in self.conversations:
  223. self.conversations[config.conversation_id] = {}
  224. self.conversations[config.conversation_id][config.provider] = chunk
  225. else:
  226. yield f"data: {json.dumps(chunk.to_json())}\n\n"
  227. except GeneratorExit:
  228. pass
  229. except Exception as e:
  230. logger.exception(e)
  231. yield f'data: {format_exception(e, config)}\n\n'
  232. yield "data: [DONE]\n\n"
  233. return StreamingResponse(streaming(), media_type="text/event-stream")
  234. except Exception as e:
  235. logger.exception(e)
  236. return Response(content=format_exception(e, config), status_code=500, media_type="application/json")
  237. @self.app.post("/v1/images/generate")
  238. @self.app.post("/v1/images/generations")
  239. async def generate_image(config: ImageGenerationConfig, request: Request):
  240. if config.api_key is None:
  241. auth_header = request.headers.get("Authorization")
  242. if auth_header is not None:
  243. api_key = auth_header.split(None, 1)[-1]
  244. if api_key and api_key != "Bearer":
  245. config.api_key = api_key
  246. try:
  247. response = await self.client.images.generate(
  248. prompt=config.prompt,
  249. model=config.model,
  250. provider=AppConfig.image_provider if config.provider is None else config.provider,
  251. **filter_none(
  252. response_format = config.response_format,
  253. api_key = config.api_key,
  254. proxy = config.proxy
  255. )
  256. )
  257. for image in response.data:
  258. if hasattr(image, "url") and image.url.startswith("/"):
  259. image.url = f"{request.base_url}{image.url.lstrip('/')}"
  260. return JSONResponse(response.to_json())
  261. except Exception as e:
  262. logger.exception(e)
  263. return Response(content=format_exception(e, config, True), status_code=500, media_type="application/json")
  264. @self.app.get("/v1/providers")
  265. async def providers() -> list[ProviderResponseModel]:
  266. return [{
  267. 'id': provider.__name__,
  268. 'object': 'provider',
  269. 'created': 0,
  270. 'url': provider.url,
  271. 'label': getattr(provider, "label", None),
  272. } for provider in __providers__ if provider.working]
  273. @self.app.get("/v1/providers/{provider}")
  274. async def providers_info(provider: str) -> ProviderResponseModelDetail:
  275. if provider not in ProviderUtils.convert:
  276. return JSONResponse({"error": "The provider does not exist."}, 404)
  277. provider: ProviderType = ProviderUtils.convert[provider]
  278. def safe_get_models(provider: ProviderType) -> list[str]:
  279. try:
  280. return provider.get_models() if hasattr(provider, "get_models") else []
  281. except:
  282. return []
  283. return {
  284. 'id': provider.__name__,
  285. 'object': 'provider',
  286. 'created': 0,
  287. 'url': provider.url,
  288. 'label': getattr(provider, "label", None),
  289. 'models': safe_get_models(provider),
  290. 'image_models': getattr(provider, "image_models", []) or [],
  291. 'vision_models': [model for model in [getattr(provider, "default_vision_model", None)] if model],
  292. 'params': [*provider.get_parameters()] if hasattr(provider, "get_parameters") else []
  293. }
  294. @self.app.post("/v1/upload_cookies")
  295. def upload_cookies(files: List[UploadFile]):
  296. response_data = []
  297. for file in files:
  298. try:
  299. if file and file.filename.endswith(".json") or file.filename.endswith(".har"):
  300. filename = os.path.basename(file.filename)
  301. with open(os.path.join(get_cookies_dir(), filename), 'wb') as f:
  302. shutil.copyfileobj(file.file, f)
  303. response_data.append({"filename": filename})
  304. finally:
  305. file.file.close()
  306. return response_data
  307. @self.app.get("/v1/synthesize/{provider}")
  308. async def synthesize(request: Request, provider: str):
  309. try:
  310. provider_handler = convert_to_provider(provider)
  311. except ProviderNotFoundError:
  312. return Response("Provider not found", 404)
  313. if not hasattr(provider_handler, "synthesize"):
  314. return Response("Provider doesn't support synthesize", 500)
  315. if len(request.query_params) == 0:
  316. return Response("Missing query params", 500)
  317. response_data = provider_handler.synthesize({**request.query_params})
  318. content_type = getattr(provider_handler, "synthesize_content_type", "application/octet-stream")
  319. return StreamingResponse(response_data, media_type=content_type)
  320. @self.app.get("/images/{filename}")
  321. async def get_image(filename) -> FileResponse:
  322. target = os.path.join(images_dir, filename)
  323. if not os.path.isfile(target):
  324. return Response(status_code=404)
  325. with open(target, "rb") as f:
  326. content_type = is_accepted_format(f.read(12))
  327. return FileResponse(target, media_type=content_type)
  328. def format_exception(e: Exception, config: Union[ChatCompletionsConfig, ImageGenerationConfig], image: bool = False) -> str:
  329. last_provider = {} if not image else g4f.get_last_provider(True)
  330. provider = (AppConfig.image_provider if image else AppConfig.provider) if config.provider is None else config.provider
  331. model = AppConfig.model if config.model is None else config.model
  332. return json.dumps({
  333. "error": {"message": f"{e.__class__.__name__}: {e}"},
  334. "model": last_provider.get("model") if model is None else model,
  335. **filter_none(
  336. provider=last_provider.get("name") if provider is None else provider
  337. )
  338. })
  339. def run_api(
  340. host: str = '0.0.0.0',
  341. port: int = 1337,
  342. bind: str = None,
  343. debug: bool = False,
  344. workers: int = None,
  345. use_colors: bool = None,
  346. reload: bool = False
  347. ) -> None:
  348. print(f'Starting server... [g4f v-{g4f.version.utils.current_version}]' + (" (debug)" if debug else ""))
  349. if use_colors is None:
  350. use_colors = debug
  351. if bind is not None:
  352. host, port = bind.split(":")
  353. uvicorn.run(
  354. f"g4f.api:create_app{'_debug' if debug else ''}",
  355. host=host,
  356. port=int(port),
  357. workers=workers,
  358. use_colors=use_colors,
  359. factory=True,
  360. reload=reload
  361. )